cubecl_core/frontend/
list.rs

1use super::{CubeType, ExpandElementTyped};
2use crate::unexpanded;
3use cubecl_ir::Scope;
4
5/// Type from which we can read values in cube functions.
6/// For a mutable version, see [ListMut].
7pub trait List<T: CubeType>: CubeType<ExpandType: ListExpand<T>> {
8    #[allow(unused)]
9    fn read(&self, index: u32) -> T {
10        unexpanded!()
11    }
12
13    fn __expand_read(
14        scope: &mut Scope,
15        this: Self::ExpandType,
16        index: ExpandElementTyped<u32>,
17    ) -> T::ExpandType;
18}
19
20/// Expand version of [CubeRead].
21pub trait ListExpand<T: CubeType> {
22    fn __expand_read_method(
23        self,
24        scope: &mut Scope,
25        index: ExpandElementTyped<u32>,
26    ) -> T::ExpandType;
27}
28
29/// Type for which we can read and write values in cube functions.
30/// For an immutable version, see [List].
31pub trait ListMut<T: CubeType>: CubeType<ExpandType: ListMutExpand<T>> + List<T> {
32    #[allow(unused)]
33    fn write(&self, index: u32, value: T) {
34        unexpanded!()
35    }
36
37    fn __expand_write(
38        scope: &mut Scope,
39        this: Self::ExpandType,
40        index: ExpandElementTyped<u32>,
41        value: T::ExpandType,
42    );
43}
44
45/// Expand version of [CubeWrite].
46pub trait ListMutExpand<T: CubeType> {
47    fn __expand_write_method(
48        self,
49        scope: &mut Scope,
50        index: ExpandElementTyped<u32>,
51        value: T::ExpandType,
52    );
53}
54
55// Automatic implementation for mutable references to List.
56impl<'a, T: CubeType, L: List<T>> List<T> for &'a L
57where
58    &'a L: CubeType<ExpandType = L::ExpandType>,
59{
60    fn read(&self, index: u32) -> T {
61        L::read(self, index)
62    }
63
64    fn __expand_read(
65        scope: &mut Scope,
66        this: Self::ExpandType,
67        index: ExpandElementTyped<u32>,
68    ) -> <T as CubeType>::ExpandType {
69        L::__expand_read(scope, this, index)
70    }
71}
72
73// Automatic implementation for mutable references to List.
74impl<'a, T: CubeType, L: List<T>> List<T> for &'a mut L
75where
76    &'a mut L: CubeType<ExpandType = L::ExpandType>,
77{
78    fn read(&self, index: u32) -> T {
79        L::read(self, index)
80    }
81
82    fn __expand_read(
83        scope: &mut Scope,
84        this: Self::ExpandType,
85        index: ExpandElementTyped<u32>,
86    ) -> <T as CubeType>::ExpandType {
87        L::__expand_read(scope, this, index)
88    }
89}
90
91// Automatic implementation for references to ListMut.
92impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a L
93where
94    &'a L: CubeType<ExpandType = L::ExpandType>,
95{
96    fn write(&self, index: u32, value: T) {
97        L::write(self, index, value);
98    }
99
100    fn __expand_write(
101        scope: &mut Scope,
102        this: Self::ExpandType,
103        index: ExpandElementTyped<u32>,
104        value: T::ExpandType,
105    ) {
106        L::__expand_write(scope, this, index, value);
107    }
108}
109
110// Automatic implementation for references to ListMut.
111impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a mut L
112where
113    &'a mut L: CubeType<ExpandType = L::ExpandType>,
114{
115    fn write(&self, index: u32, value: T) {
116        L::write(self, index, value);
117    }
118
119    fn __expand_write(
120        scope: &mut Scope,
121        this: Self::ExpandType,
122        index: ExpandElementTyped<u32>,
123        value: T::ExpandType,
124    ) {
125        L::__expand_write(scope, this, index, value);
126    }
127}