cubecl_core/frontend/
list.rs1use std::ops::{Deref, DerefMut};
2
3use super::{CubeType, ExpandElementTyped};
4use crate as cubecl;
5use crate::{prelude::*, unexpanded};
6use cubecl_ir::{LineSize, Scope};
7
8#[allow(clippy::len_without_is_empty)]
11#[cube(self_type = "ref", expand_base_traits = "SliceOperatorExpand<T>")]
12pub trait List<T: CubePrimitive>: SliceOperator<T> + Lined + Deref<Target = [T]> {
13 #[allow(unused)]
14 fn read(&self, index: usize) -> T {
15 unexpanded!()
16 }
17
18 #[allow(unused)]
19 fn read_unchecked(&self, index: usize) -> T {
20 unexpanded!()
21 }
22
23 #[allow(unused)]
24 fn len(&self) -> usize {
25 unexpanded!();
26 }
27}
28
29#[cube(self_type = "ref", expand_base_traits = "SliceMutOperatorExpand<T>")]
32pub trait ListMut<T: CubePrimitive>:
33 List<T> + SliceMutOperator<T> + DerefMut<Target = [T]>
34{
35 #[allow(unused)]
36 fn write(&self, index: usize, value: T) {
37 unexpanded!()
38 }
39}
40
41impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a L
43where
44 &'a L: CubeType<ExpandType = L::ExpandType>,
45 &'a L: Deref<Target = [T]>,
46{
47 fn read(&self, index: usize) -> T {
48 L::read(self, index)
49 }
50
51 fn __expand_read(
52 scope: &mut Scope,
53 this: Self::ExpandType,
54 index: ExpandElementTyped<usize>,
55 ) -> <T as CubeType>::ExpandType {
56 L::__expand_read(scope, this, index)
57 }
58}
59
60impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a mut L
62where
63 &'a mut L: CubeType<ExpandType = L::ExpandType>,
64 &'a mut L: Deref<Target = [T]>,
65{
66 fn read(&self, index: usize) -> T {
67 L::read(self, index)
68 }
69
70 fn __expand_read(
71 scope: &mut Scope,
72 this: Self::ExpandType,
73 index: ExpandElementTyped<usize>,
74 ) -> <T as CubeType>::ExpandType {
75 L::__expand_read(scope, this, index)
76 }
77}
78
79impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a L
81where
82 &'a L: CubeType<ExpandType = L::ExpandType>,
83 &'a L: DerefMut<Target = [T]>,
84{
85 fn write(&self, index: usize, value: T) {
86 L::write(self, index, value);
87 }
88
89 fn __expand_write(
90 scope: &mut Scope,
91 this: Self::ExpandType,
92 index: ExpandElementTyped<usize>,
93 value: T::ExpandType,
94 ) {
95 L::__expand_write(scope, this, index, value);
96 }
97}
98
99impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a mut L
101where
102 &'a mut L: CubeType<ExpandType = L::ExpandType>,
103 &'a mut L: DerefMut<Target = [T]>,
104{
105 fn write(&self, index: usize, value: T) {
106 L::write(self, index, value);
107 }
108
109 fn __expand_write(
110 scope: &mut Scope,
111 this: Self::ExpandType,
112 index: ExpandElementTyped<usize>,
113 value: T::ExpandType,
114 ) {
115 L::__expand_write(scope, this, index, value);
116 }
117}
118
119pub trait Lined: CubeType<ExpandType: LinedExpand> {
120 fn line_size(&self) -> LineSize {
121 unexpanded!()
122 }
123 fn __expand_line_size(_scope: &mut Scope, this: Self::ExpandType) -> LineSize {
124 this.line_size()
125 }
126}
127
128pub trait LinedExpand {
129 fn line_size(&self) -> LineSize;
130 fn __expand_line_size_method(&self, _scope: &mut Scope) -> LineSize {
131 self.line_size()
132 }
133}
134
135impl<'a, L: Lined> Lined for &'a L where &'a L: CubeType<ExpandType: LinedExpand> {}
136impl<'a, L: Lined> Lined for &'a mut L where &'a mut L: CubeType<ExpandType: LinedExpand> {}