use crate::csc::CscMatrix;
use crate::ops::Op;
use crate::ops::serial::cs::{
spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc, spmm_cs_prealloc_unchecked,
};
use crate::ops::serial::{OperationError, OperationErrorKind};
use nalgebra::{ClosedAddAssign, ClosedMulAssign, DMatrixView, DMatrixViewMut, RealField, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow;
pub fn spmm_csc_dense<'a, T>(
beta: T,
c: impl Into<DMatrixViewMut<'a, T>>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<impl Into<DMatrixView<'a, T>>>,
) where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
{
let b = b.convert();
spmm_csc_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csc_dense_<T>(
beta: T,
c: DMatrixViewMut<'_, T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<DMatrixView<'_, T>>,
) where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
let a = a.transposed().map_same_op(|a| &a.cs);
spmm_cs_dense(beta, c, alpha, a, b)
}
pub fn spadd_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
}
pub fn spmm_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
use Op::NoOp;
match (&a, &b) {
(NoOp(a), NoOp(b)) => {
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
}
_ => spmm_csc_transposed(beta, c, alpha, a, b, spmm_csc_prealloc),
}
}
pub fn spmm_csc_prealloc_unchecked<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
use Op::NoOp;
match (&a, &b) {
(NoOp(a), NoOp(b)) => {
spmm_cs_prealloc_unchecked(beta, &mut c.cs, alpha, &b.cs, &a.cs)
}
_ => spmm_csc_transposed(beta, c, alpha, a, b, spmm_csc_prealloc_unchecked),
}
}
fn spmm_csc_transposed<T, F>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>,
spmm_kernel: F,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One,
F: Fn(
T,
&mut CscMatrix<T>,
T,
Op<&CscMatrix<T>>,
Op<&CscMatrix<T>>,
) -> Result<(), OperationError>,
{
use Op::{NoOp, Transpose};
let a_ref: &CscMatrix<T> = a.inner_ref();
let b_ref: &CscMatrix<T> = b.inner_ref();
let (a, b) = {
use Cow::*;
match (&a, &b) {
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(a), Transpose(b)) => (Owned(a.transpose()), Owned(b.transpose())),
}
};
spmm_kernel(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
}
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
l: Op<&CscMatrix<T>>,
b: impl Into<DMatrixViewMut<'a, T>>,
) -> Result<(), OperationError> {
let b = b.into();
let l_matrix = l.into_inner();
assert_eq!(
l_matrix.nrows(),
l_matrix.ncols(),
"Matrix must be square for triangular solve."
);
assert_eq!(
l_matrix.nrows(),
b.nrows(),
"Dimension mismatch in sparse lower triangular solver."
);
match l {
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
}
}
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixViewMut<'_, T>,
) -> Result<(), OperationError> {
let mut x = b;
for j in 0..x.ncols() {
let mut x_col_j = x.column_mut(j);
for k in 0..l.ncols() {
let l_col_k = l.col(k);
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
if let Some(diag_csc_index) = diag_csc_index {
let l_kk = l_col_k.values()[diag_csc_index].clone();
if l_kk != T::zero() {
x_col_j[k] /= l_kk;
let x_kj = x_col_j[k].clone();
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
for (&i, l_ik) in row_indices.iter().zip(l_values) {
let x_ij = &mut x_col_j[i];
*x_ij -= l_ik.clone() * x_kj.clone();
}
x_col_j[k] = x_kj;
} else {
return spsolve_encountered_zero_diagonal();
}
} else {
return spsolve_encountered_zero_diagonal();
}
}
}
Ok(())
}
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
let message = "Matrix contains at least one diagonal entry that is zero.";
Err(OperationError::from_kind_and_message(
OperationErrorKind::Singular,
String::from(message),
))
}
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixViewMut<'_, T>,
) -> Result<(), OperationError> {
let mut x = b;
for j in 0..x.ncols() {
let mut x_col_j = x.column_mut(j);
for i in (0..l.ncols()).rev() {
let l_col_i = l.col(i);
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
if let Some(diag_csc_index) = diag_csc_index {
let l_ii = l_col_i.values()[diag_csc_index].clone();
if l_ii != T::zero() {
let mut x_ii = x_col_j[i].clone();
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
for (k, l_ki) in row_indices.iter().zip(a_values) {
let x_kj = x_col_j[*k].clone();
x_ii -= l_ki.clone() * x_kj;
}
x_col_j[i] = x_ii / l_ii;
} else {
return spsolve_encountered_zero_diagonal();
}
} else {
return spsolve_encountered_zero_diagonal();
}
}
}
Ok(())
}