1use alloc::{rc::Rc, vec::Vec};
2use core::cell::RefCell;
3
4use hashbrown::HashMap;
5use portable_atomic::{AtomicU32, Ordering};
6
7use crate::SemanticType;
8
9use super::{Matrix, Type, Variable, VariableKind};
10
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30#[derive(Clone, Debug, Default, TypeHash)]
31pub struct Allocator {
32 #[cfg_attr(feature = "serde", serde(skip))]
33 local_mut_pool: Rc<RefCell<HashMap<Type, Vec<ExpandElement>>>>,
34 next_id: Rc<AtomicU32>,
35}
36
37impl PartialEq for Allocator {
38 fn eq(&self, other: &Self) -> bool {
39 Rc::ptr_eq(&self.local_mut_pool, &other.local_mut_pool)
40 && Rc::ptr_eq(&self.next_id, &other.next_id)
41 }
42}
43impl Eq for Allocator {}
44
45impl Allocator {
46 pub fn create_local(&self, item: Type) -> ExpandElement {
48 let id = self.new_local_index();
49 let local = VariableKind::LocalConst { id };
50 ExpandElement::Plain(Variable::new(local, item))
51 }
52
53 pub fn create_local_mut(&self, item: Type) -> ExpandElement {
57 if item.is_atomic() {
58 self.create_local_restricted(item)
59 } else {
60 self.reuse_local_mut(item)
61 .unwrap_or_else(|| ExpandElement::Managed(self.add_local_mut(item)))
62 }
63 }
64
65 pub fn create_local_restricted(&self, item: Type) -> ExpandElement {
67 let id = self.new_local_index();
68 let local = VariableKind::LocalMut { id };
69 ExpandElement::Plain(Variable::new(local, item))
70 }
71
72 pub fn create_local_array(&self, item: Type, array_size: u32) -> ExpandElement {
73 let id = self.new_local_index();
74 let local_array = Variable::new(
75 VariableKind::LocalArray {
76 id,
77 length: array_size,
78 unroll_factor: 1,
79 },
80 item,
81 );
82 ExpandElement::Plain(local_array)
83 }
84
85 pub fn create_matrix(&self, matrix: Matrix) -> ExpandElement {
87 let id = self.new_local_index();
88 let variable = Variable::new(
89 VariableKind::Matrix { id, mat: matrix },
90 Type::new(matrix.storage),
91 );
92 ExpandElement::Plain(variable)
93 }
94
95 pub fn create_pipeline(&self, num_stages: u8) -> ExpandElement {
96 let id = self.new_local_index();
97 let variable = Variable::new(
98 VariableKind::Pipeline { id, num_stages },
99 SemanticType::Pipeline.into(),
100 );
101 ExpandElement::Plain(variable)
102 }
103
104 pub fn reuse_local_mut(&self, item: Type) -> Option<ExpandElement> {
106 self.local_mut_pool.borrow().get(&item).and_then(|vars| {
109 vars.iter()
110 .rev()
111 .find(|var| matches!(var, ExpandElement::Managed(v) if Rc::strong_count(v) == 1))
112 .cloned()
113 })
114 }
115
116 pub fn add_local_mut(&self, item: Type) -> Rc<Variable> {
118 let id = self.new_local_index();
119 let local = Variable::new(VariableKind::LocalMut { id }, item);
120 let var = Rc::new(local);
121 let expand = ExpandElement::Managed(var.clone());
122 let mut pool = self.local_mut_pool.borrow_mut();
123 let variables = pool.entry(item).or_default();
124 variables.push(expand);
125 var
126 }
127
128 pub fn new_local_index(&self) -> u32 {
129 self.next_id.fetch_add(1, Ordering::Release)
130 }
131
132 pub fn take_variables(&self) -> Vec<Variable> {
133 self.local_mut_pool
134 .borrow_mut()
135 .drain()
136 .flat_map(|it| it.1)
137 .map(|it| *it)
138 .collect()
139 }
140}
141
142use cubecl_macros_internal::TypeHash;
143pub use expand_element::*;
144
145mod expand_element {
146 use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
147 use half::{bf16, f16};
148
149 use super::*;
150
151 #[derive(Clone, Debug, TypeHash)]
153 pub enum ExpandElement {
154 Managed(Rc<Variable>),
156 Plain(Variable),
158 }
159
160 impl core::ops::Deref for ExpandElement {
161 type Target = Variable;
162
163 fn deref(&self) -> &Self::Target {
164 match self {
165 ExpandElement::Managed(var) => var.as_ref(),
166 ExpandElement::Plain(var) => var,
167 }
168 }
169 }
170
171 impl From<ExpandElement> for Variable {
172 fn from(value: ExpandElement) -> Self {
173 match value {
174 ExpandElement::Managed(var) => *var,
175 ExpandElement::Plain(var) => var,
176 }
177 }
178 }
179
180 impl ExpandElement {
181 pub fn can_mut(&self) -> bool {
183 match self {
184 ExpandElement::Managed(var) => {
185 if let VariableKind::LocalMut { .. } = var.as_ref().kind {
186 Rc::strong_count(var) <= 2
187 } else {
188 false
189 }
190 }
191 ExpandElement::Plain(_) => false,
192 }
193 }
194
195 pub fn consume(self) -> Variable {
197 *self
198 }
199 }
200
201 macro_rules! impl_into_expand_element {
202 ($type:ty) => {
203 impl From<$type> for ExpandElement {
204 fn from(value: $type) -> Self {
205 ExpandElement::Plain(Variable::from(value))
206 }
207 }
208 };
209 }
210
211 impl_into_expand_element!(u8);
212 impl_into_expand_element!(u16);
213 impl_into_expand_element!(u32);
214 impl_into_expand_element!(u64);
215 impl_into_expand_element!(usize);
216 impl_into_expand_element!(bool);
217 impl_into_expand_element!(e2m1);
218 impl_into_expand_element!(e2m1x2);
219 impl_into_expand_element!(e2m3);
220 impl_into_expand_element!(e3m2);
221 impl_into_expand_element!(e4m3);
222 impl_into_expand_element!(e5m2);
223 impl_into_expand_element!(ue8m0);
224 impl_into_expand_element!(flex32);
225 impl_into_expand_element!(f16);
226 impl_into_expand_element!(bf16);
227 impl_into_expand_element!(tf32);
228 impl_into_expand_element!(f32);
229 impl_into_expand_element!(i8);
230 impl_into_expand_element!(i16);
231 impl_into_expand_element!(i32);
232 impl_into_expand_element!(i64);
233}