opt_einsum_path/paths/
dp.rs

1// src/paths/dp.rs
2use crate::*;
3use std::collections::VecDeque;
4
5// Define our tree type
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ContractionTree {
8    Leaf(usize),
9    Node(Vec<ContractionTree>),
10}
11
12impl From<usize> for ContractionTree {
13    fn from(value: usize) -> Self {
14        ContractionTree::Leaf(value)
15    }
16}
17
18impl From<Vec<ContractionTree>> for ContractionTree {
19    fn from(value: Vec<ContractionTree>) -> Self {
20        ContractionTree::Node(value)
21    }
22}
23
24impl From<Vec<usize>> for ContractionTree {
25    fn from(value: Vec<usize>) -> Self {
26        ContractionTree::Node(value.into_iter().map(ContractionTree::Leaf).collect())
27    }
28}
29
30/// Converts a contraction tree to a contraction path.
31///
32/// A contraction tree can either be a leaf node containing an integer (representingno contraction)
33/// or a node containing a sequence of subtrees to be contracted.Contractions are commutative
34/// (order-independent) and solutions are not unique.
35///
36/// # Parameters
37///
38/// * `tree` - The contraction tree to convert, represented as a `ContractionTree` enum where leaves
39///   are integers and nodes contain sequences of subtrees.
40///
41/// # Returns
42///
43/// A [`PathType`] (`Vec<Vec<usize>>`) representing the contraction path, where each inner
44/// `Vec<usize>` represents a single contraction step with the indices of tensors to contract.
45///
46/// The conversion process works by:
47/// 1. Processing leaf nodes (integers) first to determine their positions
48/// 2. Building the contraction sequence by tracking elementary tensors and remaining contractions
49/// 3. Maintaining proper index accounting throughout the conversion
50///
51/// Note: This implementation matches the behavior of Python's `_tree_to_sequence` function
52/// from opt_einsum, producing equivalent output for equivalent input trees.
53pub fn tree_to_sequence(tree: &ContractionTree) -> PathType {
54    // Handle leaf case (equivalent to Python's int case)
55    if let ContractionTree::Leaf(_) = tree {
56        return Vec::new();
57    }
58
59    let mut c: VecDeque<&ContractionTree> = VecDeque::new(); // list of remaining contractions
60    c.push_back(tree);
61
62    let mut t: Vec<usize> = Vec::new(); // list of elementary tensors
63    let mut s: VecDeque<Vec<usize>> = VecDeque::new(); // resulting contraction sequence
64
65    while !c.is_empty() {
66        let j = c.pop_back().unwrap();
67        s.push_front(Vec::new());
68
69        // First process the integer leaves
70        if let ContractionTree::Node(children) = j {
71            // Collect integer leaves first
72            let mut int_children: Vec<usize> = children
73                .iter()
74                .filter_map(|child| match child {
75                    ContractionTree::Leaf(i) => Some(*i),
76                    _ => None,
77                })
78                .collect();
79
80            // Sort them as in Python
81            int_children.sort_unstable();
82
83            for i in int_children {
84                // Calculate the position as in Python: sum(1 for q in t if q < i)
85                let pos = t.iter().filter(|&&q| q < i).count();
86                s[0].push(pos);
87                t.insert(pos, i);
88            }
89
90            // Then process the non-integer children (other nodes)
91            for i_tup in children.iter().filter(|child| matches!(child, ContractionTree::Node(_))) {
92                let pos = t.len() + c.len();
93                s[0].push(pos);
94                c.push_back(i_tup);
95            }
96        }
97    }
98
99    s.into_iter().collect()
100}
101
102/// Finds disconnected subgraphs in a list of input tensor dimensions.
103///
104/// Input tensors are considered connected if they share summation indices (indices not
105/// present in the output). Disconnected subgraphs can be contracted independently
106/// before forming outer products, which is useful for optimization.
107///
108/// # Parameters
109///
110/// * `inputs` - Slice of sets representing input tensor dimensions (lhs of einsum)
111/// * `output` - Set representing output tensor dimensions (rhs of einsum)
112///
113/// # Returns
114///
115/// Vector of sets where each set contains indices of connected input tensors.
116///
117/// # Note
118///
119/// - Summation indices are determined as `(union of all inputs) \ output`
120/// - The order of returned subgraphs is implementation-defined
121/// - Within each subgraph, the order of tensor indices is sorted
122pub fn find_disconnected_subgraphs(inputs: &[ArrayIndexType], output: &ArrayIndexType) -> Vec<BTreeSet<usize>> {
123    let mut subgraphs = Vec::new();
124    let mut unused_inputs: BTreeSet<usize> = (0..inputs.len()).collect();
125
126    // Calculate all summation indices (union of all inputs minus output)
127    let input_indices: ArrayIndexType = inputs.iter().flat_map(|set| set.iter()).cloned().collect();
128    let i_sum = &input_indices - output;
129
130    while !unused_inputs.is_empty() {
131        let mut g = BTreeSet::new();
132        let mut queue = VecDeque::new();
133
134        // Start with any remaining input
135        queue.push_back(*unused_inputs.iter().next().unwrap());
136        unused_inputs.remove(&queue[0]);
137
138        while !queue.is_empty() {
139            let j = queue.pop_front().unwrap();
140            g.insert(j);
141
142            // Get summation indices for current input
143            let i_tmp: ArrayIndexType = &i_sum & &inputs[j];
144
145            // Find connected inputs
146            let neighbors = unused_inputs.iter().filter(|&&k| !inputs[k].is_disjoint(&i_tmp)).cloned().collect_vec();
147
148            for neighbor in neighbors {
149                queue.push_back(neighbor);
150                unused_inputs.remove(&neighbor);
151            }
152        }
153        subgraphs.push(g);
154    }
155    subgraphs
156}
157
158/// Select elements of `seq` which are marked by the bitmap `s`.
159///
160/// # Parameters
161///
162/// * `s` - Bitmap where each bit represents whether to select the corresponding element
163/// * `seq` - Sequence of items to select from
164///
165/// # Returns
166///
167/// An iterator yielding selected elements from `seq` where the corresponding bit in `s` is set.
168pub fn bitmap_select<'t, T>(s: &'t BigUint, seq: &'t [T]) -> impl Iterator<Item = &'t T> {
169    let uint_1 = BigUint::from_u32(1).unwrap();
170    seq.iter().enumerate().filter(move |(i, _)| (s >> i) & &uint_1 == uint_1).map(move |(_, x)| x)
171}
172
173// Calculates the effective outer indices of the intermediate tensor
174/// corresponding to the subgraph `s`.
175///
176/// # Parameters
177///
178/// * `g` - Bitmap representing all tensors in the current graph
179/// * `all_tensors` - Bitmap representing all possible tensors
180/// * `s` - Bitmap representing the subgraph to calculate legs for
181/// * `inputs` - Slice of input tensor dimension sets
182/// * `i1_cut_i2_wo_output` - Precomputed intersection of indices (i1 ∩ i2) \ output
183/// * `i1_union_i2` - Precomputed union of indices (i1 ∪ i2)
184///
185/// # Returns
186///
187/// The effective outer indices of the intermediate tensor
188pub fn dp_calc_legs(
189    g: &BigUint,
190    all_tensors: &BigUint,
191    s: &BigUint,
192    inputs: &[&ArrayIndexType],
193    i1_cut_i2_wo_output: &ArrayIndexType,
194    i1_union_i2: &ArrayIndexType,
195) -> ArrayIndexType {
196    // set of remaining tensors (= g & (!s))
197    let r = g & (all_tensors ^ s);
198
199    // indices of remaining indices:
200    let i_r = if r != BigUint::ZERO {
201        bitmap_select(&r, inputs).flat_map(|x| x.iter()).collect_vec().into_iter().copied().collect()
202    } else {
203        ArrayIndexType::new()
204    };
205
206    // contraction indices:
207    let i_contract = i1_cut_i2_wo_output - &i_r;
208    i1_union_i2 - &i_contract
209}
210
211#[derive(Debug, Clone)]
212pub struct DpTerm {
213    pub indices: ArrayIndexType,
214    pub cost: SizeType,
215    pub contract: ContractionTree,
216}
217
218pub struct DpCompareArgs<'a> {
219    // parameters
220    pub minimize: &'a str,
221    pub combo_factor: SizeType,
222    // inputs
223    pub inputs: &'a [&'a ArrayIndexType],
224    pub size_dict: &'a SizeDictType,
225    pub all_tensors: BigUint,
226    pub memory_limit: Option<SizeType>,
227    pub cost_cap: SizeType,
228    pub bitmap_g: BigUint,
229}
230
231impl<'a> DpCompareArgs<'a> {
232    /// Performs the inner comparison of whether the two subgraphs (the bitmaps `s1` and `s2`)
233    /// should be merged and added to the dynamic programming search. Will skip for a number of
234    /// reasons:
235    /// 1. If the number of operations to form `s = s1 | s2` including previous contractions is
236    ///    above the cost-cap.
237    /// 2. If we've already found a better way of making `s`.
238    /// 3. If the intermediate tensor corresponding to `s` is going to break the memory limit.
239    pub fn compare_flops(
240        &self,
241        xn: &mut BTreeMap<BigUint, DpTerm>,
242        s1: &BigUint,
243        s2: &BigUint,
244        term1: &DpTerm,
245        term2: &DpTerm,
246        i1_cut_i2_wo_output: &ArrayIndexType,
247    ) {
248        let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
249        let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
250        let i1_union_i2 = i1 | i2;
251
252        let cost = cost1 + cost2 + helpers::compute_size_by_dict(i1_union_i2.iter(), self.size_dict);
253        if cost <= self.cost_cap {
254            let s = s1 | s2;
255            if xn.get(&s).is_none_or(|term| cost < term.cost) {
256                let indices =
257                    dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
258                let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
259                if self.memory_limit.is_none_or(|limit| mem <= limit) {
260                    let contract = vec![contract1.clone(), contract2.clone()].into();
261                    xn.insert(s, DpTerm { indices, cost, contract });
262                }
263            }
264        }
265    }
266
267    /// Like `compare_flops` but sieves the potential contraction based on the size of the
268    /// intermediate tensor created, rather than the number of operations, and so calculates that
269    /// first.
270    pub fn compare_size(
271        &self,
272        xn: &mut BTreeMap<BigUint, DpTerm>,
273        s1: &BigUint,
274        s2: &BigUint,
275        term1: &DpTerm,
276        term2: &DpTerm,
277        i1_cut_i2_wo_output: &ArrayIndexType,
278    ) {
279        let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
280        let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
281        let i1_union_i2 = i1 | i2;
282        let s = s1 | s2;
283        let indices =
284            dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
285
286        let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
287        let cost = (*cost1).max(*cost2).max(mem);
288        if cost <= self.cost_cap
289            && xn.get(&s).is_none_or(|term| cost < term.cost)
290            && self.memory_limit.is_none_or(|limit| mem <= limit)
291        {
292            let contract = vec![contract1.clone(), contract2.clone()].into();
293            xn.insert(s, DpTerm { indices, cost, contract });
294        }
295    }
296    /// Like `compare_flops` but sieves the potential contraction based on the total size of memory
297    /// created, rather than the number of operations, and so calculates that first.
298    pub fn compare_write(
299        &self,
300        xn: &mut BTreeMap<BigUint, DpTerm>,
301        s1: &BigUint,
302        s2: &BigUint,
303        term1: &DpTerm,
304        term2: &DpTerm,
305        i1_cut_i2_wo_output: &ArrayIndexType,
306    ) {
307        let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
308        let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
309        let i1_union_i2 = i1 | i2;
310        let s = s1 | s2;
311        let indices =
312            dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
313
314        let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
315        let cost = cost1 + cost2 + mem;
316
317        if cost <= self.cost_cap
318            && xn.get(&s).is_none_or(|term| cost < term.cost)
319            && self.memory_limit.is_none_or(|limit| mem <= limit)
320        {
321            let contract = vec![contract1.clone(), contract2.clone()].into();
322            xn.insert(s, DpTerm { indices, cost, contract });
323        }
324    }
325
326    /// Like `compare_flops` but sieves the potential contraction based
327    /// on some combination of both the flops and size.
328    pub fn compare_combo(
329        &self,
330        xn: &mut BTreeMap<BigUint, DpTerm>,
331        s1: &BigUint,
332        s2: &BigUint,
333        term1: &DpTerm,
334        term2: &DpTerm,
335        i1_cut_i2_wo_output: &ArrayIndexType,
336    ) {
337        let DpTerm { indices: i1, cost: cost1, contract: contract1 } = term1;
338        let DpTerm { indices: i2, cost: cost2, contract: contract2 } = term2;
339        let i1_union_i2 = i1 | i2;
340        let s = s1 | s2;
341        let indices =
342            dp_calc_legs(&self.bitmap_g, &self.all_tensors, &s, self.inputs, i1_cut_i2_wo_output, &i1_union_i2);
343
344        let mem = helpers::compute_size_by_dict(indices.iter(), self.size_dict);
345        let f = helpers::compute_size_by_dict(i1_union_i2.iter(), self.size_dict);
346
347        // Hardcoded to sum: f + self.combo_factor * mem
348        let combined = match self.minimize {
349            "combo" => f + self.combo_factor * mem,
350            "limit" => f.max(self.combo_factor * mem),
351            _ => panic!("Unknown minimize type for combo mode: {}", self.minimize),
352        };
353        let cost = cost1 + cost2 + combined;
354
355        if cost <= self.cost_cap
356            && xn.get(&s).is_none_or(|term| cost < term.cost)
357            && self.memory_limit.is_none_or(|limit| mem <= limit)
358        {
359            let contract = vec![contract1.clone(), contract2.clone()].into();
360            xn.insert(s, DpTerm { indices, cost, contract });
361        }
362    }
363
364    pub fn scale(&self) -> SizeType {
365        get_scale_from_minimize(self.minimize)
366    }
367
368    pub fn compare(
369        &self,
370        xn: &mut BTreeMap<BigUint, DpTerm>,
371        s1: &BigUint,
372        s2: &BigUint,
373        term1: &DpTerm,
374        term2: &DpTerm,
375        i1_cut_i2_wo_output: &ArrayIndexType,
376    ) {
377        let minimize_split = self.minimize.split('-').collect_vec();
378        if minimize_split.is_empty() {
379            panic!("Unknown minimize type: {}", self.minimize);
380        }
381        match minimize_split[0] {
382            "flops" => self.compare_flops(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
383            "size" => self.compare_size(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
384            "write" => self.compare_write(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
385            "combo" | "limit" => self.compare_combo(xn, s1, s2, term1, term2, i1_cut_i2_wo_output),
386            _ => panic!("Unknown minimize type: {}", self.minimize),
387        }
388    }
389}
390
391pub fn get_scale_from_minimize(minimize: &str) -> SizeType {
392    match minimize {
393        "flops" | "size" | "write" => SizeType::one(),
394        "combo" | "limit" => SizeType::MAX,
395        _ => panic!("Unknown minimize type: {minimize}"),
396    }
397}
398
399/// Makes a simple left-to-right binary tree out of a sequence of terms.
400///
401/// # Arguments
402/// * `seq` - Sequence of terms to nest
403///
404/// # Returns
405/// A `ContractionTree` representing the left-nested binary tree
406pub fn simple_tree_tuple(seq: &[ContractionTree]) -> ContractionTree {
407    seq.iter().cloned().reduce(|left, right| ContractionTree::Node(vec![left, right])).unwrap()
408}
409use std::collections::{BTreeMap, BTreeSet};
410
411/// Parses inputs for single term index operations (indices appearing on one tensor).
412///
413/// Returns:
414/// - Parsed inputs with single indices removed
415/// - Inputs that were reduced to scalars
416/// - Contractions needed for the reductions
417pub fn dp_parse_out_single_term_ops(
418    inputs: &[&ArrayIndexType],
419    all_inds: &[char],
420    ind_counts: &SizeDictType,
421) -> (Vec<ArrayIndexType>, Vec<ContractionTree>, Vec<ContractionTree>) {
422    let i_single: BTreeSet<char> = all_inds.iter().filter(|&c| ind_counts.get(c) == Some(&1)).cloned().collect();
423
424    let mut inputs_parsed = Vec::new();
425    let mut inputs_done = Vec::new();
426    let mut inputs_contractions = Vec::new();
427
428    for (j, input) in inputs.iter().enumerate() {
429        let i_reduced: ArrayIndexType = *input - &i_single;
430        if i_reduced.is_empty() && !input.is_empty() {
431            // Input reduced to scalar - remove
432            inputs_done.push(vec![j].into());
433        } else {
434            // Add single contraction if indices were reduced
435            inputs_contractions.push(if i_reduced.len() != input.len() { vec![j].into() } else { j.into() });
436            inputs_parsed.push(i_reduced);
437        }
438    }
439
440    (inputs_parsed, inputs_done, inputs_contractions)
441}
442
443#[derive(Debug, Clone)]
444pub struct DynamicProgramming {
445    pub minimize: String,
446    pub search_outer: bool,
447    pub cost_cap: SizeLimitType,
448    pub combo_factor: SizeType,
449}
450
451impl Default for DynamicProgramming {
452    fn default() -> Self {
453        Self {
454            minimize: "flops".into(),
455            search_outer: false,
456            cost_cap: true.into(),
457            combo_factor: SizeType::from_usize(64).unwrap(),
458        }
459    }
460}
461
462impl DynamicProgramming {
463    pub fn find_optimal_path(
464        &self,
465        inputs: &[&ArrayIndexType],
466        output: &ArrayIndexType,
467        size_dict: &SizeDictType,
468        memory_limit: Option<SizeType>,
469    ) -> Result<PathType, String> {
470        let uint_1 = BigUint::from(1u32);
471        let uint_0 = BigUint::from(0u32);
472
473        // Count index occurrences
474        let ind_counts: BTreeMap<char, usize> =
475            inputs.iter().flat_map(|inds| inds.iter()).chain(output.iter()).fold(BTreeMap::new(), |mut counts, &c| {
476                *counts.entry(c).or_default() += 1;
477                counts
478            });
479
480        let all_inds: Vec<char> = ind_counts.keys().copied().collect();
481
482        // Parse single-term operations
483        let (inputs, inputs_done, inputs_contractions) = dp_parse_out_single_term_ops(inputs, &all_inds, &ind_counts);
484        let inputs_ref = inputs.iter().collect_vec();
485
486        if inputs.is_empty() {
487            return Ok(tree_to_sequence(&simple_tree_tuple(&inputs_done)));
488        }
489
490        // Initialize subgraph tracking
491        let mut subgraph_contractions = inputs_done;
492        let mut subgraph_sizes: Vec<SizeType> = vec![SizeType::one(); subgraph_contractions.len()];
493
494        // Find disconnected subgraphs
495        let subgraphs = if self.search_outer {
496            vec![(0..inputs.len()).collect_vec()]
497        } else {
498            find_disconnected_subgraphs(&inputs, output).into_iter().map(|s| s.into_iter().collect()).collect()
499        };
500
501        let all_tensors = (&uint_1 << inputs.len()) - &uint_1;
502        let naive_scale = get_scale_from_minimize(&self.minimize);
503        let naive_cost = naive_scale
504            * SizeType::from_usize(inputs.len()).unwrap()
505            * size_dict.values().map(|v| SizeType::from_usize(*v).unwrap()).product::<SizeType>();
506
507        for g in subgraphs {
508            let bitmap_g = g.iter().fold(uint_0.clone(), |acc, &j| acc | (&uint_1 << j));
509
510            // Initialize DP table
511            let mut x: Vec<BTreeMap<BigUint, DpTerm>> = vec![BTreeMap::new(); g.len() + 1];
512            x[1] = g
513                .iter()
514                .map(|&j| {
515                    (&uint_1 << j, DpTerm {
516                        indices: inputs[j].clone(),
517                        cost: SizeType::zero(),
518                        contract: inputs_contractions[j].clone(),
519                    })
520                })
521                .collect();
522
523            // Initialize cost cap
524            let subgraph_inds = bitmap_select(&bitmap_g, &inputs).flat_map(|inds| inds.iter().copied()).collect();
525
526            let mut cost_cap = match self.cost_cap {
527                SizeLimitType::Size(cap) => cap,
528                SizeLimitType::None => SizeType::MAX,
529                SizeLimitType::MaxInput => helpers::compute_size_by_dict((&subgraph_inds & output).iter(), size_dict),
530            };
531
532            let cost_increment = if subgraph_inds.is_empty() {
533                SizeType::from_usize(2).unwrap()
534            } else {
535                subgraph_inds
536                    .iter()
537                    .map(|c| size_dict[c] as SizeType)
538                    .fold(SizeType::MAX, SizeType::min)
539                    .max(SizeType::from_usize(2).unwrap())
540            };
541
542            let mut dp_comp_args = DpCompareArgs {
543                inputs: &inputs_ref,
544                size_dict,
545                all_tensors: all_tensors.clone(),
546                memory_limit,
547                cost_cap,
548                bitmap_g,
549                combo_factor: self.combo_factor,
550                minimize: &self.minimize,
551            };
552
553            fn has_common_bits(s1: &BigUint, s2: &BigUint) -> bool {
554                let digits1 = s1.iter_u64_digits();
555                let digits2 = s2.iter_u64_digits();
556                digits1.zip(digits2).any(|(d1, d2)| d1 & d2 != 0)
557            }
558
559            while x.last().unwrap().is_empty() {
560                for n in 2..=g.len() {
561                    let (xn_left, xn_right) = x.split_at_mut(n);
562                    let xn = &mut xn_right[0];
563                    for m in 1..=(n / 2) {
564                        for (s1, term1) in &xn_left[m] {
565                            for (s2, term2) in &xn_left[n - m] {
566                                // EFFICIENCY: `s1 & s2 != 0` changes to `!has_common_bits(s1, s2)`
567                                if !has_common_bits(s1, s2) && (m != n - m || s1 < s2) {
568                                    let i1 = &term1.indices;
569                                    let i2 = &term2.indices;
570                                    // EFFICIENCY: use iterators instead of `&` and `-` for set operations
571                                    // let i1_cut_i2_wo_output = &(i1 & i2) - output;
572                                    let i1_cut_i2_wo_output: ArrayIndexType = i1
573                                        .iter()
574                                        .filter(|&&c| i2.contains(&c) && !output.contains(&c))
575                                        .cloned()
576                                        .collect();
577                                    if self.search_outer || !i1_cut_i2_wo_output.is_empty() {
578                                        dp_comp_args.compare(xn, s1, s2, term1, term2, &i1_cut_i2_wo_output);
579                                    }
580                                }
581                            }
582                        }
583                    }
584                }
585
586                // avoid overflow
587                cost_cap = match cost_cap >= SizeType::MAX / cost_increment {
588                    true => SizeType::MAX,
589                    false => cost_cap * cost_increment,
590                };
591                dp_comp_args.cost_cap = cost_cap;
592
593                if cost_cap > naive_cost && x.last().unwrap().is_empty() {
594                    return Err("No contraction found for given memory_limit".into());
595                }
596            }
597
598            let (_, term) = x.last().unwrap().iter().next().unwrap();
599            subgraph_contractions.push(term.contract.clone());
600            subgraph_sizes.push(helpers::compute_size_by_dict(term.indices.iter(), size_dict));
601        }
602
603        // Sort subgraphs by size
604        let sorted_indices =
605            (0..subgraph_sizes.len()).sorted_by(|&a, &b| subgraph_sizes[a].partial_cmp(&subgraph_sizes[b]).unwrap());
606        let sorted_contractions = sorted_indices.map(|i| subgraph_contractions[i].clone()).collect_vec();
607
608        Ok(tree_to_sequence(&simple_tree_tuple(&sorted_contractions)))
609    }
610}
611
612impl PathOptimizer for DynamicProgramming {
613    fn optimize_path(
614        &mut self,
615        inputs: &[&ArrayIndexType],
616        output: &ArrayIndexType,
617        size_dict: &SizeDictType,
618        memory_limit: Option<SizeType>,
619    ) -> Result<PathType, String> {
620        self.find_optimal_path(inputs, output, size_dict, memory_limit)
621    }
622}
623
624impl From<&str> for DynamicProgramming {
625    fn from(s: &str) -> Self {
626        let s = s.replace(['_', ' '], "-").to_lowercase();
627        if s == "dp" || s == "dynamic-programming" {
628            return DynamicProgramming::default();
629        }
630        if s.starts_with("dp-") {
631            let minimize = s.strip_prefix("dp-").unwrap();
632            // sanity of minimize
633            if minimize.starts_with("combo") || minimize.starts_with("limit") {
634                let minimize_split = minimize.split('-').collect_vec();
635                if minimize_split.len() > 2 {
636                    panic!("Unknown dynamic programming optimizer: {s}");
637                }
638                match minimize_split.len() {
639                    1 => {
640                        let minimize = minimize_split[0];
641                        if minimize != "combo" && minimize != "limit" {
642                            panic!("Unknown dynamic programming optimizer: {s}");
643                        }
644                        return DynamicProgramming { minimize: minimize.into(), ..Default::default() };
645                    },
646                    2 => {
647                        let minimize = minimize_split[0];
648                        if minimize != "combo" && minimize != "limit" {
649                            panic!("Unknown dynamic programming optimizer: {s}");
650                        }
651                        let combo_factor = match minimize_split[1].parse::<SizeType>() {
652                            Ok(factor) => factor,
653                            Err(_) => panic!("Invalid combo factor in dynamic programming optimizer: {s}"),
654                        };
655                        return DynamicProgramming { minimize: minimize.into(), combo_factor, ..Default::default() };
656                    },
657                    _ => panic!("Unknown dynamic programming optimizer: {s}"),
658                };
659            } else if minimize == "flops" || minimize == "size" || minimize == "write" {
660                return DynamicProgramming { minimize: minimize.into(), ..Default::default() };
661            } else {
662                panic!("Unknown dynamic programming optimizer: {s}");
663            }
664        }
665        panic!("Unknown dynamic programming optimizer: {s}");
666    }
667}
668
669#[test]
670fn test_tree_to_sequence() {
671    let tree: ContractionTree = ContractionTree::from(vec![
672        ContractionTree::from(vec![1, 2]),
673        vec![ContractionTree::from(0), ContractionTree::from(vec![4, 5, 3])].into(),
674    ]);
675
676    let path = tree_to_sequence(&tree);
677    println!("{path:?}");
678    assert_eq!(path, vec![vec![1, 2], vec![1, 2, 3], vec![0, 2], vec![0, 1]]);
679}
680
681#[test]
682fn test_find_disconnected_subgraphs() {
683    use crate::helpers::setify;
684    // First test case
685    let inputs1 = vec![setify("ab"), setify("c"), setify("ad")];
686    let output1 = setify("bd");
687    let result1 = find_disconnected_subgraphs(&inputs1, &output1);
688    assert_eq!(result1, vec![setify([0, 2]), setify([1])]);
689
690    // Second test case
691    let inputs2 = vec![setify("ab"), setify("c"), setify("ad")];
692    let output2 = setify("abd");
693    let result2 = find_disconnected_subgraphs(&inputs2, &output2);
694    assert_eq!(result2, vec![setify([0]), setify([1]), setify([2])]);
695}
696
697#[test]
698fn test_bitmap_select() {
699    use crate::helpers::setify;
700    let seq = vec![setify("A"), setify("B"), setify("C"), setify("D"), setify("E")];
701
702    // Test case from Python example
703    let s = BigUint::from(0b11010_u32);
704    let selected = bitmap_select(&s, &seq).collect_vec();
705    assert_eq!(selected, vec![&setify("B"), &setify("D"), &setify("E")]);
706
707    // Additional test cases
708    assert_eq!(bitmap_select(&BigUint::from(0b00000_u32), &seq).count(), 0);
709    assert_eq!(bitmap_select(&BigUint::from(0b11111_u32), &seq).count(), 5);
710    assert_eq!(bitmap_select(&BigUint::from(0b00001_u32), &seq).collect_vec(), vec![&setify("A")]);
711}
712
713#[test]
714fn test_simple_tree_tuple() {
715    let tree = simple_tree_tuple(&[1.into(), 2.into(), 3.into(), 4.into()]);
716    assert_eq!(
717        tree,
718        ContractionTree::Node(vec![
719            ContractionTree::Node(vec![ContractionTree::Node(vec![1.into(), 2.into()]), 3.into()]),
720            4.into()
721        ])
722    );
723}