use crate::dense_vector::{DenseVector, DenseVectorMut};
use crate::errors::{LinalgError, SingularMatrixInfo};
use crate::indexing::SpIndex;
use crate::sparse::CsMatViewI;
use crate::sparse::CsVecViewI;
use crate::stack::{self, DStack, StackVal};
use num_traits::Num;
fn check_solver_dimensions<N, I, Iptr, V>(
lower_tri_mat: &CsMatViewI<N, I, Iptr>,
rhs: &V,
) where
V: DenseVector<Scalar = N> + ?Sized,
I: SpIndex,
Iptr: SpIndex,
{
let (cols, rows) = (lower_tri_mat.cols(), lower_tri_mat.rows());
assert_eq!(cols, rows, "Non square matrix passed to solver");
assert_eq!(cols, rhs.dim(), "Dimension mismatch");
}
pub fn lsolve_csr_dense_rhs<N, I, Iptr, V>(
lower_tri_mat: CsMatViewI<N, I, Iptr>,
mut rhs: V,
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N: std::ops::Mul<&'r N, Output = N>,
V: DenseVectorMut<Scalar = N>,
I: SpIndex,
Iptr: SpIndex,
{
check_solver_dimensions(&lower_tri_mat, &rhs);
assert!(lower_tri_mat.is_csr(), "Storage mismatch");
for (row_ind, row) in lower_tri_mat.outer_iterator().enumerate() {
let mut diag_val = N::zero();
let mut x = rhs.index(row_ind).clone();
for (col_ind, val) in row.iter() {
if col_ind == row_ind {
diag_val = val.clone();
continue;
}
if col_ind > row_ind {
continue;
}
x -= val * rhs.index(col_ind);
}
if diag_val == N::zero() {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: row_ind,
reason: "diagonal element is 0",
}));
}
*rhs.index_mut(row_ind) = x / diag_val;
}
Ok(())
}
pub fn lsolve_csc_dense_rhs<N, I, Iptr, V>(
lower_tri_mat: CsMatViewI<N, I, Iptr>,
mut rhs: V,
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N:
std::ops::Mul<&'r N, Output = N> + std::ops::Div<&'r N, Output = N>,
V: DenseVectorMut<Scalar = N>,
I: SpIndex,
Iptr: SpIndex,
{
check_solver_dimensions(&lower_tri_mat, &rhs);
assert!(lower_tri_mat.is_csc(), "Storage mismatch");
for (col_ind, col) in lower_tri_mat.outer_iterator().enumerate() {
lspsolve_csc_process_col(col, col_ind, &mut rhs)?;
}
Ok(())
}
fn lspsolve_csc_process_col<N, I, V>(
col: CsVecViewI<N, I>,
col_ind: usize,
rhs: &mut V,
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N:
std::ops::Mul<&'r N, Output = N> + std::ops::Div<&'r N, Output = N>,
V: DenseVectorMut<Scalar = N>,
I: SpIndex,
{
if let Some(diag_val) = col.get(col_ind) {
if *diag_val == N::zero() {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: col_ind,
reason: "diagonal element is a numeric 0",
}));
}
let b = rhs.index(col_ind);
let x = b / diag_val;
*rhs.index_mut(col_ind) = x.clone();
for (row_ind, val) in col.iter() {
if row_ind <= col_ind {
continue;
}
*rhs.index_mut(row_ind) -= val * &x;
}
} else {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: col_ind,
reason: "diagonal element is a structural 0",
}));
}
Ok(())
}
pub fn usolve_csc_dense_rhs<N, I, Iptr, V>(
upper_tri_mat: CsMatViewI<N, I, Iptr>,
mut rhs: V,
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N:
std::ops::Mul<&'r N, Output = N> + std::ops::Div<&'r N, Output = N>,
V: DenseVectorMut<Scalar = N>,
I: SpIndex,
Iptr: SpIndex,
{
check_solver_dimensions(&upper_tri_mat, &rhs);
assert!(upper_tri_mat.is_csc(), "Storage mismatch");
for (col_ind, col) in upper_tri_mat.outer_iterator().enumerate().rev() {
if let Some(diag_val) = col.get(col_ind) {
if *diag_val == N::zero() {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: col_ind,
reason: "diagonal element is a numeric 0",
}));
}
let b = rhs.index(col_ind);
let x = b / diag_val;
*rhs.index_mut(col_ind) = x.clone();
for (row_ind, val) in col.iter() {
if row_ind >= col_ind {
continue;
}
*rhs.index_mut(row_ind) -= val * &x;
}
} else {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: col_ind,
reason: "diagonal element is a structural 0",
}));
}
}
Ok(())
}
pub fn usolve_csr_dense_rhs<N, I, Iptr, V>(
upper_tri_mat: CsMatViewI<N, I, Iptr>,
mut rhs: V,
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N:
std::ops::Mul<&'r N, Output = N> + std::ops::Div<&'r N, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
I: SpIndex,
Iptr: SpIndex,
{
check_solver_dimensions(&upper_tri_mat, &rhs);
assert!(upper_tri_mat.is_csr(), "Storage mismatch");
for (row_ind, row) in upper_tri_mat.outer_iterator().enumerate().rev() {
let mut diag_val = N::zero();
let mut x = rhs.index(row_ind).clone();
for (col_ind, val) in row.iter() {
if col_ind == row_ind {
diag_val = val.clone();
continue;
}
if col_ind < row_ind {
continue;
}
x -= val * rhs.index(col_ind);
}
if diag_val == N::zero() {
return Err(LinalgError::SingularMatrix(SingularMatrixInfo {
index: row_ind,
reason: "diagonal element is a numeric 0",
}));
}
*rhs.index_mut(row_ind) = x / diag_val;
}
Ok(())
}
pub fn lsolve_csc_sparse_rhs<N, I, Iptr, V>(
lower_tri_mat: CsMatViewI<N, I, Iptr>,
rhs: CsVecViewI<N, I>,
dstack: &mut DStack<StackVal<usize>>,
mut x_workspace: V,
visited: &mut [bool],
) -> Result<(), LinalgError>
where
N: Clone + Num + std::ops::SubAssign,
for<'r> &'r N:
std::ops::Mul<&'r N, Output = N> + std::ops::Div<&'r N, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
I: SpIndex,
Iptr: SpIndex,
{
assert!(lower_tri_mat.is_csc(), "Storage mismatch");
let n = lower_tri_mat.rows();
assert!(dstack.capacity() >= 2 * n, "dstack cap should be 2*n");
assert!(
dstack.is_left_empty() && dstack.is_right_empty(),
"dstack should be empty"
);
assert!(x_workspace.dim() == n, "x should be of len n");
for (root_ind, _) in rhs.iter() {
if visited[root_ind] {
continue;
}
dstack.push_left(StackVal::Enter(root_ind));
while let Some(stack_val) = dstack.pop_left() {
match stack_val {
StackVal::Enter(ind) => {
if visited[ind] {
continue;
}
visited[ind] = true;
dstack.push_left(StackVal::Exit(ind));
if let Some(column) = lower_tri_mat.outer_view(ind) {
for (child_ind, _) in column.iter() {
dstack.push_left(StackVal::Enter(child_ind));
}
} else {
unreachable!();
}
}
StackVal::Exit(ind) => {
dstack.push_right(StackVal::Enter(ind));
}
}
}
}
rhs.scatter(&mut x_workspace);
for &ind in dstack.iter_right().map(stack::extract_stack_val) {
println!("ind: {ind}");
let col = lower_tri_mat.outer_view(ind).expect("ind not in bounds");
lspsolve_csc_process_col(col, ind, &mut x_workspace)?;
}
Ok(())
}
#[cfg(test)]
mod test {
use crate::sparse::{CsMat, CsVec};
use crate::stack::{self, DStack};
use ndarray::arr1;
use std::collections::HashSet;
#[test]
fn lsolve_csr_dense_rhs() {
let l = CsMat::new(
(3, 3),
vec![0, 1, 2, 4],
vec![0, 1, 0, 2],
vec![1, 2, 1, 1],
);
let b = arr1(&[3, 2, 4]);
let mut x = b.clone();
super::lsolve_csr_dense_rhs(l.view(), x.view_mut()).unwrap();
assert_eq!(x, arr1(&[3, 1, 1]));
}
#[test]
fn lsolve_csc_dense_rhs() {
let l = CsMat::new_csc(
(3, 3),
vec![0, 2, 3, 4],
vec![0, 1, 1, 2],
vec![1, 1, 2, 3],
);
let b = vec![3, 5, 3];
let mut x = b.clone();
super::lsolve_csc_dense_rhs(l.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
let x: &mut [i32] = &mut [3, 5, 3];
super::lsolve_csc_dense_rhs(l.view(), &mut x[..]).unwrap();
assert_eq!(x, &[3, 1, 1]);
}
#[test]
fn usolve_csc_dense_rhs() {
let u = CsMat::new_csc(
(3, 3),
vec![0, 1, 2, 4],
vec![0, 1, 0, 2],
vec![1, 2, 1, 3],
);
let b = vec![4, 2, 3];
let mut x = b.clone();
super::usolve_csc_dense_rhs(u.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn usolve_csr_dense_rhs() {
let u = CsMat::new(
(3, 3),
vec![0, 2, 4, 5],
vec![0, 1, 1, 2, 2],
vec![1, 1, 5, 3, 1],
);
let b = vec![4, 8, 1];
let mut x = b.clone();
super::usolve_csr_dense_rhs(u.view(), &mut x).unwrap();
assert_eq!(x, vec![3, 1, 1]);
}
#[test]
fn lspsolve_csc() {
let l = CsMat::new_csc(
(5, 5),
vec![0, 2, 5, 6, 8, 9],
vec![0, 1, 1, 2, 4, 2, 3, 4, 4],
vec![1, 1, 2, 3, 2, 3, 7, 3, 5],
);
let b = CsVec::new(5, vec![1, 2, 4], vec![4, 9, 9]);
let mut xw = vec![1; 5]; let mut visited = vec![false; 5]; let mut dstack = DStack::with_capacity(2 * 5);
super::lsolve_csc_sparse_rhs(
l.view(),
b.view(),
&mut dstack,
&mut xw,
&mut visited,
)
.unwrap();
let x: HashSet<_> = dstack
.iter_right()
.map(stack::extract_stack_val)
.map(|&i| (i, xw[i]))
.collect();
let expected_output = CsVec::new(5, vec![1, 2, 4], vec![2, 1, 1]);
let expected_output = expected_output.to_set();
assert_eq!(x, expected_output);
let l = CsMat::new_csc(
(7, 7),
vec![0, 2, 4, 6, 7, 9, 10, 11],
vec![0, 2, 1, 6, 2, 5, 3, 4, 6, 5, 6],
vec![1, 1, 2, 3, 3, 1, 7, 5, 2, 1, 2],
);
let b = CsVec::new(7, vec![0, 2, 3, 5], vec![1, 7, 7, 3]);
let mut dstack = DStack::with_capacity(2 * 7);
let mut xw = vec![1; 7]; let mut visited = vec![false; 7];
super::lsolve_csc_sparse_rhs(
l.view(),
b.view(),
&mut dstack,
&mut xw,
&mut visited,
)
.unwrap();
let x: HashSet<_> = dstack
.iter_right()
.map(stack::extract_stack_val)
.map(|&i| (i, xw[i]))
.collect();
let expected_output =
CsVec::new(7, vec![0, 2, 3, 5], vec![1, 2, 1, 1]).to_set();
assert_eq!(x, expected_output);
}
}