1use crate::dimension::{Axis, Dimension, IxDyn};
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8
9use super::cow::CowArray;
10use super::owned::Array;
11use super::view::ArrayView;
12
13impl<T: Element, D: Dimension> Array<T, D> {
14 pub fn mapv(&self, f: impl Fn(T) -> T) -> Self {
19 let inner = self.inner.mapv(&f);
20 Self::from_ndarray(inner)
21 }
22
23 pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
25 self.inner.mapv_inplace(&f);
26 }
27
28 pub fn zip_mut_with(
37 &mut self,
38 other: &Array<T, D>,
39 f: impl Fn(&mut T, &T),
40 ) -> FerrayResult<()> {
41 if self.shape() != other.shape() {
42 return Err(FerrayError::shape_mismatch(format!(
43 "cannot zip arrays with shapes {:?} and {:?}",
44 self.shape(),
45 other.shape(),
46 )));
47 }
48 self.inner.zip_mut_with(&other.inner, |a, b| f(a, b));
49 Ok(())
50 }
51
52 pub fn fold_axis(
64 &self,
65 axis: Axis,
66 init: T,
67 fold: impl FnMut(&T, &T) -> T,
68 ) -> FerrayResult<Array<T, IxDyn>>
69 where
70 D::NdarrayDim: ndarray::RemoveAxis,
71 {
72 let ndim = self.ndim();
73 if axis.index() >= ndim {
74 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
75 }
76 let nd_axis = ndarray::Axis(axis.index());
77 let mut fold = fold;
78 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
79 let dyn_result = result.into_dyn();
80 Ok(Array::from_ndarray(dyn_result))
81 }
82
83 pub fn map_to<U: Element>(&self, f: impl Fn(T) -> U) -> Array<U, D> {
88 let inner = self.inner.mapv(&f);
89 Array::from_ndarray(inner)
90 }
91
92 pub fn to_dyn(&self) -> Array<T, IxDyn> {
98 let dyn_inner = self.inner.clone().into_dyn();
99 Array::<T, IxDyn>::from_ndarray(dyn_inner)
100 }
101
102 pub fn into_dyn(self) -> Array<T, IxDyn> {
106 let dyn_inner = self.inner.into_dyn();
107 Array::<T, IxDyn>::from_ndarray(dyn_inner)
108 }
109
110 pub fn as_standard_layout(&self) -> CowArray<'_, T, D> {
118 if self.inner.is_standard_layout() {
123 CowArray::Borrowed(self.view())
124 } else {
125 let data: Vec<T> = self.iter().cloned().collect();
129 let owned = Array::from_vec(self.dim().clone(), data)
130 .expect("from_vec: data length was just built from self.iter()");
131 CowArray::Owned(owned)
132 }
133 }
134
135 pub fn as_fortran_layout(&self) -> CowArray<'_, T, D> {
142 if self.inner.t().is_standard_layout() {
146 CowArray::Borrowed(self.view())
147 } else {
148 let data: Vec<T> = self.inner.t().iter().cloned().collect();
153 let owned = Array::from_vec_f(self.dim().clone(), data)
154 .expect("from_vec_f: data length was just built from self.inner.t().iter()");
155 CowArray::Owned(owned)
156 }
157 }
158}
159
160impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
165 pub fn mapv(&self, f: impl Fn(T) -> T) -> Array<T, D> {
167 let inner = self.inner.mapv(&f);
168 Array::from_ndarray(inner)
169 }
170
171 pub fn fold_axis(
173 &self,
174 axis: Axis,
175 init: T,
176 fold: impl FnMut(&T, &T) -> T,
177 ) -> FerrayResult<Array<T, IxDyn>>
178 where
179 D::NdarrayDim: ndarray::RemoveAxis,
180 {
181 let ndim = self.ndim();
182 if axis.index() >= ndim {
183 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
184 }
185 let nd_axis = ndarray::Axis(axis.index());
186 let mut fold = fold;
187 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
188 let dyn_result = result.into_dyn();
189 Ok(Array::from_ndarray(dyn_result))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::dimension::{Ix1, Ix2};
197 use crate::layout::MemoryLayout;
198
199 #[test]
200 fn mapv_double() {
201 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
202 let doubled = arr.mapv(|x| x * 2.0);
203 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0, 8.0]);
204 assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
206 }
207
208 #[test]
209 fn mapv_inplace_negate() {
210 let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, -2.0, 3.0]).unwrap();
211 arr.mapv_inplace(|x| -x);
212 assert_eq!(arr.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
213 }
214
215 #[test]
216 fn zip_mut_with_add() {
217 let mut a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
218 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
219 a.zip_mut_with(&b, |x, y| *x += y).unwrap();
220 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
221 }
222
223 #[test]
224 fn zip_mut_with_shape_mismatch() {
225 let mut a = Array::<f64, Ix1>::zeros(Ix1::new([3])).unwrap();
226 let b = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
227 assert!(a.zip_mut_with(&b, |_, _| {}).is_err());
228 }
229
230 #[test]
233 fn as_standard_layout_borrows_when_already_c_contig() {
234 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
235 .unwrap();
236 assert_eq!(a.layout(), MemoryLayout::C);
237 let cow = a.as_standard_layout();
238 assert!(cow.is_borrowed(), "C-contig input must borrow, not copy");
239 assert_eq!(cow.shape(), &[2, 3]);
240 assert_eq!(cow.layout(), MemoryLayout::C);
241 }
242
243 #[test]
244 fn as_standard_layout_copies_f_contig_input_to_c() {
245 let a = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
246 .unwrap();
247 assert_eq!(a.layout(), MemoryLayout::Fortran);
249 let cow = a.as_standard_layout();
250 assert!(cow.is_owned(), "F-contig input must be copied to C-contig");
251 assert_eq!(cow.shape(), &[2, 3]);
252 assert_eq!(cow.layout(), MemoryLayout::C);
253 let owned = cow.into_owned();
255 assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
256 }
257
258 #[test]
259 fn as_fortran_layout_borrows_when_already_f_contig() {
260 let a = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
261 .unwrap();
262 assert_eq!(a.layout(), MemoryLayout::Fortran);
263 let cow = a.as_fortran_layout();
264 assert!(cow.is_borrowed(), "F-contig input must borrow, not copy");
265 assert_eq!(cow.shape(), &[2, 3]);
266 assert_eq!(cow.layout(), MemoryLayout::Fortran);
267 }
268
269 #[test]
270 fn as_fortran_layout_copies_c_contig_input_to_f() {
271 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
273 .unwrap();
274 assert_eq!(a.layout(), MemoryLayout::C);
275 let cow = a.as_fortran_layout();
276 assert!(cow.is_owned(), "C-contig input must be copied to F-contig");
277 assert_eq!(cow.shape(), &[2, 3]);
278 assert_eq!(cow.layout(), MemoryLayout::Fortran);
279 let owned = cow.into_owned();
281 let logical: Vec<f64> = owned.iter().copied().collect();
282 assert_eq!(logical, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
283 }
284
285 #[test]
286 fn layout_roundtrip_preserves_values() {
287 let original = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12i32).collect()).unwrap();
289 let f_cow = original.as_fortran_layout();
290 let f_owned = f_cow.into_owned();
291 assert_eq!(f_owned.layout(), MemoryLayout::Fortran);
292 let c_cow = f_owned.as_standard_layout();
293 let c_owned = c_cow.into_owned();
294 assert_eq!(c_owned.layout(), MemoryLayout::C);
295 assert_eq!(c_owned.as_slice().unwrap(), original.as_slice().unwrap());
296 }
297
298 #[test]
299 fn as_standard_layout_1d_always_borrows() {
300 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
302 assert!(a.as_standard_layout().is_borrowed());
303 assert!(a.as_fortran_layout().is_borrowed());
304 }
305
306 #[test]
307 fn fold_axis_sum_rows() {
308 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
309 .unwrap();
310 let sums = arr.fold_axis(Axis(1), 0.0, |acc, &x| *acc + x).unwrap();
312 assert_eq!(sums.shape(), &[2]);
313 let data: Vec<f64> = sums.iter().copied().collect();
314 assert_eq!(data, vec![6.0, 15.0]);
315 }
316
317 #[test]
318 fn fold_axis_sum_cols() {
319 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
320 .unwrap();
321 let sums = arr.fold_axis(Axis(0), 0.0, |acc, &x| *acc + x).unwrap();
323 assert_eq!(sums.shape(), &[3]);
324 let data: Vec<f64> = sums.iter().copied().collect();
325 assert_eq!(data, vec![5.0, 7.0, 9.0]);
326 }
327
328 #[test]
329 fn fold_axis_out_of_bounds() {
330 let arr = Array::<f64, Ix2>::zeros(Ix2::new([2, 3])).unwrap();
331 assert!(arr.fold_axis(Axis(2), 0.0, |a, _| *a).is_err());
332 }
333
334 #[test]
335 fn map_to_different_type() {
336 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.5, 2.7, 3.1]).unwrap();
337 let ints: Array<i32, Ix1> = arr.map_to(|x| x as i32);
338 assert_eq!(ints.as_slice().unwrap(), &[1, 2, 3]);
339 }
340
341 #[test]
342 fn view_mapv() {
343 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
344 let v = arr.view();
345 let doubled = v.mapv(|x| x * 2.0);
346 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
347 }
348}