ferray_core/dimension/
mod.rs1#[cfg(feature = "std")]
7pub mod broadcast;
8#[cfg(all(feature = "const_shapes", feature = "std"))]
11pub mod static_shape;
12
13use core::fmt;
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19
20#[cfg(feature = "std")]
22use ndarray::Dimension as NdDimension;
23
24pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
29 const NDIM: Option<usize>;
31
32 #[doc(hidden)]
34 #[cfg(feature = "std")]
35 type NdarrayDim: ndarray::Dimension;
36
37 type Smaller: Dimension;
48
49 type Larger: Dimension;
58
59 fn as_slice(&self) -> &[usize];
61
62 fn as_slice_mut(&mut self) -> &mut [usize];
64
65 fn ndim(&self) -> usize {
67 self.as_slice().len()
68 }
69
70 fn size(&self) -> usize {
72 self.as_slice().iter().product()
73 }
74
75 #[doc(hidden)]
77 #[cfg(feature = "std")]
78 fn to_ndarray_dim(&self) -> Self::NdarrayDim;
79
80 #[doc(hidden)]
82 #[cfg(feature = "std")]
83 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
84
85 fn from_dim_slice(shape: &[usize]) -> Option<Self>;
92}
93
94macro_rules! impl_fixed_dimension {
99 ($name:ident, $n:expr, $ndarray_ty:ty, $smaller:ty, $larger:ty) => {
100 #[doc = concat!(stringify!($n), " axes.")]
102 #[derive(Clone, PartialEq, Eq, Hash)]
103 pub struct $name {
104 shape: [usize; $n],
105 }
106
107 impl $name {
108 #[inline]
110 pub const fn new(shape: [usize; $n]) -> Self {
111 Self { shape }
112 }
113 }
114
115 impl fmt::Debug for $name {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 write!(f, "{:?}", &self.shape[..])
118 }
119 }
120
121 impl From<[usize; $n]> for $name {
122 #[inline]
123 fn from(shape: [usize; $n]) -> Self {
124 Self::new(shape)
125 }
126 }
127
128 impl Dimension for $name {
129 const NDIM: Option<usize> = Some($n);
130
131 #[cfg(feature = "std")]
132 type NdarrayDim = $ndarray_ty;
133
134 type Smaller = $smaller;
135 type Larger = $larger;
136
137 #[inline]
138 fn as_slice(&self) -> &[usize] {
139 &self.shape
140 }
141
142 #[inline]
143 fn as_slice_mut(&mut self) -> &mut [usize] {
144 &mut self.shape
145 }
146
147 #[cfg(feature = "std")]
148 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
149 ndarray::Dim(self.shape)
151 }
152
153 #[cfg(feature = "std")]
154 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
155 let view = dim.as_array_view();
156 let s = view.as_slice().expect("ndarray dim should be contiguous");
157 let mut shape = [0usize; $n];
158 shape.copy_from_slice(s);
159 Self { shape }
160 }
161
162 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
163 if shape.len() != $n {
164 return None;
165 }
166 let mut arr = [0usize; $n];
167 arr.copy_from_slice(shape);
168 Some(Self { shape: arr })
169 }
170 }
171 };
172}
173
174impl_fixed_dimension!(Ix1, 1, ndarray::Ix1, Ix0, Ix2);
180impl_fixed_dimension!(Ix2, 2, ndarray::Ix2, Ix1, Ix3);
181impl_fixed_dimension!(Ix3, 3, ndarray::Ix3, Ix2, Ix4);
182impl_fixed_dimension!(Ix4, 4, ndarray::Ix4, Ix3, Ix5);
183impl_fixed_dimension!(Ix5, 5, ndarray::Ix5, Ix4, Ix6);
184impl_fixed_dimension!(Ix6, 6, ndarray::Ix6, Ix5, IxDyn);
185
186#[derive(Clone, PartialEq, Eq, Hash)]
192pub struct Ix0;
193
194impl fmt::Debug for Ix0 {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(f, "[]")
197 }
198}
199
200impl Dimension for Ix0 {
201 const NDIM: Option<usize> = Some(0);
202
203 #[cfg(feature = "std")]
204 type NdarrayDim = ndarray::Ix0;
205
206 type Smaller = Ix0;
209 type Larger = Ix1;
210
211 #[inline]
212 fn as_slice(&self) -> &[usize] {
213 &[]
214 }
215
216 #[inline]
217 fn as_slice_mut(&mut self) -> &mut [usize] {
218 &mut []
219 }
220
221 #[cfg(feature = "std")]
222 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
223 ndarray::Dim(())
224 }
225
226 #[cfg(feature = "std")]
227 fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
228 Self
229 }
230
231 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
232 if shape.is_empty() { Some(Self) } else { None }
233 }
234}
235
236#[derive(Clone, PartialEq, Eq, Hash)]
242pub struct IxDyn {
243 shape: Vec<usize>,
244}
245
246impl IxDyn {
247 #[must_use]
249 pub fn new(shape: &[usize]) -> Self {
250 Self {
251 shape: shape.to_vec(),
252 }
253 }
254}
255
256impl fmt::Debug for IxDyn {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 write!(f, "{:?}", &self.shape[..])
259 }
260}
261
262impl From<Vec<usize>> for IxDyn {
263 fn from(shape: Vec<usize>) -> Self {
264 Self { shape }
265 }
266}
267
268impl From<&[usize]> for IxDyn {
269 fn from(shape: &[usize]) -> Self {
270 Self::new(shape)
271 }
272}
273
274impl Dimension for IxDyn {
275 const NDIM: Option<usize> = None;
276
277 #[cfg(feature = "std")]
278 type NdarrayDim = ndarray::IxDyn;
279
280 type Smaller = IxDyn;
282 type Larger = IxDyn;
283
284 #[inline]
285 fn as_slice(&self) -> &[usize] {
286 &self.shape
287 }
288
289 #[inline]
290 fn as_slice_mut(&mut self) -> &mut [usize] {
291 &mut self.shape
292 }
293
294 #[cfg(feature = "std")]
295 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
296 ndarray::IxDyn(&self.shape)
297 }
298
299 #[cfg(feature = "std")]
300 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
301 let view = dim.as_array_view();
302 let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
303 Self { shape: s.to_vec() }
304 }
305
306 fn from_dim_slice(shape: &[usize]) -> Option<Self> {
307 Some(Self {
308 shape: shape.to_vec(),
309 })
310 }
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
315pub struct Axis(pub usize);
316
317impl Axis {
318 #[inline]
320 #[must_use]
321 pub const fn index(self) -> usize {
322 self.0
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn ix1_basics() {
332 let d = Ix1::new([5]);
333 assert_eq!(d.ndim(), 1);
334 assert_eq!(d.size(), 5);
335 assert_eq!(d.as_slice(), &[5]);
336 }
337
338 #[test]
339 fn ix2_basics() {
340 let d = Ix2::new([3, 4]);
341 assert_eq!(d.ndim(), 2);
342 assert_eq!(d.size(), 12);
343 }
344
345 #[test]
346 fn ix0_basics() {
347 let d = Ix0;
348 assert_eq!(d.ndim(), 0);
349 assert_eq!(d.size(), 1);
350 }
351
352 #[test]
353 fn ixdyn_basics() {
354 let d = IxDyn::new(&[2, 3, 4]);
355 assert_eq!(d.ndim(), 3);
356 assert_eq!(d.size(), 24);
357 }
358
359 #[test]
360 fn roundtrip_ix2_ndarray() {
361 let d = Ix2::new([3, 7]);
362 let nd = d.to_ndarray_dim();
363 let d2 = Ix2::from_ndarray_dim(&nd);
364 assert_eq!(d, d2);
365 }
366
367 #[test]
368 fn roundtrip_ixdyn_ndarray() {
369 let d = IxDyn::new(&[2, 5, 3]);
370 let nd = d.to_ndarray_dim();
371 let d2 = IxDyn::from_ndarray_dim(&nd);
372 assert_eq!(d, d2);
373 }
374
375 #[test]
376 fn axis_index() {
377 let a = Axis(2);
378 assert_eq!(a.index(), 2);
379 }
380}