use serde::{Deserialize, Serialize};
use std::hash::Hash;
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
pub enum TransformType {
Linear,
Arcsinh { cofactor: f32 },
Biexponential {
top_of_scale: f32,
positive_decades: f32,
negative_decades: f32,
width: f32,
},
}
impl TransformType {
pub fn create_from_str(s: Option<&str>) -> Self {
match s {
Some("linear") => TransformType::Linear,
Some("arcsinh") => TransformType::Arcsinh { cofactor: 200.0 },
Some("biexponential") | Some("logicle") => TransformType::Biexponential {
top_of_scale: 262144.0,
positive_decades: 4.5,
negative_decades: 0.0,
width: 0.5,
},
_ => TransformType::default(),
}
}
}
pub trait Transformable {
fn transform(&self, value: &f32) -> f32;
fn inverse_transform(&self, value: &f32) -> f32;
}
#[allow(unused)]
pub trait Formattable {
fn format(&self, value: &f32) -> String;
}
impl Transformable for TransformType {
fn transform(&self, value: &f32) -> f32 {
match self {
TransformType::Linear => *value,
TransformType::Arcsinh { cofactor } => (value / cofactor).asinh(),
TransformType::Biexponential {
top_of_scale,
positive_decades,
negative_decades,
width: _,
} => {
let ln_10 = 10.0_f32.ln();
let m_ln10 = positive_decades * ln_10;
let sinh_m_ln10 = m_ln10.sinh();
let a_ln10 = negative_decades * ln_10;
if *top_of_scale == 0.0 {
return *value;
}
let scaled_x = value * sinh_m_ln10 / top_of_scale;
scaled_x.asinh() + a_ln10
}
}
}
fn inverse_transform(&self, value: &f32) -> f32 {
match self {
TransformType::Linear => *value,
TransformType::Arcsinh { cofactor } => {
eprintln!(
"🔧 [INVERSE_TRANSFORM] Arcsinh inverse: value={}, cofactor={}",
value, cofactor
);
let final_result = (*value).sinh() * *cofactor;
eprintln!(
"🔧 [INVERSE_TRANSFORM] final result: {} * {} = {}",
value.sinh(),
cofactor,
final_result
);
final_result
}
TransformType::Biexponential {
top_of_scale,
positive_decades,
negative_decades,
width: _,
} => {
let ln_10 = 10.0_f32.ln();
let m_ln10 = positive_decades * ln_10;
let sinh_m_ln10 = m_ln10.sinh();
let a_ln10 = negative_decades * ln_10;
let y_minus_a = value - a_ln10;
let sinh_y_minus_a = y_minus_a.sinh();
top_of_scale * sinh_y_minus_a / sinh_m_ln10
}
}
}
}
impl Formattable for TransformType {
fn format(&self, value: &f32) -> String {
match self {
TransformType::Linear => format!("{:.1e}", value),
TransformType::Arcsinh { cofactor: _ } => {
let original_value = self.inverse_transform(value);
format!("{:.1e}", original_value)
}
TransformType::Biexponential { .. } => {
let original_value = self.inverse_transform(value);
format!("{:.1e}", original_value)
}
}
}
}
impl Default for TransformType {
fn default() -> Self {
TransformType::Arcsinh { cofactor: 200.0 }
}
}
impl Hash for TransformType {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
TransformType::Linear => "linear".hash(state),
TransformType::Arcsinh { cofactor } => {
"arcsinh".hash(state);
cofactor.to_bits().hash(state);
}
TransformType::Biexponential {
top_of_scale,
positive_decades,
negative_decades,
width,
} => {
"biexponential".hash(state);
top_of_scale.to_bits().hash(state);
positive_decades.to_bits().hash(state);
negative_decades.to_bits().hash(state);
width.to_bits().hash(state);
}
}
}
}
#[test]
fn test_transform() {
let t = TransformType::Linear;
assert_eq!(t.transform(&1.0), 1.0);
assert_eq!(t.inverse_transform(&1.0), 1.0);
let t = TransformType::Arcsinh { cofactor: 200.0 };
let transformed = t.transform(&1.0);
assert!(
(transformed - 0.005).abs() < 1e-6,
"Expected ~0.005, got {}",
transformed
);
let inverse = t.inverse_transform(&0.005);
assert!(
(inverse - 1.0).abs() < 1e-5,
"Expected ~1.0, got {}",
inverse
);
assert!(!t.transform(&-1.0).is_nan());
assert!(!t.transform(&0.0).is_nan());
assert!(!t.transform(&-200.0).is_nan());
}
#[test]
fn test_transform_type_partial_eq_and_hash_consistency() {
use std::hash::{Hash, Hasher};
let a = TransformType::Arcsinh { cofactor: 200.0 };
let b = TransformType::Arcsinh { cofactor: 200.0 };
let c = TransformType::Arcsinh { cofactor: 150.0 };
assert_eq!(a, b);
assert_ne!(a, c);
let mut hasher_a = std::collections::hash_map::DefaultHasher::new();
let mut hasher_b = std::collections::hash_map::DefaultHasher::new();
a.hash(&mut hasher_a);
b.hash(&mut hasher_b);
assert_eq!(hasher_a.finish(), hasher_b.finish());
c.hash(&mut hasher_a);
b.hash(&mut hasher_b);
assert_ne!(hasher_a.finish(), hasher_b.finish());
}