ferray_core/dimension/
mod.rs1#[cfg(feature = "std")]
7pub mod broadcast;
8#[cfg(all(feature = "const_shapes", feature = "std"))]
11pub mod static_shape;
12
13use core::fmt;
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19
20#[cfg(feature = "std")]
22use ndarray::Dimension as NdDimension;
23
24pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
29 const NDIM: Option<usize>;
31
32 #[doc(hidden)]
34 #[cfg(feature = "std")]
35 type NdarrayDim: ndarray::Dimension;
36
37 fn as_slice(&self) -> &[usize];
39
40 fn as_slice_mut(&mut self) -> &mut [usize];
42
43 fn ndim(&self) -> usize {
45 self.as_slice().len()
46 }
47
48 fn size(&self) -> usize {
50 self.as_slice().iter().product()
51 }
52
53 #[doc(hidden)]
55 #[cfg(feature = "std")]
56 fn to_ndarray_dim(&self) -> Self::NdarrayDim;
57
58 #[doc(hidden)]
60 #[cfg(feature = "std")]
61 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
62
63 fn from_dim_slice(shape: &[usize]) -> Option<Self>;
70}
71
72macro_rules! impl_fixed_dimension {
77 ($name:ident, $n:expr, $ndarray_ty:ty) => {
78 #[doc = concat!(stringify!($n), " axes.")]
80 #[derive(Clone, PartialEq, Eq, Hash)]
81 pub struct $name {
82 shape: [usize; $n],
83 }
84
85 impl $name {
86 #[inline]
88 pub const fn new(shape: [usize; $n]) -> Self {
89 Self { shape }
90 }
91 }
92
93 impl fmt::Debug for $name {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 write!(f, "{:?}", &self.shape[..])
96 }
97 }
98
99 impl From<[usize; $n]> for $name {
100 #[inline]
101 fn from(shape: [usize; $n]) -> Self {
102 Self::new(shape)
103 }
104 }
105
106 impl Dimension for $name {
107 const NDIM: Option<usize> = Some($n);
108
109 #[cfg(feature = "std")]
110 type NdarrayDim = $ndarray_ty;
111
112 #[inline]
113 fn as_slice(&self) -> &[usize] {
114 &self.shape
115 }
116
117 #[inline]
118 fn as_slice_mut(&mut self) -> &mut [usize] {
119 &mut self.shape
120 }
121
122 #[cfg(feature = "std")]
123 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
124 ndarray::Dim(self.shape)
126 }
127
128 #[cfg(feature = "std")]
129 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
130 let view = dim.as_array_view();
131 let s = view.as_slice().expect("ndarray dim should be contiguous");
132 let mut shape = [0usize; $n];
133 shape.copy_from_slice(s);
134 Self { shape }
135 }
136
137 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
138 if shape.len() != $n {
139 return None;
140 }
141 let mut arr = [0usize; $n];
142 arr.copy_from_slice(shape);
143 Some(Self { shape: arr })
144 }
145 }
146 };
147}
148
149impl_fixed_dimension!(Ix1, 1, ndarray::Ix1);
150impl_fixed_dimension!(Ix2, 2, ndarray::Ix2);
151impl_fixed_dimension!(Ix3, 3, ndarray::Ix3);
152impl_fixed_dimension!(Ix4, 4, ndarray::Ix4);
153impl_fixed_dimension!(Ix5, 5, ndarray::Ix5);
154impl_fixed_dimension!(Ix6, 6, ndarray::Ix6);
155
156#[derive(Clone, PartialEq, Eq, Hash)]
162pub struct Ix0;
163
164impl fmt::Debug for Ix0 {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(f, "[]")
167 }
168}
169
170impl Dimension for Ix0 {
171 const NDIM: Option<usize> = Some(0);
172
173 #[cfg(feature = "std")]
174 type NdarrayDim = ndarray::Ix0;
175
176 #[inline]
177 fn as_slice(&self) -> &[usize] {
178 &[]
179 }
180
181 #[inline]
182 fn as_slice_mut(&mut self) -> &mut [usize] {
183 &mut []
184 }
185
186 #[cfg(feature = "std")]
187 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
188 ndarray::Dim(())
189 }
190
191 #[cfg(feature = "std")]
192 fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
193 Self
194 }
195
196 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
197 if shape.is_empty() { Some(Self) } else { None }
198 }
199}
200
201#[derive(Clone, PartialEq, Eq, Hash)]
207pub struct IxDyn {
208 shape: Vec<usize>,
209}
210
211impl IxDyn {
212 #[must_use]
214 pub fn new(shape: &[usize]) -> Self {
215 Self {
216 shape: shape.to_vec(),
217 }
218 }
219}
220
221impl fmt::Debug for IxDyn {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 write!(f, "{:?}", &self.shape[..])
224 }
225}
226
227impl From<Vec<usize>> for IxDyn {
228 fn from(shape: Vec<usize>) -> Self {
229 Self { shape }
230 }
231}
232
233impl From<&[usize]> for IxDyn {
234 fn from(shape: &[usize]) -> Self {
235 Self::new(shape)
236 }
237}
238
239impl Dimension for IxDyn {
240 const NDIM: Option<usize> = None;
241
242 #[cfg(feature = "std")]
243 type NdarrayDim = ndarray::IxDyn;
244
245 #[inline]
246 fn as_slice(&self) -> &[usize] {
247 &self.shape
248 }
249
250 #[inline]
251 fn as_slice_mut(&mut self) -> &mut [usize] {
252 &mut self.shape
253 }
254
255 #[cfg(feature = "std")]
256 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
257 ndarray::IxDyn(&self.shape)
258 }
259
260 #[cfg(feature = "std")]
261 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
262 let view = dim.as_array_view();
263 let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
264 Self { shape: s.to_vec() }
265 }
266
267 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
268 Some(Self {
269 shape: shape.to_vec(),
270 })
271 }
272}
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
276pub struct Axis(pub usize);
277
278impl Axis {
279 #[inline]
281 #[must_use]
282 pub const fn index(self) -> usize {
283 self.0
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn ix1_basics() {
293 let d = Ix1::new([5]);
294 assert_eq!(d.ndim(), 1);
295 assert_eq!(d.size(), 5);
296 assert_eq!(d.as_slice(), &[5]);
297 }
298
299 #[test]
300 fn ix2_basics() {
301 let d = Ix2::new([3, 4]);
302 assert_eq!(d.ndim(), 2);
303 assert_eq!(d.size(), 12);
304 }
305
306 #[test]
307 fn ix0_basics() {
308 let d = Ix0;
309 assert_eq!(d.ndim(), 0);
310 assert_eq!(d.size(), 1);
311 }
312
313 #[test]
314 fn ixdyn_basics() {
315 let d = IxDyn::new(&[2, 3, 4]);
316 assert_eq!(d.ndim(), 3);
317 assert_eq!(d.size(), 24);
318 }
319
320 #[test]
321 fn roundtrip_ix2_ndarray() {
322 let d = Ix2::new([3, 7]);
323 let nd = d.to_ndarray_dim();
324 let d2 = Ix2::from_ndarray_dim(&nd);
325 assert_eq!(d, d2);
326 }
327
328 #[test]
329 fn roundtrip_ixdyn_ndarray() {
330 let d = IxDyn::new(&[2, 5, 3]);
331 let nd = d.to_ndarray_dim();
332 let d2 = IxDyn::from_ndarray_dim(&nd);
333 assert_eq!(d, d2);
334 }
335
336 #[test]
337 fn axis_index() {
338 let a = Axis(2);
339 assert_eq!(a.index(), 2);
340 }
341}