use simba::scalar::ComplexField;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TworkType {
Float64 = 0, Float64X2 = 1, Auto = -1, }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SVDStrategy {
Fast = 0, Accurate = 1, Auto = -1, }
pub fn safe_epsilon(
epsilon: f64,
twork: TworkType,
svd_strategy: SVDStrategy,
) -> (f64, TworkType, SVDStrategy) {
if epsilon < 0.0 {
panic!("eps_required must be non-negative");
}
let twork_actual = match twork {
TworkType::Auto => {
if epsilon.is_nan() || epsilon < 1e-8 {
TworkType::Float64X2 } else {
TworkType::Float64
}
}
other => other,
};
let precision_floor = match twork_actual {
TworkType::Float64 => {
1e-8
}
TworkType::Float64X2 => {
use crate::numeric::CustomNumeric;
crate::Df64::epsilon().sqrt().to_f64()
}
_ => 1e-8,
};
let safe_eps = if epsilon.is_nan() {
precision_floor
} else {
epsilon.max(precision_floor)
};
let svd_strategy_actual = match svd_strategy {
SVDStrategy::Auto => {
if !epsilon.is_nan() && epsilon < safe_eps {
SVDStrategy::Accurate
} else {
SVDStrategy::Fast
}
}
other => other,
};
(safe_eps, twork_actual, svd_strategy_actual)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_epsilon_auto_float64() {
let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
assert_eq!(twork, TworkType::Float64);
assert_eq!(safe_eps, 1e-7);
}
#[test]
fn test_safe_epsilon_auto_float64x2() {
let (safe_eps, twork, _) = safe_epsilon(1e-10, TworkType::Auto, SVDStrategy::Auto);
assert_eq!(twork, TworkType::Float64X2);
assert_eq!(safe_eps, 1e-10);
}
#[test]
fn test_safe_epsilon_explicit_precision() {
let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Float64X2, SVDStrategy::Auto);
assert_eq!(twork, TworkType::Float64X2);
assert_eq!(safe_eps, 1e-7);
}
#[test]
fn test_svd_strategy_auto_accurate() {
let (_, _, strategy) = safe_epsilon(1e-20, TworkType::Auto, SVDStrategy::Auto);
assert_eq!(strategy, SVDStrategy::Accurate);
}
#[test]
fn test_svd_strategy_auto_fast() {
let (_, _, strategy) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
assert_eq!(strategy, SVDStrategy::Fast);
}
#[test]
#[should_panic(expected = "eps_required must be non-negative")]
fn test_negative_epsilon_panics() {
safe_epsilon(-1.0, TworkType::Auto, SVDStrategy::Auto);
}
}