1use crate::dimension::{Axis, Dimension, IxDyn};
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8
9use super::cow::CowArray;
10use super::owned::Array;
11use super::view::ArrayView;
12
13impl<T: Element, D: Dimension> Array<T, D> {
14 #[must_use]
19 pub fn mapv(&self, f: impl Fn(T) -> T) -> Self {
20 let inner = self.inner.mapv(&f);
21 Self::from_ndarray(inner)
22 }
23
24 pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
26 self.inner.mapv_inplace(&f);
27 }
28
29 pub fn zip_mut_with(&mut self, other: &Self, f: impl Fn(&mut T, &T)) -> FerrayResult<()> {
38 if self.shape() != other.shape() {
39 return Err(FerrayError::shape_mismatch(format!(
40 "cannot zip arrays with shapes {:?} and {:?}",
41 self.shape(),
42 other.shape(),
43 )));
44 }
45 self.inner.zip_mut_with(&other.inner, |a, b| f(a, b));
46 Ok(())
47 }
48
49 pub fn fold_axis(
61 &self,
62 axis: Axis,
63 init: T,
64 fold: impl FnMut(&T, &T) -> T,
65 ) -> FerrayResult<Array<T, IxDyn>>
66 where
67 D::NdarrayDim: ndarray::RemoveAxis,
68 {
69 let ndim = self.ndim();
70 if axis.index() >= ndim {
71 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
72 }
73 let nd_axis = ndarray::Axis(axis.index());
74 let mut fold = fold;
75 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
76 let dyn_result = result.into_dyn();
77 Ok(Array::from_ndarray(dyn_result))
78 }
79
80 pub fn map_to<U: Element>(&self, f: impl Fn(T) -> U) -> Array<U, D> {
85 let inner = self.inner.mapv(&f);
86 Array::from_ndarray(inner)
87 }
88
89 pub fn to_dyn(&self) -> Array<T, IxDyn> {
95 let dyn_inner = self.inner.clone().into_dyn();
96 Array::<T, IxDyn>::from_ndarray(dyn_inner)
97 }
98
99 pub fn into_dyn(self) -> Array<T, IxDyn> {
103 let dyn_inner = self.inner.into_dyn();
104 Array::<T, IxDyn>::from_ndarray(dyn_inner)
105 }
106
107 pub fn as_standard_layout(&self) -> CowArray<'_, T, D> {
115 if self.inner.is_standard_layout() {
120 CowArray::Borrowed(self.view())
121 } else {
122 let data: Vec<T> = self.iter().cloned().collect();
126 let owned = Self::from_vec(self.dim().clone(), data)
127 .expect("from_vec: data length was just built from self.iter()");
128 CowArray::Owned(owned)
129 }
130 }
131
132 pub fn as_fortran_layout(&self) -> CowArray<'_, T, D> {
139 if self.inner.t().is_standard_layout() {
143 CowArray::Borrowed(self.view())
144 } else {
145 let data: Vec<T> = self.inner.t().iter().cloned().collect();
150 let owned = Self::from_vec_f(self.dim().clone(), data)
151 .expect("from_vec_f: data length was just built from self.inner.t().iter()");
152 CowArray::Owned(owned)
153 }
154 }
155}
156
157impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
162 pub fn mapv(&self, f: impl Fn(T) -> T) -> Array<T, D> {
164 let inner = self.inner.mapv(&f);
165 Array::from_ndarray(inner)
166 }
167
168 pub fn fold_axis(
170 &self,
171 axis: Axis,
172 init: T,
173 fold: impl FnMut(&T, &T) -> T,
174 ) -> FerrayResult<Array<T, IxDyn>>
175 where
176 D::NdarrayDim: ndarray::RemoveAxis,
177 {
178 let ndim = self.ndim();
179 if axis.index() >= ndim {
180 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
181 }
182 let nd_axis = ndarray::Axis(axis.index());
183 let mut fold = fold;
184 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
185 let dyn_result = result.into_dyn();
186 Ok(Array::from_ndarray(dyn_result))
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::dimension::{Ix1, Ix2};
194 use crate::layout::MemoryLayout;
195
196 #[test]
197 fn mapv_double() {
198 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
199 let doubled = arr.mapv(|x| x * 2.0);
200 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0, 8.0]);
201 assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
203 }
204
205 #[test]
206 fn mapv_inplace_negate() {
207 let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, -2.0, 3.0]).unwrap();
208 arr.mapv_inplace(|x| -x);
209 assert_eq!(arr.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
210 }
211
212 #[test]
213 fn zip_mut_with_add() {
214 let mut a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
215 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
216 a.zip_mut_with(&b, |x, y| *x += y).unwrap();
217 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
218 }
219
220 #[test]
221 fn zip_mut_with_shape_mismatch() {
222 let mut a = Array::<f64, Ix1>::zeros(Ix1::new([3])).unwrap();
223 let b = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
224 assert!(a.zip_mut_with(&b, |_, _| {}).is_err());
225 }
226
227 #[test]
230 fn as_standard_layout_borrows_when_already_c_contig() {
231 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
232 .unwrap();
233 assert_eq!(a.layout(), MemoryLayout::C);
234 let cow = a.as_standard_layout();
235 assert!(cow.is_borrowed(), "C-contig input must borrow, not copy");
236 assert_eq!(cow.shape(), &[2, 3]);
237 assert_eq!(cow.layout(), MemoryLayout::C);
238 }
239
240 #[test]
241 fn as_standard_layout_copies_f_contig_input_to_c() {
242 let a = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
243 .unwrap();
244 assert_eq!(a.layout(), MemoryLayout::Fortran);
246 let cow = a.as_standard_layout();
247 assert!(cow.is_owned(), "F-contig input must be copied to C-contig");
248 assert_eq!(cow.shape(), &[2, 3]);
249 assert_eq!(cow.layout(), MemoryLayout::C);
250 let owned = cow.into_owned();
252 assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
253 }
254
255 #[test]
256 fn as_fortran_layout_borrows_when_already_f_contig() {
257 let a = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
258 .unwrap();
259 assert_eq!(a.layout(), MemoryLayout::Fortran);
260 let cow = a.as_fortran_layout();
261 assert!(cow.is_borrowed(), "F-contig input must borrow, not copy");
262 assert_eq!(cow.shape(), &[2, 3]);
263 assert_eq!(cow.layout(), MemoryLayout::Fortran);
264 }
265
266 #[test]
267 fn as_fortran_layout_copies_c_contig_input_to_f() {
268 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
270 .unwrap();
271 assert_eq!(a.layout(), MemoryLayout::C);
272 let cow = a.as_fortran_layout();
273 assert!(cow.is_owned(), "C-contig input must be copied to F-contig");
274 assert_eq!(cow.shape(), &[2, 3]);
275 assert_eq!(cow.layout(), MemoryLayout::Fortran);
276 let owned = cow.into_owned();
278 let logical: Vec<f64> = owned.iter().copied().collect();
279 assert_eq!(logical, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
280 }
281
282 #[test]
283 fn layout_roundtrip_preserves_values() {
284 let original = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12i32).collect()).unwrap();
286 let f_cow = original.as_fortran_layout();
287 let f_owned = f_cow.into_owned();
288 assert_eq!(f_owned.layout(), MemoryLayout::Fortran);
289 let c_cow = f_owned.as_standard_layout();
290 let c_owned = c_cow.into_owned();
291 assert_eq!(c_owned.layout(), MemoryLayout::C);
292 assert_eq!(c_owned.as_slice().unwrap(), original.as_slice().unwrap());
293 }
294
295 #[test]
296 fn as_standard_layout_1d_always_borrows() {
297 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
299 assert!(a.as_standard_layout().is_borrowed());
300 assert!(a.as_fortran_layout().is_borrowed());
301 }
302
303 #[test]
304 fn fold_axis_sum_rows() {
305 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
306 .unwrap();
307 let sums = arr.fold_axis(Axis(1), 0.0, |acc, &x| *acc + x).unwrap();
309 assert_eq!(sums.shape(), &[2]);
310 let data: Vec<f64> = sums.iter().copied().collect();
311 assert_eq!(data, vec![6.0, 15.0]);
312 }
313
314 #[test]
315 fn fold_axis_sum_cols() {
316 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
317 .unwrap();
318 let sums = arr.fold_axis(Axis(0), 0.0, |acc, &x| *acc + x).unwrap();
320 assert_eq!(sums.shape(), &[3]);
321 let data: Vec<f64> = sums.iter().copied().collect();
322 assert_eq!(data, vec![5.0, 7.0, 9.0]);
323 }
324
325 #[test]
326 fn fold_axis_out_of_bounds() {
327 let arr = Array::<f64, Ix2>::zeros(Ix2::new([2, 3])).unwrap();
328 assert!(arr.fold_axis(Axis(2), 0.0, |a, _| *a).is_err());
329 }
330
331 #[test]
332 fn map_to_different_type() {
333 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.5, 2.7, 3.1]).unwrap();
334 let ints: Array<i32, Ix1> = arr.map_to(|x| x as i32);
335 assert_eq!(ints.as_slice().unwrap(), &[1, 2, 3]);
336 }
337
338 #[test]
339 fn view_mapv() {
340 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
341 let v = arr.view();
342 let doubled = v.mapv(|x| x * 2.0);
343 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
344 }
345}