cubecl_core/frontend/container/
shared_memory.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{
4    frontend::{
5        indexation::Index, CubeContext, CubePrimitive, CubeType, ExpandElementTyped, Init,
6        IntoRuntime,
7    },
8    ir::Item,
9    prelude::Line,
10};
11
12#[derive(Clone, Copy)]
13pub struct SharedMemory<T: CubeType> {
14    _val: PhantomData<T>,
15}
16
17impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
18    fn init(self, _context: &mut CubeContext) -> Self {
19        self
20    }
21}
22
23impl<T: CubePrimitive> IntoRuntime for SharedMemory<T> {
24    fn __expand_runtime_method(self, _context: &mut CubeContext) -> ExpandElementTyped<Self> {
25        unimplemented!("Shared memory can't exist at comptime");
26    }
27}
28
29impl<T: CubePrimitive> CubeType for SharedMemory<T> {
30    type ExpandType = ExpandElementTyped<SharedMemory<T>>;
31}
32
33impl<T: CubePrimitive + Clone> SharedMemory<T> {
34    pub fn new<S: Index>(_size: S) -> Self {
35        SharedMemory { _val: PhantomData }
36    }
37
38    pub fn new_lined<S: Index>(_size: S, _vectorization_factor: u32) -> SharedMemory<Line<T>> {
39        SharedMemory { _val: PhantomData }
40    }
41
42    pub fn __expand_new_lined(
43        context: &mut CubeContext,
44        size: ExpandElementTyped<u32>,
45        vectorization_factor: u32,
46    ) -> <SharedMemory<Line<T>> as CubeType>::ExpandType {
47        let size = size
48            .constant()
49            .expect("Shared memory need constant initialization value")
50            .as_u32();
51        let var = context.create_shared(
52            Item::vectorized(
53                T::as_elem(context),
54                NonZero::new(vectorization_factor as u8),
55            ),
56            size,
57        );
58        ExpandElementTyped::new(var)
59    }
60    pub fn vectorized<S: Index>(_size: S, _vectorization_factor: u32) -> Self {
61        SharedMemory { _val: PhantomData }
62    }
63
64    pub fn __expand_vectorized(
65        context: &mut CubeContext,
66        size: ExpandElementTyped<u32>,
67        vectorization_factor: u32,
68    ) -> <Self as CubeType>::ExpandType {
69        let size = size
70            .constant()
71            .expect("Shared memory need constant initialization value")
72            .as_u32();
73        let var = context.create_shared(
74            Item::vectorized(
75                T::as_elem(context),
76                NonZero::new(vectorization_factor as u8),
77            ),
78            size,
79        );
80        ExpandElementTyped::new(var)
81    }
82
83    pub fn __expand_new(
84        context: &mut CubeContext,
85        size: ExpandElementTyped<u32>,
86    ) -> <Self as CubeType>::ExpandType {
87        let size = size
88            .constant()
89            .expect("Shared memory need constant initialization value")
90            .as_u32();
91        let var = context.create_shared(Item::new(T::as_elem(context)), size);
92        ExpandElementTyped::new(var)
93    }
94}
95
96/// Module that contains the implementation details of the index functions.
97mod indexation {
98    use crate::{
99        ir::{BinaryOperator, Instruction, Operator},
100        prelude::{CubeIndex, CubeIndexMut},
101        unexpanded,
102    };
103
104    use super::*;
105
106    impl<E: CubePrimitive> SharedMemory<E> {
107        /// Perform an unchecked index into the array
108        ///
109        /// # Safety
110        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
111        /// always in bounds
112        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
113        where
114            Self: CubeIndex<I>,
115        {
116            unexpanded!()
117        }
118
119        /// Perform an unchecked index assignment into the array
120        ///
121        /// # Safety
122        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
123        /// always in bounds
124        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
125        where
126            Self: CubeIndexMut<I>,
127        {
128            unexpanded!()
129        }
130    }
131
132    impl<E: CubePrimitive> ExpandElementTyped<SharedMemory<E>> {
133        pub fn __expand_index_unchecked_method(
134            self,
135            context: &mut CubeContext,
136            i: ExpandElementTyped<u32>,
137        ) -> ExpandElementTyped<E> {
138            let out = context.create_local(self.expand.item);
139            context.register(Instruction::new(
140                Operator::UncheckedIndex(BinaryOperator {
141                    lhs: *self.expand,
142                    rhs: i.expand.consume(),
143                }),
144                *out,
145            ));
146            out.into()
147        }
148
149        pub fn __expand_index_assign_unchecked_method(
150            self,
151            context: &mut CubeContext,
152            i: ExpandElementTyped<u32>,
153            value: ExpandElementTyped<E>,
154        ) {
155            context.register(Instruction::new(
156                Operator::UncheckedIndexAssign(BinaryOperator {
157                    lhs: i.expand.consume(),
158                    rhs: value.expand.consume(),
159                }),
160                *self.expand,
161            ));
162        }
163    }
164}