Skip to main content

cubecl_std/tensor/view/operations/
slice.rs

1use super::*;
2use crate::tensor::layout::Coords1d;
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, io::read_masked, prelude::barrier::BarrierExpand};
5
6impl<T: CubePrimitive, IO: SliceVisibility> ViewOperations<T, Coords1d> for Slice<T, IO> {}
7impl<T: CubePrimitive, IO: SliceVisibility> ViewOperationsExpand<T, Coords1d>
8    for SliceExpand<T, IO>
9{
10    fn __expand_read_method(&self, scope: &mut Scope, pos: NativeExpand<usize>) -> <T>::ExpandType {
11        <Self as ListExpand<T>>::__expand_read_method(self, scope, pos)
12    }
13
14    fn __expand_read_checked_method(
15        &self,
16        scope: &mut Scope,
17        pos: NativeExpand<usize>,
18    ) -> <T>::ExpandType {
19        let len = self.__expand_len_method(scope);
20        let in_bounds = lt::expand(scope, pos.clone(), len);
21        let slice = self.clone().__expand_to_slice_method(scope);
22        let zero = T::__expand_cast_from(scope, 0.into());
23        read_masked::expand::<T>(scope, in_bounds, slice, pos, zero)
24    }
25
26    fn __expand_read_masked_method(
27        &self,
28        scope: &mut Scope,
29        pos: NativeExpand<usize>,
30        mask_value: <T>::ExpandType,
31    ) -> <T>::ExpandType {
32        let len = self.__expand_len_method(scope);
33        let in_bounds = lt::expand(scope, pos.clone(), len);
34        let slice = self.clone().__expand_to_slice_method(scope);
35        read_masked::expand::<T>(scope, in_bounds, slice, pos, mask_value)
36    }
37
38    fn __expand_read_unchecked_method(
39        &self,
40        scope: &mut Scope,
41        pos: NativeExpand<usize>,
42    ) -> <T>::ExpandType {
43        <Self as ListExpand<T>>::__expand_read_unchecked_method(self, scope, pos)
44    }
45
46    fn __expand_to_linear_slice_method(
47        &self,
48        scope: &mut Scope,
49        pos: NativeExpand<usize>,
50        end: NativeExpand<usize>,
51    ) -> SliceExpand<T, ReadOnly> {
52        // Convert to exclusive end
53        let end = add::expand(scope, end, 1usize.into());
54        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
55        // negative length.
56        let start = clamp_max::expand(scope, pos, end.clone());
57        <Self as SliceOperatorExpand<T>>::__expand_slice_method(self, scope, start, end)
58    }
59
60    fn __expand_shape_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
61        <Self as ListExpand<T>>::__expand_len_method(self, scope)
62    }
63
64    fn __expand_is_in_bounds_method(
65        &self,
66        scope: &mut Scope,
67        pos: NativeExpand<usize>,
68    ) -> NativeExpand<bool> {
69        let len = self.__expand_shape_method(scope);
70        lt::expand(scope, pos, len)
71    }
72
73    fn __expand_tensor_map_load_method(
74        &self,
75        _scope: &mut Scope,
76        _barrier: BarrierExpand,
77        _shared_memory: SliceExpand<T, ReadWrite>,
78        _pos: NativeExpand<usize>,
79    ) {
80        unimplemented!("Not a tensor map");
81    }
82}
83
84impl<T: CubePrimitive> ViewOperationsMut<T, Coords1d> for Slice<T, ReadWrite> {}
85impl<T: CubePrimitive> ViewOperationsMutExpand<T, Coords1d> for SliceExpand<T, ReadWrite> {
86    fn __expand_write_method(
87        &self,
88        scope: &mut Scope,
89        pos: NativeExpand<usize>,
90        value: <T>::ExpandType,
91    ) {
92        <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
93    }
94
95    fn __expand_write_checked_method(
96        &self,
97        scope: &mut Scope,
98        pos: NativeExpand<usize>,
99        value: <T>::ExpandType,
100    ) {
101        let len = <Self as ListExpand<T>>::__expand_len_method(self, scope);
102        let in_bounds = lt::expand(scope, pos.clone(), len);
103        if_expand(scope, in_bounds, |scope| {
104            <Self as ListMutExpand<T>>::__expand_write_method(self, scope, pos, value)
105        })
106    }
107
108    fn __expand_to_linear_slice_mut_method(
109        &self,
110        scope: &mut Scope,
111        pos: NativeExpand<usize>,
112        end: NativeExpand<usize>,
113    ) -> SliceExpand<T, ReadWrite> {
114        // Convert to exclusive end
115        let end = add::expand(scope, end, 1usize.into());
116        // Handling for shapes that are 0 in at least one dim, ensures the slice is not
117        // negative length.
118        let start = clamp_max::expand(scope, pos, end.clone());
119        <Self as SliceMutOperatorExpand<T>>::__expand_slice_mut_method(self, scope, start, end)
120    }
121
122    fn __expand_tensor_map_store_method(
123        &self,
124        _scope: &mut Scope,
125        _shared_memory: SliceExpand<T, ReadOnly>,
126        _pos: <Coords1d as CubeType>::ExpandType,
127    ) {
128        unimplemented!("Not a tensor map");
129    }
130}