candle_core/
shape.rs

1//! The shape of a tensor is a tuple with the size of each of its dimensions.
2#![allow(clippy::redundant_closure_call)]
3use crate::{Error, Result};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct Shape(Vec<usize>);
7
8pub const SCALAR: Shape = Shape(vec![]);
9
10impl std::fmt::Debug for Shape {
11    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12        write!(f, "{:?}", &self.dims())
13    }
14}
15
16impl<const C: usize> From<&[usize; C]> for Shape {
17    fn from(dims: &[usize; C]) -> Self {
18        Self(dims.to_vec())
19    }
20}
21
22impl From<&[usize]> for Shape {
23    fn from(dims: &[usize]) -> Self {
24        Self(dims.to_vec())
25    }
26}
27
28impl From<&Shape> for Shape {
29    fn from(shape: &Shape) -> Self {
30        Self(shape.0.to_vec())
31    }
32}
33
34impl From<()> for Shape {
35    fn from(_: ()) -> Self {
36        Self(vec![])
37    }
38}
39
40impl From<usize> for Shape {
41    fn from(d1: usize) -> Self {
42        Self(vec![d1])
43    }
44}
45
46impl From<(usize,)> for Shape {
47    fn from(d1: (usize,)) -> Self {
48        Self(vec![d1.0])
49    }
50}
51
52impl From<(usize, usize)> for Shape {
53    fn from(d12: (usize, usize)) -> Self {
54        Self(vec![d12.0, d12.1])
55    }
56}
57
58impl From<(usize, usize, usize)> for Shape {
59    fn from(d123: (usize, usize, usize)) -> Self {
60        Self(vec![d123.0, d123.1, d123.2])
61    }
62}
63
64impl From<(usize, usize, usize, usize)> for Shape {
65    fn from(d1234: (usize, usize, usize, usize)) -> Self {
66        Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
67    }
68}
69
70impl From<(usize, usize, usize, usize, usize)> for Shape {
71    fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
72        Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
73    }
74}
75
76impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
77    fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
78        Self(vec![
79            d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
80        ])
81    }
82}
83
84impl From<Vec<usize>> for Shape {
85    fn from(dims: Vec<usize>) -> Self {
86        Self(dims)
87    }
88}
89
90macro_rules! extract_dims {
91    ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
92        pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
93            if dims.len() != $cnt {
94                Err(Error::UnexpectedNumberOfDims {
95                    expected: $cnt,
96                    got: dims.len(),
97                    shape: Shape::from(dims),
98                }
99                .bt())
100            } else {
101                Ok($dims(dims))
102            }
103        }
104
105        impl Shape {
106            pub fn $fn_name(&self) -> Result<$out_type> {
107                $fn_name(self.0.as_slice())
108            }
109        }
110
111        impl crate::Tensor {
112            pub fn $fn_name(&self) -> Result<$out_type> {
113                self.shape().$fn_name()
114            }
115        }
116
117        impl std::convert::TryInto<$out_type> for Shape {
118            type Error = crate::Error;
119            fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
120                self.$fn_name()
121            }
122        }
123    };
124}
125
126impl Shape {
127    pub fn from_dims(dims: &[usize]) -> Self {
128        Self(dims.to_vec())
129    }
130
131    /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
132    pub fn rank(&self) -> usize {
133        self.0.len()
134    }
135
136    pub fn into_dims(self) -> Vec<usize> {
137        self.0
138    }
139
140    /// The dimensions as a slice of `usize`.
141    pub fn dims(&self) -> &[usize] {
142        &self.0
143    }
144
145    /// The dimension size for a specified dimension index.
146    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
147        let dim = dim.to_index(self, "dim")?;
148        Ok(self.dims()[dim])
149    }
150
151    /// The total number of elements, this is the product of all dimension sizes.
152    pub fn elem_count(&self) -> usize {
153        self.0.iter().product()
154    }
155
156    /// The strides given in number of elements for a contiguous n-dimensional
157    /// arrays using this shape.
158    pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
159        let mut stride: Vec<_> = self
160            .0
161            .iter()
162            .rev()
163            .scan(1, |prod, u| {
164                let prod_pre_mult = *prod;
165                *prod *= u;
166                Some(prod_pre_mult)
167            })
168            .collect();
169        stride.reverse();
170        stride
171    }
172
173    /// Returns true if the strides are C contiguous (aka row major).
174    pub fn is_contiguous(&self, stride: &[usize]) -> bool {
175        if self.0.len() != stride.len() {
176            return false;
177        }
178        let mut acc = 1;
179        for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
180            if dim > 1 && stride != acc {
181                return false;
182            }
183            acc *= dim;
184        }
185        true
186    }
187
188    /// Returns true if the strides are Fortran contiguous (aka column major).
189    pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
190        if self.0.len() != stride.len() {
191            return false;
192        }
193        let mut acc = 1;
194        for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
195            if dim > 1 && stride != acc {
196                return false;
197            }
198            acc *= dim;
199        }
200        true
201    }
202
203    /// Modifies the shape by adding a list of additional dimensions at the end of the existing
204    /// dimensions.
205    pub fn extend(mut self, additional_dims: &[usize]) -> Self {
206        self.0.extend(additional_dims);
207        self
208    }
209
210    /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
211    /// broadcasted shape. This is to be used for binary pointwise ops.
212    pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
213        let lhs = self;
214        let lhs_dims = lhs.dims();
215        let rhs_dims = rhs.dims();
216        let lhs_ndims = lhs_dims.len();
217        let rhs_ndims = rhs_dims.len();
218        let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
219        let mut bcast_dims = vec![0; bcast_ndims];
220        for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
221            let rev_idx = bcast_ndims - idx;
222            let l_value = if lhs_ndims < rev_idx {
223                1
224            } else {
225                lhs_dims[lhs_ndims - rev_idx]
226            };
227            let r_value = if rhs_ndims < rev_idx {
228                1
229            } else {
230                rhs_dims[rhs_ndims - rev_idx]
231            };
232            *bcast_value = if l_value == r_value {
233                l_value
234            } else if l_value == 1 {
235                r_value
236            } else if r_value == 1 {
237                l_value
238            } else {
239                Err(Error::ShapeMismatchBinaryOp {
240                    lhs: lhs.clone(),
241                    rhs: rhs.clone(),
242                    op,
243                }
244                .bt())?
245            }
246        }
247        Ok(Shape::from(bcast_dims))
248    }
249
250    pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
251        let lhs = self;
252        let lhs_dims = lhs.dims();
253        let rhs_dims = rhs.dims();
254        if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
255            crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
256        }
257        let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
258        let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
259        if lhs_k != rhs_k {
260            crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
261        }
262
263        let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
264        let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
265        let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
266        let bcast_dims = bcast.dims();
267
268        let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
269        let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
270        Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
271    }
272}
273
274pub trait Dim {
275    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
276    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
277}
278
279impl Dim for usize {
280    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
281        let dim = *self;
282        if dim >= shape.dims().len() {
283            Err(Error::DimOutOfRange {
284                shape: shape.clone(),
285                dim: dim as i32,
286                op,
287            }
288            .bt())?
289        } else {
290            Ok(dim)
291        }
292    }
293
294    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
295        let dim = *self;
296        if dim > shape.dims().len() {
297            Err(Error::DimOutOfRange {
298                shape: shape.clone(),
299                dim: dim as i32,
300                op,
301            }
302            .bt())?
303        } else {
304            Ok(dim)
305        }
306    }
307}
308
309#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
310pub enum D {
311    Minus1,
312    Minus2,
313    Minus(usize),
314}
315
316impl D {
317    fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
318        let dim = match self {
319            Self::Minus1 => -1,
320            Self::Minus2 => -2,
321            Self::Minus(u) => -(*u as i32),
322        };
323        Error::DimOutOfRange {
324            shape: shape.clone(),
325            dim,
326            op,
327        }
328        .bt()
329    }
330}
331
332impl Dim for D {
333    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
334        let rank = shape.rank();
335        match self {
336            Self::Minus1 if rank >= 1 => Ok(rank - 1),
337            Self::Minus2 if rank >= 2 => Ok(rank - 2),
338            Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
339            _ => Err(self.out_of_range(shape, op)),
340        }
341    }
342
343    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
344        let rank = shape.rank();
345        match self {
346            Self::Minus1 => Ok(rank),
347            Self::Minus2 if rank >= 1 => Ok(rank - 1),
348            Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
349            _ => Err(self.out_of_range(shape, op)),
350        }
351    }
352}
353
354pub trait Dims: Sized {
355    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
356
357    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
358        let dims = self.to_indexes_internal(shape, op)?;
359        for (i, &dim) in dims.iter().enumerate() {
360            if dims[..i].contains(&dim) {
361                Err(Error::DuplicateDimIndex {
362                    shape: shape.clone(),
363                    dims: dims.clone(),
364                    op,
365                }
366                .bt())?
367            }
368            if dim >= shape.rank() {
369                Err(Error::DimOutOfRange {
370                    shape: shape.clone(),
371                    dim: dim as i32,
372                    op,
373                }
374                .bt())?
375            }
376        }
377        Ok(dims)
378    }
379}
380
381impl Dims for Vec<usize> {
382    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
383        Ok(self)
384    }
385}
386
387impl<const N: usize> Dims for [usize; N] {
388    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
389        Ok(self.to_vec())
390    }
391}
392
393impl Dims for &[usize] {
394    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
395        Ok(self.to_vec())
396    }
397}
398
399impl Dims for () {
400    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
401        Ok(vec![])
402    }
403}
404
405impl<D: Dim + Sized> Dims for D {
406    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
407        let dim = self.to_index(shape, op)?;
408        Ok(vec![dim])
409    }
410}
411
412impl<D: Dim> Dims for (D,) {
413    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
414        let dim = self.0.to_index(shape, op)?;
415        Ok(vec![dim])
416    }
417}
418
419impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
420    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
421        let d0 = self.0.to_index(shape, op)?;
422        let d1 = self.1.to_index(shape, op)?;
423        Ok(vec![d0, d1])
424    }
425}
426
427impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
428    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
429        let d0 = self.0.to_index(shape, op)?;
430        let d1 = self.1.to_index(shape, op)?;
431        let d2 = self.2.to_index(shape, op)?;
432        Ok(vec![d0, d1, d2])
433    }
434}
435
436impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
437    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
438        let d0 = self.0.to_index(shape, op)?;
439        let d1 = self.1.to_index(shape, op)?;
440        let d2 = self.2.to_index(shape, op)?;
441        let d3 = self.3.to_index(shape, op)?;
442        Ok(vec![d0, d1, d2, d3])
443    }
444}
445
446impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
447    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
448        let d0 = self.0.to_index(shape, op)?;
449        let d1 = self.1.to_index(shape, op)?;
450        let d2 = self.2.to_index(shape, op)?;
451        let d3 = self.3.to_index(shape, op)?;
452        let d4 = self.4.to_index(shape, op)?;
453        Ok(vec![d0, d1, d2, d3, d4])
454    }
455}
456
457impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
458    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
459        let d0 = self.0.to_index(shape, op)?;
460        let d1 = self.1.to_index(shape, op)?;
461        let d2 = self.2.to_index(shape, op)?;
462        let d3 = self.3.to_index(shape, op)?;
463        let d4 = self.4.to_index(shape, op)?;
464        let d5 = self.5.to_index(shape, op)?;
465        Ok(vec![d0, d1, d2, d3, d4, d5])
466    }
467}
468
469extract_dims!(dims0, 0, |_: &[usize]| (), ());
470extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
471extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
472extract_dims!(
473    dims3,
474    3,
475    |d: &[usize]| (d[0], d[1], d[2]),
476    (usize, usize, usize)
477);
478extract_dims!(
479    dims4,
480    4,
481    |d: &[usize]| (d[0], d[1], d[2], d[3]),
482    (usize, usize, usize, usize)
483);
484extract_dims!(
485    dims5,
486    5,
487    |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
488    (usize, usize, usize, usize, usize)
489);
490
491pub trait ShapeWithOneHole {
492    fn into_shape(self, el_count: usize) -> Result<Shape>;
493}
494
495impl<S: Into<Shape>> ShapeWithOneHole for S {
496    fn into_shape(self, _el_count: usize) -> Result<Shape> {
497        Ok(self.into())
498    }
499}
500
501impl ShapeWithOneHole for ((),) {
502    fn into_shape(self, el_count: usize) -> Result<Shape> {
503        Ok(el_count.into())
504    }
505}
506
507fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
508    if prod_d == 0 {
509        crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
510    }
511    if el_count % prod_d != 0 {
512        crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
513    }
514    Ok(el_count / prod_d)
515}
516
517impl ShapeWithOneHole for ((), usize) {
518    fn into_shape(self, el_count: usize) -> Result<Shape> {
519        let ((), d1) = self;
520        Ok((hole_size(el_count, d1, &self)?, d1).into())
521    }
522}
523
524impl ShapeWithOneHole for (usize, ()) {
525    fn into_shape(self, el_count: usize) -> Result<Shape> {
526        let (d1, ()) = self;
527        Ok((d1, hole_size(el_count, d1, &self)?).into())
528    }
529}
530
531impl ShapeWithOneHole for ((), usize, usize) {
532    fn into_shape(self, el_count: usize) -> Result<Shape> {
533        let ((), d1, d2) = self;
534        Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
535    }
536}
537
538impl ShapeWithOneHole for (usize, (), usize) {
539    fn into_shape(self, el_count: usize) -> Result<Shape> {
540        let (d1, (), d2) = self;
541        Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
542    }
543}
544
545impl ShapeWithOneHole for (usize, usize, ()) {
546    fn into_shape(self, el_count: usize) -> Result<Shape> {
547        let (d1, d2, ()) = self;
548        Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
549    }
550}
551
552impl ShapeWithOneHole for ((), usize, usize, usize) {
553    fn into_shape(self, el_count: usize) -> Result<Shape> {
554        let ((), d1, d2, d3) = self;
555        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
556        Ok((d, d1, d2, d3).into())
557    }
558}
559
560impl ShapeWithOneHole for (usize, (), usize, usize) {
561    fn into_shape(self, el_count: usize) -> Result<Shape> {
562        let (d1, (), d2, d3) = self;
563        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
564        Ok((d1, d, d2, d3).into())
565    }
566}
567
568impl ShapeWithOneHole for (usize, usize, (), usize) {
569    fn into_shape(self, el_count: usize) -> Result<Shape> {
570        let (d1, d2, (), d3) = self;
571        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
572        Ok((d1, d2, d, d3).into())
573    }
574}
575
576impl ShapeWithOneHole for (usize, usize, usize, ()) {
577    fn into_shape(self, el_count: usize) -> Result<Shape> {
578        let (d1, d2, d3, ()) = self;
579        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
580        Ok((d1, d2, d3, d).into())
581    }
582}
583
584impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
585    fn into_shape(self, el_count: usize) -> Result<Shape> {
586        let ((), d1, d2, d3, d4) = self;
587        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
588        Ok((d, d1, d2, d3, d4).into())
589    }
590}
591
592impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
593    fn into_shape(self, el_count: usize) -> Result<Shape> {
594        let (d1, (), d2, d3, d4) = self;
595        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
596        Ok((d1, d, d2, d3, d4).into())
597    }
598}
599
600impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
601    fn into_shape(self, el_count: usize) -> Result<Shape> {
602        let (d1, d2, (), d3, d4) = self;
603        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
604        Ok((d1, d2, d, d3, d4).into())
605    }
606}
607
608impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
609    fn into_shape(self, el_count: usize) -> Result<Shape> {
610        let (d1, d2, d3, (), d4) = self;
611        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
612        Ok((d1, d2, d3, d, d4).into())
613    }
614}
615
616impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
617    fn into_shape(self, el_count: usize) -> Result<Shape> {
618        let (d1, d2, d3, d4, ()) = self;
619        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
620        Ok((d1, d2, d3, d4, d).into())
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[test]
629    fn stride() {
630        let shape = Shape::from(());
631        assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
632        let shape = Shape::from(42);
633        assert_eq!(shape.stride_contiguous(), [1]);
634        let shape = Shape::from((42, 1337));
635        assert_eq!(shape.stride_contiguous(), [1337, 1]);
636        let shape = Shape::from((299, 792, 458));
637        assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
638    }
639}