Skip to main content

burn_std/tensor/
shape.rs

1//! Tensor shape definition.
2
3use super::{Slice, SliceArg};
4use alloc::vec::Vec;
5use core::ops::Range;
6
7pub use crate::errors::ExpressionError;
8
9pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape};
10
11/// Slice-related ops on [`Shape`]
12pub trait SliceOps: Sized {
13    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
14    fn into_ranges(self) -> Vec<Range<usize>>;
15    /// Converts slice arguments into an array of slice specifications for the shape.
16    ///
17    /// This method returns an array of `Slice` objects that can be used for slicing operations.
18    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
19    /// allows custom slice specifications instead of full ranges.
20    /// For creating complex slice specifications, use the [`s!`] macro.
21    ///
22    /// # Arguments
23    ///
24    /// * `slices` - An array of slice specifications, where each element can be:
25    ///   - A range (e.g., `2..5`)
26    ///   - An index
27    ///   - A `Slice` object
28    ///   - The output of the [`s!`] macro for advanced slicing
29    ///
30    /// # Behavior
31    ///
32    /// - Supports partial and full slicing in any number of dimensions.
33    /// - Missing ranges are treated as full slices if D > D2.
34    /// - Handles negative indices by wrapping around from the end of the dimension.
35    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.
36    ///
37    /// # Returns
38    ///
39    /// An array of `Slice` objects corresponding to the provided slice specifications,
40    /// clamped to the shape's actual dimensions.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use burn_std::{Shape, Slice, s, SliceOps};
46    ///
47    /// fn example() {
48    ///     // 1D slicing
49    ///     let slices = Shape::new([4]).into_slices(1..4);
50    ///     assert_eq!(slices[0].to_range(4), 1..3);
51    ///
52    ///     // 2D slicing
53    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
54    ///     assert_eq!(slices[0].to_range(3), 1..3);
55    ///     assert_eq!(slices[1].to_range(4), 0..2);
56    ///
57    ///     // Using negative indices
58    ///     let slices = Shape::new([3]).into_slices(..-2);
59    ///     assert_eq!(slices[0].to_range(3), 0..1);
60    ///
61    ///     // Using the slice macro to select different ranges
62    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
63    ///     assert_eq!(slices[0].to_range(2), 0..2);
64    ///     assert_eq!(slices[1].to_range(3), 1..2);
65    /// }
66    /// ```
67    ///
68    /// # See Also
69    ///
70    /// - [`s!`] - The recommended macro for creating slice specifications
71    /// - [`Shape::into_ranges`] - Convert to full covering ranges
72    ///
73    /// [`s!`]: crate::s!
74    fn into_slices<S>(self, slices: S) -> Vec<Slice>
75    where
76        S: SliceArg;
77    /// Compute the output shape from the given slices.
78    fn slice(self, slices: &[Slice]) -> Result<Self, MetadataError>;
79}
80
81impl SliceOps for Shape {
82    fn into_ranges(self) -> Vec<Range<usize>> {
83        self.iter().map(|&d| 0..d).collect()
84    }
85
86    fn into_slices<S>(self, slices: S) -> Vec<Slice>
87    where
88        S: SliceArg,
89    {
90        slices.into_slices(&self)
91    }
92
93    fn slice(mut self, slices: &[Slice]) -> Result<Self, MetadataError> {
94        if slices.len() > self.rank() {
95            return Err(MetadataError::RankMismatch {
96                left: self.rank(),
97                right: slices.len(),
98            });
99        }
100
101        slices
102            .iter()
103            .zip(self.iter_mut())
104            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
105
106        Ok(self)
107    }
108}
109
110#[cfg(test)]
111#[allow(clippy::identity_op, reason = "useful for clarity")]
112mod tests {
113    use super::*;
114    use crate::s;
115    use alloc::vec;
116
117    #[test]
118    fn test_into_ranges() {
119        let dims = [2, 3, 4, 5];
120        let shape = Shape::new(dims);
121        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
122    }
123
124    #[allow(clippy::single_range_in_vec_init)]
125    #[test]
126    fn test_into_slices() {
127        let slices = Shape::new([3]).into_slices(1..4);
128        assert_eq!(slices[0].to_range(3), 1..3);
129
130        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
131        assert_eq!(slices[0].to_range(3), 1..3);
132        assert_eq!(slices[1].to_range(4), 0..2);
133
134        let slices = Shape::new([3]).into_slices(..-2);
135        assert_eq!(slices[0].to_range(3), 0..1);
136
137        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
138        assert_eq!(slices[0].to_range(2), 0..2);
139        assert_eq!(slices[1].to_range(3), 1..2);
140
141        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
142        assert_eq!(slices[0].to_range(2), 0..2);
143        assert_eq!(slices[1].to_range(3), 2..3);
144    }
145
146    #[test]
147    fn test_shape_as_slice() {
148        let dims = [2, 3, 4, 5];
149        let shape = Shape::new(dims);
150
151        assert_eq!(shape.as_slice(), dims.as_slice());
152
153        // Deref coercion
154        let shape_slice: &[usize] = &shape;
155        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
156    }
157
158    #[test]
159    fn test_shape_as_mut_slice() {
160        let mut dims = [2, 3, 4, 5];
161        let mut shape = Shape::new(dims);
162
163        let shape_mut = shape.as_mut_slice();
164        assert_eq!(shape_mut, dims.as_mut_slice());
165        shape_mut[1] = 6;
166
167        assert_eq!(shape_mut, &[2, 6, 4, 5]);
168
169        let mut shape = Shape::new(dims);
170        let shape = &mut shape[..];
171        shape[1] = 6;
172
173        assert_eq!(shape, shape_mut)
174    }
175
176    #[test]
177    fn test_shape_slice_output_shape_basic() {
178        // Test basic slicing with step=1
179        let slices = [
180            Slice::new(0, Some(5), 1), // 5 elements
181            Slice::new(2, Some(8), 1), // 6 elements
182        ];
183        let original_shape = Shape::new([10, 10, 10]);
184        let result = original_shape.slice(&slices).unwrap();
185        assert_eq!(result, Shape::new([5, 6, 10]));
186    }
187
188    #[test]
189    fn test_shape_slice_output_shape_with_positive_steps() {
190        // Test slicing with various positive steps
191        let slices = [
192            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
193            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements
194            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements
195        ];
196        let original_shape = Shape::new([20, 20, 20, 30]);
197        let result = original_shape.slice(&slices).unwrap();
198        assert_eq!(result, Shape::new([5, 3, 2, 30]));
199    }
200
201    #[test]
202    fn test_shape_slice_output_shape_with_negative_steps() {
203        // Test slicing with negative steps (backward iteration)
204        let slices = [
205            Slice::new(0, Some(10), -1), // 10 elements traversed backward
206            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements
207        ];
208        let original_shape = Shape::new([20, 20, 20]);
209        let result = original_shape.slice(&slices).unwrap();
210        assert_eq!(result, Shape::new([10, 3, 20]));
211    }
212
213    #[test]
214    fn test_shape_slice_output_shape_mixed_steps() {
215        // Test with a mix of positive, negative, and unit steps
216        let slices = [
217            Slice::from_range_stepped(1..6, 1),   // 5 elements
218            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
219            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements
220        ];
221        let original_shape = Shape::new([20, 20, 20]);
222        let result = original_shape.slice(&slices).unwrap();
223        assert_eq!(result, Shape::new([5, 4, 3]));
224    }
225
226    #[test]
227    fn test_shape_slice_output_shape_partial_dims() {
228        // Test when slices has fewer dimensions than original shape
229        let slices = [
230            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
231        ];
232        let original_shape = Shape::new([10, 20, 30, 40]);
233        let result = original_shape.slice(&slices).unwrap();
234        assert_eq!(result, Shape::new([3, 20, 30, 40]));
235    }
236
237    #[test]
238    fn test_shape_slice_output_shape_edge_cases() {
239        // Test edge cases with small ranges and large steps
240        let slices = [
241            Slice::from_range_stepped(0..1, 1),    // Single element
242            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
243            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements
244        ];
245        let original_shape = Shape::new([10, 20, 30]);
246        let result = original_shape.slice(&slices).unwrap();
247        assert_eq!(result, Shape::new([1, 1, 0]));
248    }
249
250    #[test]
251    fn test_shape_slice_output_shape_empty() {
252        // Test with no slice infos (should return original shape)
253        let slices = [];
254        let original_shape = Shape::new([10, 20, 30]);
255        let result = original_shape.slice(&slices).unwrap();
256        assert_eq!(result, Shape::new([10, 20, 30]));
257    }
258
259    #[test]
260    fn test_shape_slice_output_shape_uneven_division() {
261        // Test cases where range size doesn't divide evenly by step
262        let slices = [
263            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
264            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
265            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
266        ];
267        let original_shape = Shape::new([20, 20, 20]);
268        let result = original_shape.slice(&slices).unwrap();
269        assert_eq!(result, Shape::new([3, 3, 2]));
270    }
271}