1use crate::*;
2
3pub type BetterFn = fn(SizeType, SizeType, SizeType, SizeType) -> bool;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum MinimizeStrategy {
7 FlopsFirst,
8 SizeFirst,
9}
10
11pub fn get_better_fn(key: MinimizeStrategy) -> BetterFn {
13 match key {
14 MinimizeStrategy::FlopsFirst => {
15 |flops, size, best_flops, best_size| flops < best_flops || (flops == best_flops && size < best_size)
16 },
17 MinimizeStrategy::SizeFirst => {
18 |flops, size, best_flops, best_size| size < best_size || (size == best_size && flops < best_flops)
19 },
20 }
21}
22
23#[derive(Debug, Clone)]
24pub struct BranchBoundBest {
25 pub flops: SizeType,
26 pub size: SizeType,
27 pub ssa_path: Option<PathType>,
28}
29
30impl Default for BranchBoundBest {
31 fn default() -> Self {
32 Self { flops: SizeType::MAX, size: SizeType::MAX, ssa_path: None }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, PartialOrd)]
37pub struct BranchBoundCandidate {
38 cost: SizeType,
39 flops12: SizeType,
40 new_flops: SizeType,
41 new_size: SizeType,
42 pair: (usize, usize),
43 k12: ArrayIndexType,
44}
45
46impl Eq for BranchBoundCandidate {}
47
48#[allow(clippy::derive_ord_xor_partial_ord)]
49impl Ord for BranchBoundCandidate {
50 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51 self.cost.partial_cmp(&other.cost).unwrap()
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct BranchBound {
57 inputs: Vec<ArrayIndexType>,
59 output: ArrayIndexType,
60 size_dict: SizeDictType,
61 memory_limit: Option<SizeType>,
62 pub nbranch: Option<usize>,
64 pub cutoff_flops_factor: SizeType,
65 pub better_fn: BetterFn,
66 pub cost_fn: paths::CostFn,
67 pub best: BranchBoundBest,
69 pub best_progress: BTreeMap<usize, SizeType>,
70 size_cache: BTreeMap<ArrayIndexType, SizeType>,
71}
72
73impl Default for BranchBound {
74 fn default() -> Self {
75 Self {
76 inputs: Vec::new(),
78 output: ArrayIndexType::default(),
79 size_dict: SizeDictType::default(),
80 memory_limit: None,
81 nbranch: None,
83 cutoff_flops_factor: SizeType::from_f64(4.0).unwrap(),
84 better_fn: get_better_fn(MinimizeStrategy::FlopsFirst),
85 cost_fn: paths::util::memory_removed(false),
86 best: BranchBoundBest::default(),
88 best_progress: BTreeMap::new(),
89 size_cache: BTreeMap::new(),
90 }
91 }
92}
93
94impl BranchBound {
95 pub fn path(&self) -> PathType {
96 paths::util::ssa_to_linear(self.best.ssa_path.as_ref().unwrap_or(&Vec::new()))
97 }
98}
99
100impl BranchBound {
101 #[allow(clippy::too_many_arguments)]
102 fn assess_candidate(
103 &mut self,
104 k1: &ArrayIndexType,
105 k2: &ArrayIndexType,
106 i: usize,
107 j: usize,
108 path: &[TensorShapeType],
109 inputs: &[&ArrayIndexType],
110 remaining: &[usize],
111 flops: SizeType,
112 size: SizeType,
113 ) -> Option<BranchBoundCandidate> {
114 let (k12, flops12) = paths::util::calc_k12_flops(inputs, &self.output, remaining, i, j, &self.size_dict);
116
117 let size12 = *self
118 .size_cache
119 .entry(k12.clone())
120 .or_insert_with(|| helpers::compute_size_by_dict(k12.iter(), &self.size_dict));
121
122 let new_flops = flops + flops12;
123 let new_size = size.max(size12);
124
125 if !(self.better_fn)(new_flops, new_size, self.best.flops, self.best.size) {
127 return None;
128 }
129
130 let inputs_len = inputs.len();
131 let best_progress = self.best_progress.entry(inputs_len).or_insert(SizeType::MAX);
132 if new_flops < *best_progress {
133 *best_progress = new_flops;
135 } else if new_flops > self.cutoff_flops_factor * *best_progress {
136 return None;
138 }
139
140 if let Some(limit) = self.memory_limit {
142 if size12 > limit {
143 let oversize_flops =
145 flops + paths::util::compute_oversize_flops(inputs, remaining, &self.output, &self.size_dict);
146 if oversize_flops < self.best.flops {
147 self.best.flops = oversize_flops;
148 let mut new_path = path.to_vec();
149 new_path.push(remaining.to_vec());
150 self.best.ssa_path = Some(new_path);
151 }
152 return None;
153 }
154 }
155
156 let size1 = self.size_cache[k1];
158 let size2 = self.size_cache[k2];
159 let cost = (self.cost_fn)(size12, size1, size2, 0, 0, 0);
160
161 Some(BranchBoundCandidate { cost, flops12, new_flops, new_size, pair: (i, j), k12 })
162 }
163
164 #[allow(clippy::too_many_arguments)]
165 #[allow(clippy::type_complexity)]
166 fn branch_iterate(
167 &mut self,
168 path: &[TensorShapeType],
169 inputs: &[&ArrayIndexType],
170 remaining: Vec<usize>,
171 flops: SizeType,
172 size: SizeType,
173 ) {
174 if remaining.len() == 1 {
176 self.best.flops = flops;
177 self.best.size = size;
178 self.best.ssa_path = Some(path.to_vec());
179 return;
180 }
181
182 let mut candidates = BinaryHeap::new();
184 for (i, j) in remaining.iter().tuple_combinations() {
185 let (i, j) = if i < j { (*i, *j) } else { (*j, *i) };
186 let k1 = &inputs[i];
187 let k2 = &inputs[j];
188
189 if k1.is_disjoint(k2) {
191 continue;
192 }
193
194 if let Some(candidate) = self.assess_candidate(k1, k2, i, j, path, inputs, &remaining, flops, size) {
195 candidates.push(Reverse(candidate));
196 }
197 }
198
199 if candidates.is_empty() {
201 for (i, j) in remaining.iter().tuple_combinations() {
202 let (i, j) = if i < j { (*i, *j) } else { (*j, *i) };
203 let k1 = &inputs[i];
204 let k2 = &inputs[j];
205
206 if let Some(candidate) = self.assess_candidate(k1, k2, i, j, path, inputs, &remaining, flops, size) {
207 candidates.push(Reverse(candidate));
208 }
209 }
210 }
211
212 let mut bi = 0;
214 while (self.nbranch.is_none() || bi < self.nbranch.unwrap()) && !candidates.is_empty() {
215 let Reverse(candidate) = candidates.pop().unwrap();
216 let BranchBoundCandidate { new_flops, new_size, pair: (i, j), k12, .. } = candidate;
217
218 let mut new_remaining = remaining.clone();
219 new_remaining.retain(|&x| x != i && x != j);
220 new_remaining.push(inputs.len());
221
222 let mut new_inputs = inputs.to_vec();
223 new_inputs.push(&k12);
224
225 let mut new_path = path.to_vec();
226 new_path.push(vec![i, j]);
227
228 self.branch_iterate(&new_path, &new_inputs, new_remaining, new_flops, new_size);
229
230 bi += 1;
231 }
232 }
233
234 fn branch_bound(
235 &mut self,
236 inputs: &[&ArrayIndexType],
237 output: &ArrayIndexType,
238 size_dict: &SizeDictType,
239 memory_limit: Option<SizeType>,
240 ) -> Result<PathType, String> {
241 self.best = BranchBoundBest::default();
243 self.best_progress.clear();
244
245 self.size_cache =
247 inputs.iter().map(|&k| (k.clone(), helpers::compute_size_by_dict(k.iter(), size_dict))).collect();
248
249 self.inputs = inputs.iter().map(|s| (*s).clone()).collect();
251 self.output = output.clone();
252 self.size_dict = size_dict.clone();
253 self.memory_limit = memory_limit;
254
255 let inputs_len = inputs.len();
257 self.branch_iterate(&Vec::new(), inputs, (0..inputs_len).collect(), SizeType::zero(), SizeType::zero());
258
259 Ok(self.path())
260 }
261}
262
263impl PathOptimizer for BranchBound {
264 fn optimize_path(
265 &mut self,
266 inputs: &[&ArrayIndexType],
267 output: &ArrayIndexType,
268 size_dict: &SizeDictType,
269 memory_limit: Option<SizeType>,
270 ) -> Result<PathType, String> {
271 self.branch_bound(inputs, output, size_dict, memory_limit)
272 }
273}
274
275impl From<&str> for BranchBound {
276 fn from(s: &str) -> Self {
277 match s.replace(['_', ' '], "-").to_lowercase().as_str() {
278 "branch-all" => BranchBound::default(),
279 "branch-1" => BranchBound { nbranch: Some(1), ..BranchBound::default() },
280 "branch-2" => BranchBound { nbranch: Some(2), ..BranchBound::default() },
281 _ => panic!("Unknown branch bound kind: {s}"),
282 }
283 }
284}
285
286