Skip to main content

anofox_ml_core/
function_transformer.rs

1//! FunctionTransformer — wraps a closure as a pipeline-compatible transformer.
2//!
3//! Useful for inserting arbitrary transformations into a Pipeline without
4//! defining a full struct + impl.
5
6use std::sync::Arc;
7
8use ndarray::Array2;
9
10use crate::error::Result;
11use crate::float::Float;
12use crate::traits::{FitUnsupervised, Transform};
13
14/// A transformer that applies an arbitrary function to the data.
15///
16/// Implements `FitTransform` so it can be used directly in a `Pipeline`.
17/// The fit step is a no-op — only the transform function is called.
18///
19/// # Example
20///
21/// ```ignore
22/// use anofox_ml_core::{Pipeline, FunctionTransformer};
23///
24/// let log_transform = FunctionTransformer::new(|x: &Array2<f64>| {
25///     Ok(x.mapv(|v| v.ln()))
26/// });
27///
28/// let pipeline = Pipeline::new()
29///     .push_transformer("log", log_transform);
30/// ```
31pub struct FunctionTransformer<F: Float> {
32    func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
33}
34
35impl<F: Float> FunctionTransformer<F> {
36    /// Create a new FunctionTransformer from a closure.
37    pub fn new(func: impl Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync + 'static) -> Self {
38        Self {
39            func: Arc::new(func),
40        }
41    }
42}
43
44impl<F: Float> std::fmt::Debug for FunctionTransformer<F> {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("FunctionTransformer").finish()
47    }
48}
49
50impl<F: Float> Clone for FunctionTransformer<F> {
51    fn clone(&self) -> Self {
52        Self {
53            func: Arc::clone(&self.func),
54        }
55    }
56}
57
58/// Fitted function transformer — just holds the closure.
59pub struct FittedFunctionTransformer<F: Float> {
60    func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
61}
62
63unsafe impl<F: Float> Send for FittedFunctionTransformer<F> {}
64unsafe impl<F: Float> Sync for FittedFunctionTransformer<F> {}
65
66impl<F: Float> Transform<F> for FittedFunctionTransformer<F> {
67    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
68        (self.func)(x)
69    }
70}
71
72impl<F: Float + 'static> FitUnsupervised<F> for FunctionTransformer<F> {
73    type Fitted = FittedFunctionTransformer<F>;
74
75    fn fit(&self, _x: &Array2<F>) -> Result<Self::Fitted> {
76        Ok(FittedFunctionTransformer {
77            func: Arc::clone(&self.func),
78        })
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::traits::{FitUnsupervised, Transform};
86    use ndarray::array;
87
88    #[test]
89    fn test_function_transformer_identity() {
90        let ft = FunctionTransformer::<f64>::new(|x| Ok(x.to_owned()));
91        let x = array![[1.0, 2.0], [3.0, 4.0]];
92        let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
93        let transformed = fitted.transform(&x).unwrap();
94        assert_eq!(transformed, x);
95    }
96
97    #[test]
98    fn test_function_transformer_log() {
99        let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v.ln())));
100        let x = array![[1.0, std::f64::consts::E], [std::f64::consts::E, 1.0]];
101        let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
102        let transformed = fitted.transform(&x).unwrap();
103        assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10);
104        assert!((transformed[[0, 1]] - 1.0).abs() < 1e-10);
105    }
106
107    #[test]
108    fn test_function_transformer_scale() {
109        let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v * 2.0)));
110        let x = array![[1.0, 2.0], [3.0, 4.0]];
111        let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
112        let transformed = fitted.transform(&x).unwrap();
113        assert_eq!(transformed, array![[2.0, 4.0], [6.0, 8.0]]);
114    }
115
116    #[test]
117    fn test_function_transformer_clone() {
118        let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v + 1.0)));
119        let ft2 = ft.clone();
120        let x = array![[1.0]];
121        let f1 = FitUnsupervised::fit(&ft, &x).unwrap();
122        let f2 = FitUnsupervised::fit(&ft2, &x).unwrap();
123        assert_eq!(f1.transform(&x).unwrap(), f2.transform(&x).unwrap());
124    }
125}