use num_traits::{Num, NumCast};
#[derive(Debug, Clone, Copy)]
pub enum TransformType {
Sqrt,
Log,
Tanh,
}
pub fn apply_transform<T>(data: &[T], transform_type: TransformType) -> Vec<f64>
where
T: Num + NumCast + Copy,
{
match transform_type {
TransformType::Sqrt => sqrt_transform(data),
TransformType::Log => log_transform(data),
TransformType::Tanh => tanh_transform(data),
}
}
pub fn sqrt_transform<T>(data: &[T]) -> Vec<f64>
where
T: Num + NumCast + Copy,
{
if data.is_empty() {
return Vec::new();
}
data.iter()
.filter_map(|&x| NumCast::from(x).map(f64::sqrt))
.collect()
}
pub fn log_transform<T>(data: &[T]) -> Vec<f64>
where
T: Num + NumCast + Copy,
{
if data.is_empty() {
return Vec::new();
}
data.iter()
.filter_map(|&x| {
let x_f: f64 = NumCast::from(x)?;
if x_f > 0.0 { Some(x_f.ln()) } else { None }
})
.collect()
}
pub fn tanh_transform<T>(data: &[T]) -> Vec<f64>
where
T: Num + NumCast + Copy,
{
if data.is_empty() {
return Vec::new();
}
data.iter()
.filter_map(|&x| NumCast::from(x).map(f64::tanh))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_close(a: f64, b: f64, eps: f64) {
assert!(
(a - b).abs() < eps,
"Expected {:.6}, got {:.6}, diff = {:.6}",
b,
a,
(a - b).abs()
);
}
#[test]
fn test_sqrt_transform_with_f64() {
let data = vec![0.0, 1.0, 4.0, 9.0, 16.0];
let expected = [0.0, 1.0, 2.0, 3.0, 4.0];
let result = sqrt_transform(&data);
for (a, b) in result.into_iter().zip(expected.into_iter()) {
assert_close(a, b, 1e-6);
}
}
#[test]
fn test_sqrt_transform_with_u32() {
let data = vec![0u32, 1, 4, 9, 16, 25];
let expected = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let result = sqrt_transform(&data);
for (a, b) in result.into_iter().zip(expected.into_iter()) {
assert_close(a, b, 1e-6);
}
}
#[test]
fn test_sqrt_transform_with_empty_input() {
let data: Vec<u32> = vec![];
let result = sqrt_transform(&data);
assert!(result.is_empty());
}
#[test]
fn test_sqrt_transform_skips_unconvertible_values() {
let data = vec![1u8, 4u8, 9u8];
let result = sqrt_transform(&data);
assert_eq!(result.len(), 3);
}
#[test]
fn test_log_transform_with_f64() {
let data = vec![1.0, std::f64::consts::E, 10.0, 100.0];
let expected = [0.0, 1.0, 10.0_f64.ln(), 100.0_f64.ln()];
let result = log_transform(&data);
for (a, b) in result.iter().zip(expected.iter()) {
assert_close(*a, *b, 1e-10);
}
}
#[test]
fn test_log_transform_with_u32() {
let data = vec![1u32, 2, 10, 100];
let expected = vec![1.0, 2.0, 10.0, 100.0]
.into_iter()
.map(f64::ln)
.collect::<Vec<_>>();
let result = log_transform(&data);
for (a, b) in result.iter().zip(expected.iter()) {
assert_close(*a, *b, 1e-10);
}
}
#[test]
fn test_log_transform_skips_zeros_and_negatives() {
let data = vec![0.0, -1.0, -100.0, 1.0];
let result = log_transform(&data);
assert_eq!(result.len(), 1);
assert_close(result[0], 0.0, 1e-10);
}
#[test]
fn test_log_transform_empty_input() {
let data: Vec<f64> = vec![];
let result = log_transform(&data);
assert!(result.is_empty());
}
#[test]
fn test_tanh_transform_with_f64() {
let data: Vec<f64> = vec![0.0, 1.0, -1.0, 10.0, -10.0];
let expected: Vec<f64> = data.iter().map(|x| x.tanh()).collect();
let result = tanh_transform(&data);
for (a, b) in result.iter().zip(expected.iter()) {
assert_close(*a, *b, 1e-6);
}
}
#[test]
fn test_tanh_transform_with_integers() {
let data = vec![-3, -1, 0, 1, 3];
let expected: Vec<f64> = data.iter().map(|&x| (x as f64).tanh()).collect();
let result = tanh_transform(&data);
for (a, b) in result.iter().zip(expected.iter()) {
assert_close(*a, *b, 1e-6);
}
}
#[test]
fn test_tanh_transform_extremes() {
let data = vec![1000.0, -1000.0];
let result = tanh_transform(&data);
assert_close(result[0], 1.0, 1e-6);
assert_close(result[1], -1.0, 1e-6);
}
#[test]
fn test_tanh_transform_empty_input() {
let data: Vec<f64> = vec![];
let result = tanh_transform(&data);
assert!(result.is_empty());
}
}