use std::ops::Mul;
use na::{DMatrix, DMatrixView, DVector, DVectorView, DVectorViewMut};
use crate::csv::CsVecRef;
use crate::{
CscMatrixView, CsrMatrix, CsrMatrixView, CsrMatrixViewMethods, DiagonalBlockMatrixView, Real,
};
pub(crate) fn mul_csr_bd_to<T>(
a: CsrMatrixView<T>,
b: DiagonalBlockMatrixView<T>,
o: &mut CsrMatrix<T>,
) where
T: Real,
{
assert_eq!(a.ncols(), b.nrows());
assert_eq!(b.ncols(), o.ncols());
assert_eq!(o.nrows(), 0);
for i in 0..a.nrows() {
let mut or = o.new_row_builder(T::zero_threshold());
let ar = a.get_row(i);
let mut a_col_start = 0;
for bindex in 0..b.num_blocks() {
let range = b.get_block_row_range(bindex);
let block = b.view_block(bindex);
let mut a_n = 0;
for col in ar.indices().iter().skip(a_col_start) {
if *col >= range.end {
break;
}
a_n += 1;
}
for j in 0..block.ncols() {
let mut o_ij = T::zero();
let col = block.column(j);
for (k, a_ik) in ar.iter().skip(a_col_start).take(a_n) {
o_ij += a_ik * col[k - range.start];
}
if o_ij.abs() > T::zero_threshold() {
or.push(range.start + j, o_ij);
}
}
a_col_start += a_n;
}
}
}
pub(crate) fn mul_csr_csc_to<T: Real>(
a: CsrMatrixView<T>,
b: CscMatrixView<T>,
o: &mut CsrMatrix<T>,
) {
assert_eq!(o.nrows(), 0);
assert_eq!(a.ncols(), b.nrows());
for i in 0..a.nrows() {
let mut or = o.new_row_builder(T::zero_threshold());
let ar = a.get_row(i);
for j in 0..b.ncols() {
let bc = b.get_col(j);
let o_ij = dot_csvec(ar, bc);
if o_ij.abs() > T::zero_threshold() {
or.push(j, o_ij);
}
}
}
}
#[inline]
pub(crate) fn mul_csr_dvec_to_dvec<T: Real>(
a: CsrMatrixView<T>,
b: DVectorView<T>,
mut o: DVectorViewMut<T>,
) {
assert_eq!(a.nrows(), o.len());
assert!(!b.is_empty());
for i in 0..a.nrows() {
let ar = a.get_row(i);
let mut o_i = T::zero();
for (j, a_ij) in ar.iter() {
o_i += a_ij * b[j];
}
o[i] = o_i;
}
}
pub(crate) fn mul_csc_dvec_to_dvec<T>(
a: CscMatrixView<T>,
b: DVectorView<T>,
mut o: DVectorViewMut<T>,
) where
T: Real,
{
assert_eq!(a.ncols(), b.len());
assert_eq!(o.len(), a.nrows());
assert!(!b.is_empty());
debug_assert!(o.iter().all(|v| v.abs() < T::zero_threshold()));
for j in 0..a.ncols() {
let aj = a.get_col(j);
for (i, a_ij) in aj.iter() {
o[i] += a_ij * b[j];
}
}
}
pub(crate) fn mul_csr_dmat_to_csr<T>(a: CsrMatrixView<T>, b: DMatrixView<T>, o: &mut CsrMatrix<T>)
where
T: Real,
{
assert_eq!(a.ncols(), b.nrows());
assert_eq!(o.ncols(), b.ncols());
assert_eq!(o.nrows(), 0);
for i in 0..a.nrows() {
let ar = a.get_row(i);
let mut or = o.new_row_builder(T::zero_threshold());
for j in 0..b.ncols() {
let bc = b.column(j);
let o_ij = dot_csv_dv(ar, bc);
if o_ij.abs() > T::zero_threshold() {
or.push(j, o_ij);
}
}
}
}
pub(crate) fn dot_csvec<T: Real>(a: CsVecRef<T>, b: CsVecRef<T>) -> T {
let mut res = T::zero();
let col_a = a.indices();
let col_b = b.indices();
let values_a = a.values();
let values_b = b.values();
let mut ia = 0;
let mut ib = 0;
unsafe {
while ia < col_a.len() && ib < col_b.len() {
let ca = *col_a.get_unchecked(ia);
let cb = *col_b.get_unchecked(ib);
match ca.cmp(&cb) {
std::cmp::Ordering::Less => {
ia += 1;
}
std::cmp::Ordering::Equal => {
res += *values_a.get_unchecked(ia) * *values_b.get_unchecked(ib);
ia += 1;
ib += 1;
}
std::cmp::Ordering::Greater => {
ib += 1;
}
}
}
}
res
}
pub(crate) fn dot_csv_dv<T>(a: CsVecRef<T>, b: DVectorView<T>) -> T
where
T: Real,
{
assert_eq!(a.len(), b.len());
let mut res = T::zero();
for (i, a_ij) in a.iter() {
res += a_ij * b[i];
}
res
}
pub(crate) fn add_csv_dv<T>(a: CsVecRef<T>, d: DVectorView<T>, mut o: DVectorViewMut<T>)
where
T: Real,
{
assert_eq!(a.len(), d.len());
assert_eq!(a.len(), o.len());
o.copy_from(&d);
for (i, a_ij) in a.iter() {
o[i] += a_ij * d[i];
}
}
pub(crate) fn mul_bd_vec<T>(
a: DiagonalBlockMatrixView<T>,
b: DVectorView<T>,
mut o: DVectorViewMut<T>,
) where
T: Real,
{
assert_eq!(a.ncols(), b.len());
assert_eq!(a.nrows(), o.len());
let mut element_offset = 0;
let mut row_offset = 0;
for block_index in 0..a.num_blocks() {
let block_size = a.get_block_size(block_index);
let block_size2 = block_size * block_size;
let block = DMatrixView::from_slice(
&a.values()[element_offset..element_offset + block_size2],
block_size,
block_size,
);
let row_range = row_offset..row_offset + block_size;
let mut o = o.rows_range_mut(row_range.clone());
let b = b.rows_range(row_range);
block.mul_to(&b, &mut o);
element_offset += block_size2;
row_offset += block_size;
}
}
pub fn mul_add_diag_to_csr<T: Real>(
o: &mut CsrMatrix<T>,
diag_scale: DVectorView<T>,
diag_add: DVectorView<T>,
) {
assert_eq!(o.ncols(), o.nrows());
assert_eq!(o.nrows(), diag_scale.len());
assert_eq!(o.nrows(), diag_add.len());
for i in 0..o.nrows() {
let row = o.get_row_mut(i);
let k = row
.col_indices
.binary_search(&i)
.expect("Diagonal element must be present in CSR row");
row.values[k] = row.values[k] * diag_scale[i] + diag_add[i];
}
}
impl<'a, T: Real> Mul<DVectorView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
type Output = DVector<T>;
#[inline]
fn mul(self, rhs: DVectorView<'a, T>) -> DVector<T> {
let mut o = DVector::zeros(self.nrows());
mul_bd_vec(self, rhs, o.as_view_mut());
o
}
}
impl<'a, T: Real> Mul<DVector<T>> for DiagonalBlockMatrixView<'a, T> {
type Output = DVector<T>;
#[inline]
fn mul(self, rhs: DVector<T>) -> DVector<T> {
let mut o = DVector::zeros(self.nrows());
mul_bd_vec(self, rhs.as_view(), o.as_view_mut());
o
}
}
impl<'a, T: Real> Mul<DiagonalBlockMatrixView<'a, T>> for CsrMatrixView<'a, T> {
type Output = CsrMatrix<T>;
fn mul(self, rhs: DiagonalBlockMatrixView<'a, T>) -> Self::Output {
let mut o = CsrMatrix::new(rhs.ncols());
mul_csr_bd_to(self, rhs, &mut o);
o
}
}
impl<'a, T: Real> Mul<CscMatrixView<'a, T>> for CsrMatrixView<'a, T> {
type Output = CsrMatrix<T>;
fn mul(self, rhs: CscMatrixView<'a, T>) -> Self::Output {
let mut o = CsrMatrix::new(rhs.ncols());
mul_csr_csc_to(self, rhs, &mut o);
o
}
}
impl<'a, T: Real> Mul<DMatrixView<'a, T>> for CsrMatrixView<'a, T> {
type Output = CsrMatrix<T>;
fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
let mut o = CsrMatrix::new(rhs.ncols());
mul_csr_dmat_to_csr(self, rhs, &mut o);
o
}
}
impl<'a, T: Real> Mul<DVectorView<'a, T>> for CsrMatrixView<'a, T> {
type Output = DVector<T>;
#[inline]
fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
let mut o = DVector::zeros(self.nrows());
mul_csr_dvec_to_dvec(self, rhs, o.as_view_mut());
o
}
}
impl<'a, T: Real> Mul<DVector<T>> for CsrMatrixView<'a, T> {
type Output = DVector<T>;
#[inline]
fn mul(self, rhs: DVector<T>) -> Self::Output {
let rhs: DVectorView<T> = rhs.as_view();
self * rhs
}
}
impl<'a, T: Real> Mul<DVectorView<'a, T>> for CscMatrixView<'a, T> {
type Output = DVector<T>;
#[inline]
fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
let mut o = DVector::zeros(self.nrows());
mul_csc_dvec_to_dvec(self, rhs, o.as_view_mut());
o
}
}
impl<'a, T: Real> Mul<DMatrixView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
type Output = DMatrix<T>;
fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
let mut result = DMatrix::zeros(self.nrows(), rhs.ncols());
for bindex in 0..self.num_blocks() {
let block = self.view_block(bindex);
let range = self.get_block_row_range(bindex);
let mut output = result.rows_range_mut(range.clone());
let rhs = rhs.rows_range(range);
block.mul_to(&rhs, &mut output);
}
result
}
}
mod add {
use std::ops::Add;
use super::*;
impl<'a, T: Real> Add<DVectorView<'a, T>> for CsVecRef<'a, T> {
type Output = DVector<T>;
#[inline]
fn add(self, rhs: DVectorView<'a, T>) -> Self::Output {
let mut o = DVector::zeros(self.len());
add_csv_dv(self, rhs, o.as_view_mut());
o
}
}
}