Skip to main content

ferray_core/dimension/
mod.rs

1// ferray-core: Dimension trait and concrete dimension types
2//
3// These types mirror ndarray's Ix1..Ix6 and IxDyn but live in ferray-core's
4// namespace so that ndarray never appears in the public API.
5
6#[cfg(feature = "std")]
7pub mod broadcast;
8// static_shape depends on Array, Vec, format!, and the ndarray-gated methods
9// on Dimension, so it's only available when std is in scope.
10#[cfg(all(feature = "const_shapes", feature = "std"))]
11pub mod static_shape;
12
13use core::fmt;
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19
20// We need ndarray's Dimension trait in scope for `as_array_view()` etc.
21#[cfg(feature = "std")]
22use ndarray::Dimension as NdDimension;
23
24/// Trait for types that describe the dimensionality of an array.
25///
26/// Each dimension type knows its number of axes at the type level
27/// (except [`IxDyn`] which carries it at runtime).
28pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
29    /// The number of axes, or `None` for dynamic-rank arrays.
30    const NDIM: Option<usize>;
31
32    /// The corresponding `ndarray` dimension type (private, not exposed in public API).
33    #[doc(hidden)]
34    #[cfg(feature = "std")]
35    type NdarrayDim: ndarray::Dimension;
36
37    /// Dimension type produced by removing one axis (#349). Mirrors
38    /// ndarray's `RemoveAxis::Smaller`. Used to express the rank of
39    /// `index_axis` and `remove_axis` results at the type level so
40    /// `Array<T, Ix2>::index_axis(...)` returns `ArrayView<T, Ix1>`
41    /// instead of `ArrayView<T, IxDyn>`.
42    ///
43    /// Saturation rules:
44    /// - `Ix0::Smaller = Ix0` (no negative ranks; calling
45    ///   `remove_axis` on a scalar is a runtime error).
46    /// - `IxDyn::Smaller = IxDyn` (closed under the operation).
47    type Smaller: Dimension;
48
49    /// Dimension type produced by inserting one axis (#349). Mirrors
50    /// ndarray's `Larger`. Used by `insert_axis` to preserve compile-
51    /// time rank.
52    ///
53    /// Saturation rules:
54    /// - `Ix6::Larger = IxDyn` (we don't expose Ix7 / Ix8 / etc.; the
55    ///   insertion point hops to dynamic rank for higher dimensions).
56    /// - `IxDyn::Larger = IxDyn` (closed).
57    type Larger: Dimension;
58
59    /// Return the shape as a slice.
60    fn as_slice(&self) -> &[usize];
61
62    /// Return the shape as a mutable slice.
63    fn as_slice_mut(&mut self) -> &mut [usize];
64
65    /// Number of dimensions.
66    fn ndim(&self) -> usize {
67        self.as_slice().len()
68    }
69
70    /// Total number of elements (product of all dimension sizes).
71    fn size(&self) -> usize {
72        self.as_slice().iter().product()
73    }
74
75    /// Convert to the internal ndarray dimension type.
76    #[doc(hidden)]
77    #[cfg(feature = "std")]
78    fn to_ndarray_dim(&self) -> Self::NdarrayDim;
79
80    /// Create from the internal ndarray dimension type.
81    #[doc(hidden)]
82    #[cfg(feature = "std")]
83    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
84
85    /// Construct a dimension from a slice of axis lengths.
86    ///
87    /// Returns `None` if the slice length does not match `Self::NDIM`
88    /// for fixed-rank dimensions. Always succeeds for [`IxDyn`].
89    ///
90    /// This is the inverse of [`Dimension::as_slice`].
91    fn from_dim_slice(shape: &[usize]) -> Option<Self>;
92}
93
94// ---------------------------------------------------------------------------
95// Fixed-rank dimension types
96// ---------------------------------------------------------------------------
97
98macro_rules! impl_fixed_dimension {
99    ($name:ident, $n:expr, $ndarray_ty:ty, $smaller:ty, $larger:ty) => {
100        /// A fixed-rank dimension with
101        #[doc = concat!(stringify!($n), " axes.")]
102        #[derive(Clone, PartialEq, Eq, Hash)]
103        pub struct $name {
104            shape: [usize; $n],
105        }
106
107        impl $name {
108            /// Create a new dimension from a fixed-size array.
109            #[inline]
110            pub const fn new(shape: [usize; $n]) -> Self {
111                Self { shape }
112            }
113        }
114
115        impl fmt::Debug for $name {
116            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117                write!(f, "{:?}", &self.shape[..])
118            }
119        }
120
121        impl From<[usize; $n]> for $name {
122            #[inline]
123            fn from(shape: [usize; $n]) -> Self {
124                Self::new(shape)
125            }
126        }
127
128        impl Dimension for $name {
129            const NDIM: Option<usize> = Some($n);
130
131            #[cfg(feature = "std")]
132            type NdarrayDim = $ndarray_ty;
133
134            type Smaller = $smaller;
135            type Larger = $larger;
136
137            #[inline]
138            fn as_slice(&self) -> &[usize] {
139                &self.shape
140            }
141
142            #[inline]
143            fn as_slice_mut(&mut self) -> &mut [usize] {
144                &mut self.shape
145            }
146
147            #[cfg(feature = "std")]
148            fn to_ndarray_dim(&self) -> Self::NdarrayDim {
149                // ndarray::Dim implements From<[usize; N]> for N=1..6
150                ndarray::Dim(self.shape)
151            }
152
153            #[cfg(feature = "std")]
154            fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
155                let view = dim.as_array_view();
156                let s = view.as_slice().expect("ndarray dim should be contiguous");
157                let mut shape = [0usize; $n];
158                shape.copy_from_slice(s);
159                Self { shape }
160            }
161
162            fn from_dim_slice(shape: &[usize]) -> Option<Self> {
163                if shape.len() != $n {
164                    return None;
165                }
166                let mut arr = [0usize; $n];
167                arr.copy_from_slice(shape);
168                Some(Self { shape: arr })
169            }
170        }
171    };
172}
173
174// Smaller / Larger relationships per #349:
175//   Ix1 → Smaller=Ix0,   Larger=Ix2
176//   Ix2 → Smaller=Ix1,   Larger=Ix3
177//   ...
178//   Ix6 → Smaller=Ix5,   Larger=IxDyn (no Ix7 type; saturate to dyn rank)
179impl_fixed_dimension!(Ix1, 1, ndarray::Ix1, Ix0, Ix2);
180impl_fixed_dimension!(Ix2, 2, ndarray::Ix2, Ix1, Ix3);
181impl_fixed_dimension!(Ix3, 3, ndarray::Ix3, Ix2, Ix4);
182impl_fixed_dimension!(Ix4, 4, ndarray::Ix4, Ix3, Ix5);
183impl_fixed_dimension!(Ix5, 5, ndarray::Ix5, Ix4, Ix6);
184impl_fixed_dimension!(Ix6, 6, ndarray::Ix6, Ix5, IxDyn);
185
186// ---------------------------------------------------------------------------
187// Ix0: scalar (0-dimensional)
188// ---------------------------------------------------------------------------
189
190/// A zero-dimensional (scalar) dimension.
191#[derive(Clone, PartialEq, Eq, Hash)]
192pub struct Ix0;
193
194impl fmt::Debug for Ix0 {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(f, "[]")
197    }
198}
199
200impl Dimension for Ix0 {
201    const NDIM: Option<usize> = Some(0);
202
203    #[cfg(feature = "std")]
204    type NdarrayDim = ndarray::Ix0;
205
206    // Saturation: removing an axis from a scalar is a runtime error
207    // anyway, but the type still has to map somewhere. Stay at Ix0.
208    type Smaller = Ix0;
209    type Larger = Ix1;
210
211    #[inline]
212    fn as_slice(&self) -> &[usize] {
213        &[]
214    }
215
216    #[inline]
217    fn as_slice_mut(&mut self) -> &mut [usize] {
218        &mut []
219    }
220
221    #[cfg(feature = "std")]
222    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
223        ndarray::Dim(())
224    }
225
226    #[cfg(feature = "std")]
227    fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
228        Self
229    }
230
231    fn from_dim_slice(shape: &[usize]) -> Option<Self> {
232        if shape.is_empty() { Some(Self) } else { None }
233    }
234}
235
236// ---------------------------------------------------------------------------
237// IxDyn: dynamic-rank dimension
238// ---------------------------------------------------------------------------
239
240/// A dynamic-rank dimension whose number of axes is determined at runtime.
241#[derive(Clone, PartialEq, Eq, Hash)]
242pub struct IxDyn {
243    shape: Vec<usize>,
244}
245
246impl IxDyn {
247    /// Create a new dynamic dimension from a slice.
248    #[must_use]
249    pub fn new(shape: &[usize]) -> Self {
250        Self {
251            shape: shape.to_vec(),
252        }
253    }
254}
255
256impl fmt::Debug for IxDyn {
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        write!(f, "{:?}", &self.shape[..])
259    }
260}
261
262impl From<Vec<usize>> for IxDyn {
263    fn from(shape: Vec<usize>) -> Self {
264        Self { shape }
265    }
266}
267
268impl From<&[usize]> for IxDyn {
269    fn from(shape: &[usize]) -> Self {
270        Self::new(shape)
271    }
272}
273
274impl Dimension for IxDyn {
275    const NDIM: Option<usize> = None;
276
277    #[cfg(feature = "std")]
278    type NdarrayDim = ndarray::IxDyn;
279
280    // Closed under Smaller / Larger: IxDyn handles any rank at runtime.
281    type Smaller = IxDyn;
282    type Larger = IxDyn;
283
284    #[inline]
285    fn as_slice(&self) -> &[usize] {
286        &self.shape
287    }
288
289    #[inline]
290    fn as_slice_mut(&mut self) -> &mut [usize] {
291        &mut self.shape
292    }
293
294    #[cfg(feature = "std")]
295    fn to_ndarray_dim(&self) -> Self::NdarrayDim {
296        ndarray::IxDyn(&self.shape)
297    }
298
299    #[cfg(feature = "std")]
300    fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
301        let view = dim.as_array_view();
302        let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
303        Self { shape: s.to_vec() }
304    }
305
306    fn from_dim_slice(shape: &[usize]) -> Option<Self> {
307        Some(Self {
308            shape: shape.to_vec(),
309        })
310    }
311}
312
313/// Newtype for axis indices used throughout ferray.
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
315pub struct Axis(pub usize);
316
317impl Axis {
318    /// Return the axis index.
319    #[inline]
320    #[must_use]
321    pub const fn index(self) -> usize {
322        self.0
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn ix1_basics() {
332        let d = Ix1::new([5]);
333        assert_eq!(d.ndim(), 1);
334        assert_eq!(d.size(), 5);
335        assert_eq!(d.as_slice(), &[5]);
336    }
337
338    #[test]
339    fn ix2_basics() {
340        let d = Ix2::new([3, 4]);
341        assert_eq!(d.ndim(), 2);
342        assert_eq!(d.size(), 12);
343    }
344
345    #[test]
346    fn ix0_basics() {
347        let d = Ix0;
348        assert_eq!(d.ndim(), 0);
349        assert_eq!(d.size(), 1);
350    }
351
352    #[test]
353    fn ixdyn_basics() {
354        let d = IxDyn::new(&[2, 3, 4]);
355        assert_eq!(d.ndim(), 3);
356        assert_eq!(d.size(), 24);
357    }
358
359    #[test]
360    fn roundtrip_ix2_ndarray() {
361        let d = Ix2::new([3, 7]);
362        let nd = d.to_ndarray_dim();
363        let d2 = Ix2::from_ndarray_dim(&nd);
364        assert_eq!(d, d2);
365    }
366
367    #[test]
368    fn roundtrip_ixdyn_ndarray() {
369        let d = IxDyn::new(&[2, 5, 3]);
370        let nd = d.to_ndarray_dim();
371        let d2 = IxDyn::from_ndarray_dim(&nd);
372        assert_eq!(d, d2);
373    }
374
375    #[test]
376    fn axis_index() {
377        let a = Axis(2);
378        assert_eq!(a.index(), 2);
379    }
380}