opt_einsum_path/paths/
optimal.rs1use crate::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct Optimal {
5 output: ArrayIndexType,
7 size_dict: SizeDictType,
8 memory_limit: Option<SizeType>,
9 best_flops: Option<SizeType>,
11 best_ssa_path: PathType,
12 size_cache: BTreeMap<ArrayIndexType, SizeType>,
13}
14
15impl Optimal {
16 fn optimal_iterate(&mut self, path: PathType, remaining: &[usize], inputs: &[&ArrayIndexType], flops: SizeType) {
17 if remaining.len() == 1 {
19 self.best_flops = Some(flops);
20 self.best_ssa_path = path;
21 return;
22 }
23
24 for i in 0..remaining.len() {
26 for j in (i + 1)..remaining.len() {
27 let a = remaining[i];
28 let b = remaining[j];
29 let (i, j) = if a < b { (a, b) } else { (b, a) };
30
31 let (k12, flops12) =
32 paths::util::calc_k12_flops(inputs, &self.output, remaining, i, j, &self.size_dict);
33
34 let new_flops = flops + flops12;
36 if self.best_flops.is_some_and(|best| new_flops >= best) {
37 continue;
38 }
39
40 if let Some(limit) = self.memory_limit {
42 let size12 = self
43 .size_cache
44 .entry(k12.clone())
45 .or_insert_with(|| helpers::compute_size_by_dict(k12.iter(), &self.size_dict));
46
47 if *size12 > limit {
49 let oversize_flops = flops
50 + paths::util::compute_oversize_flops(inputs, remaining, &self.output, &self.size_dict);
51 if self.best_flops.is_none_or(|best| oversize_flops < best) {
52 self.best_flops = Some(oversize_flops);
53 let mut new_path = path.clone();
54 new_path.push(remaining.to_vec());
55 self.best_ssa_path = new_path;
56 }
57 continue;
58 }
59 }
60
61 let mut new_remaining = remaining.to_vec();
63 new_remaining.retain(|&x| x != i && x != j);
64 new_remaining.push(inputs.len());
65 let mut new_inputs = inputs.to_vec();
66 new_inputs.push(&k12);
67
68 let mut new_path = path.clone();
69 new_path.push(vec![i, j]);
70
71 self.optimal_iterate(new_path, &new_remaining, &new_inputs, new_flops);
72 }
73 }
74 }
75}
76
77pub fn optimal(
117 inputs: &[&ArrayIndexType],
118 output: &ArrayIndexType,
119 size_dict: &SizeDictType,
120 memory_limit: Option<SizeType>,
121) -> Result<PathType, String> {
122 let mut optimizer =
123 Optimal { output: output.clone(), size_dict: size_dict.clone(), memory_limit, ..Default::default() };
124
125 let path = Vec::new();
126 let inputs_indices: Vec<usize> = (0..inputs.len()).collect();
127 let flops = SizeType::zero();
128 optimizer.optimal_iterate(path, &inputs_indices, inputs, flops);
129 Ok(paths::util::ssa_to_linear(&optimizer.best_ssa_path))
130}
131
132impl PathOptimizer for Optimal {
133 fn optimize_path(
134 &mut self,
135 inputs: &[&ArrayIndexType],
136 output: &ArrayIndexType,
137 size_dict: &SizeDictType,
138 memory_limit: Option<SizeType>,
139 ) -> Result<PathType, String> {
140 optimal(inputs, output, size_dict, memory_limit)
141 }
142}
143
144#[test]
145fn playground() {
146 use std::collections::BTreeMap;
147 let time = std::time::Instant::now();
148 let inputs = [&"abd".chars().collect(), &"ac".chars().collect(), &"bdc".chars().collect()];
149 let output = "".chars().collect();
150 let size_dict = BTreeMap::from([('a', 1), ('b', 2), ('c', 3), ('d', 4)]);
151 let path = optimal(&inputs, &output, &size_dict, Some(SizeType::from_usize(5000).unwrap())).unwrap();
152 assert_eq!(path, vec![vec![0, 2], vec![0, 1]]);
153 let duration = time.elapsed();
154 println!("Optimal path found in: {duration:?}");
155}
156
157#[test]
158fn playground_issue() {
159 use std::collections::BTreeMap;
160 let time = std::time::Instant::now();
161 let inputs = [&"bgk".chars().collect(), &"bkd".chars().collect(), &"bk".chars().collect()];
162 let output = "bgd".chars().collect();
163 let size_dict = BTreeMap::from([('b', 64), ('g', 8), ('k', 4096), ('d', 128)]);
164 let path = optimal(&inputs, &output, &size_dict, None);
165 println!("{path:?}");
166 let duration = time.elapsed();
167 println!("Optimal path found in: {duration:?}");
168}