anofox_ml_core/
function_transformer.rs1use std::sync::Arc;
7
8use ndarray::Array2;
9
10use crate::error::Result;
11use crate::float::Float;
12use crate::traits::{FitUnsupervised, Transform};
13
14pub struct FunctionTransformer<F: Float> {
32 func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
33}
34
35impl<F: Float> FunctionTransformer<F> {
36 pub fn new(func: impl Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync + 'static) -> Self {
38 Self {
39 func: Arc::new(func),
40 }
41 }
42}
43
44impl<F: Float> std::fmt::Debug for FunctionTransformer<F> {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("FunctionTransformer").finish()
47 }
48}
49
50impl<F: Float> Clone for FunctionTransformer<F> {
51 fn clone(&self) -> Self {
52 Self {
53 func: Arc::clone(&self.func),
54 }
55 }
56}
57
58pub struct FittedFunctionTransformer<F: Float> {
60 func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
61}
62
63unsafe impl<F: Float> Send for FittedFunctionTransformer<F> {}
64unsafe impl<F: Float> Sync for FittedFunctionTransformer<F> {}
65
66impl<F: Float> Transform<F> for FittedFunctionTransformer<F> {
67 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
68 (self.func)(x)
69 }
70}
71
72impl<F: Float + 'static> FitUnsupervised<F> for FunctionTransformer<F> {
73 type Fitted = FittedFunctionTransformer<F>;
74
75 fn fit(&self, _x: &Array2<F>) -> Result<Self::Fitted> {
76 Ok(FittedFunctionTransformer {
77 func: Arc::clone(&self.func),
78 })
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::traits::{FitUnsupervised, Transform};
86 use ndarray::array;
87
88 #[test]
89 fn test_function_transformer_identity() {
90 let ft = FunctionTransformer::<f64>::new(|x| Ok(x.to_owned()));
91 let x = array![[1.0, 2.0], [3.0, 4.0]];
92 let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
93 let transformed = fitted.transform(&x).unwrap();
94 assert_eq!(transformed, x);
95 }
96
97 #[test]
98 fn test_function_transformer_log() {
99 let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v.ln())));
100 let x = array![[1.0, std::f64::consts::E], [std::f64::consts::E, 1.0]];
101 let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
102 let transformed = fitted.transform(&x).unwrap();
103 assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10);
104 assert!((transformed[[0, 1]] - 1.0).abs() < 1e-10);
105 }
106
107 #[test]
108 fn test_function_transformer_scale() {
109 let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v * 2.0)));
110 let x = array![[1.0, 2.0], [3.0, 4.0]];
111 let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
112 let transformed = fitted.transform(&x).unwrap();
113 assert_eq!(transformed, array![[2.0, 4.0], [6.0, 8.0]]);
114 }
115
116 #[test]
117 fn test_function_transformer_clone() {
118 let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v + 1.0)));
119 let ft2 = ft.clone();
120 let x = array![[1.0]];
121 let f1 = FitUnsupervised::fit(&ft, &x).unwrap();
122 let f2 = FitUnsupervised::fit(&ft2, &x).unwrap();
123 assert_eq!(f1.transform(&x).unwrap(), f2.transform(&x).unwrap());
124 }
125}