use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{self, Rng};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use crate::norm::vector_norm;
#[allow(dead_code)]
pub fn largest_k_eigh<F>(
a: &ArrayView2<F>,
k: usize,
max_iter: usize,
tol: F,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square matrix, got shape {:?}",
a.shape()
)));
}
for i in 0..a.nrows() {
for j in (i + 1)..a.ncols() {
if (a[[i, j]] - a[[j, i]]).abs() > F::epsilon() {
return Err(LinalgError::ShapeError(
"Matrix must be symmetric for this solver".to_string(),
));
}
}
}
let n = a.nrows();
if k > n {
return Err(LinalgError::ValueError(format!(
"k ({}) cannot be larger than matrix size ({})",
k, n
)));
}
if k == 0 {
return Ok((Array1::zeros(0), Array2::zeros((n, 0))));
}
let mut eigenvalues = Array1::zeros(k);
let mut eigenvectors = Array2::zeros((n, k));
let mut a_work = a.to_owned();
for i in 0..k {
let (eigenvalue, eigenvector) =
match power_iteration_with_convergence(&a_work.view(), max_iter, tol) {
Ok((lambda, v)) => (lambda, v),
Err(e) => return Err(e),
};
eigenvalues[i] = eigenvalue;
for j in 0..n {
eigenvectors[[j, i]] = eigenvector[j];
}
if i == k - 1 {
break;
}
for p in 0..n {
for q in 0..n {
a_work[[p, q]] -= eigenvalue * eigenvector[p] * eigenvector[q];
}
}
}
Ok((eigenvalues, eigenvectors))
}
#[allow(dead_code)]
pub fn smallest_k_eigh<F>(
a: &ArrayView2<F>,
k: usize,
max_iter: usize,
tol: F,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square matrix, got shape {:?}",
a.shape()
)));
}
for i in 0..a.nrows() {
for j in (i + 1)..a.ncols() {
if (a[[i, j]] - a[[j, i]]).abs() > F::epsilon() {
return Err(LinalgError::ShapeError(
"Matrix must be symmetric for this solver".to_string(),
));
}
}
}
let n = a.nrows();
if k > n {
return Err(LinalgError::ValueError(format!(
"k ({}) cannot be larger than matrix size ({})",
k, n
)));
}
if k == 0 {
return Ok((Array1::zeros(0), Array2::zeros((n, 0))));
}
let (_largest_eigenvalue_) = match power_iteration_with_convergence(a, max_iter, tol) {
Ok((lambda, v)) => (lambda, v),
Err(e) => return Err(e),
};
let full_k = n.min(5); let (all_eigenvalues, all_eigenvectors) = largest_k_eigh(a, full_k, max_iter, tol)?;
if k > all_eigenvalues.len() {
return Err(LinalgError::ValueError(format!(
"Requested {} eigenvalues but matrix only has {} computed eigenvalues. Use a full eigenvalue solver.",
k, all_eigenvalues.len()
)));
}
let mut eigenvalue_pairs: Vec<(F, usize)> = all_eigenvalues
.iter()
.enumerate()
.map(|(i, &lambda)| (lambda, i))
.collect();
eigenvalue_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut eigenvalues = Array1::zeros(k);
let mut eigenvectors = Array2::zeros((n, k));
for i in 0..k {
let (eigenvalue, orig_index) = eigenvalue_pairs[i];
eigenvalues[i] = eigenvalue;
for j in 0..n {
eigenvectors[[j, i]] = all_eigenvectors[[j, orig_index]];
}
}
Ok((eigenvalues, eigenvectors))
}
#[allow(dead_code)]
fn power_iteration_with_convergence<F>(
a: &ArrayView2<F>,
max_iter: usize,
tol: F,
) -> LinalgResult<(F, Array1<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = a.nrows();
let mut rng = scirs2_core::random::rng();
let mut b = Array1::zeros(n);
for i in 0..n {
b[i] = F::from(rng.random_range(-1.0..1.0)).unwrap_or(F::zero());
}
let norm_b = vector_norm(&b.view(), 2)?;
b.mapv_inplace(|x| x / norm_b);
let mut eigenvalue = F::zero();
let mut prev_eigenvalue = F::zero();
for _ in 0..max_iter {
let mut b_new = a.dot(&b);
eigenvalue = F::zero();
for i in 0..n {
eigenvalue += b[i] * b_new[i];
}
let norm_b_new = vector_norm(&b_new.view(), 2)?;
if norm_b_new < F::epsilon() {
return Err(LinalgError::ComputationError(
"Power iteration produced zero vector".to_string(),
));
}
b_new.mapv_inplace(|x| x / norm_b_new);
if (eigenvalue - prev_eigenvalue).abs() < tol {
return Ok((eigenvalue, b_new));
}
prev_eigenvalue = eigenvalue;
b = b_new;
}
let current_residual = (eigenvalue - prev_eigenvalue).abs();
Err(LinalgError::convergence_with_suggestions(
"Power iteration",
max_iter,
tol.to_f64().unwrap_or(1e-10),
Some(current_residual.to_f64().unwrap_or(1.0)),
))
}
#[allow(dead_code)]
fn solve_with_lu<F>(p: &Array2<F>, l: &Array2<F>, u: &Array2<F>, b: &Array1<F>) -> Array1<F>
where
F: Float + NumAssign + 'static,
{
let n = b.len();
let b_perm = p.dot(b);
let mut y = Array1::zeros(n);
for i in 0..n {
let mut sum = F::zero();
for j in 0..i {
sum += l[[i, j]] * y[j];
}
y[i] = (b_perm[i] - sum) / l[[i, i]];
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = F::zero();
for j in (i + 1)..n {
sum += u[[i, j]] * x[j];
}
x[i] = (y[i] - sum) / u[[i, i]];
}
x
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_largest_k_eigh_simple() {
let a = array![[2.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 3.0]];
let (eigenvalues, eigenvectors) = largest_k_eigh(&a.view(), 2, 100, 1e-10).expect("Operation failed");
assert_relative_eq!(eigenvalues[0], 3.0, epsilon = 1e-8);
assert_relative_eq!(eigenvalues[1], 2.0, epsilon = 1e-8);
let first_is_z_axis = eigenvectors[[0, 0]].abs() < 1e-3
&& eigenvectors[[1, 0]].abs() < 1e-3
&& (eigenvectors[[2, 0]].abs() - 1.0).abs() < 1e-3;
let second_is_x_axis = (eigenvectors[[0, 1]].abs() - 1.0).abs() < 1e-3
&& eigenvectors[[1, 1]].abs() < 1e-3
&& eigenvectors[[2, 1]].abs() < 1e-3;
assert!(
first_is_z_axis,
"First eigenvector (λ=3.0) should be along z-axis: [{}, {}, {}]",
eigenvectors[[0, 0]],
eigenvectors[[1, 0]],
eigenvectors[[2, 0]]
);
assert!(
second_is_x_axis,
"Second eigenvector (λ=2.0) should be along x-axis: [{}, {}, {}]",
eigenvectors[[0, 1]],
eigenvectors[[1, 1]],
eigenvectors[[2, 1]]
);
}
#[test]
fn test_smallest_k_eigh_simple() {
let a = array![[2.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 3.0]];
let (eigenvalues, eigenvectors) = smallest_k_eigh(&a.view(), 2, 100, 1e-10).expect("Operation failed");
assert_relative_eq!(eigenvalues[0], 1.0, epsilon = 1e-8);
assert_relative_eq!(eigenvalues[1], 2.0, epsilon = 1e-8);
let first_is_y_axis = eigenvectors[[0, 0]].abs() < 1e-3
&& (eigenvectors[[1, 0]].abs() - 1.0).abs() < 1e-3
&& eigenvectors[[2, 0]].abs() < 1e-3;
let second_is_x_axis = (eigenvectors[[0, 1]].abs() - 1.0).abs() < 1e-3
&& eigenvectors[[1, 1]].abs() < 1e-3
&& eigenvectors[[2, 1]].abs() < 1e-3;
assert!(
first_is_y_axis,
"First eigenvector (λ=1.0) should be along y-axis: [{}, {}, {}]",
eigenvectors[[0, 0]],
eigenvectors[[1, 0]],
eigenvectors[[2, 0]]
);
assert!(
second_is_x_axis,
"Second eigenvector (λ=2.0) should be along x-axis: [{}, {}, {}]",
eigenvectors[[0, 1]],
eigenvectors[[1, 1]],
eigenvectors[[2, 1]]
);
}
#[test]
fn test_power_iteration_with_convergence() {
let a = array![[3.0_f64, 1.0], [1.0, 3.0]];
let (eigenvalue, eigenvector) =
power_iteration_with_convergence(&a.view(), 100, 1e-10).expect("Operation failed");
assert_relative_eq!(eigenvalue, 4.0, epsilon = 1e-8);
let norm = vector_norm(&eigenvector.view(), 2).expect("Operation failed");
assert_relative_eq!(norm, 1.0, epsilon = 1e-10);
let expected_val = 1.0 / 2.0_f64.sqrt();
let is_positive = (eigenvector[0] - expected_val).abs() < 1e-4
&& (eigenvector[1] - expected_val).abs() < 1e-4;
let is_negative = (eigenvector[0] + expected_val).abs() < 1e-4
&& (eigenvector[1] + expected_val).abs() < 1e-4;
assert!(
is_positive || is_negative,
"Eigenvector {:?} is not close to [{}, {}] or [{}, {}]",
eigenvector,
expected_val,
expected_val,
-expected_val,
-expected_val
);
}
}