pub trait SaeKroneckerRow {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]);
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]);
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]);
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]);
}
#[derive(Debug, Clone)]
pub struct SaeKroneckerRows {
pub(crate) p: usize,
pub(crate) a_phi: std::sync::Arc<[Vec<(usize, f64)>]>,
pub(crate) local_jac: std::sync::Arc<[Vec<f64>]>,
}
impl SaeKroneckerRows {
pub fn new(
p: usize,
a_phi: std::sync::Arc<[Vec<(usize, f64)>]>,
local_jac: std::sync::Arc<[Vec<f64>]>,
) -> Self {
assert_eq!(
a_phi.len(),
local_jac.len(),
"SaeKroneckerRows: a_phi rows ({}) != local_jac rows ({})",
a_phi.len(),
local_jac.len(),
);
Self {
p,
a_phi,
local_jac,
}
}
}
impl SaeKroneckerRow for SaeKroneckerRows {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += phi * x_beta[beta_base + j];
}
}
}
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]) {
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
y_beta[beta_base + j] += phi * u[j];
}
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let mut acc = 0.0_f64;
for j in 0..self.p {
acc += jac[c * self.p + j] * u[j];
}
w_out[c] = acc;
}
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let vc = v[c];
if vc == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += jac[c * self.p + j] * vc;
}
}
}
}
#[cfg(test)]
mod tests {
use super::SaeKroneckerRows;
use gam_solve::arrow_schur::DeviceSaePcgData;
use std::sync::Arc;
#[test]
fn device_and_kron_rows_share_backing_alloc_1033() {
let p = 6usize;
let a_phi: Arc<[Vec<(usize, f64)>]> = Arc::from(
vec![vec![(0usize, 2.0f64), (12, 1.0)], vec![(0, 0.5)]].into_boxed_slice(),
);
let jac: Arc<[Vec<f64>]> =
Arc::from(vec![vec![1.0; 4 * p], vec![2.0; 4 * p]].into_boxed_slice());
let host = SaeKroneckerRows::new(p, Arc::clone(&a_phi), Arc::clone(&jac));
let device = DeviceSaePcgData {
p,
beta_dim: 6,
a_phi: Arc::clone(&a_phi),
local_jac: Arc::clone(&jac),
smooth_blocks: Vec::new(),
sparse_g_blocks: Vec::new(),
frame: None,
};
assert!(
Arc::ptr_eq(&host.local_jac, &device.local_jac),
"host SaeKroneckerRows and DeviceSaePcgData must share one local_jac alloc"
);
assert!(
Arc::ptr_eq(&host.a_phi, &device.a_phi),
"host SaeKroneckerRows and DeviceSaePcgData must share one a_phi alloc"
);
assert_eq!(
Arc::strong_count(&jac),
3,
"exactly three references (original, host, device) share the Jacobian"
);
}
}