opensrdk_linear_algebra/tensor/sparse/operations/
dot.rs

1use crate::sparse::RankIndex;
2use crate::tensor::Tensor;
3use crate::{generate_rank_combinations, RankCombinationId};
4use crate::{sparse::SparseTensor, Number};
5use std::collections::HashMap;
6
7pub trait DotProduct<T>
8where
9    T: Number,
10{
11    fn dot_product(
12        self,
13        rank_combinations: &[HashMap<RankIndex, RankCombinationId>],
14    ) -> SparseTensor<T>;
15}
16
17impl<'a, I, T> DotProduct<T> for I
18where
19    I: Iterator<Item = &'a SparseTensor<T>>,
20    T: Number + 'a,
21{
22    fn dot_product(
23        self,
24        rank_combinations: &[HashMap<RankIndex, RankCombinationId>],
25    ) -> SparseTensor<T> {
26        let terms = self.collect::<Vec<_>>();
27        let max_rank = terms.iter().map(|t| t.rank()).max().unwrap();
28        let mut new_sizes = vec![1; max_rank];
29        let mut _rank_combination0 = 0;
30        let mut _rank_combination1 = 0;
31
32        for (i, t) in terms.iter().enumerate() {
33            for (j, &dim) in t.sizes.iter().enumerate() {
34                if rank_combinations[i].get(&j).is_none() && dim > 1 {
35                    if new_sizes[j] == 1 {
36                        new_sizes[j] = dim;
37                    } else {
38                        panic!("The tensor whose a rank that is not aggregated and has a dimension greater than 1 can't be included.")
39                    }
40                } else if i == 0 && rank_combinations[i].get(&j).is_some() {
41                    _rank_combination0 = j;
42                } else if i == 1 && rank_combinations[i].get(&j).is_some() {
43                    _rank_combination1 = j;
44                }
45            }
46        }
47
48        let mut result = SparseTensor::<T>::new(new_sizes.clone());
49
50        fn create_indices(dimensions: &[usize]) -> Vec<Vec<usize>> {
51            let mut indices = Vec::new();
52            if dimensions.len() == 1 {
53                for i in 0..dimensions[0] {
54                    indices.push(vec![i]);
55                }
56            } else {
57                for i in 0..dimensions[0] {
58                    let sub_array = create_indices(&dimensions[1..]);
59                    for j in 0..sub_array.len() {
60                        let mut elem = sub_array[j].clone();
61                        elem.insert(0, i);
62                        indices.push(elem);
63                    }
64                }
65            }
66            indices
67        }
68
69        let indices = create_indices(&new_sizes);
70
71        if terms[0].sizes[_rank_combination0] != terms[1].sizes[_rank_combination1] {
72            panic!("The dimensions of the rank to be aggregated must be the same.");
73        }
74        for index in indices.iter() {
75            for k in 0..terms[0].sizes[_rank_combination0] {
76                let mut first_index = index.clone();
77                first_index[_rank_combination0] = k;
78                let mut second_index = index.clone();
79                second_index[_rank_combination1] = k;
80
81                result[&index] += terms[0][&first_index] * terms[1][&second_index];
82            }
83        }
84        result
85    }
86}
87
88impl<T> SparseTensor<T>
89where
90    T: Number,
91{
92    pub fn dot(&self, rhs: &Self, rank_pairs: &[[RankIndex; 2]]) -> Self {
93        let rank_combinations = generate_rank_combinations(rank_pairs);
94
95        vec![self, rhs].into_iter().dot_product(&rank_combinations)
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::sparse::SparseTensor;
103    #[test]
104    fn test_dot_product() {
105        let mut a = SparseTensor::<f64>::new(vec![2, 2]);
106        a[&[0, 0]] = 1.0;
107        a[&[0, 1]] = 2.0;
108        a[&[1, 0]] = 3.0;
109        a[&[1, 1]] = 4.0;
110
111        let mut b = SparseTensor::<f64>::new(vec![2, 2]);
112        b[&[0, 0]] = 2.0;
113        b[&[0, 1]] = 4.0;
114        b[&[1, 0]] = 6.0;
115        b[&[1, 1]] = 8.0;
116
117        let mut c = SparseTensor::<f64>::new(vec![2, 2]);
118        c[&[0, 0]] = 14.0;
119        c[&[0, 1]] = 20.0;
120        c[&[1, 0]] = 30.0;
121        c[&[1, 1]] = 44.0;
122
123        let mut d = SparseTensor::<f64>::new(vec![2, 2, 2]);
124        d[&[0, 0, 0]] = 1.0;
125        d[&[0, 0, 1]] = 2.0;
126        d[&[0, 1, 0]] = 3.0;
127        d[&[0, 1, 1]] = 4.0;
128        d[&[1, 0, 0]] = 5.0;
129        d[&[1, 0, 1]] = 6.0;
130        d[&[1, 1, 0]] = 7.0;
131        d[&[1, 1, 1]] = 8.0;
132
133        let mut e = SparseTensor::<f64>::new(vec![2, 2, 2]);
134        e[&[0, 0, 0]] = 2.0;
135        e[&[0, 0, 1]] = 4.0;
136        e[&[0, 1, 0]] = 6.0;
137        e[&[0, 1, 1]] = 8.0;
138        e[&[1, 0, 0]] = 10.0;
139        e[&[1, 0, 1]] = 12.0;
140        e[&[1, 1, 0]] = 14.0;
141        e[&[1, 1, 1]] = 16.0;
142
143        let rank_pairs = [[1, 0]];
144        let rank_combinations = generate_rank_combinations(&rank_pairs);
145
146        let result = vec![&a, &b].into_iter().dot_product(&rank_combinations);
147        assert_eq!(result, c);
148    }
149}