#[allow(unused_imports)]
use crate::algebra::blas::{dot_conj, nrm2};
#[allow(unused_imports)]
use crate::algebra::prelude::*;
use crate::matrix::sparse::CsrMatrix;
#[derive(Clone, Debug)]
pub struct CsrPattern {
pub nrows: usize,
pub ncols: usize,
pub row_ptr: Vec<usize>,
pub col_idx: Vec<usize>,
}
pub fn rap_symbolic(r: &CsrMatrix<f64>, a: &CsrMatrix<f64>, p: &CsrMatrix<f64>) -> CsrPattern {
let nc = r.nrows();
let rp_r = r.row_ptr();
let cj_r = r.col_idx();
let rp_a = a.row_ptr();
let cj_a = a.col_idx();
let rp_p = p.row_ptr();
let cj_p = p.col_idx();
let mut row_ptr = Vec::with_capacity(nc + 1);
let mut col_idx: Vec<usize> = Vec::new();
row_ptr.push(0);
for i in 0..nc {
let mut cols: Vec<usize> = Vec::new();
let rs_r = rp_r[i];
let re_r = rp_r[i + 1];
for rpos in rs_r..re_r {
let k = cj_r[rpos]; let rs_a = rp_a[k];
let re_a = rp_a[k + 1];
for apos in rs_a..re_a {
let j = cj_a[apos]; let rs_p = rp_p[j];
let re_p = rp_p[j + 1];
for ppos in rs_p..re_p {
cols.push(cj_p[ppos]);
}
}
}
if !cols.is_empty() {
cols.sort_unstable();
cols.dedup();
col_idx.extend_from_slice(&cols);
}
row_ptr.push(col_idx.len());
}
CsrPattern {
nrows: r.nrows(),
ncols: p.ncols(),
row_ptr,
col_idx,
}
}
pub fn rap_numeric(
pat: &CsrPattern,
r: &CsrMatrix<f64>,
a: &CsrMatrix<f64>,
p: &CsrMatrix<f64>,
out_vals: &mut [f64],
) {
assert_eq!(out_vals.len(), pat.col_idx.len());
out_vals.fill(0.0);
let rp_r = r.row_ptr();
let cj_r = r.col_idx();
let vv_r = r.values();
let rp_a = a.row_ptr();
let cj_a = a.col_idx();
let vv_a = a.values();
let rp_p = p.row_ptr();
let cj_p = p.col_idx();
let vv_p = p.values();
let pr = &pat.row_ptr;
let pc = &pat.col_idx;
for i in 0..pat.nrows {
let row_start = pr[i];
let row_end = pr[i + 1];
if row_start == row_end {
continue;
}
let len = row_end - row_start;
let cols: Vec<usize> = pc[row_start..row_end].to_vec();
let mut vals: Vec<f64> = vec![0.0; len];
let rs_r = rp_r[i];
let re_r = rp_r[i + 1];
for rpos in rs_r..re_r {
let k = cj_r[rpos];
let r_ik = vv_r[rpos];
let rs_a = rp_a[k];
let re_a = rp_a[k + 1];
for apos in rs_a..re_a {
let j = cj_a[apos];
let a_kj = vv_a[apos];
let rs_p = rp_p[j];
let re_p = rp_p[j + 1];
for ppos in rs_p..re_p {
let c = cj_p[ppos];
let v = r_ik * a_kj * vv_p[ppos];
if let Ok(idx) = cols.binary_search(&c) {
vals[idx] += v;
}
}
}
}
out_vals[row_start..row_end].copy_from_slice(&vals);
}
}