ferrolearn_preprocess/
function_transformer.rs1use ferrolearn_core::error::FerroError;
11use ferrolearn_core::traits::Transform;
12use ndarray::Array2;
13use num_traits::Float;
14
15pub struct FunctionTransformer<F> {
37 func: Box<dyn Fn(F) -> F + Send + Sync>,
38}
39
40impl<F: Float + Send + Sync + 'static> FunctionTransformer<F> {
41 pub fn new<Func>(func: Func) -> Self
45 where
46 Func: Fn(F) -> F + Send + Sync + 'static,
47 {
48 Self {
49 func: Box::new(func),
50 }
51 }
52}
53
54impl<F: Float + Send + Sync + 'static> std::fmt::Debug for FunctionTransformer<F> {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("FunctionTransformer")
57 .field("func", &"<fn(F) -> F>")
58 .finish()
59 }
60}
61
62impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FunctionTransformer<F> {
67 type Output = Array2<F>;
68 type Error = FerroError;
69
70 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
78 let out = x.mapv(|v| (self.func)(v));
79 Ok(out)
80 }
81}
82
83#[cfg(test)]
88mod tests {
89 use super::*;
90 use approx::assert_abs_diff_eq;
91 use ndarray::array;
92
93 #[test]
94 fn test_identity_function() {
95 let ft = FunctionTransformer::<f64>::new(|v| v);
96 let x = array![[1.0, 2.0], [3.0, 4.0]];
97 let out = ft.transform(&x).unwrap();
98 for (a, b) in x.iter().zip(out.iter()) {
99 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
100 }
101 }
102
103 #[test]
104 fn test_sqrt_function() {
105 let ft = FunctionTransformer::<f64>::new(|v: f64| v.sqrt());
106 let x = array![[1.0, 4.0], [9.0, 16.0]];
107 let out = ft.transform(&x).unwrap();
108 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
109 assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
110 assert_abs_diff_eq!(out[[1, 0]], 3.0, epsilon = 1e-10);
111 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
112 }
113
114 #[test]
115 fn test_ln_function() {
116 let ft = FunctionTransformer::<f64>::new(|v: f64| v.ln());
117 let x = array![[1.0, 2.0]];
118 let out = ft.transform(&x).unwrap();
119 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0_f64.ln(), epsilon = 1e-10);
121 }
122
123 #[test]
124 fn test_negate_function() {
125 let ft = FunctionTransformer::<f64>::new(|v| -v);
126 let x = array![[1.0, -2.0, 3.0]];
127 let out = ft.transform(&x).unwrap();
128 assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
129 assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
130 assert_abs_diff_eq!(out[[0, 2]], -3.0, epsilon = 1e-10);
131 }
132
133 #[test]
134 fn test_constant_function() {
135 let ft = FunctionTransformer::<f64>::new(|_| 42.0);
136 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
137 let out = ft.transform(&x).unwrap();
138 for v in out.iter() {
139 assert_abs_diff_eq!(*v, 42.0, epsilon = 1e-15);
140 }
141 }
142
143 #[test]
144 fn test_preserves_shape() {
145 let ft = FunctionTransformer::<f64>::new(|v| v * 2.0);
146 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
147 let out = ft.transform(&x).unwrap();
148 assert_eq!(out.shape(), x.shape());
149 }
150
151 #[test]
152 fn test_clamp_function() {
153 let ft = FunctionTransformer::<f64>::new(|v: f64| v.max(0.0).min(1.0));
154 let x = array![[-1.0, 0.5, 2.0]];
155 let out = ft.transform(&x).unwrap();
156 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
157 assert_abs_diff_eq!(out[[0, 1]], 0.5, epsilon = 1e-10);
158 assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10);
159 }
160
161 #[test]
162 fn test_f32_function() {
163 let ft = FunctionTransformer::<f32>::new(|v: f32| v * 2.0);
164 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0]];
165 let out = ft.transform(&x).unwrap();
166 assert!((out[[0, 0]] - 2.0f32).abs() < 1e-6);
167 assert!((out[[1, 1]] - 8.0f32).abs() < 1e-6);
168 }
169
170 #[test]
171 fn test_closure_captures_environment() {
172 let scale = 3.0_f64;
173 let ft = FunctionTransformer::<f64>::new(move |v| v * scale);
174 let x = array![[1.0, 2.0]];
175 let out = ft.transform(&x).unwrap();
176 assert_abs_diff_eq!(out[[0, 0]], 3.0, epsilon = 1e-10);
177 assert_abs_diff_eq!(out[[0, 1]], 6.0, epsilon = 1e-10);
178 }
179
180 #[test]
181 fn test_empty_matrix() {
182 let ft = FunctionTransformer::<f64>::new(|v| v);
183 let x: Array2<f64> = Array2::zeros((0, 3));
184 let out = ft.transform(&x).unwrap();
185 assert_eq!(out.shape(), &[0, 3]);
186 }
187}