cubecl_core/frontend/container/
shared_memory.rs1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{
4 frontend::{CubePrimitive, CubeType, ExpandElementTyped, Init, indexation::Index},
5 ir::{Item, Scope},
6 prelude::{Line, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
7};
8
9#[derive(Clone, Copy)]
10pub struct SharedMemory<T: CubeType> {
11 _val: PhantomData<T>,
12}
13
14impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
15 fn init(self, _scope: &mut Scope) -> Self {
16 self
17 }
18}
19
20impl<T: CubePrimitive> CubeType for SharedMemory<T> {
21 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
22}
23
24impl<T: CubePrimitive + Clone> SharedMemory<T> {
25 pub fn new<S: Index>(_size: S) -> Self {
26 SharedMemory { _val: PhantomData }
27 }
28
29 pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
30 SharedMemory { _val: PhantomData }
31 }
32
33 pub fn new_aligned<S: Index>(
34 _size: S,
35 _vectorization_factor: u32,
36 _alignment: u32,
37 ) -> SharedMemory<Line<T>> {
38 SharedMemory { _val: PhantomData }
39 }
40
41 pub fn __expand_new_lined(
42 scope: &mut Scope,
43 size: ExpandElementTyped<u32>,
44 vectorization_factor: u32,
45 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
46 let size = size
47 .constant()
48 .expect("Shared memory need constant initialization value")
49 .as_u32();
50 let var = scope.create_shared(
51 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
52 size,
53 None,
54 );
55 ExpandElementTyped::new(var)
56 }
57
58 pub fn __expand_new_aligned(
59 scope: &mut Scope,
60 size: ExpandElementTyped<u32>,
61 vectorization_factor: u32,
62 alignment: u32,
63 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
64 let size = size
65 .constant()
66 .expect("Shared memory need constant initialization value")
67 .as_u32();
68 let var = scope.create_shared(
69 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
70 size,
71 Some(alignment),
72 );
73 ExpandElementTyped::new(var)
74 }
75
76 pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
77 SharedMemory { _val: PhantomData }
78 }
79
80 pub fn __expand_vectorized(
81 scope: &mut Scope,
82 size: ExpandElementTyped<u32>,
83 vectorization_factor: u32,
84 ) -> <Self as CubeType>::ExpandType {
85 let size = size
86 .constant()
87 .expect("Shared memory need constant initialization value")
88 .as_u32();
89 let var = scope.create_shared(
90 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
91 size,
92 None,
93 );
94 ExpandElementTyped::new(var)
95 }
96
97 pub fn __expand_new(
98 scope: &mut Scope,
99 size: ExpandElementTyped<u32>,
100 ) -> <Self as CubeType>::ExpandType {
101 let size = size
102 .constant()
103 .expect("Shared memory need constant initialization value")
104 .as_u32();
105 let var = scope.create_shared(Item::new(T::as_elem(scope)), size, None);
106 ExpandElementTyped::new(var)
107 }
108}
109
110mod indexation {
112 use cubecl_ir::Operator;
113
114 use crate::{
115 ir::{BinaryOperator, Instruction},
116 prelude::{CubeIndex, CubeIndexMut},
117 unexpanded,
118 };
119
120 use super::*;
121
122 impl<E: CubePrimitive> SharedMemory<E> {
123 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
129 where
130 Self: CubeIndex<I>,
131 {
132 unexpanded!()
133 }
134
135 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
141 where
142 Self: CubeIndexMut<I>,
143 {
144 unexpanded!()
145 }
146 }
147
148 impl<E: CubePrimitive> ExpandElementTyped<SharedMemory<E>> {
149 pub fn __expand_index_unchecked_method(
150 self,
151 scope: &mut Scope,
152 i: ExpandElementTyped<u32>,
153 ) -> ExpandElementTyped<E> {
154 let out = scope.create_local(self.expand.item);
155 scope.register(Instruction::new(
156 Operator::UncheckedIndex(BinaryOperator {
157 lhs: *self.expand,
158 rhs: i.expand.consume(),
159 }),
160 *out,
161 ));
162 out.into()
163 }
164
165 pub fn __expand_index_assign_unchecked_method(
166 self,
167 scope: &mut Scope,
168 i: ExpandElementTyped<u32>,
169 value: ExpandElementTyped<E>,
170 ) {
171 scope.register(Instruction::new(
172 Operator::UncheckedIndexAssign(BinaryOperator {
173 lhs: i.expand.consume(),
174 rhs: value.expand.consume(),
175 }),
176 *self.expand,
177 ));
178 }
179 }
180}
181
182impl<T: CubePrimitive> List<T> for SharedMemory<T> {
183 fn __expand_read(
184 scope: &mut Scope,
185 this: ExpandElementTyped<SharedMemory<T>>,
186 idx: ExpandElementTyped<u32>,
187 ) -> ExpandElementTyped<T> {
188 index::expand(scope, this, idx)
189 }
190}
191
192impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SharedMemory<T>> {
193 fn __expand_read_method(
194 self,
195 scope: &mut Scope,
196 idx: ExpandElementTyped<u32>,
197 ) -> ExpandElementTyped<T> {
198 index::expand(scope, self, idx)
199 }
200}
201
202impl<T: CubePrimitive> ListMut<T> for SharedMemory<T> {
203 fn __expand_write(
204 scope: &mut Scope,
205 this: ExpandElementTyped<SharedMemory<T>>,
206 idx: ExpandElementTyped<u32>,
207 value: ExpandElementTyped<T>,
208 ) {
209 index_assign::expand(scope, this, idx, value);
210 }
211}
212
213impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SharedMemory<T>> {
214 fn __expand_write_method(
215 self,
216 scope: &mut Scope,
217 idx: ExpandElementTyped<u32>,
218 value: ExpandElementTyped<T>,
219 ) {
220 index_assign::expand(scope, self, idx, value);
221 }
222}