1pub mod branch_bound;
4pub mod dp;
5pub mod greedy;
6pub mod greedy_random;
7pub mod no_optimize;
8pub mod optimal;
9pub mod util;
10
11use crate::*;
12use paths::util::*;
13use std::str::FromStr;
14
15pub trait PathOptimizer {
16 fn optimize_path(
17 &mut self,
18 inputs: &[&ArrayIndexType],
19 output: &ArrayIndexType,
20 size_dict: &SizeDictType,
21 memory_limit: Option<SizeType>,
22 ) -> Result<PathType, String>;
23}
24
25#[non_exhaustive]
26#[derive(Debug)]
27pub enum OptimizeKind {
28 Optimal(paths::optimal::Optimal),
29 NoOptimize(paths::no_optimize::NoOptimize),
30 BranchBound(paths::branch_bound::BranchBound),
31 Greedy(paths::greedy::Greedy),
32 DynamicProgramming(paths::dp::DynamicProgramming),
33 RandomGreedy(paths::greedy_random::RandomGreedy),
34 Auto(Auto),
35 AutoHq(AutoHq),
36}
37
38#[derive(Debug, Default, Clone)]
39pub struct Auto {}
40
41#[derive(Debug, Default, Clone)]
42pub struct AutoHq {}
43
44impl PathOptimizer for Auto {
45 fn optimize_path(
46 &mut self,
47 inputs: &[&ArrayIndexType],
48 output: &ArrayIndexType,
49 size_dict: &SizeDictType,
50 memory_limit: Option<SizeType>,
51 ) -> Result<PathType, String> {
52 let mut optimizer: Box<dyn PathOptimizer> = match inputs.len() {
53 ..5 => Box::new(paths::optimal::Optimal::default()),
54 5..7 => Box::new(paths::branch_bound::BranchBound::from("branch-all")),
55 7..9 => Box::new(paths::branch_bound::BranchBound::from("branch-2")),
56 9..15 => Box::new(paths::branch_bound::BranchBound::from("branch-1")),
57 15.. => Box::new(paths::greedy::Greedy::default()),
58 };
59 optimizer.optimize_path(inputs, output, size_dict, memory_limit)
60 }
61}
62
63impl PathOptimizer for AutoHq {
64 fn optimize_path(
65 &mut self,
66 inputs: &[&ArrayIndexType],
67 output: &ArrayIndexType,
68 size_dict: &SizeDictType,
69 memory_limit: Option<SizeType>,
70 ) -> Result<PathType, String> {
71 let mut optimizer: Box<dyn PathOptimizer> = match inputs.len() {
72 ..6 => Box::new(paths::optimal::Optimal::default()),
73 6..20 => Box::new(paths::dp::DynamicProgramming::default()),
74 20.. => Box::new(paths::greedy_random::RandomGreedy::from("random-greedy-128")),
75 };
76 optimizer.optimize_path(inputs, output, size_dict, memory_limit)
77 }
78}
79
80impl OptimizeKind {
81 pub fn optimizer(&mut self) -> &mut dyn PathOptimizer {
82 use OptimizeKind::*;
83 match self {
84 Optimal(optimizer) => optimizer,
85 NoOptimize(optimizer) => optimizer,
86 BranchBound(optimizer) => optimizer,
87 Greedy(optimizer) => optimizer,
88 DynamicProgramming(optimizer) => optimizer,
89 RandomGreedy(optimizer) => optimizer,
90 Auto(optimizer) => optimizer,
91 AutoHq(optimizer) => optimizer,
92 }
93 }
94}
95
96impl PathOptimizer for OptimizeKind {
97 fn optimize_path(
98 &mut self,
99 inputs: &[&ArrayIndexType],
100 output: &ArrayIndexType,
101 size_dict: &SizeDictType,
102 memory_limit: Option<SizeType>,
103 ) -> Result<PathType, String> {
104 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
106 self.optimizer().optimize_path(inputs, output, size_dict, memory_limit)
107 }))
108 .unwrap_or_else(|err| Err(format!("Optimizer panicked: {err:?}")))
109 }
110}
111
112impl FromStr for OptimizeKind {
113 type Err = String;
114 fn from_str(s: &str) -> Result<Self, Self::Err> {
115 use OptimizeKind::*;
116
117 if s.starts_with("dp-") {
119 return Ok(DynamicProgramming(s.into()));
120 }
121 if s.starts_with("random-greedy") {
122 return Ok(RandomGreedy(s.into()));
123 }
124
125 let optimizer = match s.replace(['_', ' '], "-").to_lowercase().as_str() {
127 "optimal" | "optimized" => Optimal(Default::default()),
128 "no-optimize" => NoOptimize(Default::default()),
129 "branch-all" => BranchBound(Default::default()),
130 "branch-2" => BranchBound("branch-2".into()),
131 "branch-1" => BranchBound("branch-1".into()),
132 "greedy" | "eager" | "opportunistic" => Greedy(Default::default()),
133 "dp" | "dynamic-programming" => DynamicProgramming(Default::default()),
134 "auto" => Auto(Default::default()),
135 "auto-hq" => AutoHq(Default::default()),
136 _ => Err(format!("Unknown optimization kind: {s}"))?,
137 };
138 Ok(optimizer)
139 }
140}
141
142impl From<&str> for OptimizeKind {
143 fn from(s: &str) -> Self {
144 OptimizeKind::from_str(s).unwrap()
145 }
146}
147
148impl From<bool> for OptimizeKind {
149 fn from(b: bool) -> Self {
150 match b {
151 true => "auto-hq".into(),
152 false => "no-optimize".into(),
153 }
154 }
155}
156
157impl From<Option<bool>> for OptimizeKind {
158 fn from(b: Option<bool>) -> Self {
159 b.unwrap_or(true).into()
160 }
161}
162
163impl PathOptimizer for &str {
164 fn optimize_path(
165 &mut self,
166 inputs: &[&ArrayIndexType],
167 output: &ArrayIndexType,
168 size_dict: &SizeDictType,
169 memory_limit: Option<SizeType>,
170 ) -> Result<PathType, String> {
171 let mut optimizer = OptimizeKind::from(*self);
172 optimizer.optimize_path(inputs, output, size_dict, memory_limit)
173 }
174}
175
176impl PathOptimizer for bool {
177 fn optimize_path(
178 &mut self,
179 inputs: &[&ArrayIndexType],
180 output: &ArrayIndexType,
181 size_dict: &SizeDictType,
182 memory_limit: Option<SizeType>,
183 ) -> Result<PathType, String> {
184 let mut optimizer = OptimizeKind::from(*self);
185 optimizer.optimize_path(inputs, output, size_dict, memory_limit)
186 }
187}
188
189impl PathOptimizer for PathType {
192 fn optimize_path(
193 &mut self,
194 inputs: &[&ArrayIndexType],
195 _output: &ArrayIndexType,
196 _size_dict: &SizeDictType,
197 _memory_limit: Option<SizeType>,
198 ) -> Result<PathType, String> {
199 {
201 let mut n = inputs.len();
202 for indices in self.iter() {
203 if indices.is_empty() {
204 return Err("Empty path step found".to_string());
205 }
206 let mut indices = indices.to_vec();
207 indices.sort_unstable();
208 if indices.last().unwrap() >= &n {
210 return Err(format!("Path step index {} out of bounds for {n} inputs", indices.last().unwrap()));
211 }
212 n -= indices.len() - 1;
214 }
215 if n != 1 {
216 return Err(format!("Path does not reduce to single output, ended with {n} tensors"));
217 }
218 }
219 Ok(self.clone())
220 }
221}
222
223#[duplicate::duplicate_item(ImplType; [&[Vec<usize>]]; [&[&[usize]]]; [Vec<[usize; 2]>]; [&[[usize; 2]]])]
224impl PathOptimizer for ImplType {
225 fn optimize_path(
226 &mut self,
227 inputs: &[&ArrayIndexType],
228 output: &ArrayIndexType,
229 size_dict: &SizeDictType,
230 memory_limit: Option<SizeType>,
231 ) -> Result<PathType, String> {
232 let mut path = self.iter().map(|step| step.to_vec()).collect::<PathType>();
233 path.optimize_path(inputs, output, size_dict, memory_limit)
234 }
235}
236
237#[duplicate::duplicate_item(ImplType; [[Vec<usize>; N]]; [[[usize; 2]; N]])]
238impl<const N: usize> PathOptimizer for ImplType {
239 fn optimize_path(
240 &mut self,
241 inputs: &[&ArrayIndexType],
242 output: &ArrayIndexType,
243 size_dict: &SizeDictType,
244 memory_limit: Option<SizeType>,
245 ) -> Result<PathType, String> {
246 let mut path = self.iter().map(|step| step.to_vec()).collect::<PathType>();
247 path.optimize_path(inputs, output, size_dict, memory_limit)
248 }
249}
250
251#[duplicate::duplicate_item(ImplType; [Vec<(usize, usize)>]; [&[(usize, usize)]])]
252impl PathOptimizer for ImplType {
253 fn optimize_path(
254 &mut self,
255 inputs: &[&ArrayIndexType],
256 output: &ArrayIndexType,
257 size_dict: &SizeDictType,
258 memory_limit: Option<SizeType>,
259 ) -> Result<PathType, String> {
260 let mut path = self.iter().map(|step| vec![step.0, step.1]).collect::<PathType>();
261 path.optimize_path(inputs, output, size_dict, memory_limit)
262 }
263}
264
265#[duplicate::duplicate_item(ImplType; [[(usize, usize); N]])]
266impl<const N: usize> PathOptimizer for ImplType {
267 fn optimize_path(
268 &mut self,
269 inputs: &[&ArrayIndexType],
270 output: &ArrayIndexType,
271 size_dict: &SizeDictType,
272 memory_limit: Option<SizeType>,
273 ) -> Result<PathType, String> {
274 let mut path = self.iter().map(|step| vec![step.0, step.1]).collect::<PathType>();
275 path.optimize_path(inputs, output, size_dict, memory_limit)
276 }
277}
278
279