ndarray_einsum/contractors/
singleton_contractors.rs1use ndarray::prelude::*;
24use ndarray::LinalgScalar;
25
26use super::{SingletonContractor, SingletonViewer};
27use crate::{Contraction, SizedContraction};
28
29#[derive(Clone, Debug)]
33pub struct Identity {}
34
35impl Identity {
36 pub fn new(_sc: &SizedContraction) -> Self {
37 Identity {}
38 }
39}
40
41impl<A> SingletonViewer<A> for Identity {
42 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
43 where
44 'a: 'b,
45 A: Clone + LinalgScalar,
46 {
47 tensor.view()
48 }
49}
50
51impl<A> SingletonContractor<A> for Identity {
52 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
53 where
54 'a: 'b,
55 A: Clone + LinalgScalar,
56 {
57 tensor.to_owned()
58 }
59}
60
61#[derive(Clone, Debug)]
65pub struct Permutation {
66 permutation: Vec<usize>,
67}
68
69impl Permutation {
70 pub fn new(sc: &SizedContraction) -> Self {
71 let SizedContraction {
72 contraction:
73 Contraction {
74 operand_indices,
75 output_indices,
76 ..
77 },
78 ..
79 } = sc;
80
81 assert_eq!(operand_indices.len(), 1);
82 assert_eq!(operand_indices[0].len(), output_indices.len());
83
84 let mut permutation = Vec::new();
85 for &c in output_indices.iter() {
86 permutation.push(operand_indices[0].iter().position(|&x| x == c).unwrap());
87 }
88
89 Permutation { permutation }
90 }
91
92 pub fn from_indices(permutation: &[usize]) -> Self {
93 Permutation {
94 permutation: permutation.to_vec(),
95 }
96 }
97}
98
99impl<A> SingletonViewer<A> for Permutation {
100 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
101 where
102 'a: 'b,
103 A: Clone + LinalgScalar,
104 {
105 tensor.view().permuted_axes(IxDyn(&self.permutation))
106 }
107}
108
109impl<A> SingletonContractor<A> for Permutation {
110 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
111 where
112 'a: 'b,
113 A: Clone + LinalgScalar,
114 {
115 tensor
116 .view()
117 .permuted_axes(IxDyn(&self.permutation))
118 .to_owned()
119 }
120}
121
122#[derive(Clone, Debug)]
126pub struct Summation {
127 adjusted_axis_list: Vec<usize>,
128}
129
130impl Summation {
131 pub fn new(sc: &SizedContraction) -> Self {
132 let output_indices = &sc.contraction.output_indices;
133 let input_indices = &sc.contraction.operand_indices[0];
134
135 Summation::from_sizes(
136 output_indices.len(),
137 input_indices.len() - output_indices.len(),
138 )
139 }
140
141 fn from_sizes(start_index: usize, num_summed_axes: usize) -> Self {
142 assert!(num_summed_axes >= 1);
143 let adjusted_axis_list = (0..num_summed_axes).map(|_| start_index).collect();
144
145 Summation { adjusted_axis_list }
146 }
147}
148
149impl<A> SingletonContractor<A> for Summation {
150 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
151 where
152 'a: 'b,
153 A: Clone + LinalgScalar,
154 {
155 let mut result = tensor.sum_axis(Axis(self.adjusted_axis_list[0]));
156 for &axis in self.adjusted_axis_list[1..].iter() {
157 result = result.sum_axis(Axis(axis));
158 }
159 result
160 }
161}
162
163#[derive(Clone, Debug)]
171pub struct Diagonalization {
172 input_to_output_mapping: Vec<usize>,
173 output_shape: Vec<usize>,
174}
175
176impl Diagonalization {
177 pub fn new(sc: &SizedContraction) -> Self {
178 let SizedContraction {
179 contraction:
180 Contraction {
181 operand_indices,
182 output_indices,
183 ..
184 },
185 output_size,
186 } = sc;
187 assert_eq!(operand_indices.len(), 1);
188
189 let mut adjusted_output_indices = output_indices.clone();
190 let mut input_to_output_mapping = Vec::new();
191 for &c in operand_indices[0].iter() {
192 let current_length = adjusted_output_indices.len();
193 match adjusted_output_indices.iter().position(|&x| x == c) {
194 Some(pos) => {
195 input_to_output_mapping.push(pos);
196 }
197 None => {
198 adjusted_output_indices.push(c);
199 input_to_output_mapping.push(current_length);
200 }
201 }
202 }
203 let output_shape = adjusted_output_indices
204 .iter()
205 .map(|c| output_size[c])
206 .collect();
207
208 Diagonalization {
209 input_to_output_mapping,
210 output_shape,
211 }
212 }
213}
214
215impl<A> SingletonViewer<A> for Diagonalization {
216 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
217 where
218 'a: 'b,
219 A: Clone + LinalgScalar,
220 {
221 let mut strides = vec![0; self.output_shape.len()];
224 for (idx, &stride) in tensor.strides().iter().enumerate() {
225 assert!(stride > 0);
226 strides[self.input_to_output_mapping[idx]] += stride as usize;
227 }
228
229 let data_slice = tensor.as_slice_memory_order().unwrap();
232 ArrayView::from_shape(
233 IxDyn(&self.output_shape).strides(IxDyn(&strides)),
234 data_slice,
235 )
236 .unwrap()
237 }
238}
239
240impl<A> SingletonContractor<A> for Diagonalization {
241 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
242 where
243 'a: 'b,
244 A: Clone + LinalgScalar,
245 {
246 let cloned_tensor: ArrayD<A> =
249 Array::from_shape_vec(tensor.raw_dim(), tensor.iter().cloned().collect()).unwrap();
250 self.view_singleton(&cloned_tensor.view()).into_owned()
251 }
252}
253
254#[derive(Clone, Debug)]
258pub struct PermutationAndSummation {
259 permutation: Permutation,
260 summation: Summation,
261}
262
263impl PermutationAndSummation {
264 pub fn new(sc: &SizedContraction) -> Self {
265 let mut output_order: Vec<usize> = Vec::new();
266
267 for &output_char in sc.contraction.output_indices.iter() {
268 let input_pos = sc.contraction.operand_indices[0]
269 .iter()
270 .position(|&input_char| input_char == output_char)
271 .unwrap();
272 output_order.push(input_pos);
273 }
274 for (i, &input_char) in sc.contraction.operand_indices[0].iter().enumerate() {
275 if !sc.contraction.output_indices.contains(&input_char) {
276 output_order.push(i);
277 }
278 }
279
280 let permutation = Permutation::from_indices(&output_order);
281 let summation = Summation::new(sc);
282
283 PermutationAndSummation {
284 permutation,
285 summation,
286 }
287 }
288}
289
290impl<A> SingletonContractor<A> for PermutationAndSummation {
291 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
292 where
293 'a: 'b,
294 A: Clone + LinalgScalar,
295 {
296 let permuted_singleton = self.permutation.view_singleton(tensor);
297 self.summation.contract_singleton(&permuted_singleton)
298 }
299}
300
301#[derive(Clone, Debug)]
309pub struct DiagonalizationAndSummation {
310 diagonalization: Diagonalization,
311 summation: Summation,
312}
313
314impl DiagonalizationAndSummation {
315 pub fn new(sc: &SizedContraction) -> Self {
316 let diagonalization = Diagonalization::new(sc);
317 let summation = Summation::from_sizes(
318 sc.contraction.output_indices.len(),
319 diagonalization.output_shape.len() - sc.contraction.output_indices.len(),
320 );
321
322 DiagonalizationAndSummation {
323 diagonalization,
324 summation,
325 }
326 }
327}
328
329impl<A> SingletonContractor<A> for DiagonalizationAndSummation {
330 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
331 where
332 'a: 'b,
333 A: Clone + LinalgScalar,
334 {
335 let contracted_singleton;
343 let viewed_singleton = if tensor.as_slice_memory_order().is_some()
344 && tensor.strides().iter().all(|&stride| stride > 0)
345 {
346 self.diagonalization.view_singleton(tensor)
347 } else {
348 contracted_singleton = self.diagonalization.contract_singleton(tensor);
349 contracted_singleton.view()
350 };
351
352 self.summation.contract_singleton(&viewed_singleton)
353 }
354}