Skip to main content

ferray_window/
functional.rs

1// ferray-window: Functional programming utilities
2//
3// Implements NumPy-equivalent functional utilities: vectorize, piecewise,
4// apply_along_axis, and apply_over_axes.
5
6use ferray_core::Array;
7use ferray_core::dimension::{Axis, Dimension, Ix1, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10
11/// Wrap a scalar function to operate elementwise on arrays.
12///
13/// Returns a closure that accepts `&Array<T, D>` and returns
14/// `FerrayResult<Array<U, D>>`, applying `f` to every element.
15///
16/// This is NumPy's `np.vectorize` — in Rust it is essentially `.mapv()`
17/// wrapped as a reusable callable.
18///
19/// # Example
20/// ```ignore
21/// let square = vectorize(|x: f64| x * x);
22/// let result = square(&input_array)?;
23/// ```
24pub fn vectorize<T, U, F>(f: F) -> impl Fn(&Array<T, Ix1>) -> FerrayResult<Array<U, Ix1>>
25where
26    T: Element + Copy,
27    U: Element,
28    F: Fn(T) -> U,
29{
30    move |input: &Array<T, Ix1>| {
31        let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
32        Array::from_vec(Ix1::new([data.len()]), data)
33    }
34}
35
36/// Wrap a scalar function to operate elementwise on arrays of any dimension.
37///
38/// Like [`vectorize`], but works with any dimension type `D`.
39///
40/// # Example
41/// ```ignore
42/// let square = vectorize_nd(|x: f64| x * x);
43/// let result = square(&input_2d_array)?;
44/// ```
45pub fn vectorize_nd<T, U, F, D>(f: F) -> impl Fn(&Array<T, D>) -> FerrayResult<Array<U, D>>
46where
47    T: Element + Copy,
48    U: Element,
49    D: Dimension,
50    F: Fn(T) -> U,
51{
52    move |input: &Array<T, D>| {
53        let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
54        Array::from_vec(input.dim().clone(), data)
55    }
56}
57
58/// Evaluate a piecewise-defined function.
59///
60/// For each element position, the first condition in `condlist` that is `true`
61/// determines which function from `funclist` is applied. Elements where no
62/// condition is true receive the `default` value.
63///
64/// This is equivalent to `numpy.piecewise(x, condlist, funclist)`.
65///
66/// # Arguments
67/// * `x` - The input array.
68/// * `condlist` - A slice of boolean arrays, each the same shape as `x`.
69/// * `funclist` - A slice of functions, one per condition. Each function maps `T -> T`.
70/// * `default` - The default value for elements where no condition is true.
71///
72/// # Errors
73/// - Returns `FerrayError::InvalidValue` if `condlist` and `funclist` have different lengths.
74/// - Returns `FerrayError::ShapeMismatch` if any condition array has a different shape than `x`.
75pub fn piecewise<T, D>(
76    x: &Array<T, D>,
77    condlist: &[Array<bool, D>],
78    funclist: &[Box<dyn Fn(T) -> T>],
79    default: T,
80) -> FerrayResult<Array<T, D>>
81where
82    T: Element + Copy,
83    D: Dimension,
84{
85    if condlist.len() != funclist.len() {
86        return Err(FerrayError::invalid_value(format!(
87            "piecewise: condlist length ({}) must equal funclist length ({})",
88            condlist.len(),
89            funclist.len()
90        )));
91    }
92
93    for (i, cond) in condlist.iter().enumerate() {
94        if cond.shape() != x.shape() {
95            return Err(FerrayError::shape_mismatch(format!(
96                "piecewise: condlist[{i}] shape {:?} does not match x shape {:?}",
97                cond.shape(),
98                x.shape()
99            )));
100        }
101    }
102
103    let size = x.size();
104    let mut result_data = vec![default; size];
105    let x_data: Vec<T> = x.iter().copied().collect();
106
107    // Collect all condition data upfront
108    let cond_data: Vec<Vec<bool>> = condlist
109        .iter()
110        .map(|c| c.iter().copied().collect())
111        .collect();
112
113    // For each element, find the first matching condition
114    for i in 0..size {
115        for (j, cond) in cond_data.iter().enumerate() {
116            if cond[i] {
117                result_data[i] = funclist[j](x_data[i]);
118                break;
119            }
120        }
121    }
122
123    Array::from_vec(x.dim().clone(), result_data)
124}
125
126/// Apply a function along one axis of an array.
127///
128/// The function receives 1-D slices (lanes) along the specified axis and
129/// returns a scalar value. The result has one fewer dimension than the input
130/// (the specified axis is removed).
131///
132/// This is equivalent to `numpy.apply_along_axis(func1d, axis, arr)` when
133/// `func1d` returns a scalar.
134///
135/// # Arguments
136/// * `func` - A function that takes a 1-D array view and returns a scalar result.
137/// * `axis` - The axis along which to apply the function.
138/// * `a` - The input array.
139///
140/// # Errors
141/// - Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
142/// - Propagates any error from the function or array construction.
143pub fn apply_along_axis<T, D>(
144    func: impl Fn(&Array<T, Ix1>) -> FerrayResult<T>,
145    axis: Axis,
146    a: &Array<T, D>,
147) -> FerrayResult<Array<T, IxDyn>>
148where
149    T: Element + Copy,
150    D: Dimension,
151{
152    let ndim = a.ndim();
153    let ax = axis.index();
154    if ax >= ndim {
155        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
156    }
157
158    // Collect lanes along the axis, apply the function, collect results
159    let lanes_iter = a.lanes(axis)?;
160    let mut results = Vec::new();
161
162    for lane in lanes_iter {
163        // Convert the Ix1 ArrayView to an owned Array<T, Ix1>
164        let owned_lane = lane.to_owned();
165        let val = func(&owned_lane)?;
166        results.push(val);
167    }
168
169    // Compute the result shape: input shape with the axis dimension removed
170    let mut result_shape: Vec<usize> = a.shape().to_vec();
171    result_shape.remove(ax);
172    if result_shape.is_empty() {
173        // 0-D result when input was 1-D
174        result_shape.push(results.len());
175    }
176
177    Array::from_vec(IxDyn::new(&result_shape), results)
178}
179
180/// Apply a reducing function repeatedly over multiple axes.
181///
182/// The function is applied to the array over each specified axis in sequence.
183/// After each application, the axis dimension is kept with size 1 (keepdims
184/// semantics) to maintain dimensionality alignment for subsequent reductions.
185///
186/// This is equivalent to `numpy.apply_over_axes(func, a, axes)`.
187///
188/// # Arguments
189/// * `func` - A function that reduces an array along a single axis, returning
190///   a dynamic-rank array with the axis dimension reduced (but kept as size 1).
191/// * `a` - The input array.
192/// * `axes` - The axes over which to apply the function.
193///
194/// # Errors
195/// - Returns `FerrayError::AxisOutOfBounds` if any axis is out of bounds.
196/// - Propagates any error from the function.
197pub fn apply_over_axes(
198    func: impl Fn(&Array<f64, IxDyn>, Axis) -> FerrayResult<Array<f64, IxDyn>>,
199    a: &Array<f64, IxDyn>,
200    axes: &[usize],
201) -> FerrayResult<Array<f64, IxDyn>> {
202    let ndim = a.ndim();
203    for &ax in axes {
204        if ax >= ndim {
205            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
206        }
207    }
208
209    let mut current = a.clone();
210    for &ax in axes {
211        current = func(&current, Axis(ax))?;
212        // Ensure the result has the same number of dimensions as current
213        // (keepdims semantics): if the function collapsed an axis, we don't
214        // need to re-expand since we expect the function to keep dims.
215    }
216
217    Ok(current)
218}
219
220/// Helper: sum along an axis with keepdims semantics (keeps the axis as size 1).
221///
222/// This is useful as a `func` argument for [`apply_over_axes`].
223///
224/// # Errors
225/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
226pub fn sum_axis_keepdims(a: &Array<f64, IxDyn>, axis: Axis) -> FerrayResult<Array<f64, IxDyn>> {
227    let ndim = a.ndim();
228    let ax = axis.index();
229    if ax >= ndim {
230        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
231    }
232
233    let reduced = a.fold_axis(axis, 0.0, |acc, &x| *acc + x)?;
234
235    // Reinsert the axis as size 1 (keepdims)
236    let mut new_shape: Vec<usize> = reduced.shape().to_vec();
237    new_shape.insert(ax, 1);
238    let data: Vec<f64> = reduced.iter().copied().collect();
239    Array::from_vec(IxDyn::new(&new_shape), data)
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use ferray_core::dimension::Ix2;
246
247    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
248        let n = data.len();
249        Array::from_vec(Ix1::new([n]), data).unwrap()
250    }
251
252    fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
253        let n = data.len();
254        Array::from_vec(Ix1::new([n]), data).unwrap()
255    }
256
257    // -----------------------------------------------------------------------
258    // AC-4: vectorize(|x: f64| x.powi(2))(&array) produces element-squared
259    // -----------------------------------------------------------------------
260    #[test]
261    fn vectorize_square_ac4() {
262        let square = vectorize(|x: f64| x.powi(2));
263        let input = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
264        let result = square(&input).unwrap();
265        let expected = vec![1.0, 4.0, 9.0, 16.0, 25.0];
266        assert_eq!(result.as_slice().unwrap(), &expected[..]);
267    }
268
269    #[test]
270    fn vectorize_matches_mapv() {
271        let f = |x: f64| x.sin();
272        let vf = vectorize(f);
273        let input = arr1(vec![0.0, 1.0, 2.0, 3.0]);
274        let via_vectorize = vf(&input).unwrap();
275        let via_mapv = input.mapv(f);
276        assert_eq!(
277            via_vectorize.as_slice().unwrap(),
278            via_mapv.as_slice().unwrap()
279        );
280    }
281
282    #[test]
283    fn vectorize_nd_2d() {
284        let square = vectorize_nd(|x: f64| x * x);
285        let input =
286            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
287                .unwrap();
288        let result = square(&input).unwrap();
289        assert_eq!(result.shape(), &[2, 3]);
290        let expected = vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0];
291        assert_eq!(result.as_slice().unwrap(), &expected[..]);
292    }
293
294    #[test]
295    fn vectorize_empty() {
296        let f = vectorize(|x: f64| x + 1.0);
297        let input = arr1(vec![]);
298        let result = f(&input).unwrap();
299        assert_eq!(result.shape(), &[0]);
300    }
301
302    // -----------------------------------------------------------------------
303    // Piecewise tests
304    // -----------------------------------------------------------------------
305    #[test]
306    fn piecewise_basic() {
307        let x = arr1(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
308        let cond_neg = arr1_bool(vec![true, true, false, false, false]);
309        let cond_pos = arr1_bool(vec![false, false, false, true, true]);
310
311        let result = piecewise(
312            &x,
313            &[cond_neg, cond_pos],
314            &[
315                Box::new(|v: f64| -v),      // negate for negatives
316                Box::new(|v: f64| v * 2.0), // double for positives
317            ],
318            0.0, // default for zero
319        )
320        .unwrap();
321
322        let s = result.as_slice().unwrap();
323        assert_eq!(s, &[2.0, 1.0, 0.0, 2.0, 4.0]);
324    }
325
326    #[test]
327    fn piecewise_first_match_wins() {
328        let x = arr1(vec![1.0, 2.0, 3.0]);
329        // Both conditions true for all elements
330        let cond1 = arr1_bool(vec![true, true, true]);
331        let cond2 = arr1_bool(vec![true, true, true]);
332
333        let result = piecewise(
334            &x,
335            &[cond1, cond2],
336            &[Box::new(|v: f64| v * 10.0), Box::new(|v: f64| v * 100.0)],
337            0.0,
338        )
339        .unwrap();
340
341        // First condition wins
342        let s = result.as_slice().unwrap();
343        assert_eq!(s, &[10.0, 20.0, 30.0]);
344    }
345
346    #[test]
347    fn piecewise_no_match_uses_default() {
348        let x = arr1(vec![1.0, 2.0, 3.0]);
349        let cond = arr1_bool(vec![false, false, false]);
350
351        let result = piecewise(&x, &[cond], &[Box::new(|v: f64| v * 10.0)], -999.0).unwrap();
352
353        let s = result.as_slice().unwrap();
354        assert_eq!(s, &[-999.0, -999.0, -999.0]);
355    }
356
357    #[test]
358    fn piecewise_length_mismatch() {
359        let x = arr1(vec![1.0, 2.0]);
360        let cond = arr1_bool(vec![true, false]);
361        assert!(
362            piecewise(
363                &x,
364                &[cond],
365                &[Box::new(|v: f64| v), Box::new(|v: f64| v)],
366                0.0
367            )
368            .is_err()
369        );
370    }
371
372    #[test]
373    fn piecewise_shape_mismatch() {
374        let x = arr1(vec![1.0, 2.0]);
375        let cond = arr1_bool(vec![true, false, true]); // wrong shape
376        assert!(piecewise(&x, &[cond], &[Box::new(|v: f64| v)], 0.0).is_err());
377    }
378
379    // -----------------------------------------------------------------------
380    // AC-5: apply_along_axis sum along axis 0 produces column sums
381    // -----------------------------------------------------------------------
382    #[test]
383    fn apply_along_axis_col_sums_ac5() {
384        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
385            .unwrap();
386
387        let result = apply_along_axis(
388            |col| {
389                let sum: f64 = col.iter().sum();
390                Ok(sum)
391            },
392            Axis(0),
393            &m,
394        )
395        .unwrap();
396
397        // Lanes along axis 0 yield columns: [1,4], [2,5], [3,6]
398        // Sums: [5, 7, 9]
399        assert_eq!(result.shape(), &[3]);
400        let data: Vec<f64> = result.iter().copied().collect();
401        assert_eq!(data, vec![5.0, 7.0, 9.0]);
402    }
403
404    #[test]
405    fn apply_along_axis_row_sums() {
406        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
407            .unwrap();
408
409        let result = apply_along_axis(
410            |row| {
411                let sum: f64 = row.iter().sum();
412                Ok(sum)
413            },
414            Axis(1),
415            &m,
416        )
417        .unwrap();
418
419        // Lanes along axis 1 yield rows: [1,2,3], [4,5,6]
420        // Sums: [6, 15]
421        assert_eq!(result.shape(), &[2]);
422        let data: Vec<f64> = result.iter().copied().collect();
423        assert_eq!(data, vec![6.0, 15.0]);
424    }
425
426    #[test]
427    fn apply_along_axis_1d() {
428        let a = arr1(vec![1.0, 2.0, 3.0]);
429        let result = apply_along_axis(
430            |lane| {
431                let sum: f64 = lane.iter().sum();
432                Ok(sum)
433            },
434            Axis(0),
435            &a,
436        )
437        .unwrap();
438        // Should return scalar-like (1 element)
439        let data: Vec<f64> = result.iter().copied().collect();
440        assert_eq!(data, vec![6.0]);
441    }
442
443    #[test]
444    fn apply_along_axis_out_of_bounds() {
445        let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
446            .unwrap();
447        assert!(apply_along_axis(|_| Ok(0.0), Axis(5), &m).is_err());
448    }
449
450    // -----------------------------------------------------------------------
451    // apply_over_axes tests
452    // -----------------------------------------------------------------------
453    #[test]
454    fn apply_over_axes_sum() {
455        // 2x3 array, sum over axis 0 then axis 1
456        let a =
457            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
458                .unwrap();
459
460        let result = apply_over_axes(sum_axis_keepdims, &a, &[0, 1]).unwrap();
461
462        // After sum axis 0: shape [1, 3], values [5, 7, 9]
463        // After sum axis 1: shape [1, 1], values [21]
464        assert_eq!(result.shape(), &[1, 1]);
465        let data: Vec<f64> = result.iter().copied().collect();
466        assert_eq!(data, vec![21.0]);
467    }
468
469    #[test]
470    fn apply_over_axes_single_axis() {
471        let a =
472            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
473                .unwrap();
474
475        let result = apply_over_axes(sum_axis_keepdims, &a, &[0]).unwrap();
476        assert_eq!(result.shape(), &[1, 3]);
477        let data: Vec<f64> = result.iter().copied().collect();
478        assert_eq!(data, vec![5.0, 7.0, 9.0]);
479    }
480
481    #[test]
482    fn apply_over_axes_out_of_bounds() {
483        let a =
484            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
485                .unwrap();
486        assert!(apply_over_axes(sum_axis_keepdims, &a, &[5]).is_err());
487    }
488
489    // -----------------------------------------------------------------------
490    // sum_axis_keepdims tests
491    // -----------------------------------------------------------------------
492    #[test]
493    fn sum_axis_keepdims_basic() {
494        let a =
495            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
496                .unwrap();
497
498        let result = sum_axis_keepdims(&a, Axis(0)).unwrap();
499        assert_eq!(result.shape(), &[1, 3]);
500        let data: Vec<f64> = result.iter().copied().collect();
501        assert_eq!(data, vec![5.0, 7.0, 9.0]);
502    }
503
504    #[test]
505    fn sum_axis_keepdims_axis1() {
506        let a =
507            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
508                .unwrap();
509
510        let result = sum_axis_keepdims(&a, Axis(1)).unwrap();
511        assert_eq!(result.shape(), &[2, 1]);
512        let data: Vec<f64> = result.iter().copied().collect();
513        assert_eq!(data, vec![6.0, 15.0]);
514    }
515}