pub trait SaeKroneckerRow {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]);
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]);
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]);
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]);
}
#[derive(Debug, Clone)]
pub struct SaeKroneckerRows {
pub(crate) p: usize,
pub(crate) a_phi: Vec<Vec<(usize, f64)>>,
pub(crate) local_jac: Vec<Vec<f64>>,
}
impl SaeKroneckerRows {
pub fn new(p: usize, a_phi: Vec<Vec<(usize, f64)>>, local_jac: Vec<Vec<f64>>) -> Self {
assert_eq!(
a_phi.len(),
local_jac.len(),
"SaeKroneckerRows: a_phi rows ({}) != local_jac rows ({})",
a_phi.len(),
local_jac.len(),
);
Self {
p,
a_phi,
local_jac,
}
}
}
impl SaeKroneckerRow for SaeKroneckerRows {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += phi * x_beta[beta_base + j];
}
}
}
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]) {
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
y_beta[beta_base + j] += phi * u[j];
}
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let mut acc = 0.0_f64;
for j in 0..self.p {
acc += jac[c * self.p + j] * u[j];
}
w_out[c] = acc;
}
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let vc = v[c];
if vc == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += jac[c * self.p + j] * vc;
}
}
}
}