candle_core/
shape.rs

1//! The shape of a tensor is a tuple with the size of each of its dimensions.
2#![allow(clippy::redundant_closure_call)]
3use crate::{Error, Result};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct Shape(Vec<usize>);
7
8pub const SCALAR: Shape = Shape(vec![]);
9
10impl std::fmt::Debug for Shape {
11    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12        write!(f, "{:?}", &self.dims())
13    }
14}
15
16impl<const C: usize> From<&[usize; C]> for Shape {
17    fn from(dims: &[usize; C]) -> Self {
18        Self(dims.to_vec())
19    }
20}
21
22impl From<&[usize]> for Shape {
23    fn from(dims: &[usize]) -> Self {
24        Self(dims.to_vec())
25    }
26}
27
28impl From<&Shape> for Shape {
29    fn from(shape: &Shape) -> Self {
30        Self(shape.0.to_vec())
31    }
32}
33
34impl From<()> for Shape {
35    fn from(_: ()) -> Self {
36        Self(vec![])
37    }
38}
39
40impl From<usize> for Shape {
41    fn from(d1: usize) -> Self {
42        Self(vec![d1])
43    }
44}
45
46macro_rules! impl_from_tuple {
47    ($tuple:ty, $($index:tt),+) => {
48        impl From<$tuple> for Shape {
49            fn from(d: $tuple) -> Self {
50                Self(vec![$(d.$index,)+])
51            }
52        }
53    }
54}
55
56impl_from_tuple!((usize,), 0);
57impl_from_tuple!((usize, usize), 0, 1);
58impl_from_tuple!((usize, usize, usize), 0, 1, 2);
59impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
60impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
61impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
62
63impl From<Vec<usize>> for Shape {
64    fn from(dims: Vec<usize>) -> Self {
65        Self(dims)
66    }
67}
68
69macro_rules! extract_dims {
70    ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
71        pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
72            if dims.len() != $cnt {
73                Err(Error::UnexpectedNumberOfDims {
74                    expected: $cnt,
75                    got: dims.len(),
76                    shape: Shape::from(dims),
77                }
78                .bt())
79            } else {
80                Ok($dims(dims))
81            }
82        }
83
84        impl Shape {
85            pub fn $fn_name(&self) -> Result<$out_type> {
86                $fn_name(self.0.as_slice())
87            }
88        }
89
90        impl crate::Tensor {
91            pub fn $fn_name(&self) -> Result<$out_type> {
92                self.shape().$fn_name()
93            }
94        }
95
96        impl std::convert::TryInto<$out_type> for Shape {
97            type Error = crate::Error;
98            fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
99                self.$fn_name()
100            }
101        }
102    };
103}
104
105impl Shape {
106    pub fn from_dims(dims: &[usize]) -> Self {
107        Self(dims.to_vec())
108    }
109
110    /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
111    pub fn rank(&self) -> usize {
112        self.0.len()
113    }
114
115    pub fn into_dims(self) -> Vec<usize> {
116        self.0
117    }
118
119    /// The dimensions as a slice of `usize`.
120    pub fn dims(&self) -> &[usize] {
121        &self.0
122    }
123
124    /// The dimension size for a specified dimension index.
125    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
126        let dim = dim.to_index(self, "dim")?;
127        Ok(self.dims()[dim])
128    }
129
130    /// The total number of elements, this is the product of all dimension sizes.
131    pub fn elem_count(&self) -> usize {
132        self.0.iter().product()
133    }
134
135    /// The strides given in number of elements for a contiguous n-dimensional
136    /// arrays using this shape.
137    pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
138        let mut stride: Vec<_> = self
139            .0
140            .iter()
141            .rev()
142            .scan(1, |prod, u| {
143                let prod_pre_mult = *prod;
144                *prod *= u;
145                Some(prod_pre_mult)
146            })
147            .collect();
148        stride.reverse();
149        stride
150    }
151
152    /// Returns true if the strides are C contiguous (aka row major).
153    pub fn is_contiguous(&self, stride: &[usize]) -> bool {
154        if self.0.len() != stride.len() {
155            return false;
156        }
157        let mut acc = 1;
158        for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
159            if dim > 1 && stride != acc {
160                return false;
161            }
162            acc *= dim;
163        }
164        true
165    }
166
167    /// Returns true if the strides are Fortran contiguous (aka column major).
168    pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
169        if self.0.len() != stride.len() {
170            return false;
171        }
172        let mut acc = 1;
173        for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
174            if dim > 1 && stride != acc {
175                return false;
176            }
177            acc *= dim;
178        }
179        true
180    }
181
182    /// Modifies the shape by adding a list of additional dimensions at the end of the existing
183    /// dimensions.
184    pub fn extend(mut self, additional_dims: &[usize]) -> Self {
185        self.0.extend(additional_dims);
186        self
187    }
188
189    /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
190    /// broadcasted shape. This is to be used for binary pointwise ops.
191    pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
192        let lhs = self;
193        let lhs_dims = lhs.dims();
194        let rhs_dims = rhs.dims();
195        let lhs_ndims = lhs_dims.len();
196        let rhs_ndims = rhs_dims.len();
197        let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
198        let mut bcast_dims = vec![0; bcast_ndims];
199        for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
200            let rev_idx = bcast_ndims - idx;
201            let l_value = if lhs_ndims < rev_idx {
202                1
203            } else {
204                lhs_dims[lhs_ndims - rev_idx]
205            };
206            let r_value = if rhs_ndims < rev_idx {
207                1
208            } else {
209                rhs_dims[rhs_ndims - rev_idx]
210            };
211            *bcast_value = if l_value == r_value {
212                l_value
213            } else if l_value == 1 {
214                r_value
215            } else if r_value == 1 {
216                l_value
217            } else {
218                Err(Error::ShapeMismatchBinaryOp {
219                    lhs: lhs.clone(),
220                    rhs: rhs.clone(),
221                    op,
222                }
223                .bt())?
224            }
225        }
226        Ok(Shape::from(bcast_dims))
227    }
228
229    pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
230        let lhs = self;
231        let lhs_dims = lhs.dims();
232        let rhs_dims = rhs.dims();
233        if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
234            crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
235        }
236        let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
237        let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
238        if lhs_k != rhs_k {
239            crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
240        }
241
242        let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
243        let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
244        let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
245        let bcast_dims = bcast.dims();
246
247        let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
248        let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
249        Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
250    }
251}
252
253pub trait Dim {
254    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
255    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
256}
257
258impl Dim for usize {
259    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
260        let dim = *self;
261        if dim >= shape.dims().len() {
262            Err(Error::DimOutOfRange {
263                shape: shape.clone(),
264                dim: dim as i32,
265                op,
266            }
267            .bt())?
268        } else {
269            Ok(dim)
270        }
271    }
272
273    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
274        let dim = *self;
275        if dim > shape.dims().len() {
276            Err(Error::DimOutOfRange {
277                shape: shape.clone(),
278                dim: dim as i32,
279                op,
280            }
281            .bt())?
282        } else {
283            Ok(dim)
284        }
285    }
286}
287
288#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
289pub enum D {
290    Minus1,
291    Minus2,
292    Minus(usize),
293}
294
295impl D {
296    fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
297        let dim = match self {
298            Self::Minus1 => -1,
299            Self::Minus2 => -2,
300            Self::Minus(u) => -(*u as i32),
301        };
302        Error::DimOutOfRange {
303            shape: shape.clone(),
304            dim,
305            op,
306        }
307        .bt()
308    }
309}
310
311impl Dim for D {
312    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
313        let rank = shape.rank();
314        match self {
315            Self::Minus1 if rank >= 1 => Ok(rank - 1),
316            Self::Minus2 if rank >= 2 => Ok(rank - 2),
317            Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
318            _ => Err(self.out_of_range(shape, op)),
319        }
320    }
321
322    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
323        let rank = shape.rank();
324        match self {
325            Self::Minus1 => Ok(rank),
326            Self::Minus2 if rank >= 1 => Ok(rank - 1),
327            Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
328            _ => Err(self.out_of_range(shape, op)),
329        }
330    }
331}
332
333pub trait Dims: Sized {
334    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
335
336    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
337        let dims = self.to_indexes_internal(shape, op)?;
338        for (i, &dim) in dims.iter().enumerate() {
339            if dims[..i].contains(&dim) {
340                Err(Error::DuplicateDimIndex {
341                    shape: shape.clone(),
342                    dims: dims.clone(),
343                    op,
344                }
345                .bt())?
346            }
347            if dim >= shape.rank() {
348                Err(Error::DimOutOfRange {
349                    shape: shape.clone(),
350                    dim: dim as i32,
351                    op,
352                }
353                .bt())?
354            }
355        }
356        Ok(dims)
357    }
358}
359
360impl Dims for Vec<usize> {
361    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
362        Ok(self)
363    }
364}
365
366impl<const N: usize> Dims for [usize; N] {
367    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
368        Ok(self.to_vec())
369    }
370}
371
372impl Dims for &[usize] {
373    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
374        Ok(self.to_vec())
375    }
376}
377
378impl Dims for () {
379    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
380        Ok(vec![])
381    }
382}
383
384impl<D: Dim + Sized> Dims for D {
385    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
386        let dim = self.to_index(shape, op)?;
387        Ok(vec![dim])
388    }
389}
390
391impl<D: Dim> Dims for (D,) {
392    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
393        let dim = self.0.to_index(shape, op)?;
394        Ok(vec![dim])
395    }
396}
397
398impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
399    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
400        let d0 = self.0.to_index(shape, op)?;
401        let d1 = self.1.to_index(shape, op)?;
402        Ok(vec![d0, d1])
403    }
404}
405
406impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
407    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
408        let d0 = self.0.to_index(shape, op)?;
409        let d1 = self.1.to_index(shape, op)?;
410        let d2 = self.2.to_index(shape, op)?;
411        Ok(vec![d0, d1, d2])
412    }
413}
414
415impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
416    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
417        let d0 = self.0.to_index(shape, op)?;
418        let d1 = self.1.to_index(shape, op)?;
419        let d2 = self.2.to_index(shape, op)?;
420        let d3 = self.3.to_index(shape, op)?;
421        Ok(vec![d0, d1, d2, d3])
422    }
423}
424
425impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
426    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
427        let d0 = self.0.to_index(shape, op)?;
428        let d1 = self.1.to_index(shape, op)?;
429        let d2 = self.2.to_index(shape, op)?;
430        let d3 = self.3.to_index(shape, op)?;
431        let d4 = self.4.to_index(shape, op)?;
432        Ok(vec![d0, d1, d2, d3, d4])
433    }
434}
435
436impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
437    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
438        let d0 = self.0.to_index(shape, op)?;
439        let d1 = self.1.to_index(shape, op)?;
440        let d2 = self.2.to_index(shape, op)?;
441        let d3 = self.3.to_index(shape, op)?;
442        let d4 = self.4.to_index(shape, op)?;
443        let d5 = self.5.to_index(shape, op)?;
444        Ok(vec![d0, d1, d2, d3, d4, d5])
445    }
446}
447
448extract_dims!(dims0, 0, |_: &[usize]| (), ());
449extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
450extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
451extract_dims!(
452    dims3,
453    3,
454    |d: &[usize]| (d[0], d[1], d[2]),
455    (usize, usize, usize)
456);
457extract_dims!(
458    dims4,
459    4,
460    |d: &[usize]| (d[0], d[1], d[2], d[3]),
461    (usize, usize, usize, usize)
462);
463extract_dims!(
464    dims5,
465    5,
466    |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
467    (usize, usize, usize, usize, usize)
468);
469
470pub trait ShapeWithOneHole {
471    fn into_shape(self, el_count: usize) -> Result<Shape>;
472}
473
474impl<S: Into<Shape>> ShapeWithOneHole for S {
475    fn into_shape(self, _el_count: usize) -> Result<Shape> {
476        Ok(self.into())
477    }
478}
479
480impl ShapeWithOneHole for ((),) {
481    fn into_shape(self, el_count: usize) -> Result<Shape> {
482        Ok(el_count.into())
483    }
484}
485
486fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
487    if prod_d == 0 {
488        crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
489    }
490    if el_count % prod_d != 0 {
491        crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
492    }
493    Ok(el_count / prod_d)
494}
495
496impl ShapeWithOneHole for ((), usize) {
497    fn into_shape(self, el_count: usize) -> Result<Shape> {
498        let ((), d1) = self;
499        Ok((hole_size(el_count, d1, &self)?, d1).into())
500    }
501}
502
503impl ShapeWithOneHole for (usize, ()) {
504    fn into_shape(self, el_count: usize) -> Result<Shape> {
505        let (d1, ()) = self;
506        Ok((d1, hole_size(el_count, d1, &self)?).into())
507    }
508}
509
510impl ShapeWithOneHole for ((), usize, usize) {
511    fn into_shape(self, el_count: usize) -> Result<Shape> {
512        let ((), d1, d2) = self;
513        Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
514    }
515}
516
517impl ShapeWithOneHole for (usize, (), usize) {
518    fn into_shape(self, el_count: usize) -> Result<Shape> {
519        let (d1, (), d2) = self;
520        Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
521    }
522}
523
524impl ShapeWithOneHole for (usize, usize, ()) {
525    fn into_shape(self, el_count: usize) -> Result<Shape> {
526        let (d1, d2, ()) = self;
527        Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
528    }
529}
530
531impl ShapeWithOneHole for ((), usize, usize, usize) {
532    fn into_shape(self, el_count: usize) -> Result<Shape> {
533        let ((), d1, d2, d3) = self;
534        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
535        Ok((d, d1, d2, d3).into())
536    }
537}
538
539impl ShapeWithOneHole for (usize, (), usize, usize) {
540    fn into_shape(self, el_count: usize) -> Result<Shape> {
541        let (d1, (), d2, d3) = self;
542        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
543        Ok((d1, d, d2, d3).into())
544    }
545}
546
547impl ShapeWithOneHole for (usize, usize, (), usize) {
548    fn into_shape(self, el_count: usize) -> Result<Shape> {
549        let (d1, d2, (), d3) = self;
550        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
551        Ok((d1, d2, d, d3).into())
552    }
553}
554
555impl ShapeWithOneHole for (usize, usize, usize, ()) {
556    fn into_shape(self, el_count: usize) -> Result<Shape> {
557        let (d1, d2, d3, ()) = self;
558        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
559        Ok((d1, d2, d3, d).into())
560    }
561}
562
563impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
564    fn into_shape(self, el_count: usize) -> Result<Shape> {
565        let ((), d1, d2, d3, d4) = self;
566        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
567        Ok((d, d1, d2, d3, d4).into())
568    }
569}
570
571impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
572    fn into_shape(self, el_count: usize) -> Result<Shape> {
573        let (d1, (), d2, d3, d4) = self;
574        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
575        Ok((d1, d, d2, d3, d4).into())
576    }
577}
578
579impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
580    fn into_shape(self, el_count: usize) -> Result<Shape> {
581        let (d1, d2, (), d3, d4) = self;
582        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
583        Ok((d1, d2, d, d3, d4).into())
584    }
585}
586
587impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
588    fn into_shape(self, el_count: usize) -> Result<Shape> {
589        let (d1, d2, d3, (), d4) = self;
590        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
591        Ok((d1, d2, d3, d, d4).into())
592    }
593}
594
595impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
596    fn into_shape(self, el_count: usize) -> Result<Shape> {
597        let (d1, d2, d3, d4, ()) = self;
598        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
599        Ok((d1, d2, d3, d4, d).into())
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn stride() {
609        let shape = Shape::from(());
610        assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
611        let shape = Shape::from(42);
612        assert_eq!(shape.stride_contiguous(), [1]);
613        let shape = Shape::from((42, 1337));
614        assert_eq!(shape.stride_contiguous(), [1337, 1]);
615        let shape = Shape::from((299, 792, 458));
616        assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
617    }
618
619    #[test]
620    fn test_from_tuple() {
621        let shape = Shape::from((2,));
622        assert_eq!(shape.dims(), &[2]);
623        let shape = Shape::from((2, 3));
624        assert_eq!(shape.dims(), &[2, 3]);
625        let shape = Shape::from((2, 3, 4));
626        assert_eq!(shape.dims(), &[2, 3, 4]);
627        let shape = Shape::from((2, 3, 4, 5));
628        assert_eq!(shape.dims(), &[2, 3, 4, 5]);
629        let shape = Shape::from((2, 3, 4, 5, 6));
630        assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
631        let shape = Shape::from((2, 3, 4, 5, 6, 7));
632        assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
633    }
634}