opt_einsum_path/paths/
mod.rs

1//! Contains the path technology behind opt_einsum in addition to several path helpers.
2
3pub mod branch_bound;
4pub mod dp;
5pub mod greedy;
6pub mod greedy_random;
7pub mod no_optimize;
8pub mod optimal;
9pub mod util;
10
11use crate::*;
12use paths::util::*;
13use std::str::FromStr;
14
15pub trait PathOptimizer {
16    fn optimize_path(
17        &mut self,
18        inputs: &[&ArrayIndexType],
19        output: &ArrayIndexType,
20        size_dict: &SizeDictType,
21        memory_limit: Option<SizeType>,
22    ) -> Result<PathType, String>;
23}
24
25#[non_exhaustive]
26#[derive(Debug)]
27pub enum OptimizeKind {
28    Optimal(paths::optimal::Optimal),
29    NoOptimize(paths::no_optimize::NoOptimize),
30    BranchBound(paths::branch_bound::BranchBound),
31    Greedy(paths::greedy::Greedy),
32    DynamicProgramming(paths::dp::DynamicProgramming),
33    RandomGreedy(paths::greedy_random::RandomGreedy),
34    Auto(Auto),
35    AutoHq(AutoHq),
36}
37
38#[derive(Debug, Default, Clone)]
39pub struct Auto {}
40
41#[derive(Debug, Default, Clone)]
42pub struct AutoHq {}
43
44impl PathOptimizer for Auto {
45    fn optimize_path(
46        &mut self,
47        inputs: &[&ArrayIndexType],
48        output: &ArrayIndexType,
49        size_dict: &SizeDictType,
50        memory_limit: Option<SizeType>,
51    ) -> Result<PathType, String> {
52        let mut optimizer: Box<dyn PathOptimizer> = match inputs.len() {
53            ..5 => Box::new(paths::optimal::Optimal::default()),
54            5..7 => Box::new(paths::branch_bound::BranchBound::from("branch-all")),
55            7..9 => Box::new(paths::branch_bound::BranchBound::from("branch-2")),
56            9..15 => Box::new(paths::branch_bound::BranchBound::from("branch-1")),
57            15.. => Box::new(paths::greedy::Greedy::default()),
58        };
59        optimizer.optimize_path(inputs, output, size_dict, memory_limit)
60    }
61}
62
63impl PathOptimizer for AutoHq {
64    fn optimize_path(
65        &mut self,
66        inputs: &[&ArrayIndexType],
67        output: &ArrayIndexType,
68        size_dict: &SizeDictType,
69        memory_limit: Option<SizeType>,
70    ) -> Result<PathType, String> {
71        let mut optimizer: Box<dyn PathOptimizer> = match inputs.len() {
72            ..6 => Box::new(paths::optimal::Optimal::default()),
73            6..20 => Box::new(paths::dp::DynamicProgramming::default()),
74            20.. => Box::new(paths::greedy_random::RandomGreedy::from("random-greedy-128")),
75        };
76        optimizer.optimize_path(inputs, output, size_dict, memory_limit)
77    }
78}
79
80impl OptimizeKind {
81    pub fn optimizer(&mut self) -> &mut dyn PathOptimizer {
82        use OptimizeKind::*;
83        match self {
84            Optimal(optimizer) => optimizer,
85            NoOptimize(optimizer) => optimizer,
86            BranchBound(optimizer) => optimizer,
87            Greedy(optimizer) => optimizer,
88            DynamicProgramming(optimizer) => optimizer,
89            RandomGreedy(optimizer) => optimizer,
90            Auto(optimizer) => optimizer,
91            AutoHq(optimizer) => optimizer,
92        }
93    }
94}
95
96impl PathOptimizer for OptimizeKind {
97    fn optimize_path(
98        &mut self,
99        inputs: &[&ArrayIndexType],
100        output: &ArrayIndexType,
101        size_dict: &SizeDictType,
102        memory_limit: Option<SizeType>,
103    ) -> Result<PathType, String> {
104        // capture panics from the optimizer and convert to error
105        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
106            self.optimizer().optimize_path(inputs, output, size_dict, memory_limit)
107        }))
108        .unwrap_or_else(|err| Err(format!("Optimizer panicked: {err:?}")))
109    }
110}
111
112impl FromStr for OptimizeKind {
113    type Err = String;
114    fn from_str(s: &str) -> Result<Self, Self::Err> {
115        use OptimizeKind::*;
116
117        // special handling for dp;
118        if s.starts_with("dp-") {
119            return Ok(DynamicProgramming(s.into()));
120        }
121        if s.starts_with("random-greedy") {
122            return Ok(RandomGreedy(s.into()));
123        }
124
125        // general handling
126        let optimizer = match s.replace(['_', ' '], "-").to_lowercase().as_str() {
127            "optimal" | "optimized" => Optimal(Default::default()),
128            "no-optimize" => NoOptimize(Default::default()),
129            "branch-all" => BranchBound(Default::default()),
130            "branch-2" => BranchBound("branch-2".into()),
131            "branch-1" => BranchBound("branch-1".into()),
132            "greedy" | "eager" | "opportunistic" => Greedy(Default::default()),
133            "dp" | "dynamic-programming" => DynamicProgramming(Default::default()),
134            "auto" => Auto(Default::default()),
135            "auto-hq" => AutoHq(Default::default()),
136            _ => Err(format!("Unknown optimization kind: {s}"))?,
137        };
138        Ok(optimizer)
139    }
140}
141
142impl From<&str> for OptimizeKind {
143    fn from(s: &str) -> Self {
144        OptimizeKind::from_str(s).unwrap()
145    }
146}
147
148impl From<bool> for OptimizeKind {
149    fn from(b: bool) -> Self {
150        match b {
151            true => "auto-hq".into(),
152            false => "no-optimize".into(),
153        }
154    }
155}
156
157impl From<Option<bool>> for OptimizeKind {
158    fn from(b: Option<bool>) -> Self {
159        b.unwrap_or(true).into()
160    }
161}
162
163impl PathOptimizer for &str {
164    fn optimize_path(
165        &mut self,
166        inputs: &[&ArrayIndexType],
167        output: &ArrayIndexType,
168        size_dict: &SizeDictType,
169        memory_limit: Option<SizeType>,
170    ) -> Result<PathType, String> {
171        let mut optimizer = OptimizeKind::from(*self);
172        optimizer.optimize_path(inputs, output, size_dict, memory_limit)
173    }
174}
175
176impl PathOptimizer for bool {
177    fn optimize_path(
178        &mut self,
179        inputs: &[&ArrayIndexType],
180        output: &ArrayIndexType,
181        size_dict: &SizeDictType,
182        memory_limit: Option<SizeType>,
183    ) -> Result<PathType, String> {
184        let mut optimizer = OptimizeKind::from(*self);
185        optimizer.optimize_path(inputs, output, size_dict, memory_limit)
186    }
187}
188
189/* #region special impl for PathType (Vec<Vec<usize>> or related) */
190
191impl PathOptimizer for PathType {
192    fn optimize_path(
193        &mut self,
194        inputs: &[&ArrayIndexType],
195        _output: &ArrayIndexType,
196        _size_dict: &SizeDictType,
197        _memory_limit: Option<SizeType>,
198    ) -> Result<PathType, String> {
199        // simple validation
200        {
201            let mut n = inputs.len();
202            for indices in self.iter() {
203                if indices.is_empty() {
204                    return Err("Empty path step found".to_string());
205                }
206                let mut indices = indices.to_vec();
207                indices.sort_unstable();
208                // check largest index is less than n
209                if indices.last().unwrap() >= &n {
210                    return Err(format!("Path step index {} out of bounds for {n} inputs", indices.last().unwrap()));
211                }
212                // update n by removing the contracted indices
213                n -= indices.len() - 1;
214            }
215            if n != 1 {
216                return Err(format!("Path does not reduce to single output, ended with {n} tensors"));
217            }
218        }
219        Ok(self.clone())
220    }
221}
222
223#[duplicate::duplicate_item(ImplType; [&[Vec<usize>]]; [&[&[usize]]]; [Vec<[usize; 2]>]; [&[[usize; 2]]])]
224impl PathOptimizer for ImplType {
225    fn optimize_path(
226        &mut self,
227        inputs: &[&ArrayIndexType],
228        output: &ArrayIndexType,
229        size_dict: &SizeDictType,
230        memory_limit: Option<SizeType>,
231    ) -> Result<PathType, String> {
232        let mut path = self.iter().map(|step| step.to_vec()).collect::<PathType>();
233        path.optimize_path(inputs, output, size_dict, memory_limit)
234    }
235}
236
237#[duplicate::duplicate_item(ImplType; [[Vec<usize>; N]]; [[[usize; 2]; N]])]
238impl<const N: usize> PathOptimizer for ImplType {
239    fn optimize_path(
240        &mut self,
241        inputs: &[&ArrayIndexType],
242        output: &ArrayIndexType,
243        size_dict: &SizeDictType,
244        memory_limit: Option<SizeType>,
245    ) -> Result<PathType, String> {
246        let mut path = self.iter().map(|step| step.to_vec()).collect::<PathType>();
247        path.optimize_path(inputs, output, size_dict, memory_limit)
248    }
249}
250
251#[duplicate::duplicate_item(ImplType; [Vec<(usize, usize)>]; [&[(usize, usize)]])]
252impl PathOptimizer for ImplType {
253    fn optimize_path(
254        &mut self,
255        inputs: &[&ArrayIndexType],
256        output: &ArrayIndexType,
257        size_dict: &SizeDictType,
258        memory_limit: Option<SizeType>,
259    ) -> Result<PathType, String> {
260        let mut path = self.iter().map(|step| vec![step.0, step.1]).collect::<PathType>();
261        path.optimize_path(inputs, output, size_dict, memory_limit)
262    }
263}
264
265#[duplicate::duplicate_item(ImplType; [[(usize, usize); N]])]
266impl<const N: usize> PathOptimizer for ImplType {
267    fn optimize_path(
268        &mut self,
269        inputs: &[&ArrayIndexType],
270        output: &ArrayIndexType,
271        size_dict: &SizeDictType,
272        memory_limit: Option<SizeType>,
273    ) -> Result<PathType, String> {
274        let mut path = self.iter().map(|step| vec![step.0, step.1]).collect::<PathType>();
275        path.optimize_path(inputs, output, size_dict, memory_limit)
276    }
277}
278
279/* #endregion */