opensrdk_linear_algebra/tensor/sparse/operations/
direct.rs

1use crate::sparse::RankIndex;
2use crate::tensor::Tensor;
3use crate::{sparse::SparseTensor, Number};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7pub trait DirectProduct<T>
8where
9    T: Number,
10{
11    fn direct_product(self) -> SparseTensor<T>;
12}
13
14impl<'a, I, T> DirectProduct<T> for I
15where
16    I: Iterator<Item = &'a SparseTensor<T>>,
17    T: Number + 'a,
18{
19    fn direct_product(self) -> SparseTensor<T> {
20        let terms = self.collect::<Vec<_>>();
21        let rhs_size = &terms[terms.len() - 1].sizes;
22        let new_sizes = terms.iter().fold(vec![], |mut acc, &next| {
23            if acc.len() < next.sizes.len() {
24                for i in 0..acc.len() {
25                    acc[i] *= next.size(i);
26                }
27                acc.extend(next.sizes[acc.len()..].iter());
28            } else {
29                for i in 0..next.sizes.len() {
30                    acc[i] *= next.size(i);
31                }
32            }
33            acc
34        });
35
36        let new_elems = terms
37            .iter()
38            .enumerate()
39            .fold(
40                Vec::<Vec<(usize, &Vec<usize>)>>::new(),
41                |accum, (term_index, &next_term)| {
42                    if accum.is_empty() {
43                        return next_term
44                            .elems
45                            .keys()
46                            .map(|indices| vec![(term_index, indices)])
47                            .collect::<Vec<_>>();
48                    };
49                    accum
50                        .into_iter()
51                        .flat_map(|acc| {
52                            next_term
53                                .elems
54                                .keys()
55                                .map(|indices| [&acc[..], &[(term_index, indices)]].concat())
56                                .collect::<Vec<_>>()
57                        })
58                        .collect::<Vec<_>>()
59                },
60            )
61            .into_iter()
62            .map(|combination| {
63                combination.into_iter().fold(
64                    (Vec::<usize>::new(), T::default()),
65                    |(mut accum_indices, mut accum_value), (term_index, indices)| {
66                        if accum_indices.is_empty() {
67                            return (indices.clone(), terms[term_index].elem(&indices).clone());
68                        }
69
70                        if accum_indices.len() < indices.len() {
71                            for i in 0..accum_indices.len() {
72                                accum_indices[i] = accum_indices[i] * rhs_size[i] + indices[i];
73                            }
74                            accum_indices.extend(indices[accum_indices.len()..].iter());
75                        } else {
76                            for i in 0..indices.len() {
77                                accum_indices[i] = accum_indices[i] * rhs_size[i] + indices[i];
78                            }
79                        }
80                        accum_value *= terms[term_index].elem(&indices).clone();
81
82                        (accum_indices, accum_value)
83                    },
84                )
85            })
86            .collect();
87
88        SparseTensor::<T>::from(new_sizes, new_elems).unwrap()
89    }
90}
91
92impl<T> SparseTensor<T>
93where
94    T: Number,
95{
96    pub fn direct(&self, rhs: &Self) -> Self {
97        vec![self, rhs].into_iter().direct_product()
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn direct() {
107        let mut a = SparseTensor::<f64>::new(vec![2, 2]);
108        a[&[0, 0]] = 1.0;
109        a[&[0, 1]] = 2.0;
110        a[&[1, 0]] = 3.0;
111        a[&[1, 1]] = 4.0;
112
113        let mut b = SparseTensor::<f64>::new(vec![2, 2]);
114        b[&[0, 0]] = 5.0;
115        b[&[0, 1]] = 6.0;
116        b[&[1, 0]] = 7.0;
117        b[&[1, 1]] = 8.0;
118
119        let mut c = SparseTensor::<f64>::new(vec![4, 4]);
120
121        c[&[0, 0]] = 5.0;
122        c[&[0, 1]] = 6.0;
123        c[&[1, 0]] = 7.0;
124        c[&[1, 1]] = 8.0;
125        c[&[0, 2]] = 10.0;
126        c[&[0, 3]] = 12.0;
127        c[&[1, 2]] = 14.0;
128        c[&[1, 3]] = 16.0;
129        c[&[2, 0]] = 15.0;
130        c[&[2, 1]] = 18.0;
131        c[&[3, 0]] = 21.0;
132        c[&[3, 1]] = 24.0;
133        c[&[2, 2]] = 20.0;
134        c[&[2, 3]] = 24.0;
135        c[&[3, 2]] = 28.0;
136        c[&[3, 3]] = 32.0;
137
138        assert_eq!(a.direct(&b), c);
139    }
140
141    #[test]
142    fn direct_product_two() {
143        let mut a = SparseTensor::<f64>::new(vec![2, 2]);
144        a[&[0, 0]] = 1.0;
145        a[&[0, 1]] = 2.0;
146        a[&[1, 0]] = 3.0;
147        a[&[1, 1]] = 4.0;
148
149        let mut b = SparseTensor::<f64>::new(vec![3, 3]);
150        b[&[0, 0]] = 5.0;
151        b[&[0, 1]] = 6.0;
152        b[&[1, 0]] = 7.0;
153        b[&[1, 1]] = 8.0;
154
155        let mut c = a.direct(&b);
156        println!("{:?}", c);
157    }
158    #[test]
159    fn direct_three_dimensional() {
160        let mut a = SparseTensor::<f64>::new(vec![2, 2, 2]);
161        a[&[0, 0, 0]] = 1.0;
162        a[&[0, 0, 1]] = 2.0;
163        a[&[0, 1, 0]] = 3.0;
164        a[&[0, 1, 1]] = 4.0;
165        a[&[1, 0, 0]] = 5.0;
166        a[&[1, 0, 1]] = 6.0;
167        a[&[1, 1, 0]] = 7.0;
168        a[&[1, 1, 1]] = 8.0;
169
170        let mut b = SparseTensor::<f64>::new(vec![2, 2, 2]);
171        b[&[0, 0, 0]] = 9.0;
172        b[&[0, 0, 1]] = 10.0;
173        b[&[0, 1, 0]] = 11.0;
174        b[&[0, 1, 1]] = 12.0;
175        b[&[1, 0, 0]] = 13.0;
176        b[&[1, 0, 1]] = 14.0;
177        b[&[1, 1, 0]] = 15.0;
178        b[&[1, 1, 1]] = 16.0;
179
180        let c = a.direct(&b);
181        println!("{:?}", c);
182
183        assert_eq!(c[&[0, 0, 0]], 9.0);
184
185        assert_eq!(c[&[0, 2, 0]], 27.0);
186        assert_eq!(c[&[0, 3, 1]], 36.0);
187
188        assert_eq!(c[&[0, 2, 2]], 36.0);
189        assert_eq!(c[&[0, 2, 3]], 40.0);
190        assert_eq!(c[&[0, 3, 2]], 44.0);
191        assert_eq!(c[&[0, 3, 3]], 48.0);
192
193        assert_eq!(c[&[1, 0, 0]], 13.0);
194        assert_eq!(c[&[1, 2, 3]], 56.0);
195        assert_eq!(c[&[1, 3, 3]], 64.0);
196
197        assert_eq!(c[&[2, 0, 0]], 45.0);
198        assert_eq!(c[&[2, 2, 2]], 72.0);
199        assert_eq!(c[&[2, 3, 3]], 96.0);
200
201        assert_eq!(c[&[3, 0, 0]], 65.0);
202        assert_eq!(c[&[3, 2, 2]], 104.0);
203        assert_eq!(c[&[3, 3, 3]], 128.0);
204    }
205}