ferrolearn_preprocess/function_transformer.rs
1//! Function transformer: apply a user-provided function element-wise.
2//!
3//! Wraps any `Fn(F) -> F` callable and applies it to every element in the
4//! input matrix. This is useful for applying non-standard transformations
5//! such as `ln`, `sqrt`, or custom domain-specific functions.
6//!
7//! This transformer is **stateless** — no fitting is required. Call
8//! [`Transform::transform`] directly.
9//!
10//! # `## REQ status`
11//!
12//! Binary (R-DEFER-2), translating `sklearn/preprocessing/_function_transformer.py`
13//! (`class FunctionTransformer(TransformerMixin, BaseEstimator)`). Design doc:
14//! `.design/preprocess/function_transformer.md`. Expected values from the live sklearn 1.5.2
15//! oracle (R-CHAR-3). HONEST (R-HONEST-3): ferrolearn ships a THIN element-wise wrapper —
16//! `func` is a scalar `Fn(F) -> F` applied via `mapv`, NOT sklearn's whole-array `func(X)`;
17//! matches sklearn only on the element-wise/ufunc subset. Consumer: crate re-export
18//! (`lib.rs`, grandfathered S5).
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 (element-wise forward transform) | SHIPPED (scoped) | `Transform::transform` = `x.mapv(\|v\| (self.func)(v))`, shape-preserving, infallible; mirrors sklearn `_transform` (`_function_transformer.py:375-379`) for element-wise ufunc `func`. Critic-verified bit-identical to live sklearn: `guard_log1p_/expm1_/sqrt_/log_nan_inf_/empty_matrix_*` (5 green) in `tests/divergence_function_transformer.rs`. Consumer: `pub use function_transformer::FunctionTransformer` (`lib.rs:114`). Caveat: scalar `Fn(F)->F`, not array `Fn(X)->X`. |
23//! | REQ-2 (func=None identity default) | NOT-STARTED | open prereq blocker #1112. `new` requires a closure; no identity default (`_identity`, `:22-24`). |
24//! | REQ-3 (whole-array func, headline) | NOT-STARTED | open prereq blocker #1113. `Box<dyn Fn(F)->F>` cannot read the array / change shape; sklearn `func(X)` is array→array (`:375-379`). |
25//! | REQ-4 (inverse_func / inverse_transform) | NOT-STARTED | open prereq blocker #1114. No inverse path (sklearn `:309-325`). |
26//! | REQ-5 (validate / accept_sparse) | NOT-STARTED | open prereq blocker #1115. No input validation (sklearn `:173-182`). |
27//! | REQ-6 (fit / check_inverse / is_fitted) | NOT-STARTED | open prereq blocker #1116. No fit/check_inverse (sklearn `:213-235`, `:184-210`). |
28//! | REQ-7 (feature_names_out / n_features_in_) | NOT-STARTED | open prereq blocker #1117. None (sklearn `:327-373`). |
29//! | REQ-8 (kw_args / inv_kw_args) | NOT-STARTED | open prereq blocker #1118. No kwarg forwarding (sklearn `:93-101`,`:379`). |
30//! | REQ-9 (ctor surface + _parameter_constraints) | NOT-STARTED | open prereq blocker #1119. Only `func`; 7 params + validation missing (R-DEV-2, sklearn `:141-171`). |
31//! | REQ-10 (PyO3 binding) | NOT-STARTED | open prereq blocker #1120. No `ferrolearn-python` registration. |
32//! | REQ-11 (ferray substrate) | NOT-STARTED | open prereq blocker #1121. `ndarray`/`num_traits`, not `ferray-core`/`ferray-ufunc` (R-SUBSTRATE-1/2). |
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::traits::Transform;
36use ndarray::Array2;
37use num_traits::Float;
38
39// ---------------------------------------------------------------------------
40// FunctionTransformer
41// ---------------------------------------------------------------------------
42
43/// A stateless element-wise function transformer.
44///
45/// Wraps a boxed `Fn(F) -> F` closure and applies it to every element in
46/// the input matrix.
47///
48/// # Examples
49///
50/// ```
51/// use ferrolearn_preprocess::function_transformer::FunctionTransformer;
52/// use ferrolearn_core::traits::Transform;
53/// use ndarray::array;
54///
55/// // Apply natural logarithm element-wise (values must be > 0)
56/// let ft = FunctionTransformer::<f64>::new(|v| v.ln());
57/// let x = array![[1.0, 2.0], [3.0, 4.0]];
58/// let out = ft.transform(&x).unwrap();
59/// ```
60pub struct FunctionTransformer<F> {
61 func: Box<dyn Fn(F) -> F + Send + Sync>,
62}
63
64impl<F: Float + Send + Sync + 'static> FunctionTransformer<F> {
65 /// Create a new `FunctionTransformer` with the given function.
66 ///
67 /// The function will be applied element-wise to the input matrix.
68 pub fn new<Func>(func: Func) -> Self
69 where
70 Func: Fn(F) -> F + Send + Sync + 'static,
71 {
72 Self {
73 func: Box::new(func),
74 }
75 }
76}
77
78impl<F: Float + Send + Sync + 'static> std::fmt::Debug for FunctionTransformer<F> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("FunctionTransformer")
81 .field("func", &"<fn(F) -> F>")
82 .finish()
83 }
84}
85
86// ---------------------------------------------------------------------------
87// Trait implementations
88// ---------------------------------------------------------------------------
89
90impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FunctionTransformer<F> {
91 type Output = Array2<F>;
92 type Error = FerroError;
93
94 /// Apply the stored function to every element of `x`.
95 ///
96 /// # Errors
97 ///
98 /// This implementation never returns an error for well-formed inputs.
99 /// Note: if the user-provided function produces NaN or infinity for
100 /// certain inputs, those values will appear in the output without error.
101 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
102 let out = x.mapv(|v| (self.func)(v));
103 Ok(out)
104 }
105}
106
107// ---------------------------------------------------------------------------
108// Tests
109// ---------------------------------------------------------------------------
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use approx::assert_abs_diff_eq;
115 use ndarray::array;
116
117 #[test]
118 fn test_identity_function() {
119 let ft = FunctionTransformer::<f64>::new(|v| v);
120 let x = array![[1.0, 2.0], [3.0, 4.0]];
121 let out = ft.transform(&x).unwrap();
122 for (a, b) in x.iter().zip(out.iter()) {
123 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
124 }
125 }
126
127 #[test]
128 fn test_sqrt_function() {
129 let ft = FunctionTransformer::<f64>::new(|v: f64| v.sqrt());
130 let x = array![[1.0, 4.0], [9.0, 16.0]];
131 let out = ft.transform(&x).unwrap();
132 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
133 assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
134 assert_abs_diff_eq!(out[[1, 0]], 3.0, epsilon = 1e-10);
135 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
136 }
137
138 #[test]
139 fn test_ln_function() {
140 let ft = FunctionTransformer::<f64>::new(|v: f64| v.ln());
141 let x = array![[1.0, 2.0]];
142 let out = ft.transform(&x).unwrap();
143 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // ln(1) = 0
144 assert_abs_diff_eq!(out[[0, 1]], 2.0_f64.ln(), epsilon = 1e-10);
145 }
146
147 #[test]
148 fn test_negate_function() {
149 let ft = FunctionTransformer::<f64>::new(|v| -v);
150 let x = array![[1.0, -2.0, 3.0]];
151 let out = ft.transform(&x).unwrap();
152 assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
153 assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
154 assert_abs_diff_eq!(out[[0, 2]], -3.0, epsilon = 1e-10);
155 }
156
157 #[test]
158 fn test_constant_function() {
159 let ft = FunctionTransformer::<f64>::new(|_| 42.0);
160 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
161 let out = ft.transform(&x).unwrap();
162 for v in &out {
163 assert_abs_diff_eq!(*v, 42.0, epsilon = 1e-15);
164 }
165 }
166
167 #[test]
168 fn test_preserves_shape() {
169 let ft = FunctionTransformer::<f64>::new(|v| v * 2.0);
170 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
171 let out = ft.transform(&x).unwrap();
172 assert_eq!(out.shape(), x.shape());
173 }
174
175 #[test]
176 fn test_clamp_function() {
177 let ft = FunctionTransformer::<f64>::new(|v: f64| v.clamp(0.0, 1.0));
178 let x = array![[-1.0, 0.5, 2.0]];
179 let out = ft.transform(&x).unwrap();
180 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
181 assert_abs_diff_eq!(out[[0, 1]], 0.5, epsilon = 1e-10);
182 assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10);
183 }
184
185 #[test]
186 fn test_f32_function() {
187 let ft = FunctionTransformer::<f32>::new(|v: f32| v * 2.0);
188 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0]];
189 let out = ft.transform(&x).unwrap();
190 assert!((out[[0, 0]] - 2.0f32).abs() < 1e-6);
191 assert!((out[[1, 1]] - 8.0f32).abs() < 1e-6);
192 }
193
194 #[test]
195 fn test_closure_captures_environment() {
196 let scale = 3.0_f64;
197 let ft = FunctionTransformer::<f64>::new(move |v| v * scale);
198 let x = array![[1.0, 2.0]];
199 let out = ft.transform(&x).unwrap();
200 assert_abs_diff_eq!(out[[0, 0]], 3.0, epsilon = 1e-10);
201 assert_abs_diff_eq!(out[[0, 1]], 6.0, epsilon = 1e-10);
202 }
203
204 #[test]
205 fn test_empty_matrix() {
206 let ft = FunctionTransformer::<f64>::new(|v| v);
207 let x: Array2<f64> = Array2::zeros((0, 3));
208 let out = ft.transform(&x).unwrap();
209 assert_eq!(out.shape(), &[0, 3]);
210 }
211}