opensrdk_linear_algebra/tensor/sparse/operations/
direct.rs1use crate::sparse::RankIndex;
2use crate::tensor::Tensor;
3use crate::{sparse::SparseTensor, Number};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7pub trait DirectProduct<T>
8where
9 T: Number,
10{
11 fn direct_product(self) -> SparseTensor<T>;
12}
13
14impl<'a, I, T> DirectProduct<T> for I
15where
16 I: Iterator<Item = &'a SparseTensor<T>>,
17 T: Number + 'a,
18{
19 fn direct_product(self) -> SparseTensor<T> {
20 let terms = self.collect::<Vec<_>>();
21 let rhs_size = &terms[terms.len() - 1].sizes;
22 let new_sizes = terms.iter().fold(vec![], |mut acc, &next| {
23 if acc.len() < next.sizes.len() {
24 for i in 0..acc.len() {
25 acc[i] *= next.size(i);
26 }
27 acc.extend(next.sizes[acc.len()..].iter());
28 } else {
29 for i in 0..next.sizes.len() {
30 acc[i] *= next.size(i);
31 }
32 }
33 acc
34 });
35
36 let new_elems = terms
37 .iter()
38 .enumerate()
39 .fold(
40 Vec::<Vec<(usize, &Vec<usize>)>>::new(),
41 |accum, (term_index, &next_term)| {
42 if accum.is_empty() {
43 return next_term
44 .elems
45 .keys()
46 .map(|indices| vec![(term_index, indices)])
47 .collect::<Vec<_>>();
48 };
49 accum
50 .into_iter()
51 .flat_map(|acc| {
52 next_term
53 .elems
54 .keys()
55 .map(|indices| [&acc[..], &[(term_index, indices)]].concat())
56 .collect::<Vec<_>>()
57 })
58 .collect::<Vec<_>>()
59 },
60 )
61 .into_iter()
62 .map(|combination| {
63 combination.into_iter().fold(
64 (Vec::<usize>::new(), T::default()),
65 |(mut accum_indices, mut accum_value), (term_index, indices)| {
66 if accum_indices.is_empty() {
67 return (indices.clone(), terms[term_index].elem(&indices).clone());
68 }
69
70 if accum_indices.len() < indices.len() {
71 for i in 0..accum_indices.len() {
72 accum_indices[i] = accum_indices[i] * rhs_size[i] + indices[i];
73 }
74 accum_indices.extend(indices[accum_indices.len()..].iter());
75 } else {
76 for i in 0..indices.len() {
77 accum_indices[i] = accum_indices[i] * rhs_size[i] + indices[i];
78 }
79 }
80 accum_value *= terms[term_index].elem(&indices).clone();
81
82 (accum_indices, accum_value)
83 },
84 )
85 })
86 .collect();
87
88 SparseTensor::<T>::from(new_sizes, new_elems).unwrap()
89 }
90}
91
92impl<T> SparseTensor<T>
93where
94 T: Number,
95{
96 pub fn direct(&self, rhs: &Self) -> Self {
97 vec![self, rhs].into_iter().direct_product()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn direct() {
107 let mut a = SparseTensor::<f64>::new(vec![2, 2]);
108 a[&[0, 0]] = 1.0;
109 a[&[0, 1]] = 2.0;
110 a[&[1, 0]] = 3.0;
111 a[&[1, 1]] = 4.0;
112
113 let mut b = SparseTensor::<f64>::new(vec![2, 2]);
114 b[&[0, 0]] = 5.0;
115 b[&[0, 1]] = 6.0;
116 b[&[1, 0]] = 7.0;
117 b[&[1, 1]] = 8.0;
118
119 let mut c = SparseTensor::<f64>::new(vec![4, 4]);
120
121 c[&[0, 0]] = 5.0;
122 c[&[0, 1]] = 6.0;
123 c[&[1, 0]] = 7.0;
124 c[&[1, 1]] = 8.0;
125 c[&[0, 2]] = 10.0;
126 c[&[0, 3]] = 12.0;
127 c[&[1, 2]] = 14.0;
128 c[&[1, 3]] = 16.0;
129 c[&[2, 0]] = 15.0;
130 c[&[2, 1]] = 18.0;
131 c[&[3, 0]] = 21.0;
132 c[&[3, 1]] = 24.0;
133 c[&[2, 2]] = 20.0;
134 c[&[2, 3]] = 24.0;
135 c[&[3, 2]] = 28.0;
136 c[&[3, 3]] = 32.0;
137
138 assert_eq!(a.direct(&b), c);
139 }
140
141 #[test]
142 fn direct_product_two() {
143 let mut a = SparseTensor::<f64>::new(vec![2, 2]);
144 a[&[0, 0]] = 1.0;
145 a[&[0, 1]] = 2.0;
146 a[&[1, 0]] = 3.0;
147 a[&[1, 1]] = 4.0;
148
149 let mut b = SparseTensor::<f64>::new(vec![3, 3]);
150 b[&[0, 0]] = 5.0;
151 b[&[0, 1]] = 6.0;
152 b[&[1, 0]] = 7.0;
153 b[&[1, 1]] = 8.0;
154
155 let mut c = a.direct(&b);
156 println!("{:?}", c);
157 }
158 #[test]
159 fn direct_three_dimensional() {
160 let mut a = SparseTensor::<f64>::new(vec![2, 2, 2]);
161 a[&[0, 0, 0]] = 1.0;
162 a[&[0, 0, 1]] = 2.0;
163 a[&[0, 1, 0]] = 3.0;
164 a[&[0, 1, 1]] = 4.0;
165 a[&[1, 0, 0]] = 5.0;
166 a[&[1, 0, 1]] = 6.0;
167 a[&[1, 1, 0]] = 7.0;
168 a[&[1, 1, 1]] = 8.0;
169
170 let mut b = SparseTensor::<f64>::new(vec![2, 2, 2]);
171 b[&[0, 0, 0]] = 9.0;
172 b[&[0, 0, 1]] = 10.0;
173 b[&[0, 1, 0]] = 11.0;
174 b[&[0, 1, 1]] = 12.0;
175 b[&[1, 0, 0]] = 13.0;
176 b[&[1, 0, 1]] = 14.0;
177 b[&[1, 1, 0]] = 15.0;
178 b[&[1, 1, 1]] = 16.0;
179
180 let c = a.direct(&b);
181 println!("{:?}", c);
182
183 assert_eq!(c[&[0, 0, 0]], 9.0);
184
185 assert_eq!(c[&[0, 2, 0]], 27.0);
186 assert_eq!(c[&[0, 3, 1]], 36.0);
187
188 assert_eq!(c[&[0, 2, 2]], 36.0);
189 assert_eq!(c[&[0, 2, 3]], 40.0);
190 assert_eq!(c[&[0, 3, 2]], 44.0);
191 assert_eq!(c[&[0, 3, 3]], 48.0);
192
193 assert_eq!(c[&[1, 0, 0]], 13.0);
194 assert_eq!(c[&[1, 2, 3]], 56.0);
195 assert_eq!(c[&[1, 3, 3]], 64.0);
196
197 assert_eq!(c[&[2, 0, 0]], 45.0);
198 assert_eq!(c[&[2, 2, 2]], 72.0);
199 assert_eq!(c[&[2, 3, 3]], 96.0);
200
201 assert_eq!(c[&[3, 0, 0]], 65.0);
202 assert_eq!(c[&[3, 2, 2]], 104.0);
203 assert_eq!(c[&[3, 3, 3]], 128.0);
204 }
205}