use crate::core::{IgraphError, IgraphResult};
#[allow(
clippy::cast_precision_loss,
clippy::needless_range_loop,
unknown_lints,
clippy::manual_midpoint
)]
pub fn dim_select(sv: &[f64]) -> IgraphResult<usize> {
let n = sv.len();
if n == 0 {
return Err(IgraphError::InvalidArgument(
"Need at least one singular value for dimensionality selection".to_string(),
));
}
if n == 1 {
return Ok(1);
}
let nf = n as f64;
let mut sum1 = 0.0_f64;
let mut sum2: f64 = sv.iter().sum();
let mut sumsq1 = 0.0_f64;
let mut sumsq2 = 0.0_f64;
let mut mean1 = 0.0_f64;
let mut mean2 = sum2 / nf;
let mut varsq1 = 0.0_f64;
let mut varsq2 = 0.0_f64;
for &x in sv {
sumsq2 += x * x;
varsq2 += (mean2 - x) * (mean2 - x);
}
let mut max = f64::NEG_INFINITY;
let mut dim = n;
for i in 0..n - 1 {
let n1 = (i + 1) as f64;
let n2 = (n - i - 1) as f64;
let n1m1 = n1 - 1.0;
let n2m1 = n2 - 1.0;
let x = sv[i];
let x2 = x * x;
sum1 += x;
sum2 -= x;
sumsq1 += x2;
sumsq2 -= x2;
let oldmean1 = mean1;
let oldmean2 = mean2;
mean1 = sum1 / n1;
mean2 = sum2 / n2;
varsq1 += (x - oldmean1) * (x - mean1);
varsq2 -= (x - oldmean2) * (x - mean2);
let var1 = if i == 0 { 0.0 } else { varsq1 / n1m1 };
let var2 = if i == n - 2 { 0.0 } else { varsq2 / n2m1 };
let sd = ((n1m1 * var1 + n2m1 * var2) / (nf - 2.0)).sqrt();
let profile = -nf * sd.ln()
- ((sumsq1 - 2.0 * mean1 * sum1 + n1 * mean1 * mean1)
+ (sumsq2 - 2.0 * mean2 * sum2 + n2 * mean2 * mean2))
/ 2.0
/ sd
/ sd;
if profile > max {
max = profile;
dim = i + 1;
}
}
let x = sv[n - 1];
sum1 += x;
let oldmean1 = mean1;
mean1 = sum1 / nf;
sumsq1 += x * x;
varsq1 += (x - oldmean1) * (x - mean1);
let var1 = varsq1 / (nf - 1.0);
let sd = var1.sqrt();
let profile =
-nf * sd.ln() - (sumsq1 - 2.0 * mean1 * sum1 + nf * mean1 * mean1) / 2.0 / sd / sd;
if profile > max {
dim = n;
}
Ok(dim)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_input_errors() {
assert!(dim_select(&[]).is_err());
}
#[test]
fn single_value_is_one() {
assert_eq!(dim_select(&[42.0]).unwrap(), 1);
}
#[test]
fn ascending_ramp_splits_at_midpoint() {
let sv: Vec<f64> = (1..=100).map(f64::from).collect();
assert_eq!(dim_select(&sv).unwrap(), 50);
}
#[test]
fn small_ramp_anchor() {
let sv: Vec<f64> = (1..=10).map(f64::from).collect();
assert_eq!(dim_select(&sv).unwrap(), 5);
}
#[test]
fn clear_gap_is_detected() {
let sv = [100.0, 99.0, 98.0, 1.0, 0.9, 0.8, 0.7, 0.6];
assert_eq!(dim_select(&sv).unwrap(), 3);
}
#[test]
fn two_values_returns_two() {
assert_eq!(dim_select(&[2.0, 1.0]).unwrap(), 2);
}
#[test]
fn result_within_bounds() {
let sv = [5.0, 4.0, 3.0, 2.0, 1.0];
let d = dim_select(&sv).unwrap();
assert!((1..=sv.len()).contains(&d));
}
}