1use crate::*;
3use std::collections::VecDeque;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ContractionTree {
8 Leaf(usize),
9 Node(Vec<ContractionTree>),
10}
11
12impl From<usize> for ContractionTree {
13 fn from(value: usize) -> Self {
14 ContractionTree::Leaf(value)
15 }
16}
17
18impl From<Vec<ContractionTree>> for ContractionTree {
19 fn from(value: Vec<ContractionTree>) -> Self {
20 ContractionTree::Node(value)
21 }
22}
23
24impl From<Vec<usize>> for ContractionTree {
25 fn from(value: Vec<usize>) -> Self {
26 ContractionTree::Node(value.into_iter().map(ContractionTree::Leaf).collect())
27 }
28}
29
30pub fn tree_to_sequence(tree: &ContractionTree) -> PathType {
54 if let ContractionTree::Leaf(_) = tree {
56 return Vec::new();
57 }
58
59 let mut c: VecDeque<&ContractionTree> = VecDeque::new(); c.push_back(tree);
61
62 let mut t: Vec<usize> = Vec::new(); let mut s: VecDeque<Vec<usize>> = VecDeque::new(); while !c.is_empty() {
66 let j = c.pop_back().unwrap();
67 s.push_front(Vec::new());
68
69 if let ContractionTree::Node(children) = j {
71 let mut int_children: Vec<usize> = children
73 .iter()
74 .filter_map(|child| match child {
75 ContractionTree::Leaf(i) => Some(*i),
76 _ => None,
77 })
78 .collect();
79
80 int_children.sort_unstable();
82
83 for i in int_children {
84 let pos = t.iter().filter(|&&q| q < i).count();
86 s[0].push(pos);
87 t.insert(pos, i);
88 }
89
90 for i_tup in children.iter().filter(|child| matches!(child, ContractionTree::Node(_))) {
92 let pos = t.len() + c.len();
93 s[0].push(pos);
94 c.push_back(i_tup);
95 }
96 }
97 }
98
99 s.into_iter().collect()
100}
101
102pub fn find_disconnected_subgraphs(inputs: &[ArrayIndexType], output: &ArrayIndexType) -> Vec<BTreeSet<usize>> {
123 let mut subgraphs = Vec::new();
124 let mut unused_inputs: BTreeSet<usize> = (0..inputs.len()).collect();
125
126 let input_indices: ArrayIndexType = inputs.iter().flat_map(|set| set.iter()).cloned().collect();
128 let i_sum = &input_indices - output;
129
130 while !unused_inputs.is_empty() {
131 let mut g = BTreeSet::new();
132 let mut queue = VecDeque::new();
133
134 queue.push_back(*unused_inputs.iter().next().unwrap());
136 unused_inputs.remove(&queue[0]);
137
138 while !queue.is_empty() {
139 let j = queue.pop_front().unwrap();
140 g.insert(j);
141
142 let i_tmp: ArrayIndexType = &i_sum & &inputs[j];
144
145 let neighbors = unused_inputs.iter().filter(|&&k| !inputs[k].is_disjoint(&i_tmp)).cloned().collect_vec();
147
148 for neighbor in neighbors {
149 queue.push_back(neighbor);
150 unused_inputs.remove(&neighbor);
151 }
152 }
153 subgraphs.push(g);
154 }
155 subgraphs
156}
157
158pub fn bitmap_select<'t, T>(s: &'t BigUint, seq: &'t [T]) -> impl Iterator<Item = &'t T> {
169 let uint_1 = BigUint::from_u32(1).unwrap();
170 seq.iter().enumerate().filter(move |(i, _)| (s >> i) & &uint_1 == uint_1).map(move |(_, x)| x)
171}
172
173pub fn dp_calc_legs(
189 g: &BigUint,
190 all_tensors: &BigUint,
191 s: &BigUint,
192 inputs: &[&ArrayIndexType],
193 i1_cut_i2_wo_output: &ArrayIndexType,
194 i1_union_i2: &ArrayIndexType,
195) -> ArrayIndexType {
196 let r = g & (all_tensors ^ s);
198
199 let i_r = if r != BigUint::ZERO {
201 bitmap_select(&r, inputs).flat_map(|x| x.iter()).collect_vec().into_iter().copied().collect()
202 } else {
203 ArrayIndexType::new()
204 };
205
206 let i_contract = i1_cut_i2_wo_output - &i_r;
208 i1_union_i2 - &i_contract
209}
210
211#[derive(Debug, Clone)]
212pub struct DpTerm {
213 pub indices: ArrayIndexType,
214 pub cost: SizeType,
215 pub contract: ContractionTree,
216}
217
218pub struct DpCompareArgs<'a> {
219 pub minimize: &'a str,
221 pub combo_factor: SizeType,
222 pub inputs: &'a [&'a ArrayIndexType],
224 pub size_dict: &'a SizeDictType,
225 pub all_tensors: BigUint,
226 pub memory_limit: Option<SizeType>,
227 pub cost_cap: SizeType,
228 pub bitmap_g: BigUint,
229}
230
231impl<'a> DpCompareArgs<'a> {
232 pub fn compare_flops(
240 &self,
241 xn: &mut BTreeMap<BigUint, DpTerm>,
242 s1: &BigUint,
243 s2: &BigUint,
244 term1: &DpTerm,
245 term2: &DpTerm,
246 i1_cut_i2_wo_output: &ArrayIndexType,
247 ) {
248 let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
249 let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
250 let i1_union_i2 = i1 | i2;
251
252 let cost = cost1 + cost2 + helpers::compute_size_by_dict(i1_union_i2.iter(), self.size_dict);
253 if cost <= self.cost_cap {
254 let s = s1 | s2;
255 if xn.get(&s).is_none_or(|term| cost < term.cost) {
256 let indices =
257 dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
258 let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
259 if self.memory_limit.is_none_or(|limit| mem <= limit) {
260 let contract = vec![contract1.clone(), contract2.clone()].into();
261 xn.insert(s, DpTerm { indices, cost, contract });
262 }
263 }
264 }
265 }
266
267 pub fn compare_size(
271 &self,
272 xn: &mut BTreeMap<BigUint, DpTerm>,
273 s1: &BigUint,
274 s2: &BigUint,
275 term1: &DpTerm,
276 term2: &DpTerm,
277 i1_cut_i2_wo_output: &ArrayIndexType,
278 ) {
279 let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
280 let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
281 let i1_union_i2 = i1 | i2;
282 let s = s1 | s2;
283 let indices =
284 dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
285
286 let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
287 let cost = (*cost1).max(*cost2).max(mem);
288 if cost <= self.cost_cap
289 && xn.get(&s).is_none_or(|term| cost < term.cost)
290 && self.memory_limit.is_none_or(|limit| mem <= limit)
291 {
292 let contract = vec![contract1.clone(), contract2.clone()].into();
293 xn.insert(s, DpTerm { indices, cost, contract });
294 }
295 }
296 pub fn compare_write(
299 &self,
300 xn: &mut BTreeMap<BigUint, DpTerm>,
301 s1: &BigUint,
302 s2: &BigUint,
303 term1: &DpTerm,
304 term2: &DpTerm,
305 i1_cut_i2_wo_output: &ArrayIndexType,
306 ) {
307 let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
308 let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
309 let i1_union_i2 = i1 | i2;
310 let s = s1 | s2;
311 let indices =
312 dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
313
314 let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
315 let cost = cost1 + cost2 + mem;
316
317 if cost <= self.cost_cap
318 && xn.get(&s).is_none_or(|term| cost < term.cost)
319 && self.memory_limit.is_none_or(|limit| mem <= limit)
320 {
321 let contract = vec![contract1.clone(), contract2.clone()].into();
322 xn.insert(s, DpTerm { indices, cost, contract });
323 }
324 }
325
326 pub fn compare_combo(
329 &self,
330 xn: &mut BTreeMap<BigUint, DpTerm>,
331 s1: &BigUint,
332 s2: &BigUint,
333 term1: &DpTerm,
334 term2: &DpTerm,
335 i1_cut_i2_wo_output: &ArrayIndexType,
336 ) {
337 let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
338 let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
339 let i1_union_i2 = i1 | i2;
340 let s = s1 | s2;
341 let indices =
342 dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
343
344 let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
345 let f = helpers::compute_size_by_dict(i1_union_i2.iter(), self.size_dict);
346
347 let combined = match self.minimize {
349 "combo" => f + self.combo_factor * mem,
350 "limit" => f.max(self.combo_factor * mem),
351 _ => panic!("Unknown minimize type for combo mode: {}", self.minimize),
352 };
353 let cost = cost1 + cost2 + combined;
354
355 if cost <= self.cost_cap
356 && xn.get(&s).is_none_or(|term| cost < term.cost)
357 && self.memory_limit.is_none_or(|limit| mem <= limit)
358 {
359 let contract = vec![contract1.clone(), contract2.clone()].into();
360 xn.insert(s, DpTerm { indices, cost, contract });
361 }
362 }
363
364 pub fn scale(&self) -> SizeType {
365 get_scale_from_minimize(self.minimize)
366 }
367
368 pub fn compare(
369 &self,
370 xn: &mut BTreeMap<BigUint, DpTerm>,
371 s1: &BigUint,
372 s2: &BigUint,
373 term1: &DpTerm,
374 term2: &DpTerm,
375 i1_cut_i2_wo_output: &ArrayIndexType,
376 ) {
377 let minimize_split = self.minimize.split('-').collect_vec();
378 if minimize_split.is_empty() {
379 panic!("Unknown minimize type: {}", self.minimize);
380 }
381 match minimize_split[0] {
382 "flops" => self.compare_flops(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
383 "size" => self.compare_size(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
384 "write" => self.compare_write(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
385 "combo" | "limit" => self.compare_combo(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
386 _ => panic!("Unknown minimize type: {}", self.minimize),
387 }
388 }
389}
390
391pub fn get_scale_from_minimize(minimize: &str) -> SizeType {
392 match minimize {
393 "flops" | "size" | "write" => SizeType::one(),
394 "combo" | "limit" => SizeType::MAX,
395 _ => panic!("Unknown minimize type: {minimize}"),
396 }
397}
398
399pub fn simple_tree_tuple(seq: &[ContractionTree]) -> ContractionTree {
407 seq.iter().cloned().reduce(|left, right| ContractionTree::Node(vec![left, right])).unwrap()
408}
409use std::collections::{BTreeMap, BTreeSet};
410
411pub fn dp_parse_out_single_term_ops(
418 inputs: &[&ArrayIndexType],
419 all_inds: &[char],
420 ind_counts: &SizeDictType,
421) -> (Vec<ArrayIndexType>, Vec<ContractionTree>, Vec<ContractionTree>) {
422 let i_single: BTreeSet<char> = all_inds.iter().filter(|&c| ind_counts.get(c) == Some(&1)).cloned().collect();
423
424 let mut inputs_parsed = Vec::new();
425 let mut inputs_done = Vec::new();
426 let mut inputs_contractions = Vec::new();
427
428 for (j, input) in inputs.iter().enumerate() {
429 let i_reduced: ArrayIndexType = *input - &i_single;
430 if i_reduced.is_empty() && !input.is_empty() {
431 inputs_done.push(vec![j].into());
433 } else {
434 inputs_contractions.push(if i_reduced.len() != input.len() { vec![j].into() } else { j.into() });
436 inputs_parsed.push(i_reduced);
437 }
438 }
439
440 (inputs_parsed, inputs_done, inputs_contractions)
441}
442
443#[derive(Debug, Clone)]
444pub struct DynamicProgramming {
445 pub minimize: String,
446 pub search_outer: bool,
447 pub cost_cap: SizeLimitType,
448 pub combo_factor: SizeType,
449}
450
451impl Default for DynamicProgramming {
452 fn default() -> Self {
453 Self {
454 minimize: "flops".into(),
455 search_outer: false,
456 cost_cap: true.into(),
457 combo_factor: SizeType::from_usize(64).unwrap(),
458 }
459 }
460}
461
462impl DynamicProgramming {
463 pub fn find_optimal_path(
464 &self,
465 inputs: &[&ArrayIndexType],
466 output: &ArrayIndexType,
467 size_dict: &SizeDictType,
468 memory_limit: Option<SizeType>,
469 ) -> Result<PathType, String> {
470 let uint_1 = BigUint::from(1u32);
471 let uint_0 = BigUint::from(0u32);
472
473 let ind_counts: BTreeMap<char, usize> =
475 inputs.iter().flat_map(|inds| inds.iter()).chain(output.iter()).fold(BTreeMap::new(), |mut counts, &c| {
476 *counts.entry(c).or_default() += 1;
477 counts
478 });
479
480 let all_inds: Vec<char> = ind_counts.keys().copied().collect();
481
482 let (inputs, inputs_done, inputs_contractions) = dp_parse_out_single_term_ops(inputs, &all_inds, &ind_counts);
484 let inputs_ref = inputs.iter().collect_vec();
485
486 if inputs.is_empty() {
487 return Ok(tree_to_sequence(&simple_tree_tuple(&inputs_done)));
488 }
489
490 let mut subgraph_contractions = inputs_done;
492 let mut subgraph_sizes: Vec<SizeType> = vec![SizeType::one(); subgraph_contractions.len()];
493
494 let subgraphs = if self.search_outer {
496 vec![(0..inputs.len()).collect_vec()]
497 } else {
498 find_disconnected_subgraphs(&inputs, output).into_iter().map(|s| s.into_iter().collect()).collect()
499 };
500
501 let all_tensors = (&uint_1 << inputs.len()) - &uint_1;
502 let naive_scale = get_scale_from_minimize(&self.minimize);
503 let naive_cost = naive_scale
504 * SizeType::from_usize(inputs.len()).unwrap()
505 * size_dict.values().map(|v| SizeType::from_usize(*v).unwrap()).product::<SizeType>();
506
507 for g in subgraphs {
508 let bitmap_g = g.iter().fold(uint_0.clone(), |acc, &j| acc | (&uint_1 << j));
509
510 let mut x: Vec<BTreeMap<BigUint, DpTerm>> = vec![BTreeMap::new(); g.len() + 1];
512 x[1] = g
513 .iter()
514 .map(|&j| {
515 (&uint_1 << j, DpTerm {
516 indices: inputs[j].clone(),
517 cost: SizeType::zero(),
518 contract: inputs_contractions[j].clone(),
519 })
520 })
521 .collect();
522
523 let subgraph_inds = bitmap_select(&bitmap_g, &inputs).flat_map(|inds| inds.iter().copied()).collect();
525
526 let mut cost_cap = match self.cost_cap {
527 SizeLimitType::Size(cap) => cap,
528 SizeLimitType::None => SizeType::MAX,
529 SizeLimitType::MaxInput => helpers::compute_size_by_dict((&subgraph_inds & output).iter(), size_dict),
530 };
531
532 let cost_increment = if subgraph_inds.is_empty() {
533 SizeType::from_usize(2).unwrap()
534 } else {
535 subgraph_inds
536 .iter()
537 .map(|c| size_dict[c] as SizeType)
538 .fold(SizeType::MAX, SizeType::min)
539 .max(SizeType::from_usize(2).unwrap())
540 };
541
542 let mut dp_comp_args = DpCompareArgs {
543 inputs: &inputs_ref,
544 size_dict,
545 all_tensors: all_tensors.clone(),
546 memory_limit,
547 cost_cap,
548 bitmap_g,
549 combo_factor: self.combo_factor,
550 minimize: &self.minimize,
551 };
552
553 fn has_common_bits(s1: &BigUint, s2: &BigUint) -> bool {
554 let digits1 = s1.iter_u64_digits();
555 let digits2 = s2.iter_u64_digits();
556 digits1.zip(digits2).any(|(d1, d2)| d1 & d2 != 0)
557 }
558
559 while x.last().unwrap().is_empty() {
560 for n in 2..=g.len() {
561 let (xn_left, xn_right) = x.split_at_mut(n);
562 let xn = &mut xn_right[0];
563 for m in 1..=(n / 2) {
564 for (s1, term1) in &xn_left[m] {
565 for (s2, term2) in &xn_left[n - m] {
566 if !has_common_bits(s1, s2) && (m != n - m || s1 < s2) {
568 let i1 = &term1.indices;
569 let i2 = &term2.indices;
570 let i1_cut_i2_wo_output: ArrayIndexType = i1
573 .iter()
574 .filter(|&&c| i2.contains(&c) && !output.contains(&c))
575 .cloned()
576 .collect();
577 if self.search_outer || !i1_cut_i2_wo_output.is_empty() {
578 dp_comp_args.compare(xn, s1, s2, term1, term2, &i1_cut_i2_wo_output);
579 }
580 }
581 }
582 }
583 }
584 }
585
586 cost_cap = match cost_cap >= SizeType::MAX / cost_increment {
588 true => SizeType::MAX,
589 false => cost_cap * cost_increment,
590 };
591 dp_comp_args.cost_cap = cost_cap;
592
593 if cost_cap > naive_cost && x.last().unwrap().is_empty() {
594 return Err("No contraction found for given memory_limit".into());
595 }
596 }
597
598 let (_, term) = x.last().unwrap().iter().next().unwrap();
599 subgraph_contractions.push(term.contract.clone());
600 subgraph_sizes.push(helpers::compute_size_by_dict(term.indices.iter(), size_dict));
601 }
602
603 let sorted_indices =
605 (0..subgraph_sizes.len()).sorted_by(|&a, &b| subgraph_sizes[a].partial_cmp(&subgraph_sizes[b]).unwrap());
606 let sorted_contractions = sorted_indices.map(|i| subgraph_contractions[i].clone()).collect_vec();
607
608 Ok(tree_to_sequence(&simple_tree_tuple(&sorted_contractions)))
609 }
610}
611
612impl PathOptimizer for DynamicProgramming {
613 fn optimize_path(
614 &mut self,
615 inputs: &[&ArrayIndexType],
616 output: &ArrayIndexType,
617 size_dict: &SizeDictType,
618 memory_limit: Option<SizeType>,
619 ) -> Result<PathType, String> {
620 self.find_optimal_path(inputs, output, size_dict, memory_limit)
621 }
622}
623
624impl From<&str> for DynamicProgramming {
625 fn from(s: &str) -> Self {
626 let s = s.replace(['_', ' '], "-").to_lowercase();
627 if s == "dp" || s == "dynamic-programming" {
628 return DynamicProgramming::default();
629 }
630 if s.starts_with("dp-") {
631 let minimize = s.strip_prefix("dp-").unwrap();
632 if minimize.starts_with("combo") || minimize.starts_with("limit") {
634 let minimize_split = minimize.split('-').collect_vec();
635 if minimize_split.len() > 2 {
636 panic!("Unknown dynamic programming optimizer: {s}");
637 }
638 match minimize_split.len() {
639 1 => {
640 let minimize = minimize_split[0];
641 if minimize != "combo" && minimize != "limit" {
642 panic!("Unknown dynamic programming optimizer: {s}");
643 }
644 return DynamicProgramming { minimize: minimize.into(), ..Default::default() };
645 },
646 2 => {
647 let minimize = minimize_split[0];
648 if minimize != "combo" && minimize != "limit" {
649 panic!("Unknown dynamic programming optimizer: {s}");
650 }
651 let combo_factor = match minimize_split[1].parse::<SizeType>() {
652 Ok(factor) => factor,
653 Err(_) => panic!("Invalid combo factor in dynamic programming optimizer: {s}"),
654 };
655 return DynamicProgramming { minimize: minimize.into(), combo_factor, ..Default::default() };
656 },
657 _ => panic!("Unknown dynamic programming optimizer: {s}"),
658 };
659 } else if minimize == "flops" || minimize == "size" || minimize == "write" {
660 return DynamicProgramming { minimize: minimize.into(), ..Default::default() };
661 } else {
662 panic!("Unknown dynamic programming optimizer: {s}");
663 }
664 }
665 panic!("Unknown dynamic programming optimizer: {s}");
666 }
667}
668
669#[test]
670fn test_tree_to_sequence() {
671 let tree: ContractionTree = ContractionTree::from(vec![
672 ContractionTree::from(vec![1, 2]),
673 vec![ContractionTree::from(0), ContractionTree::from(vec![4, 5, 3])].into(),
674 ]);
675
676 let path = tree_to_sequence(&tree);
677 println!("{path:?}");
678 assert_eq!(path, vec![vec![1, 2], vec![1, 2, 3], vec![0, 2], vec![0, 1]]);
679}
680
681#[test]
682fn test_find_disconnected_subgraphs() {
683 use crate::helpers::setify;
684 let inputs1 = vec![setify("ab"), setify("c"), setify("ad")];
686 let output1 = setify("bd");
687 let result1 = find_disconnected_subgraphs(&inputs1, &output1);
688 assert_eq!(result1, vec![setify([0, 2]), setify([1])]);
689
690 let inputs2 = vec![setify("ab"), setify("c"), setify("ad")];
692 let output2 = setify("abd");
693 let result2 = find_disconnected_subgraphs(&inputs2, &output2);
694 assert_eq!(result2, vec![setify([0]), setify([1]), setify([2])]);
695}
696
697#[test]
698fn test_bitmap_select() {
699 use crate::helpers::setify;
700 let seq = vec![setify("A"), setify("B"), setify("C"), setify("D"), setify("E")];
701
702 let s = BigUint::from(0b11010_u32);
704 let selected = bitmap_select(&s, &seq).collect_vec();
705 assert_eq!(selected, vec![&setify("B"), &setify("D"), &setify("E")]);
706
707 assert_eq!(bitmap_select(&BigUint::from(0b00000_u32), &seq).count(), 0);
709 assert_eq!(bitmap_select(&BigUint::from(0b11111_u32), &seq).count(), 5);
710 assert_eq!(bitmap_select(&BigUint::from(0b00001_u32), &seq).collect_vec(), vec![&setify("A")]);
711}
712
713#[test]
714fn test_simple_tree_tuple() {
715 let tree = simple_tree_tuple(&[1.into(), 2.into(), 3.into(), 4.into()]);
716 assert_eq!(
717 tree,
718 ContractionTree::Node(vec![
719 ContractionTree::Node(vec![ContractionTree::Node(vec![1.into(), 2.into()]), 3.into()]),
720 4.into()
721 ])
722 );
723}