1use faer::Accum;
4use faer::Par;
5use faer::linalg::matmul::matmul;
6use gam_linalg::faer_ndarray::{
7 CrossprodAccum, CrossprodStructure, FaerArrayView, array2_to_matmut,
8 effective_global_parallelism, fast_atv, fast_av, stream_weighted_crossprod_into,
9};
10use gam_linalg::matrix::{DenseDesignOperator, LinearOperator};
11use gam_problem::Gauge;
12use gam_runtime::resource::MatrixMaterializationError;
13use ndarray::{Array1, Array2, ArrayViewMut2, s};
14use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
15use rayon::slice::ParallelSliceMut;
16use std::ops::Range;
17use std::sync::{Arc, OnceLock};
18
19const KERNEL_OPERATOR_ROW_CHUNK_SIZE: usize = 2048;
20
21pub trait SpatialKernelEvaluator: Send + Sync + 'static {
22 fn eval(&self, x: &[f64], c: &[f64]) -> f64;
23}
24
25impl<F> SpatialKernelEvaluator for F
26where
27 F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static,
28{
29 fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
30 self(x, c)
31 }
32}
33
34impl<F> SpatialKernelEvaluator for Arc<F>
35where
36 F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static + ?Sized,
37{
38 fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
39 self.as_ref()(x, c)
40 }
41}
42
43pub struct ChunkedKernelDesignOperator<K: SpatialKernelEvaluator> {
55 data: Arc<Array2<f64>>,
57 centers: Arc<Array2<f64>>,
59 kernel: K,
61 kernel_gauge: Option<Arc<Gauge>>,
63 poly_basis: Option<Arc<Array2<f64>>>,
65 n: usize,
66 total_cols: usize,
67 materialized: OnceLock<Option<Arc<Array2<f64>>>>,
71}
72
73impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
74 pub fn new(
75 data: Arc<Array2<f64>>,
76 centers: Arc<Array2<f64>>,
77 kernel: K,
78 kernel_gauge: Option<Arc<Gauge>>,
79 poly_basis: Option<Arc<Array2<f64>>>,
80 ) -> Result<Self, String> {
81 let n = data.nrows();
82 let k = centers.nrows();
83 if data.ncols() != centers.ncols() {
84 return Err(format!(
85 "ChunkedKernelDesignOperator: data dim {} != centers dim {}",
86 data.ncols(),
87 centers.ncols(),
88 ));
89 }
90 if let Some(gauge) = kernel_gauge.as_ref()
91 && gauge.raw_total() != k
92 {
93 return Err(format!(
94 "ChunkedKernelDesignOperator: kernel gauge raw width {} != centers rows {}",
95 gauge.raw_total(),
96 k,
97 ));
98 }
99 if let Some(poly) = poly_basis.as_ref()
100 && poly.nrows() != n
101 {
102 return Err(format!(
103 "ChunkedKernelDesignOperator: poly_basis rows {} != data rows {}",
104 poly.nrows(),
105 n,
106 ));
107 }
108 let k_eff = kernel_gauge.as_ref().map_or(k, |g| g.reduced_total());
109 let poly_cols = poly_basis.as_ref().map_or(0, |p| p.ncols());
110 Ok(Self {
111 data: Arc::new(data.as_standard_layout().to_owned()),
112 centers: Arc::new(centers.as_standard_layout().to_owned()),
113 kernel,
114 kernel_gauge,
115 poly_basis,
116 n,
117 total_cols: k_eff + poly_cols,
118 materialized: OnceLock::new(),
119 })
120 }
121
122 const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
130
131 fn materialized_combined(&self) -> Option<&Array2<f64>> {
148 if let Some(slot) = self.materialized.get() {
149 return slot.as_ref().map(|a| a.as_ref());
150 }
151 let bytes = self
152 .n
153 .checked_mul(self.total_cols)
154 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
155 let computed = match bytes {
156 Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
157 Some(Arc::new(self.build_row_chunk_combined(0..self.n)))
158 }
159 _ => None,
160 };
161 if self.materialized.set(computed).is_err() {
162 return self
163 .materialized
164 .get()
165 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
166 }
167 self.materialized
168 .get()
169 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
170 }
171
172 fn kernel_chunk(&self, rows: Range<usize>) -> Array2<f64> {
179 let chunk_n = rows.end - rows.start;
180 let k_raw = self.centers.nrows();
181 let dim = self.data.ncols();
182 let data = self
183 .data
184 .as_slice()
185 .expect("ChunkedKernelDesignOperator stores standard-layout data");
186 let centers = self
187 .centers
188 .as_slice()
189 .expect("ChunkedKernelDesignOperator stores standard-layout centers");
190 let kernel = &self.kernel;
191 let mut values = vec![0.0_f64; chunk_n * k_raw];
192 values
193 .par_chunks_mut(k_raw)
194 .enumerate()
195 .for_each(|(local, out_row)| {
196 let global = rows.start + local;
197 let x_start = global * dim;
198 let x = &data[x_start..x_start + dim];
199 for j in 0..k_raw {
200 let c_start = j * dim;
201 out_row[j] = kernel.eval(x, ¢ers[c_start..c_start + dim]);
202 }
203 });
204 let kernel_block = Array2::from_shape_vec((chunk_n, k_raw), values)
205 .expect("kernel chunk shape should match generated values");
206 if let Some(gauge) = self.kernel_gauge.as_ref() {
207 gauge.restrict_design(&kernel_block)
208 } else {
209 kernel_block
210 }
211 }
212}
213
214impl<K: SpatialKernelEvaluator> LinearOperator for ChunkedKernelDesignOperator<K> {
215 fn nrows(&self) -> usize {
216 self.n
217 }
218 fn ncols(&self) -> usize {
219 self.total_cols
220 }
221 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
222 if let Some(combined) = self.materialized_combined() {
223 return fast_av(combined, vector);
224 }
225 let k_eff = self
226 .kernel_gauge
227 .as_ref()
228 .map_or(self.centers.nrows(), |g| g.reduced_total());
229 let v_kernel = vector.slice(s![..k_eff]);
230 let mut result = Array1::<f64>::zeros(self.n);
231 for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
233 let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
234 let chunk = self.kernel_chunk(start..end);
235 let partial = fast_av(&chunk, &v_kernel);
236 result.slice_mut(s![start..end]).assign(&partial);
237 }
238 if let Some(poly) = self.poly_basis.as_ref() {
239 let v_poly = vector.slice(s![k_eff..]);
240 let poly_part = fast_av(poly, &v_poly);
241 result += &poly_part;
242 }
243 result
244 }
245 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
246 if let Some(combined) = self.materialized_combined() {
247 return fast_atv(combined, vector);
248 }
249 let k_eff = self
250 .kernel_gauge
251 .as_ref()
252 .map_or(self.centers.nrows(), |g| g.reduced_total());
253 let mut result = Array1::<f64>::zeros(self.total_cols);
254 for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
256 let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
257 let chunk = self.kernel_chunk(start..end);
258 let v_slice = vector.slice(s![start..end]);
259 let partial = fast_atv(&chunk, &v_slice);
260 result.slice_mut(s![..k_eff]).scaled_add(1.0, &partial);
261 }
262 if let Some(poly) = self.poly_basis.as_ref() {
264 let poly_part = fast_atv(poly, vector);
265 result.slice_mut(s![k_eff..]).assign(&poly_part);
266 }
267 result
268 }
269 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
270 let p = self.total_cols;
271 if let Some(combined) = self.materialized_combined() {
276 let mut xtwx = Array2::<f64>::zeros((p, p));
277 stream_weighted_crossprod_into(
278 combined,
279 weights,
280 &mut xtwx,
281 CrossprodStructure::Full,
282 CrossprodAccum::Replace,
283 effective_global_parallelism(),
284 );
285 return Ok(xtwx);
286 }
287 let n = self.n;
291 if n == 0 || p == 0 {
292 return Ok(Array2::<f64>::zeros((p, p)));
293 }
294 let chunk_starts: Vec<usize> = (0..n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE).collect();
295 let xtwx = chunk_starts
296 .into_par_iter()
297 .fold(
298 || Array2::<f64>::zeros((p, p)),
299 |mut acc, start| {
300 let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(n);
301 let chunk = self.row_chunk_combined(start..end);
302 let mut wchunk = chunk.clone();
303 for local in 0..(end - start) {
304 let wi = weights[start + local];
305 wchunk.row_mut(local).mapv_inplace(|v| v * wi);
306 }
307 let chunk_view = FaerArrayView::new(&chunk);
308 let wchunk_view = FaerArrayView::new(&wchunk);
309 let mut acc_view = array2_to_matmut(&mut acc);
310 matmul(
311 acc_view.as_mut(),
312 Accum::Add,
313 chunk_view.as_ref().transpose(),
314 wchunk_view.as_ref(),
315 1.0,
316 Par::Seq,
317 );
318 acc
319 },
320 )
321 .reduce(
322 || Array2::<f64>::zeros((p, p)),
323 |mut a, b| {
324 a += &b;
325 a
326 },
327 );
328 Ok(xtwx)
329 }
330}
331
332impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
333 pub(crate) fn row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
336 if let Some(combined) = self.materialized_combined() {
337 return combined.slice(s![rows, ..]).to_owned();
338 }
339 self.build_row_chunk_combined(rows)
340 }
341
342 fn build_row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
346 let chunk_n = rows.end - rows.start;
347 let k_eff = self
348 .kernel_gauge
349 .as_ref()
350 .map_or(self.centers.nrows(), |g| g.reduced_total());
351 let kernel = self.kernel_chunk(rows.clone());
352 let poly_cols = self.poly_basis.as_ref().map_or(0, |p| p.ncols());
353 let mut combined = Array2::<f64>::zeros((chunk_n, k_eff + poly_cols));
354 combined.slice_mut(s![.., ..k_eff]).assign(&kernel);
355 if let Some(poly) = self.poly_basis.as_ref() {
356 combined
357 .slice_mut(s![.., k_eff..])
358 .assign(&poly.slice(s![rows, ..]));
359 }
360 combined
361 }
362}
363
364impl<K: SpatialKernelEvaluator> DenseDesignOperator for ChunkedKernelDesignOperator<K> {
365 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
369 self.materialized_combined()
370 }
371
372 fn row_chunk_into(
373 &self,
374 rows: Range<usize>,
375 mut out: ArrayViewMut2<'_, f64>,
376 ) -> Result<(), MatrixMaterializationError> {
377 if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
378 return Err(MatrixMaterializationError::MissingRowChunk {
379 context: "ChunkedKernelDesignOperator::row_chunk_into shape mismatch",
380 });
381 }
382 if let Some(combined) = self.materialized_combined() {
383 out.assign(&combined.slice(s![rows, ..]));
384 } else {
385 out.assign(&self.row_chunk_combined(rows));
386 }
387 Ok(())
388 }
389
390 fn to_dense(&self) -> Array2<f64> {
391 if let Some(combined) = self.materialized_combined() {
392 return combined.clone();
393 }
394 self.row_chunk_combined(0..self.n)
395 }
396}
397
398#[cfg(test)]
399mod chunked_kernel_operator_tests {
400 use super::*;
401 use gam_linalg::matrix::DenseDesignMatrix;
402 use ndarray::{Array1, Array2, array};
403 use std::sync::Arc;
404 #[test]
405 fn chunked_kernel_operator_uses_center_rows_for_column_count() {
406 let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
407 let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
408 let kernel =
409 |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
410 let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
411 .expect("chunked kernel operator");
412
413 assert_eq!(operator.ncols(), 3);
414 let chunk = operator.row_chunk_combined(0..2);
415 assert_eq!(chunk.dim(), (2, 3));
416 }
417
418 #[test]
419 fn chunked_kernel_operator_rejects_incompatible_optional_shapes() {
420 let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
421 let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
422 let kernel = |_: &[f64], _: &[f64]| 0.0;
423 let bad_gauge = Arc::new(gam_problem::Gauge::from_block_transforms(&[
424 Array2::<f64>::zeros((2, 1)),
425 ]));
426 let bad_poly = Arc::new(Array2::<f64>::zeros((3, 1)));
427
428 let gauge_err = match ChunkedKernelDesignOperator::new(
429 data.clone(),
430 centers.clone(),
431 kernel,
432 Some(bad_gauge),
433 None,
434 ) {
435 Ok(_) => panic!("gauge raw width should match centers rows"),
437 Err(err) => err,
438 };
439 assert!(gauge_err.contains("kernel gauge raw width 2 != centers rows 3"));
440
441 let poly_err =
442 match ChunkedKernelDesignOperator::new(data, centers, kernel, None, Some(bad_poly)) {
443 Ok(_) => panic!("poly rows should match data rows"),
445 Err(err) => err,
446 };
447 assert!(poly_err.contains("poly_basis rows 3 != data rows 2"));
448 }
449
450 #[test]
451 fn chunked_kernel_operator_canonicalizes_non_contiguous_inputs() {
452 let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]].reversed_axes());
453 let centers = Arc::new(array![[0.0, 1.0, 2.0], [0.0, 1.0, -1.0]].reversed_axes());
454 assert!(!data.is_standard_layout());
455 assert!(!centers.is_standard_layout());
456
457 let kernel =
458 |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
459 let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
460 .expect("chunked kernel operator");
461 let chunk = operator.row_chunk_combined(0..2);
462
463 assert_eq!(chunk.dim(), (2, 3));
464 assert_eq!(chunk[[0, 0]], 0.0);
465 assert_eq!(chunk[[1, 1]], 1.5);
466 }
467 #[test]
468 fn chunked_kernel_operator_exposes_cached_dense_to_block_dispatch() {
469 let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5], [2.0, -1.0]]);
470 let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0]]);
471 let kernel =
472 |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
473 let op = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
474 .expect("chunked kernel operator");
475 let expected = op.to_dense();
476
477 let dense_design = DenseDesignMatrix::from(Arc::new(op));
478
479 let probe = Array1::from_elem(3, 1.0);
480 let warmed = dense_design.apply_transpose(&probe);
481 assert_eq!(warmed.len(), expected.ncols());
482
483 let dense_ref = dense_design
484 .as_dense_ref()
485 .expect("DenseDesignMatrix::as_dense_ref must reach the cached kernel block");
486 assert_eq!(dense_ref.dim(), expected.dim());
487 for ((r, c), v) in expected.indexed_iter() {
488 assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
489 }
490 }
491}