kn_cuda_eval/
shape.rs

1use std::cmp::Reverse;
2use std::fmt::{Debug, Formatter};
3use std::iter::zip;
4
5use itertools::{zip_eq, Itertools};
6use kn_cuda_sys::bindings::cudnnDataType_t;
7
8use kn_cuda_sys::wrapper::descriptor::{FilterDescriptor, MatrixLayout, TensorDescriptor};
9use kn_graph::graph::SliceRange;
10
11#[derive(Clone, Eq, PartialEq)]
12pub struct StridedShape {
13    shape: Vec<usize>,
14    strides: Vec<isize>,
15    has_simple_strides: bool,
16    has_dense_strides: bool,
17}
18
19#[derive(Clone, Eq, PartialEq)]
20pub struct ViewError {
21    old: StridedShape,
22    new: Vec<usize>,
23}
24
25impl StridedShape {
26    pub fn new(shape: Vec<usize>, strides: Vec<isize>) -> Self {
27        assert_eq!(shape.len(), strides.len(), "Shape and stride rank mismatch");
28
29        let has_simple_strides = &strides == &simple_strides(&shape);
30        let has_dense_strides = has_dense_strides(&shape, &strides);
31
32        if has_simple_strides {
33            assert!(
34                has_dense_strides,
35                "Simple should imply dense, {{ shape: {:?}, strides: {:?} }}",
36                shape, strides
37            );
38        }
39
40        let result = StridedShape {
41            shape,
42            strides,
43            has_simple_strides,
44            has_dense_strides,
45        };
46
47        result
48    }
49
50    pub fn new_simple(shape: Vec<usize>) -> Self {
51        let strides = simple_strides(&shape);
52        StridedShape::new(shape, strides)
53    }
54
55    pub fn shape(&self) -> &[usize] {
56        &self.shape
57    }
58
59    pub fn strides(&self) -> &[isize] {
60        &self.strides
61    }
62
63    pub fn rank(&self) -> usize {
64        self.shape.len()
65    }
66
67    pub fn has_simple_strides(&self) -> bool {
68        self.has_simple_strides
69    }
70
71    pub fn has_dense_strides(&self) -> bool {
72        self.has_dense_strides
73    }
74
75    pub fn visit_strided_indices(&self, mut f: impl FnMut(isize)) {
76        visit_strided_indices_impl(0, &self.shape, &self.strides, &mut f)
77    }
78
79    pub fn size(&self) -> usize {
80        self.shape.iter().copied().product()
81    }
82
83    pub fn slice(&self, axis: usize, range: SliceRange) -> StridedShape {
84        assert!(axis < self.rank(), "Rank {} out of bounds for {:?}", self.rank(), self);
85        range.assert_in_bounds(self.shape[axis]);
86
87        let mut new_shape = self.shape.clone();
88        let mut new_strides = self.strides.clone();
89
90        let SliceRange { start, end, step } = range;
91
92        new_shape[axis] = (end - start) / step;
93        new_strides[axis] *= step as isize;
94
95        StridedShape::new(new_shape, new_strides)
96    }
97
98    pub fn flip(&self, axis: usize) -> StridedShape {
99        let new_shape = self.shape.clone();
100        let mut new_strides = self.strides.clone();
101
102        // just flip the stride of the axis
103        new_strides[axis] *= -1;
104
105        StridedShape::new(new_shape, new_strides)
106    }
107
108    pub fn broadcast(&self, new_shape: Vec<usize>) -> StridedShape {
109        assert_eq!(
110            self.rank(),
111            new_shape.len(),
112            "Can only broadcast to same rank, got {:?} and {:?}",
113            self,
114            new_shape
115        );
116
117        let new_strides = (0..self.rank())
118            .map(|i| {
119                if new_shape[i] == self.shape[i] {
120                    self.strides[i]
121                } else {
122                    assert_eq!(
123                        self.shape[i], 1,
124                        "Broadcast mismatch between {:?} and {:?} at axis {}",
125                        self, new_shape, i
126                    );
127                    0
128                }
129            })
130            .collect_vec();
131
132        StridedShape::new(new_shape, new_strides)
133    }
134
135    pub fn view(&self, new_shape: Vec<usize>) -> Result<StridedShape, ViewError> {
136        // implementation originally based on pytorch computeStride_impl:
137        // https://github.com/pytorch/pytorch/blob/560cd881956bbf425251d63f0ff0f9085a759447/aten/src/ATen/TensorUtils.cpp#L335-L346
138
139        let new_size = new_shape.iter().copied().product::<usize>();
140        assert_eq!(
141            self.size(),
142            new_size,
143            "Size cannot change during view, cannot go from {:?} to {:?}",
144            self,
145            new_shape
146        );
147
148        if self.size() == 0 || self.rank() == 0 {
149            return Ok(StridedShape::new_simple(new_shape));
150        }
151
152        let mut new_strides = vec![0; new_shape.len()];
153        let mut next_d = 0;
154
155        let mut failed = false;
156
157        self.for_each_continuous_group(|group_size, group_stride| {
158            if failed {
159                return;
160            };
161
162            let mut left_group_size = group_size;
163            while left_group_size > 1 {
164                if left_group_size % new_shape[next_d] == 0 {
165                    left_group_size /= new_shape[next_d];
166                    new_strides[next_d] = left_group_size as isize * group_stride;
167                    next_d += 1;
168                } else {
169                    failed = true;
170                    return;
171                }
172            }
173        });
174
175        if failed {
176            Err(ViewError {
177                old: self.clone(),
178                new: new_shape,
179            })
180        } else {
181            // complete the strides for trailing 1-sized dims
182            for d in next_d..new_shape.len() {
183                assert_eq!(new_shape[d], 1);
184                new_strides[d] = 1;
185            }
186
187            Ok(StridedShape::new(new_shape, new_strides))
188        }
189    }
190
191    fn for_each_continuous_group(&self, mut f: impl FnMut(usize, isize)) {
192        if self.size() == 0 || self.rank() == 0 {
193            f(0, 1);
194            return;
195        }
196
197        let mut group_size = 1;
198        let mut prev_stride = None;
199
200        for (&d_size, &d_stride) in zip_eq(&self.shape, &self.strides) {
201            if let Some(prev_stride) = prev_stride {
202                if prev_stride != d_size as isize * d_stride {
203                    //finish previous group
204                    f(group_size, prev_stride);
205                    group_size = 1;
206                }
207            }
208
209            group_size *= d_size;
210            prev_stride = Some(d_stride)
211        }
212
213        if let Some(prev_stride) = prev_stride {
214            //finish last group
215            f(group_size, prev_stride)
216        }
217    }
218
219    pub fn permute(&self, permutation: &[usize]) -> StridedShape {
220        assert_eq!(permutation.len(), self.rank());
221        assert!(permutation.iter().all_unique());
222
223        // just permute the shape and strides
224        let new_shape = permutation.iter().map(|&i| self.shape()[i]).collect();
225        let new_strides = permutation.iter().map(|&i| self.strides()[i]).collect();
226
227        StridedShape::new(new_shape, new_strides)
228    }
229
230    pub fn repeat_unary(&self, axis: usize, count: usize) -> StridedShape {
231        assert!(axis < self.rank());
232        assert_eq!(self.shape[axis], 1);
233
234        let mut new_shape = self.shape.clone();
235        let mut new_strides = self.strides.clone();
236
237        new_shape[axis] = count;
238        new_strides[axis] = 0;
239
240        StridedShape::new(new_shape, new_strides)
241    }
242
243    pub fn descriptor(&self, dtype: cudnnDataType_t) -> TensorDescriptor {
244        let mut shape = self.shape.iter().map(|&x| x as i32).collect_vec();
245        let mut strides = self.strides.iter().map(|&x| x as i32).collect_vec();
246
247        // tensor descriptors and some cudnn operations seem to break with ranks < 4,
248        //   so pad the rank until it's large enough
249        while shape.len() < 4 {
250            shape.push(1);
251            strides.push(1);
252        }
253
254        TensorDescriptor::new(shape, strides, dtype)
255    }
256
257    pub fn filter_descriptor(&self, dtype: cudnnDataType_t) -> FilterDescriptor {
258        assert_eq!(4, self.rank(), "Filter must have rank 4");
259        assert!(self.has_simple_strides(), "Filter must have simple strides");
260
261        let dims = self.shape();
262        FilterDescriptor::new(dims[0] as i32, dims[1] as i32, dims[2] as i32, dims[3] as i32, dtype)
263    }
264
265    pub fn matrix_layout(&self) -> MatrixLayout {
266        assert_eq!(3, self.rank(), "Matrix must have rank 3");
267
268        let shape = [self.shape[0], self.shape[1], self.shape[2]];
269        let strides = [self.strides[0], self.strides[1], self.strides[2]];
270
271        MatrixLayout::new(shape, strides).unwrap_or_else(|| panic!("Failed to convert {:?} to MatrixLayout", self))
272    }
273
274    pub fn remove(&self, axis: usize) -> StridedShape {
275        assert!(axis < self.rank(), "Axis {} out of bounds for {:?}", axis, self);
276
277        let mut new_shape = self.shape.clone();
278        let mut new_strides = self.strides.clone();
279
280        new_shape.remove(axis);
281        new_strides.remove(axis);
282
283        StridedShape::new(new_shape, new_strides)
284    }
285}
286
287impl Debug for StridedShape {
288    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
289        f.debug_struct("StridedShape")
290            .field("shape", &self.shape)
291            .field("strides", &self.strides)
292            .finish()
293    }
294}
295
296fn simple_strides(shape: &[usize]) -> Vec<isize> {
297    let mut result = vec![];
298    let mut next_stride = 1;
299
300    for &size in shape.iter().rev() {
301        result.push(next_stride as isize);
302        next_stride *= size;
303    }
304
305    result.reverse();
306    result
307}
308
309/// Whether the given shape covers every value within its data range.
310/// This is equivalent to asking whether any possible permutation of the shape with abs strides has simple strides.
311fn has_dense_strides(shape: &[usize], strides: &[isize]) -> bool {
312    assert_eq!(shape.len(), strides.len());
313
314    if shape.iter().copied().product::<usize>() == 0 {
315        return true;
316    }
317
318    let pairs = zip(shape.iter().copied(), strides.iter().copied().map(|x| x.abs()))
319        .sorted_by_key(|x| Reverse(x.1))
320        .collect_vec();
321
322    let sorted_shape = pairs.iter().map(|&x| x.0).collect_vec();
323    let sorted_strides = pairs.iter().map(|&x| x.1).collect_vec();
324
325    simple_strides(&sorted_shape) == sorted_strides
326}
327
328fn visit_strided_indices_impl(start: isize, shape: &[usize], strides: &[isize], f: &mut impl FnMut(isize)) {
329    match shape {
330        [] => f(start as isize),
331        [size_curr, size_rest @ ..] => {
332            for i in 0..*size_curr {
333                let i_start = start + i as isize * strides[0];
334                visit_strided_indices_impl(i_start, size_rest, &strides[1..], f)
335            }
336        }
337    }
338}
339
340impl Debug for ViewError {
341    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
342        write!(f, "Cannot view shape {:?} as {:?}", self.old, self.new)
343    }
344}
345
346#[cfg(test)]
347mod test {
348    use kn_graph::graph::SliceRange;
349
350    use crate::shape::StridedShape;
351
352    #[test]
353    fn properties_positive() {
354        let simple = StridedShape::new(vec![2, 3], vec![3, 1]);
355        assert!(simple.has_simple_strides);
356        assert!(simple.has_dense_strides);
357
358        let dense = StridedShape::new(vec![3, 2], vec![1, 3]);
359        assert!(!dense.has_simple_strides);
360        assert!(dense.has_dense_strides);
361
362        let neither = StridedShape::new(vec![3, 2], vec![8, 10]);
363        assert!(!neither.has_simple_strides);
364        assert!(!neither.has_dense_strides);
365    }
366
367    #[test]
368    fn properties_negative() {
369        let simple = StridedShape::new(vec![2, 3], vec![3, -1]);
370        assert!(!simple.has_simple_strides);
371        assert!(simple.has_dense_strides);
372    }
373
374    fn collect_groups(shape: &StridedShape) -> (Vec<usize>, Vec<isize>) {
375        let mut sizes = vec![];
376        let mut strides = vec![];
377        shape.for_each_continuous_group(|group_size, group_stride| {
378            sizes.push(group_size);
379            strides.push(group_stride);
380        });
381        (sizes, strides)
382    }
383
384    #[test]
385    fn view_rank_zero() {
386        let shape = StridedShape::new(vec![], vec![]);
387        assert_eq!(collect_groups(&shape), (vec![0], vec![1]),);
388        assert_eq!(
389            shape.view(vec![1, 1, 1]),
390            Ok(StridedShape::new(vec![1, 1, 1], vec![1, 1, 1])),
391        );
392    }
393
394    #[test]
395    fn view_size_zero() {
396        let shape = StridedShape::new(vec![2, 3, 0, 5], vec![0, 0, 0, 2]);
397        assert_eq!(collect_groups(&shape), (vec![0], vec![1]));
398        assert_eq!(shape.view(vec![0]), Ok(StridedShape::new(vec![0], vec![1])));
399        assert_eq!(shape.view(vec![12, 0]), Ok(StridedShape::new(vec![12, 0], vec![0, 1])),);
400    }
401
402    #[test]
403    fn view_simple() {
404        let shape = StridedShape::new(vec![2, 3, 4, 3, 2], vec![72, 24, 6, 2, 1]);
405        assert!(shape.has_simple_strides());
406        assert_eq!(collect_groups(&shape), (vec![144], vec![1]));
407        assert_eq!(shape.view(vec![144]), Ok(StridedShape::new(vec![144], vec![1])),);
408        assert_eq!(shape.view(vec![72, 2]), Ok(StridedShape::new(vec![72, 2], vec![2, 1])),);
409        assert_eq!(
410            shape.view(vec![72, 2, 1, 1, 1]),
411            Ok(StridedShape::new(vec![72, 2, 1, 1, 1], vec![2, 1, 1, 1, 1])),
412        );
413    }
414
415    #[test]
416    fn view_split() {
417        let shape = StridedShape::new(vec![2, 3, 4], vec![24, 8, 1]);
418        assert_eq!(collect_groups(&shape), (vec![6, 4], vec![8, 1]));
419        assert_eq!(shape.view(vec![6, 4]), Ok(StridedShape::new(vec![6, 4], vec![8, 1])),);
420        assert!(shape.view(vec![24]).is_err());
421    }
422
423    #[test]
424    fn slice_simple() {
425        let shape = StridedShape::new(vec![2, 3, 4], vec![24, 8, 1]);
426        assert_eq!(
427            shape.slice(1, SliceRange::new(0, 4, 2)),
428            StridedShape::new(vec![2, 2, 4], vec![24, 16, 1])
429        )
430    }
431}