1use crate::optimizers::{
40 generate_optimized_order, ContractionOrder, OperandNumber, OptimizationMethod,
41};
42use crate::{ArrayLike, SizedContraction};
43use hashbrown::HashSet;
44use ndarray::prelude::*;
45use ndarray::LinalgScalar;
46use std::fmt::Debug;
47
48mod singleton_contractors;
49use singleton_contractors::{
50 Diagonalization, DiagonalizationAndSummation, Identity, Permutation, PermutationAndSummation,
51 Summation,
52};
53
54mod pair_contractors;
55pub use pair_contractors::TensordotGeneral;
56use pair_contractors::{
57 BroadcastProductGeneral, HadamardProduct, HadamardProductGeneral, MatrixScalarProduct,
58 MatrixScalarProductGeneral, ScalarMatrixProduct, ScalarMatrixProductGeneral,
59 StackedTensordotGeneral, TensordotFixedPosition,
60};
61
62mod strategies;
63use strategies::{PairMethod, PairSummary, SingletonMethod, SingletonSummary};
64
65pub trait SingletonViewer<A>: Debug {
74 fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
75 where
76 'a: 'b,
77 A: Clone + LinalgScalar;
78}
79
80pub trait SingletonContractor<A>: Debug {
84 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
85 where
86 'a: 'b,
87 A: Clone + LinalgScalar;
88}
89
90pub trait PairContractor<A>: Debug {
96 fn contract_pair<'a, 'b, 'c, 'd>(
97 &self,
98 lhs: &'b ArrayViewD<'a, A>,
99 rhs: &'d ArrayViewD<'c, A>,
100 ) -> ArrayD<A>
101 where
102 'a: 'b,
103 'c: 'd,
104 A: Clone + LinalgScalar;
105
106 fn contract_and_assign_pair<'a, 'b, 'c, 'd, 'e, 'f>(
107 &self,
108 lhs: &'b ArrayViewD<'a, A>,
109 rhs: &'d ArrayViewD<'c, A>,
110 out: &'f mut ArrayViewMutD<'e, A>,
111 ) where
112 'a: 'b,
113 'c: 'd,
114 'e: 'f,
115 A: Clone + LinalgScalar,
116 {
117 let result = self.contract_pair(lhs, rhs);
118 out.assign(&result);
119 }
120}
121
122pub struct SingletonContraction<A> {
132 method: SingletonMethod,
133 op: Box<dyn SingletonContractor<A>>,
134}
135
136impl<A> SingletonContraction<A> {
137 pub fn new(sc: &SizedContraction) -> Self {
138 let singleton_summary = SingletonSummary::new(sc);
139 let method = singleton_summary.get_strategy();
140
141 SingletonContraction {
142 method,
143 op: match method {
144 SingletonMethod::Identity => Box::new(Identity::new(sc)),
145 SingletonMethod::Permutation => Box::new(Permutation::new(sc)),
146 SingletonMethod::Summation => Box::new(Summation::new(sc)),
147 SingletonMethod::Diagonalization => Box::new(Diagonalization::new(sc)),
148 SingletonMethod::PermutationAndSummation => {
149 Box::new(PermutationAndSummation::new(sc))
150 }
151 SingletonMethod::DiagonalizationAndSummation => {
152 Box::new(DiagonalizationAndSummation::new(sc))
153 }
154 },
155 }
156 }
157}
158
159impl<A> SingletonContractor<A> for SingletonContraction<A> {
160 fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
161 where
162 'a: 'b,
163 A: Clone + LinalgScalar,
164 {
165 self.op.contract_singleton(tensor)
166 }
167}
168
169impl<A> Debug for SingletonContraction<A> {
170 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171 write!(
172 f,
173 "SingletonContraction {{ method: {:?}, op: {:?} }}",
174 self.method, self.op
175 )
176 }
177}
178
179struct SimplificationMethodAndOutput<A> {
181 method: SingletonMethod,
182 op: Box<dyn SingletonContractor<A>>,
183 new_indices: Vec<char>,
184 einsum_string: String,
185}
186
187impl<A> SimplificationMethodAndOutput<A> {
188 fn from_indices_and_sizes(
192 this_input_indices: &[char],
193 other_input_indices: &[char],
194 output_indices: &[char],
195 orig_contraction: &SizedContraction,
196 ) -> Option<Self> {
197 let this_input_uniques: HashSet<char> = this_input_indices.iter().cloned().collect();
198 let other_input_uniques: HashSet<char> = other_input_indices.iter().cloned().collect();
199 let output_uniques: HashSet<char> = output_indices.iter().cloned().collect();
200
201 let other_and_output: HashSet<char> = other_input_uniques
202 .union(&output_uniques)
203 .cloned()
204 .collect();
205 let desired_uniques: HashSet<char> = this_input_uniques
206 .intersection(&other_and_output)
207 .cloned()
208 .collect();
209 let new_indices: Vec<char> = desired_uniques.iter().cloned().collect();
210
211 let simplification_sc = orig_contraction
212 .subset(&[this_input_indices.to_vec()], &new_indices)
213 .unwrap();
214
215 let SingletonContraction { method, op } = SingletonContraction::new(&simplification_sc);
216
217 match method {
218 SingletonMethod::Identity | SingletonMethod::Permutation => None,
219 _ => Some(SimplificationMethodAndOutput {
220 method,
221 op,
222 new_indices,
223 einsum_string: simplification_sc.as_einsum_string(),
224 }),
225 }
226 }
227}
228
229impl<A> Debug for SimplificationMethodAndOutput<A> {
230 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231 write!(
232 f,
233 "SingletonContraction {{ method: {:?}, op: {:?}, new_indices: {:?}, einsum_string: {:?} }}",
234 self.method, self.op, self.new_indices, self.einsum_string
235 )
236 }
237}
238
239pub struct PairContraction<A> {
258 lhs_simplification: Option<SimplificationMethodAndOutput<A>>,
259 rhs_simplification: Option<SimplificationMethodAndOutput<A>>,
260 method: PairMethod,
261 op: Box<dyn PairContractor<A>>,
262 simplified_einsum_string: String,
263}
264
265impl<A> PairContraction<A> {
266 pub fn new(sc: &SizedContraction) -> Self {
267 assert_eq!(sc.contraction.operand_indices.len(), 2);
268 let lhs_indices = &sc.contraction.operand_indices[0];
269 let rhs_indices = &sc.contraction.operand_indices[1];
270 let output_indices = &sc.contraction.output_indices;
271
272 let lhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
273 lhs_indices,
274 rhs_indices,
275 output_indices,
276 sc,
277 );
278 let rhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
279 rhs_indices,
280 lhs_indices,
281 output_indices,
282 sc,
283 );
284 let new_lhs_indices = match &lhs_simplification {
285 Some(s) => s.new_indices.clone(),
286 None => lhs_indices.clone(),
287 };
288 let new_rhs_indices = match &rhs_simplification {
289 Some(s) => s.new_indices.clone(),
290 None => rhs_indices.clone(),
291 };
292
293 let reduced_sc = sc
294 .subset(&[new_lhs_indices, new_rhs_indices], output_indices)
295 .unwrap();
296
297 let pair_summary = PairSummary::new(&reduced_sc);
298 let method = pair_summary.get_strategy();
299
300 let op: Box<dyn PairContractor<A>> = match method {
301 PairMethod::HadamardProduct => {
302 Box::new(HadamardProduct::new(&reduced_sc))
304 }
305 PairMethod::HadamardProductGeneral => {
306 Box::new(HadamardProductGeneral::new(&reduced_sc))
307 }
308 PairMethod::ScalarMatrixProduct => {
309 Box::new(ScalarMatrixProduct::new(&reduced_sc))
311 }
312 PairMethod::ScalarMatrixProductGeneral => {
313 Box::new(ScalarMatrixProductGeneral::new(&reduced_sc))
314 }
315 PairMethod::MatrixScalarProduct => {
316 Box::new(MatrixScalarProduct::new(&reduced_sc))
318 }
319 PairMethod::MatrixScalarProductGeneral => {
320 Box::new(MatrixScalarProductGeneral::new(&reduced_sc))
321 }
322 PairMethod::TensordotFixedPosition => {
323 Box::new(TensordotFixedPosition::new(&reduced_sc))
325 }
326 PairMethod::TensordotGeneral => Box::new(TensordotGeneral::new(&reduced_sc)),
327 PairMethod::StackedTensordotGeneral => {
328 Box::new(StackedTensordotGeneral::new(&reduced_sc))
329 }
330 PairMethod::BroadcastProductGeneral => {
331 Box::new(BroadcastProductGeneral::new(&reduced_sc))
333 }
334 };
335 PairContraction {
336 lhs_simplification,
337 rhs_simplification,
338 method,
339 op,
340 simplified_einsum_string: reduced_sc.as_einsum_string(),
341 }
342 }
343}
344
345impl<A> PairContractor<A> for PairContraction<A> {
346 fn contract_pair<'a, 'b, 'c, 'd>(
347 &self,
348 lhs: &'b ArrayViewD<'a, A>,
349 rhs: &'d ArrayViewD<'c, A>,
350 ) -> ArrayD<A>
351 where
352 'a: 'b,
353 'c: 'd,
354 A: Clone + LinalgScalar,
355 {
356 match (&self.lhs_simplification, &self.rhs_simplification) {
357 (None, None) => self.op.contract_pair(lhs, rhs),
358 (Some(lhs_contraction), None) => self
359 .op
360 .contract_pair(&lhs_contraction.op.contract_singleton(lhs).view(), rhs),
361 (None, Some(rhs_contraction)) => self
362 .op
363 .contract_pair(lhs, &rhs_contraction.op.contract_singleton(rhs).view()),
364 (Some(lhs_contraction), Some(rhs_contraction)) => self.op.contract_pair(
365 &lhs_contraction.op.contract_singleton(lhs).view(),
366 &rhs_contraction.op.contract_singleton(rhs).view(),
367 ),
368 }
369 }
370}
371
372impl<A> Debug for PairContraction<A> {
373 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
374 write!(
375 f,
376 "PairContraction {{ \
377 lhs_simplification: {:?}, \
378 rhs_simplification: {:?}, \
379 method: {:?}, \
380 op: {:?}, \
381 simplified_einsum_string: {:?}",
382 self.lhs_simplification,
383 self.rhs_simplification,
384 self.method,
385 self.op,
386 self.simplified_einsum_string
387 )
388 }
389}
390
391#[derive(Debug)]
394pub enum EinsumPathSteps<A> {
395 SingletonContraction(SingletonContraction<A>),
398
399 PairContractions(Vec<PairContraction<A>>),
403}
404
405pub struct EinsumPath<A> {
412 pub contraction_order: ContractionOrder,
414
415 pub steps: EinsumPathSteps<A>,
417}
418
419impl<A> EinsumPath<A> {
420 pub fn new(sc: &SizedContraction) -> Self {
421 let contraction_order = generate_optimized_order(sc, OptimizationMethod::Naive);
422
423 EinsumPath::from_path(&contraction_order)
424 }
425
426 pub fn from_path(contraction_order: &ContractionOrder) -> Self {
427 match contraction_order {
428 ContractionOrder::Singleton(sized_contraction) => EinsumPath {
429 contraction_order: contraction_order.clone(),
430 steps: EinsumPathSteps::SingletonContraction(SingletonContraction::new(
431 sized_contraction,
432 )),
433 },
434 ContractionOrder::Pairs(order_steps) => {
435 let mut steps = Vec::new();
436
437 for step in order_steps.iter() {
438 steps.push(PairContraction::new(&step.sized_contraction));
439 }
440
441 EinsumPath {
442 contraction_order: contraction_order.clone(),
443 steps: EinsumPathSteps::PairContractions(steps),
444 }
445 }
446 }
447 }
448}
449
450impl<A> EinsumPath<A> {
451 pub fn contract_operands(&self, operands: &[&dyn ArrayLike<A>]) -> ArrayD<A>
452 where
453 A: Clone + LinalgScalar,
454 {
455 match (&self.steps, &self.contraction_order) {
458 (EinsumPathSteps::SingletonContraction(c), ContractionOrder::Singleton(_)) => {
459 c.contract_singleton(&operands[0].into_dyn_view())
460 }
461 (EinsumPathSteps::PairContractions(steps), ContractionOrder::Pairs(order_steps)) => {
462 let mut intermediate_results: Vec<ArrayD<A>> = Vec::new();
463 for (step, order_step) in steps.iter().zip(order_steps.iter()) {
464 let lhs = match order_step.operand_nums.lhs {
465 OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
466 OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
467 };
468 let rhs = match order_step.operand_nums.rhs {
469 OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
470 OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
471 };
472 let intermediate_result = step.contract_pair(&lhs, &rhs);
473 intermediate_results.push(intermediate_result);
475 }
476 intermediate_results.pop().unwrap()
477 }
478 _ => panic!(), }
480 }
481}
482
483impl<A> Debug for EinsumPath<A> {
484 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
485 match &self.steps {
486 EinsumPathSteps::SingletonContraction(step) => write!(f, "only_step: {:?}", step),
487 EinsumPathSteps::PairContractions(steps) => write!(f, "steps: {:?}", steps),
488 }
489 }
490}