Skip to main content

radiate_utils/array/
shape.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[repr(transparent)]
8pub struct Strides(Arc<[usize]>);
9
10impl Strides {
11    pub fn as_slice(&self) -> &[usize] {
12        &self.0
13    }
14
15    pub fn stride_at(&self, index: usize) -> usize {
16        self.0[index]
17    }
18
19    pub fn strides(&self) -> &[usize] {
20        &self.0
21    }
22}
23
24impl From<&[usize]> for Strides {
25    fn from(strides: &[usize]) -> Self {
26        Self(Arc::from(strides))
27    }
28}
29
30impl From<Vec<usize>> for Strides {
31    fn from(strides: Vec<usize>) -> Self {
32        Self(Arc::from(strides))
33    }
34}
35
36impl From<&Shape> for Strides {
37    fn from(shape: &Shape) -> Self {
38        let rank = shape.dimensions();
39        if rank == 0 {
40            return Self(Arc::from(Vec::<usize>::new()));
41        }
42
43        let mut strides = vec![1usize; rank];
44        if rank >= 2 {
45            for i in (0..rank - 1).rev() {
46                let next = shape.dim_at(i + 1);
47                strides[i] = strides[i + 1].saturating_mul(next);
48            }
49        }
50
51        Self(Arc::from(strides))
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57#[repr(transparent)]
58pub struct Shape(Arc<[usize]>);
59
60impl Shape {
61    pub fn new(dims: impl Into<Arc<[usize]>>) -> Self {
62        let dims = dims.into();
63        Shape(dims)
64    }
65
66    /// Total number of elements implied by this shape.
67    /// Uses saturating multiplication to avoid overflow in release builds.
68    pub fn size(&self) -> usize {
69        self.0.iter().fold(1usize, |acc, &d| acc.saturating_mul(d))
70    }
71
72    /// Checked total element count. Returns None on overflow.
73    pub fn try_size(&self) -> Option<usize> {
74        let mut acc = 1usize;
75        for &d in self.0.iter() {
76            acc = acc.checked_mul(d)?;
77        }
78
79        Some(acc)
80    }
81
82    pub fn dimensions(&self) -> usize {
83        self.0.len()
84    }
85
86    pub fn contains_dim(&self, dim: usize) -> bool {
87        self.0.contains(&dim)
88    }
89
90    pub fn dim_at(&self, index: usize) -> usize {
91        self.0[index]
92    }
93
94    pub fn rank(&self) -> usize {
95        self.0.len()
96    }
97
98    pub fn is_empty(&self) -> bool {
99        self.0.is_empty()
100    }
101
102    pub fn is_scalar(&self) -> bool {
103        self.0.len() == 1 && self.0[0] == 1
104    }
105
106    pub fn is_vector(&self) -> bool {
107        self.0.len() == 1
108    }
109
110    pub fn is_matrix(&self) -> bool {
111        self.0.len() == 2
112    }
113
114    pub fn is_tensor(&self) -> bool {
115        self.0.len() > 2
116    }
117
118    pub fn is_square(&self) -> bool {
119        self.0.len() == 2 && self.0[0] == self.0[1]
120    }
121
122    pub fn iter(&self) -> impl Iterator<Item = &usize> {
123        self.0.iter()
124    }
125
126    pub fn as_slice(&self) -> &[usize] {
127        &self.0
128    }
129}
130
131impl AsRef<[usize]> for Shape {
132    fn as_ref(&self) -> &[usize] {
133        self.as_slice()
134    }
135}
136
137impl AsRef<[usize]> for Strides {
138    fn as_ref(&self) -> &[usize] {
139        self.as_slice()
140    }
141}
142
143impl From<&Shape> for Shape {
144    fn from(shape: &Shape) -> Self {
145        Shape::new(Arc::clone(&shape.0))
146    }
147}
148
149impl From<Vec<i32>> for Shape {
150    fn from(dims: Vec<i32>) -> Self {
151        Shape::new(dims.into_iter().map(|d| d as usize).collect::<Vec<usize>>())
152    }
153}
154
155impl From<Vec<usize>> for Shape {
156    fn from(dims: Vec<usize>) -> Self {
157        Shape::new(dims)
158    }
159}
160
161impl From<usize> for Shape {
162    fn from(value: usize) -> Shape {
163        Shape::new(vec![value])
164    }
165}
166
167impl From<(usize, usize)> for Shape {
168    fn from(value: (usize, usize)) -> Shape {
169        Shape::new(vec![value.0, value.1])
170    }
171}
172
173impl From<(usize, usize, usize)> for Shape {
174    fn from(value: (usize, usize, usize)) -> Shape {
175        Shape::new(vec![value.0, value.1, value.2])
176    }
177}
178
179impl From<(usize, usize, usize, usize)> for Shape {
180    fn from(value: (usize, usize, usize, usize)) -> Shape {
181        Shape::new(vec![value.0, value.1, value.2, value.3])
182    }
183}
184
185impl From<(usize, usize, usize, usize, usize)> for Shape {
186    fn from(value: (usize, usize, usize, usize, usize)) -> Shape {
187        Shape::new(vec![value.0, value.1, value.2, value.3, value.4])
188    }
189}
190
191impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
192    fn from(value: (usize, usize, usize, usize, usize, usize)) -> Shape {
193        Shape::new(vec![value.0, value.1, value.2, value.3, value.4, value.5])
194    }
195}
196
197impl From<(usize, usize, usize, usize, usize, usize, usize)> for Shape {
198    fn from(value: (usize, usize, usize, usize, usize, usize, usize)) -> Shape {
199        Shape::new(vec![
200            value.0, value.1, value.2, value.3, value.4, value.5, value.6,
201        ])
202    }
203}
204
205impl From<(usize, usize, usize, usize, usize, usize, usize, usize)> for Shape {
206    fn from(value: (usize, usize, usize, usize, usize, usize, usize, usize)) -> Shape {
207        Shape::new(vec![
208            value.0, value.1, value.2, value.3, value.4, value.5, value.6, value.7,
209        ])
210    }
211}
212
213impl From<&[usize]> for Shape {
214    fn from(dims: &[usize]) -> Self {
215        Shape::new(dims.to_vec())
216    }
217}
218
219// /// Compute the row-major flat index for a full N-D index (panics on mismatch/OOB).
220// #[inline]
221// pub(crate) fn flat_index_of(shape: &Shape, strides: &Strides, index: &[usize]) -> usize {
222//     assert_eq!(index.len(), shape.dimensions(), "rank mismatch");
223//     let mut flat = 0usize;
224//     for i in 0..index.len() {
225//         let dim = shape.dim_at(i);
226//         let idx = index[i];
227//         assert!(
228//             idx < dim,
229//             "index out of bounds: axis {i} idx={idx} dim={dim}"
230//         );
231//         flat = flat.saturating_add(idx.saturating_mul(strides.stride_at(i)));
232//     }
233//     flat
234// }
235
236// /// Fallible version of `flat_index_of`.
237// #[inline]
238// pub(crate) fn try_flat_index_of(
239//     shape: &Shape,
240//     strides: &Strides,
241//     index: &[usize],
242// ) -> Option<usize> {
243//     if index.len() != shape.dimensions() {
244//         return None;
245//     }
246//     let mut flat = 0usize;
247//     for i in 0..index.len() {
248//         let dim = shape.dim_at(i);
249//         let idx = index[i];
250
251//         if idx >= dim {
252//             return None;
253//         }
254
255//         flat = flat.saturating_add(idx.saturating_mul(strides.stride_at(i)));
256//     }
257
258//     Some(flat)
259// }