Skip to main content

morok_schedule/rangeify/
context.rs

1//! Rangeify context for tracking state during transformation.
2
3use morok_ir::{UOp, UOpKey};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7/// Context for rangeify transformation.
8///
9/// Tracks state during the rangeify transformation, including:
10/// - Mapping from original UOps to their rangeified versions
11/// - Counter for generating unique range IDs
12#[derive(Default)]
13pub struct RangeifyContext {
14    /// Maps old UOps to their rangeified versions.
15    ///
16    /// This allows us to track how each node in the original graph
17    /// has been transformed during the rangeify process.
18    ///
19    /// Uses UOpKey for HashMap keys since Arc<UOp> doesn't implement Hash/Eq.
20    pub range_map: HashMap<UOpKey, Arc<UOp>>,
21
22    /// Counter for generating unique range IDs.
23    ///
24    /// Each RANGE operation needs a unique axis_id to distinguish
25    /// different loop dimensions. This counter ensures we never
26    /// reuse IDs within a single transformation.
27    pub range_counter: usize,
28}
29
30impl RangeifyContext {
31    /// Create a new empty rangeify context.
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Get the next available range ID.
37    ///
38    /// Increments the internal counter and returns the previous value.
39    /// This ensures each range gets a unique ID.
40    pub fn next_range_id(&mut self) -> usize {
41        let id = self.range_counter;
42        self.range_counter += 1;
43        id
44    }
45
46    /// Record that a UOp has been transformed.
47    ///
48    /// Maps the original UOp to its rangeified version so we can
49    /// track the transformation.
50    pub fn record_transform(&mut self, original: Arc<UOp>, rangeified: Arc<UOp>) {
51        self.range_map.insert(UOpKey(original), rangeified);
52    }
53
54    /// Get the rangeified version of a UOp, if it exists.
55    pub fn get_rangeified(&self, original: &Arc<UOp>) -> Option<&Arc<UOp>> {
56        self.range_map.get(&UOpKey(original.clone()))
57    }
58}