opensrdk_linear_algebra/tensor/sparse/operations/
dot.rs1use crate::sparse::RankIndex;
2use crate::tensor::Tensor;
3use crate::{generate_rank_combinations, RankCombinationId};
4use crate::{sparse::SparseTensor, Number};
5use std::collections::HashMap;
6
7pub trait DotProduct<T>
8where
9 T: Number,
10{
11 fn dot_product(
12 self,
13 rank_combinations: &[HashMap<RankIndex, RankCombinationId>],
14 ) -> SparseTensor<T>;
15}
16
17impl<'a, I, T> DotProduct<T> for I
18where
19 I: Iterator<Item = &'a SparseTensor<T>>,
20 T: Number + 'a,
21{
22 fn dot_product(
23 self,
24 rank_combinations: &[HashMap<RankIndex, RankCombinationId>],
25 ) -> SparseTensor<T> {
26 let terms = self.collect::<Vec<_>>();
27 let max_rank = terms.iter().map(|t| t.rank()).max().unwrap();
28 let mut new_sizes = vec![1; max_rank];
29 let mut _rank_combination0 = 0;
30 let mut _rank_combination1 = 0;
31
32 for (i, t) in terms.iter().enumerate() {
33 for (j, &dim) in t.sizes.iter().enumerate() {
34 if rank_combinations[i].get(&j).is_none() && dim > 1 {
35 if new_sizes[j] == 1 {
36 new_sizes[j] = dim;
37 } else {
38 panic!("The tensor whose a rank that is not aggregated and has a dimension greater than 1 can't be included.")
39 }
40 } else if i == 0 && rank_combinations[i].get(&j).is_some() {
41 _rank_combination0 = j;
42 } else if i == 1 && rank_combinations[i].get(&j).is_some() {
43 _rank_combination1 = j;
44 }
45 }
46 }
47
48 let mut result = SparseTensor::<T>::new(new_sizes.clone());
49
50 fn create_indices(dimensions: &[usize]) -> Vec<Vec<usize>> {
51 let mut indices = Vec::new();
52 if dimensions.len() == 1 {
53 for i in 0..dimensions[0] {
54 indices.push(vec![i]);
55 }
56 } else {
57 for i in 0..dimensions[0] {
58 let sub_array = create_indices(&dimensions[1..]);
59 for j in 0..sub_array.len() {
60 let mut elem = sub_array[j].clone();
61 elem.insert(0, i);
62 indices.push(elem);
63 }
64 }
65 }
66 indices
67 }
68
69 let indices = create_indices(&new_sizes);
70
71 if terms[0].sizes[_rank_combination0] != terms[1].sizes[_rank_combination1] {
72 panic!("The dimensions of the rank to be aggregated must be the same.");
73 }
74 for index in indices.iter() {
75 for k in 0..terms[0].sizes[_rank_combination0] {
76 let mut first_index = index.clone();
77 first_index[_rank_combination0] = k;
78 let mut second_index = index.clone();
79 second_index[_rank_combination1] = k;
80
81 result[&index] += terms[0][&first_index] * terms[1][&second_index];
82 }
83 }
84 result
85 }
86}
87
88impl<T> SparseTensor<T>
89where
90 T: Number,
91{
92 pub fn dot(&self, rhs: &Self, rank_pairs: &[[RankIndex; 2]]) -> Self {
93 let rank_combinations = generate_rank_combinations(rank_pairs);
94
95 vec![self, rhs].into_iter().dot_product(&rank_combinations)
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::sparse::SparseTensor;
103 #[test]
104 fn test_dot_product() {
105 let mut a = SparseTensor::<f64>::new(vec![2, 2]);
106 a[&[0, 0]] = 1.0;
107 a[&[0, 1]] = 2.0;
108 a[&[1, 0]] = 3.0;
109 a[&[1, 1]] = 4.0;
110
111 let mut b = SparseTensor::<f64>::new(vec![2, 2]);
112 b[&[0, 0]] = 2.0;
113 b[&[0, 1]] = 4.0;
114 b[&[1, 0]] = 6.0;
115 b[&[1, 1]] = 8.0;
116
117 let mut c = SparseTensor::<f64>::new(vec![2, 2]);
118 c[&[0, 0]] = 14.0;
119 c[&[0, 1]] = 20.0;
120 c[&[1, 0]] = 30.0;
121 c[&[1, 1]] = 44.0;
122
123 let mut d = SparseTensor::<f64>::new(vec![2, 2, 2]);
124 d[&[0, 0, 0]] = 1.0;
125 d[&[0, 0, 1]] = 2.0;
126 d[&[0, 1, 0]] = 3.0;
127 d[&[0, 1, 1]] = 4.0;
128 d[&[1, 0, 0]] = 5.0;
129 d[&[1, 0, 1]] = 6.0;
130 d[&[1, 1, 0]] = 7.0;
131 d[&[1, 1, 1]] = 8.0;
132
133 let mut e = SparseTensor::<f64>::new(vec![2, 2, 2]);
134 e[&[0, 0, 0]] = 2.0;
135 e[&[0, 0, 1]] = 4.0;
136 e[&[0, 1, 0]] = 6.0;
137 e[&[0, 1, 1]] = 8.0;
138 e[&[1, 0, 0]] = 10.0;
139 e[&[1, 0, 1]] = 12.0;
140 e[&[1, 1, 0]] = 14.0;
141 e[&[1, 1, 1]] = 16.0;
142
143 let rank_pairs = [[1, 0]];
144 let rank_combinations = generate_rank_combinations(&rank_pairs);
145
146 let result = vec![&a, &b].into_iter().dot_product(&rank_combinations);
147 assert_eq!(result, c);
148 }
149}