Skip to main content

ferrolearn_preprocess/
function_transformer.rs

1//! Function transformer: apply a user-provided function element-wise.
2//!
3//! Wraps any `Fn(F) -> F` callable and applies it to every element in the
4//! input matrix. This is useful for applying non-standard transformations
5//! such as `ln`, `sqrt`, or custom domain-specific functions.
6//!
7//! This transformer is **stateless** — no fitting is required. Call
8//! [`Transform::transform`] directly.
9
10use ferrolearn_core::error::FerroError;
11use ferrolearn_core::traits::Transform;
12use ndarray::Array2;
13use num_traits::Float;
14
15// ---------------------------------------------------------------------------
16// FunctionTransformer
17// ---------------------------------------------------------------------------
18
19/// A stateless element-wise function transformer.
20///
21/// Wraps a boxed `Fn(F) -> F` closure and applies it to every element in
22/// the input matrix.
23///
24/// # Examples
25///
26/// ```
27/// use ferrolearn_preprocess::function_transformer::FunctionTransformer;
28/// use ferrolearn_core::traits::Transform;
29/// use ndarray::array;
30///
31/// // Apply natural logarithm element-wise (values must be > 0)
32/// let ft = FunctionTransformer::<f64>::new(|v| v.ln());
33/// let x = array![[1.0, 2.0], [3.0, 4.0]];
34/// let out = ft.transform(&x).unwrap();
35/// ```
36pub struct FunctionTransformer<F> {
37    func: Box<dyn Fn(F) -> F + Send + Sync>,
38}
39
40impl<F: Float + Send + Sync + 'static> FunctionTransformer<F> {
41    /// Create a new `FunctionTransformer` with the given function.
42    ///
43    /// The function will be applied element-wise to the input matrix.
44    pub fn new<Func>(func: Func) -> Self
45    where
46        Func: Fn(F) -> F + Send + Sync + 'static,
47    {
48        Self {
49            func: Box::new(func),
50        }
51    }
52}
53
54impl<F: Float + Send + Sync + 'static> std::fmt::Debug for FunctionTransformer<F> {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("FunctionTransformer")
57            .field("func", &"<fn(F) -> F>")
58            .finish()
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Trait implementations
64// ---------------------------------------------------------------------------
65
66impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FunctionTransformer<F> {
67    type Output = Array2<F>;
68    type Error = FerroError;
69
70    /// Apply the stored function to every element of `x`.
71    ///
72    /// # Errors
73    ///
74    /// This implementation never returns an error for well-formed inputs.
75    /// Note: if the user-provided function produces NaN or infinity for
76    /// certain inputs, those values will appear in the output without error.
77    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
78        let out = x.mapv(|v| (self.func)(v));
79        Ok(out)
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Tests
85// ---------------------------------------------------------------------------
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use approx::assert_abs_diff_eq;
91    use ndarray::array;
92
93    #[test]
94    fn test_identity_function() {
95        let ft = FunctionTransformer::<f64>::new(|v| v);
96        let x = array![[1.0, 2.0], [3.0, 4.0]];
97        let out = ft.transform(&x).unwrap();
98        for (a, b) in x.iter().zip(out.iter()) {
99            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
100        }
101    }
102
103    #[test]
104    fn test_sqrt_function() {
105        let ft = FunctionTransformer::<f64>::new(|v: f64| v.sqrt());
106        let x = array![[1.0, 4.0], [9.0, 16.0]];
107        let out = ft.transform(&x).unwrap();
108        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
109        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
110        assert_abs_diff_eq!(out[[1, 0]], 3.0, epsilon = 1e-10);
111        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
112    }
113
114    #[test]
115    fn test_ln_function() {
116        let ft = FunctionTransformer::<f64>::new(|v: f64| v.ln());
117        let x = array![[1.0, 2.0]];
118        let out = ft.transform(&x).unwrap();
119        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // ln(1) = 0
120        assert_abs_diff_eq!(out[[0, 1]], 2.0_f64.ln(), epsilon = 1e-10);
121    }
122
123    #[test]
124    fn test_negate_function() {
125        let ft = FunctionTransformer::<f64>::new(|v| -v);
126        let x = array![[1.0, -2.0, 3.0]];
127        let out = ft.transform(&x).unwrap();
128        assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
129        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
130        assert_abs_diff_eq!(out[[0, 2]], -3.0, epsilon = 1e-10);
131    }
132
133    #[test]
134    fn test_constant_function() {
135        let ft = FunctionTransformer::<f64>::new(|_| 42.0);
136        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
137        let out = ft.transform(&x).unwrap();
138        for v in out.iter() {
139            assert_abs_diff_eq!(*v, 42.0, epsilon = 1e-15);
140        }
141    }
142
143    #[test]
144    fn test_preserves_shape() {
145        let ft = FunctionTransformer::<f64>::new(|v| v * 2.0);
146        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
147        let out = ft.transform(&x).unwrap();
148        assert_eq!(out.shape(), x.shape());
149    }
150
151    #[test]
152    fn test_clamp_function() {
153        let ft = FunctionTransformer::<f64>::new(|v: f64| v.max(0.0).min(1.0));
154        let x = array![[-1.0, 0.5, 2.0]];
155        let out = ft.transform(&x).unwrap();
156        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
157        assert_abs_diff_eq!(out[[0, 1]], 0.5, epsilon = 1e-10);
158        assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10);
159    }
160
161    #[test]
162    fn test_f32_function() {
163        let ft = FunctionTransformer::<f32>::new(|v: f32| v * 2.0);
164        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0]];
165        let out = ft.transform(&x).unwrap();
166        assert!((out[[0, 0]] - 2.0f32).abs() < 1e-6);
167        assert!((out[[1, 1]] - 8.0f32).abs() < 1e-6);
168    }
169
170    #[test]
171    fn test_closure_captures_environment() {
172        let scale = 3.0_f64;
173        let ft = FunctionTransformer::<f64>::new(move |v| v * scale);
174        let x = array![[1.0, 2.0]];
175        let out = ft.transform(&x).unwrap();
176        assert_abs_diff_eq!(out[[0, 0]], 3.0, epsilon = 1e-10);
177        assert_abs_diff_eq!(out[[0, 1]], 6.0, epsilon = 1e-10);
178    }
179
180    #[test]
181    fn test_empty_matrix() {
182        let ft = FunctionTransformer::<f64>::new(|v| v);
183        let x: Array2<f64> = Array2::zeros((0, 3));
184        let out = ft.transform(&x).unwrap();
185        assert_eq!(out.shape(), &[0, 3]);
186    }
187}