cubecl_core/frontend/container/
shared_memory.rs1use std::marker::PhantomData;
2
3use crate::{
4 self as cubecl,
5 prelude::{Lined, LinedExpand},
6 unexpanded,
7};
8use cubecl_ir::{Instruction, Operation, VariableKind};
9use cubecl_macros::{cube, intrinsic};
10
11use crate::{
12 frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
13 ir::{Scope, Type},
14 prelude::{
15 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
16 },
17};
18
19type SharedMemoryExpand<T> = ExpandElementTyped<SharedMemory<T>>;
20
21#[derive(Clone, Copy)]
22pub struct SharedMemory<T: CubeType> {
23 _val: PhantomData<T>,
24}
25
26impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
27 fn into_mut(self, _scope: &mut Scope) -> Self {
28 self
29 }
30}
31
32impl<T: CubePrimitive> CubeType for SharedMemory<T> {
33 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
34}
35
36impl<T: CubePrimitive + Clone> SharedMemory<T> {
37 pub fn new<S: Index>(_size: S) -> Self {
38 SharedMemory { _val: PhantomData }
39 }
40
41 pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
42 SharedMemory { _val: PhantomData }
43 }
44
45 #[allow(clippy::len_without_is_empty)]
46 pub fn len(&self) -> u32 {
47 unexpanded!()
48 }
49
50 pub fn buffer_len(&self) -> u32 {
51 unexpanded!()
52 }
53
54 pub fn __expand_new_lined(
55 scope: &mut Scope,
56 size: ExpandElementTyped<u32>,
57 line_size: u32,
58 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
59 let size = size
60 .constant()
61 .expect("Shared memory need constant initialization value")
62 .as_u32();
63 let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
64 ExpandElementTyped::new(var)
65 }
66
67 pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
68 SharedMemory { _val: PhantomData }
69 }
70
71 pub fn __expand_vectorized(
72 scope: &mut Scope,
73 size: ExpandElementTyped<u32>,
74 line_size: u32,
75 ) -> <Self as CubeType>::ExpandType {
76 let size = size
77 .constant()
78 .expect("Shared memory need constant initialization value")
79 .as_u32();
80 let var = scope.create_shared(Type::new(T::as_type(scope)).line(line_size), size, None);
81 ExpandElementTyped::new(var)
82 }
83
84 pub fn __expand_new(
85 scope: &mut Scope,
86 size: ExpandElementTyped<u32>,
87 ) -> <Self as CubeType>::ExpandType {
88 let size = size
89 .constant()
90 .expect("Shared memory need constant initialization value")
91 .as_u32();
92 let var = scope.create_shared(Type::new(T::as_type(scope)), size, None);
93 ExpandElementTyped::new(var)
94 }
95
96 pub fn __expand_len(
97 scope: &mut Scope,
98 this: ExpandElementTyped<Self>,
99 ) -> ExpandElementTyped<u32> {
100 this.__expand_len_method(scope)
101 }
102
103 pub fn __expand_buffer_len(
104 scope: &mut Scope,
105 this: ExpandElementTyped<Self>,
106 ) -> ExpandElementTyped<u32> {
107 this.__expand_buffer_len_method(scope)
108 }
109}
110
111impl<T: CubePrimitive> ExpandElementTyped<SharedMemory<T>> {
112 pub fn __expand_len_method(self, _scope: &mut Scope) -> ExpandElementTyped<u32> {
113 len_static(&self)
114 }
115
116 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
117 self.__expand_len_method(scope)
118 }
119}
120
121#[cube]
122impl<T: CubePrimitive + Clone> SharedMemory<T> {
123 #[allow(unused_variables)]
124 pub fn new_aligned(
125 #[comptime] size: u32,
126 #[comptime] line_size: u32,
127 #[comptime] alignment: u32,
128 ) -> SharedMemory<Line<T>> {
129 intrinsic!(|scope| {
130 let var = scope.create_shared(
131 Type::new(T::as_type(scope)).line(line_size),
132 size,
133 Some(alignment),
134 );
135 ExpandElementTyped::new(var)
136 })
137 }
138
139 pub unsafe fn free(self) {
145 intrinsic!(|scope| { scope.register(Instruction::no_out(Operation::Free(*self.expand))) })
146 }
147}
148
149fn len_static<T: CubePrimitive>(
150 shared: &ExpandElementTyped<SharedMemory<T>>,
151) -> ExpandElementTyped<u32> {
152 let VariableKind::SharedMemory { length, .. } = shared.expand.kind else {
153 unreachable!("Kind of shared memory is always shared memory")
154 };
155 length.into()
156}
157
158mod indexation {
160 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
161
162 use crate::ir::Instruction;
163
164 use super::*;
165
166 type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
167
168 #[cube]
169 impl<E: CubePrimitive> SharedMemory<E> {
170 #[allow(unused_variables)]
176 pub unsafe fn index_unchecked(&self, i: u32) -> &E {
177 intrinsic!(|scope| {
178 let out = scope.create_local(self.expand.ty);
179 scope.register(Instruction::new(
180 Operator::UncheckedIndex(IndexOperator {
181 list: *self.expand,
182 index: i.expand.consume(),
183 line_size: 0,
184 unroll_factor: 1,
185 }),
186 *out,
187 ));
188 out.into()
189 })
190 }
191
192 #[allow(unused_variables)]
198 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
199 intrinsic!(|scope| {
200 scope.register(Instruction::new(
201 Operator::UncheckedIndexAssign(IndexAssignOperator {
202 index: i.expand.consume(),
203 value: value.expand.consume(),
204 line_size: 0,
205 unroll_factor: 1,
206 }),
207 *self.expand,
208 ));
209 })
210 }
211 }
212}
213
214impl<T: CubePrimitive> List<T> for SharedMemory<T> {
215 fn __expand_read(
216 scope: &mut Scope,
217 this: ExpandElementTyped<SharedMemory<T>>,
218 idx: ExpandElementTyped<u32>,
219 ) -> ExpandElementTyped<T> {
220 index::expand(scope, this, idx)
221 }
222}
223
224impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
225 fn __expand_read_method(
226 &self,
227 scope: &mut Scope,
228 idx: ExpandElementTyped<u32>,
229 ) -> ExpandElementTyped<T> {
230 index::expand(scope, self.clone(), idx)
231 }
232 fn __expand_read_unchecked_method(
233 &self,
234 scope: &mut Scope,
235 idx: ExpandElementTyped<u32>,
236 ) -> ExpandElementTyped<T> {
237 index_unchecked::expand(scope, self.clone(), idx)
238 }
239
240 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
241 Self::__expand_len_method(self.clone(), scope)
242 }
243}
244
245impl<T: CubePrimitive> Lined for SharedMemory<T> {}
246impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<SharedMemory<T>> {
247 fn line_size(&self) -> u32 {
248 self.expand.ty.line_size()
249 }
250}
251
252impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
253 fn __expand_write(
254 scope: &mut Scope,
255 this: ExpandElementTyped<SharedMemory<T>>,
256 idx: ExpandElementTyped<u32>,
257 value: ExpandElementTyped<T>,
258 ) {
259 index_assign::expand(scope, this, idx, value);
260 }
261}
262
263impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
264 fn __expand_write_method(
265 &self,
266 scope: &mut Scope,
267 idx: ExpandElementTyped<u32>,
268 value: ExpandElementTyped<T>,
269 ) {
270 index_assign::expand(scope, self.clone(), idx, value);
271 }
272}