1use super::{Slice, SliceArg};
4use alloc::vec::Vec;
5use core::ops::Range;
6
7pub use crate::errors::ExpressionError;
8
9pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape};
10
11pub trait SliceOps: Sized {
13 fn into_ranges(self) -> Vec<Range<usize>>;
15 fn into_slices<S>(self, slices: S) -> Vec<Slice>
75 where
76 S: SliceArg;
77 fn slice(self, slices: &[Slice]) -> Result<Self, MetadataError>;
79}
80
81impl SliceOps for Shape {
82 fn into_ranges(self) -> Vec<Range<usize>> {
83 self.iter().map(|&d| 0..d).collect()
84 }
85
86 fn into_slices<S>(self, slices: S) -> Vec<Slice>
87 where
88 S: SliceArg,
89 {
90 slices.into_slices(&self)
91 }
92
93 fn slice(mut self, slices: &[Slice]) -> Result<Self, MetadataError> {
94 if slices.len() > self.rank() {
95 return Err(MetadataError::RankMismatch {
96 left: self.rank(),
97 right: slices.len(),
98 });
99 }
100
101 slices
102 .iter()
103 .zip(self.iter_mut())
104 .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
105
106 Ok(self)
107 }
108}
109
110#[cfg(test)]
111#[allow(clippy::identity_op, reason = "useful for clarity")]
112mod tests {
113 use super::*;
114 use crate::s;
115 use alloc::vec;
116
117 #[test]
118 fn test_into_ranges() {
119 let dims = [2, 3, 4, 5];
120 let shape = Shape::new(dims);
121 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
122 }
123
124 #[allow(clippy::single_range_in_vec_init)]
125 #[test]
126 fn test_into_slices() {
127 let slices = Shape::new([3]).into_slices(1..4);
128 assert_eq!(slices[0].to_range(3), 1..3);
129
130 let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
131 assert_eq!(slices[0].to_range(3), 1..3);
132 assert_eq!(slices[1].to_range(4), 0..2);
133
134 let slices = Shape::new([3]).into_slices(..-2);
135 assert_eq!(slices[0].to_range(3), 0..1);
136
137 let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
138 assert_eq!(slices[0].to_range(2), 0..2);
139 assert_eq!(slices[1].to_range(3), 1..2);
140
141 let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
142 assert_eq!(slices[0].to_range(2), 0..2);
143 assert_eq!(slices[1].to_range(3), 2..3);
144 }
145
146 #[test]
147 fn test_shape_as_slice() {
148 let dims = [2, 3, 4, 5];
149 let shape = Shape::new(dims);
150
151 assert_eq!(shape.as_slice(), dims.as_slice());
152
153 let shape_slice: &[usize] = &shape;
155 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
156 }
157
158 #[test]
159 fn test_shape_as_mut_slice() {
160 let mut dims = [2, 3, 4, 5];
161 let mut shape = Shape::new(dims);
162
163 let shape_mut = shape.as_mut_slice();
164 assert_eq!(shape_mut, dims.as_mut_slice());
165 shape_mut[1] = 6;
166
167 assert_eq!(shape_mut, &[2, 6, 4, 5]);
168
169 let mut shape = Shape::new(dims);
170 let shape = &mut shape[..];
171 shape[1] = 6;
172
173 assert_eq!(shape, shape_mut)
174 }
175
176 #[test]
177 fn test_shape_slice_output_shape_basic() {
178 let slices = [
180 Slice::new(0, Some(5), 1), Slice::new(2, Some(8), 1), ];
183 let original_shape = Shape::new([10, 10, 10]);
184 let result = original_shape.slice(&slices).unwrap();
185 assert_eq!(result, Shape::new([5, 6, 10]));
186 }
187
188 #[test]
189 fn test_shape_slice_output_shape_with_positive_steps() {
190 let slices = [
192 Slice::new(0, Some(10), 2), Slice::new(1, Some(9), 3), Slice::new(0, Some(7), 4), ];
196 let original_shape = Shape::new([20, 20, 20, 30]);
197 let result = original_shape.slice(&slices).unwrap();
198 assert_eq!(result, Shape::new([5, 3, 2, 30]));
199 }
200
201 #[test]
202 fn test_shape_slice_output_shape_with_negative_steps() {
203 let slices = [
205 Slice::new(0, Some(10), -1), Slice::new(2, Some(8), -2), ];
208 let original_shape = Shape::new([20, 20, 20]);
209 let result = original_shape.slice(&slices).unwrap();
210 assert_eq!(result, Shape::new([10, 3, 20]));
211 }
212
213 #[test]
214 fn test_shape_slice_output_shape_mixed_steps() {
215 let slices = [
217 Slice::from_range_stepped(1..6, 1), Slice::from_range_stepped(0..10, -3), Slice::from_range_stepped(2..14, 4), ];
221 let original_shape = Shape::new([20, 20, 20]);
222 let result = original_shape.slice(&slices).unwrap();
223 assert_eq!(result, Shape::new([5, 4, 3]));
224 }
225
226 #[test]
227 fn test_shape_slice_output_shape_partial_dims() {
228 let slices = [
230 Slice::from_range_stepped(2..7, 2), ];
232 let original_shape = Shape::new([10, 20, 30, 40]);
233 let result = original_shape.slice(&slices).unwrap();
234 assert_eq!(result, Shape::new([3, 20, 30, 40]));
235 }
236
237 #[test]
238 fn test_shape_slice_output_shape_edge_cases() {
239 let slices = [
241 Slice::from_range_stepped(0..1, 1), Slice::from_range_stepped(0..10, 100), Slice::from_range_stepped(5..5, 1), ];
245 let original_shape = Shape::new([10, 20, 30]);
246 let result = original_shape.slice(&slices).unwrap();
247 assert_eq!(result, Shape::new([1, 1, 0]));
248 }
249
250 #[test]
251 fn test_shape_slice_output_shape_empty() {
252 let slices = [];
254 let original_shape = Shape::new([10, 20, 30]);
255 let result = original_shape.slice(&slices).unwrap();
256 assert_eq!(result, Shape::new([10, 20, 30]));
257 }
258
259 #[test]
260 fn test_shape_slice_output_shape_uneven_division() {
261 let slices = [
263 Slice::from_range_stepped(0..7, 3), Slice::from_range_stepped(0..11, 4), Slice::from_range_stepped(1..10, 5), ];
267 let original_shape = Shape::new([20, 20, 20]);
268 let result = original_shape.slice(&slices).unwrap();
269 assert_eq!(result, Shape::new([3, 3, 2]));
270 }
271}