Skip to main content

ferrolearn_preprocess/
function_transformer.rs

1//! Function transformer: apply a user-provided function element-wise.
2//!
3//! Wraps any `Fn(F) -> F` callable and applies it to every element in the
4//! input matrix. This is useful for applying non-standard transformations
5//! such as `ln`, `sqrt`, or custom domain-specific functions.
6//!
7//! This transformer is **stateless** — no fitting is required. Call
8//! [`Transform::transform`] directly.
9//!
10//! # `## REQ status`
11//!
12//! Binary (R-DEFER-2), translating `sklearn/preprocessing/_function_transformer.py`
13//! (`class FunctionTransformer(TransformerMixin, BaseEstimator)`). Design doc:
14//! `.design/preprocess/function_transformer.md`. Expected values from the live sklearn 1.5.2
15//! oracle (R-CHAR-3). HONEST (R-HONEST-3): ferrolearn ships a THIN element-wise wrapper —
16//! `func` is a scalar `Fn(F) -> F` applied via `mapv`, NOT sklearn's whole-array `func(X)`;
17//! matches sklearn only on the element-wise/ufunc subset. Consumer: crate re-export
18//! (`lib.rs`, grandfathered S5).
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 (element-wise forward transform) | SHIPPED (scoped) | `Transform::transform` = `x.mapv(\|v\| (self.func)(v))`, shape-preserving, infallible; mirrors sklearn `_transform` (`_function_transformer.py:375-379`) for element-wise ufunc `func`. Critic-verified bit-identical to live sklearn: `guard_log1p_/expm1_/sqrt_/log_nan_inf_/empty_matrix_*` (5 green) in `tests/divergence_function_transformer.rs`. Consumer: `pub use function_transformer::FunctionTransformer` (`lib.rs:114`). Caveat: scalar `Fn(F)->F`, not array `Fn(X)->X`. |
23//! | REQ-2 (func=None identity default) | NOT-STARTED | open prereq blocker #1112. `new` requires a closure; no identity default (`_identity`, `:22-24`). |
24//! | REQ-3 (whole-array func, headline) | NOT-STARTED | open prereq blocker #1113. `Box<dyn Fn(F)->F>` cannot read the array / change shape; sklearn `func(X)` is array→array (`:375-379`). |
25//! | REQ-4 (inverse_func / inverse_transform) | NOT-STARTED | open prereq blocker #1114. No inverse path (sklearn `:309-325`). |
26//! | REQ-5 (validate / accept_sparse) | NOT-STARTED | open prereq blocker #1115. No input validation (sklearn `:173-182`). |
27//! | REQ-6 (fit / check_inverse / is_fitted) | NOT-STARTED | open prereq blocker #1116. No fit/check_inverse (sklearn `:213-235`, `:184-210`). |
28//! | REQ-7 (feature_names_out / n_features_in_) | NOT-STARTED | open prereq blocker #1117. None (sklearn `:327-373`). |
29//! | REQ-8 (kw_args / inv_kw_args) | NOT-STARTED | open prereq blocker #1118. No kwarg forwarding (sklearn `:93-101`,`:379`). |
30//! | REQ-9 (ctor surface + _parameter_constraints) | NOT-STARTED | open prereq blocker #1119. Only `func`; 7 params + validation missing (R-DEV-2, sklearn `:141-171`). |
31//! | REQ-10 (PyO3 binding) | NOT-STARTED | open prereq blocker #1120. No `ferrolearn-python` registration. |
32//! | REQ-11 (ferray substrate) | NOT-STARTED | open prereq blocker #1121. `ndarray`/`num_traits`, not `ferray-core`/`ferray-ufunc` (R-SUBSTRATE-1/2). |
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::traits::Transform;
36use ndarray::Array2;
37use num_traits::Float;
38
39// ---------------------------------------------------------------------------
40// FunctionTransformer
41// ---------------------------------------------------------------------------
42
43/// A stateless element-wise function transformer.
44///
45/// Wraps a boxed `Fn(F) -> F` closure and applies it to every element in
46/// the input matrix.
47///
48/// # Examples
49///
50/// ```
51/// use ferrolearn_preprocess::function_transformer::FunctionTransformer;
52/// use ferrolearn_core::traits::Transform;
53/// use ndarray::array;
54///
55/// // Apply natural logarithm element-wise (values must be > 0)
56/// let ft = FunctionTransformer::<f64>::new(|v| v.ln());
57/// let x = array![[1.0, 2.0], [3.0, 4.0]];
58/// let out = ft.transform(&x).unwrap();
59/// ```
60pub struct FunctionTransformer<F> {
61    func: Box<dyn Fn(F) -> F + Send + Sync>,
62}
63
64impl<F: Float + Send + Sync + 'static> FunctionTransformer<F> {
65    /// Create a new `FunctionTransformer` with the given function.
66    ///
67    /// The function will be applied element-wise to the input matrix.
68    pub fn new<Func>(func: Func) -> Self
69    where
70        Func: Fn(F) -> F + Send + Sync + 'static,
71    {
72        Self {
73            func: Box::new(func),
74        }
75    }
76}
77
78impl<F: Float + Send + Sync + 'static> std::fmt::Debug for FunctionTransformer<F> {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("FunctionTransformer")
81            .field("func", &"<fn(F) -> F>")
82            .finish()
83    }
84}
85
86// ---------------------------------------------------------------------------
87// Trait implementations
88// ---------------------------------------------------------------------------
89
90impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FunctionTransformer<F> {
91    type Output = Array2<F>;
92    type Error = FerroError;
93
94    /// Apply the stored function to every element of `x`.
95    ///
96    /// # Errors
97    ///
98    /// This implementation never returns an error for well-formed inputs.
99    /// Note: if the user-provided function produces NaN or infinity for
100    /// certain inputs, those values will appear in the output without error.
101    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
102        let out = x.mapv(|v| (self.func)(v));
103        Ok(out)
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Tests
109// ---------------------------------------------------------------------------
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use approx::assert_abs_diff_eq;
115    use ndarray::array;
116
117    #[test]
118    fn test_identity_function() {
119        let ft = FunctionTransformer::<f64>::new(|v| v);
120        let x = array![[1.0, 2.0], [3.0, 4.0]];
121        let out = ft.transform(&x).unwrap();
122        for (a, b) in x.iter().zip(out.iter()) {
123            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
124        }
125    }
126
127    #[test]
128    fn test_sqrt_function() {
129        let ft = FunctionTransformer::<f64>::new(|v: f64| v.sqrt());
130        let x = array![[1.0, 4.0], [9.0, 16.0]];
131        let out = ft.transform(&x).unwrap();
132        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
133        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
134        assert_abs_diff_eq!(out[[1, 0]], 3.0, epsilon = 1e-10);
135        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
136    }
137
138    #[test]
139    fn test_ln_function() {
140        let ft = FunctionTransformer::<f64>::new(|v: f64| v.ln());
141        let x = array![[1.0, 2.0]];
142        let out = ft.transform(&x).unwrap();
143        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // ln(1) = 0
144        assert_abs_diff_eq!(out[[0, 1]], 2.0_f64.ln(), epsilon = 1e-10);
145    }
146
147    #[test]
148    fn test_negate_function() {
149        let ft = FunctionTransformer::<f64>::new(|v| -v);
150        let x = array![[1.0, -2.0, 3.0]];
151        let out = ft.transform(&x).unwrap();
152        assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
153        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
154        assert_abs_diff_eq!(out[[0, 2]], -3.0, epsilon = 1e-10);
155    }
156
157    #[test]
158    fn test_constant_function() {
159        let ft = FunctionTransformer::<f64>::new(|_| 42.0);
160        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
161        let out = ft.transform(&x).unwrap();
162        for v in &out {
163            assert_abs_diff_eq!(*v, 42.0, epsilon = 1e-15);
164        }
165    }
166
167    #[test]
168    fn test_preserves_shape() {
169        let ft = FunctionTransformer::<f64>::new(|v| v * 2.0);
170        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
171        let out = ft.transform(&x).unwrap();
172        assert_eq!(out.shape(), x.shape());
173    }
174
175    #[test]
176    fn test_clamp_function() {
177        let ft = FunctionTransformer::<f64>::new(|v: f64| v.clamp(0.0, 1.0));
178        let x = array![[-1.0, 0.5, 2.0]];
179        let out = ft.transform(&x).unwrap();
180        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
181        assert_abs_diff_eq!(out[[0, 1]], 0.5, epsilon = 1e-10);
182        assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10);
183    }
184
185    #[test]
186    fn test_f32_function() {
187        let ft = FunctionTransformer::<f32>::new(|v: f32| v * 2.0);
188        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0]];
189        let out = ft.transform(&x).unwrap();
190        assert!((out[[0, 0]] - 2.0f32).abs() < 1e-6);
191        assert!((out[[1, 1]] - 8.0f32).abs() < 1e-6);
192    }
193
194    #[test]
195    fn test_closure_captures_environment() {
196        let scale = 3.0_f64;
197        let ft = FunctionTransformer::<f64>::new(move |v| v * scale);
198        let x = array![[1.0, 2.0]];
199        let out = ft.transform(&x).unwrap();
200        assert_abs_diff_eq!(out[[0, 0]], 3.0, epsilon = 1e-10);
201        assert_abs_diff_eq!(out[[0, 1]], 6.0, epsilon = 1e-10);
202    }
203
204    #[test]
205    fn test_empty_matrix() {
206        let ft = FunctionTransformer::<f64>::new(|v| v);
207        let x: Array2<f64> = Array2::zeros((0, 3));
208        let out = ft.transform(&x).unwrap();
209        assert_eq!(out.shape(), &[0, 3]);
210    }
211}