burn_tensor/tensor/
shape.rs

1use alloc::vec::Vec;
2
3/// Shape of a tensor.
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Shape {
6    /// The dimensions of the tensor.
7    pub dims: Vec<usize>,
8}
9
10impl Shape {
11    /// Returns the total number of elements of a tensor having this shape
12    pub fn num_elements(&self) -> usize {
13        self.dims.iter().product()
14    }
15
16    /// Returns the number of dimensions.
17    pub fn num_dims(&self) -> usize {
18        self.dims.len()
19    }
20
21    /// Constructs a new `Shape`.
22    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
23        // For backward compat
24        Self {
25            dims: dims.to_vec(),
26        }
27    }
28
29    // For compat with dims: [usize; D]
30    /// Returns the dimensions of the tensor as an array.
31    pub fn dims<const D: usize>(&self) -> [usize; D] {
32        let mut dims = [1; D];
33        dims[..D].copy_from_slice(&self.dims[..D]);
34        dims
35    }
36
37    /// Change the shape to one dimensional with the same number of elements.
38    pub fn flatten(&self) -> Self {
39        Self {
40            dims: [self.dims.iter().product()].into(),
41        }
42    }
43}
44
45impl<const D: usize> From<[usize; D]> for Shape {
46    fn from(dims: [usize; D]) -> Self {
47        Shape::new(dims)
48    }
49}
50
51impl From<Vec<i64>> for Shape {
52    fn from(shape: Vec<i64>) -> Self {
53        Self {
54            dims: shape.into_iter().map(|d| d as usize).collect(),
55        }
56    }
57}
58
59impl From<Vec<u64>> for Shape {
60    fn from(shape: Vec<u64>) -> Self {
61        Self {
62            dims: shape.into_iter().map(|d| d as usize).collect(),
63        }
64    }
65}
66
67impl From<Vec<usize>> for Shape {
68    fn from(shape: Vec<usize>) -> Self {
69        Self { dims: shape }
70    }
71}
72
73impl From<&Vec<usize>> for Shape {
74    fn from(shape: &Vec<usize>) -> Self {
75        Self {
76            dims: shape.clone(),
77        }
78    }
79}
80
81impl From<Shape> for Vec<usize> {
82    fn from(shape: Shape) -> Self {
83        shape.dims
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn num_elements() {
93        let dims = [2, 3, 4, 5];
94        let shape = Shape::new(dims);
95        assert_eq!(120, shape.num_elements());
96    }
97}