opt_einsum_path/paths/
greedy.rs

1use crate::*;
2
3pub type GreedyChooseFn = Box<
4    dyn FnMut(
5        &mut BinaryHeap<GreedyContractionType>,
6        &BTreeMap<ArrayIndexType, usize>,
7    ) -> Option<GreedyContractionType>,
8>;
9
10/// Type representing the cost of a greedy contraction.
11///
12/// Please note that order of cost is not reversed (greater is better).
13#[derive(Debug, Clone, PartialEq, PartialOrd)]
14pub struct GreedyCostType {
15    /// The cost of the contraction.
16    ///
17    /// Cost is defined as the size of the resulting array after the contraction,
18    /// minus the sizes of the two input arrays being contracted:
19    /// `size(final) - size(input1) - size(input2)`.
20    pub cost: SizeType,
21    /// The ID of the first input array being contracted.
22    pub id1: usize,
23    /// The ID of the second input array being contracted.
24    pub id2: usize,
25}
26
27impl Eq for GreedyCostType {}
28
29#[allow(clippy::derive_ord_xor_partial_ord)]
30impl Ord for GreedyCostType {
31    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32        self.partial_cmp(other).unwrap()
33    }
34}
35
36/// Type representing a greedy contraction candidate.
37///
38/// Order of cost is reversed (less is better).
39#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
40pub struct GreedyContractionType {
41    /// The cost of the contraction, wrapped in `Reverse` to reverse the order.
42    pub cost: Reverse<GreedyCostType>,
43    /// The first input array indices being contracted.
44    pub k1: ArrayIndexType,
45    /// The second input array indices being contracted.
46    pub k2: ArrayIndexType,
47    /// The resulting array indices after the contraction.
48    pub k12: ArrayIndexType,
49}
50
51/// Given k1 and k2 tensors, compute the resulting indices k12 and the cost of the contraction.
52fn get_candidate(
53    output: &ArrayIndexType,
54    size_dict: &SizeDictType,
55    remaining: &BTreeMap<ArrayIndexType, usize>,
56    footprints: &BTreeMap<ArrayIndexType, SizeType>,
57    dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
58    k1: &ArrayIndexType,
59    k2: &ArrayIndexType,
60    cost_fn: paths::CostFn,
61) -> GreedyContractionType {
62    let either = k1 | k2;
63    let two = k1 & k2;
64    let one = &either - &two;
65
66    // k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
67    // indices in output must kept
68    let part1 = either.intersection(output);
69    // remaining indices kept if referenced by other tensors
70    let part2 = two.intersection(&dim_ref_counts[&3]);
71    let part3 = one.intersection(&dim_ref_counts[&2]);
72    let k12: ArrayIndexType = part1.chain(part2).chain(part3).cloned().collect();
73
74    let size12 = helpers::compute_size_by_dict(k12.iter(), size_dict);
75    let footprint1 = footprints[k1];
76    let footprint2 = footprints[k2];
77    let cost = cost_fn(size12, footprint1, footprint2, 0, 0, 0);
78
79    let id1 = remaining[k1];
80    let id2 = remaining[k2];
81    let (k1, id1, k2, id2) =
82        if id1 > id2 { (k1.clone(), id1, k2.clone(), id2) } else { (k2.clone(), id2, k1.clone(), id1) };
83
84    GreedyContractionType { cost: Reverse(GreedyCostType { cost, id1, id2 }), k1, k2, k12 }
85}
86
87/// Given k1 and its candidate k2s, push the best candidates to the queue.
88fn push_candidate(
89    output: &ArrayIndexType,
90    size_dict: &SizeDictType,
91    remaining: &BTreeMap<ArrayIndexType, usize>,
92    footprints: &BTreeMap<ArrayIndexType, SizeType>,
93    dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
94    k1: &ArrayIndexType,
95    k2s: &[ArrayIndexType],
96    queue: &mut BinaryHeap<GreedyContractionType>,
97    push_all: bool,
98    cost_fn: paths::CostFn,
99) {
100    let candidates: Vec<GreedyContractionType> = k2s
101        .iter()
102        .map(|k2| get_candidate(output, size_dict, remaining, footprints, dim_ref_counts, k1, k2, cost_fn))
103        .collect();
104
105    if push_all {
106        candidates.into_iter().for_each(|c| queue.push(c));
107    } else if let Some(max_cand) = candidates.into_iter().max() {
108        queue.push(max_cand);
109    }
110}
111
112/// Update the reference counts for dimensions in `dims` based on their presence in `dim_to_keys`.
113///
114/// Note on `dim_ref_counts`: This is a mapping of
115/// - 0, 1, 2: the indices that appear in exactly that many remaining tensors (excluding output)
116/// - 3: the indices that appear in 3 or more remaining tensors (excluding output)
117fn update_ref_counts(
118    dim_to_keys: &BTreeMap<char, BTreeSet<ArrayIndexType>>,
119    dim_ref_counts: &mut BTreeMap<usize, BTreeSet<char>>,
120    dims: &ArrayIndexType,
121    output: &ArrayIndexType,
122) {
123    for dim in dims {
124        if output.contains(dim) {
125            continue;
126        }
127        let count = dim_to_keys.get(dim).map(|s| s.len()).unwrap_or(0);
128
129        match count {
130            0..=1 => {
131                dim_ref_counts.get_mut(&2).unwrap().remove(dim);
132                dim_ref_counts.get_mut(&3).unwrap().remove(dim);
133            },
134            2 => {
135                dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
136                dim_ref_counts.get_mut(&3).unwrap().remove(dim);
137            },
138            3.. => {
139                dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
140                dim_ref_counts.get_mut(&3).unwrap().insert(*dim);
141            },
142        }
143    }
144}
145
146/// Default contraction chooser that simply takes the minimum cost option.
147///
148/// This function will pop candidates only when they are valid (both k1 and k2 must be present in
149/// `remaining`).
150pub fn simple_chooser(
151    queue: &mut BinaryHeap<GreedyContractionType>,
152    remaining: &BTreeMap<ArrayIndexType, usize>,
153) -> Option<GreedyContractionType> {
154    while let Some(cand) = queue.pop() {
155        if remaining.contains_key(&cand.k1) && remaining.contains_key(&cand.k2) {
156            return Some(cand);
157        }
158    }
159    None
160}
161
162/// This is the core function for [`greedy`] but produces a path with static single assignment
163/// ids rather than recycled linear ids. SSA ids are cheaper to work with and easier to reason
164/// about.
165pub fn ssa_greedy_optimize(
166    inputs: &[&ArrayIndexType],
167    output: &ArrayIndexType,
168    size_dict: &SizeDictType,
169    choose_fn: Option<&mut GreedyChooseFn>,
170    cost_fn: Option<paths::CostFn>,
171) -> PathType {
172    if inputs.is_empty() {
173        return vec![];
174    }
175
176    if inputs.len() == 1 {
177        // Perform a single contraction to match output shape.
178        return vec![vec![0]];
179    }
180
181    // set the function that chooses which contraction to take
182    let push_all = choose_fn.is_none();
183    let mut default_chooser: GreedyChooseFn = Box::new(simple_chooser);
184    let choose_fn: &mut GreedyChooseFn = if let Some(choose_fn) = choose_fn { choose_fn } else { &mut default_chooser };
185
186    // set the function that assigns a heuristic cost to a possible contraction
187    let cost_fn = cost_fn.unwrap_or(paths::util::memory_removed(false));
188
189    // A dim that is common to all tensors might as well be an output dim, since it cannot be contracted
190    // until the final step. This avoids an expensive all-pairs comparison to search for possible
191    // contractions at each step, leading to speedup in many practical problems where all tensors share
192    // a common batch dimension.
193    let common_dims = inputs.iter().skip(1).fold(inputs[0].clone(), |acc, s| &acc & s);
194    let output = output | &common_dims;
195
196    // Deduplicate shapes by eagerly computing Hadamard products.
197    let mut remaining = BTreeMap::new(); // key -> ssa_id
198    let mut ssa_ids = inputs.len();
199    let mut ssa_path = Vec::new();
200
201    for (ssa_id, &key) in inputs.iter().enumerate() {
202        let key = key.clone();
203        if let Some(&existing_id) = remaining.get(&key) {
204            ssa_path.push(vec![existing_id, ssa_id]);
205            remaining.insert(key, ssa_ids);
206            ssa_ids += 1;
207        } else {
208            remaining.insert(key, ssa_id);
209        }
210    }
211
212    // Keep track of possible contraction dims.
213    let mut dim_to_keys: BTreeMap<char, BTreeSet<ArrayIndexType>> = BTreeMap::new();
214    for key in remaining.keys() {
215        for dim in key - &output {
216            dim_to_keys.entry(dim).or_default().insert(key.clone());
217        }
218    }
219
220    // Keep track of the number of tensors using each dim; when the dim is no longer used it can be
221    // contracted. Since we specialize to binary ops, we only care about ref counts of >=2 or >=3.
222    let mut dim_ref_counts = BTreeMap::from([(2, BTreeSet::new()), (3, BTreeSet::new())]);
223    for (&dim, keys) in &dim_to_keys {
224        if keys.len() >= 2 {
225            dim_ref_counts.get_mut(&2).unwrap().insert(dim);
226        }
227        if keys.len() >= 3 {
228            dim_ref_counts.get_mut(&3).unwrap().insert(dim);
229        }
230    }
231    output.iter().for_each(|dim| {
232        dim_ref_counts.get_mut(&2).unwrap().remove(dim);
233        dim_ref_counts.get_mut(&3).unwrap().remove(dim);
234    });
235
236    // Compute separable part of the objective function for contractions.
237    let mut footprints: BTreeMap<ArrayIndexType, SizeType> =
238        remaining.keys().map(|k| (k.clone(), helpers::compute_size_by_dict(k.iter(), size_dict))).collect();
239
240    // Find initial candidate contractions.
241    let mut queue = BinaryHeap::new();
242    for dim_keys in dim_to_keys.values() {
243        let mut dim_keys_list = dim_keys.iter().cloned().collect_vec();
244        dim_keys_list.sort_by_key(|k| remaining[k]);
245        for i in 0..dim_keys_list.len().saturating_sub(1) {
246            let k1 = &dim_keys_list[i];
247            let k2s_guess = &dim_keys_list[i + 1..];
248            push_candidate(
249                &output,
250                size_dict,
251                &remaining,
252                &footprints,
253                &dim_ref_counts,
254                k1,
255                k2s_guess,
256                &mut queue,
257                push_all,
258                cost_fn,
259            );
260        }
261    }
262
263    // Greedily contract pairs of tensors.
264    while !queue.is_empty() {
265        let Some(con) = choose_fn(&mut queue, &remaining) else {
266            continue; // allow choose_fn to flag all candidates obsolete
267        };
268        let GreedyContractionType { k1, k2, k12, .. } = con;
269
270        let ssa_id1 = remaining.remove(&k1).unwrap();
271        let ssa_id2 = remaining.remove(&k2).unwrap();
272
273        for dim in &k1 - &output {
274            dim_to_keys.get_mut(&dim).unwrap().remove(&k1);
275        }
276        for dim in &k2 - &output {
277            dim_to_keys.get_mut(&dim).unwrap().remove(&k2);
278        }
279
280        ssa_path.push(vec![ssa_id1, ssa_id2]);
281
282        if remaining.contains_key(&k12) {
283            ssa_path.push(vec![remaining[&k12], ssa_ids]);
284            ssa_ids += 1;
285        } else {
286            for dim in &k12 - &output {
287                dim_to_keys.get_mut(&dim).unwrap().insert(k12.clone());
288            }
289        }
290        remaining.insert(k12.clone(), ssa_ids);
291        ssa_ids += 1;
292
293        let updated_dims = &(&k1 | &k2) - &output;
294        update_ref_counts(&dim_to_keys, &mut dim_ref_counts, &updated_dims, &output);
295
296        footprints.insert(k12.clone(), helpers::compute_size_by_dict(k12.iter(), size_dict));
297
298        // Find new candidate contractions.
299        let k1 = k12;
300        let k2s: BTreeSet<ArrayIndexType> =
301            (&k1 - &output).into_iter().flat_map(|dim| dim_to_keys[&dim].clone()).filter(|k| k != &k1).collect();
302
303        if !k2s.is_empty() {
304            push_candidate(
305                &output,
306                size_dict,
307                &remaining,
308                &footprints,
309                &dim_ref_counts,
310                &k1,
311                &k2s.into_iter().collect_vec(),
312                &mut queue,
313                push_all,
314                cost_fn,
315            );
316        }
317    }
318
319    // Greedily compute pairwise outer products.
320    #[derive(Clone, Debug, PartialEq, PartialOrd)]
321    struct FinalEntry {
322        size: SizeType,
323        ssa_id: usize,
324        key: ArrayIndexType,
325    }
326    impl Eq for FinalEntry {}
327    #[allow(clippy::derive_ord_xor_partial_ord)]
328    impl Ord for FinalEntry {
329        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
330            self.partial_cmp(other).unwrap()
331        }
332    }
333
334    // Greedily compute pairwise outer products.
335    let mut final_queue: BinaryHeap<Reverse<FinalEntry>> = remaining
336        .into_iter()
337        .map(|(key, ssa_id)| {
338            let size = helpers::compute_size_by_dict((&key & &output).iter(), size_dict);
339            Reverse(FinalEntry { size, ssa_id, key })
340        })
341        .collect();
342
343    let Some(Reverse(FinalEntry { ssa_id: ssa_id1, key: k1, .. })) = final_queue.pop() else {
344        return ssa_path;
345    };
346
347    let mut current_id = ssa_id1;
348    let mut current_k = k1;
349
350    while let Some(Reverse(FinalEntry { ssa_id: ssa_id2, key: k2, .. })) = final_queue.pop() {
351        ssa_path.push(vec![current_id.min(ssa_id2), current_id.max(ssa_id2)]);
352        let k12: ArrayIndexType = &(&current_k | &k2) & &output;
353        let cost = helpers::compute_size_by_dict(k12.iter(), size_dict);
354        let new_ssa_id = ssa_ids;
355        ssa_ids += 1;
356
357        final_queue.push(Reverse(FinalEntry { size: cost, ssa_id: new_ssa_id, key: k12.clone() }));
358        let Reverse(FinalEntry { ssa_id: new_id, key: new_k, .. }) = final_queue.pop().unwrap();
359        current_id = new_id;
360        current_k = new_k;
361    }
362
363    ssa_path
364}
365
366/// Finds the path by a three stage greedy algorithm.
367///
368/// 1. Eagerly compute Hadamard products.
369/// 2. Greedily compute contractions to maximize `removed_size`.
370/// 3. Greedily compute outer products.
371///
372/// This algorithm scales quadratically with respect to the maximum number of elements sharing a
373/// common dim.
374///
375/// # Parameters
376///
377/// - **inputs** - List of sets that represent the lhs side of the einsum subscript
378/// - **output** - Set that represents the rhs side of the overall einsum subscript
379/// - **size_dict** - Dictionary of index sizes
380/// - **memory_limit** - The maximum number of elements in a temporary array
381/// - **choose_fn** - A function that chooses which contraction to perform from the queue
382/// - **cost_fn** - A function that assigns a potential contraction a cost.
383///
384/// # Returns
385///
386/// - **path** - The contraction order (a list of tuples of ints).
387pub fn greedy(
388    inputs: &[&ArrayIndexType],
389    output: &ArrayIndexType,
390    size_dict: &SizeDictType,
391    memory_limit: Option<SizeType>,
392    choose_fn: Option<&mut GreedyChooseFn>,
393    cost_fn: Option<paths::CostFn>,
394) -> Result<PathType, String> {
395    if memory_limit.is_some() {
396        let mut branch_optimizer = paths::branch_bound::BranchBound::from("branch-1");
397        return branch_optimizer.optimize_path(inputs, output, size_dict, memory_limit);
398    }
399
400    let ssa_path = ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn);
401    Ok(paths::util::ssa_to_linear(&ssa_path))
402}
403
404#[derive(Default)]
405pub struct Greedy {
406    cost_fn: Option<paths::CostFn>,
407    choose_fn: Option<GreedyChooseFn>,
408}
409
410impl std::fmt::Debug for Greedy {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("Greedy").field("cost_fn", &self.cost_fn).field("choose_fn", &self.choose_fn.is_some()).finish()
413    }
414}
415
416impl PathOptimizer for Greedy {
417    fn optimize_path(
418        &mut self,
419        inputs: &[&ArrayIndexType],
420        output: &ArrayIndexType,
421        size_dict: &SizeDictType,
422        memory_limit: Option<SizeType>,
423    ) -> Result<PathType, String> {
424        greedy(inputs, output, size_dict, memory_limit, self.choose_fn.as_mut(), self.cost_fn)
425    }
426}