opt_einsum_path/paths/
branch_bound.rs

1use crate::*;
2
3pub type BetterFn = fn(SizeType, SizeType, SizeType, SizeType) -> bool;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum MinimizeStrategy {
7    FlopsFirst,
8    SizeFirst,
9}
10
11/// functions for comparing which of two paths is 'better'.
12pub fn get_better_fn(key: MinimizeStrategy) -> BetterFn {
13    match key {
14        MinimizeStrategy::FlopsFirst => {
15            |flops, size, best_flops, best_size| flops < best_flops || (flops == best_flops && size < best_size)
16        },
17        MinimizeStrategy::SizeFirst => {
18            |flops, size, best_flops, best_size| size < best_size || (size == best_size && flops < best_flops)
19        },
20    }
21}
22
23#[derive(Debug, Clone)]
24pub struct BranchBoundBest {
25    pub flops: SizeType,
26    pub size: SizeType,
27    pub ssa_path: Option<PathType>,
28}
29
30impl Default for BranchBoundBest {
31    fn default() -> Self {
32        Self { flops: SizeType::MAX, size: SizeType::MAX, ssa_path: None }
33    }
34}
35
36#[derive(Debug, Clone, PartialEq, PartialOrd)]
37pub struct BranchBoundCandidate {
38    cost: SizeType,
39    flops12: SizeType,
40    new_flops: SizeType,
41    new_size: SizeType,
42    pair: (usize, usize),
43    k12: ArrayIndexType,
44}
45
46impl Eq for BranchBoundCandidate {}
47
48#[allow(clippy::derive_ord_xor_partial_ord)]
49impl Ord for BranchBoundCandidate {
50    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51        self.cost.partial_cmp(&other.cost).unwrap()
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct BranchBound {
57    // inputs
58    inputs: Vec<ArrayIndexType>,
59    output: ArrayIndexType,
60    size_dict: SizeDictType,
61    memory_limit: Option<SizeType>,
62    // parameters
63    pub nbranch: Option<usize>,
64    pub cutoff_flops_factor: SizeType,
65    pub better_fn: BetterFn,
66    pub cost_fn: paths::CostFn,
67    // caches
68    pub best: BranchBoundBest,
69    pub best_progress: BTreeMap<usize, SizeType>,
70    size_cache: BTreeMap<ArrayIndexType, SizeType>,
71}
72
73impl Default for BranchBound {
74    fn default() -> Self {
75        Self {
76            // inputs
77            inputs: Vec::new(),
78            output: ArrayIndexType::default(),
79            size_dict: SizeDictType::default(),
80            memory_limit: None,
81            // parameters
82            nbranch: None,
83            cutoff_flops_factor: SizeType::from_f64(4.0).unwrap(),
84            better_fn: get_better_fn(MinimizeStrategy::FlopsFirst),
85            cost_fn: paths::util::memory_removed(false),
86            // caches
87            best: BranchBoundBest::default(),
88            best_progress: BTreeMap::new(),
89            size_cache: BTreeMap::new(),
90        }
91    }
92}
93
94impl BranchBound {
95    pub fn path(&self) -> PathType {
96        paths::util::ssa_to_linear(self.best.ssa_path.as_ref().unwrap_or(&Vec::new()))
97    }
98}
99
100impl BranchBound {
101    #[allow(clippy::too_many_arguments)]
102    fn assess_candidate(
103        &mut self,
104        k1: &ArrayIndexType,
105        k2: &ArrayIndexType,
106        i: usize,
107        j: usize,
108        path: &[TensorShapeType],
109        inputs: &[&ArrayIndexType],
110        remaining: &[usize],
111        flops: SizeType,
112        size: SizeType,
113    ) -> Option<BranchBoundCandidate> {
114        // find resulting indices and flops
115        let (k12, flops12) = paths::util::calc_k12_flops(inputs, &self.output, remaining, i, j, &self.size_dict);
116
117        let size12 = *self
118            .size_cache
119            .entry(k12.clone())
120            .or_insert_with(|| helpers::compute_size_by_dict(k12.iter(), &self.size_dict));
121
122        let new_flops = flops + flops12;
123        let new_size = size.max(size12);
124
125        // sieve based on current best i.e. check flops and size still better
126        if !(self.better_fn)(new_flops, new_size, self.best.flops, self.best.size) {
127            return None;
128        }
129
130        let inputs_len = inputs.len();
131        let best_progress = self.best_progress.entry(inputs_len).or_insert(SizeType::MAX);
132        if new_flops < *best_progress {
133            // compare to how the best method was doing as this point
134            *best_progress = new_flops;
135        } else if new_flops > self.cutoff_flops_factor * *best_progress {
136            // sieve based on current progress relative to best
137            return None;
138        }
139
140        // sieve based on memory limit
141        if let Some(limit) = self.memory_limit {
142            if size12 > limit {
143                // terminate path here, but check all-terms contract first
144                let oversize_flops =
145                    flops + paths::util::compute_oversize_flops(inputs, remaining, &self.output, &self.size_dict);
146                if oversize_flops < self.best.flops {
147                    self.best.flops = oversize_flops;
148                    let mut new_path = path.to_vec();
149                    new_path.push(remaining.to_vec());
150                    self.best.ssa_path = Some(new_path);
151                }
152                return None;
153            }
154        }
155
156        // Calculate cost heuristic
157        let size1 = self.size_cache[k1];
158        let size2 = self.size_cache[k2];
159        let cost = (self.cost_fn)(size12, size1, size2, 0, 0, 0);
160
161        Some(BranchBoundCandidate { cost, flops12, new_flops, new_size, pair: (i, j), k12 })
162    }
163
164    #[allow(clippy::too_many_arguments)]
165    #[allow(clippy::type_complexity)]
166    fn branch_iterate(
167        &mut self,
168        path: &[TensorShapeType],
169        inputs: &[&ArrayIndexType],
170        remaining: Vec<usize>,
171        flops: SizeType,
172        size: SizeType,
173    ) {
174        // Reached end of path (only get here if flops is best found so far)
175        if remaining.len() == 1 {
176            self.best.flops = flops;
177            self.best.size = size;
178            self.best.ssa_path = Some(path.to_vec());
179            return;
180        }
181
182        // Check all possible remaining paths
183        let mut candidates = BinaryHeap::new();
184        for (i, j) in remaining.iter().tuple_combinations() {
185            let (i, j) = if i < j { (*i, *j) } else { (*j, *i) };
186            let k1 = &inputs[i];
187            let k2 = &inputs[j];
188
189            // Initially ignore outer products
190            if k1.is_disjoint(k2) {
191                continue;
192            }
193
194            if let Some(candidate) = self.assess_candidate(k1, k2, i, j, path, inputs, &remaining, flops, size) {
195                candidates.push(Reverse(candidate));
196            }
197        }
198
199        // Assess outer products if nothing left
200        if candidates.is_empty() {
201            for (i, j) in remaining.iter().tuple_combinations() {
202                let (i, j) = if i < j { (*i, *j) } else { (*j, *i) };
203                let k1 = &inputs[i];
204                let k2 = &inputs[j];
205
206                if let Some(candidate) = self.assess_candidate(k1, k2, i, j, path, inputs, &remaining, flops, size) {
207                    candidates.push(Reverse(candidate));
208                }
209            }
210        }
211
212        // Recurse into all or some of the best candidate contractions
213        let mut bi = 0;
214        while (self.nbranch.is_none() || bi < self.nbranch.unwrap()) && !candidates.is_empty() {
215            let Reverse(candidate) = candidates.pop().unwrap();
216            let BranchBoundCandidate { new_flops, new_size, pair: (i, j), k12, .. } = candidate;
217
218            let mut new_remaining = remaining.clone();
219            new_remaining.retain(|&x| x != i && x != j);
220            new_remaining.push(inputs.len());
221
222            let mut new_inputs = inputs.to_vec();
223            new_inputs.push(&k12);
224
225            let mut new_path = path.to_vec();
226            new_path.push(vec![i, j]);
227
228            self.branch_iterate(&new_path, &new_inputs, new_remaining, new_flops, new_size);
229
230            bi += 1;
231        }
232    }
233
234    fn branch_bound(
235        &mut self,
236        inputs: &[&ArrayIndexType],
237        output: &ArrayIndexType,
238        size_dict: &SizeDictType,
239        memory_limit: Option<SizeType>,
240    ) -> Result<PathType, String> {
241        // Reset best state for new optimization
242        self.best = BranchBoundBest::default();
243        self.best_progress.clear();
244
245        // Prepare caches
246        self.size_cache =
247            inputs.iter().map(|&k| (k.clone(), helpers::compute_size_by_dict(k.iter(), size_dict))).collect();
248
249        // Convert inputs to Vec of owned sets for easier manipulation
250        self.inputs = inputs.iter().map(|s| (*s).clone()).collect();
251        self.output = output.clone();
252        self.size_dict = size_dict.clone();
253        self.memory_limit = memory_limit;
254
255        // Start the recursive process
256        let inputs_len = inputs.len();
257        self.branch_iterate(&Vec::new(), inputs, (0..inputs_len).collect(), SizeType::zero(), SizeType::zero());
258
259        Ok(self.path())
260    }
261}
262
263impl PathOptimizer for BranchBound {
264    fn optimize_path(
265        &mut self,
266        inputs: &[&ArrayIndexType],
267        output: &ArrayIndexType,
268        size_dict: &SizeDictType,
269        memory_limit: Option<SizeType>,
270    ) -> Result<PathType, String> {
271        self.branch_bound(inputs, output, size_dict, memory_limit)
272    }
273}
274
275impl From<&str> for BranchBound {
276    fn from(s: &str) -> Self {
277        match s.replace(['_', ' '], "-").to_lowercase().as_str() {
278            "branch-all" => BranchBound::default(),
279            "branch-1" => BranchBound { nbranch: Some(1), ..BranchBound::default() },
280            "branch-2" => BranchBound { nbranch: Some(2), ..BranchBound::default() },
281            _ => panic!("Unknown branch bound kind: {s}"),
282        }
283    }
284}
285
286/* #endregion */