use crate::analysis::pca::{self, PcaResult};
use crate::errors::AnalysisError;
use ndarray::{s, Array1, Array2};
#[derive(Debug, Clone)]
pub struct VarianceSpectrum {
pub explained_variance: Array1<f32>,
pub ratios: Array1<f32>,
pub cumulative: Array1<f32>,
pub components_90pct: usize,
pub components_99pct: usize,
pub top10_concentration: f32,
}
impl VarianceSpectrum {
pub fn truncated(&self, k: usize) -> Self {
let k = k.max(1).min(self.ratios.len());
let explained_variance = self.explained_variance.slice(s![..k]).to_owned();
let ratios = self.ratios.slice(s![..k]).to_owned();
let mut cumulative = Array1::<f32>::zeros(k);
let mut running = 0.0_f32;
for (index, ratio) in ratios.iter().enumerate() {
running += *ratio;
cumulative[index] = running;
}
let components_90pct = cumulative
.iter()
.position(|&c| c >= 0.90)
.map(|i| i + 1)
.unwrap_or(k);
let components_99pct = cumulative
.iter()
.position(|&c| c >= 0.99)
.map(|i| i + 1)
.unwrap_or(k);
let top10_concentration = ratios.iter().take(10).sum();
Self {
explained_variance,
ratios,
cumulative,
components_90pct,
components_99pct,
top10_concentration,
}
}
}
pub fn variance_spectrum_from_pca_result(result: &PcaResult) -> VarianceSpectrum {
let explained_variance = result.explained_variance.clone();
let ratios = result.explained_variance_ratio.clone();
let mut cumulative = Array1::<f32>::zeros(ratios.len());
let mut cum = 0.0_f32;
for (i, &r) in ratios.iter().enumerate() {
cum += r;
cumulative[i] = cum;
}
let components_90pct = cumulative
.iter()
.position(|&c| c >= 0.90)
.map(|i| i + 1)
.unwrap_or(ratios.len());
let components_99pct = cumulative
.iter()
.position(|&c| c >= 0.99)
.map(|i| i + 1)
.unwrap_or(ratios.len());
let top10_concentration = ratios.iter().take(10).sum();
VarianceSpectrum {
explained_variance,
ratios,
cumulative,
components_90pct,
components_99pct,
top10_concentration,
}
}
pub fn variance_spectrum(data: &Array2<f32>, k: usize) -> Result<VarianceSpectrum, AnalysisError> {
let result = pca::pca(data, k, 500)?;
Ok(variance_spectrum_from_pca_result(&result))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectrum_shape() {
let data = Array2::from_shape_fn((50, 16), |(i, j)| (i + j) as f32);
let spec = variance_spectrum(&data, 8).unwrap();
assert_eq!(spec.ratios.len(), 8);
assert_eq!(spec.cumulative.len(), 8);
}
#[test]
fn test_cumulative_monotone() {
let data = Array2::from_shape_fn((40, 8), |(i, j)| (i * 2 + j) as f32);
let spec = variance_spectrum(&data, 6).unwrap();
for i in 1..spec.cumulative.len() {
assert!(spec.cumulative[i] >= spec.cumulative[i - 1] - 1e-5);
}
}
#[test]
fn test_top10_concentration_range() {
let data = Array2::from_shape_fn((60, 16), |(i, j)| (i + j) as f32);
let spec = variance_spectrum(&data, 16).unwrap();
assert!(spec.top10_concentration >= 0.0);
assert!(spec.top10_concentration <= 1.0 + 1e-5);
}
#[test]
fn truncation_recomputes_cumulative_fields() {
let data = Array2::from_shape_fn((60, 16), |(i, j)| (i + j * 2) as f32);
let spec = variance_spectrum(&data, 12).unwrap();
let truncated = spec.truncated(5);
assert_eq!(truncated.explained_variance.len(), 5);
assert_eq!(truncated.ratios.len(), 5);
assert_eq!(truncated.cumulative.len(), 5);
assert!(truncated.cumulative[4] <= 1.0 + 1e-5);
}
#[test]
fn spectrum_from_pca_result_matches_direct_computation() {
let data = Array2::from_shape_fn((60, 16), |(i, j)| (i + j * 3) as f32);
let result = pca::pca(&data, 8, 200).unwrap();
let from_result = variance_spectrum_from_pca_result(&result);
let direct = variance_spectrum(&data, 8).unwrap();
assert_eq!(
from_result.explained_variance.len(),
direct.explained_variance.len()
);
assert_eq!(from_result.components_90pct, direct.components_90pct);
assert_eq!(from_result.components_99pct, direct.components_99pct);
approx::assert_relative_eq!(
from_result.top10_concentration,
direct.top10_concentration,
epsilon = 1e-5
);
}
}