jax_rs/
shape.rs

1//! Shape and stride utilities for n-dimensional arrays.
2
3use std::fmt;
4
5/// Shape of an n-dimensional array.
6///
7/// Represented as a vector of dimensions. An empty vector represents a scalar.
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub struct Shape {
10    dims: Vec<usize>,
11}
12
13impl Shape {
14    /// Create a new shape from dimensions.
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// # use jax_rs::Shape;
20    /// let shape = Shape::new(vec![2, 3, 4]);
21    /// assert_eq!(shape.ndim(), 3);
22    /// assert_eq!(shape.size(), 24);
23    /// ```
24    pub fn new(dims: Vec<usize>) -> Self {
25        Self { dims }
26    }
27
28    /// Create a scalar shape (empty dimensions).
29    pub fn scalar() -> Self {
30        Self { dims: Vec::new() }
31    }
32
33    /// Returns the number of dimensions.
34    #[inline]
35    pub fn ndim(&self) -> usize {
36        self.dims.len()
37    }
38
39    /// Returns the total number of elements.
40    pub fn size(&self) -> usize {
41        if self.dims.is_empty() {
42            1
43        } else {
44            self.dims.iter().product()
45        }
46    }
47
48    /// Returns a slice of the dimensions.
49    #[inline]
50    pub fn as_slice(&self) -> &[usize] {
51        &self.dims
52    }
53
54    /// Returns true if this is a scalar shape.
55    #[inline]
56    pub fn is_scalar(&self) -> bool {
57        self.dims.is_empty()
58    }
59
60    /// Get a specific dimension, or None if out of bounds.
61    pub fn get(&self, index: usize) -> Option<usize> {
62        self.dims.get(index).copied()
63    }
64
65    /// Compute default row-major (C-order) strides for this shape.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// # use jax_rs::Shape;
71    /// let shape = Shape::new(vec![2, 3, 4]);
72    /// let strides = shape.default_strides();
73    /// assert_eq!(strides, vec![12, 4, 1]);
74    /// ```
75    pub fn default_strides(&self) -> Vec<usize> {
76        let mut strides = vec![1; self.ndim()];
77        for i in (0..self.ndim().saturating_sub(1)).rev() {
78            strides[i] = strides[i + 1] * self.dims[i + 1];
79        }
80        strides
81    }
82
83    /// Check if two shapes are broadcast-compatible and return the result shape.
84    ///
85    /// Following NumPy broadcasting rules: dimensions are compatible if they are equal
86    /// or one of them is 1.
87    pub fn broadcast_with(&self, other: &Shape) -> Option<Shape> {
88        let ndim = self.ndim().max(other.ndim());
89        let mut result = Vec::with_capacity(ndim);
90
91        for i in 0..ndim {
92            let dim1 = if i < self.ndim() {
93                self.dims[self.ndim() - 1 - i]
94            } else {
95                1
96            };
97            let dim2 = if i < other.ndim() {
98                other.dims[other.ndim() - 1 - i]
99            } else {
100                1
101            };
102
103            if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
104                result.push(dim1.max(dim2));
105            } else {
106                return None; // Incompatible shapes
107            }
108        }
109
110        result.reverse();
111        Some(Shape::new(result))
112    }
113}
114
115impl From<Vec<usize>> for Shape {
116    fn from(dims: Vec<usize>) -> Self {
117        Shape::new(dims)
118    }
119}
120
121impl From<&[usize]> for Shape {
122    fn from(dims: &[usize]) -> Self {
123        Shape::new(dims.to_vec())
124    }
125}
126
127impl fmt::Display for Shape {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        write!(f, "(")?;
130        for (i, dim) in self.dims.iter().enumerate() {
131            if i > 0 {
132                write!(f, ", ")?;
133            }
134            write!(f, "{}", dim)?;
135        }
136        if self.dims.len() == 1 {
137            write!(f, ",")?;
138        }
139        write!(f, ")")
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_shape_creation() {
149        let shape = Shape::new(vec![2, 3, 4]);
150        assert_eq!(shape.ndim(), 3);
151        assert_eq!(shape.size(), 24);
152        assert_eq!(shape.as_slice(), &[2, 3, 4]);
153    }
154
155    #[test]
156    fn test_scalar_shape() {
157        let shape = Shape::scalar();
158        assert_eq!(shape.ndim(), 0);
159        assert_eq!(shape.size(), 1);
160        assert!(shape.is_scalar());
161    }
162
163    #[test]
164    fn test_default_strides() {
165        let shape = Shape::new(vec![2, 3, 4]);
166        assert_eq!(shape.default_strides(), vec![12, 4, 1]);
167
168        let shape = Shape::new(vec![5]);
169        assert_eq!(shape.default_strides(), vec![1]);
170
171        let shape = Shape::scalar();
172        assert_eq!(shape.default_strides(), Vec::<usize>::new());
173    }
174
175    #[test]
176    fn test_broadcast() {
177        let s1 = Shape::new(vec![3, 1]);
178        let s2 = Shape::new(vec![1, 4]);
179        assert_eq!(s1.broadcast_with(&s2), Some(Shape::new(vec![3, 4])));
180
181        let s1 = Shape::new(vec![2, 3]);
182        let s2 = Shape::new(vec![3]);
183        assert_eq!(s1.broadcast_with(&s2), Some(Shape::new(vec![2, 3])));
184
185        let s1 = Shape::new(vec![2, 3]);
186        let s2 = Shape::new(vec![4]);
187        assert_eq!(s1.broadcast_with(&s2), None); // Incompatible
188    }
189
190    #[test]
191    fn test_display() {
192        assert_eq!(Shape::new(vec![2, 3, 4]).to_string(), "(2, 3, 4)");
193        assert_eq!(Shape::new(vec![5]).to_string(), "(5,)");
194        assert_eq!(Shape::scalar().to_string(), "()");
195    }
196}