cubecl_core/frontend/container/
shared_memory.rs1use core::marker::PhantomData;
2
3use crate::{
4 self as cubecl,
5 prelude::{Lined, LinedExpand},
6 unexpanded,
7};
8use cubecl_ir::{Marker, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12 frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
13 ir::{Scope, Type},
14 prelude::{
15 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16 },
17};
18
19pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
21
22#[derive(Clone, Copy)]
23pub struct Shared<E: CubePrimitive> {
24 _val: PhantomData<E>,
25}
26
27#[derive(Clone, Copy)]
28pub struct SharedMemory<E: CubePrimitive> {
29 _val: PhantomData<E>,
30}
31
32impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
33 fn into_mut(self, _scope: &mut Scope) -> Self {
34 self
35 }
36}
37
38impl<T: CubePrimitive> CubeType for SharedMemory<T> {
39 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
40}
41
42impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
43 fn into_mut(self, _scope: &mut Scope) -> Self {
44 self
45 }
46}
47
48impl<T: CubePrimitive> CubeType for Shared<T> {
49 type ExpandType = ExpandElementTyped<Shared<T>>;
50}
51
52impl<T: CubePrimitive + Clone> SharedMemory<T> {
53 pub fn new<S: Index>(_size: S) -> Self {
54 SharedMemory { _val: PhantomData }
55 }
56
57 pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
58 SharedMemory { _val: PhantomData }
59 }
60
61 #[allow(clippy::len_without_is_empty)]
62 pub fn len(&self) -> u32 {
63 unexpanded!()
64 }
65
66 pub fn buffer_len(&self) -> u32 {
67 unexpanded!()
68 }
69
70 pub fn __expand_new_lined(
71 scope: &mut Scope,
72 size: ExpandElementTyped<u32>,
73 line_size: u32,
74 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
75 let size = size
76 .constant()
77 .expect("Shared memory need constant initialization value")
78 .as_u32();
79 let var =
80 scope.create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None);
81 ExpandElementTyped::new(var)
82 }
83
84 pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
85 SharedMemory { _val: PhantomData }
86 }
87
88 pub fn __expand_vectorized(
89 scope: &mut Scope,
90 size: ExpandElementTyped<u32>,
91 line_size: u32,
92 ) -> <Self as CubeType>::ExpandType {
93 let size = size
94 .constant()
95 .expect("Shared memory need constant initialization value")
96 .as_u32();
97 let var =
98 scope.create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None);
99 ExpandElementTyped::new(var)
100 }
101
102 pub fn __expand_new(
103 scope: &mut Scope,
104 size: ExpandElementTyped<u32>,
105 ) -> <Self as CubeType>::ExpandType {
106 let size = size
107 .constant()
108 .expect("Shared memory need constant initialization value")
109 .as_u32();
110 let var = scope.create_shared_array(Type::new(T::as_type(scope)), size, None);
111 ExpandElementTyped::new(var)
112 }
113
114 pub fn __expand_len(
115 scope: &mut Scope,
116 this: ExpandElementTyped<Self>,
117 ) -> ExpandElementTyped<u32> {
118 this.__expand_len_method(scope)
119 }
120
121 pub fn __expand_buffer_len(
122 scope: &mut Scope,
123 this: ExpandElementTyped<Self>,
124 ) -> ExpandElementTyped<u32> {
125 this.__expand_buffer_len_method(scope)
126 }
127}
128
129#[cube]
130impl<T: CubePrimitive> Shared<T> {
131 pub fn new() -> Self {
132 intrinsic!(|scope| {
133 let var = scope.create_shared(Type::new(T::as_type(scope)));
134 ExpandElementTyped::new(var)
135 })
136 }
137}
138
139pub trait AsRefExpand<T: CubeType> {
140 fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
142}
143impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
144 fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
145 self
146 }
147}
148pub trait AsMutExpand<T: CubeType> {
149 fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
151}
152impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
153 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
154 self
155 }
156}
157
158impl<T: CubePrimitive> AsRef<T> for Shared<T> {
162 fn as_ref(&self) -> &T {
163 unexpanded!()
164 }
165}
166impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
167 fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
168 self.expand.into()
169 }
170}
171
172impl<T: CubePrimitive> AsMut<T> for Shared<T> {
173 fn as_mut(&mut self) -> &mut T {
174 unexpanded!()
175 }
176}
177impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
178 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
179 self.expand.into()
180 }
181}
182
183impl<T: CubePrimitive> Default for Shared<T> {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188impl<T: CubePrimitive> Shared<T> {
189 pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
190 Self::__expand_new(scope)
191 }
192}
193
194#[cube]
195impl<T: CubePrimitive> Shared<Line<T>> {
196 #[allow(unused_variables)]
197 pub fn new_lined(#[comptime] line_size: u32) -> SharedMemory<Line<T>> {
198 intrinsic!(|scope| {
199 let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
200 ExpandElementTyped::new(var)
201 })
202 }
203}
204
205impl<T: CubePrimitive> ExpandElementTyped<SharedMemory<T>> {
206 pub fn __expand_len_method(self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
207 len_static(&self)
208 }
209
210 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
211 self.__expand_len_method(scope)
212 }
213}
214
215#[cube]
216impl<T: CubePrimitive + Clone> SharedMemory<T> {
217 #[allow(unused_variables)]
218 pub fn new_aligned(
219 #[comptime] size: u32,
220 #[comptime] line_size: u32,
221 #[comptime] alignment: u32,
222 ) -> SharedMemory<Line<T>> {
223 intrinsic!(|scope| {
224 let var = scope.create_shared_array(
225 Type::new(T::as_type(scope)).line(line_size),
226 size,
227 Some(alignment),
228 );
229 ExpandElementTyped::new(var)
230 })
231 }
232
233 pub unsafe fn free(self) {
239 intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
240 }
241}
242
243fn len_static<T: CubePrimitive>(
244 shared: &ExpandElementTyped<SharedMemory<T>>,
245) -> ExpandElementTyped<u32> {
246 let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
247 unreachable!("Kind of shared memory is always shared memory")
248 };
249 length.into()
250}
251
252mod indexation {
254 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
255
256 use crate::ir::Instruction;
257
258 use super::*;
259
260 type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
261
262 #[cube]
263 impl<E: CubePrimitive> SharedMemory<E> {
264 #[allow(unused_variables)]
270 pub unsafe fn index_unchecked(&self, i: u32) -> &E {
271 intrinsic!(|scope| {
272 let out = scope.create_local(self.expand.ty);
273 scope.register(Instruction::new(
274 Operator::UncheckedIndex(IndexOperator {
275 list: *self.expand,
276 index: i.expand.consume(),
277 line_size: 0,
278 unroll_factor: 1,
279 }),
280 *out,
281 ));
282 out.into()
283 })
284 }
285
286 #[allow(unused_variables)]
292 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
293 intrinsic!(|scope| {
294 scope.register(Instruction::new(
295 Operator::UncheckedIndexAssign(IndexAssignOperator {
296 index: i.expand.consume(),
297 value: value.expand.consume(),
298 line_size: 0,
299 unroll_factor: 1,
300 }),
301 *self.expand,
302 ));
303 })
304 }
305 }
306}
307
308impl<T: CubePrimitive> List<T> for SharedMemory<T> {
309 fn __expand_read(
310 scope: &mut Scope,
311 this: ExpandElementTyped<SharedMemory<T>>,
312 idx: ExpandElementTyped<u32>,
313 ) -> ExpandElementTyped<T> {
314 index::expand(scope, this, idx)
315 }
316}
317
318impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
319 fn __expand_read_method(
320 &self,
321 scope: &mut Scope,
322 idx: ExpandElementTyped<u32>,
323 ) -> ExpandElementTyped<T> {
324 index::expand(scope, self.clone(), idx)
325 }
326 fn __expand_read_unchecked_method(
327 &self,
328 scope: &mut Scope,
329 idx: ExpandElementTyped<u32>,
330 ) -> ExpandElementTyped<T> {
331 index_unchecked::expand(scope, self.clone(), idx)
332 }
333
334 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
335 Self::__expand_len_method(self.clone(), scope)
336 }
337}
338
339impl<T: CubePrimitive> Lined for SharedMemory<T> {}
340impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
341 fn line_size(&self) -> u32 {
342 self.expand.ty.line_size()
343 }
344}
345
346impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
347 fn __expand_write(
348 scope: &mut Scope,
349 this: ExpandElementTyped<SharedMemory<T>>,
350 idx: ExpandElementTyped<u32>,
351 value: ExpandElementTyped<T>,
352 ) {
353 index_assign::expand(scope, this, idx, value);
354 }
355}
356
357impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
358 fn __expand_write_method(
359 &self,
360 scope: &mut Scope,
361 idx: ExpandElementTyped<u32>,
362 value: ExpandElementTyped<T>,
363 ) {
364 index_assign::expand(scope, self.clone(), idx, value);
365 }
366}