1use ndarray::ShapeBuilder;
4
5use crate::dimension::Dimension;
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8use crate::layout::MemoryLayout;
9
10pub struct Array<T: Element, D: Dimension> {
22 pub(crate) inner: ndarray::Array<T, D::NdarrayDim>,
24 pub(crate) dim: D,
26}
27
28impl<T: Element, D: Dimension> Array<T, D> {
29 pub(crate) fn from_ndarray(inner: ndarray::Array<T, D::NdarrayDim>) -> Self {
33 let dim = D::from_ndarray_dim(&inner.raw_dim());
34 Self { inner, dim }
35 }
36
37 pub fn into_ndarray(self) -> ndarray::Array<T, D::NdarrayDim> {
46 self.inner
47 }
48
49 pub fn from_elem(dim: D, elem: T) -> FerrayResult<Self> {
57 let nd_dim = dim.to_ndarray_dim();
58 let inner = ndarray::Array::from_elem(nd_dim, elem);
59 Ok(Self { inner, dim })
60 }
61
62 pub fn zeros(dim: D) -> FerrayResult<Self> {
64 Self::from_elem(dim, T::zero())
65 }
66
67 pub fn ones(dim: D) -> FerrayResult<Self> {
69 Self::from_elem(dim, T::one())
70 }
71
72 pub fn from_vec(dim: D, data: Vec<T>) -> FerrayResult<Self> {
78 let expected = dim.size();
79 if data.len() != expected {
80 return Err(FerrayError::shape_mismatch(format!(
81 "data length {} does not match shape {:?} (expected {})",
82 data.len(),
83 dim.as_slice(),
84 expected,
85 )));
86 }
87 let nd_dim = dim.to_ndarray_dim();
88 let inner = ndarray::Array::from_shape_vec(nd_dim, data)
89 .map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
90 Ok(Self { inner, dim })
91 }
92
93 pub fn from_vec_f(dim: D, data: Vec<T>) -> FerrayResult<Self> {
98 let expected = dim.size();
99 if data.len() != expected {
100 return Err(FerrayError::shape_mismatch(format!(
101 "data length {} does not match shape {:?} (expected {})",
102 data.len(),
103 dim.as_slice(),
104 expected,
105 )));
106 }
107 let nd_dim = dim.to_ndarray_dim();
108 let inner = ndarray::Array::from_shape_vec(nd_dim.f(), data)
109 .map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
110 let dim = D::from_ndarray_dim(&inner.raw_dim());
111 Ok(Self { inner, dim })
112 }
113
114 pub fn from_iter_1d(iter: impl IntoIterator<Item = T>) -> FerrayResult<Self>
119 where
120 D: Dimension<NdarrayDim = ndarray::Ix1>,
121 {
122 let inner = ndarray::Array::from_iter(iter);
123 let dim = D::from_ndarray_dim(&inner.raw_dim());
124 Ok(Self { inner, dim })
125 }
126
127 pub fn layout(&self) -> MemoryLayout {
129 crate::layout::classify_layout(
130 self.inner.is_standard_layout(),
131 self.dim.as_slice(),
132 self.inner.strides(),
133 )
134 }
135
136 #[inline]
138 pub fn ndim(&self) -> usize {
139 self.dim.ndim()
140 }
141
142 #[inline]
144 pub fn shape(&self) -> &[usize] {
145 self.inner.shape()
146 }
147
148 #[inline]
150 pub fn strides(&self) -> &[isize] {
151 self.inner.strides()
152 }
153
154 #[inline]
156 pub fn size(&self) -> usize {
157 self.inner.len()
158 }
159
160 #[inline]
162 pub fn is_empty(&self) -> bool {
163 self.inner.is_empty()
164 }
165
166 #[inline]
168 pub fn as_ptr(&self) -> *const T {
169 self.inner.as_ptr()
170 }
171
172 #[inline]
174 pub fn as_mut_ptr(&mut self) -> *mut T {
175 self.inner.as_mut_ptr()
176 }
177
178 pub fn as_slice(&self) -> Option<&[T]> {
180 self.inner.as_slice()
181 }
182
183 pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
185 self.inner.as_slice_mut()
186 }
187
188 #[inline]
190 pub fn dim(&self) -> &D {
191 &self.dim
192 }
193}
194
195impl<T: Element, D: Dimension> Clone for Array<T, D> {
208 fn clone(&self) -> Self {
209 Self {
210 inner: self.inner.clone(),
211 dim: self.dim.clone(),
212 }
213 }
214}
215
216impl<T: Element + PartialEq, D: Dimension> PartialEq for Array<T, D> {
217 fn eq(&self, other: &Self) -> bool {
218 self.inner == other.inner
219 }
220}
221
222impl<T: Element + Eq, D: Dimension> Eq for Array<T, D> {}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::dimension::{Ix1, Ix2, IxDyn};
228
229 #[test]
230 fn create_zeros() {
231 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
232 assert_eq!(arr.shape(), &[3, 4]);
233 assert_eq!(arr.size(), 12);
234 assert_eq!(arr.ndim(), 2);
235 assert!(!arr.is_empty());
236 }
237
238 #[test]
239 fn create_from_vec() {
240 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
241 assert_eq!(arr.shape(), &[4]);
242 assert_eq!(arr.as_slice().unwrap(), &[1, 2, 3, 4]);
243 }
244
245 #[test]
246 fn create_from_vec_shape_mismatch() {
247 let res = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0]);
248 assert!(res.is_err());
249 }
250
251 #[test]
252 fn from_iter_1d() {
253 let arr = Array::<f64, Ix1>::from_iter_1d((0..5).map(|x| x as f64)).unwrap();
254 assert_eq!(arr.shape(), &[5]);
255 }
256
257 #[test]
258 fn layout_c_contiguous() {
259 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
260 assert_eq!(arr.layout(), MemoryLayout::C);
261 }
262
263 #[test]
264 fn from_vec_f_order() {
265 let arr =
266 Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
267 .unwrap();
268 assert_eq!(arr.shape(), &[2, 3]);
269 assert_eq!(arr.layout(), MemoryLayout::Fortran);
270 }
271
272 #[test]
273 fn clone_array() {
274 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
275 let b = a.clone();
276 assert_eq!(a, b);
277 }
278
279 #[test]
280 fn ndarray_roundtrip() {
281 let original = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
282 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), original.clone()).unwrap();
283 let nd: ndarray::Array<f64, ndarray::Ix2> = arr.into_ndarray();
286 let arr2: Array<f64, Ix2> = Array::from_ndarray(nd);
287 assert_eq!(arr2.as_slice().unwrap(), &original[..]);
288 }
289
290 #[test]
291 fn dynamic_rank() {
292 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0; 6]).unwrap();
293 assert_eq!(arr.ndim(), 2);
294 assert_eq!(arr.shape(), &[2, 3]);
295 }
296}