1use crate::*;
2
3pub type GreedyChooseFn = Box<
4 dyn FnMut(
5 &mut BinaryHeap<GreedyContractionType>,
6 &BTreeMap<ArrayIndexType, usize>,
7 ) -> Option<GreedyContractionType>,
8>;
9
10#[derive(Debug, Clone, PartialEq, PartialOrd)]
14pub struct GreedyCostType {
15 pub cost: SizeType,
21 pub id1: usize,
23 pub id2: usize,
25}
26
27impl Eq for GreedyCostType {}
28
29#[allow(clippy::derive_ord_xor_partial_ord)]
30impl Ord for GreedyCostType {
31 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32 self.partial_cmp(other).unwrap()
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
40pub struct GreedyContractionType {
41 pub cost: Reverse<GreedyCostType>,
43 pub k1: ArrayIndexType,
45 pub k2: ArrayIndexType,
47 pub k12: ArrayIndexType,
49}
50
51fn get_candidate(
53 output: &ArrayIndexType,
54 size_dict: &SizeDictType,
55 remaining: &BTreeMap<ArrayIndexType, usize>,
56 footprints: &BTreeMap<ArrayIndexType, SizeType>,
57 dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
58 k1: &ArrayIndexType,
59 k2: &ArrayIndexType,
60 cost_fn: paths::CostFn,
61) -> GreedyContractionType {
62 let either = k1 | k2;
63 let two = k1 & k2;
64 let one = &either - &two;
65
66 let part1 = either.intersection(output);
69 let part2 = two.intersection(&dim_ref_counts[&3]);
71 let part3 = one.intersection(&dim_ref_counts[&2]);
72 let k12: ArrayIndexType = part1.chain(part2).chain(part3).cloned().collect();
73
74 let size12 = helpers::compute_size_by_dict(k12.iter(), size_dict);
75 let footprint1 = footprints[k1];
76 let footprint2 = footprints[k2];
77 let cost = cost_fn(size12, footprint1, footprint2, 0, 0, 0);
78
79 let id1 = remaining[k1];
80 let id2 = remaining[k2];
81 let (k1, id1, k2, id2) =
82 if id1 > id2 { (k1.clone(), id1, k2.clone(), id2) } else { (k2.clone(), id2, k1.clone(), id1) };
83
84 GreedyContractionType { cost: Reverse(GreedyCostType { cost, id1, id2 }), k1, k2, k12 }
85}
86
87fn push_candidate(
89 output: &ArrayIndexType,
90 size_dict: &SizeDictType,
91 remaining: &BTreeMap<ArrayIndexType, usize>,
92 footprints: &BTreeMap<ArrayIndexType, SizeType>,
93 dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
94 k1: &ArrayIndexType,
95 k2s: &[ArrayIndexType],
96 queue: &mut BinaryHeap<GreedyContractionType>,
97 push_all: bool,
98 cost_fn: paths::CostFn,
99) {
100 let candidates: Vec<GreedyContractionType> = k2s
101 .iter()
102 .map(|k2| get_candidate(output, size_dict, remaining, footprints, dim_ref_counts, k1, k2, cost_fn))
103 .collect();
104
105 if push_all {
106 candidates.into_iter().for_each(|c| queue.push(c));
107 } else if let Some(max_cand) = candidates.into_iter().max() {
108 queue.push(max_cand);
109 }
110}
111
112fn update_ref_counts(
118 dim_to_keys: &BTreeMap<char, BTreeSet<ArrayIndexType>>,
119 dim_ref_counts: &mut BTreeMap<usize, BTreeSet<char>>,
120 dims: &ArrayIndexType,
121 output: &ArrayIndexType,
122) {
123 for dim in dims {
124 if output.contains(dim) {
125 continue;
126 }
127 let count = dim_to_keys.get(dim).map(|s| s.len()).unwrap_or(0);
128
129 match count {
130 0..=1 => {
131 dim_ref_counts.get_mut(&2).unwrap().remove(dim);
132 dim_ref_counts.get_mut(&3).unwrap().remove(dim);
133 },
134 2 => {
135 dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
136 dim_ref_counts.get_mut(&3).unwrap().remove(dim);
137 },
138 3.. => {
139 dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
140 dim_ref_counts.get_mut(&3).unwrap().insert(*dim);
141 },
142 }
143 }
144}
145
146pub fn simple_chooser(
151 queue: &mut BinaryHeap<GreedyContractionType>,
152 remaining: &BTreeMap<ArrayIndexType, usize>,
153) -> Option<GreedyContractionType> {
154 while let Some(cand) = queue.pop() {
155 if remaining.contains_key(&cand.k1) && remaining.contains_key(&cand.k2) {
156 return Some(cand);
157 }
158 }
159 None
160}
161
162pub fn ssa_greedy_optimize(
166 inputs: &[&ArrayIndexType],
167 output: &ArrayIndexType,
168 size_dict: &SizeDictType,
169 choose_fn: Option<&mut GreedyChooseFn>,
170 cost_fn: Option<paths::CostFn>,
171) -> PathType {
172 if inputs.is_empty() {
173 return vec![];
174 }
175
176 if inputs.len() == 1 {
177 return vec![vec![0]];
179 }
180
181 let push_all = choose_fn.is_none();
183 let mut default_chooser: GreedyChooseFn = Box::new(simple_chooser);
184 let choose_fn: &mut GreedyChooseFn = if let Some(choose_fn) = choose_fn { choose_fn } else { &mut default_chooser };
185
186 let cost_fn = cost_fn.unwrap_or(paths::util::memory_removed(false));
188
189 let common_dims = inputs.iter().skip(1).fold(inputs[0].clone(), |acc, s| &acc & s);
194 let output = output | &common_dims;
195
196 let mut remaining = BTreeMap::new(); let mut ssa_ids = inputs.len();
199 let mut ssa_path = Vec::new();
200
201 for (ssa_id, &key) in inputs.iter().enumerate() {
202 let key = key.clone();
203 if let Some(&existing_id) = remaining.get(&key) {
204 ssa_path.push(vec![existing_id, ssa_id]);
205 remaining.insert(key, ssa_ids);
206 ssa_ids += 1;
207 } else {
208 remaining.insert(key, ssa_id);
209 }
210 }
211
212 let mut dim_to_keys: BTreeMap<char, BTreeSet<ArrayIndexType>> = BTreeMap::new();
214 for key in remaining.keys() {
215 for dim in key - &output {
216 dim_to_keys.entry(dim).or_default().insert(key.clone());
217 }
218 }
219
220 let mut dim_ref_counts = BTreeMap::from([(2, BTreeSet::new()), (3, BTreeSet::new())]);
223 for (&dim, keys) in &dim_to_keys {
224 if keys.len() >= 2 {
225 dim_ref_counts.get_mut(&2).unwrap().insert(dim);
226 }
227 if keys.len() >= 3 {
228 dim_ref_counts.get_mut(&3).unwrap().insert(dim);
229 }
230 }
231 output.iter().for_each(|dim| {
232 dim_ref_counts.get_mut(&2).unwrap().remove(dim);
233 dim_ref_counts.get_mut(&3).unwrap().remove(dim);
234 });
235
236 let mut footprints: BTreeMap<ArrayIndexType, SizeType> =
238 remaining.keys().map(|k| (k.clone(), helpers::compute_size_by_dict(k.iter(), size_dict))).collect();
239
240 let mut queue = BinaryHeap::new();
242 for dim_keys in dim_to_keys.values() {
243 let mut dim_keys_list = dim_keys.iter().cloned().collect_vec();
244 dim_keys_list.sort_by_key(|k| remaining[k]);
245 for i in 0..dim_keys_list.len().saturating_sub(1) {
246 let k1 = &dim_keys_list[i];
247 let k2s_guess = &dim_keys_list[i + 1..];
248 push_candidate(
249 &output,
250 size_dict,
251 &remaining,
252 &footprints,
253 &dim_ref_counts,
254 k1,
255 k2s_guess,
256 &mut queue,
257 push_all,
258 cost_fn,
259 );
260 }
261 }
262
263 while !queue.is_empty() {
265 let Some(con) = choose_fn(&mut queue, &remaining) else {
266 continue; };
268 let GreedyContractionType { k1, k2, k12, .. } = con;
269
270 let ssa_id1 = remaining.remove(&k1).unwrap();
271 let ssa_id2 = remaining.remove(&k2).unwrap();
272
273 for dim in &k1 - &output {
274 dim_to_keys.get_mut(&dim).unwrap().remove(&k1);
275 }
276 for dim in &k2 - &output {
277 dim_to_keys.get_mut(&dim).unwrap().remove(&k2);
278 }
279
280 ssa_path.push(vec![ssa_id1, ssa_id2]);
281
282 if remaining.contains_key(&k12) {
283 ssa_path.push(vec![remaining[&k12], ssa_ids]);
284 ssa_ids += 1;
285 } else {
286 for dim in &k12 - &output {
287 dim_to_keys.get_mut(&dim).unwrap().insert(k12.clone());
288 }
289 }
290 remaining.insert(k12.clone(), ssa_ids);
291 ssa_ids += 1;
292
293 let updated_dims = &(&k1 | &k2) - &output;
294 update_ref_counts(&dim_to_keys, &mut dim_ref_counts, &updated_dims, &output);
295
296 footprints.insert(k12.clone(), helpers::compute_size_by_dict(k12.iter(), size_dict));
297
298 let k1 = k12;
300 let k2s: BTreeSet<ArrayIndexType> =
301 (&k1 - &output).into_iter().flat_map(|dim| dim_to_keys[&dim].clone()).filter(|k| k != &k1).collect();
302
303 if !k2s.is_empty() {
304 push_candidate(
305 &output,
306 size_dict,
307 &remaining,
308 &footprints,
309 &dim_ref_counts,
310 &k1,
311 &k2s.into_iter().collect_vec(),
312 &mut queue,
313 push_all,
314 cost_fn,
315 );
316 }
317 }
318
319 #[derive(Clone, Debug, PartialEq, PartialOrd)]
321 struct FinalEntry {
322 size: SizeType,
323 ssa_id: usize,
324 key: ArrayIndexType,
325 }
326 impl Eq for FinalEntry {}
327 #[allow(clippy::derive_ord_xor_partial_ord)]
328 impl Ord for FinalEntry {
329 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
330 self.partial_cmp(other).unwrap()
331 }
332 }
333
334 let mut final_queue: BinaryHeap<Reverse<FinalEntry>> = remaining
336 .into_iter()
337 .map(|(key, ssa_id)| {
338 let size = helpers::compute_size_by_dict((&key & &output).iter(), size_dict);
339 Reverse(FinalEntry { size, ssa_id, key })
340 })
341 .collect();
342
343 let Some(Reverse(FinalEntry { ssa_id: ssa_id1, key: k1, .. })) = final_queue.pop() else {
344 return ssa_path;
345 };
346
347 let mut current_id = ssa_id1;
348 let mut current_k = k1;
349
350 while let Some(Reverse(FinalEntry { ssa_id: ssa_id2, key: k2, .. })) = final_queue.pop() {
351 ssa_path.push(vec![current_id.min(ssa_id2), current_id.max(ssa_id2)]);
352 let k12: ArrayIndexType = &(¤t_k | &k2) & &output;
353 let cost = helpers::compute_size_by_dict(k12.iter(), size_dict);
354 let new_ssa_id = ssa_ids;
355 ssa_ids += 1;
356
357 final_queue.push(Reverse(FinalEntry { size: cost, ssa_id: new_ssa_id, key: k12.clone() }));
358 let Reverse(FinalEntry { ssa_id: new_id, key: new_k, .. }) = final_queue.pop().unwrap();
359 current_id = new_id;
360 current_k = new_k;
361 }
362
363 ssa_path
364}
365
366pub fn greedy(
388 inputs: &[&ArrayIndexType],
389 output: &ArrayIndexType,
390 size_dict: &SizeDictType,
391 memory_limit: Option<SizeType>,
392 choose_fn: Option<&mut GreedyChooseFn>,
393 cost_fn: Option<paths::CostFn>,
394) -> Result<PathType, String> {
395 if memory_limit.is_some() {
396 let mut branch_optimizer = paths::branch_bound::BranchBound::from("branch-1");
397 return branch_optimizer.optimize_path(inputs, output, size_dict, memory_limit);
398 }
399
400 let ssa_path = ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn);
401 Ok(paths::util::ssa_to_linear(&ssa_path))
402}
403
404#[derive(Default)]
405pub struct Greedy {
406 cost_fn: Option<paths::CostFn>,
407 choose_fn: Option<GreedyChooseFn>,
408}
409
410impl std::fmt::Debug for Greedy {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("Greedy").field("cost_fn", &self.cost_fn).field("choose_fn", &self.choose_fn.is_some()).finish()
413 }
414}
415
416impl PathOptimizer for Greedy {
417 fn optimize_path(
418 &mut self,
419 inputs: &[&ArrayIndexType],
420 output: &ArrayIndexType,
421 size_dict: &SizeDictType,
422 memory_limit: Option<SizeType>,
423 ) -> Result<PathType, String> {
424 greedy(inputs, output, size_dict, memory_limit, self.choose_fn.as_mut(), self.cost_fn)
425 }
426}