use ndarray::{Array1, Array2, ArrayView2};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
use crate::lasso::subsampling::{SubsamplingScheme, get_row_indices, select_rows};
const LIPSCHITZ_MAXITS: usize = 20;
const LIPSCHITZ_TOL: f64 = 5e-3;
fn power_iteration(x: &ArrayView2<f64>, v: &Array1<f64>) -> (Array1<f64>, f64) {
let xv: Array1<f64> = x.dot(v);
let mut v_new: Array1<f64> = x.t().dot(&xv);
let s = norm2(&v_new);
if s > 0.0 {
v_new.mapv_inplace(|x| x / s);
}
(v_new, s)
}
fn subsampled_power_iteration(
x: &Array2<f64>,
v: &Array1<f64>,
scheme: &SubsamplingScheme,
rng: &mut ChaCha8Rng,
) -> (Array1<f64>, f64) {
let indices = get_row_indices(x.nrows(), scheme, rng);
let x_sub = select_rows(&x.view(), &indices);
let (v_new, s) = power_iteration(&x_sub.view(), v);
let frac = indices.len() as f64 / x.nrows() as f64;
(v_new, s / frac)
}
#[inline]
fn norm2(v: &Array1<f64>) -> f64 {
v.dot(v).sqrt()
}
pub fn find_largest_singular_value(
x: &Array2<f64>,
seed: u64,
scheme: &SubsamplingScheme,
maxits: Option<usize>,
tol: Option<f64>,
) -> f64 {
let maxits = maxits.unwrap_or(LIPSCHITZ_MAXITS);
let tol = tol.unwrap_or(LIPSCHITZ_TOL);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let normal = StandardNormal;
let p = x.ncols();
let raw: Vec<f64> = (0..p).map(|_| normal.sample(&mut rng)).collect();
let mut v = Array1::from_vec(raw);
let n0 = norm2(&v);
v.mapv_inplace(|x| x / n0);
let mut s = n0; let mut s_best = s;
for i in 0..maxits {
let s_prev = s;
let v_prev = v.clone();
let (v_new, s_new) = subsampled_power_iteration(x, &v, scheme, &mut rng);
v = v_new;
s = s_new;
let improvement = (s - s_prev).abs() / s.abs().max(s_prev.abs());
if improvement < tol && i > 0 {
return s_best.sqrt();
}
if s > s_best {
s_best = s;
} else {
s = s_prev;
v = v_prev;
}
}
s_best.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn singular_value_identity() {
let eye = Array2::<f64>::eye(3);
let sv = find_largest_singular_value(&eye, 0, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv, 1.0, epsilon = 0.05);
}
#[test]
fn singular_value_scaled_identity() {
let x = Array2::<f64>::eye(3) * 5.0;
let sv = find_largest_singular_value(&x, 1, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv, 5.0, epsilon = 0.1);
}
#[test]
fn singular_value_diagonal() {
let mut d = Array2::<f64>::zeros((4, 4));
d[[0, 0]] = 1.0;
d[[1, 1]] = 2.0;
d[[2, 2]] = 7.0;
d[[3, 3]] = 3.0;
let sv = find_largest_singular_value(&d, 42, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv, 7.0, epsilon = 0.2);
}
#[test]
fn singular_value_rank1_matrix() {
let u = ndarray::array![[1.0], [2.0], [3.0]];
let v = ndarray::array![[1.0, 1.0]];
let x = u.dot(&v); let expected = (1.0 + 4.0 + 9.0_f64).sqrt() * 2.0_f64.sqrt();
let sv = find_largest_singular_value(&x, 7, &SubsamplingScheme::None, Some(50), None);
assert_abs_diff_eq!(sv, expected, epsilon = 0.1);
}
#[test]
fn singular_value_rectangular_tall() {
let x = ndarray::array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
let sv = find_largest_singular_value(&x, 0, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv, 2.0, epsilon = 0.1);
}
#[test]
fn singular_value_rectangular_wide() {
let x = ndarray::array![[3.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
let sv = find_largest_singular_value(&x, 0, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv, 3.0, epsilon = 0.15);
}
#[test]
fn singular_value_positive_for_nonzero_matrix() {
let x = ndarray::array![[1.0, 2.0], [3.0, 4.0]];
let sv = find_largest_singular_value(&x, 0, &SubsamplingScheme::None, None, None);
assert!(sv > 0.0);
}
#[test]
fn singular_value_deterministic_with_same_seed() {
let x = ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let sv1 = find_largest_singular_value(&x, 42, &SubsamplingScheme::None, None, None);
let sv2 = find_largest_singular_value(&x, 42, &SubsamplingScheme::None, None, None);
assert_abs_diff_eq!(sv1, sv2, epsilon = 1e-15);
}
#[test]
fn norm2_known_values() {
let v = Array1::from_vec(vec![3.0, 4.0]);
assert_abs_diff_eq!(norm2(&v), 5.0, epsilon = 1e-15);
}
}