cubecl_core/frontend/container/
shared_memory.rs1use std::{marker::PhantomData, num::NonZero};
2
3use crate as cubecl;
4use cubecl_macros::{cube, intrinsic};
5
6use crate::{
7 frontend::{CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
8 ir::{Item, Scope},
9 prelude::{
10 Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign, index_unchecked,
11 },
12};
13
14#[derive(Clone, Copy)]
15pub struct SharedMemory<T: CubeType> {
16 _val: PhantomData<T>,
17}
18
19impl<T: CubePrimitive> IntoMut for ExpandElementTyped<SharedMemory<T>> {
20 fn into_mut(self, _scope: &mut Scope) -> Self {
21 self
22 }
23}
24
25impl<T: CubePrimitive> CubeType for SharedMemory<T> {
26 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
27}
28
29impl<T: CubePrimitive + Clone> SharedMemory<T> {
30 pub fn new<S: Index>(_size: S) -> Self {
31 SharedMemory { _val: PhantomData }
32 }
33
34 pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
35 SharedMemory { _val: PhantomData }
36 }
37
38 pub fn __expand_new_lined(
39 scope: &mut Scope,
40 size: ExpandElementTyped<u32>,
41 vectorization_factor: u32,
42 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
43 let size = size
44 .constant()
45 .expect("Shared memory need constant initialization value")
46 .as_u32();
47 let var = scope.create_shared(
48 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
49 size,
50 None,
51 );
52 ExpandElementTyped::new(var)
53 }
54
55 pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
56 SharedMemory { _val: PhantomData }
57 }
58
59 pub fn __expand_vectorized(
60 scope: &mut Scope,
61 size: ExpandElementTyped<u32>,
62 vectorization_factor: u32,
63 ) -> <Self as CubeType>::ExpandType {
64 let size = size
65 .constant()
66 .expect("Shared memory need constant initialization value")
67 .as_u32();
68 let var = scope.create_shared(
69 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
70 size,
71 None,
72 );
73 ExpandElementTyped::new(var)
74 }
75
76 pub fn __expand_new(
77 scope: &mut Scope,
78 size: ExpandElementTyped<u32>,
79 ) -> <Self as CubeType>::ExpandType {
80 let size = size
81 .constant()
82 .expect("Shared memory need constant initialization value")
83 .as_u32();
84 let var = scope.create_shared(Item::new(T::as_elem(scope)), size, None);
85 ExpandElementTyped::new(var)
86 }
87}
88
89#[cube]
90impl<T: CubePrimitive + Clone> SharedMemory<T> {
91 #[allow(unused_variables)]
92 pub fn new_aligned(
93 #[comptime] size: u32,
94 #[comptime] vectorization_factor: u32,
95 #[comptime] alignment: u32,
96 ) -> SharedMemory<Line<T>> {
97 intrinsic!(|scope| {
98 let var = scope.create_shared(
99 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
100 size,
101 Some(alignment),
102 );
103 ExpandElementTyped::new(var)
104 })
105 }
106}
107
108mod indexation {
110 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
111
112 use crate::ir::Instruction;
113
114 use super::*;
115
116 type SharedMemoryExpand<E> = ExpandElementTyped<SharedMemory<E>>;
117
118 #[cube]
119 impl<E: CubePrimitive> SharedMemory<E> {
120 #[allow(unused_variables)]
126 pub unsafe fn index_unchecked(&self, i: u32) -> &E {
127 intrinsic!(|scope| {
128 let out = scope.create_local(self.expand.item);
129 scope.register(Instruction::new(
130 Operator::UncheckedIndex(IndexOperator {
131 list: *self.expand,
132 index: i.expand.consume(),
133 line_size: 0,
134 }),
135 *out,
136 ));
137 out.into()
138 })
139 }
140
141 #[allow(unused_variables)]
147 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E) {
148 intrinsic!(|scope| {
149 scope.register(Instruction::new(
150 Operator::UncheckedIndexAssign(IndexAssignOperator {
151 index: i.expand.consume(),
152 value: value.expand.consume(),
153 line_size: 0,
154 }),
155 *self.expand,
156 ));
157 })
158 }
159 }
160}
161
162impl<T: CubePrimitive> List<T> for SharedMemory<T> {
163 fn __expand_read(
164 scope: &mut Scope,
165 this: ExpandElementTyped<SharedMemory<T>>,
166 idx: ExpandElementTyped<u32>,
167 ) -> ExpandElementTyped<T> {
168 index::expand(scope, this, idx)
169 }
170}
171
172impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
173 fn __expand_read_method(
174 &self,
175 scope: &mut Scope,
176 idx: ExpandElementTyped<u32>,
177 ) -> ExpandElementTyped<T> {
178 index::expand(scope, self.clone(), idx)
179 }
180 fn __expand_read_unchecked_method(
181 &self,
182 scope: &mut Scope,
183 idx: ExpandElementTyped<u32>,
184 ) -> ExpandElementTyped<T> {
185 index_unchecked::expand(scope, self.clone(), idx)
186 }
187}
188
189impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
190 fn __expand_write(
191 scope: &mut Scope,
192 this: ExpandElementTyped<SharedMemory<T>>,
193 idx: ExpandElementTyped<u32>,
194 value: ExpandElementTyped<T>,
195 ) {
196 index_assign::expand(scope, this, idx, value);
197 }
198}
199
200impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
201 fn __expand_write_method(
202 &self,
203 scope: &mut Scope,
204 idx: ExpandElementTyped<u32>,
205 value: ExpandElementTyped<T>,
206 ) {
207 index_assign::expand(scope, self.clone(), idx, value);
208 }
209}