Skip to main content

ferray_core/dimension/
mod.rs

1// ferray-core: Dimension trait and concrete dimension types
2//
3// These types mirror ndarray's Ix1..Ix6 and IxDyn but live in ferray-core's
4// namespace so that ndarray never appears in the public API.
5
6#[cfg(not(feature = "no_std"))]
7pub mod broadcast;
8#[cfg(feature = "const_shapes")]
9pub mod static_shape;
10
11use core::fmt;
12
13#[cfg(feature = "no_std")]
14extern crate alloc;
15#[cfg(feature = "no_std")]
16use alloc::vec::Vec;
17
18// We need ndarray's Dimension trait in scope for `as_array_view()` etc.
19#[cfg(not(feature = "no_std"))]
20use ndarray::Dimension as NdDimension;
21
22/// Trait for types that describe the dimensionality of an array.
23///
24/// Each dimension type knows its number of axes at the type level
25/// (except [`IxDyn`] which carries it at runtime).
26pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
27    /// The number of axes, or `None` for dynamic-rank arrays.
28    const NDIM: Option<usize>;
29
30    /// The corresponding `ndarray` dimension type (private, not exposed in public API).
31    #[doc(hidden)]
32    #[cfg(not(feature = "no_std"))]
33    type NdarrayDim: ndarray::Dimension;
34
35    /// Return the shape as a slice.
36    fn as_slice(&self) -> &[usize];
37
38    /// Return the shape as a mutable slice.
39    fn as_slice_mut(&mut self) -> &mut [usize];
40
41    /// Number of dimensions.
42    fn ndim(&self) -> usize {
43        self.as_slice().len()
44    }
45
46    /// Total number of elements (product of all dimension sizes).
47    fn size(&self) -> usize {
48        self.as_slice().iter().product()
49    }
50
51    /// Convert to the internal ndarray dimension type.
52    #[doc(hidden)]
53    #[cfg(not(feature = "no_std"))]
54    fn to_ndarray_dim(&self) -> Self::NdarrayDim;
55
56    /// Create from the internal ndarray dimension type.
57    #[doc(hidden)]
58    #[cfg(not(feature = "no_std"))]
59    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
60}
61
62// ---------------------------------------------------------------------------
63// Fixed-rank dimension types
64// ---------------------------------------------------------------------------
65
66macro_rules! impl_fixed_dimension {
67    ($name:ident, $n:expr, $ndarray_ty:ty) => {
68        /// A fixed-rank dimension with
69        #[doc = concat!(stringify!($n), " axes.")]
70        #[derive(Clone, PartialEq, Eq, Hash)]
71        pub struct $name {
72            shape: [usize; $n],
73        }
74
75        impl $name {
76            /// Create a new dimension from a fixed-size array.
77            #[inline]
78            pub fn new(shape: [usize; $n]) -> Self {
79                Self { shape }
80            }
81        }
82
83        impl fmt::Debug for $name {
84            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85                write!(f, "{:?}", &self.shape[..])
86            }
87        }
88
89        impl From<[usize; $n]> for $name {
90            #[inline]
91            fn from(shape: [usize; $n]) -> Self {
92                Self::new(shape)
93            }
94        }
95
96        impl Dimension for $name {
97            const NDIM: Option<usize> = Some($n);
98
99            #[cfg(not(feature = "no_std"))]
100            type NdarrayDim = $ndarray_ty;
101
102            #[inline]
103            fn as_slice(&self) -> &[usize] {
104                &self.shape
105            }
106
107            #[inline]
108            fn as_slice_mut(&mut self) -> &mut [usize] {
109                &mut self.shape
110            }
111
112            #[cfg(not(feature = "no_std"))]
113            fn to_ndarray_dim(&self) -> Self::NdarrayDim {
114                // ndarray::Dim implements From<[usize; N]> for N=1..6
115                ndarray::Dim(self.shape)
116            }
117
118            #[cfg(not(feature = "no_std"))]
119            fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
120                let view = dim.as_array_view();
121                let s = view.as_slice().expect("ndarray dim should be contiguous");
122                let mut shape = [0usize; $n];
123                shape.copy_from_slice(s);
124                Self { shape }
125            }
126        }
127    };
128}
129
130impl_fixed_dimension!(Ix1, 1, ndarray::Ix1);
131impl_fixed_dimension!(Ix2, 2, ndarray::Ix2);
132impl_fixed_dimension!(Ix3, 3, ndarray::Ix3);
133impl_fixed_dimension!(Ix4, 4, ndarray::Ix4);
134impl_fixed_dimension!(Ix5, 5, ndarray::Ix5);
135impl_fixed_dimension!(Ix6, 6, ndarray::Ix6);
136
137// ---------------------------------------------------------------------------
138// Ix0: scalar (0-dimensional)
139// ---------------------------------------------------------------------------
140
141/// A zero-dimensional (scalar) dimension.
142#[derive(Clone, PartialEq, Eq, Hash)]
143pub struct Ix0;
144
145impl fmt::Debug for Ix0 {
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        write!(f, "[]")
148    }
149}
150
151impl Dimension for Ix0 {
152    const NDIM: Option<usize> = Some(0);
153
154    #[cfg(not(feature = "no_std"))]
155    type NdarrayDim = ndarray::Ix0;
156
157    #[inline]
158    fn as_slice(&self) -> &[usize] {
159        &[]
160    }
161
162    #[inline]
163    fn as_slice_mut(&mut self) -> &mut [usize] {
164        &mut []
165    }
166
167    #[cfg(not(feature = "no_std"))]
168    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
169        ndarray::Dim(())
170    }
171
172    #[cfg(not(feature = "no_std"))]
173    fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
174        Ix0
175    }
176}
177
178// ---------------------------------------------------------------------------
179// IxDyn: dynamic-rank dimension
180// ---------------------------------------------------------------------------
181
182/// A dynamic-rank dimension whose number of axes is determined at runtime.
183#[derive(Clone, PartialEq, Eq, Hash)]
184pub struct IxDyn {
185    shape: Vec<usize>,
186}
187
188impl IxDyn {
189    /// Create a new dynamic dimension from a slice.
190    pub fn new(shape: &[usize]) -> Self {
191        Self {
192            shape: shape.to_vec(),
193        }
194    }
195}
196
197impl fmt::Debug for IxDyn {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        write!(f, "{:?}", &self.shape[..])
200    }
201}
202
203impl From<Vec<usize>> for IxDyn {
204    fn from(shape: Vec<usize>) -> Self {
205        Self { shape }
206    }
207}
208
209impl From<&[usize]> for IxDyn {
210    fn from(shape: &[usize]) -> Self {
211        Self::new(shape)
212    }
213}
214
215impl Dimension for IxDyn {
216    const NDIM: Option<usize> = None;
217
218    #[cfg(not(feature = "no_std"))]
219    type NdarrayDim = ndarray::IxDyn;
220
221    #[inline]
222    fn as_slice(&self) -> &[usize] {
223        &self.shape
224    }
225
226    #[inline]
227    fn as_slice_mut(&mut self) -> &mut [usize] {
228        &mut self.shape
229    }
230
231    #[cfg(not(feature = "no_std"))]
232    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
233        ndarray::IxDyn(&self.shape)
234    }
235
236    #[cfg(not(feature = "no_std"))]
237    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
238        let view = dim.as_array_view();
239        let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
240        Self { shape: s.to_vec() }
241    }
242}
243
244/// Newtype for axis indices used throughout ferray.
245#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
246pub struct Axis(pub usize);
247
248impl Axis {
249    /// Return the axis index.
250    #[inline]
251    pub fn index(self) -> usize {
252        self.0
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn ix1_basics() {
262        let d = Ix1::new([5]);
263        assert_eq!(d.ndim(), 1);
264        assert_eq!(d.size(), 5);
265        assert_eq!(d.as_slice(), &[5]);
266    }
267
268    #[test]
269    fn ix2_basics() {
270        let d = Ix2::new([3, 4]);
271        assert_eq!(d.ndim(), 2);
272        assert_eq!(d.size(), 12);
273    }
274
275    #[test]
276    fn ix0_basics() {
277        let d = Ix0;
278        assert_eq!(d.ndim(), 0);
279        assert_eq!(d.size(), 1);
280    }
281
282    #[test]
283    fn ixdyn_basics() {
284        let d = IxDyn::new(&[2, 3, 4]);
285        assert_eq!(d.ndim(), 3);
286        assert_eq!(d.size(), 24);
287    }
288
289    #[test]
290    fn roundtrip_ix2_ndarray() {
291        let d = Ix2::new([3, 7]);
292        let nd = d.to_ndarray_dim();
293        let d2 = Ix2::from_ndarray_dim(&nd);
294        assert_eq!(d, d2);
295    }
296
297    #[test]
298    fn roundtrip_ixdyn_ndarray() {
299        let d = IxDyn::new(&[2, 5, 3]);
300        let nd = d.to_ndarray_dim();
301        let d2 = IxDyn::from_ndarray_dim(&nd);
302        assert_eq!(d, d2);
303    }
304
305    #[test]
306    fn axis_index() {
307        let a = Axis(2);
308        assert_eq!(a.index(), 2);
309    }
310}