ghostflow_core/
shape.rs

1//! Shape and stride handling for tensors
2
3use crate::error::{GhostError, Result};
4use smallvec::SmallVec;
5
6/// Maximum dimensions for stack allocation (most tensors are <= 6D)
7const MAX_INLINE_DIMS: usize = 6;
8
9/// Shape of a tensor - dimensions along each axis
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct Shape(SmallVec<[usize; MAX_INLINE_DIMS]>);
12
13impl Shape {
14    /// Create a new shape from dimensions
15    pub fn new(dims: &[usize]) -> Self {
16        Shape(SmallVec::from_slice(dims))
17    }
18
19    /// Create a scalar shape (0 dimensions)
20    pub fn scalar() -> Self {
21        Shape(SmallVec::new())
22    }
23
24    /// Number of dimensions
25    pub fn ndim(&self) -> usize {
26        self.0.len()
27    }
28
29    /// Total number of elements
30    pub fn numel(&self) -> usize {
31        self.0.iter().product()
32    }
33
34    /// Get dimension at index
35    pub fn dim(&self, idx: usize) -> Option<usize> {
36        self.0.get(idx).copied()
37    }
38
39    /// Get dimensions as slice
40    pub fn dims(&self) -> &[usize] {
41        &self.0
42    }
43
44    /// Check if this is a scalar (0D tensor)
45    pub fn is_scalar(&self) -> bool {
46        self.0.is_empty()
47    }
48
49    /// Check if shapes are broadcastable
50    pub fn broadcast_with(&self, other: &Shape) -> Result<Shape> {
51        let max_ndim = self.ndim().max(other.ndim());
52        let mut result = SmallVec::with_capacity(max_ndim);
53
54        for i in 0..max_ndim {
55            let a = if i < self.ndim() {
56                self.0[self.ndim() - 1 - i]
57            } else {
58                1
59            };
60            let b = if i < other.ndim() {
61                other.0[other.ndim() - 1 - i]
62            } else {
63                1
64            };
65
66            if a == b {
67                result.push(a);
68            } else if a == 1 {
69                result.push(b);
70            } else if b == 1 {
71                result.push(a);
72            } else {
73                return Err(GhostError::BroadcastError {
74                    a: self.0.to_vec(),
75                    b: other.0.to_vec(),
76                });
77            }
78        }
79
80        result.reverse();
81        Ok(Shape(result))
82    }
83
84    /// Compute default (contiguous) strides for this shape
85    pub fn default_strides(&self) -> Strides {
86        if self.is_scalar() {
87            return Strides::new(&[]);
88        }
89
90        let mut strides = SmallVec::with_capacity(self.ndim());
91        let mut stride = 1usize;
92
93        for &dim in self.0.iter().rev() {
94            strides.push(stride);
95            stride *= dim;
96        }
97
98        strides.reverse();
99        Strides(strides)
100    }
101}
102
103impl From<&[usize]> for Shape {
104    fn from(dims: &[usize]) -> Self {
105        Shape::new(dims)
106    }
107}
108
109impl From<Vec<usize>> for Shape {
110    fn from(dims: Vec<usize>) -> Self {
111        Shape(SmallVec::from_vec(dims))
112    }
113}
114
115impl std::fmt::Display for Shape {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "[")?;
118        for (i, d) in self.0.iter().enumerate() {
119            if i > 0 {
120                write!(f, ", ")?;
121            }
122            write!(f, "{}", d)?;
123        }
124        write!(f, "]")
125    }
126}
127
128/// Strides for memory layout - bytes to skip for each dimension
129#[derive(Debug, Clone, PartialEq, Eq, Hash)]
130pub struct Strides(SmallVec<[usize; MAX_INLINE_DIMS]>);
131
132impl Strides {
133    /// Create new strides
134    pub fn new(strides: &[usize]) -> Self {
135        Strides(SmallVec::from_slice(strides))
136    }
137
138    /// Get stride at index
139    pub fn stride(&self, idx: usize) -> Option<usize> {
140        self.0.get(idx).copied()
141    }
142
143    /// Get strides as slice
144    pub fn as_slice(&self) -> &[usize] {
145        &self.0
146    }
147
148    /// Check if strides represent contiguous memory
149    pub fn is_contiguous(&self, shape: &Shape) -> bool {
150        if shape.is_scalar() {
151            return true;
152        }
153
154        let expected = shape.default_strides();
155        self.0 == expected.0
156    }
157
158    /// Compute linear offset from multi-dimensional indices
159    pub fn offset(&self, indices: &[usize]) -> usize {
160        indices
161            .iter()
162            .zip(self.0.iter())
163            .map(|(&idx, &stride)| idx * stride)
164            .sum()
165    }
166}
167
168impl From<&[usize]> for Strides {
169    fn from(strides: &[usize]) -> Self {
170        Strides::new(strides)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_shape_numel() {
180        assert_eq!(Shape::new(&[2, 3, 4]).numel(), 24);
181        assert_eq!(Shape::new(&[1]).numel(), 1);
182        assert_eq!(Shape::scalar().numel(), 1);
183    }
184
185    #[test]
186    fn test_broadcast() {
187        let a = Shape::new(&[3, 1]);
188        let b = Shape::new(&[1, 4]);
189        let c = a.broadcast_with(&b).unwrap();
190        assert_eq!(c.dims(), &[3, 4]);
191    }
192
193    #[test]
194    fn test_strides() {
195        let shape = Shape::new(&[2, 3, 4]);
196        let strides = shape.default_strides();
197        assert_eq!(strides.as_slice(), &[12, 4, 1]);
198    }
199}