1use crate::paths::branch_bound::*;
3use crate::paths::greedy::*;
4use crate::*;
5use rand::prelude::*;
6use rand::rngs::StdRng;
7use std::time::{Duration, Instant};
8
9pub fn thermal_chooser(
38 queue: &mut BinaryHeap<GreedyContractionType>,
39 remaining: &BTreeMap<ArrayIndexType, usize>,
40 rng: &mut StdRng,
41 nbranch: usize,
42 temperature: f64,
43 rel_temperature: bool,
44) -> Option<GreedyContractionType> {
45 let mut choices = Vec::new();
46 let mut n = 0;
47
48 while n < nbranch && !queue.is_empty() {
50 if let Some(candidate) = queue.pop() {
51 if remaining.contains_key(&candidate.k1) && remaining.contains_key(&candidate.k2) {
52 choices.push(candidate);
53 n += 1;
54 }
55 }
56 }
57
58 if choices.is_empty() {
59 return None;
60 }
61
62 if choices.len() == 1 {
63 return Some(choices.remove(0));
64 }
65
66 let costs: Vec<f64> = choices.iter().map(|c| c.cost.0.cost.to_f64().unwrap()).collect();
68
69 let cmin = costs.iter().cloned().fold(f64::INFINITY, |a, b| a.min(b));
70
71 let effective_temperature = if rel_temperature { temperature * cmin.abs().max(1.0) } else { temperature };
73
74 let energies: Vec<f64> = if effective_temperature == 0.0 {
76 costs.iter().map(|&c| if c == cmin { 1.0 } else { 0.0 }).collect()
77 } else {
78 costs.iter().map(|&c| (-(c - cmin) / effective_temperature).exp()).collect()
79 };
80
81 let chosen_index = if energies.iter().sum::<f64>() > 0.0 {
83 let mut cumulative = 0.0;
84 let total: f64 = energies.iter().sum();
85 let rand_val: f64 = rng.random_range(0.0..total);
86
87 for (i, &energy) in energies.iter().enumerate() {
88 cumulative += energy;
89 if cumulative >= rand_val {
90 return Some(choices.remove(i));
91 }
92 }
93 0 } else {
95 0
96 };
97
98 for (i, choice) in choices.clone().into_iter().enumerate() {
100 if i != chosen_index {
101 queue.push(choice);
102 }
103 }
104
105 Some(choices.remove(chosen_index))
106}
107
108pub fn ssa_path_compute_cost(
110 ssa_path: &PathType,
111 inputs: &[&ArrayIndexType],
112 output: &ArrayIndexType,
113 size_dict: &SizeDictType,
114) -> (SizeType, SizeType) {
115 let mut inputs = inputs.iter().map(|x| (*x).clone()).collect_vec();
116 let mut remaining: BTreeSet<usize> = (0..inputs.len()).collect();
117 let mut total_cost = SizeType::zero();
118 let mut max_size = SizeType::zero();
119
120 for contraction in ssa_path {
121 if contraction.len() < 2 {
122 continue;
123 }
124
125 let i = contraction[0];
126 let j = contraction[1];
127
128 let inputs_ref = inputs.iter().collect_vec();
129 let (k12, flops12) =
130 paths::util::calc_k12_flops(&inputs_ref, output, &remaining.iter().cloned().collect_vec(), i, j, size_dict);
131
132 let size12 = helpers::compute_size_by_dict(k12.iter(), size_dict);
133 total_cost += flops12;
134 max_size = max_size.max(size12);
135
136 remaining.remove(&i);
137 remaining.remove(&j);
138 remaining.insert(inputs.len());
139 inputs.push(k12);
140 }
141
142 (total_cost, max_size)
143}
144
145#[derive(Debug, Clone)]
147pub struct RandomGreedyConfig {
148 pub max_repeats: usize,
149 pub max_time: Option<Duration>,
150 pub minimize: MinimizeStrategy,
151 pub cost_fn: &'static str,
152 pub temperature: f64,
153 pub rel_temperature: bool,
154 pub nbranch: usize,
155}
156
157impl Default for RandomGreedyConfig {
158 fn default() -> Self {
159 Self {
160 max_repeats: 32,
161 max_time: None,
162 minimize: MinimizeStrategy::FlopsFirst,
163 cost_fn: "memory-removed-jitter",
164 temperature: 1.0,
165 rel_temperature: true,
166 nbranch: 8,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct RandomGreedy {
174 pub config: RandomGreedyConfig,
175 pub best_flops: SizeType,
176 pub best_size: SizeType,
177 pub best_ssa_path: Option<PathType>,
178 pub costs: Vec<SizeType>,
179 pub sizes: Vec<SizeType>,
180 pub repeats_start: usize,
181}
182
183impl Default for RandomGreedy {
184 fn default() -> Self {
185 Self {
186 config: RandomGreedyConfig::default(),
187 best_flops: SizeType::MAX,
188 best_size: SizeType::MAX,
189 best_ssa_path: None,
190 costs: Vec::new(),
191 sizes: Vec::new(),
192 repeats_start: 0,
193 }
194 }
195}
196
197impl RandomGreedy {
198 pub fn new(config: RandomGreedyConfig) -> Self {
200 Self { config, ..Default::default() }
201 }
202
203 pub fn path(&self) -> PathType {
205 self.best_ssa_path.as_ref().map_or_else(Vec::new, |p| paths::util::ssa_to_linear(p))
206 }
207
208 fn run_trial(
210 config: &RandomGreedyConfig,
211 r: usize,
212 inputs: &[&ArrayIndexType],
213 output: &ArrayIndexType,
214 size_dict: &SizeDictType,
215 ) -> (PathType, SizeType, SizeType) {
216 let mut rng = StdRng::seed_from_u64(r as u64);
217 let nbranch = config.nbranch;
219 let temperature = config.temperature;
220 let rel_temperature = config.rel_temperature;
221 let thermal_chooser_fn: GreedyChooseFn = Box::new({
222 move |queue, remaining| thermal_chooser(queue, remaining, &mut rng, nbranch, temperature, rel_temperature)
223 });
224 let mut choose_fn = if r == 0 { Some(thermal_chooser_fn) } else { None };
225
226 let cost_fn = match config.cost_fn {
227 "memory-removed-jitter" => Some(paths::util::memory_removed(true)),
228 _ => Some(paths::util::memory_removed(false)),
229 };
230
231 let ssa_path = paths::greedy::ssa_greedy_optimize(inputs, output, size_dict, choose_fn.as_mut(), cost_fn);
232
233 let (cost, size) = ssa_path_compute_cost(&ssa_path, inputs, output, size_dict);
234
235 (ssa_path, cost, size)
236 }
237}
238
239impl PathOptimizer for RandomGreedy {
240 fn optimize_path(
241 &mut self,
242 inputs: &[&ArrayIndexType],
243 output: &ArrayIndexType,
244 size_dict: &SizeDictType,
245 memory_limit: Option<SizeType>,
246 ) -> Result<PathType, String> {
247 if memory_limit.is_some() {
249 let mut branch_optimizer = paths::branch_bound::BranchBound::from("branch-1");
250 return branch_optimizer.optimize_path(inputs, output, size_dict, memory_limit);
251 }
252
253 let start_time = Instant::now();
254 let better_fn = paths::branch_bound::get_better_fn(self.config.minimize);
255
256 let r_start = self.repeats_start + self.costs.len();
257 let r_end = r_start + self.config.max_repeats;
258
259 #[cfg(feature = "par_rand")]
260 use rayon::prelude::*;
261 #[cfg(feature = "par_rand")]
262 let r_iter = (r_start..r_end).into_par_iter();
263 #[cfg(not(feature = "par_rand"))]
264 let r_iter = r_start..r_end;
265
266 let trials: Vec<_> = r_iter
267 .map(|r| {
268 if self.config.max_time.is_some_and(|max_time| start_time.elapsed() > max_time) {
270 None
271 } else {
272 Some(RandomGreedy::run_trial(&self.config, r, inputs, output, size_dict))
273 }
274 })
275 .collect();
276
277 for (ssa_path, cost, size) in trials.into_iter().flatten() {
278 self.costs.push(cost);
280 self.sizes.push(size);
281
282 let found_new_best = better_fn(
284 cost.to_f64().unwrap(),
285 size.to_f64().unwrap(),
286 self.best_flops.to_f64().unwrap(),
287 self.best_size.to_f64().unwrap(),
288 );
289
290 if found_new_best {
291 self.best_flops = cost;
292 self.best_size = size;
293 self.best_ssa_path = Some(ssa_path);
294 }
295 }
296
297 Ok(self.path())
298 }
299}
300
301pub fn random_greedy(
303 inputs: &[&ArrayIndexType],
304 output: &ArrayIndexType,
305 size_dict: &SizeDictType,
306 memory_limit: Option<SizeType>,
307 config: RandomGreedyConfig,
308) -> Result<PathType, String> {
309 let mut optimizer = RandomGreedy::new(config);
310 optimizer.optimize_path(inputs, output, size_dict, memory_limit)
311}
312
313pub fn random_greedy_128(
315 inputs: &[&ArrayIndexType],
316 output: &ArrayIndexType,
317 size_dict: &SizeDictType,
318 memory_limit: Option<SizeType>,
319) -> Result<PathType, String> {
320 let config = RandomGreedyConfig { max_repeats: 128, ..RandomGreedyConfig::default() };
321 random_greedy(inputs, output, size_dict, memory_limit, config)
322}
323
324impl From<&str> for RandomGreedy {
325 fn from(s: &str) -> Self {
326 let s = s.trim().replace(['_', ' '], "-").to_lowercase();
327 assert!(s.starts_with("random-greedy"), "RandomGreedy must start with 'random-greedy'");
328 let v = s.strip_prefix("random-greedy").unwrap();
329 if v.is_empty() {
330 RandomGreedy::default()
331 } else {
332 let max_repeats = v.replace("-", "").parse::<usize>().unwrap_or_else(|_| {
333 panic!("Invalid RandomGreedy configuration: {s}. Expected format: 'random-greedy-<max_repeats>'")
334 });
335 let config = RandomGreedyConfig { max_repeats, ..RandomGreedyConfig::default() };
336 RandomGreedy::new(config)
337 }
338 }
339}