kn_graph/
shape.rs

1use std::convert::TryInto;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::ControlFlow;
4
5use itertools::Itertools;
6
7#[macro_export]
8macro_rules! shape {
9    [$($(*)? $value:expr),* $(,)?] => {
10        $crate::shape::Shape::new(vec![$($crate::shape::Size::from($value)),*])
11    };
12}
13
14/// A shape with each dimension corresponding to [Size].
15///
16/// Use one of the constructor [Shape::new], one of the utilities [Shape::single], [Shape::fixed], [Shape::ones] or
17/// the the `shape!` macro to conveniently construct one:
18/// ```
19/// # use kn_graph::shape;
20/// # use kn_graph::shape::{Shape, Size};
21/// // these are all equivalent
22/// shape![Size::BATCH, 16, 8, 8];
23/// Shape::new(vec![Size::BATCH, 16.into(), 8.into(), 8.into()]);
24/// Shape::new(vec![Size::BATCH, Size::fixed(16), Size::fixed(8), Size::fixed(8)]);
25/// ```
26#[derive(Clone, Eq, PartialEq, Hash)]
27pub struct Shape {
28    pub dims: Vec<Size>,
29}
30
31/// A size expression of the form `F * pow(batch_size, N)`.
32///
33/// Can represent fixed sizes, sizes proportional to the batch size or any other higher power of the batch size.
34#[derive(Copy, Clone, Eq, PartialEq, Hash)]
35pub struct Size {
36    batch_exp: u32,
37    fixed_factor: usize,
38}
39
40/// A shape with each dimension being a fixed `usize`.
41#[derive(Debug, Clone, Eq, PartialEq, Hash)]
42pub struct ConcreteShape {
43    pub dims: Vec<usize>,
44}
45
46// TODO unify all shape types into a generic "Shape<S>"
47//  what about strides?
48impl Shape {
49    pub const SCALAR: Shape = Shape { dims: vec![] };
50
51    pub fn new(dims: Vec<Size>) -> Shape {
52        Shape { dims }
53    }
54
55    pub fn single(size: Size) -> Shape {
56        Shape { dims: vec![size] }
57    }
58
59    pub fn fixed(dims: &[usize]) -> Shape {
60        let dims = dims.iter().map(|&d| Size::fixed(d)).collect_vec();
61        Shape { dims }
62    }
63
64    pub fn ones(rank: usize) -> Shape {
65        Shape::new(vec![Size::ONE; rank])
66    }
67
68    pub fn zeros(rank: usize) -> Shape {
69        Shape::new(vec![Size::ZERO; rank])
70    }
71
72    pub fn rank(&self) -> usize {
73        self.dims.len()
74    }
75
76    pub fn assert_has_axis(&self, axis: usize) {
77        assert!(axis < self.rank(), "Axis {} out of bounds for {:?}", axis, self);
78    }
79
80    pub fn as_fixed(&self) -> Option<ConcreteShape> {
81        self.dims
82            .iter()
83            .map(|d| d.try_unwrap_fixed().ok_or(()))
84            .try_collect()
85            .ok()
86            .map(ConcreteShape::new)
87    }
88
89    pub fn unwrap_fixed(&self, what: &str) -> ConcreteShape {
90        let dims = self.dims.iter().map(|d| d.unwrap_fixed(what)).collect_vec();
91        ConcreteShape { dims }
92    }
93
94    pub fn eval(&self, batch_size: usize) -> ConcreteShape {
95        let dims = self.dims.iter().map(|d| d.eval(batch_size)).collect_vec();
96        ConcreteShape { dims }
97    }
98
99    pub fn size(&self) -> Size {
100        self.dims.iter().copied().product()
101    }
102
103    pub fn unwrap_1(&self) -> Size {
104        assert_eq!(1, self.dims.len(), "Expected rank 1 shape");
105        self.dims[0]
106    }
107
108    pub fn unwrap_2(&self) -> [Size; 2] {
109        self.dims
110            .as_slice()
111            .try_into()
112            .unwrap_or_else(|_| panic!("Expected rank 2 shape, got {:?}", self))
113    }
114
115    pub fn unwrap_3(&self) -> [Size; 3] {
116        self.dims
117            .as_slice()
118            .try_into()
119            .unwrap_or_else(|_| panic!("Expected rank 3 shape, got {:?}", self))
120    }
121
122    pub fn unwrap_4(&self) -> [Size; 4] {
123        self.dims
124            .as_slice()
125            .try_into()
126            .unwrap_or_else(|_| panic!("Expected rank 4 shape, got {:?}", self))
127    }
128
129    pub fn concat(mut self, other: &Shape) -> Shape {
130        self.dims.extend_from_slice(&other.dims);
131        self
132    }
133
134    pub fn batched(&self) -> Shape {
135        shape![Size::BATCH].concat(self)
136    }
137
138    /// Build a new shape with the shape at `axis` replaced by `replacement`, the rest are kept as-is.
139    pub fn replace(&self, axis: usize, replacement: Shape) -> Shape {
140        self.replace_all(&[axis], replacement)
141    }
142
143    pub fn replace_all(&self, axes: &[usize], replacement: Shape) -> Shape {
144        // validate axes
145        assert!(axes.iter().all_unique(), "Axes must be unique, got {:?}", axes);
146
147        for &axis in axes {
148            self.assert_has_axis(axis);
149        }
150
151        // construct new shape
152        let mut dims = vec![];
153        for i in 0..self.rank() {
154            if axes.contains(&i) {
155                dims.extend_from_slice(&replacement.dims);
156            } else {
157                dims.push(self[i])
158            }
159        }
160
161        Shape::new(dims)
162    }
163
164    /// Build a new shape with the shape at `axis` kept and all other axes replaced by `rest`.
165    pub fn keep(&self, axis: usize, rest: Size) -> Shape {
166        self.assert_has_axis(axis);
167
168        let mut dims = self.dims.clone();
169        for i in 0..self.rank() {
170            if i != axis {
171                dims[i] = rest;
172            }
173        }
174        Shape::new(dims)
175    }
176
177    pub fn repeat_unary(&self, axis: usize, new_size: Size) -> Shape {
178        self.assert_has_axis(axis);
179
180        assert_eq!(
181            self.dims[axis],
182            Size::ONE,
183            "Repeated axis {} must have length 1 for {:?}",
184            axis,
185            self
186        );
187
188        let mut dims = self.dims.clone();
189        dims[axis] = new_size;
190        Shape::new(dims)
191    }
192
193    pub fn insert(&self, axis: usize, size: Size) -> Shape {
194        assert!(
195            axis <= self.rank(),
196            "Axis {} out of bounds for inserting into {:?}",
197            axis,
198            self
199        );
200
201        let mut dims = self.dims.clone();
202        dims.insert(axis, size);
203        Shape::new(dims)
204    }
205
206    pub fn split(&self, index: usize) -> (Shape, Shape) {
207        assert!(
208            index <= self.rank(),
209            "Split index {} out of bounds for {:?}",
210            index,
211            self
212        );
213
214        let body = self.dims[..index].to_vec();
215        let tail = self.dims[index..].to_vec();
216
217        (Shape::new(body), Shape::new(tail))
218    }
219}
220
221impl From<usize> for Size {
222    fn from(fixed_factor: usize) -> Self {
223        Size::fixed(fixed_factor)
224    }
225}
226
227#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
228pub enum DivResult {
229    Exact(Size),
230    Remainder(usize),
231    Impossible,
232}
233
234impl Size {
235    pub const ZERO: Size = Size::new(0, 0);
236    pub const ONE: Size = Size::new(0, 1);
237    pub const BATCH: Size = Size::new(1, 1);
238
239    pub const fn new(batch_exp: u32, fixed_factor: usize) -> Size {
240        if fixed_factor == 0 {
241            Size {
242                batch_exp: 0,
243                fixed_factor: 0,
244            }
245        } else {
246            Size {
247                batch_exp,
248                fixed_factor,
249            }
250        }
251    }
252
253    pub const fn fixed(size: usize) -> Size {
254        Size {
255            batch_exp: 0,
256            fixed_factor: size,
257        }
258    }
259
260    pub const fn is_zero(&self) -> bool {
261        matches!(
262            self,
263            Size {
264                batch_exp: 0,
265                fixed_factor: 0
266            }
267        )
268    }
269
270    pub const fn components_factor_exp(self) -> (usize, u32) {
271        (self.fixed_factor, self.batch_exp)
272    }
273
274    pub fn eval(self, batch_size: usize) -> usize {
275        batch_size.pow(self.batch_exp) * self.fixed_factor
276    }
277
278    pub fn try_unwrap_fixed(self) -> Option<usize> {
279        if self.batch_exp == 0 {
280            Some(self.fixed_factor)
281        } else {
282            None
283        }
284    }
285
286    #[track_caller]
287    pub fn unwrap_fixed(self, what: &str) -> usize {
288        assert_eq!(0, self.batch_exp, "{} must be fixed, but got size {:?}", what, self);
289        self.fixed_factor
290    }
291
292    pub fn floor_div(self, rhs: Self) -> Option<Self> {
293        if self.batch_exp < rhs.batch_exp {
294            None
295        } else {
296            Some(Size::new(
297                self.batch_exp - rhs.batch_exp,
298                self.fixed_factor / rhs.fixed_factor,
299            ))
300        }
301    }
302
303    pub fn div_rem(self, rhs: impl Into<Size>) -> DivResult {
304        let rhs = rhs.into();
305        let fixed_rem = self.fixed_factor % rhs.fixed_factor;
306        if self.batch_exp < rhs.batch_exp {
307            DivResult::Impossible
308        } else if fixed_rem != 0 {
309            DivResult::Remainder(fixed_rem)
310        } else {
311            DivResult::Exact(Size::new(
312                self.batch_exp - rhs.batch_exp,
313                self.fixed_factor / rhs.fixed_factor,
314            ))
315        }
316    }
317}
318
319impl ConcreteShape {
320    pub fn new(dims: Vec<usize>) -> Self {
321        ConcreteShape { dims }
322    }
323
324    pub fn rank(&self) -> usize {
325        self.dims.len()
326    }
327
328    pub fn size(&self) -> usize {
329        self.dims.iter().product()
330    }
331
332    pub fn unwrap_2(&self) -> [usize; 2] {
333        self.dims.as_slice().try_into().expect("Expected rank 2 shape")
334    }
335
336    pub fn unwrap_3(&self) -> [usize; 3] {
337        self.dims.as_slice().try_into().expect("Expected rank 2 shape")
338    }
339
340    pub fn unwrap_4(&self) -> [usize; 4] {
341        self.dims.as_slice().try_into().expect("Expected rank 4 shape")
342    }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq)]
346pub enum ShapeMismatch {
347    DifferentLength,
348    ConstantMismatch,
349    BatchConflict,
350    ImpossibleBatchValue,
351}
352
353/// Infer what the batch size should be for the given input shapes.
354///
355/// Returns:
356/// * `Ok(Some(batch_size))` if the batch size can be inferred
357/// * `Ok(None)` if any batch size would fit
358/// * `Err(ShapeMismatch)` if no batch size would fit,
359/// either because there are conflicting requirements or because there is some other shape mismatch
360pub fn infer_batch_size(expected: &[Shape], actual: &[ConcreteShape]) -> Result<Option<usize>, ShapeMismatch> {
361    infer_batch_size_dims(
362        expected.iter().flat_map(|s| s.dims.iter().copied()),
363        actual.iter().flat_map(|s| s.dims.iter().copied()),
364    )
365}
366
367/// Same as [infer_batch_size], except for individual dimensions.
368pub fn infer_batch_size_dims(
369    expected: impl IntoIterator<Item=Size>,
370    actuals: impl IntoIterator<Item=usize>,
371) -> Result<Option<usize>, ShapeMismatch> {
372    let mut shapes = expected.into_iter();
373    let mut actuals = actuals.into_iter();
374
375    let mut batch_size = None;
376
377    loop {
378        let (expected, actual) = match (shapes.next(), actuals.next()) {
379            (Some(shape), Some(actual)) => (shape, actual),
380            (None, None) => return Ok(batch_size),
381            _ => return Err(ShapeMismatch::DifferentLength),
382        };
383
384        let (factor, exp) = expected.components_factor_exp();
385
386        if exp == 0 {
387            // constant dim, check match
388            if actual != factor {
389                return Err(ShapeMismatch::ConstantMismatch);
390            }
391        } else {
392            // dim containing batch
393            if let Some(batch_size) = batch_size {
394                // we already know the batch size, check match
395                if actual != expected.eval(batch_size) {
396                    return Err(ShapeMismatch::BatchConflict);
397                }
398            } else {
399                // we don't know the batch size, compute it
400                let batch_size_approx = (actual as f64 / factor as f64).powf(1.0 / exp as f64) as usize;
401
402                let deltas = [0, 1, -1, 2, -2];
403                let batch_size_exact = deltas
404                    .iter()
405                    .find_map(|&delta| {
406                        let cand = batch_size_approx.checked_add_signed(delta).unwrap();
407                        if factor * cand.pow(exp) == actual {
408                            Some(cand)
409                        } else {
410                            None
411                        }
412                    })
413                    .ok_or(ShapeMismatch::ImpossibleBatchValue)?;
414
415                batch_size = Some(batch_size_exact);
416            }
417        }
418    }
419}
420
421impl<R: Into<Size>> std::ops::Add<R> for Size {
422    type Output = Option<Size>;
423
424    fn add(self, rhs: R) -> Self::Output {
425        let rhs = rhs.into();
426        if self == Size::ZERO {
427            return Some(rhs);
428        }
429        if rhs == Size::ZERO {
430            return Some(self);
431        }
432        if self.batch_exp != rhs.batch_exp {
433            return None;
434        }
435
436        Some(Size::new(self.batch_exp, self.fixed_factor + rhs.fixed_factor))
437    }
438}
439
440impl<R: Into<Size>> std::ops::Sub<R> for Size {
441    type Output = Option<Size>;
442
443    fn sub(self, rhs: R) -> Self::Output {
444        let rhs = rhs.into();
445        if rhs == Size::ZERO {
446            return Some(self);
447        }
448
449        if self.batch_exp != rhs.batch_exp || self.fixed_factor < rhs.fixed_factor {
450            return None;
451        }
452
453        Some(Size::new(self.batch_exp, self.fixed_factor - rhs.fixed_factor))
454    }
455}
456
457impl<R: Into<Size>> std::ops::Mul<R> for Size {
458    type Output = Size;
459
460    fn mul(self, rhs: R) -> Self::Output {
461        let rhs = rhs.into();
462        Size::new(self.batch_exp + rhs.batch_exp, self.fixed_factor * rhs.fixed_factor)
463    }
464}
465
466impl<R: Into<Size>> std::ops::Div<R> for Size {
467    type Output = Option<Size>;
468
469    fn div(self, rhs: R) -> Self::Output {
470        match self.div_rem(rhs) {
471            DivResult::Exact(s) => Some(s),
472            DivResult::Remainder(_) | DivResult::Impossible => None,
473        }
474    }
475}
476
477impl<R: Into<Size>> std::ops::Rem<R> for Size {
478    type Output = Option<usize>;
479
480    fn rem(self, rhs: R) -> Self::Output {
481        match self.div_rem(rhs) {
482            DivResult::Exact(_) => Some(0),
483            DivResult::Remainder(r) => Some(r),
484            DivResult::Impossible => None,
485        }
486    }
487}
488
489impl std::iter::Sum<Size> for Option<Size> {
490    fn sum<I: Iterator<Item = Size>>(mut iter: I) -> Self {
491        let result = iter.try_fold(Size::ZERO, |a, s| match a + s {
492            Some(v) => ControlFlow::Continue(v),
493            None => ControlFlow::Break(()),
494        });
495
496        match result {
497            ControlFlow::Continue(v) => Some(v),
498            ControlFlow::Break(()) => None,
499        }
500    }
501}
502
503impl std::iter::Product for Size {
504    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
505        iter.fold(Size::fixed(1), |a, s| a * s)
506    }
507}
508
509impl std::ops::Index<usize> for Shape {
510    type Output = Size;
511
512    fn index(&self, axis: usize) -> &Self::Output {
513        self.assert_has_axis(axis);
514        &self.dims[axis]
515    }
516}
517
518impl Debug for Shape {
519    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520        write!(f, "Shape{}", self)
521    }
522}
523
524impl Debug for Size {
525    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
526        write!(f, "Size({})", self)
527    }
528}
529
530impl Display for Shape {
531    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
532        fmt_shape_impl(f, &self.dims)
533    }
534}
535
536impl Display for Size {
537    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
538        match (self.fixed_factor, self.batch_exp) {
539            (a, 0) => write!(f, "{}", a),
540            (1, 1) => write!(f, "B"),
541            (a, 1) => write!(f, "{}B", a),
542            (1, b) => write!(f, "B^{}", b),
543            (a, b) => write!(f, "{}B^{}", a, b),
544        }
545    }
546}
547
548impl Display for ConcreteShape {
549    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
550        fmt_shape_impl(f, &self.dims)
551    }
552}
553
554fn fmt_shape_impl(f: &mut Formatter, dims: &[impl Display]) -> Result<(), std::fmt::Error> {
555    write!(f, "(")?;
556    for i in 0..dims.len() {
557        if i != 0 {
558            write!(f, " x ")?;
559        }
560
561        write!(f, "{}", dims[i])?;
562    }
563    write!(f, ")")?;
564    Ok(())
565}