1use alloc::{rc::Rc, vec::Vec};
2use core::cell::RefCell;
3
4use hashbrown::HashMap;
5use portable_atomic::{AtomicU32, Ordering};
6
7use crate::BarrierLevel;
8
9use super::{Item, Matrix, Variable, VariableKind};
10
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30#[derive(Clone, Debug, Default, TypeHash)]
31pub struct Allocator {
32 #[cfg_attr(feature = "serde", serde(skip))]
33 local_mut_pool: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
34 next_id: Rc<AtomicU32>,
35}
36
37impl PartialEq for Allocator {
38 fn eq(&self, other: &Self) -> bool {
39 Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
40 && Rc::ptr_eq(&self.next_id, &other.next_id)
41 }
42}
43impl Eq for Allocator {}
44
45impl Allocator {
46 pub fn create_local(&self, item: Item) -> ExpandElement {
48 let id = self.new_local_index();
49 let local = VariableKind::LocalConst { id };
50 ExpandElement::Plain(Variable::new(local, item))
51 }
52
53 pub fn create_local_mut(&self, item: Item) -> ExpandElement {
57 if item.elem.is_atomic() {
58 self.create_local_restricted(item)
59 } else {
60 self.reuse_local_mut(item)
61 .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(item)))
62 }
63 }
64
65 pub fn create_local_restricted(&self, item: Item) -> ExpandElement {
67 let id = self.new_local_index();
68 let local = VariableKind::LocalMut { id };
69 ExpandElement::Plain(Variable::new(local, item))
70 }
71
72 pub fn create_local_array(&self, item: Item, array_size: u32) -> ExpandElement {
73 let id = self.new_local_index();
74 let local_array = Variable::new(
75 VariableKind::LocalArray {
76 id,
77 length: array_size,
78 },
79 item,
80 );
81 ExpandElement::Plain(local_array)
82 }
83
84 pub fn create_slice(&self, item: Item) -> ExpandElement {
86 let id = self.new_local_index();
87 let variable = Variable::new(VariableKind::Slice { id }, item);
88 ExpandElement::Plain(variable)
89 }
90
91 pub fn create_matrix(&self, matrix: Matrix) -> ExpandElement {
93 let id = self.new_local_index();
94 let variable = Variable::new(
95 VariableKind::Matrix { id, mat: matrix },
96 Item::new(matrix.elem),
97 );
98 ExpandElement::Plain(variable)
99 }
100
101 pub fn create_pipeline(&self, item: Item, num_stages: u8) -> ExpandElement {
102 let id = self.new_local_index();
103 let variable = Variable::new(
104 VariableKind::Pipeline {
105 id,
106 item,
107 num_stages,
108 },
109 item,
110 );
111 ExpandElement::Plain(variable)
112 }
113
114 pub fn create_barrier(&self, item: Item, level: BarrierLevel) -> ExpandElement {
115 let id = self.new_local_index();
116 let variable = Variable::new(VariableKind::Barrier { id, item, level }, item);
117 ExpandElement::Plain(variable)
118 }
119
120 pub fn reuse_local_mut(&self, item: Item) -> Option<ExpandElement> {
122 self.local_mut_pool.borrow().get(&item).and_then(|vars| {
125 vars.iter()
126 .rev()
127 .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1))
128 .cloned()
129 })
130 }
131
132 pub fn add_local_mut(&self, item: Item) -> Rc<Variable> {
134 let id = self.new_local_index();
135 let local = Variable::new(VariableKind::LocalMut { id }, item);
136 let var = Rc::new(local);
137 let expand = ExpandElement::Managed(var.clone());
138 let mut pool = self.local_mut_pool.borrow_mut();
139 let variables = pool.entry(item).or_default();
140 variables.push(expand);
141 var
142 }
143
144 pub fn new_local_index(&self) -> u32 {
145 self.next_id.fetch_add(1, Ordering::Release)
146 }
147
148 pub fn take_variables(&self) -> Vec<Variable> {
149 self.local_mut_pool
150 .borrow_mut()
151 .drain()
152 .flat_map(|it| it.1)
153 .map(|it| *it)
154 .collect()
155 }
156}
157
158use cubecl_macros_internal::TypeHash;
159pub use expand_element::*;
160
161mod expand_element {
162 use cubecl_common::{flex32, tf32};
163 use half::{bf16, f16};
164
165 use super::*;
166
167 #[derive(Clone, Debug, TypeHash)]
169 pub enum ExpandElement {
170 Managed(Rc<Variable>),
172 Plain(Variable),
174 }
175
176 impl core::ops::Deref for ExpandElement {
177 type Target = Variable;
178
179 fn deref(&self) -> &Self::Target {
180 match self {
181 ExpandElement::Managed(var) => var.as_ref(),
182 ExpandElement::Plain(var) => var,
183 }
184 }
185 }
186
187 impl From<ExpandElement> for Variable {
188 fn from(value: ExpandElement) -> Self {
189 match value {
190 ExpandElement::Managed(var) => *var,
191 ExpandElement::Plain(var) => var,
192 }
193 }
194 }
195
196 impl ExpandElement {
197 pub fn can_mut(&self) -> bool {
199 match self {
200 ExpandElement::Managed(var) => {
201 if let VariableKind::LocalMut { .. } = var.as_ref().kind {
202 Rc::strong_count(var) <= 2
203 } else {
204 false
205 }
206 }
207 ExpandElement::Plain(_) => false,
208 }
209 }
210
211 pub fn consume(self) -> Variable {
213 *self
214 }
215 }
216
217 macro_rules! impl_into_expand_element {
218 ($type:ty) => {
219 impl From<$type> for ExpandElement {
220 fn from(value: $type) -> Self {
221 ExpandElement::Plain(Variable::from(value))
222 }
223 }
224 };
225 }
226
227 impl_into_expand_element!(u8);
228 impl_into_expand_element!(u16);
229 impl_into_expand_element!(u32);
230 impl_into_expand_element!(u64);
231 impl_into_expand_element!(usize);
232 impl_into_expand_element!(bool);
233 impl_into_expand_element!(flex32);
234 impl_into_expand_element!(f16);
235 impl_into_expand_element!(bf16);
236 impl_into_expand_element!(tf32);
237 impl_into_expand_element!(f32);
238 impl_into_expand_element!(i8);
239 impl_into_expand_element!(i16);
240 impl_into_expand_element!(i32);
241 impl_into_expand_element!(i64);
242}