Skip to main content

ferray_window/
functional.rs

1// ferray-window: Functional programming utilities
2//
3// Implements NumPy-equivalent functional utilities: vectorize, piecewise,
4// apply_along_axis, and apply_over_axes.
5
6use ferray_core::Array;
7use ferray_core::dimension::{Axis, Dimension, Ix1, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10use num_traits::Zero;
11
12/// Wrap a scalar function to operate elementwise on arrays of any dimension.
13///
14/// Returns a closure that accepts `&Array<T, D>` and returns
15/// `FerrayResult<Array<U, D>>`, applying `f` to every element.
16///
17/// This is NumPy's `np.vectorize` — in Rust it is essentially `.mapv()`
18/// wrapped as a reusable callable.
19///
20/// Works with any dimension type: `Ix1`, `Ix2`, `IxDyn`, etc.
21///
22/// # Example
23/// ```ignore
24/// let square = vectorize(|x: f64| x * x);
25/// let result = square(&input_array)?;
26/// ```
27pub fn vectorize<T, U, F, D>(f: F) -> impl Fn(&Array<T, D>) -> FerrayResult<Array<U, D>>
28where
29    T: Element + Copy,
30    U: Element,
31    D: Dimension,
32    F: Fn(T) -> U,
33{
34    move |input: &Array<T, D>| {
35        let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
36        Array::from_vec(input.dim().clone(), data)
37    }
38}
39
40/// Evaluate a piecewise-defined function.
41///
42/// For each element position, the first condition in `condlist` that is `true`
43/// determines which function from `funclist` is applied. Elements where no
44/// condition is true receive the `default` value.
45///
46/// This is equivalent to `numpy.piecewise(x, condlist, funclist)`.
47///
48/// # Arguments
49/// * `x` - The input array.
50/// * `condlist` - A slice of boolean arrays, each the same shape as `x`.
51/// * `funclist` - A slice of function references, one per condition. The
52///   slice element type is `&dyn Fn(T) -> T` (not `Box<dyn Fn>`) so
53///   callers can pass stack-allocated closures without heap allocation
54///   (#297).
55/// * `default` - The default value for elements where no condition is true.
56///
57/// # Errors
58/// - Returns `FerrayError::InvalidValue` if `condlist` and `funclist` have different lengths.
59/// - Returns `FerrayError::ShapeMismatch` if any condition array has a different shape than `x`.
60pub fn piecewise<T, D>(
61    x: &Array<T, D>,
62    condlist: &[Array<bool, D>],
63    funclist: &[&dyn Fn(T) -> T],
64    default: T,
65) -> FerrayResult<Array<T, D>>
66where
67    T: Element + Copy,
68    D: Dimension,
69{
70    if condlist.len() != funclist.len() {
71        return Err(FerrayError::invalid_value(format!(
72            "piecewise: condlist length ({}) must equal funclist length ({})",
73            condlist.len(),
74            funclist.len()
75        )));
76    }
77
78    for (i, cond) in condlist.iter().enumerate() {
79        if cond.shape() != x.shape() {
80            return Err(FerrayError::shape_mismatch(format!(
81                "piecewise: condlist[{i}] shape {:?} does not match x shape {:?}",
82                cond.shape(),
83                x.shape()
84            )));
85        }
86    }
87
88    let size = x.size();
89    let mut result_data = vec![default; size];
90    let x_data: Vec<T> = x.iter().copied().collect();
91
92    // Collect all condition data upfront
93    let cond_data: Vec<Vec<bool>> = condlist
94        .iter()
95        .map(|c| c.iter().copied().collect())
96        .collect();
97
98    // For each element, find the first matching condition
99    for i in 0..size {
100        for (j, cond) in cond_data.iter().enumerate() {
101            if cond[i] {
102                result_data[i] = funclist[j](x_data[i]);
103                break;
104            }
105        }
106    }
107
108    Array::from_vec(x.dim().clone(), result_data)
109}
110
111/// Apply a function along one axis of an array.
112///
113/// The function receives 1-D slices (lanes) along the specified axis and
114/// returns a scalar value. The result has one fewer dimension than the input
115/// (the specified axis is removed).
116///
117/// This is equivalent to `numpy.apply_along_axis(func1d, axis, arr)` when
118/// `func1d` returns a scalar.
119///
120/// # Arguments
121/// * `func` - A function that takes a 1-D array view and returns a scalar result.
122/// * `axis` - The axis along which to apply the function.
123/// * `a` - The input array.
124///
125/// # Errors
126/// - Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
127/// - Propagates any error from the function or array construction.
128pub fn apply_along_axis<T, D>(
129    func: impl Fn(&Array<T, Ix1>) -> FerrayResult<T>,
130    axis: Axis,
131    a: &Array<T, D>,
132) -> FerrayResult<Array<T, IxDyn>>
133where
134    T: Element + Copy,
135    D: Dimension,
136{
137    let ndim = a.ndim();
138    let ax = axis.index();
139    if ax >= ndim {
140        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
141    }
142
143    // Collect lanes along the axis, apply the function, collect results
144    let lanes_iter = a.lanes(axis)?;
145    let mut results = Vec::new();
146
147    for lane in lanes_iter {
148        // Convert the Ix1 ArrayView to an owned Array<T, Ix1>
149        let owned_lane = lane.to_owned();
150        let val = func(&owned_lane)?;
151        results.push(val);
152    }
153
154    // Compute the result shape: input shape with the axis dimension removed
155    let mut result_shape: Vec<usize> = a.shape().to_vec();
156    result_shape.remove(ax);
157    if result_shape.is_empty() {
158        // 0-D result when input was 1-D
159        result_shape.push(results.len());
160    }
161
162    Array::from_vec(IxDyn::new(&result_shape), results)
163}
164
165/// Apply a reducing function repeatedly over multiple axes.
166///
167/// The function is applied to the array over each specified axis in sequence.
168/// After each application, the axis dimension is kept with size 1 (keepdims
169/// semantics) to maintain dimensionality alignment for subsequent reductions.
170///
171/// This is equivalent to `numpy.apply_over_axes(func, a, axes)`. Generic
172/// over any `T: Element + Copy` so f32/i32/etc. arrays work too — the
173/// previous signature hard-coded `Array<f64, IxDyn>` for no real
174/// reason (#182).
175///
176/// # Errors
177/// - Returns `FerrayError::AxisOutOfBounds` if any axis is out of bounds.
178/// - Propagates any error from the function.
179pub fn apply_over_axes<T, F>(
180    func: F,
181    a: &Array<T, IxDyn>,
182    axes: &[usize],
183) -> FerrayResult<Array<T, IxDyn>>
184where
185    T: Element + Copy,
186    F: Fn(&Array<T, IxDyn>, Axis) -> FerrayResult<Array<T, IxDyn>>,
187{
188    let ndim = a.ndim();
189    for &ax in axes {
190        if ax >= ndim {
191            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
192        }
193    }
194
195    let mut current = a.clone();
196    for &ax in axes {
197        current = func(&current, Axis(ax))?;
198        // Ensure the result has the same number of dimensions as current
199        // (keepdims semantics): if the function collapsed an axis, we don't
200        // need to re-expand since we expect the function to keep dims.
201    }
202
203    Ok(current)
204}
205
206/// Helper: sum along an axis with keepdims semantics (keeps the axis as size 1).
207///
208/// This is useful as a `func` argument for [`apply_over_axes`]. Generic
209/// over any `T: Element + Copy + Zero + Add` so it works for any
210/// addable numeric type (#182).
211///
212/// # Errors
213/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
214pub fn sum_axis_keepdims<T>(a: &Array<T, IxDyn>, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
215where
216    T: Element + Copy + Zero + core::ops::Add<Output = T>,
217{
218    let ndim = a.ndim();
219    let ax = axis.index();
220    if ax >= ndim {
221        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
222    }
223
224    let reduced = a.fold_axis(axis, <T as Zero>::zero(), |acc, &x| *acc + x)?;
225
226    // Reinsert the axis as size 1 (keepdims)
227    let mut new_shape: Vec<usize> = reduced.shape().to_vec();
228    new_shape.insert(ax, 1);
229    let data: Vec<T> = reduced.iter().copied().collect();
230    Array::from_vec(IxDyn::new(&new_shape), data)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use ferray_core::dimension::Ix2;
237
238    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
239        let n = data.len();
240        Array::from_vec(Ix1::new([n]), data).unwrap()
241    }
242
243    fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
244        let n = data.len();
245        Array::from_vec(Ix1::new([n]), data).unwrap()
246    }
247
248    // -----------------------------------------------------------------------
249    // AC-4: vectorize(|x: f64| x.powi(2))(&array) produces element-squared
250    // -----------------------------------------------------------------------
251    #[test]
252    fn vectorize_square_ac4() {
253        let square = vectorize(|x: f64| x.powi(2));
254        let input = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
255        let result = square(&input).unwrap();
256        assert_eq!(result.as_slice().unwrap(), &[1.0, 4.0, 9.0, 16.0, 25.0][..]);
257    }
258
259    #[test]
260    fn vectorize_matches_mapv() {
261        let f = |x: f64| x.sin();
262        let vf = vectorize(f);
263        let input = arr1(vec![0.0, 1.0, 2.0, 3.0]);
264        let via_vectorize = vf(&input).unwrap();
265        let via_mapv = input.mapv(f);
266        assert_eq!(
267            via_vectorize.as_slice().unwrap(),
268            via_mapv.as_slice().unwrap()
269        );
270    }
271
272    #[test]
273    fn vectorize_2d_generic_dimension() {
274        // vectorize is now generic over dimension — this used to live
275        // in a separate `vectorize_nd` that has been removed (#293).
276        let square = vectorize(|x: f64| x * x);
277        let input =
278            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
279                .unwrap();
280        let result = square(&input).unwrap();
281        assert_eq!(result.shape(), &[2, 3]);
282        assert_eq!(
283            result.as_slice().unwrap(),
284            &[1.0, 4.0, 9.0, 16.0, 25.0, 36.0][..]
285        );
286    }
287
288    #[test]
289    fn vectorize_empty() {
290        let f = vectorize(|x: f64| x + 1.0);
291        let input = arr1(vec![]);
292        let result = f(&input).unwrap();
293        assert_eq!(result.shape(), &[0]);
294    }
295
296    // -----------------------------------------------------------------------
297    // Piecewise tests
298    // -----------------------------------------------------------------------
299    #[test]
300    fn piecewise_basic() {
301        let x = arr1(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
302        let cond_neg = arr1_bool(vec![true, true, false, false, false]);
303        let cond_pos = arr1_bool(vec![false, false, false, true, true]);
304
305        // piecewise now takes `&[&dyn Fn]` rather than `&[Box<dyn Fn>]`
306        // so callers don't need a heap allocation per branch (#297).
307        let neg: &dyn Fn(f64) -> f64 = &|v| -v;
308        let pos: &dyn Fn(f64) -> f64 = &|v| v * 2.0;
309        let result = piecewise(&x, &[cond_neg, cond_pos], &[neg, pos], 0.0).unwrap();
310
311        let s = result.as_slice().unwrap();
312        assert_eq!(s, &[2.0, 1.0, 0.0, 2.0, 4.0]);
313    }
314
315    #[test]
316    fn piecewise_first_match_wins() {
317        let x = arr1(vec![1.0, 2.0, 3.0]);
318        // Both conditions true for all elements
319        let cond1 = arr1_bool(vec![true, true, true]);
320        let cond2 = arr1_bool(vec![true, true, true]);
321
322        let f1: &dyn Fn(f64) -> f64 = &|v| v * 10.0;
323        let f2: &dyn Fn(f64) -> f64 = &|v| v * 100.0;
324        let result = piecewise(&x, &[cond1, cond2], &[f1, f2], 0.0).unwrap();
325
326        // First condition wins
327        let s = result.as_slice().unwrap();
328        assert_eq!(s, &[10.0, 20.0, 30.0]);
329    }
330
331    #[test]
332    fn piecewise_no_match_uses_default() {
333        let x = arr1(vec![1.0, 2.0, 3.0]);
334        let cond = arr1_bool(vec![false, false, false]);
335
336        let f: &dyn Fn(f64) -> f64 = &|v| v * 10.0;
337        let result = piecewise(&x, &[cond], &[f], -999.0).unwrap();
338
339        let s = result.as_slice().unwrap();
340        assert_eq!(s, &[-999.0, -999.0, -999.0]);
341    }
342
343    #[test]
344    fn piecewise_length_mismatch() {
345        let x = arr1(vec![1.0, 2.0]);
346        let cond = arr1_bool(vec![true, false]);
347        let f1: &dyn Fn(f64) -> f64 = &|v| v;
348        let f2: &dyn Fn(f64) -> f64 = &|v| v;
349        assert!(piecewise(&x, &[cond], &[f1, f2], 0.0).is_err());
350    }
351
352    #[test]
353    fn piecewise_shape_mismatch() {
354        let x = arr1(vec![1.0, 2.0]);
355        let cond = arr1_bool(vec![true, false, true]); // wrong shape
356        let f: &dyn Fn(f64) -> f64 = &|v| v;
357        assert!(piecewise(&x, &[cond], &[f], 0.0).is_err());
358    }
359
360    // -----------------------------------------------------------------------
361    // AC-5: apply_along_axis sum along axis 0 produces column sums
362    // -----------------------------------------------------------------------
363    #[test]
364    fn apply_along_axis_col_sums_ac5() {
365        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
366            .unwrap();
367
368        let result = apply_along_axis(
369            |col| {
370                let sum: f64 = col.iter().sum();
371                Ok(sum)
372            },
373            Axis(0),
374            &m,
375        )
376        .unwrap();
377
378        // Lanes along axis 0 yield columns: [1,4], [2,5], [3,6]
379        // Sums: [5, 7, 9]
380        assert_eq!(result.shape(), &[3]);
381        let data: Vec<f64> = result.iter().copied().collect();
382        assert_eq!(data, vec![5.0, 7.0, 9.0]);
383    }
384
385    #[test]
386    fn apply_along_axis_row_sums() {
387        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
388            .unwrap();
389
390        let result = apply_along_axis(
391            |row| {
392                let sum: f64 = row.iter().sum();
393                Ok(sum)
394            },
395            Axis(1),
396            &m,
397        )
398        .unwrap();
399
400        // Lanes along axis 1 yield rows: [1,2,3], [4,5,6]
401        // Sums: [6, 15]
402        assert_eq!(result.shape(), &[2]);
403        let data: Vec<f64> = result.iter().copied().collect();
404        assert_eq!(data, vec![6.0, 15.0]);
405    }
406
407    #[test]
408    fn apply_along_axis_1d() {
409        let a = arr1(vec![1.0, 2.0, 3.0]);
410        let result = apply_along_axis(
411            |lane| {
412                let sum: f64 = lane.iter().sum();
413                Ok(sum)
414            },
415            Axis(0),
416            &a,
417        )
418        .unwrap();
419        // Should return scalar-like (1 element)
420        let data: Vec<f64> = result.iter().copied().collect();
421        assert_eq!(data, vec![6.0]);
422    }
423
424    #[test]
425    fn apply_along_axis_out_of_bounds() {
426        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
427            .unwrap();
428        assert!(apply_along_axis(|_| Ok(0.0), Axis(5), &m).is_err());
429    }
430
431    // -----------------------------------------------------------------------
432    // apply_over_axes tests
433    // -----------------------------------------------------------------------
434    #[test]
435    fn apply_over_axes_sum() {
436        // 2x3 array, sum over axis 0 then axis 1
437        let a =
438            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
439                .unwrap();
440
441        let result = apply_over_axes(sum_axis_keepdims, &a, &[0, 1]).unwrap();
442
443        // After sum axis 0: shape [1, 3], values [5, 7, 9]
444        // After sum axis 1: shape [1, 1], values [21]
445        assert_eq!(result.shape(), &[1, 1]);
446        let data: Vec<f64> = result.iter().copied().collect();
447        assert_eq!(data, vec![21.0]);
448    }
449
450    #[test]
451    fn apply_over_axes_single_axis() {
452        let a =
453            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
454                .unwrap();
455
456        let result = apply_over_axes(sum_axis_keepdims, &a, &[0]).unwrap();
457        assert_eq!(result.shape(), &[1, 3]);
458        let data: Vec<f64> = result.iter().copied().collect();
459        assert_eq!(data, vec![5.0, 7.0, 9.0]);
460    }
461
462    #[test]
463    fn apply_over_axes_out_of_bounds() {
464        let a =
465            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
466                .unwrap();
467        assert!(apply_over_axes(sum_axis_keepdims, &a, &[5]).is_err());
468    }
469
470    // -----------------------------------------------------------------------
471    // sum_axis_keepdims tests
472    // -----------------------------------------------------------------------
473    #[test]
474    fn sum_axis_keepdims_basic() {
475        let a =
476            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
477                .unwrap();
478
479        let result = sum_axis_keepdims(&a, Axis(0)).unwrap();
480        assert_eq!(result.shape(), &[1, 3]);
481        let data: Vec<f64> = result.iter().copied().collect();
482        assert_eq!(data, vec![5.0, 7.0, 9.0]);
483    }
484
485    #[test]
486    fn sum_axis_keepdims_axis1() {
487        let a =
488            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
489                .unwrap();
490
491        let result = sum_axis_keepdims(&a, Axis(1)).unwrap();
492        assert_eq!(result.shape(), &[2, 1]);
493        let data: Vec<f64> = result.iter().copied().collect();
494        assert_eq!(data, vec![6.0, 15.0]);
495    }
496}