cubecl_core/frontend/container/
shared_memory.rs1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{
4 frontend::{
5 indexation::Index, CubeContext, CubePrimitive, CubeType, ExpandElementTyped, Init,
6 IntoRuntime,
7 },
8 ir::Item,
9 prelude::Line,
10};
11
12#[derive(Clone, Copy)]
13pub struct SharedMemory<T: CubeType> {
14 _val: PhantomData<T>,
15}
16
17impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
18 fn init(self, _context: &mut CubeContext) -> Self {
19 self
20 }
21}
22
23impl<T: CubePrimitive> IntoRuntime for SharedMemory<T> {
24 fn __expand_runtime_method(self, _context: &mut CubeContext) -> ExpandElementTyped<Self> {
25 unimplemented!("Shared memory can't exist at comptime");
26 }
27}
28
29impl<T: CubePrimitive> CubeType for SharedMemory<T> {
30 type ExpandType = ExpandElementTyped<SharedMemory<T>>;
31}
32
33impl<T: CubePrimitive + Clone> SharedMemory<T> {
34 pub fn new<S: Index>(_size: S) -> Self {
35 SharedMemory { _val: PhantomData }
36 }
37
38 pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
39 SharedMemory { _val: PhantomData }
40 }
41
42 pub fn __expand_new_lined(
43 context: &mut CubeContext,
44 size: ExpandElementTyped<u32>,
45 vectorization_factor: u32,
46 ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
47 let size = size
48 .constant()
49 .expect("Shared memory need constant initialization value")
50 .as_u32();
51 let var = context.create_shared(
52 Item::vectorized(
53 T::as_elem(context),
54 NonZero::new(vectorization_factor as u8),
55 ),
56 size,
57 );
58 ExpandElementTyped::new(var)
59 }
60 pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
61 SharedMemory { _val: PhantomData }
62 }
63
64 pub fn __expand_vectorized(
65 context: &mut CubeContext,
66 size: ExpandElementTyped<u32>,
67 vectorization_factor: u32,
68 ) -> <Self as CubeType>::ExpandType {
69 let size = size
70 .constant()
71 .expect("Shared memory need constant initialization value")
72 .as_u32();
73 let var = context.create_shared(
74 Item::vectorized(
75 T::as_elem(context),
76 NonZero::new(vectorization_factor as u8),
77 ),
78 size,
79 );
80 ExpandElementTyped::new(var)
81 }
82
83 pub fn __expand_new(
84 context: &mut CubeContext,
85 size: ExpandElementTyped<u32>,
86 ) -> <Self as CubeType>::ExpandType {
87 let size = size
88 .constant()
89 .expect("Shared memory need constant initialization value")
90 .as_u32();
91 let var = context.create_shared(Item::new(T::as_elem(context)), size);
92 ExpandElementTyped::new(var)
93 }
94}
95
96mod indexation {
98 use crate::{
99 ir::{BinaryOperator, Instruction, Operator},
100 prelude::{CubeIndex, CubeIndexMut},
101 unexpanded,
102 };
103
104 use super::*;
105
106 impl<E: CubePrimitive> SharedMemory<E> {
107 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
113 where
114 Self: CubeIndex<I>,
115 {
116 unexpanded!()
117 }
118
119 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
125 where
126 Self: CubeIndexMut<I>,
127 {
128 unexpanded!()
129 }
130 }
131
132 impl<E: CubePrimitive> ExpandElementTyped<SharedMemory<E>> {
133 pub fn __expand_index_unchecked_method(
134 self,
135 context: &mut CubeContext,
136 i: ExpandElementTyped<u32>,
137 ) -> ExpandElementTyped<E> {
138 let out = context.create_local(self.expand.item);
139 context.register(Instruction::new(
140 Operator::UncheckedIndex(BinaryOperator {
141 lhs: *self.expand,
142 rhs: i.expand.consume(),
143 }),
144 *out,
145 ));
146 out.into()
147 }
148
149 pub fn __expand_index_assign_unchecked_method(
150 self,
151 context: &mut CubeContext,
152 i: ExpandElementTyped<u32>,
153 value: ExpandElementTyped<E>,
154 ) {
155 context.register(Instruction::new(
156 Operator::UncheckedIndexAssign(BinaryOperator {
157 lhs: i.expand.consume(),
158 rhs: value.expand.consume(),
159 }),
160 *self.expand,
161 ));
162 }
163 }
164}