use nalgebra::{DMatrix, DVector, SymmetricEigen};
use rand::rngs::StdRng;
use rand::SeedableRng;
use serde::{Deserialize, Serialize};
use super::nystrom::{GaussianSketch, NystromApprox};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RandNlaConfig {
pub sketch_size: Option<usize>,
pub seed: u64,
pub max_error: f64,
pub fallback_enabled: bool,
}
impl Default for RandNlaConfig {
fn default() -> Self {
Self {
sketch_size: None,
seed: 42,
max_error: 0.001, fallback_enabled: true,
}
}
}
#[derive(Debug, Clone)]
pub struct RandNlaInfo {
pub k: usize,
pub residual_error: f64,
pub used_fallback: bool,
}
pub fn solve_eht_randnla(
h: &DMatrix<f64>,
s: &DMatrix<f64>,
config: &RandNlaConfig,
) -> (DVector<f64>, DMatrix<f64>, RandNlaInfo) {
let n = h.nrows();
let k = config
.sketch_size
.unwrap_or_else(|| GaussianSketch::default_k(n));
let k = k.min(n);
let mut rng = StdRng::seed_from_u64(config.seed);
if k >= n {
return solve_exact_with_info(h, s, k);
}
let sketch = GaussianSketch::new(&mut rng, n, k);
let nystrom = NystromApprox::from_matrix(s, &sketch);
let s_inv_sqrt = nystrom.inverse_sqrt();
let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
let h_eigen = SymmetricEigen::new(h_prime);
let energies = h_eigen.eigenvalues.clone();
let c_prime = h_eigen.eigenvectors.clone();
let c = &s_inv_sqrt * c_prime;
let (sorted_energies, sorted_c) = sort_eigenpairs(&energies, &c);
let residual = compute_residual(h, s, &sorted_energies, &sorted_c);
if config.fallback_enabled && residual > config.max_error {
let (e_exact, c_exact, mut info) = solve_exact_with_info(h, s, k);
info.used_fallback = true;
info.residual_error = residual; return (e_exact, c_exact, info);
}
let info = RandNlaInfo {
k,
residual_error: residual,
used_fallback: false,
};
(sorted_energies, sorted_c, info)
}
fn solve_exact_with_info(
h: &DMatrix<f64>,
s: &DMatrix<f64>,
k: usize,
) -> (DVector<f64>, DMatrix<f64>, RandNlaInfo) {
let n = h.nrows();
let s_eigen = SymmetricEigen::new(s.clone());
let s_vals = &s_eigen.eigenvalues;
let s_vecs = &s_eigen.eigenvectors;
let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
for i in 0..n {
let val = s_vals[i];
if val > 1e-10 {
s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
}
}
let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
let h_eigen = SymmetricEigen::new(h_prime);
let energies = h_eigen.eigenvalues.clone();
let c_prime = h_eigen.eigenvectors.clone();
let c = &s_inv_sqrt * c_prime;
let (sorted_energies, sorted_c) = sort_eigenpairs(&energies, &c);
let residual = compute_residual(h, s, &sorted_energies, &sorted_c);
let info = RandNlaInfo {
k,
residual_error: residual,
used_fallback: false,
};
(sorted_energies, sorted_c, info)
}
fn sort_eigenpairs(energies: &DVector<f64>, c: &DMatrix<f64>) -> (DVector<f64>, DMatrix<f64>) {
let n = energies.len();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
let mut sorted_energies = DVector::zeros(n);
let mut sorted_c = DMatrix::zeros(c.nrows(), n);
for (new_idx, &old_idx) in indices.iter().enumerate() {
sorted_energies[new_idx] = energies[old_idx];
for row in 0..c.nrows() {
sorted_c[(row, new_idx)] = c[(row, old_idx)];
}
}
(sorted_energies, sorted_c)
}
fn compute_residual(
h: &DMatrix<f64>,
s: &DMatrix<f64>,
energies: &DVector<f64>,
c: &DMatrix<f64>,
) -> f64 {
let n = energies.len();
let sc = s * c;
let mut sce = sc.clone();
for j in 0..n {
for i in 0..sce.nrows() {
sce[(i, j)] *= energies[j];
}
}
let hc = h * c;
let diff = &hc - ⪰
let mut diff_norm = 0.0;
let mut sce_norm = 0.0;
for j in 0..n {
for i in 0..diff.nrows() {
diff_norm += diff[(i, j)] * diff[(i, j)];
sce_norm += sce[(i, j)] * sce[(i, j)];
}
}
diff_norm.sqrt() / sce_norm.sqrt().max(1e-15)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
fn make_toy_h_s(n: usize) -> (DMatrix<f64>, DMatrix<f64>) {
let mut rng = StdRng::seed_from_u64(42);
let mut s = DMatrix::identity(n, n);
for i in 0..n {
for j in (i + 1)..n {
let decay = (-0.1 * (j - i) as f64).exp();
let off = rng.gen_range(-0.3..0.3) * decay;
s[(i, j)] = off;
s[(j, i)] = off;
}
}
let mut h = DMatrix::zeros(n, n);
for i in 0..n {
h[(i, i)] = rng.gen_range(-20.0..-5.0);
for j in (i + 1)..n {
let off = rng.gen_range(-3.0..0.0);
h[(i, j)] = off;
h[(j, i)] = off;
}
}
(h, s)
}
#[test]
fn test_exact_fallback_small_system() {
let (h, s) = make_toy_h_s(6);
let config = RandNlaConfig {
sketch_size: Some(6), seed: 42,
..Default::default()
};
let (energies, _c, info) = solve_eht_randnla(&h, &s, &config);
assert!(!info.used_fallback);
assert!(info.residual_error < 1e-10);
for i in 1..energies.len() {
assert!(energies[i] >= energies[i - 1] - 1e-12);
}
}
#[test]
fn test_randomized_vs_exact() {
let n = 20;
let (h, s) = make_toy_h_s(n);
let config_exact = RandNlaConfig {
sketch_size: Some(n),
seed: 42,
..Default::default()
};
let (e_exact, _, _) = solve_eht_randnla(&h, &s, &config_exact);
let config_rand = RandNlaConfig {
sketch_size: Some(17),
seed: 42,
max_error: 1.0, fallback_enabled: false,
};
let (e_rand, _, info) = solve_eht_randnla(&h, &s, &config_rand);
assert!(!info.used_fallback);
assert_eq!(info.k, 17);
let homo_exact = e_exact[n / 2 - 1];
let homo_rand = e_rand[n / 2 - 1];
let rel_diff = ((homo_exact - homo_rand) / homo_exact).abs();
assert!(
rel_diff < 0.10,
"HOMO relative error = {:.4}%, expected < 10%",
rel_diff * 100.0
);
}
#[test]
fn test_residual_computation() {
let n = 10;
let (h, s) = make_toy_h_s(n);
let config = RandNlaConfig {
sketch_size: Some(n),
seed: 42,
..Default::default()
};
let (_, _, info) = solve_eht_randnla(&h, &s, &config);
assert!(
info.residual_error < 1e-8,
"Exact solve residual = {}, expected < 1e-8",
info.residual_error
);
}
#[test]
fn test_fallback_on_poor_approximation() {
let n = 10;
let (h, s) = make_toy_h_s(n);
let config = RandNlaConfig {
sketch_size: Some(3), seed: 42,
max_error: 0.001,
fallback_enabled: true,
};
let (_, _, info) = solve_eht_randnla(&h, &s, &config);
assert!(info.k == 3);
}
}