opt_einsum_path/paths/
optimal.rs

1use crate::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct Optimal {
5    // const after build
6    output: ArrayIndexType,
7    size_dict: SizeDictType,
8    memory_limit: Option<SizeType>,
9    // mutable during iteration
10    best_flops: Option<SizeType>,
11    best_ssa_path: PathType,
12    size_cache: BTreeMap<ArrayIndexType, SizeType>,
13}
14
15impl Optimal {
16    fn optimal_iterate(&mut self, path: PathType, remaining: &[usize], inputs: &[&ArrayIndexType], flops: SizeType) {
17        // Reached end of path (only get here if flops is best found so far)
18        if remaining.len() == 1 {
19            self.best_flops = Some(flops);
20            self.best_ssa_path = path;
21            return;
22        }
23
24        // Generate all possible pairs
25        for i in 0..remaining.len() {
26            for j in (i + 1)..remaining.len() {
27                let a = remaining[i];
28                let b = remaining[j];
29                let (i, j) = if a < b { (a, b) } else { (b, a) };
30
31                let (k12, flops12) =
32                    paths::util::calc_k12_flops(inputs, &self.output, remaining, i, j, &self.size_dict);
33
34                // Sieve based on current best flops
35                let new_flops = flops + flops12;
36                if self.best_flops.is_some_and(|best| new_flops >= best) {
37                    continue;
38                }
39
40                // Sieve based on memory limit
41                if let Some(limit) = self.memory_limit {
42                    let size12 = self
43                        .size_cache
44                        .entry(k12.clone())
45                        .or_insert_with(|| helpers::compute_size_by_dict(k12.iter(), &self.size_dict));
46
47                    // Possibly terminate this path with an all-terms einsum
48                    if *size12 > limit {
49                        let oversize_flops = flops
50                            + paths::util::compute_oversize_flops(inputs, remaining, &self.output, &self.size_dict);
51                        if self.best_flops.is_none_or(|best| oversize_flops < best) {
52                            self.best_flops = Some(oversize_flops);
53                            let mut new_path = path.clone();
54                            new_path.push(remaining.to_vec());
55                            self.best_ssa_path = new_path;
56                        }
57                        continue;
58                    }
59                }
60
61                // Add contraction and recurse
62                let mut new_remaining = remaining.to_vec();
63                new_remaining.retain(|&x| x != i && x != j);
64                new_remaining.push(inputs.len());
65                let mut new_inputs = inputs.to_vec();
66                new_inputs.push(&k12);
67
68                let mut new_path = path.clone();
69                new_path.push(vec![i, j]);
70
71                self.optimal_iterate(new_path, &new_remaining, &new_inputs, new_flops);
72            }
73        }
74    }
75}
76
77/// Computes all possible pair contractions in a depth-first recursive manner,
78/// sieving results based on `memory_limit` and the best path found so far.
79///
80/// # Parameters
81///
82/// - `inputs`: List of sets that represent the lhs side of the einsum subscript
83/// - `output`: Set that represents the rhs side of the overall einsum subscript
84/// - `size_dict`: Dictionary of index sizes
85/// - `memory_limit`: The maximum number of elements in a temporary array
86///
87/// # Returns
88///
89/// The optimal contraction order within the memory limit constraint
90///
91/// # Example
92///
93/// ```rust
94/// # use std::collections::BTreeMap;
95/// # use opt_einsum_path::typing::*;
96/// # use num::FromPrimitive;
97/// # use opt_einsum_path::paths::optimal::optimal;
98/// use opt_einsum_path::helpers::setify;
99/// let inputs = [&setify("abd"), &setify("ac"), &setify("bdc")];
100/// let output = setify("");
101/// let size_dict = BTreeMap::from([('a', 1), ('b', 2), ('c', 3), ('d', 4)]);
102/// let path = optimal(&inputs, &output, &size_dict, Some(5000.0)).unwrap();
103/// assert_eq!(path, vec![vec![0, 2], vec![0, 1]]);
104/// ```
105///
106/// Python equivalent:
107///
108/// ```python
109/// >>> from opt_einsum.paths import optimal
110/// >>> isets = [set('abd'), set('ac'), set('bdc')]
111/// >>> oset = set('')
112/// >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
113/// >>> optimal(isets, oset, idx_sizes, 5000)
114/// [(0, 2), (0, 1)]
115/// ```
116pub fn optimal(
117    inputs: &[&ArrayIndexType],
118    output: &ArrayIndexType,
119    size_dict: &SizeDictType,
120    memory_limit: Option<SizeType>,
121) -> Result<PathType, String> {
122    let mut optimizer =
123        Optimal { output: output.clone(), size_dict: size_dict.clone(), memory_limit, ..Default::default() };
124
125    let path = Vec::new();
126    let inputs_indices: Vec<usize> = (0..inputs.len()).collect();
127    let flops = SizeType::zero();
128    optimizer.optimal_iterate(path, &inputs_indices, inputs, flops);
129    Ok(paths::util::ssa_to_linear(&optimizer.best_ssa_path))
130}
131
132impl PathOptimizer for Optimal {
133    fn optimize_path(
134        &mut self,
135        inputs: &[&ArrayIndexType],
136        output: &ArrayIndexType,
137        size_dict: &SizeDictType,
138        memory_limit: Option<SizeType>,
139    ) -> Result<PathType, String> {
140        optimal(inputs, output, size_dict, memory_limit)
141    }
142}
143
144#[test]
145fn playground() {
146    use std::collections::BTreeMap;
147    let time = std::time::Instant::now();
148    let inputs = [&"abd".chars().collect(), &"ac".chars().collect(), &"bdc".chars().collect()];
149    let output = "".chars().collect();
150    let size_dict = BTreeMap::from([('a', 1), ('b', 2), ('c', 3), ('d', 4)]);
151    let path = optimal(&inputs, &output, &size_dict, Some(SizeType::from_usize(5000).unwrap())).unwrap();
152    assert_eq!(path, vec![vec![0, 2], vec![0, 1]]);
153    let duration = time.elapsed();
154    println!("Optimal path found in: {duration:?}");
155}
156
157#[test]
158fn playground_issue() {
159    use std::collections::BTreeMap;
160    let time = std::time::Instant::now();
161    let inputs = [&"bgk".chars().collect(), &"bkd".chars().collect(), &"bk".chars().collect()];
162    let output = "bgd".chars().collect();
163    let size_dict = BTreeMap::from([('b', 64), ('g', 8), ('k', 4096), ('d', 128)]);
164    let path = optimal(&inputs, &output, &size_dict, None);
165    println!("{path:?}");
166    let duration = time.elapsed();
167    println!("Optimal path found in: {duration:?}");
168}