1use crate::dimension::{Axis, Dimension, IxDyn};
5use crate::dtype::Element;
6use crate::error::{FerrayError, FerrayResult};
7
8use super::owned::Array;
9use super::view::ArrayView;
10
11impl<T: Element, D: Dimension> Array<T, D> {
12 pub fn mapv(&self, f: impl Fn(T) -> T) -> Self {
17 let inner = self.inner.mapv(&f);
18 Self::from_ndarray(inner)
19 }
20
21 pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
23 self.inner.mapv_inplace(&f);
24 }
25
26 pub fn zip_mut_with(
35 &mut self,
36 other: &Array<T, D>,
37 f: impl Fn(&mut T, &T),
38 ) -> FerrayResult<()> {
39 if self.shape() != other.shape() {
40 return Err(FerrayError::shape_mismatch(format!(
41 "cannot zip arrays with shapes {:?} and {:?}",
42 self.shape(),
43 other.shape(),
44 )));
45 }
46 self.inner.zip_mut_with(&other.inner, |a, b| f(a, b));
47 Ok(())
48 }
49
50 pub fn fold_axis(
62 &self,
63 axis: Axis,
64 init: T,
65 fold: impl FnMut(&T, &T) -> T,
66 ) -> FerrayResult<Array<T, IxDyn>>
67 where
68 D::NdarrayDim: ndarray::RemoveAxis,
69 {
70 let ndim = self.ndim();
71 if axis.index() >= ndim {
72 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
73 }
74 let nd_axis = ndarray::Axis(axis.index());
75 let mut fold = fold;
76 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
77 let dyn_result = result.into_dyn();
78 Ok(Array::from_ndarray(dyn_result))
79 }
80
81 pub fn map_to<U: Element>(&self, f: impl Fn(T) -> U) -> Array<U, D> {
86 let inner = self.inner.mapv(&f);
87 Array::from_ndarray(inner)
88 }
89}
90
91impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
96 pub fn mapv(&self, f: impl Fn(T) -> T) -> Array<T, D> {
98 let inner = self.inner.mapv(&f);
99 Array::from_ndarray(inner)
100 }
101
102 pub fn fold_axis(
104 &self,
105 axis: Axis,
106 init: T,
107 fold: impl FnMut(&T, &T) -> T,
108 ) -> FerrayResult<Array<T, IxDyn>>
109 where
110 D::NdarrayDim: ndarray::RemoveAxis,
111 {
112 let ndim = self.ndim();
113 if axis.index() >= ndim {
114 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
115 }
116 let nd_axis = ndarray::Axis(axis.index());
117 let mut fold = fold;
118 let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
119 let dyn_result = result.into_dyn();
120 Ok(Array::from_ndarray(dyn_result))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::dimension::{Ix1, Ix2};
128
129 #[test]
130 fn mapv_double() {
131 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
132 let doubled = arr.mapv(|x| x * 2.0);
133 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0, 8.0]);
134 assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
136 }
137
138 #[test]
139 fn mapv_inplace_negate() {
140 let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, -2.0, 3.0]).unwrap();
141 arr.mapv_inplace(|x| -x);
142 assert_eq!(arr.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
143 }
144
145 #[test]
146 fn zip_mut_with_add() {
147 let mut a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
148 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
149 a.zip_mut_with(&b, |x, y| *x += y).unwrap();
150 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
151 }
152
153 #[test]
154 fn zip_mut_with_shape_mismatch() {
155 let mut a = Array::<f64, Ix1>::zeros(Ix1::new([3])).unwrap();
156 let b = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
157 assert!(a.zip_mut_with(&b, |_, _| {}).is_err());
158 }
159
160 #[test]
161 fn fold_axis_sum_rows() {
162 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
163 .unwrap();
164 let sums = arr.fold_axis(Axis(1), 0.0, |acc, &x| *acc + x).unwrap();
166 assert_eq!(sums.shape(), &[2]);
167 let data: Vec<f64> = sums.iter().copied().collect();
168 assert_eq!(data, vec![6.0, 15.0]);
169 }
170
171 #[test]
172 fn fold_axis_sum_cols() {
173 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
174 .unwrap();
175 let sums = arr.fold_axis(Axis(0), 0.0, |acc, &x| *acc + x).unwrap();
177 assert_eq!(sums.shape(), &[3]);
178 let data: Vec<f64> = sums.iter().copied().collect();
179 assert_eq!(data, vec![5.0, 7.0, 9.0]);
180 }
181
182 #[test]
183 fn fold_axis_out_of_bounds() {
184 let arr = Array::<f64, Ix2>::zeros(Ix2::new([2, 3])).unwrap();
185 assert!(arr.fold_axis(Axis(2), 0.0, |a, _| *a).is_err());
186 }
187
188 #[test]
189 fn map_to_different_type() {
190 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.5, 2.7, 3.1]).unwrap();
191 let ints: Array<i32, Ix1> = arr.map_to(|x| x as i32);
192 assert_eq!(ints.as_slice().unwrap(), &[1, 2, 3]);
193 }
194
195 #[test]
196 fn view_mapv() {
197 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
198 let v = arr.view();
199 let doubled = v.mapv(|x| x * 2.0);
200 assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
201 }
202}