burn_tensor/tensor/
shape.rs1use alloc::vec::Vec;
2
3#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Shape {
6 pub dims: Vec<usize>,
8}
9
10impl Shape {
11 pub fn num_elements(&self) -> usize {
13 self.dims.iter().product()
14 }
15
16 pub fn num_dims(&self) -> usize {
18 self.dims.len()
19 }
20
21 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
23 Self {
25 dims: dims.to_vec(),
26 }
27 }
28
29 pub fn dims<const D: usize>(&self) -> [usize; D] {
32 let mut dims = [1; D];
33 dims[..D].copy_from_slice(&self.dims[..D]);
34 dims
35 }
36
37 pub fn flatten(&self) -> Self {
39 Self {
40 dims: [self.dims.iter().product()].into(),
41 }
42 }
43}
44
45impl<const D: usize> From<[usize; D]> for Shape {
46 fn from(dims: [usize; D]) -> Self {
47 Shape::new(dims)
48 }
49}
50
51impl From<Vec<i64>> for Shape {
52 fn from(shape: Vec<i64>) -> Self {
53 Self {
54 dims: shape.into_iter().map(|d| d as usize).collect(),
55 }
56 }
57}
58
59impl From<Vec<u64>> for Shape {
60 fn from(shape: Vec<u64>) -> Self {
61 Self {
62 dims: shape.into_iter().map(|d| d as usize).collect(),
63 }
64 }
65}
66
67impl From<Vec<usize>> for Shape {
68 fn from(shape: Vec<usize>) -> Self {
69 Self { dims: shape }
70 }
71}
72
73impl From<&Vec<usize>> for Shape {
74 fn from(shape: &Vec<usize>) -> Self {
75 Self {
76 dims: shape.clone(),
77 }
78 }
79}
80
81impl From<Shape> for Vec<usize> {
82 fn from(shape: Shape) -> Self {
83 shape.dims
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn num_elements() {
93 let dims = [2, 3, 4, 5];
94 let shape = Shape::new(dims);
95 assert_eq!(120, shape.num_elements());
96 }
97}