opt_einsum_path/paths/
greedy_random.rs

1// src/paths/random.rs
2use crate::paths::branch_bound::*;
3use crate::paths::greedy::*;
4use crate::*;
5use rand::prelude::*;
6use rand::rngs::StdRng;
7use std::time::{Duration, Instant};
8
9/// A contraction 'chooser' that weights possible contractions using a
10/// Boltzmann distribution. Explicitly, given costs `c_i` (with `c_0` the
11/// smallest), the relative weights, `w_i`, are computed as:
12///
13/// w_i = exp( -(c_i - c_0) / temperature)
14///
15/// Additionally, if `rel_temperature` is set, scale `temperature` by
16/// `abs(c_0)` to account for likely fluctuating cost magnitudes during the
17/// course of a contraction.
18///
19/// # Parameters
20///
21/// * `queue` - The heapified list of candidate contractions.
22/// * `remaining` - Mapping of remaining inputs' indices to the ssa id.
23/// * `rng` - Random number generator.
24/// * `nbranch` - How many potential paths to calculate probability for and choose from at each
25///   step.
26/// * `temperature` - When choosing a possible contraction, its relative probability will be
27///   proportional to `exp(-cost / temperature)`. Thus the larger `temperature` is, the further
28///   random paths will stray from the normal 'greedy' path. Conversely, if set to zero, only paths
29///   with exactly the same cost as the best at each step will be explored.
30/// * `rel_temperature` - Whether to normalize the `temperature` at each step to the scale of the
31///   best cost. This is generally beneficial as the magnitude of costs can vary significantly
32///   throughout a contraction.
33///
34/// # Returns
35///
36/// `Option<GreedyContractionType>` where Some contains the chosen contraction
37pub fn thermal_chooser(
38    queue: &mut BinaryHeap<GreedyContractionType>,
39    remaining: &BTreeMap<ArrayIndexType, usize>,
40    rng: &mut StdRng,
41    nbranch: usize,
42    temperature: f64,
43    rel_temperature: bool,
44) -> Option<GreedyContractionType> {
45    let mut choices = Vec::new();
46    let mut n = 0;
47
48    // Extract up to nbranch valid choices from the queue
49    while n < nbranch && !queue.is_empty() {
50        if let Some(candidate) = queue.pop() {
51            if remaining.contains_key(&candidate.k1) && remaining.contains_key(&candidate.k2) {
52                choices.push(candidate);
53                n += 1;
54            }
55        }
56    }
57
58    if choices.is_empty() {
59        return None;
60    }
61
62    if choices.len() == 1 {
63        return Some(choices.remove(0));
64    }
65
66    // Extract costs from choices
67    let costs: Vec<f64> = choices.iter().map(|c| c.cost.0.cost.to_f64().unwrap()).collect();
68
69    let cmin = costs.iter().cloned().fold(f64::INFINITY, |a, b| a.min(b));
70
71    // Adjust by the overall scale to account for fluctuating absolute costs
72    let effective_temperature = if rel_temperature { temperature * cmin.abs().max(1.0) } else { temperature };
73
74    // Compute relative probability for each potential contraction
75    let energies: Vec<f64> = if effective_temperature == 0.0 {
76        costs.iter().map(|&c| if c == cmin { 1.0 } else { 0.0 }).collect()
77    } else {
78        costs.iter().map(|&c| (-(c - cmin) / effective_temperature).exp()).collect()
79    };
80
81    // Randomly choose a contraction based on energies
82    let chosen_index = if energies.iter().sum::<f64>() > 0.0 {
83        let mut cumulative = 0.0;
84        let total: f64 = energies.iter().sum();
85        let rand_val: f64 = rng.random_range(0.0..total);
86
87        for (i, &energy) in energies.iter().enumerate() {
88            cumulative += energy;
89            if cumulative >= rand_val {
90                return Some(choices.remove(i));
91            }
92        }
93        0 // fallback
94    } else {
95        0
96    };
97
98    // Put the other choices back in the heap
99    for (i, choice) in choices.clone().into_iter().enumerate() {
100        if i != chosen_index {
101            queue.push(choice);
102        }
103    }
104
105    Some(choices.remove(chosen_index))
106}
107
108/// Compute the flops and max size of an ssa path.
109pub fn ssa_path_compute_cost(
110    ssa_path: &PathType,
111    inputs: &[&ArrayIndexType],
112    output: &ArrayIndexType,
113    size_dict: &SizeDictType,
114) -> (SizeType, SizeType) {
115    let mut inputs = inputs.iter().map(|x| (*x).clone()).collect_vec();
116    let mut remaining: BTreeSet<usize> = (0..inputs.len()).collect();
117    let mut total_cost = SizeType::zero();
118    let mut max_size = SizeType::zero();
119
120    for contraction in ssa_path {
121        if contraction.len() < 2 {
122            continue;
123        }
124
125        let i = contraction[0];
126        let j = contraction[1];
127
128        let inputs_ref = inputs.iter().collect_vec();
129        let (k12, flops12) =
130            paths::util::calc_k12_flops(&inputs_ref, output, &remaining.iter().cloned().collect_vec(), i, j, size_dict);
131
132        let size12 = helpers::compute_size_by_dict(k12.iter(), size_dict);
133        total_cost += flops12;
134        max_size = max_size.max(size12);
135
136        remaining.remove(&i);
137        remaining.remove(&j);
138        remaining.insert(inputs.len());
139        inputs.push(k12);
140    }
141
142    (total_cost, max_size)
143}
144
145/// Configuration for random greedy optimization
146#[derive(Debug, Clone)]
147pub struct RandomGreedyConfig {
148    pub max_repeats: usize,
149    pub max_time: Option<Duration>,
150    pub minimize: MinimizeStrategy,
151    pub cost_fn: &'static str,
152    pub temperature: f64,
153    pub rel_temperature: bool,
154    pub nbranch: usize,
155}
156
157impl Default for RandomGreedyConfig {
158    fn default() -> Self {
159        Self {
160            max_repeats: 32,
161            max_time: None,
162            minimize: MinimizeStrategy::FlopsFirst,
163            cost_fn: "memory-removed-jitter",
164            temperature: 1.0,
165            rel_temperature: true,
166            nbranch: 8,
167        }
168    }
169}
170
171/// Random greedy path optimizer
172#[derive(Debug, Clone)]
173pub struct RandomGreedy {
174    pub config: RandomGreedyConfig,
175    pub best_flops: SizeType,
176    pub best_size: SizeType,
177    pub best_ssa_path: Option<PathType>,
178    pub costs: Vec<SizeType>,
179    pub sizes: Vec<SizeType>,
180    pub repeats_start: usize,
181}
182
183impl Default for RandomGreedy {
184    fn default() -> Self {
185        Self {
186            config: RandomGreedyConfig::default(),
187            best_flops: SizeType::MAX,
188            best_size: SizeType::MAX,
189            best_ssa_path: None,
190            costs: Vec::new(),
191            sizes: Vec::new(),
192            repeats_start: 0,
193        }
194    }
195}
196
197impl RandomGreedy {
198    /// Create a new RandomGreedy optimizer with custom configuration
199    pub fn new(config: RandomGreedyConfig) -> Self {
200        Self { config, ..Default::default() }
201    }
202
203    /// Get the best path found so far
204    pub fn path(&self) -> PathType {
205        self.best_ssa_path.as_ref().map_or_else(Vec::new, |p| paths::util::ssa_to_linear(p))
206    }
207
208    /// Run a single trial of greedy optimization
209    fn run_trial(
210        config: &RandomGreedyConfig,
211        r: usize,
212        inputs: &[&ArrayIndexType],
213        output: &ArrayIndexType,
214        size_dict: &SizeDictType,
215    ) -> (PathType, SizeType, SizeType) {
216        let mut rng = StdRng::seed_from_u64(r as u64);
217        // For the first trial, use standard greedy approach
218        let nbranch = config.nbranch;
219        let temperature = config.temperature;
220        let rel_temperature = config.rel_temperature;
221        let thermal_chooser_fn: GreedyChooseFn = Box::new({
222            move |queue, remaining| thermal_chooser(queue, remaining, &mut rng, nbranch, temperature, rel_temperature)
223        });
224        let mut choose_fn = if r == 0 { Some(thermal_chooser_fn) } else { None };
225
226        let cost_fn = match config.cost_fn {
227            "memory-removed-jitter" => Some(paths::util::memory_removed(true)),
228            _ => Some(paths::util::memory_removed(false)),
229        };
230
231        let ssa_path = paths::greedy::ssa_greedy_optimize(inputs, output, size_dict, choose_fn.as_mut(), cost_fn);
232
233        let (cost, size) = ssa_path_compute_cost(&ssa_path, inputs, output, size_dict);
234
235        (ssa_path, cost, size)
236    }
237}
238
239impl PathOptimizer for RandomGreedy {
240    fn optimize_path(
241        &mut self,
242        inputs: &[&ArrayIndexType],
243        output: &ArrayIndexType,
244        size_dict: &SizeDictType,
245        memory_limit: Option<SizeType>,
246    ) -> Result<PathType, String> {
247        // Handle memory limit by falling back to branch bound
248        if memory_limit.is_some() {
249            let mut branch_optimizer = paths::branch_bound::BranchBound::from("branch-1");
250            return branch_optimizer.optimize_path(inputs, output, size_dict, memory_limit);
251        }
252
253        let start_time = Instant::now();
254        let better_fn = paths::branch_bound::get_better_fn(self.config.minimize);
255
256        let r_start = self.repeats_start + self.costs.len();
257        let r_end = r_start + self.config.max_repeats;
258
259        #[cfg(feature = "par_rand")]
260        use rayon::prelude::*;
261        #[cfg(feature = "par_rand")]
262        let r_iter = (r_start..r_end).into_par_iter();
263        #[cfg(not(feature = "par_rand"))]
264        let r_iter = r_start..r_end;
265
266        let trials: Vec<_> = r_iter
267            .map(|r| {
268                // Check if we've run out of time
269                if self.config.max_time.is_some_and(|max_time| start_time.elapsed() > max_time) {
270                    None
271                } else {
272                    Some(RandomGreedy::run_trial(&self.config, r, inputs, output, size_dict))
273                }
274            })
275            .collect();
276
277        for (ssa_path, cost, size) in trials.into_iter().flatten() {
278            // Keep track of all costs and sizes
279            self.costs.push(cost);
280            self.sizes.push(size);
281
282            // Check if we have found a new best
283            let found_new_best = better_fn(
284                cost.to_f64().unwrap(),
285                size.to_f64().unwrap(),
286                self.best_flops.to_f64().unwrap(),
287                self.best_size.to_f64().unwrap(),
288            );
289
290            if found_new_best {
291                self.best_flops = cost;
292                self.best_size = size;
293                self.best_ssa_path = Some(ssa_path);
294            }
295        }
296
297        Ok(self.path())
298    }
299}
300
301/// Convenience function for random greedy optimization
302pub fn random_greedy(
303    inputs: &[&ArrayIndexType],
304    output: &ArrayIndexType,
305    size_dict: &SizeDictType,
306    memory_limit: Option<SizeType>,
307    config: RandomGreedyConfig,
308) -> Result<PathType, String> {
309    let mut optimizer = RandomGreedy::new(config);
310    optimizer.optimize_path(inputs, output, size_dict, memory_limit)
311}
312
313/// Pre-configured random greedy with 128 repeats
314pub fn random_greedy_128(
315    inputs: &[&ArrayIndexType],
316    output: &ArrayIndexType,
317    size_dict: &SizeDictType,
318    memory_limit: Option<SizeType>,
319) -> Result<PathType, String> {
320    let config = RandomGreedyConfig { max_repeats: 128, ..RandomGreedyConfig::default() };
321    random_greedy(inputs, output, size_dict, memory_limit, config)
322}
323
324impl From<&str> for RandomGreedy {
325    fn from(s: &str) -> Self {
326        let s = s.trim().replace(['_', ' '], "-").to_lowercase();
327        assert!(s.starts_with("random-greedy"), "RandomGreedy must start with 'random-greedy'");
328        let v = s.strip_prefix("random-greedy").unwrap();
329        if v.is_empty() {
330            RandomGreedy::default()
331        } else {
332            let max_repeats = v.replace("-", "").parse::<usize>().unwrap_or_else(|_| {
333                panic!("Invalid RandomGreedy configuration: {s}. Expected format: 'random-greedy-<max_repeats>'")
334            });
335            let config = RandomGreedyConfig { max_repeats, ..RandomGreedyConfig::default() };
336            RandomGreedy::new(config)
337        }
338    }
339}