cubecl_std/tensor/view/operations/
slice.rs1use super::*;
2use crate::{CubeOption, CubeOptionExpand, tensor::layout::Coords1d};
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
5
6impl<T: CubePrimitive, IO: SliceVisibility> ViewOperations<T, Coords1d> for Slice<T, IO> {}
7impl<T: CubePrimitive, IO: SliceVisibility> ViewOperationsExpand<T, Coords1d>
8 for SliceExpand<T, IO>
9{
10 fn __expand_read_method(
11 &self,
12 scope: &mut Scope,
13 pos: ExpandElementTyped<u32>,
14 ) -> <T>::ExpandType {
15 <Self as ListExpand<T>>::__expand_read_method(self, scope, pos)
16 }
17
18 fn __expand_read_checked_method(
19 &self,
20 scope: &mut Scope,
21 pos: ExpandElementTyped<u32>,
22 ) -> <T>::ExpandType {
23 let len = self.__expand_len_method(scope);
24 let in_bounds = lt::expand(scope, pos.clone(), len);
25 let slice = self.clone().__expand_to_slice_method(scope);
26 let zero = T::__expand_cast_from(scope, 0.into());
27 read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
28 }
29
30 fn __expand_read_masked_method(
31 &self,
32 scope: &mut Scope,
33 pos: ExpandElementTyped<u32>,
34 mask_value: <T>::ExpandType,
35 ) -> <T>::ExpandType {
36 let len = self.__expand_len_method(scope);
37 let in_bounds = lt::expand(scope, pos.clone(), len);
38 let slice = self.clone().__expand_to_slice_method(scope);
39 read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
40 }
41
42 fn __expand_read_unchecked_method(
43 &self,
44 scope: &mut Scope,
45 pos: ExpandElementTyped<u32>,
46 ) -> <T>::ExpandType {
47 <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
48 }
49
50 fn __expand_to_linear_slice_method(
51 &self,
52 scope: &mut Scope,
53 pos: ExpandElementTyped<u32>,
54 end: ExpandElementTyped<u32>,
55 ) -> SliceExpand<T, ReadOnly> {
56 let end = add::expand(scope, end, 1u32.into());
58 let start = Min::__expand_min(scope, pos, end.clone());
61 <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
62 }
63
64 fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<T>> {
65 CubeOption::__expand_new_None(scope)
66 }
67
68 fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
69 <Self as ListExpand<T>>::__expand_len_method(self, scope)
70 }
71
72 fn __expand_is_in_bounds_method(
73 &self,
74 scope: &mut Scope,
75 pos: ExpandElementTyped<u32>,
76 ) -> ExpandElementTyped<bool> {
77 let len = self.__expand_shape_method(scope);
78 lt::expand(scope, pos, len)
79 }
80
81 fn __expand_tensor_map_load_method(
82 &self,
83 _scope: &mut Scope,
84 _barrier: BarrierExpand,
85 _shared_memory: SliceExpand<T, ReadWrite>,
86 _pos: ExpandElementTyped<u32>,
87 ) {
88 unimplemented!("Not a tensor map");
89 }
90}
91
92impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for Slice<T, ReadWrite> {}
93impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for SliceExpand<T, ReadWrite> {
94 fn __expand_write_method(
95 &self,
96 scope: &mut Scope,
97 pos: ExpandElementTyped<u32>,
98 value: <T>::ExpandType,
99 ) {
100 <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
101 }
102
103 fn __expand_write_checked_method(
104 &self,
105 scope: &mut Scope,
106 pos: ExpandElementTyped<u32>,
107 value: <T>::ExpandType,
108 ) {
109 let len = <Self as ListExpand<T>>::__expand_len_method(self, scope);
110 let in_bounds = lt::expand(scope, pos.clone(), len);
111 if_expand(scope, in_bounds.into(), |scope| {
112 <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
113 })
114 }
115
116 fn __expand_to_linear_slice_mut_method(
117 &self,
118 scope: &mut Scope,
119 pos: ExpandElementTyped<u32>,
120 end: ExpandElementTyped<u32>,
121 ) -> SliceExpand<T, ReadWrite> {
122 let end = add::expand(scope, end, 1u32.into());
124 let start = Min::__expand_min(scope, pos, end.clone());
127 <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(self, scope, start, end)
128 }
129
130 fn __expand_tensor_map_store_method(
131 &self,
132 _scope: &mut Scope,
133 _shared_memory: SliceExpand<T, ReadOnly>,
134 _pos: <Coords1d as CubeType>::ExpandType,
135 ) {
136 unimplemented!("Not a tensor map");
137 }
138}