cubecl_core/frontend/container/
shared_memory.rs1use core::marker::PhantomData;
2
3use crate::{
4 self as cubecl,
5 prelude::{Lined, LinedExpand},
6 unexpanded,
7};
8use cubecl_ir::{LineSize, Marker, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12 frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut},
13 ir::{Scope, Type},
14 prelude::{
15 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16 },
17};
18
19pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
21
22#[derive(Clone, Copy)]
23pub struct Shared<E: CubePrimitive> {
24 _val: PhantomData<E>,
25}
26
27#[derive(Clone, Copy)]
28pub struct SharedMemory<E: CubePrimitive> {
29 _val: PhantomData<E>,
30}
31
32impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
33 fn into_mut(self, _scope: &mut Scope) -> Self {
34 self
35 }
36}
37
38impl<T: CubePrimitive> CubeType for SharedMemory<T> {
39 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
40}
41
42impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
43 fn into_mut(self, _scope: &mut Scope) -> Self {
44 self
45 }
46}
47
48impl<T: CubePrimitive> CubeType for Shared<T> {
49 type ExpandType = ExpandElementTyped<Shared<T>>;
50}
51
52#[cube]
53impl<T: CubePrimitive + Clone> SharedMemory<T> {
54 #[allow(unused_variables)]
55 pub fn new(#[comptime] size: usize) -> Self {
56 intrinsic!(|scope| {
57 scope
58 .create_shared_array(Type::new(T::as_type(scope)), size, None)
59 .into()
60 })
61 }
62
63 #[allow(unused_variables)]
64 pub fn new_lined(
65 #[comptime] size: usize,
66 #[comptime] line_size: LineSize,
67 ) -> SharedMemory<Line<T>> {
68 intrinsic!(|scope| {
69 scope
70 .create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None)
71 .into()
72 })
73 }
74
75 #[allow(clippy::len_without_is_empty)]
76 pub fn len(&self) -> usize {
77 intrinsic!(|_| len_static(&self))
78 }
79
80 pub fn buffer_len(&self) -> usize {
81 self.len()
82 }
83}
84
85#[cube]
86impl<T: CubePrimitive> Shared<T> {
87 pub fn new() -> Self {
88 intrinsic!(|scope| {
89 let var = scope.create_shared(Type::new(T::as_type(scope)));
90 ExpandElementTyped::new(var)
91 })
92 }
93}
94
95pub trait AsRefExpand<T: CubeType> {
96 fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
98}
99impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
100 fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
101 self
102 }
103}
104pub trait AsMutExpand<T: CubeType> {
105 fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
107}
108impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
109 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
110 self
111 }
112}
113
114impl<T: CubePrimitive> AsRef<T> for Shared<T> {
118 fn as_ref(&self) -> &T {
119 unexpanded!()
120 }
121}
122impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
123 fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
124 self.expand.into()
125 }
126}
127
128impl<T: CubePrimitive> AsMut<T> for Shared<T> {
129 fn as_mut(&mut self) -> &mut T {
130 unexpanded!()
131 }
132}
133impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
134 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
135 self.expand.into()
136 }
137}
138
139impl<T: CubePrimitive> Default for Shared<T> {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144impl<T: CubePrimitive> Shared<T> {
145 pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
146 Self::__expand_new(scope)
147 }
148}
149
150#[cube]
151impl<T: CubePrimitive> Shared<Line<T>> {
152 #[allow(unused_variables)]
153 pub fn new_lined(#[comptime] line_size: LineSize) -> SharedMemory<Line<T>> {
154 intrinsic!(|scope| {
155 let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
156 ExpandElementTyped::new(var)
157 })
158 }
159}
160
161#[cube]
162impl<T: CubePrimitive + Clone> SharedMemory<T> {
163 #[allow(unused_variables)]
164 pub fn new_aligned(
165 #[comptime] size: usize,
166 #[comptime] line_size: LineSize,
167 #[comptime] alignment: usize,
168 ) -> SharedMemory<Line<T>> {
169 intrinsic!(|scope| {
170 let var = scope.create_shared_array(
171 Type::new(T::as_type(scope)).line(line_size),
172 size,
173 Some(alignment),
174 );
175 ExpandElementTyped::new(var)
176 })
177 }
178
179 pub unsafe fn free(self) {
185 intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
186 }
187}
188
189fn len_static<T: CubePrimitive>(
190 shared: &ExpandElementTyped<SharedMemory<T>>,
191) -> ExpandElementTyped<usize> {
192 let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
193 unreachable!("Kind of shared memory is always shared memory")
194 };
195 length.into()
196}
197
198mod indexation {
200 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
201
202 use crate::ir::Instruction;
203
204 use super::*;
205
206 type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
207
208 #[cube]
209 impl<E: CubePrimitive> SharedMemory<E> {
210 #[allow(unused_variables)]
216 pub unsafe fn index_unchecked(&self, i: usize) -> &E {
217 intrinsic!(|scope| {
218 let out = scope.create_local(self.expand.ty);
219 scope.register(Instruction::new(
220 Operator::UncheckedIndex(IndexOperator {
221 list: *self.expand,
222 index: i.expand.consume(),
223 line_size: 0,
224 unroll_factor: 1,
225 }),
226 *out,
227 ));
228 out.into()
229 })
230 }
231
232 #[allow(unused_variables)]
238 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
239 intrinsic!(|scope| {
240 scope.register(Instruction::new(
241 Operator::UncheckedIndexAssign(IndexAssignOperator {
242 index: i.expand.consume(),
243 value: value.expand.consume(),
244 line_size: 0,
245 unroll_factor: 1,
246 }),
247 *self.expand,
248 ));
249 })
250 }
251 }
252}
253
254impl<T: CubePrimitive> List<T> for SharedMemory<T> {
255 fn __expand_read(
256 scope: &mut Scope,
257 this: ExpandElementTyped<SharedMemory<T>>,
258 idx: ExpandElementTyped<usize>,
259 ) -> ExpandElementTyped<T> {
260 index::expand(scope, this, idx)
261 }
262}
263
264impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
265 fn __expand_read_method(
266 &self,
267 scope: &mut Scope,
268 idx: ExpandElementTyped<usize>,
269 ) -> ExpandElementTyped<T> {
270 index::expand(scope, self.clone(), idx)
271 }
272 fn __expand_read_unchecked_method(
273 &self,
274 scope: &mut Scope,
275 idx: ExpandElementTyped<usize>,
276 ) -> ExpandElementTyped<T> {
277 index_unchecked::expand(scope, self.clone(), idx)
278 }
279
280 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
281 Self::__expand_len_method(self.clone(), scope)
282 }
283}
284
285impl<T: CubePrimitive> Lined for SharedMemory<T> {}
286impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
287 fn line_size(&self) -> LineSize {
288 self.expand.ty.line_size()
289 }
290}
291
292impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
293 fn __expand_write(
294 scope: &mut Scope,
295 this: ExpandElementTyped<SharedMemory<T>>,
296 idx: ExpandElementTyped<usize>,
297 value: ExpandElementTyped<T>,
298 ) {
299 index_assign::expand(scope, this, idx, value);
300 }
301}
302
303impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
304 fn __expand_write_method(
305 &self,
306 scope: &mut Scope,
307 idx: ExpandElementTyped<usize>,
308 value: ExpandElementTyped<T>,
309 ) {
310 index_assign::expand(scope, self.clone(), idx, value);
311 }
312}