use std::sync::Arc;
use ndarray::Array2;
use crate::error::Result;
use crate::float::Float;
use crate::traits::{FitUnsupervised, Transform};
pub struct FunctionTransformer<F: Float> {
func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
}
impl<F: Float> FunctionTransformer<F> {
pub fn new(func: impl Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync + 'static) -> Self {
Self {
func: Arc::new(func),
}
}
}
impl<F: Float> std::fmt::Debug for FunctionTransformer<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionTransformer").finish()
}
}
impl<F: Float> Clone for FunctionTransformer<F> {
fn clone(&self) -> Self {
Self {
func: Arc::clone(&self.func),
}
}
}
pub struct FittedFunctionTransformer<F: Float> {
func: Arc<dyn Fn(&Array2<F>) -> Result<Array2<F>> + Send + Sync>,
}
unsafe impl<F: Float> Send for FittedFunctionTransformer<F> {}
unsafe impl<F: Float> Sync for FittedFunctionTransformer<F> {}
impl<F: Float> Transform<F> for FittedFunctionTransformer<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
(self.func)(x)
}
}
impl<F: Float + 'static> FitUnsupervised<F> for FunctionTransformer<F> {
type Fitted = FittedFunctionTransformer<F>;
fn fit(&self, _x: &Array2<F>) -> Result<Self::Fitted> {
Ok(FittedFunctionTransformer {
func: Arc::clone(&self.func),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::{FitUnsupervised, Transform};
use ndarray::array;
#[test]
fn test_function_transformer_identity() {
let ft = FunctionTransformer::<f64>::new(|x| Ok(x.to_owned()));
let x = array![[1.0, 2.0], [3.0, 4.0]];
let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed, x);
}
#[test]
fn test_function_transformer_log() {
let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v.ln())));
let x = array![[1.0, std::f64::consts::E], [std::f64::consts::E, 1.0]];
let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10);
assert!((transformed[[0, 1]] - 1.0).abs() < 1e-10);
}
#[test]
fn test_function_transformer_scale() {
let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v * 2.0)));
let x = array![[1.0, 2.0], [3.0, 4.0]];
let fitted = FitUnsupervised::fit(&ft, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed, array![[2.0, 4.0], [6.0, 8.0]]);
}
#[test]
fn test_function_transformer_clone() {
let ft = FunctionTransformer::<f64>::new(|x| Ok(x.mapv(|v| v + 1.0)));
let ft2 = ft.clone();
let x = array![[1.0]];
let f1 = FitUnsupervised::fit(&ft, &x).unwrap();
let f2 = FitUnsupervised::fit(&ft2, &x).unwrap();
assert_eq!(f1.transform(&x).unwrap(), f2.transform(&x).unwrap());
}
}