use vyre_primitives::math::sheaf_laplacian_eigenvalue::cpu_ref_into;
pub const DEFAULT_POWER_ITERATIONS: u32 = 32;
#[derive(Debug, Default)]
pub struct SheafSpectrumScratch {
v_init: Vec<f64>,
v: Vec<f64>,
v_next: Vec<f64>,
}
impl SheafSpectrumScratch {
#[must_use]
pub fn eigenvector(&self) -> &[f64] {
&self.v
}
}
#[must_use]
pub fn dominant_spectrum(restriction_diag: &[f64], iterations: u32) -> (f64, Vec<f64>) {
use crate::observability::{bump, sheaf_spectral_clustering_calls};
bump(&sheaf_spectral_clustering_calls);
let mut scratch = SheafSpectrumScratch::default();
let lambda = dominant_spectrum_with_scratch(restriction_diag, iterations, &mut scratch);
(lambda, scratch.v)
}
pub fn dominant_spectrum_with_scratch(
restriction_diag: &[f64],
iterations: u32,
scratch: &mut SheafSpectrumScratch,
) -> f64 {
dominant_spectrum_into(
restriction_diag,
iterations,
&mut scratch.v_init,
&mut scratch.v,
&mut scratch.v_next,
)
}
pub fn dominant_spectrum_into(
restriction_diag: &[f64],
iterations: u32,
v_init: &mut Vec<f64>,
v: &mut Vec<f64>,
v_next: &mut Vec<f64>,
) -> f64 {
let n = restriction_diag.len();
if n == 0 {
v_init.clear();
v.clear();
v_next.clear();
return 0.0;
}
let inv_sqrt_n = 1.0 / (n as f64).sqrt();
v_init.clear();
v_init.resize(n, inv_sqrt_n);
cpu_ref_into(restriction_diag, v_init, iterations, v, v_next)
}
#[must_use]
pub fn spectral_gap(restriction_diag: &[f64]) -> f64 {
let mut scratch = SheafSpectrumScratch::default();
spectral_gap_into(restriction_diag, &mut scratch)
}
pub fn spectral_gap_into(restriction_diag: &[f64], scratch: &mut SheafSpectrumScratch) -> f64 {
let lambda =
dominant_spectrum_with_scratch(restriction_diag, DEFAULT_POWER_ITERATIONS, scratch);
let max_diag = restriction_diag.iter().cloned().fold(0.0_f64, f64::max);
if max_diag <= 1e-20 {
0.0
} else {
(lambda / max_diag).clamp(0.0, 1.0)
}
}
#[must_use]
pub fn suggested_cluster_count(restriction_diag: &[f64]) -> u32 {
let mut scratch = SheafSpectrumScratch::default();
suggested_cluster_count_into(restriction_diag, &mut scratch)
}
pub fn suggested_cluster_count_into(
restriction_diag: &[f64],
scratch: &mut SheafSpectrumScratch,
) -> u32 {
dominant_spectrum_with_scratch(restriction_diag, DEFAULT_POWER_ITERATIONS, scratch);
let v = scratch.eigenvector();
if v.is_empty() {
return 0;
}
let mut count: u32 = 1;
let mut last_sign = v[0].signum();
for &x in v.iter().skip(1) {
let sign = x.signum();
if sign != 0.0 && sign != last_sign && last_sign != 0.0 {
count = count.saturating_add(1);
last_sign = sign;
} else if last_sign == 0.0 && sign != 0.0 {
last_sign = sign;
}
}
count
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-3 * (1.0 + a.abs() + b.abs())
}
#[test]
fn dominant_eigenvalue_of_uniform_diag_is_diag_value() {
let diag = vec![0.7, 0.7, 0.7, 0.7];
let (lambda, _v) = dominant_spectrum(&diag, 64);
assert!(approx_eq(lambda, 0.7), "got lambda={lambda}");
}
#[test]
fn dominant_eigenvalue_of_nonuniform_picks_max() {
let diag = vec![0.1, 0.5, 0.9, 0.3];
let (lambda, v) = dominant_spectrum(&diag, 128);
assert!((lambda - 0.9).abs() < 0.01, "got lambda={lambda}");
let max_idx = v
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
assert_eq!(max_idx, 2);
}
#[test]
fn empty_input_returns_zero_spectrum() {
let (lambda, v) = dominant_spectrum(&[], 32);
assert_eq!(lambda, 0.0);
assert!(v.is_empty());
}
#[test]
fn spectral_gap_is_one_for_uniform_diag() {
let diag = vec![0.5; 8];
let gap = spectral_gap(&diag);
assert!((gap - 1.0).abs() < 1e-3);
}
#[test]
fn scratch_paths_match_owned_spectral_helpers() {
let diag = vec![0.1, 0.5, 0.9, 0.3];
let (owned_lambda, owned_v) = dominant_spectrum(&diag, 64);
let mut scratch = SheafSpectrumScratch::default();
let borrowed_lambda = dominant_spectrum_with_scratch(&diag, 64, &mut scratch);
assert!(approx_eq(owned_lambda, borrowed_lambda));
assert_eq!(scratch.eigenvector().len(), owned_v.len());
let owned_gap = spectral_gap(&diag);
let scratch_gap = spectral_gap_into(&diag, &mut scratch);
assert!(approx_eq(owned_gap, scratch_gap));
let owned_count = suggested_cluster_count(&diag);
let scratch_count = suggested_cluster_count_into(&diag, &mut scratch);
assert_eq!(owned_count, scratch_count);
}
#[test]
fn suggested_cluster_count_at_least_one() {
let diag = vec![0.7; 4];
let count = suggested_cluster_count(&diag);
assert!(count >= 1);
}
}