use crate::circulant_toeplitz::{CirculantMatrix, ToeplitzMatrix};
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub fn solve_toeplitz<F>(
r: &ArrayView1<F>,
c: &ArrayView1<F>,
b: &ArrayView1<F>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + Into<f64> + Clone + 'static,
{
let n = r.len();
if n == 0 {
return Err(LinalgError::InvalidInputError(
"solve_toeplitz: input vectors must be non-empty".to_string(),
));
}
if c.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_toeplitz: r has length {n} but c has length {}",
c.len()
)));
}
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_toeplitz: r has length {n} but b has length {}",
b.len()
)));
}
if (r[0] - c[0]).abs() > F::epsilon() * (r[0].abs() + F::one()) {
return Err(LinalgError::ValueError(
"solve_toeplitz: r[0] must equal c[0] (diagonal element)".to_string(),
));
}
let t0 = r[0];
if t0.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"solve_toeplitz: diagonal element is zero (singular matrix)".to_string(),
));
}
if n == 1 {
return Ok(Array1::from_vec(vec![b[0] / t0]));
}
let mut fwd = vec![F::zero(); n];
let mut bwd = vec![F::zero(); n];
let mut sol = vec![F::zero(); n];
fwd[0] = F::one() / t0;
bwd[0] = F::one() / t0;
sol[0] = b[0] / t0;
for k in 1..n {
let mut lambda = F::zero();
let mut mu = F::zero();
let mut eps = b[k];
for j in 0..k {
lambda += c[k - j] * fwd[j];
mu += r[k - j] * bwd[j];
eps -= c[k - j] * sol[j];
}
let den = F::one() - lambda * mu;
if den.abs() < F::from(1e-14_f64).unwrap_or(F::epsilon()) {
return Err(LinalgError::SingularMatrixError(format!(
"solve_toeplitz: Levinson step {k}: leading minor is near-singular"
)));
}
let den_inv = F::one() / den;
let fwd_old: Vec<F> = fwd[..k].to_vec();
let bwd_old: Vec<F> = bwd[..k].to_vec();
fwd[0] = fwd_old[0] * den_inv;
for j in 1..k {
fwd[j] = (fwd_old[j] - lambda * bwd_old[k - j]) * den_inv;
}
fwd[k] = (-lambda * bwd_old[0]) * den_inv;
for j in 0..k {
bwd[j] = (-mu * fwd_old[j]) * den_inv;
}
bwd[k] = bwd_old[k - 1] * den_inv;
for j in 0..=k {
sol[j] += eps * bwd[j];
}
}
Ok(Array1::from_vec(sol))
}
pub fn solve_circulant<F>(c: &ArrayView1<F>, b: &ArrayView1<F>) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + Into<f64> + Clone + 'static,
{
let n = c.len();
if n == 0 {
return Err(LinalgError::InvalidInputError(
"solve_circulant: input vectors must be non-empty".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_circulant: c has length {n} but b has length {}",
b.len()
)));
}
let c_owned = c.to_owned();
let mut circ = CirculantMatrix::new(c_owned)?;
circ.solve(b)
}
pub fn solve_tridiagonal<F>(
lower: &ArrayView1<F>,
diag: &ArrayView1<F>,
upper: &ArrayView1<F>,
b: &ArrayView1<F>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let n = diag.len();
if n == 0 {
return Err(LinalgError::InvalidInputError(
"solve_tridiagonal: diagonal must be non-empty".to_string(),
));
}
if lower.len() != n || upper.len() != n || b.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_tridiagonal: all vectors must have length {n}"
)));
}
let mut c_prime = Array1::<F>::zeros(n); let mut d_prime = Array1::<F>::zeros(n);
if diag[0].abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"solve_tridiagonal: diagonal element at index 0 is zero (singular)".to_string(),
));
}
c_prime[0] = upper[0] / diag[0];
d_prime[0] = b[0] / diag[0];
for i in 1..n {
let denom = diag[i] - lower[i] * c_prime[i - 1];
if denom.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"solve_tridiagonal: pivot at index {i} is zero (singular matrix)"
)));
}
if i < n - 1 {
c_prime[i] = upper[i] / denom;
}
d_prime[i] = (b[i] - lower[i] * d_prime[i - 1]) / denom;
}
let mut x = Array1::<F>::zeros(n);
x[n - 1] = d_prime[n - 1];
for i in (0..n - 1).rev() {
x[i] = d_prime[i] - c_prime[i] * x[i + 1];
}
Ok(x)
}
pub fn solve_banded<F>(
kl: usize,
ku: usize,
ab: &ArrayView2<F>,
b: &ArrayView1<F>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let bandwidth = kl + ku + 1;
let (ab_rows, n) = (ab.nrows(), ab.ncols());
if ab_rows != bandwidth {
return Err(LinalgError::ShapeError(format!(
"solve_banded: ab should have {} rows (kl+ku+1), got {ab_rows}",
bandwidth
)));
}
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_banded: RHS length {} does not match matrix size {n}",
b.len()
)));
}
if n == 0 {
return Ok(Array1::zeros(0));
}
let mut ab_work = ab.to_owned();
let mut x = b.to_owned();
for k in 0..n {
let pivot = ab_work[[ku, k]];
if pivot.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"solve_banded: zero pivot at column {k}"
)));
}
let lower_end = (kl + 1).min(n - k);
for i in 1..lower_end {
let factor = ab_work[[ku + i, k]] / pivot;
if factor.abs() < F::epsilon() {
continue;
}
let upper_end = (ku + 1).min(n - k);
for j in 0..upper_end {
if ku + i > j && ku + i - j < bandwidth {
let row_src = ku - j;
let row_dst = ku + i - j;
let col = k + j;
if col < n {
let src_val = ab_work[[row_src, col]];
ab_work[[row_dst, col]] -= factor * src_val;
}
}
}
let xk = x[k]; x[k + i] -= factor * xk;
ab_work[[ku + i, k]] = F::zero();
}
}
for k in (0..n).rev() {
let pivot = ab_work[[ku, k]];
if pivot.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"solve_banded: zero pivot at column {k} during back-substitution"
)));
}
let upper_end = (ku + 1).min(n - k);
for j in 1..upper_end {
let xkj = x[k + j]; x[k] -= ab_work[[ku - j, k + j]] * xkj;
}
x[k] /= pivot;
}
Ok(x)
}
pub fn solve_triangular<F>(
t: &ArrayView2<F>,
b: &ArrayView1<F>,
lower: bool,
unit_diagonal: bool,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let n = t.nrows();
if n == 0 {
return Err(LinalgError::InvalidInputError(
"solve_triangular: matrix must be non-empty".to_string(),
));
}
if t.ncols() != n {
return Err(LinalgError::ShapeError(
"solve_triangular: matrix must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"solve_triangular: matrix is {n}×{n} but b has length {}",
b.len()
)));
}
let mut x = Array1::<F>::zeros(n);
if lower {
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= t[[i, j]] * x[j];
}
if unit_diagonal {
x[i] = sum;
} else {
let tii = t[[i, i]];
if tii.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"solve_triangular: zero diagonal at row {i}"
)));
}
x[i] = sum / tii;
}
}
} else {
for i in (0..n).rev() {
let mut sum = b[i];
for j in (i + 1)..n {
sum -= t[[i, j]] * x[j];
}
if unit_diagonal {
x[i] = sum;
} else {
let tii = t[[i, i]];
if tii.abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"solve_triangular: zero diagonal at row {i}"
)));
}
x[i] = sum / tii;
}
}
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_toeplitz_identity() {
let r = array![1.0_f64, 0.0, 0.0];
let c = array![1.0_f64, 0.0, 0.0];
let b = array![2.0_f64, 3.0, 4.0];
let x = solve_toeplitz(&r.view(), &c.view(), &b.view()).expect("toeplitz identity");
for (xi, bi) in x.iter().zip(b.iter()) {
assert!((xi - bi).abs() < 1e-10, "{xi} ≠ {bi}");
}
}
#[test]
fn test_toeplitz_symmetric_2x2() {
let r = array![2.0_f64, 1.0];
let c = array![2.0_f64, 1.0];
let b = array![3.0_f64, 3.0];
let x = solve_toeplitz(&r.view(), &c.view(), &b.view()).expect("toeplitz 2x2");
assert!((x[0] - 1.0).abs() < 1e-10 && (x[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_toeplitz_nonsymmetric() {
let r = array![2.0_f64, 1.0];
let c = array![2.0_f64, -1.0];
let b = array![1.0_f64, 0.0];
let x = solve_toeplitz(&r.view(), &c.view(), &b.view()).expect("toeplitz nonsym");
let t00 = 2.0_f64;
let t01 = 1.0_f64;
let t10 = -1.0_f64;
let t11 = 2.0_f64;
let r0 = t00 * x[0] + t01 * x[1] - b[0];
let r1 = t10 * x[0] + t11 * x[1] - b[1];
assert!(r0.abs() < 1e-10 && r1.abs() < 1e-10);
}
#[test]
fn test_toeplitz_diagonal_mismatch_error() {
let r = array![2.0_f64, 1.0];
let c = array![3.0_f64, 1.0]; let b = array![1.0_f64, 0.0];
assert!(solve_toeplitz(&r.view(), &c.view(), &b.view()).is_err());
}
#[test]
fn test_toeplitz_singular_error() {
let r = array![0.0_f64, 1.0];
let c = array![0.0_f64, 1.0];
let b = array![1.0_f64, 0.0];
assert!(solve_toeplitz(&r.view(), &c.view(), &b.view()).is_err());
}
#[test]
fn test_circulant_2x2() {
let c = array![2.0_f64, 1.0];
let b = array![3.0_f64, 3.0];
let x = solve_circulant(&c.view(), &b.view()).expect("circulant 2x2");
assert!((x[0] - 1.0).abs() < 1e-9 && (x[1] - 1.0).abs() < 1e-9);
}
#[test]
fn test_circulant_identity() {
let c = array![1.0_f64, 0.0, 0.0, 0.0];
let b = array![5.0_f64, 3.0, 1.0, 2.0];
let x = solve_circulant(&c.view(), &b.view()).expect("circulant identity");
for (xi, bi) in x.iter().zip(b.iter()) {
assert!((xi - bi).abs() < 1e-9);
}
}
#[test]
fn test_circulant_singular_error() {
let c = array![0.0_f64, 0.0];
let b = array![1.0_f64, 0.0];
assert!(solve_circulant(&c.view(), &b.view()).is_err());
}
#[test]
fn test_circulant_dimension_mismatch_error() {
let c = array![1.0_f64, 0.0];
let b = array![1.0_f64, 0.0, 0.0];
assert!(solve_circulant(&c.view(), &b.view()).is_err());
}
#[test]
fn test_tridiagonal_simple() {
let lower = array![0.0_f64, -1.0, -1.0];
let diag = array![2.0_f64, 2.0, 2.0];
let upper = array![-1.0_f64, -1.0, 0.0];
let b = array![1.0_f64, 0.0, 1.0];
let x = solve_tridiagonal(&lower.view(), &diag.view(), &upper.view(), &b.view())
.expect("tridiagonal simple");
let r0 = 2.0 * x[0] - x[1] - 1.0;
let r1 = -x[0] + 2.0 * x[1] - x[2];
let r2 = -x[1] + 2.0 * x[2] - 1.0;
assert!(r0.abs() < 1e-10 && r1.abs() < 1e-10 && r2.abs() < 1e-10);
}
#[test]
fn test_tridiagonal_diagonal_system() {
let lower = array![0.0_f64, 0.0, 0.0];
let diag = array![2.0_f64, 4.0, 8.0];
let upper = array![0.0_f64, 0.0, 0.0];
let b = array![6.0_f64, 12.0, 24.0];
let x = solve_tridiagonal(&lower.view(), &diag.view(), &upper.view(), &b.view())
.expect("tridiagonal diagonal");
assert!((x[0] - 3.0).abs() < 1e-10);
assert!((x[1] - 3.0).abs() < 1e-10);
assert!((x[2] - 3.0).abs() < 1e-10);
}
#[test]
fn test_tridiagonal_singular_error() {
let lower = array![0.0_f64, 1.0];
let diag = array![0.0_f64, 2.0]; let upper = array![1.0_f64, 0.0];
let b = array![1.0_f64, 1.0];
assert!(solve_tridiagonal(&lower.view(), &diag.view(), &upper.view(), &b.view()).is_err());
}
#[test]
fn test_banded_tridiagonal() {
let ab = Array2::from_shape_vec(
(3, 3),
vec![
0.0_f64, -1.0, -1.0, 2.0, 2.0, 2.0, -1.0, -1.0, 0.0, ],
).expect("shape");
let b = array![1.0_f64, 0.0, 1.0];
let x = solve_banded(1, 1, &ab.view(), &b.view()).expect("banded tridiagonal");
let r0 = 2.0 * x[0] - x[1] - 1.0;
let r1 = -x[0] + 2.0 * x[1] - x[2];
let r2 = -x[1] + 2.0 * x[2] - 1.0;
assert!(
r0.abs() < 1e-10 && r1.abs() < 1e-10 && r2.abs() < 1e-10,
"residuals: {r0} {r1} {r2}, x = {x:?}"
);
}
#[test]
fn test_banded_diagonal_only() {
let ab = Array2::from_shape_vec((1, 3), vec![2.0_f64, 4.0, 8.0]).expect("shape");
let b = array![6.0_f64, 12.0, 24.0];
let x = solve_banded(0, 0, &ab.view(), &b.view()).expect("banded diagonal");
assert!((x[0] - 3.0).abs() < 1e-10);
assert!((x[1] - 3.0).abs() < 1e-10);
assert!((x[2] - 3.0).abs() < 1e-10);
}
#[test]
fn test_banded_shape_error() {
let ab = Array2::<f64>::zeros((2, 3)); let b = array![1.0_f64, 0.0, 1.0];
assert!(solve_banded(1, 1, &ab.view(), &b.view()).is_err());
}
#[test]
fn test_banded_rhs_mismatch_error() {
let ab = Array2::<f64>::zeros((1, 3));
let b = array![1.0_f64, 0.0]; assert!(solve_banded(0, 0, &ab.view(), &b.view()).is_err());
}
#[test]
fn test_triangular_lower() {
let l = array![[1.0_f64, 0.0], [2.0, 1.0]];
let b = array![3.0_f64, 7.0];
let x = solve_triangular(&l.view(), &b.view(), true, false).expect("lower tri");
assert!((x[0] - 3.0).abs() < 1e-10 && (x[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_triangular_upper() {
let u = array![[2.0_f64, 1.0], [0.0, 3.0]];
let b = array![5.0_f64, 6.0];
let x = solve_triangular(&u.view(), &b.view(), false, false).expect("upper tri");
assert!((x[0] - 1.5).abs() < 1e-10 && (x[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_triangular_unit_diagonal_lower() {
let l = array![[9.9_f64, 0.0], [2.0, 9.9]]; let b = array![3.0_f64, 7.0];
let x = solve_triangular(&l.view(), &b.view(), true, true).expect("unit lower");
assert!((x[0] - 3.0).abs() < 1e-10 && (x[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_triangular_singular_error() {
let l = array![[0.0_f64, 0.0], [1.0, 1.0]];
let b = array![1.0_f64, 2.0];
assert!(solve_triangular(&l.view(), &b.view(), true, false).is_err());
}
#[test]
fn test_triangular_nonsquare_error() {
let t = Array2::<f64>::zeros((2, 3));
let b = array![1.0_f64, 0.0];
assert!(solve_triangular(&t.view(), &b.view(), true, false).is_err());
}
#[test]
fn test_triangular_dimension_mismatch_error() {
let t = Array2::<f64>::eye(3);
let b = array![1.0_f64, 0.0]; assert!(solve_triangular(&t.view(), &b.view(), true, false).is_err());
}
#[test]
fn test_triangular_3x3_lower() {
let l = array![
[1.0_f64, 0.0, 0.0],
[2.0, 3.0, 0.0],
[4.0, 5.0, 6.0],
];
let b = array![1.0_f64, 8.0, 32.0];
let x = solve_triangular(&l.view(), &b.view(), true, false).expect("3x3 lower");
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
assert!((x[2] - 3.0).abs() < 1e-10);
}
#[test]
fn test_triangular_3x3_upper() {
let u = array![
[6.0_f64, 5.0, 4.0],
[0.0, 3.0, 2.0],
[0.0, 0.0, 1.0],
];
let b = array![32.0_f64, 8.0, 1.0];
let x = solve_triangular(&u.view(), &b.view(), false, false).expect("3x3 upper");
assert!((x[2] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
assert!((x[0] - 3.0).abs() < 1e-10);
}
}