Skip to main content

lumen_core/
shape.rs

1use std::{fmt::Display, vec};
2
3use crate::{Error, Result};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Shape(pub(crate) Vec<usize>);
7
8impl Shape {
9    pub fn scalar() -> Self {
10        Self(vec![])
11    }
12
13    pub fn is_scalar(&self) -> bool {
14        self.0.is_empty() || (self.0.len() == 1 && self.0[0] == 1)
15    } 
16
17    pub fn rank(&self) -> usize {
18        self.0.len()
19    }
20
21    pub fn dims(&self) -> &[usize] {
22        &self.0
23    }
24
25    pub fn into_dims(self) -> Vec<usize> {
26        self.0
27    }
28
29    pub fn dim(&self, dim: impl Dim) -> Result<usize> {
30        let index = dim.to_index(self, "get dim")?;
31        Ok(self.dims()[index])
32    }
33
34    pub fn element_count(&self) -> usize {
35        self.dims().iter().product()
36    }
37
38    pub fn is_contiguous(&self, stride: &[usize]) -> bool {
39        if self.rank() != stride.len() {
40            return false;
41        }
42        let mut acc = 1;
43        for (&stride, &dim) in stride.iter().zip(self.dims().iter()).rev() {
44            if dim > 1 && stride != acc {
45                return false;
46            }
47            acc *= dim;
48        }
49        true
50    }
51
52    pub fn extend(mut self, additional_dims: &[usize]) -> Self {
53        self.0.extend(additional_dims);
54        self
55    }
56
57
58    /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
59    /// broadcasted shape. This is to be used for binary pointwise ops.
60    /// Copy from https://github.com/huggingface/candle
61    pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
62        let lhs = self;
63        let lhs_dims = lhs.dims();
64        let rhs_dims = rhs.dims();
65        let lhs_ndims = lhs_dims.len();
66        let rhs_ndims = rhs_dims.len();
67        let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
68        let mut bcast_dims = vec![0; bcast_ndims];
69        for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
70            let rev_idx = bcast_ndims - idx;
71            let l_value = if lhs_ndims < rev_idx {
72                1
73            } else {
74                lhs_dims[lhs_ndims - rev_idx]
75            };
76            let r_value = if rhs_ndims < rev_idx {
77                1
78            } else {
79                rhs_dims[rhs_ndims - rev_idx]
80            };
81            *bcast_value = if l_value == r_value {
82                l_value
83            } else if l_value == 1 {
84                r_value
85            } else if r_value == 1 {
86                l_value
87            } else {
88                Err(Error::ShapeMismatchBinaryOp {
89                    lhs: lhs.clone(),
90                    rhs: rhs.clone(),
91                    op,
92                })?
93            }
94        }
95        Ok(Shape::from(bcast_dims))
96    }
97
98
99    /// Returns an iterator over **dimension coordinates**.
100    ///
101    /// This iterator yields the multi-dimensional coordinates
102    /// (e.g., `[i, j, k, ...]`) of each element in the array, independent
103    /// of the physical storage layout.
104    ///
105    /// Example for shape = (2, 2):
106    /// yields: `[0, 0], [0, 1], [1, 0], [1, 1]`
107    pub fn dim_coordinates(&self) -> DimCoordinates {
108        DimCoordinates::from_shape(self)
109    }
110
111    pub fn dims_coordinates<const N: usize>(&self) -> Result<DimNCoordinates<N>> {
112        DimNCoordinates::<N>::from_shape(self)
113    }
114
115    pub fn dim2_coordinates(&self) -> Result<DimNCoordinates<2>> {
116        DimNCoordinates::<2>::from_shape(self)
117    }
118
119    pub fn dim3_coordinates(&self) -> Result<DimNCoordinates<3>> {
120        DimNCoordinates::<3>::from_shape(self)
121    }
122
123    pub fn dim4_coordinates(&self) -> Result<DimNCoordinates<4>> {
124        DimNCoordinates::<4>::from_shape(self)
125    }
126
127    pub fn dim5_coordinates(&self) -> Result<DimNCoordinates<5>> {
128        DimNCoordinates::<5>::from_shape(self)
129    }
130
131    pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
132        let mut stride = self.dims()
133            .iter()
134            .rev()
135            .scan(1, |prod, u| {
136                let prod_pre_mult = *prod;
137                *prod *= u;
138                Some(prod_pre_mult)
139            })
140            .collect::<Vec<_>>();
141        stride.reverse();
142        stride
143    }
144}
145
146
147//////////////////////////////////////////////////////////////////////////////////////
148///                  DimCoordinates
149//////////////////////////////////////////////////////////////////////////////////////
150
151pub struct DimCoordinates {
152    shape: Vec<usize>,
153    current: Vec<usize>,
154    done: bool,
155}
156
157impl DimCoordinates {
158    pub fn from_shape(shape: &Shape) -> Self {
159        let rank = shape.rank();
160        Self {
161            shape: shape.dims().to_vec(),
162            current: vec![0; rank],
163            done: shape.is_scalar(),
164        }
165    }
166}
167
168impl Iterator for DimCoordinates {
169    type Item = Vec<usize>;
170
171    fn next(&mut self) -> Option<Self::Item> {
172        if self.done {
173            return None;
174        }
175
176        let result = self.current.clone();
177
178        for i in (0..self.current.len()).rev() {
179            self.current[i] += 1;
180            if self.current[i] < self.shape[i] {
181                break; 
182            } else {
183                self.current[i] = 0;
184                if i == 0 {
185                    self.done = true;
186                }
187            }
188        }
189
190        Some(result)
191    }
192}
193
194pub struct DimNCoordinates<const N: usize> {
195    shape: [usize; N],
196    current: [usize; N],
197    done: bool,
198}
199
200impl<const N: usize> DimNCoordinates<N> {
201    pub fn from_shape(from_shape: &Shape) -> Result<Self> {
202        if from_shape.rank() == N {
203            let mut shape = [0usize; N];
204            for i in 0..N {
205                shape[i] = from_shape.dims()[i];
206            }
207
208            let current = [0usize; N];
209            
210            Ok(Self {
211                shape,
212                current,
213                done: N == 0
214            })
215        } else {
216            Err(Error::UnexpectedNumberOfDims {
217                expected: N,
218                got: from_shape.rank(),
219                shape: Shape::from(from_shape.dims()),
220            })?
221        }
222    }
223}
224
225impl<const N: usize> Iterator for DimNCoordinates<N> {
226    type Item = [usize; N];
227    fn next(&mut self) -> Option<Self::Item> {
228        if self.done {
229            return None;
230        }
231
232        let result = self.current;
233
234        for i in (0..N).rev() {
235            self.current[i] += 1;
236            if self.current[i] < self.shape[i] {
237                break; 
238            } else {
239                self.current[i] = 0;
240                if i == 0 {
241                    self.done = true;
242                }
243            }
244        }
245
246        Some(result)
247    }
248}
249
250impl<const C: usize> From<&[usize; C]> for Shape {
251    fn from(dims: &[usize; C]) -> Self {
252        Self(dims.to_vec())
253    }
254}
255
256impl From<Vec<usize>> for Shape {
257    fn from(dims: Vec<usize>) -> Self {
258        Self(dims)
259    }
260}
261
262impl From<&Vec<usize>> for Shape {
263    fn from(dims: &Vec<usize>) -> Self {
264        Self(dims.clone())
265    }
266}
267
268impl From<&[usize]> for Shape {
269    fn from(dims: &[usize]) -> Self {
270        Self(dims.to_vec())
271    }
272}
273
274impl From<&Shape> for Shape {
275    fn from(shape: &Shape) -> Self {
276        Self(shape.0.to_vec())
277    }
278}
279
280impl From<usize> for Shape {
281    fn from(d1: usize) -> Self {
282        Self([d1].to_vec())
283    }
284}
285
286impl From<()> for Shape {
287    fn from(_: ()) -> Self {
288        Self(vec![])
289    }
290}
291
292impl std::fmt::Display for Shape {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        write!(f, "(")?;
295        for (i, dim) in self.0.iter().enumerate() {
296            if i > 0 {
297                write!(f, ", ")?;
298            }
299            write!(f, "{}", dim)?;
300        }
301        if self.0.len() == 1 {
302            write!(f, ",")?;
303        }
304        write!(f, ")")
305    }
306}
307
308macro_rules! impl_from_tuple {
309    ($tuple:ty, $($index:tt),+) => {
310        impl From<$tuple> for Shape {
311            fn from(d: $tuple) -> Self {
312                Self([$(d.$index,)+].to_vec())
313            }
314        }
315    };
316}
317
318impl_from_tuple!((usize,), 0);
319impl_from_tuple!((usize, usize), 0, 1);
320impl_from_tuple!((usize, usize, usize), 0, 1, 2);
321impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
322impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
323impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
324
325#[derive(Debug, Clone, Copy)]
326pub enum D {
327    Minus1,
328    Minus2,
329    Minus(usize),
330    Index(usize),
331}
332
333impl Display for D {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        match self {
336            Self::Minus(n) => writeln!(f, "-{}", n),
337            Self::Minus1 => writeln!(f, "-1"),
338            Self::Minus2 => writeln!(f, "-2"),
339            Self::Index(n) => writeln!(f, "{}", n),
340        }
341    }
342}
343
344impl D {
345    fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
346        let dim = match self {
347            Self::Minus1 => -1,
348            Self::Minus2 => -2,
349            Self::Minus(u) => -(*u as i32),
350            Self::Index(u) => *u as i32,
351        };
352        Error::DimOutOfRange {
353            shape: shape.clone(),
354            dim,
355            op,
356        }
357    }
358}
359
360
361macro_rules! extract_dims {
362    ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
363        pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
364            if dims.len() != $cnt {
365                Err(Error::UnexpectedNumberOfDims {
366                    expected: $cnt,
367                    got: dims.len(),
368                    shape: Shape::from(dims),
369                })?
370            } else {
371                Ok($dims(dims))
372            }
373        }
374
375        impl Shape {
376            pub fn $fn_name(&self) -> Result<$out_type> {
377                $fn_name(self.0.as_slice())
378            }
379        }
380
381        impl<T: crate::WithDType> crate::Tensor<T> {
382            pub fn $fn_name(&self) -> Result<$out_type> {
383                self.shape().$fn_name()
384            }
385        }
386
387        impl std::convert::TryInto<$out_type> for Shape {
388            type Error = crate::Error;
389            fn try_into(self) -> crate::Result<$out_type> {
390                self.$fn_name()
391            }
392        }
393    };
394}
395
396extract_dims!(dims0, 0, |_: &[usize]| (), ());
397extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
398extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
399extract_dims!(
400    dims3,
401    3,
402    |d: &[usize]| (d[0], d[1], d[2]),
403    (usize, usize, usize)
404);
405extract_dims!(
406    dims4,
407    4,
408    |d: &[usize]| (d[0], d[1], d[2], d[3]),
409    (usize, usize, usize, usize)
410);
411extract_dims!(
412    dims5,
413    5,
414    |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
415    (usize, usize, usize, usize, usize)
416);
417
418
419pub trait Dim : Copy {
420    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
421    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
422}
423
424impl Dim for usize {
425    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
426        let dim = *self;
427        if dim >= shape.rank() {
428            Err(Error::DimOutOfRange {
429                shape: shape.clone(),
430                dim: dim as i32,
431                op,
432            })?
433        } else {
434            Ok(dim)
435        }
436    }
437
438    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
439        let dim = *self;
440        if dim > shape.rank() {
441            Err(Error::DimOutOfRange {
442                shape: shape.clone(),
443                dim: dim as i32,
444                op,
445            })?
446        } else {
447            Ok(dim)
448        }
449    }
450}
451
452impl Dim for D {
453    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
454        let rank = shape.rank();
455        match self {
456            Self::Minus1 if rank >= 1 => Ok(rank - 1),
457            Self::Minus2 if rank >= 2 => Ok(rank - 2),
458            Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
459            Self::Index(u) => u.to_index(shape, op),
460            _ => Err(self.out_of_range(shape, op))?,
461        }
462    }
463
464    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
465        let rank = shape.rank();
466        match self {
467            Self::Minus1 => Ok(rank),
468            Self::Minus2 if rank >= 1 => Ok(rank - 1),
469            Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
470            Self::Index(u) => u.to_index_plus_one(shape, op),
471            _ => Err(self.out_of_range(shape, op))?,
472        }
473    }
474}
475
476pub trait Dims {
477    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
478    fn check_indexes(dims: &[usize], shape: &Shape, op: &'static str) -> Result<()> {
479        for (i, &dim) in dims.iter().enumerate() {
480            if dims[..i].contains(&dim) {
481                return Err(Error::DuplicateDimIndex {
482                    shape: shape.clone(),
483                    dims: dims.to_vec(),
484                    op,
485                })?;
486            }
487            if dim >= shape.rank() {
488                return Err(Error::DimOutOfRange {
489                    shape: shape.clone(),
490                    dim: dim as i32,
491                    op,
492                })?;
493            }
494        }
495        Ok(())
496    }
497}
498
499impl Dims for Vec<usize> {
500    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
501        Self::check_indexes(&self, shape, op)?;
502        Ok(self)
503    }
504}
505
506impl<const N: usize> Dims for [usize; N] {
507    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
508        Self::check_indexes(&self, shape, op)?;
509        Ok(self.to_vec())
510    }
511}
512
513impl Dims for &[usize] {
514    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
515        Self::check_indexes(&self, shape, op)?;
516        Ok(self.to_vec())
517    }
518}
519
520impl Dims for () {
521    fn to_indexes(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
522        Ok(vec![])
523    }
524}
525
526impl<D: Dim + Sized> Dims for D {
527    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
528        let dim = self.to_index(shape, op)?;
529        Ok([dim].to_vec())
530    }
531}
532
533impl<D: Dim> Dims for (D,) {
534    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
535        let dim = self.0.to_index(shape, op)?;
536        Ok([dim].to_vec())
537    }
538}
539
540impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
541    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
542        let d0 = self.0.to_index(shape, op)?;
543        let d1 = self.1.to_index(shape, op)?;
544        Ok([d0, d1].to_vec())
545    }
546}
547
548impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
549    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
550        let d0 = self.0.to_index(shape, op)?;
551        let d1 = self.1.to_index(shape, op)?;
552        let d2 = self.2.to_index(shape, op)?;
553        Ok([d0, d1, d2].to_vec())
554    }
555}
556
557impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
558    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
559        let d0 = self.0.to_index(shape, op)?;
560        let d1 = self.1.to_index(shape, op)?;
561        let d2 = self.2.to_index(shape, op)?;
562        let d3 = self.3.to_index(shape, op)?;
563        Ok([d0, d1, d2, d3].to_vec())
564    }
565}
566
567impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
568    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
569        let d0 = self.0.to_index(shape, op)?;
570        let d1 = self.1.to_index(shape, op)?;
571        let d2 = self.2.to_index(shape, op)?;
572        let d3 = self.3.to_index(shape, op)?;
573        let d4 = self.4.to_index(shape, op)?;
574        Ok([d0, d1, d2, d3, d4].to_vec())
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    #[test]
583    fn stride() {
584        let shape = Shape::from(());
585        assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
586        let shape = Shape::from(42);
587        assert_eq!(shape.stride_contiguous(), [1]);
588        let shape = Shape::from((42, 1337));
589        assert_eq!(shape.stride_contiguous(), [1337, 1]);
590        let shape = Shape::from((299, 792, 458));
591        assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
592    }
593
594    #[test]
595    fn test_from_tuple() {
596        let shape = Shape::from((2,));
597        assert_eq!(shape.dims(), &[2]);
598        let shape = Shape::from((2, 3));
599        assert_eq!(shape.dims(), &[2, 3]);
600        let shape = Shape::from((2, 3, 4));
601        assert_eq!(shape.dims(), &[2, 3, 4]);
602        let shape = Shape::from((2, 3, 4, 5));
603        assert_eq!(shape.dims(), &[2, 3, 4, 5]);
604        let shape = Shape::from((2, 3, 4, 5, 6));
605        assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
606        let shape = Shape::from((2, 3, 4, 5, 6, 7));
607        assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
608    }
609
610    #[test]
611    fn test_dim_coordinates_2d() {
612        let shape = Shape([2, 2].to_vec());
613        let mut iter = shape.dim_coordinates();
614
615        let expected = [
616            [0, 0].to_vec(),
617            [0, 1].to_vec(),
618            [1, 0].to_vec(),
619            [1, 1].to_vec(),
620        ];
621
622        for e in expected {
623            let idx = iter.next();
624            assert_eq!(idx.unwrap(), e);
625        }
626
627        // Iter should be exhausted
628        assert!(iter.next().is_none());
629    }
630
631    #[test]
632    fn test_dim_coordinates_2d_varied() {
633        let shape = Shape([3, 1].to_vec());
634        let mut iter = shape.dim_coordinates();
635
636        let expected = [
637            [0, 0].to_vec(),
638            [1, 0].to_vec(),
639            [2, 0].to_vec(),
640        ];
641
642        for e in expected {
643            let idx = iter.next();
644            assert_eq!(idx.unwrap(), e);
645        }
646
647        assert!(iter.next().is_none());
648    }
649
650    #[test]
651    fn test_dim_coordinates_3d() {
652        let shape = Shape([2, 2, 2].to_vec());
653        let mut iter = shape.dim_coordinates();
654
655        let mut collected = Vec::new();
656        while let Some(idx) = iter.next() {
657            collected.push(idx);
658        }
659
660        let expected = [
661            [0, 0, 0].to_vec(),
662            [0, 0, 1].to_vec(),
663            [0, 1, 0].to_vec(),
664            [0, 1, 1].to_vec(),
665            [1, 0, 0].to_vec(),
666            [1, 0, 1].to_vec(),
667            [1, 1, 0].to_vec(),
668            [1, 1, 1].to_vec(),
669        ];
670
671        assert_eq!(collected, expected);
672    }
673
674    #[test]
675    fn test_dim_n_coordinates_2d() {
676        let shape = Shape([2, 2].to_vec());
677        let mut iter = shape.dim2_coordinates().unwrap();
678
679        let expected = [
680            [0, 0],
681            [0, 1],
682            [1, 0],
683            [1, 1],
684        ];
685
686        for e in expected {
687            let idx = iter.next();
688            assert_eq!(idx.unwrap(), e);
689        }
690
691        assert!(iter.next().is_none());
692    }
693
694    #[test]
695    fn test_dim_n_coordinates_3d() {
696        let shape = Shape([2, 2, 2].to_vec());
697        let mut iter = shape.dim3_coordinates().unwrap();
698
699        let expected = [
700            [0, 0, 0],
701            [0, 0, 1],
702            [0, 1, 0],
703            [0, 1, 1],
704            [1, 0, 0],
705            [1, 0, 1],
706            [1, 1, 0],
707            [1, 1, 1],
708        ];
709
710        for e in expected {
711            let idx = iter.next();
712            assert_eq!(idx.unwrap(), e);
713        }
714
715        assert!(iter.next().is_none());
716    }
717
718    #[test]
719    fn test_dim_n_coordinates_wrong_dim() {
720        let shape = Shape([2, 2].to_vec());
721
722        // dim3_coordinates should return error
723        assert!(shape.dim3_coordinates().is_err());
724        assert!(shape.dims_coordinates::<3>().is_err());
725    }
726
727    #[test]
728    fn test_dim_n_coordinates_empty_shape() {
729        let shape = Shape(vec![]);
730        let mut iter = shape.dims_coordinates::<0>().unwrap();
731        let result = iter.next();
732        assert_eq!(result, None);
733    }
734}