ndarray_einsum/contractors/
singleton_contractors.rs1use ndarray::prelude::*;
24use ndarray::LinalgScalar;
25
26use super::{SingletonContractor, SingletonViewer};
27use crate::{Contraction, SizedContraction};
28
29#[derive(Clone, Debug)]
33pub struct Identity {}
34
35impl Identity {
36 pub fn new(_sc: &SizedContraction) -> Self {
37 Identity {}
38 }
39}
40
41impl<A> SingletonViewer<A> for Identity {
42 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
43 where
44 'a: 'b,
45 A: Clone + LinalgScalar,
46 {
47 tensor.view()
48 }
49}
50
51impl<A> SingletonContractor<A> for Identity {
52 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
53 where
54 'a: 'b,
55 A: Clone + LinalgScalar,
56 {
57 tensor.to_owned()
58 }
59}
60
61#[derive(Clone, Debug)]
65pub struct Permutation {
66 permutation: Vec<usize>,
67}
68
69impl Permutation {
70 pub fn new(sc: &SizedContraction) -> Self {
71 let SizedContraction {
72 contraction:
73 Contraction {
74 operand_indices,
75 output_indices,
76 ..
77 },
78 ..
79 } = sc;
80
81 assert_eq!(operand_indices.len(), 1);
82 assert_eq!(operand_indices[0].len(), output_indices.len());
83
84 let mut permutation = Vec::new();
85 for &c in output_indices.iter() {
86 permutation.push(operand_indices[0].iter().position(|&x| x == c).unwrap());
87 }
88
89 Permutation { permutation }
90 }
91
92 pub fn from_indices(permutation: &[usize]) -> Self {
93 Permutation {
94 permutation: permutation.to_vec(),
95 }
96 }
97}
98
99impl<A> SingletonViewer<A> for Permutation {
100 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
101 where
102 'a: 'b,
103 A: Clone + LinalgScalar,
104 {
105 tensor.view().permuted_axes(IxDyn(&self.permutation))
106 }
107}
108
109impl<A> SingletonContractor<A> for Permutation {
110 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
111 where
112 'a: 'b,
113 A: Clone + LinalgScalar,
114 {
115 tensor
116 .view()
117 .permuted_axes(IxDyn(&self.permutation))
118 .to_owned()
119 }
120}
121
122#[derive(Clone, Debug)]
126pub struct Summation {
127 adjusted_axis_list: Vec<usize>,
128}
129
130impl Summation {
131 pub fn new(sc: &SizedContraction) -> Self {
132 let output_indices = &sc.contraction.output_indices;
133 let input_indices = &sc.contraction.operand_indices[0];
134
135 Summation::from_sizes(
136 output_indices.len(),
137 input_indices.len() - output_indices.len(),
138 )
139 }
140
141 fn from_sizes(start_index: usize, num_summed_axes: usize) -> Self {
142 assert!(num_summed_axes >= 1);
143 let adjusted_axis_list = (0..num_summed_axes).map(|_| start_index).collect();
144
145 Summation { adjusted_axis_list }
146 }
147}
148
149impl<A> SingletonContractor<A> for Summation {
150 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
151 where
152 'a: 'b,
153 A: Clone + LinalgScalar,
154 {
155 let mut result = tensor.sum_axis(Axis(self.adjusted_axis_list[0]));
156 for &axis in self.adjusted_axis_list[1..].iter() {
157 result = result.sum_axis(Axis(axis));
158 }
159 result
160 }
161}
162
163#[derive(Clone, Debug)]
171pub struct Diagonalization {
172 input_to_output_mapping: Vec<usize>,
173 output_shape: Vec<usize>,
174}
175
176impl Diagonalization {
177 pub fn new(sc: &SizedContraction) -> Self {
178 let SizedContraction {
179 contraction:
180 Contraction {
181 operand_indices,
182 output_indices,
183 ..
184 },
185 output_size,
186 } = sc;
187 assert_eq!(operand_indices.len(), 1);
188
189 let mut adjusted_output_indices = output_indices.clone();
190 let mut input_to_output_mapping = Vec::new();
191 for &c in operand_indices[0].iter() {
192 let current_length = adjusted_output_indices.len();
193 match adjusted_output_indices.iter().position(|&x| x == c) {
194 Some(pos) => {
195 input_to_output_mapping.push(pos);
196 }
197 None => {
198 adjusted_output_indices.push(c);
199 input_to_output_mapping.push(current_length);
200 }
201 }
202 }
203 let output_shape = adjusted_output_indices
204 .iter()
205 .map(|c| output_size[c])
206 .collect();
207
208 Diagonalization {
209 input_to_output_mapping,
210 output_shape,
211 }
212 }
213}
214
215impl<A> SingletonViewer<A> for Diagonalization {
216 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
217 where
218 'a: 'b,
219 A: Clone + LinalgScalar,
220 {
221 let mut strides = vec![0; self.output_shape.len()];
224 for (idx, &stride) in tensor.strides().iter().enumerate() {
225 assert!(stride > 0);
226 strides[self.input_to_output_mapping[idx]] += stride as usize;
227 }
228
229 let data_slice = tensor.as_slice_memory_order().unwrap();
232 ArrayView::from_shape(
233 IxDyn(&self.output_shape).strides(IxDyn(&strides)),
234 data_slice,
235 )
236 .unwrap()
237 }
238}
239
240impl<A> SingletonContractor<A> for Diagonalization {
241 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
242 where
243 'a: 'b,
244 A: Clone + LinalgScalar,
245 {
246 let cloned_tensor: ArrayD<A> =
249 Array::from_shape_vec(tensor.raw_dim(), tensor.iter().cloned().collect()).unwrap();
250 self.view_singleton(&cloned_tensor.view()).into_owned()
251 }
252}
253
254#[derive(Clone, Debug)]
258pub struct PermutationAndSummation {
259 permutation: Permutation,
260 summation: Summation,
261}
262
263impl PermutationAndSummation {
264 pub fn new(sc: &SizedContraction) -> Self {
265 let mut output_order: Vec<usize> = Vec::new();
266
267 for &output_char in sc.contraction.output_indices.iter() {
268 let input_pos = sc.contraction.operand_indices[0]
269 .iter()
270 .position(|&input_char| input_char == output_char)
271 .unwrap();
272 output_order.push(input_pos);
273 }
274 for (i, &input_char) in sc.contraction.operand_indices[0].iter().enumerate() {
275 if !sc
276 .contraction
277 .output_indices
278 .iter()
279 .any(|&output_char| output_char == input_char)
280 {
281 output_order.push(i);
282 }
283 }
284
285 let permutation = Permutation::from_indices(&output_order);
286 let summation = Summation::new(sc);
287
288 PermutationAndSummation {
289 permutation,
290 summation,
291 }
292 }
293}
294
295impl<A> SingletonContractor<A> for PermutationAndSummation {
296 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
297 where
298 'a: 'b,
299 A: Clone + LinalgScalar,
300 {
301 let permuted_singleton = self.permutation.view_singleton(tensor);
302 self.summation.contract_singleton(&permuted_singleton)
303 }
304}
305
306#[derive(Clone, Debug)]
314pub struct DiagonalizationAndSummation {
315 diagonalization: Diagonalization,
316 summation: Summation,
317}
318
319impl DiagonalizationAndSummation {
320 pub fn new(sc: &SizedContraction) -> Self {
321 let diagonalization = Diagonalization::new(sc);
322 let summation = Summation::from_sizes(
323 sc.contraction.output_indices.len(),
324 diagonalization.output_shape.len() - sc.contraction.output_indices.len(),
325 );
326
327 DiagonalizationAndSummation {
328 diagonalization,
329 summation,
330 }
331 }
332}
333
334impl<A> SingletonContractor<A> for DiagonalizationAndSummation {
335 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
336 where
337 'a: 'b,
338 A: Clone + LinalgScalar,
339 {
340 let contracted_singleton;
348 let viewed_singleton = if tensor.as_slice_memory_order().is_some()
349 && tensor.strides().iter().all(|&stride| stride > 0)
350 {
351 self.diagonalization.view_singleton(tensor)
352 } else {
353 contracted_singleton = self.diagonalization.contract_singleton(tensor);
354 contracted_singleton.view()
355 };
356
357 self.summation.contract_singleton(&viewed_singleton)
358 }
359}