1use ndarray::ShapeBuilder;
4
5use crate::dimension::Dimension;
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8use crate::layout::MemoryLayout;
9
10pub struct Array<T: Element, D: Dimension> {
22 pub(crate) inner: ndarray::Array<T, D::NdarrayDim>,
24 pub(crate) dim: D,
26}
27
28impl<T: Element, D: Dimension> Array<T, D> {
29 pub(crate) fn from_ndarray(inner: ndarray::Array<T, D::NdarrayDim>) -> Self {
33 let dim = D::from_ndarray_dim(&inner.raw_dim());
34 Self { inner, dim }
35 }
36
37 pub(crate) fn into_ndarray(self) -> ndarray::Array<T, D::NdarrayDim> {
39 self.inner
40 }
41
42 pub fn from_elem(dim: D, elem: T) -> FerrayResult<Self> {
50 let nd_dim = dim.to_ndarray_dim();
51 let inner = ndarray::Array::from_elem(nd_dim, elem);
52 Ok(Self { inner, dim })
53 }
54
55 pub fn zeros(dim: D) -> FerrayResult<Self> {
57 Self::from_elem(dim, T::zero())
58 }
59
60 pub fn ones(dim: D) -> FerrayResult<Self> {
62 Self::from_elem(dim, T::one())
63 }
64
65 pub fn from_vec(dim: D, data: Vec<T>) -> FerrayResult<Self> {
71 let expected = dim.size();
72 if data.len() != expected {
73 return Err(FerrayError::shape_mismatch(format!(
74 "data length {} does not match shape {:?} (expected {})",
75 data.len(),
76 dim.as_slice(),
77 expected,
78 )));
79 }
80 let nd_dim = dim.to_ndarray_dim();
81 let inner = ndarray::Array::from_shape_vec(nd_dim, data)
82 .map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
83 Ok(Self { inner, dim })
84 }
85
86 pub fn from_vec_f(dim: D, data: Vec<T>) -> FerrayResult<Self> {
91 let expected = dim.size();
92 if data.len() != expected {
93 return Err(FerrayError::shape_mismatch(format!(
94 "data length {} does not match shape {:?} (expected {})",
95 data.len(),
96 dim.as_slice(),
97 expected,
98 )));
99 }
100 let nd_dim = dim.to_ndarray_dim();
101 let inner = ndarray::Array::from_shape_vec(nd_dim.f(), data)
102 .map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
103 let dim = D::from_ndarray_dim(&inner.raw_dim());
104 Ok(Self { inner, dim })
105 }
106
107 pub fn from_iter_1d(iter: impl IntoIterator<Item = T>) -> FerrayResult<Self>
112 where
113 D: Dimension<NdarrayDim = ndarray::Ix1>,
114 {
115 let inner = ndarray::Array::from_iter(iter);
116 let dim = D::from_ndarray_dim(&inner.raw_dim());
117 Ok(Self { inner, dim })
118 }
119
120 pub fn layout(&self) -> MemoryLayout {
122 if self.inner.is_standard_layout() {
123 MemoryLayout::C
124 } else {
125 let shape = self.dim.as_slice();
127 let strides = self.strides_isize();
128 crate::layout::detect_layout(shape, &strides)
129 }
130 }
131
132 pub(crate) fn strides_isize(&self) -> Vec<isize> {
134 self.inner.strides().to_vec()
135 }
136
137 #[inline]
139 pub fn ndim(&self) -> usize {
140 self.dim.ndim()
141 }
142
143 #[inline]
145 pub fn shape(&self) -> &[usize] {
146 self.inner.shape()
147 }
148
149 #[inline]
151 pub fn strides(&self) -> &[isize] {
152 self.inner.strides()
153 }
154
155 #[inline]
157 pub fn size(&self) -> usize {
158 self.inner.len()
159 }
160
161 #[inline]
163 pub fn is_empty(&self) -> bool {
164 self.inner.is_empty()
165 }
166
167 #[inline]
169 pub fn as_ptr(&self) -> *const T {
170 self.inner.as_ptr()
171 }
172
173 #[inline]
175 pub fn as_mut_ptr(&mut self) -> *mut T {
176 self.inner.as_mut_ptr()
177 }
178
179 pub fn as_slice(&self) -> Option<&[T]> {
181 self.inner.as_slice()
182 }
183
184 pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
186 self.inner.as_slice_mut()
187 }
188
189 #[inline]
191 pub fn dim(&self) -> &D {
192 &self.dim
193 }
194}
195
196impl<T: Element, D: Dimension> From<ndarray::Array<T, D::NdarrayDim>> for Array<T, D> {
198 fn from(inner: ndarray::Array<T, D::NdarrayDim>) -> Self {
199 Self::from_ndarray(inner)
200 }
201}
202
203impl<T: Element, D: Dimension> From<Array<T, D>> for ndarray::Array<T, D::NdarrayDim> {
204 fn from(arr: Array<T, D>) -> Self {
205 arr.into_ndarray()
206 }
207}
208
209impl<T: Element, D: Dimension> Clone for Array<T, D> {
210 fn clone(&self) -> Self {
211 Self {
212 inner: self.inner.clone(),
213 dim: self.dim.clone(),
214 }
215 }
216}
217
218impl<T: Element + PartialEq, D: Dimension> PartialEq for Array<T, D> {
219 fn eq(&self, other: &Self) -> bool {
220 self.inner == other.inner
221 }
222}
223
224impl<T: Element + Eq, D: Dimension> Eq for Array<T, D> {}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::dimension::{Ix1, Ix2, IxDyn};
230
231 #[test]
232 fn create_zeros() {
233 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
234 assert_eq!(arr.shape(), &[3, 4]);
235 assert_eq!(arr.size(), 12);
236 assert_eq!(arr.ndim(), 2);
237 assert!(!arr.is_empty());
238 }
239
240 #[test]
241 fn create_from_vec() {
242 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
243 assert_eq!(arr.shape(), &[4]);
244 assert_eq!(arr.as_slice().unwrap(), &[1, 2, 3, 4]);
245 }
246
247 #[test]
248 fn create_from_vec_shape_mismatch() {
249 let res = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0]);
250 assert!(res.is_err());
251 }
252
253 #[test]
254 fn from_iter_1d() {
255 let arr = Array::<f64, Ix1>::from_iter_1d((0..5).map(|x| x as f64)).unwrap();
256 assert_eq!(arr.shape(), &[5]);
257 }
258
259 #[test]
260 fn layout_c_contiguous() {
261 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
262 assert_eq!(arr.layout(), MemoryLayout::C);
263 }
264
265 #[test]
266 fn from_vec_f_order() {
267 let arr =
268 Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
269 .unwrap();
270 assert_eq!(arr.shape(), &[2, 3]);
271 assert_eq!(arr.layout(), MemoryLayout::Fortran);
272 }
273
274 #[test]
275 fn clone_array() {
276 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
277 let b = a.clone();
278 assert_eq!(a, b);
279 }
280
281 #[test]
282 fn ndarray_roundtrip() {
283 let original = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
284 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), original.clone()).unwrap();
285 let nd: ndarray::Array<f64, ndarray::Ix2> = arr.into();
286 let arr2: Array<f64, Ix2> = nd.into();
287 assert_eq!(arr2.as_slice().unwrap(), &original[..]);
288 }
289
290 #[test]
291 fn dynamic_rank() {
292 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0; 6]).unwrap();
293 assert_eq!(arr.ndim(), 2);
294 assert_eq!(arr.shape(), &[2, 3]);
295 }
296}