cubecl_core/frontend/container/
shared_memory.rs1use core::marker::PhantomData;
2use core::ops::{Deref, DerefMut};
3
4use crate::{
5 self as cubecl,
6 prelude::{Vectorized, VectorizedExpand},
7 unexpanded,
8};
9use cubecl_ir::{Marker, VariableKind, VectorSize};
10use cubecl_macros::{cube, intrinsic};
11
12use crate::{
13 frontend::{CubePrimitive, CubeType, IntoMut, NativeExpand},
14 ir::Scope,
15 prelude::*,
16};
17
18pub type SharedMemoryExpand<T> = NativeExpand<SharedMemory<T>>;
19pub type SharedExpand<T> = NativeExpand<Shared<T>>;
20
21#[derive(Clone, Copy)]
22pub struct Shared<E: CubePrimitive> {
23 _val: PhantomData<E>,
24}
25
26#[derive(Clone, Copy)]
27pub struct SharedMemory<E: CubePrimitive> {
28 _val: PhantomData<E>,
29}
30
31impl<T: CubePrimitive> IntoMut for NativeExpand<SharedMemory<T>> {
32 fn into_mut(self, _scope: &mut Scope) -> Self {
33 self
34 }
35}
36
37impl<T: CubePrimitive> CubeType for SharedMemory<T> {
38 type ExpandType = NativeExpand<SharedMemory<T>>;
39}
40
41impl<T: CubePrimitive> IntoMut for NativeExpand<Shared<T>> {
42 fn into_mut(self, _scope: &mut Scope) -> Self {
43 self
44 }
45}
46
47impl<T: CubePrimitive> CubeType for Shared<T> {
48 type ExpandType = NativeExpand<Shared<T>>;
49}
50
51#[cube]
52impl<T: CubePrimitive + Clone> SharedMemory<T> {
53 #[allow(unused_variables)]
54 pub fn new(#[comptime] size: usize) -> Self {
55 intrinsic!(|scope| {
56 scope
57 .create_shared_array(T::as_type(scope), size, None)
58 .into()
59 })
60 }
61
62 #[allow(clippy::len_without_is_empty)]
63 pub fn len(&self) -> usize {
64 intrinsic!(|_| len_static(&self))
65 }
66
67 pub fn buffer_len(&self) -> usize {
68 self.len()
69 }
70}
71
72#[cube]
73impl<T: CubePrimitive> Shared<T> {
74 pub fn new() -> Self {
75 intrinsic!(|scope| {
76 let var = scope.create_shared(T::as_type(scope));
77 NativeExpand::new(var)
78 })
79 }
80}
81
82pub trait AsRefExpand<T: CubeType> {
83 fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
85}
86impl<T: CubePrimitive> AsRefExpand<T> for NativeExpand<T> {
87 fn __expand_as_ref_method(self, _scope: &mut Scope) -> NativeExpand<T> {
88 self
89 }
90}
91pub trait AsMutExpand<T: CubeType> {
92 fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
94}
95impl<T: CubePrimitive> AsMutExpand<T> for NativeExpand<T> {
96 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
97 self
98 }
99}
100
101impl<T: CubePrimitive> AsRef<T> for Shared<T> {
105 fn as_ref(&self) -> &T {
106 unexpanded!()
107 }
108}
109impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
110 fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
111 self.expand.into()
112 }
113}
114
115impl<T: CubePrimitive> AsMut<T> for Shared<T> {
116 fn as_mut(&mut self) -> &mut T {
117 unexpanded!()
118 }
119}
120impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
121 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
122 self.expand.into()
123 }
124}
125
126impl<T: CubePrimitive> Default for Shared<T> {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131impl<T: CubePrimitive> Shared<T> {
132 pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
133 Self::__expand_new(scope)
134 }
135}
136
137#[cube]
138impl<T: CubePrimitive + Clone> SharedMemory<T> {
139 #[allow(unused_variables)]
140 pub fn new_aligned(#[comptime] size: usize, #[comptime] alignment: usize) -> SharedMemory<T> {
141 intrinsic!(|scope| {
142 let var = scope.create_shared_array(T::as_type(scope), size, Some(alignment));
143 NativeExpand::new(var)
144 })
145 }
146
147 pub unsafe fn free(self) {
153 intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
154 }
155}
156
157fn len_static<T: CubePrimitive>(shared: &NativeExpand<SharedMemory<T>>) -> NativeExpand<usize> {
158 let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
159 unreachable!("Kind of shared memory is always shared memory")
160 };
161 length.into()
162}
163
164mod indexation {
166 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
167
168 use crate::ir::Instruction;
169
170 use super::*;
171
172 type SharedMemoryExpand<E> = NativeExpand<SharedMemory<E>>;
173
174 #[cube]
175 impl<E: CubePrimitive> SharedMemory<E> {
176 #[allow(unused_variables)]
182 pub unsafe fn index_unchecked(&self, i: usize) -> &E {
183 intrinsic!(|scope| {
184 let out = scope.create_local(self.expand.ty);
185 scope.register(Instruction::new(
186 Operator::UncheckedIndex(IndexOperator {
187 list: *self.expand,
188 index: i.expand.consume(),
189 vector_size: 0,
190 unroll_factor: 1,
191 }),
192 *out,
193 ));
194 out.into()
195 })
196 }
197
198 #[allow(unused_variables)]
204 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
205 intrinsic!(|scope| {
206 scope.register(Instruction::new(
207 Operator::UncheckedIndexAssign(IndexAssignOperator {
208 index: i.expand.consume(),
209 value: value.expand.consume(),
210 vector_size: 0,
211 unroll_factor: 1,
212 }),
213 *self.expand,
214 ));
215 })
216 }
217 }
218}
219
220impl<T: CubePrimitive> List<T> for SharedMemory<T> {
221 fn __expand_read(
222 scope: &mut Scope,
223 this: NativeExpand<SharedMemory<T>>,
224 idx: NativeExpand<usize>,
225 ) -> NativeExpand<T> {
226 index::expand(scope, this, idx)
227 }
228}
229
230impl<T: CubePrimitive> Deref for SharedMemory<T> {
231 type Target = [T];
232
233 fn deref(&self) -> &Self::Target {
234 unexpanded!()
235 }
236}
237
238impl<T: CubePrimitive> DerefMut for SharedMemory<T> {
239 fn deref_mut(&mut self) -> &mut Self::Target {
240 unexpanded!()
241 }
242}
243
244impl<T: CubePrimitive> ListExpand<T> for NativeExpand<SharedMemory<T>> {
245 fn __expand_read_method(&self, scope: &mut Scope, idx: NativeExpand<usize>) -> NativeExpand<T> {
246 index::expand(scope, self.clone(), idx)
247 }
248 fn __expand_read_unchecked_method(
249 &self,
250 scope: &mut Scope,
251 idx: NativeExpand<usize>,
252 ) -> NativeExpand<T> {
253 index_unchecked::expand(scope, self.clone(), idx)
254 }
255
256 fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
257 Self::__expand_len_method(self.clone(), scope)
258 }
259}
260
261impl<T: CubePrimitive> Vectorized for SharedMemory<T> {}
262impl<T: CubePrimitive> VectorizedExpand for NativeExpand<SharedMemory<T>> {
263 fn vector_size(&self) -> VectorSize {
264 self.expand.ty.vector_size()
265 }
266}
267
268impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
269 fn __expand_write(
270 scope: &mut Scope,
271 this: NativeExpand<SharedMemory<T>>,
272 idx: NativeExpand<usize>,
273 value: NativeExpand<T>,
274 ) {
275 index_assign::expand(scope, this, idx, value);
276 }
277}
278
279impl<T: CubePrimitive> ListMutExpand<T> for NativeExpand<SharedMemory<T>> {
280 fn __expand_write_method(
281 &self,
282 scope: &mut Scope,
283 idx: NativeExpand<usize>,
284 value: NativeExpand<T>,
285 ) {
286 index_assign::expand(scope, self.clone(), idx, value);
287 }
288}