Skip to main content

ferray_core/array/
methods.rs

1// ferray-core: Closure-based operations (REQ-38)
2//   mapv, mapv_inplace, zip_mut_with, fold_axis
3
4use crate::dimension::{Axis, Dimension, IxDyn};
5use crate::dtype::Element;
6use crate::error::{FerrayError, FerrayResult};
7
8use super::owned::Array;
9use super::view::ArrayView;
10
11impl<T: Element, D: Dimension> Array<T, D> {
12    /// Apply a closure to every element, returning a new array.
13    ///
14    /// The closure receives each element by value (cloned) and must return
15    /// the same type. For type-changing maps, collect via iterators.
16    pub fn mapv(&self, f: impl Fn(T) -> T) -> Self {
17        let inner = self.inner.mapv(&f);
18        Self::from_ndarray(inner)
19    }
20
21    /// Apply a closure to every element in place.
22    pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
23        self.inner.mapv_inplace(&f);
24    }
25
26    /// Zip this array mutably with another array of the same shape,
27    /// applying a closure to each pair of elements.
28    ///
29    /// The closure receives `(&mut T, &T)` — the first element is from
30    /// `self` and can be modified, the second is from `other`.
31    ///
32    /// # Errors
33    /// Returns `FerrayError::ShapeMismatch` if shapes differ.
34    pub fn zip_mut_with(
35        &mut self,
36        other: &Array<T, D>,
37        f: impl Fn(&mut T, &T),
38    ) -> FerrayResult<()> {
39        if self.shape() != other.shape() {
40            return Err(FerrayError::shape_mismatch(format!(
41                "cannot zip arrays with shapes {:?} and {:?}",
42                self.shape(),
43                other.shape(),
44            )));
45        }
46        self.inner.zip_mut_with(&other.inner, |a, b| f(a, b));
47        Ok(())
48    }
49
50    /// Fold (reduce) along the given axis.
51    ///
52    /// `init` provides the initial accumulator value for each lane.
53    /// The closure receives `(accumulator, &element)` and must return
54    /// the new accumulator.
55    ///
56    /// Returns an array with one fewer dimension (the folded axis removed).
57    /// The result is always returned as a dynamic-rank array.
58    ///
59    /// # Errors
60    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
61    pub fn fold_axis(
62        &self,
63        axis: Axis,
64        init: T,
65        fold: impl FnMut(&T, &T) -> T,
66    ) -> FerrayResult<Array<T, IxDyn>>
67    where
68        D::NdarrayDim: ndarray::RemoveAxis,
69    {
70        let ndim = self.ndim();
71        if axis.index() >= ndim {
72            return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
73        }
74        let nd_axis = ndarray::Axis(axis.index());
75        let mut fold = fold;
76        let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
77        let dyn_result = result.into_dyn();
78        Ok(Array::from_ndarray(dyn_result))
79    }
80
81    /// Apply a closure elementwise, producing an array of a different type.
82    ///
83    /// Unlike `mapv` which preserves the element type, this allows
84    /// mapping to a different `Element` type.
85    pub fn map_to<U: Element>(&self, f: impl Fn(T) -> U) -> Array<U, D> {
86        let inner = self.inner.mapv(&f);
87        Array::from_ndarray(inner)
88    }
89}
90
91// ---------------------------------------------------------------------------
92// ArrayView methods
93// ---------------------------------------------------------------------------
94
95impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
96    /// Apply a closure to every element, returning a new owned array.
97    pub fn mapv(&self, f: impl Fn(T) -> T) -> Array<T, D> {
98        let inner = self.inner.mapv(&f);
99        Array::from_ndarray(inner)
100    }
101
102    /// Fold along an axis.
103    pub fn fold_axis(
104        &self,
105        axis: Axis,
106        init: T,
107        fold: impl FnMut(&T, &T) -> T,
108    ) -> FerrayResult<Array<T, IxDyn>>
109    where
110        D::NdarrayDim: ndarray::RemoveAxis,
111    {
112        let ndim = self.ndim();
113        if axis.index() >= ndim {
114            return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
115        }
116        let nd_axis = ndarray::Axis(axis.index());
117        let mut fold = fold;
118        let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
119        let dyn_result = result.into_dyn();
120        Ok(Array::from_ndarray(dyn_result))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::dimension::{Ix1, Ix2};
128
129    #[test]
130    fn mapv_double() {
131        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
132        let doubled = arr.mapv(|x| x * 2.0);
133        assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0, 8.0]);
134        // Original unchanged
135        assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
136    }
137
138    #[test]
139    fn mapv_inplace_negate() {
140        let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, -2.0, 3.0]).unwrap();
141        arr.mapv_inplace(|x| -x);
142        assert_eq!(arr.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
143    }
144
145    #[test]
146    fn zip_mut_with_add() {
147        let mut a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
148        let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
149        a.zip_mut_with(&b, |x, y| *x += y).unwrap();
150        assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
151    }
152
153    #[test]
154    fn zip_mut_with_shape_mismatch() {
155        let mut a = Array::<f64, Ix1>::zeros(Ix1::new([3])).unwrap();
156        let b = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
157        assert!(a.zip_mut_with(&b, |_, _| {}).is_err());
158    }
159
160    #[test]
161    fn fold_axis_sum_rows() {
162        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
163            .unwrap();
164        // Sum along axis 1 (sum each row)
165        let sums = arr.fold_axis(Axis(1), 0.0, |acc, &x| *acc + x).unwrap();
166        assert_eq!(sums.shape(), &[2]);
167        let data: Vec<f64> = sums.iter().copied().collect();
168        assert_eq!(data, vec![6.0, 15.0]);
169    }
170
171    #[test]
172    fn fold_axis_sum_cols() {
173        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
174            .unwrap();
175        // Sum along axis 0 (sum each column)
176        let sums = arr.fold_axis(Axis(0), 0.0, |acc, &x| *acc + x).unwrap();
177        assert_eq!(sums.shape(), &[3]);
178        let data: Vec<f64> = sums.iter().copied().collect();
179        assert_eq!(data, vec![5.0, 7.0, 9.0]);
180    }
181
182    #[test]
183    fn fold_axis_out_of_bounds() {
184        let arr = Array::<f64, Ix2>::zeros(Ix2::new([2, 3])).unwrap();
185        assert!(arr.fold_axis(Axis(2), 0.0, |a, _| *a).is_err());
186    }
187
188    #[test]
189    fn map_to_different_type() {
190        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.5, 2.7, 3.1]).unwrap();
191        let ints: Array<i32, Ix1> = arr.map_to(|x| x as i32);
192        assert_eq!(ints.as_slice().unwrap(), &[1, 2, 3]);
193    }
194
195    #[test]
196    fn view_mapv() {
197        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
198        let v = arr.view();
199        let doubled = v.mapv(|x| x * 2.0);
200        assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
201    }
202}