Skip to main content

ferray_core/dimension/
mod.rs

1// ferray-core: Dimension trait and concrete dimension types
2//
3// These types mirror ndarray's Ix1..Ix6 and IxDyn but live in ferray-core's
4// namespace so that ndarray never appears in the public API.
5
6#[cfg(feature = "std")]
7pub mod broadcast;
8// static_shape depends on Array, Vec, format!, and the ndarray-gated methods
9// on Dimension, so it's only available when std is in scope.
10#[cfg(all(feature = "const_shapes", feature = "std"))]
11pub mod static_shape;
12
13use core::fmt;
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19
20// We need ndarray's Dimension trait in scope for `as_array_view()` etc.
21#[cfg(feature = "std")]
22use ndarray::Dimension as NdDimension;
23
24/// Trait for types that describe the dimensionality of an array.
25///
26/// Each dimension type knows its number of axes at the type level
27/// (except [`IxDyn`] which carries it at runtime).
28pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
29    /// The number of axes, or `None` for dynamic-rank arrays.
30    const NDIM: Option<usize>;
31
32    /// The corresponding `ndarray` dimension type (private, not exposed in public API).
33    #[doc(hidden)]
34    #[cfg(feature = "std")]
35    type NdarrayDim: ndarray::Dimension;
36
37    /// Return the shape as a slice.
38    fn as_slice(&self) -> &[usize];
39
40    /// Return the shape as a mutable slice.
41    fn as_slice_mut(&mut self) -> &mut [usize];
42
43    /// Number of dimensions.
44    fn ndim(&self) -> usize {
45        self.as_slice().len()
46    }
47
48    /// Total number of elements (product of all dimension sizes).
49    fn size(&self) -> usize {
50        self.as_slice().iter().product()
51    }
52
53    /// Convert to the internal ndarray dimension type.
54    #[doc(hidden)]
55    #[cfg(feature = "std")]
56    fn to_ndarray_dim(&self) -> Self::NdarrayDim;
57
58    /// Create from the internal ndarray dimension type.
59    #[doc(hidden)]
60    #[cfg(feature = "std")]
61    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
62
63    /// Construct a dimension from a slice of axis lengths.
64    ///
65    /// Returns `None` if the slice length does not match `Self::NDIM`
66    /// for fixed-rank dimensions. Always succeeds for [`IxDyn`].
67    ///
68    /// This is the inverse of [`Dimension::as_slice`].
69    fn from_dim_slice(shape: &[usize]) -> Option<Self>;
70}
71
72// ---------------------------------------------------------------------------
73// Fixed-rank dimension types
74// ---------------------------------------------------------------------------
75
76macro_rules! impl_fixed_dimension {
77    ($name:ident, $n:expr, $ndarray_ty:ty) => {
78        /// A fixed-rank dimension with
79        #[doc = concat!(stringify!($n), " axes.")]
80        #[derive(Clone, PartialEq, Eq, Hash)]
81        pub struct $name {
82            shape: [usize; $n],
83        }
84
85        impl $name {
86            /// Create a new dimension from a fixed-size array.
87            #[inline]
88            pub const fn new(shape: [usize; $n]) -> Self {
89                Self { shape }
90            }
91        }
92
93        impl fmt::Debug for $name {
94            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95                write!(f, "{:?}", &self.shape[..])
96            }
97        }
98
99        impl From<[usize; $n]> for $name {
100            #[inline]
101            fn from(shape: [usize; $n]) -> Self {
102                Self::new(shape)
103            }
104        }
105
106        impl Dimension for $name {
107            const NDIM: Option<usize> = Some($n);
108
109            #[cfg(feature = "std")]
110            type NdarrayDim = $ndarray_ty;
111
112            #[inline]
113            fn as_slice(&self) -> &[usize] {
114                &self.shape
115            }
116
117            #[inline]
118            fn as_slice_mut(&mut self) -> &mut [usize] {
119                &mut self.shape
120            }
121
122            #[cfg(feature = "std")]
123            fn to_ndarray_dim(&self) -> Self::NdarrayDim {
124                // ndarray::Dim implements From<[usize; N]> for N=1..6
125                ndarray::Dim(self.shape)
126            }
127
128            #[cfg(feature = "std")]
129            fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
130                let view = dim.as_array_view();
131                let s = view.as_slice().expect("ndarray dim should be contiguous");
132                let mut shape = [0usize; $n];
133                shape.copy_from_slice(s);
134                Self { shape }
135            }
136
137            fn from_dim_slice(shape: &[usize]) -> Option<Self> {
138                if shape.len() != $n {
139                    return None;
140                }
141                let mut arr = [0usize; $n];
142                arr.copy_from_slice(shape);
143                Some(Self { shape: arr })
144            }
145        }
146    };
147}
148
149impl_fixed_dimension!(Ix1, 1, ndarray::Ix1);
150impl_fixed_dimension!(Ix2, 2, ndarray::Ix2);
151impl_fixed_dimension!(Ix3, 3, ndarray::Ix3);
152impl_fixed_dimension!(Ix4, 4, ndarray::Ix4);
153impl_fixed_dimension!(Ix5, 5, ndarray::Ix5);
154impl_fixed_dimension!(Ix6, 6, ndarray::Ix6);
155
156// ---------------------------------------------------------------------------
157// Ix0: scalar (0-dimensional)
158// ---------------------------------------------------------------------------
159
160/// A zero-dimensional (scalar) dimension.
161#[derive(Clone, PartialEq, Eq, Hash)]
162pub struct Ix0;
163
164impl fmt::Debug for Ix0 {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(f, "[]")
167    }
168}
169
170impl Dimension for Ix0 {
171    const NDIM: Option<usize> = Some(0);
172
173    #[cfg(feature = "std")]
174    type NdarrayDim = ndarray::Ix0;
175
176    #[inline]
177    fn as_slice(&self) -> &[usize] {
178        &[]
179    }
180
181    #[inline]
182    fn as_slice_mut(&mut self) -> &mut [usize] {
183        &mut []
184    }
185
186    #[cfg(feature = "std")]
187    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
188        ndarray::Dim(())
189    }
190
191    #[cfg(feature = "std")]
192    fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
193        Self
194    }
195
196    fn from_dim_slice(shape: &[usize]) -> Option<Self> {
197        if shape.is_empty() { Some(Self) } else { None }
198    }
199}
200
201// ---------------------------------------------------------------------------
202// IxDyn: dynamic-rank dimension
203// ---------------------------------------------------------------------------
204
205/// A dynamic-rank dimension whose number of axes is determined at runtime.
206#[derive(Clone, PartialEq, Eq, Hash)]
207pub struct IxDyn {
208    shape: Vec<usize>,
209}
210
211impl IxDyn {
212    /// Create a new dynamic dimension from a slice.
213    #[must_use]
214    pub fn new(shape: &[usize]) -> Self {
215        Self {
216            shape: shape.to_vec(),
217        }
218    }
219}
220
221impl fmt::Debug for IxDyn {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        write!(f, "{:?}", &self.shape[..])
224    }
225}
226
227impl From<Vec<usize>> for IxDyn {
228    fn from(shape: Vec<usize>) -> Self {
229        Self { shape }
230    }
231}
232
233impl From<&[usize]> for IxDyn {
234    fn from(shape: &[usize]) -> Self {
235        Self::new(shape)
236    }
237}
238
239impl Dimension for IxDyn {
240    const NDIM: Option<usize> = None;
241
242    #[cfg(feature = "std")]
243    type NdarrayDim = ndarray::IxDyn;
244
245    #[inline]
246    fn as_slice(&self) -> &[usize] {
247        &self.shape
248    }
249
250    #[inline]
251    fn as_slice_mut(&mut self) -> &mut [usize] {
252        &mut self.shape
253    }
254
255    #[cfg(feature = "std")]
256    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
257        ndarray::IxDyn(&self.shape)
258    }
259
260    #[cfg(feature = "std")]
261    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
262        let view = dim.as_array_view();
263        let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
264        Self { shape: s.to_vec() }
265    }
266
267    fn from_dim_slice(shape: &[usize]) -> Option<Self> {
268        Some(Self {
269            shape: shape.to_vec(),
270        })
271    }
272}
273
274/// Newtype for axis indices used throughout ferray.
275#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
276pub struct Axis(pub usize);
277
278impl Axis {
279    /// Return the axis index.
280    #[inline]
281    #[must_use]
282    pub const fn index(self) -> usize {
283        self.0
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn ix1_basics() {
293        let d = Ix1::new([5]);
294        assert_eq!(d.ndim(), 1);
295        assert_eq!(d.size(), 5);
296        assert_eq!(d.as_slice(), &[5]);
297    }
298
299    #[test]
300    fn ix2_basics() {
301        let d = Ix2::new([3, 4]);
302        assert_eq!(d.ndim(), 2);
303        assert_eq!(d.size(), 12);
304    }
305
306    #[test]
307    fn ix0_basics() {
308        let d = Ix0;
309        assert_eq!(d.ndim(), 0);
310        assert_eq!(d.size(), 1);
311    }
312
313    #[test]
314    fn ixdyn_basics() {
315        let d = IxDyn::new(&[2, 3, 4]);
316        assert_eq!(d.ndim(), 3);
317        assert_eq!(d.size(), 24);
318    }
319
320    #[test]
321    fn roundtrip_ix2_ndarray() {
322        let d = Ix2::new([3, 7]);
323        let nd = d.to_ndarray_dim();
324        let d2 = Ix2::from_ndarray_dim(&nd);
325        assert_eq!(d, d2);
326    }
327
328    #[test]
329    fn roundtrip_ixdyn_ndarray() {
330        let d = IxDyn::new(&[2, 5, 3]);
331        let nd = d.to_ndarray_dim();
332        let d2 = IxDyn::from_ndarray_dim(&nd);
333        assert_eq!(d, d2);
334    }
335
336    #[test]
337    fn axis_index() {
338        let a = Axis(2);
339        assert_eq!(a.index(), 2);
340    }
341}