use faer::Accum;
use faer::Par;
use faer::linalg::matmul::matmul;
use gam_linalg::faer_ndarray::{
CrossprodAccum, CrossprodStructure, FaerArrayView, array2_to_matmut,
effective_global_parallelism, fast_atv, fast_av, stream_weighted_crossprod_into,
};
use gam_linalg::matrix::{DenseDesignOperator, LinearOperator};
use gam_problem::Gauge;
use gam_runtime::resource::MatrixMaterializationError;
use ndarray::{Array1, Array2, ArrayViewMut2, s};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use std::ops::Range;
use std::sync::{Arc, OnceLock};
const KERNEL_OPERATOR_ROW_CHUNK_SIZE: usize = 2048;
pub trait SpatialKernelEvaluator: Send + Sync + 'static {
fn eval(&self, x: &[f64], c: &[f64]) -> f64;
}
impl<F> SpatialKernelEvaluator for F
where
F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static,
{
fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
self(x, c)
}
}
impl<F> SpatialKernelEvaluator for Arc<F>
where
F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static + ?Sized,
{
fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
self.as_ref()(x, c)
}
}
pub struct ChunkedKernelDesignOperator<K: SpatialKernelEvaluator> {
data: Arc<Array2<f64>>,
centers: Arc<Array2<f64>>,
kernel: K,
kernel_gauge: Option<Arc<Gauge>>,
poly_basis: Option<Arc<Array2<f64>>>,
n: usize,
total_cols: usize,
materialized: OnceLock<Option<Arc<Array2<f64>>>>,
}
impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
pub fn new(
data: Arc<Array2<f64>>,
centers: Arc<Array2<f64>>,
kernel: K,
kernel_gauge: Option<Arc<Gauge>>,
poly_basis: Option<Arc<Array2<f64>>>,
) -> Result<Self, String> {
let n = data.nrows();
let k = centers.nrows();
if data.ncols() != centers.ncols() {
return Err(format!(
"ChunkedKernelDesignOperator: data dim {} != centers dim {}",
data.ncols(),
centers.ncols(),
));
}
if let Some(gauge) = kernel_gauge.as_ref()
&& gauge.raw_total() != k
{
return Err(format!(
"ChunkedKernelDesignOperator: kernel gauge raw width {} != centers rows {}",
gauge.raw_total(),
k,
));
}
if let Some(poly) = poly_basis.as_ref()
&& poly.nrows() != n
{
return Err(format!(
"ChunkedKernelDesignOperator: poly_basis rows {} != data rows {}",
poly.nrows(),
n,
));
}
let k_eff = kernel_gauge.as_ref().map_or(k, |g| g.reduced_total());
let poly_cols = poly_basis.as_ref().map_or(0, |p| p.ncols());
Ok(Self {
data: Arc::new(data.as_standard_layout().to_owned()),
centers: Arc::new(centers.as_standard_layout().to_owned()),
kernel,
kernel_gauge,
poly_basis,
n,
total_cols: k_eff + poly_cols,
materialized: OnceLock::new(),
})
}
const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
fn materialized_combined(&self) -> Option<&Array2<f64>> {
if let Some(slot) = self.materialized.get() {
return slot.as_ref().map(|a| a.as_ref());
}
let bytes = self
.n
.checked_mul(self.total_cols)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
let computed = match bytes {
Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
Some(Arc::new(self.build_row_chunk_combined(0..self.n)))
}
_ => None,
};
if self.materialized.set(computed).is_err() {
return self
.materialized
.get()
.and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
}
self.materialized
.get()
.and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
}
fn kernel_chunk(&self, rows: Range<usize>) -> Array2<f64> {
let chunk_n = rows.end - rows.start;
let k_raw = self.centers.nrows();
let dim = self.data.ncols();
let data = self
.data
.as_slice()
.expect("ChunkedKernelDesignOperator stores standard-layout data");
let centers = self
.centers
.as_slice()
.expect("ChunkedKernelDesignOperator stores standard-layout centers");
let kernel = &self.kernel;
let mut values = vec![0.0_f64; chunk_n * k_raw];
values
.par_chunks_mut(k_raw)
.enumerate()
.for_each(|(local, out_row)| {
let global = rows.start + local;
let x_start = global * dim;
let x = &data[x_start..x_start + dim];
for j in 0..k_raw {
let c_start = j * dim;
out_row[j] = kernel.eval(x, ¢ers[c_start..c_start + dim]);
}
});
let kernel_block = Array2::from_shape_vec((chunk_n, k_raw), values)
.expect("kernel chunk shape should match generated values");
if let Some(gauge) = self.kernel_gauge.as_ref() {
gauge.restrict_design(&kernel_block)
} else {
kernel_block
}
}
}
impl<K: SpatialKernelEvaluator> LinearOperator for ChunkedKernelDesignOperator<K> {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.total_cols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return fast_av(combined, vector);
}
let k_eff = self
.kernel_gauge
.as_ref()
.map_or(self.centers.nrows(), |g| g.reduced_total());
let v_kernel = vector.slice(s![..k_eff]);
let mut result = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.kernel_chunk(start..end);
let partial = fast_av(&chunk, &v_kernel);
result.slice_mut(s![start..end]).assign(&partial);
}
if let Some(poly) = self.poly_basis.as_ref() {
let v_poly = vector.slice(s![k_eff..]);
let poly_part = fast_av(poly, &v_poly);
result += &poly_part;
}
result
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return fast_atv(combined, vector);
}
let k_eff = self
.kernel_gauge
.as_ref()
.map_or(self.centers.nrows(), |g| g.reduced_total());
let mut result = Array1::<f64>::zeros(self.total_cols);
for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.kernel_chunk(start..end);
let v_slice = vector.slice(s![start..end]);
let partial = fast_atv(&chunk, &v_slice);
result.slice_mut(s![..k_eff]).scaled_add(1.0, &partial);
}
if let Some(poly) = self.poly_basis.as_ref() {
let poly_part = fast_atv(poly, vector);
result.slice_mut(s![k_eff..]).assign(&poly_part);
}
result
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let p = self.total_cols;
if let Some(combined) = self.materialized_combined() {
let mut xtwx = Array2::<f64>::zeros((p, p));
stream_weighted_crossprod_into(
combined,
weights,
&mut xtwx,
CrossprodStructure::Full,
CrossprodAccum::Replace,
effective_global_parallelism(),
);
return Ok(xtwx);
}
let n = self.n;
if n == 0 || p == 0 {
return Ok(Array2::<f64>::zeros((p, p)));
}
let chunk_starts: Vec<usize> = (0..n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE).collect();
let xtwx = chunk_starts
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((p, p)),
|mut acc, start| {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(n);
let chunk = self.row_chunk_combined(start..end);
let mut wchunk = chunk.clone();
for local in 0..(end - start) {
let wi = weights[start + local];
wchunk.row_mut(local).mapv_inplace(|v| v * wi);
}
let chunk_view = FaerArrayView::new(&chunk);
let wchunk_view = FaerArrayView::new(&wchunk);
let mut acc_view = array2_to_matmut(&mut acc);
matmul(
acc_view.as_mut(),
Accum::Add,
chunk_view.as_ref().transpose(),
wchunk_view.as_ref(),
1.0,
Par::Seq,
);
acc
},
)
.reduce(
|| Array2::<f64>::zeros((p, p)),
|mut a, b| {
a += &b;
a
},
);
Ok(xtwx)
}
}
impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
pub(crate) fn row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
if let Some(combined) = self.materialized_combined() {
return combined.slice(s![rows, ..]).to_owned();
}
self.build_row_chunk_combined(rows)
}
fn build_row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
let chunk_n = rows.end - rows.start;
let k_eff = self
.kernel_gauge
.as_ref()
.map_or(self.centers.nrows(), |g| g.reduced_total());
let kernel = self.kernel_chunk(rows.clone());
let poly_cols = self.poly_basis.as_ref().map_or(0, |p| p.ncols());
let mut combined = Array2::<f64>::zeros((chunk_n, k_eff + poly_cols));
combined.slice_mut(s![.., ..k_eff]).assign(&kernel);
if let Some(poly) = self.poly_basis.as_ref() {
combined
.slice_mut(s![.., k_eff..])
.assign(&poly.slice(s![rows, ..]));
}
combined
}
}
impl<K: SpatialKernelEvaluator> DenseDesignOperator for ChunkedKernelDesignOperator<K> {
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
self.materialized_combined()
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ChunkedKernelDesignOperator::row_chunk_into shape mismatch",
});
}
if let Some(combined) = self.materialized_combined() {
out.assign(&combined.slice(s![rows, ..]));
} else {
out.assign(&self.row_chunk_combined(rows));
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
if let Some(combined) = self.materialized_combined() {
return combined.clone();
}
self.row_chunk_combined(0..self.n)
}
}
#[cfg(test)]
mod chunked_kernel_operator_tests {
use super::*;
use gam_linalg::matrix::DenseDesignMatrix;
use ndarray::{Array1, Array2, array};
use std::sync::Arc;
#[test]
fn chunked_kernel_operator_uses_center_rows_for_column_count() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
assert_eq!(operator.ncols(), 3);
let chunk = operator.row_chunk_combined(0..2);
assert_eq!(chunk.dim(), (2, 3));
}
#[test]
fn chunked_kernel_operator_rejects_incompatible_optional_shapes() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
let kernel = |_: &[f64], _: &[f64]| 0.0;
let bad_gauge = Arc::new(gam_problem::Gauge::from_block_transforms(&[
Array2::<f64>::zeros((2, 1)),
]));
let bad_poly = Arc::new(Array2::<f64>::zeros((3, 1)));
let gauge_err = match ChunkedKernelDesignOperator::new(
data.clone(),
centers.clone(),
kernel,
Some(bad_gauge),
None,
) {
Ok(_) => panic!("gauge raw width should match centers rows"),
Err(err) => err,
};
assert!(gauge_err.contains("kernel gauge raw width 2 != centers rows 3"));
let poly_err =
match ChunkedKernelDesignOperator::new(data, centers, kernel, None, Some(bad_poly)) {
Ok(_) => panic!("poly rows should match data rows"),
Err(err) => err,
};
assert!(poly_err.contains("poly_basis rows 3 != data rows 2"));
}
#[test]
fn chunked_kernel_operator_canonicalizes_non_contiguous_inputs() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]].reversed_axes());
let centers = Arc::new(array![[0.0, 1.0, 2.0], [0.0, 1.0, -1.0]].reversed_axes());
assert!(!data.is_standard_layout());
assert!(!centers.is_standard_layout());
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
let chunk = operator.row_chunk_combined(0..2);
assert_eq!(chunk.dim(), (2, 3));
assert_eq!(chunk[[0, 0]], 0.0);
assert_eq!(chunk[[1, 1]], 1.5);
}
#[test]
fn chunked_kernel_operator_exposes_cached_dense_to_block_dispatch() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5], [2.0, -1.0]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0]]);
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let op = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
let expected = op.to_dense();
let dense_design = DenseDesignMatrix::from(Arc::new(op));
let probe = Array1::from_elem(3, 1.0);
let warmed = dense_design.apply_transpose(&probe);
assert_eq!(warmed.len(), expected.ncols());
let dense_ref = dense_design
.as_dense_ref()
.expect("DenseDesignMatrix::as_dense_ref must reach the cached kernel block");
assert_eq!(dense_ref.dim(), expected.dim());
for ((r, c), v) in expected.indexed_iter() {
assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
}
}
}