cubecl_core/frontend/container/
shared_memory.rs1use core::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use crate::{
5 self as cubecl,
6 prelude::{Lined, LinedExpand},
7 unexpanded,
8};
9use cubecl_ir::{LineSize, Marker, VariableKind};
10use cubecl_macros::{cube, intrinsic};
11
12use crate::{
13 frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut},
14 ir::{Scope, Type},
15 prelude::{
16 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
17 },
18};
19
20pub type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
21pub type SharedExpand<T> = ExpandElementTyped<Shared<T>>;
22
23#[derive(Clone, Copy)]
24pub struct Shared<E: CubePrimitive> {
25 _val: PhantomData<E>,
26}
27
28#[derive(Clone, Copy)]
29pub struct SharedMemory<E: CubePrimitive> {
30 _val: PhantomData<E>,
31}
32
33impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
34 fn into_mut(self, _scope: &mut Scope) -> Self {
35 self
36 }
37}
38
39impl<T: CubePrimitive> CubeType for SharedMemory<T> {
40 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
41}
42
43impl<T: CubePrimitive> IntoMut for ExpandElementTyped<Shared<T>> {
44 fn into_mut(self, _scope: &mut Scope) -> Self {
45 self
46 }
47}
48
49impl<T: CubePrimitive> CubeType for Shared<T> {
50 type ExpandType = ExpandElementTyped<Shared<T>>;
51}
52
53#[cube]
54impl<T: CubePrimitive + Clone> SharedMemory<T> {
55 #[allow(unused_variables)]
56 pub fn new(#[comptime] size: usize) -> Self {
57 intrinsic!(|scope| {
58 scope
59 .create_shared_array(Type::new(T::as_type(scope)), size, None)
60 .into()
61 })
62 }
63
64 #[allow(unused_variables)]
65 pub fn new_lined(
66 #[comptime] size: usize,
67 #[comptime] line_size: LineSize,
68 ) -> SharedMemory<Line<T>> {
69 intrinsic!(|scope| {
70 scope
71 .create_shared_array(Type::new(T::as_type(scope)).line(line_size), size, None)
72 .into()
73 })
74 }
75
76 #[allow(clippy::len_without_is_empty)]
77 pub fn len(&self) -> usize {
78 intrinsic!(|_| len_static(&self))
79 }
80
81 pub fn buffer_len(&self) -> usize {
82 self.len()
83 }
84}
85
86#[cube]
87impl<T: CubePrimitive> Shared<T> {
88 pub fn new() -> Self {
89 intrinsic!(|scope| {
90 let var = scope.create_shared(Type::new(T::as_type(scope)));
91 ExpandElementTyped::new(var)
92 })
93 }
94}
95
96pub trait AsRefExpand<T: CubeType> {
97 fn __expand_as_ref_method(self, scope: &mut Scope) -> T::ExpandType;
99}
100impl<T: CubePrimitive> AsRefExpand<T> for ExpandElementTyped<T> {
101 fn __expand_as_ref_method(self, _scope: &mut Scope) -> ExpandElementTyped<T> {
102 self
103 }
104}
105pub trait AsMutExpand<T: CubeType> {
106 fn __expand_as_mut_method(self, scope: &mut Scope) -> T::ExpandType;
108}
109impl<T: CubePrimitive> AsMutExpand<T> for ExpandElementTyped<T> {
110 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
111 self
112 }
113}
114
115impl<T: CubePrimitive> AsRef<T> for Shared<T> {
119 fn as_ref(&self) -> &T {
120 unexpanded!()
121 }
122}
123impl<T: CubePrimitive> AsRefExpand<T> for SharedExpand<T> {
124 fn __expand_as_ref_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
125 self.expand.into()
126 }
127}
128
129impl<T: CubePrimitive> AsMut<T> for Shared<T> {
130 fn as_mut(&mut self) -> &mut T {
131 unexpanded!()
132 }
133}
134impl<T: CubePrimitive> AsMutExpand<T> for SharedExpand<T> {
135 fn __expand_as_mut_method(self, _scope: &mut Scope) -> <T as CubeType>::ExpandType {
136 self.expand.into()
137 }
138}
139
140impl<T: CubePrimitive> Default for Shared<T> {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145impl<T: CubePrimitive> Shared<T> {
146 pub fn __expand_default(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
147 Self::__expand_new(scope)
148 }
149}
150
151#[cube]
152impl<T: CubePrimitive> Shared<Line<T>> {
153 #[allow(unused_variables)]
154 pub fn new_lined(#[comptime] line_size: LineSize) -> SharedMemory<Line<T>> {
155 intrinsic!(|scope| {
156 let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size));
157 ExpandElementTyped::new(var)
158 })
159 }
160}
161
162#[cube]
163impl<T: CubePrimitive + Clone> SharedMemory<T> {
164 #[allow(unused_variables)]
165 pub fn new_aligned(
166 #[comptime] size: usize,
167 #[comptime] line_size: LineSize,
168 #[comptime] alignment: usize,
169 ) -> SharedMemory<Line<T>> {
170 intrinsic!(|scope| {
171 let var = scope.create_shared_array(
172 Type::new(T::as_type(scope)).line(line_size),
173 size,
174 Some(alignment),
175 );
176 ExpandElementTyped::new(var)
177 })
178 }
179
180 pub unsafe fn free(self) {
186 intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) })
187 }
188}
189
190fn len_static<T: CubePrimitive>(
191 shared: &ExpandElementTyped<SharedMemory<T>>,
192) -> ExpandElementTyped<usize> {
193 let VariableKind::SharedArray { length, .. } = shared.expand.kind else {
194 unreachable!("Kind of shared memory is always shared memory")
195 };
196 length.into()
197}
198
199mod indexation {
201 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
202
203 use crate::ir::Instruction;
204
205 use super::*;
206
207 type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
208
209 #[cube]
210 impl<E: CubePrimitive> SharedMemory<E> {
211 #[allow(unused_variables)]
217 pub unsafe fn index_unchecked(&self, i: usize) -> &E {
218 intrinsic!(|scope| {
219 let out = scope.create_local(self.expand.ty);
220 scope.register(Instruction::new(
221 Operator::UncheckedIndex(IndexOperator {
222 list: *self.expand,
223 index: i.expand.consume(),
224 line_size: 0,
225 unroll_factor: 1,
226 }),
227 *out,
228 ));
229 out.into()
230 })
231 }
232
233 #[allow(unused_variables)]
239 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
240 intrinsic!(|scope| {
241 scope.register(Instruction::new(
242 Operator::UncheckedIndexAssign(IndexAssignOperator {
243 index: i.expand.consume(),
244 value: value.expand.consume(),
245 line_size: 0,
246 unroll_factor: 1,
247 }),
248 *self.expand,
249 ));
250 })
251 }
252 }
253}
254
255impl<T: CubePrimitive> List<T> for SharedMemory<T> {
256 fn __expand_read(
257 scope: &mut Scope,
258 this: ExpandElementTyped<SharedMemory<T>>,
259 idx: ExpandElementTyped<usize>,
260 ) -> ExpandElementTyped<T> {
261 index::expand(scope, this, idx)
262 }
263}
264
265impl<T: CubePrimitive> Deref for SharedMemory<T> {
266 type Target = [T];
267
268 fn deref(&self) -> &Self::Target {
269 unexpanded!()
270 }
271}
272
273impl<T: CubePrimitive> DerefMut for SharedMemory<T> {
274 fn deref_mut(&mut self) -> &mut Self::Target {
275 unexpanded!()
276 }
277}
278
279impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
280 fn __expand_read_method(
281 &self,
282 scope: &mut Scope,
283 idx: ExpandElementTyped<usize>,
284 ) -> ExpandElementTyped<T> {
285 index::expand(scope, self.clone(), idx)
286 }
287 fn __expand_read_unchecked_method(
288 &self,
289 scope: &mut Scope,
290 idx: ExpandElementTyped<usize>,
291 ) -> ExpandElementTyped<T> {
292 index_unchecked::expand(scope, self.clone(), idx)
293 }
294
295 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
296 Self::__expand_len_method(self.clone(), scope)
297 }
298}
299
300impl<T: CubePrimitive> Lined for SharedMemory<T> {}
301impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
302 fn line_size(&self) -> LineSize {
303 self.expand.ty.line_size()
304 }
305}
306
307impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
308 fn __expand_write(
309 scope: &mut Scope,
310 this: ExpandElementTyped<SharedMemory<T>>,
311 idx: ExpandElementTyped<usize>,
312 value: ExpandElementTyped<T>,
313 ) {
314 index_assign::expand(scope, this, idx, value);
315 }
316}
317
318impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
319 fn __expand_write_method(
320 &self,
321 scope: &mut Scope,
322 idx: ExpandElementTyped<usize>,
323 value: ExpandElementTyped<T>,
324 ) {
325 index_assign::expand(scope, self.clone(), idx, value);
326 }
327}