cubecl_std/tensor/view/operations/
slice.rs

1use super::*;
2use crate::{CubeOption, CubeOptionExpand, tensor::layout::Coords1d};
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
5
6impl<T: CubePrimitive, IO: SliceVisibility> ViewOperations<T, Coords1d> for Slice<T, IO> {}
7impl<T: CubePrimitive, IO: SliceVisibility> ViewOperationsExpand<T, Coords1d>
8    for SliceExpand<T, IO>
9{
10    fn __expand_read_method(
11        &self,
12        scope: &mut Scope,
13        pos: ExpandElementTyped<u32>,
14    ) -> <T>::ExpandType {
15        <Self as ListExpand<T>>::__expand_read_method(self, scope, pos)
16    }
17
18    fn __expand_read_checked_method(
19        &self,
20        scope: &mut Scope,
21        pos: ExpandElementTyped<u32>,
22    ) -> <T>::ExpandType {
23        let len = self.__expand_len_method(scope);
24        let in_bounds = lt::expand(scope, pos.clone(), len);
25        let slice = self.clone().__expand_to_slice_method(scope);
26        let zero = T::__expand_cast_from(scope, 0.into());
27        read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
28    }
29
30    fn __expand_read_masked_method(
31        &self,
32        scope: &mut Scope,
33        pos: ExpandElementTyped<u32>,
34        mask_value: <T>::ExpandType,
35    ) -> <T>::ExpandType {
36        let len = self.__expand_len_method(scope);
37        let in_bounds = lt::expand(scope, pos.clone(), len);
38        let slice = self.clone().__expand_to_slice_method(scope);
39        read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
40    }
41
42    fn __expand_read_unchecked_method(
43        &self,
44        scope: &mut Scope,
45        pos: ExpandElementTyped<u32>,
46    ) -> <T>::ExpandType {
47        <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
48    }
49
50    fn __expand_to_linear_slice_method(
51        &self,
52        scope: &mut Scope,
53        pos: ExpandElementTyped<u32>,
54        end: ExpandElementTyped<u32>,
55    ) -> SliceExpand<T, ReadOnly> {
56        // Convert to exclusive end
57        let end = add::expand(scope, end, 1u32.into());
58        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
59        // negative length.
60        let start = Min::__expand_min(scope, pos, end.clone());
61        <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
62    }
63
64    fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<T>> {
65        CubeOption::__expand_new_None(scope)
66    }
67
68    fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
69        <Self as ListExpand<T>>::__expand_len_method(self, scope)
70    }
71
72    fn __expand_is_in_bounds_method(
73        &self,
74        scope: &mut Scope,
75        pos: ExpandElementTyped<u32>,
76    ) -> ExpandElementTyped<bool> {
77        let len = self.__expand_shape_method(scope);
78        lt::expand(scope, pos, len)
79    }
80
81    fn __expand_tensor_map_load_method(
82        &self,
83        _scope: &mut Scope,
84        _barrier: BarrierExpand,
85        _shared_memory: SliceExpand<T, ReadWrite>,
86        _pos: ExpandElementTyped<u32>,
87    ) {
88        unimplemented!("Not a tensor map");
89    }
90}
91
92impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for Slice<T, ReadWrite> {}
93impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for SliceExpand<T, ReadWrite> {
94    fn __expand_write_method(
95        &self,
96        scope: &mut Scope,
97        pos: ExpandElementTyped<u32>,
98        value: <T>::ExpandType,
99    ) {
100        <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
101    }
102
103    fn __expand_write_checked_method(
104        &self,
105        scope: &mut Scope,
106        pos: ExpandElementTyped<u32>,
107        value: <T>::ExpandType,
108    ) {
109        let len = <Self as ListExpand<T>>::__expand_len_method(self, scope);
110        let in_bounds = lt::expand(scope, pos.clone(), len);
111        if_expand(scope, in_bounds.into(), |scope| {
112            <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
113        })
114    }
115
116    fn __expand_to_linear_slice_mut_method(
117        &self,
118        scope: &mut Scope,
119        pos: ExpandElementTyped<u32>,
120        end: ExpandElementTyped<u32>,
121    ) -> SliceExpand<T, ReadWrite> {
122        // Convert to exclusive end
123        let end = add::expand(scope, end, 1u32.into());
124        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
125        // negative length.
126        let start = Min::__expand_min(scope, pos, end.clone());
127        <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(self, scope, start, end)
128    }
129
130    fn __expand_tensor_map_store_method(
131        &self,
132        _scope: &mut Scope,
133        _shared_memory: SliceExpand<T, ReadOnly>,
134        _pos: <Coords1d as CubeType>::ExpandType,
135    ) {
136        unimplemented!("Not a tensor map");
137    }
138}