ferray_core/dimension/
mod.rs1#[cfg(not(feature = "no_std"))]
7pub mod broadcast;
8#[cfg(feature = "const_shapes")]
9pub mod static_shape;
10
11use core::fmt;
12
13#[cfg(feature = "no_std")]
14extern crate alloc;
15#[cfg(feature = "no_std")]
16use alloc::vec::Vec;
17
18#[cfg(not(feature = "no_std"))]
20use ndarray::Dimension as NdDimension;
21
22pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
27 const NDIM: Option<usize>;
29
30 #[doc(hidden)]
32 #[cfg(not(feature = "no_std"))]
33 type NdarrayDim: ndarray::Dimension;
34
35 fn as_slice(&self) -> &[usize];
37
38 fn as_slice_mut(&mut self) -> &mut [usize];
40
41 fn ndim(&self) -> usize {
43 self.as_slice().len()
44 }
45
46 fn size(&self) -> usize {
48 self.as_slice().iter().product()
49 }
50
51 #[doc(hidden)]
53 #[cfg(not(feature = "no_std"))]
54 fn to_ndarray_dim(&self) -> Self::NdarrayDim;
55
56 #[doc(hidden)]
58 #[cfg(not(feature = "no_std"))]
59 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
60}
61
62macro_rules! impl_fixed_dimension {
67 ($name:ident, $n:expr, $ndarray_ty:ty) => {
68 #[doc = concat!(stringify!($n), " axes.")]
70 #[derive(Clone, PartialEq, Eq, Hash)]
71 pub struct $name {
72 shape: [usize; $n],
73 }
74
75 impl $name {
76 #[inline]
78 pub fn new(shape: [usize; $n]) -> Self {
79 Self { shape }
80 }
81 }
82
83 impl fmt::Debug for $name {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "{:?}", &self.shape[..])
86 }
87 }
88
89 impl From<[usize; $n]> for $name {
90 #[inline]
91 fn from(shape: [usize; $n]) -> Self {
92 Self::new(shape)
93 }
94 }
95
96 impl Dimension for $name {
97 const NDIM: Option<usize> = Some($n);
98
99 #[cfg(not(feature = "no_std"))]
100 type NdarrayDim = $ndarray_ty;
101
102 #[inline]
103 fn as_slice(&self) -> &[usize] {
104 &self.shape
105 }
106
107 #[inline]
108 fn as_slice_mut(&mut self) -> &mut [usize] {
109 &mut self.shape
110 }
111
112 #[cfg(not(feature = "no_std"))]
113 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
114 ndarray::Dim(self.shape)
116 }
117
118 #[cfg(not(feature = "no_std"))]
119 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
120 let view = dim.as_array_view();
121 let s = view.as_slice().expect("ndarray dim should be contiguous");
122 let mut shape = [0usize; $n];
123 shape.copy_from_slice(s);
124 Self { shape }
125 }
126 }
127 };
128}
129
130impl_fixed_dimension!(Ix1, 1, ndarray::Ix1);
131impl_fixed_dimension!(Ix2, 2, ndarray::Ix2);
132impl_fixed_dimension!(Ix3, 3, ndarray::Ix3);
133impl_fixed_dimension!(Ix4, 4, ndarray::Ix4);
134impl_fixed_dimension!(Ix5, 5, ndarray::Ix5);
135impl_fixed_dimension!(Ix6, 6, ndarray::Ix6);
136
137#[derive(Clone, PartialEq, Eq, Hash)]
143pub struct Ix0;
144
145impl fmt::Debug for Ix0 {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 write!(f, "[]")
148 }
149}
150
151impl Dimension for Ix0 {
152 const NDIM: Option<usize> = Some(0);
153
154 #[cfg(not(feature = "no_std"))]
155 type NdarrayDim = ndarray::Ix0;
156
157 #[inline]
158 fn as_slice(&self) -> &[usize] {
159 &[]
160 }
161
162 #[inline]
163 fn as_slice_mut(&mut self) -> &mut [usize] {
164 &mut []
165 }
166
167 #[cfg(not(feature = "no_std"))]
168 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
169 ndarray::Dim(())
170 }
171
172 #[cfg(not(feature = "no_std"))]
173 fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
174 Ix0
175 }
176}
177
178#[derive(Clone, PartialEq, Eq, Hash)]
184pub struct IxDyn {
185 shape: Vec<usize>,
186}
187
188impl IxDyn {
189 pub fn new(shape: &[usize]) -> Self {
191 Self {
192 shape: shape.to_vec(),
193 }
194 }
195}
196
197impl fmt::Debug for IxDyn {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 write!(f, "{:?}", &self.shape[..])
200 }
201}
202
203impl From<Vec<usize>> for IxDyn {
204 fn from(shape: Vec<usize>) -> Self {
205 Self { shape }
206 }
207}
208
209impl From<&[usize]> for IxDyn {
210 fn from(shape: &[usize]) -> Self {
211 Self::new(shape)
212 }
213}
214
215impl Dimension for IxDyn {
216 const NDIM: Option<usize> = None;
217
218 #[cfg(not(feature = "no_std"))]
219 type NdarrayDim = ndarray::IxDyn;
220
221 #[inline]
222 fn as_slice(&self) -> &[usize] {
223 &self.shape
224 }
225
226 #[inline]
227 fn as_slice_mut(&mut self) -> &mut [usize] {
228 &mut self.shape
229 }
230
231 #[cfg(not(feature = "no_std"))]
232 fn to_ndarray_dim(&self) -> Self::NdarrayDim {
233 ndarray::IxDyn(&self.shape)
234 }
235
236 #[cfg(not(feature = "no_std"))]
237 fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
238 let view = dim.as_array_view();
239 let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
240 Self { shape: s.to_vec() }
241 }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
246pub struct Axis(pub usize);
247
248impl Axis {
249 #[inline]
251 pub fn index(self) -> usize {
252 self.0
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn ix1_basics() {
262 let d = Ix1::new([5]);
263 assert_eq!(d.ndim(), 1);
264 assert_eq!(d.size(), 5);
265 assert_eq!(d.as_slice(), &[5]);
266 }
267
268 #[test]
269 fn ix2_basics() {
270 let d = Ix2::new([3, 4]);
271 assert_eq!(d.ndim(), 2);
272 assert_eq!(d.size(), 12);
273 }
274
275 #[test]
276 fn ix0_basics() {
277 let d = Ix0;
278 assert_eq!(d.ndim(), 0);
279 assert_eq!(d.size(), 1);
280 }
281
282 #[test]
283 fn ixdyn_basics() {
284 let d = IxDyn::new(&[2, 3, 4]);
285 assert_eq!(d.ndim(), 3);
286 assert_eq!(d.size(), 24);
287 }
288
289 #[test]
290 fn roundtrip_ix2_ndarray() {
291 let d = Ix2::new([3, 7]);
292 let nd = d.to_ndarray_dim();
293 let d2 = Ix2::from_ndarray_dim(&nd);
294 assert_eq!(d, d2);
295 }
296
297 #[test]
298 fn roundtrip_ixdyn_ndarray() {
299 let d = IxDyn::new(&[2, 5, 3]);
300 let nd = d.to_ndarray_dim();
301 let d2 = IxDyn::from_ndarray_dim(&nd);
302 assert_eq!(d, d2);
303 }
304
305 #[test]
306 fn axis_index() {
307 let a = Axis(2);
308 assert_eq!(a.index(), 2);
309 }
310}