cubecl_ir/
allocator.rs

1use alloc::{rc::Rc, vec::Vec};
2use core::cell::RefCell;
3
4use hashbrown::HashMap;
5use portable_atomic::{AtomicU32, Ordering};
6
7use crate::{BarrierLevel, SemanticType};
8
9use super::{Matrix, Type, Variable, VariableKind};
10
11/// An allocator for local variables of a kernel.
12///
13/// A local variable is unique to a unit. That is, each unit have their own copy of a local variable.
14/// There are three types of local variables based on their capabilities.
15///     - An immutable local variable is obtained by calling [Allocator::create_local].
16///     - A mutable local variable is obtained by calling [Allocator::create_local_mut]. The allocator will reuse
17///       previously defined mutable variables if possible.
18///     - A restricted mutable local variable is obtained by calling [Allocator::create_local_restricted]. This a is
19///       mutable variable that cannot be reused. This is mostly used for loop indices.
20///
21/// # Performance tips
22///
23/// In order, prefer immutable local variables, then mutable, then restricted.
24///
25/// To enable many compiler optimizations, it is preferred to use the [static single-assignment] strategy for immutable variables.
26/// That is, each variable must be declared and used exactly once.
27///
28/// [static single-assignment](https://en.wikipedia.org/wiki/Static_single-assignment_form)
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30#[derive(Clone, Debug, Default, TypeHash)]
31pub struct Allocator {
32    #[cfg_attr(feature = "serde", serde(skip))]
33    local_mut_pool: Rc<RefCell<HashMap<Type, Vec<ExpandElement>>>>,
34    next_id: Rc<AtomicU32>,
35}
36
37impl PartialEq for Allocator {
38    fn eq(&self, other: &Self) -> bool {
39        Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
40            && Rc::ptr_eq(&self.next_id, &other.next_id)
41    }
42}
43impl Eq for Allocator {}
44
45impl Allocator {
46    /// Create a new immutable local variable of type specified by `item`.
47    pub fn create_local(&self, item: Type) -> ExpandElement {
48        let id = self.new_local_index();
49        let local = VariableKind::LocalConst { id };
50        ExpandElement::Plain(Variable::new(local, item))
51    }
52
53    /// Create a new mutable local variable of type specified by `item`.
54    /// Try to reuse a previously defined but unused mutable variable if possible.
55    /// Else, this define a new variable.
56    pub fn create_local_mut(&self, item: Type) -> ExpandElement {
57        if item.is_atomic() {
58            self.create_local_restricted(item)
59        } else {
60            self.reuse_local_mut(item)
61                .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(item)))
62        }
63    }
64
65    /// Create a new mutable restricted local variable of type specified by `item`.
66    pub fn create_local_restricted(&self, item: Type) -> ExpandElement {
67        let id = self.new_local_index();
68        let local = VariableKind::LocalMut { id };
69        ExpandElement::Plain(Variable::new(local, item))
70    }
71
72    pub fn create_local_array(&self, item: Type, array_size: u32) -> ExpandElement {
73        let id = self.new_local_index();
74        let local_array = Variable::new(
75            VariableKind::LocalArray {
76                id,
77                length: array_size,
78                unroll_factor: 1,
79            },
80            item,
81        );
82        ExpandElement::Plain(local_array)
83    }
84
85    /// Create a matrix variable
86    pub fn create_matrix(&self, matrix: Matrix) -> ExpandElement {
87        let id = self.new_local_index();
88        let variable = Variable::new(
89            VariableKind::Matrix { id, mat: matrix },
90            Type::new(matrix.storage),
91        );
92        ExpandElement::Plain(variable)
93    }
94
95    pub fn create_pipeline(&self, num_stages: u8) -> ExpandElement {
96        let id = self.new_local_index();
97        let variable = Variable::new(
98            VariableKind::Pipeline { id, num_stages },
99            SemanticType::Pipeline.into(),
100        );
101        ExpandElement::Plain(variable)
102    }
103
104    pub fn create_barrier(&self, level: BarrierLevel) -> ExpandElement {
105        let id = self.new_local_index();
106        // Dummy elem for now, awaiting a rework to item to include non-native conceptual types
107        let variable = Variable::new(
108            VariableKind::Barrier { id, level },
109            SemanticType::Barrier.into(),
110        );
111        ExpandElement::Plain(variable)
112    }
113
114    // Try to return a reusable mutable variable for the given `item` or `None` otherwise.
115    pub fn reuse_local_mut(&self, item: Type) -> Option<ExpandElement> {
116        // Among the candidates, take a variable if it's only referenced by the pool.
117        // Arbitrarily takes the first it finds in reversed order.
118        self.local_mut_pool.borrow().get(&item).and_then(|vars| {
119            vars.iter()
120                .rev()
121                .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1))
122                .cloned()
123        })
124    }
125
126    /// Add a new variable to the pool with type specified by `item` for the given `scope`.
127    pub fn add_local_mut(&self, item: Type) -> Rc<Variable> {
128        let id = self.new_local_index();
129        let local = Variable::new(VariableKind::LocalMut { id }, item);
130        let var = Rc::new(local);
131        let expand = ExpandElement::Managed(var.clone());
132        let mut pool = self.local_mut_pool.borrow_mut();
133        let variables = pool.entry(item).or_default();
134        variables.push(expand);
135        var
136    }
137
138    pub fn new_local_index(&self) -> u32 {
139        self.next_id.fetch_add(1, Ordering::Release)
140    }
141
142    pub fn take_variables(&self) -> Vec<Variable> {
143        self.local_mut_pool
144            .borrow_mut()
145            .drain()
146            .flat_map(|it| it.1)
147            .map(|it| *it)
148            .collect()
149    }
150}
151
152use cubecl_macros_internal::TypeHash;
153pub use expand_element::*;
154
155mod expand_element {
156    use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
157    use half::{bf16, f16};
158
159    use super::*;
160
161    /// Reference to a JIT variable
162    #[derive(Clone, Debug, TypeHash)]
163    pub enum ExpandElement {
164        /// Variable kept in the variable pool.
165        Managed(Rc<Variable>),
166        /// Variable not kept in the variable pool.
167        Plain(Variable),
168    }
169
170    impl core::ops::Deref for ExpandElement {
171        type Target = Variable;
172
173        fn deref(&self) -> &Self::Target {
174            match self {
175                ExpandElement::Managed(var) => var.as_ref(),
176                ExpandElement::Plain(var) => var,
177            }
178        }
179    }
180
181    impl From<ExpandElement> for Variable {
182        fn from(value: ExpandElement) -> Self {
183            match value {
184                ExpandElement::Managed(var) => *var,
185                ExpandElement::Plain(var) => var,
186            }
187        }
188    }
189
190    impl ExpandElement {
191        /// If the element can be mutated inplace, potentially reusing the register.
192        pub fn can_mut(&self) -> bool {
193            match self {
194                ExpandElement::Managed(var) => {
195                    if let VariableKind::LocalMut { .. } = var.as_ref().kind {
196                        Rc::strong_count(var) <= 2
197                    } else {
198                        false
199                    }
200                }
201                ExpandElement::Plain(_) => false,
202            }
203        }
204
205        /// Explicitly consume the element, freeing it for reuse if no other copies exist.
206        pub fn consume(self) -> Variable {
207            *self
208        }
209    }
210
211    macro_rules! impl_into_expand_element {
212        ($type:ty) => {
213            impl From<$type> for ExpandElement {
214                fn from(value: $type) -> Self {
215                    ExpandElement::Plain(Variable::from(value))
216                }
217            }
218        };
219    }
220
221    impl_into_expand_element!(u8);
222    impl_into_expand_element!(u16);
223    impl_into_expand_element!(u32);
224    impl_into_expand_element!(u64);
225    impl_into_expand_element!(usize);
226    impl_into_expand_element!(bool);
227    impl_into_expand_element!(e2m1);
228    impl_into_expand_element!(e2m1x2);
229    impl_into_expand_element!(e2m3);
230    impl_into_expand_element!(e3m2);
231    impl_into_expand_element!(e4m3);
232    impl_into_expand_element!(e5m2);
233    impl_into_expand_element!(ue8m0);
234    impl_into_expand_element!(flex32);
235    impl_into_expand_element!(f16);
236    impl_into_expand_element!(bf16);
237    impl_into_expand_element!(tf32);
238    impl_into_expand_element!(f32);
239    impl_into_expand_element!(i8);
240    impl_into_expand_element!(i16);
241    impl_into_expand_element!(i32);
242    impl_into_expand_element!(i64);
243}