1use std::cmp::Reverse;
2use std::fmt::{Debug, Formatter};
3use std::iter::zip;
4
5use itertools::{zip_eq, Itertools};
6use kn_cuda_sys::bindings::cudnnDataType_t;
7
8use kn_cuda_sys::wrapper::descriptor::{FilterDescriptor, MatrixLayout, TensorDescriptor};
9use kn_graph::graph::SliceRange;
10
11#[derive(Clone, Eq, PartialEq)]
12pub struct StridedShape {
13 shape: Vec<usize>,
14 strides: Vec<isize>,
15 has_simple_strides: bool,
16 has_dense_strides: bool,
17}
18
19#[derive(Clone, Eq, PartialEq)]
20pub struct ViewError {
21 old: StridedShape,
22 new: Vec<usize>,
23}
24
25impl StridedShape {
26 pub fn new(shape: Vec<usize>, strides: Vec<isize>) -> Self {
27 assert_eq!(shape.len(), strides.len(), "Shape and stride rank mismatch");
28
29 let has_simple_strides = &strides == &simple_strides(&shape);
30 let has_dense_strides = has_dense_strides(&shape, &strides);
31
32 if has_simple_strides {
33 assert!(
34 has_dense_strides,
35 "Simple should imply dense, {{ shape: {:?}, strides: {:?} }}",
36 shape, strides
37 );
38 }
39
40 let result = StridedShape {
41 shape,
42 strides,
43 has_simple_strides,
44 has_dense_strides,
45 };
46
47 result
48 }
49
50 pub fn new_simple(shape: Vec<usize>) -> Self {
51 let strides = simple_strides(&shape);
52 StridedShape::new(shape, strides)
53 }
54
55 pub fn shape(&self) -> &[usize] {
56 &self.shape
57 }
58
59 pub fn strides(&self) -> &[isize] {
60 &self.strides
61 }
62
63 pub fn rank(&self) -> usize {
64 self.shape.len()
65 }
66
67 pub fn has_simple_strides(&self) -> bool {
68 self.has_simple_strides
69 }
70
71 pub fn has_dense_strides(&self) -> bool {
72 self.has_dense_strides
73 }
74
75 pub fn visit_strided_indices(&self, mut f: impl FnMut(isize)) {
76 visit_strided_indices_impl(0, &self.shape, &self.strides, &mut f)
77 }
78
79 pub fn size(&self) -> usize {
80 self.shape.iter().copied().product()
81 }
82
83 pub fn slice(&self, axis: usize, range: SliceRange) -> StridedShape {
84 assert!(axis < self.rank(), "Rank {} out of bounds for {:?}", self.rank(), self);
85 range.assert_in_bounds(self.shape[axis]);
86
87 let mut new_shape = self.shape.clone();
88 let mut new_strides = self.strides.clone();
89
90 let SliceRange { start, end, step } = range;
91
92 new_shape[axis] = (end - start) / step;
93 new_strides[axis] *= step as isize;
94
95 StridedShape::new(new_shape, new_strides)
96 }
97
98 pub fn flip(&self, axis: usize) -> StridedShape {
99 let new_shape = self.shape.clone();
100 let mut new_strides = self.strides.clone();
101
102 new_strides[axis] *= -1;
104
105 StridedShape::new(new_shape, new_strides)
106 }
107
108 pub fn broadcast(&self, new_shape: Vec<usize>) -> StridedShape {
109 assert_eq!(
110 self.rank(),
111 new_shape.len(),
112 "Can only broadcast to same rank, got {:?} and {:?}",
113 self,
114 new_shape
115 );
116
117 let new_strides = (0..self.rank())
118 .map(|i| {
119 if new_shape[i] == self.shape[i] {
120 self.strides[i]
121 } else {
122 assert_eq!(
123 self.shape[i], 1,
124 "Broadcast mismatch between {:?} and {:?} at axis {}",
125 self, new_shape, i
126 );
127 0
128 }
129 })
130 .collect_vec();
131
132 StridedShape::new(new_shape, new_strides)
133 }
134
135 pub fn view(&self, new_shape: Vec<usize>) -> Result<StridedShape, ViewError> {
136 let new_size = new_shape.iter().copied().product::<usize>();
140 assert_eq!(
141 self.size(),
142 new_size,
143 "Size cannot change during view, cannot go from {:?} to {:?}",
144 self,
145 new_shape
146 );
147
148 if self.size() == 0 || self.rank() == 0 {
149 return Ok(StridedShape::new_simple(new_shape));
150 }
151
152 let mut new_strides = vec![0; new_shape.len()];
153 let mut next_d = 0;
154
155 let mut failed = false;
156
157 self.for_each_continuous_group(|group_size, group_stride| {
158 if failed {
159 return;
160 };
161
162 let mut left_group_size = group_size;
163 while left_group_size > 1 {
164 if left_group_size % new_shape[next_d] == 0 {
165 left_group_size /= new_shape[next_d];
166 new_strides[next_d] = left_group_size as isize * group_stride;
167 next_d += 1;
168 } else {
169 failed = true;
170 return;
171 }
172 }
173 });
174
175 if failed {
176 Err(ViewError {
177 old: self.clone(),
178 new: new_shape,
179 })
180 } else {
181 for d in next_d..new_shape.len() {
183 assert_eq!(new_shape[d], 1);
184 new_strides[d] = 1;
185 }
186
187 Ok(StridedShape::new(new_shape, new_strides))
188 }
189 }
190
191 fn for_each_continuous_group(&self, mut f: impl FnMut(usize, isize)) {
192 if self.size() == 0 || self.rank() == 0 {
193 f(0, 1);
194 return;
195 }
196
197 let mut group_size = 1;
198 let mut prev_stride = None;
199
200 for (&d_size, &d_stride) in zip_eq(&self.shape, &self.strides) {
201 if let Some(prev_stride) = prev_stride {
202 if prev_stride != d_size as isize * d_stride {
203 f(group_size, prev_stride);
205 group_size = 1;
206 }
207 }
208
209 group_size *= d_size;
210 prev_stride = Some(d_stride)
211 }
212
213 if let Some(prev_stride) = prev_stride {
214 f(group_size, prev_stride)
216 }
217 }
218
219 pub fn permute(&self, permutation: &[usize]) -> StridedShape {
220 assert_eq!(permutation.len(), self.rank());
221 assert!(permutation.iter().all_unique());
222
223 let new_shape = permutation.iter().map(|&i| self.shape()[i]).collect();
225 let new_strides = permutation.iter().map(|&i| self.strides()[i]).collect();
226
227 StridedShape::new(new_shape, new_strides)
228 }
229
230 pub fn repeat_unary(&self, axis: usize, count: usize) -> StridedShape {
231 assert!(axis < self.rank());
232 assert_eq!(self.shape[axis], 1);
233
234 let mut new_shape = self.shape.clone();
235 let mut new_strides = self.strides.clone();
236
237 new_shape[axis] = count;
238 new_strides[axis] = 0;
239
240 StridedShape::new(new_shape, new_strides)
241 }
242
243 pub fn descriptor(&self, dtype: cudnnDataType_t) -> TensorDescriptor {
244 let mut shape = self.shape.iter().map(|&x| x as i32).collect_vec();
245 let mut strides = self.strides.iter().map(|&x| x as i32).collect_vec();
246
247 while shape.len() < 4 {
250 shape.push(1);
251 strides.push(1);
252 }
253
254 TensorDescriptor::new(shape, strides, dtype)
255 }
256
257 pub fn filter_descriptor(&self, dtype: cudnnDataType_t) -> FilterDescriptor {
258 assert_eq!(4, self.rank(), "Filter must have rank 4");
259 assert!(self.has_simple_strides(), "Filter must have simple strides");
260
261 let dims = self.shape();
262 FilterDescriptor::new(dims[0] as i32, dims[1] as i32, dims[2] as i32, dims[3] as i32, dtype)
263 }
264
265 pub fn matrix_layout(&self) -> MatrixLayout {
266 assert_eq!(3, self.rank(), "Matrix must have rank 3");
267
268 let shape = [self.shape[0], self.shape[1], self.shape[2]];
269 let strides = [self.strides[0], self.strides[1], self.strides[2]];
270
271 MatrixLayout::new(shape, strides).unwrap_or_else(|| panic!("Failed to convert {:?} to MatrixLayout", self))
272 }
273
274 pub fn remove(&self, axis: usize) -> StridedShape {
275 assert!(axis < self.rank(), "Axis {} out of bounds for {:?}", axis, self);
276
277 let mut new_shape = self.shape.clone();
278 let mut new_strides = self.strides.clone();
279
280 new_shape.remove(axis);
281 new_strides.remove(axis);
282
283 StridedShape::new(new_shape, new_strides)
284 }
285}
286
287impl Debug for StridedShape {
288 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
289 f.debug_struct("StridedShape")
290 .field("shape", &self.shape)
291 .field("strides", &self.strides)
292 .finish()
293 }
294}
295
296fn simple_strides(shape: &[usize]) -> Vec<isize> {
297 let mut result = vec![];
298 let mut next_stride = 1;
299
300 for &size in shape.iter().rev() {
301 result.push(next_stride as isize);
302 next_stride *= size;
303 }
304
305 result.reverse();
306 result
307}
308
309fn has_dense_strides(shape: &[usize], strides: &[isize]) -> bool {
312 assert_eq!(shape.len(), strides.len());
313
314 if shape.iter().copied().product::<usize>() == 0 {
315 return true;
316 }
317
318 let pairs = zip(shape.iter().copied(), strides.iter().copied().map(|x| x.abs()))
319 .sorted_by_key(|x| Reverse(x.1))
320 .collect_vec();
321
322 let sorted_shape = pairs.iter().map(|&x| x.0).collect_vec();
323 let sorted_strides = pairs.iter().map(|&x| x.1).collect_vec();
324
325 simple_strides(&sorted_shape) == sorted_strides
326}
327
328fn visit_strided_indices_impl(start: isize, shape: &[usize], strides: &[isize], f: &mut impl FnMut(isize)) {
329 match shape {
330 [] => f(start as isize),
331 [size_curr, size_rest @ ..] => {
332 for i in 0..*size_curr {
333 let i_start = start + i as isize * strides[0];
334 visit_strided_indices_impl(i_start, size_rest, &strides[1..], f)
335 }
336 }
337 }
338}
339
340impl Debug for ViewError {
341 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
342 write!(f, "Cannot view shape {:?} as {:?}", self.old, self.new)
343 }
344}
345
346#[cfg(test)]
347mod test {
348 use kn_graph::graph::SliceRange;
349
350 use crate::shape::StridedShape;
351
352 #[test]
353 fn properties_positive() {
354 let simple = StridedShape::new(vec![2, 3], vec![3, 1]);
355 assert!(simple.has_simple_strides);
356 assert!(simple.has_dense_strides);
357
358 let dense = StridedShape::new(vec![3, 2], vec![1, 3]);
359 assert!(!dense.has_simple_strides);
360 assert!(dense.has_dense_strides);
361
362 let neither = StridedShape::new(vec![3, 2], vec![8, 10]);
363 assert!(!neither.has_simple_strides);
364 assert!(!neither.has_dense_strides);
365 }
366
367 #[test]
368 fn properties_negative() {
369 let simple = StridedShape::new(vec![2, 3], vec![3, -1]);
370 assert!(!simple.has_simple_strides);
371 assert!(simple.has_dense_strides);
372 }
373
374 fn collect_groups(shape: &StridedShape) -> (Vec<usize>, Vec<isize>) {
375 let mut sizes = vec![];
376 let mut strides = vec![];
377 shape.for_each_continuous_group(|group_size, group_stride| {
378 sizes.push(group_size);
379 strides.push(group_stride);
380 });
381 (sizes, strides)
382 }
383
384 #[test]
385 fn view_rank_zero() {
386 let shape = StridedShape::new(vec![], vec![]);
387 assert_eq!(collect_groups(&shape), (vec![0], vec![1]),);
388 assert_eq!(
389 shape.view(vec![1, 1, 1]),
390 Ok(StridedShape::new(vec![1, 1, 1], vec![1, 1, 1])),
391 );
392 }
393
394 #[test]
395 fn view_size_zero() {
396 let shape = StridedShape::new(vec![2, 3, 0, 5], vec![0, 0, 0, 2]);
397 assert_eq!(collect_groups(&shape), (vec![0], vec![1]));
398 assert_eq!(shape.view(vec![0]), Ok(StridedShape::new(vec![0], vec![1])));
399 assert_eq!(shape.view(vec![12, 0]), Ok(StridedShape::new(vec![12, 0], vec![0, 1])),);
400 }
401
402 #[test]
403 fn view_simple() {
404 let shape = StridedShape::new(vec![2, 3, 4, 3, 2], vec![72, 24, 6, 2, 1]);
405 assert!(shape.has_simple_strides());
406 assert_eq!(collect_groups(&shape), (vec![144], vec![1]));
407 assert_eq!(shape.view(vec![144]), Ok(StridedShape::new(vec![144], vec![1])),);
408 assert_eq!(shape.view(vec![72, 2]), Ok(StridedShape::new(vec![72, 2], vec![2, 1])),);
409 assert_eq!(
410 shape.view(vec![72, 2, 1, 1, 1]),
411 Ok(StridedShape::new(vec![72, 2, 1, 1, 1], vec![2, 1, 1, 1, 1])),
412 );
413 }
414
415 #[test]
416 fn view_split() {
417 let shape = StridedShape::new(vec![2, 3, 4], vec![24, 8, 1]);
418 assert_eq!(collect_groups(&shape), (vec![6, 4], vec![8, 1]));
419 assert_eq!(shape.view(vec![6, 4]), Ok(StridedShape::new(vec![6, 4], vec![8, 1])),);
420 assert!(shape.view(vec![24]).is_err());
421 }
422
423 #[test]
424 fn slice_simple() {
425 let shape = StridedShape::new(vec![2, 3, 4], vec![24, 8, 1]);
426 assert_eq!(
427 shape.slice(1, SliceRange::new(0, 4, 2)),
428 StridedShape::new(vec![2, 2, 4], vec![24, 16, 1])
429 )
430 }
431}