Skip to main content

gam_terms/
chunked_kernel_design.rs

1//! Spatial-kernel design operators for basis construction.
2
3use faer::Accum;
4use faer::Par;
5use faer::linalg::matmul::matmul;
6use gam_linalg::faer_ndarray::{
7    CrossprodAccum, CrossprodStructure, FaerArrayView, array2_to_matmut,
8    effective_global_parallelism, fast_atv, fast_av, stream_weighted_crossprod_into,
9};
10use gam_linalg::matrix::{DenseDesignOperator, LinearOperator};
11use gam_problem::Gauge;
12use gam_runtime::resource::MatrixMaterializationError;
13use ndarray::{Array1, Array2, ArrayViewMut2, s};
14use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
15use rayon::slice::ParallelSliceMut;
16use std::ops::Range;
17use std::sync::{Arc, OnceLock};
18
19const KERNEL_OPERATOR_ROW_CHUNK_SIZE: usize = 2048;
20
21pub trait SpatialKernelEvaluator: Send + Sync + 'static {
22    fn eval(&self, x: &[f64], c: &[f64]) -> f64;
23}
24
25impl<F> SpatialKernelEvaluator for F
26where
27    F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static,
28{
29    fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
30        self(x, c)
31    }
32}
33
34impl<F> SpatialKernelEvaluator for Arc<F>
35where
36    F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static + ?Sized,
37{
38    fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
39        self.as_ref()(x, c)
40    }
41}
42
43/// Chunked kernel design operator for spatial smooths (TPS, Matérn, Duchon).
44///
45/// Instead of storing a dense n × k matrix, evaluates K(data[i], center[j])
46/// on-the-fly in row chunks. Memory usage is O(chunk_size × k) instead of O(n × k).
47///
48/// The optional `poly_basis` appends polynomial columns after the kernel columns
49/// (e.g., linear polynomial for TPS identifiability).
50///
51/// The optional `kernel_gauge` restricts the kernel coefficient block through
52/// a Gauge section, so the effective design is [K_reduced | poly] instead of
53/// [K | poly].
54pub struct ChunkedKernelDesignOperator<K: SpatialKernelEvaluator> {
55    /// Observation data points (n × d).
56    data: Arc<Array2<f64>>,
57    /// Radial basis centers (k × d).
58    centers: Arc<Array2<f64>>,
59    /// Kernel evaluator: (data_row, center_row) -> f64.
60    kernel: K,
61    /// Optional coefficient-space gauge applied to kernel columns.
62    kernel_gauge: Option<Arc<Gauge>>,
63    /// Optional polynomial basis columns (n × m) appended after kernel columns.
64    poly_basis: Option<Arc<Array2<f64>>>,
65    n: usize,
66    total_cols: usize,
67    /// One-time-materialized [K_eff | poly] block, populated on first hot use.
68    /// Only allocated when the dense block fits within the materialization budget;
69    /// reused across all PIRLS iterations and outer-seed evaluations.
70    materialized: OnceLock<Option<Arc<Array2<f64>>>>,
71}
72
73impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
74    pub fn new(
75        data: Arc<Array2<f64>>,
76        centers: Arc<Array2<f64>>,
77        kernel: K,
78        kernel_gauge: Option<Arc<Gauge>>,
79        poly_basis: Option<Arc<Array2<f64>>>,
80    ) -> Result<Self, String> {
81        let n = data.nrows();
82        let k = centers.nrows();
83        if data.ncols() != centers.ncols() {
84            return Err(format!(
85                "ChunkedKernelDesignOperator: data dim {} != centers dim {}",
86                data.ncols(),
87                centers.ncols(),
88            ));
89        }
90        if let Some(gauge) = kernel_gauge.as_ref()
91            && gauge.raw_total() != k
92        {
93            return Err(format!(
94                "ChunkedKernelDesignOperator: kernel gauge raw width {} != centers rows {}",
95                gauge.raw_total(),
96                k,
97            ));
98        }
99        if let Some(poly) = poly_basis.as_ref()
100            && poly.nrows() != n
101        {
102            return Err(format!(
103                "ChunkedKernelDesignOperator: poly_basis rows {} != data rows {}",
104                poly.nrows(),
105                n,
106            ));
107        }
108        let k_eff = kernel_gauge.as_ref().map_or(k, |g| g.reduced_total());
109        let poly_cols = poly_basis.as_ref().map_or(0, |p| p.ncols());
110        Ok(Self {
111            data: Arc::new(data.as_standard_layout().to_owned()),
112            centers: Arc::new(centers.as_standard_layout().to_owned()),
113            kernel,
114            kernel_gauge,
115            poly_basis,
116            n,
117            total_cols: k_eff + poly_cols,
118            materialized: OnceLock::new(),
119        })
120    }
121
122    /// Maximum bytes we are willing to spend on the one-shot materialized
123    /// [K_eff | poly] block. The lazy operator was originally selected because
124    /// the *initial* fit-time allocation budget was tight, but once PIRLS is
125    /// running we will pay the kernel-evaluation cost on every iteration unless
126    /// we cache the result. 1 GB is generous enough to cover large-scale
127    /// dense Duchon / TPS (n = 320k, p = 117 → ~300 MiB) while still rejecting
128    /// pathological dense kernels.
129    const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
130
131    /// Get-or-build the materialized [K_eff | poly] dense block.  Returns
132    /// `None` when the block would exceed `MATERIALIZE_MAX_BYTES`; in that
133    /// case callers must fall back to row-chunked evaluation.
134    ///
135    /// Implementation note: the build path runs `par_chunks_mut` inside
136    /// `kernel_chunk`, so we deliberately compute *outside* the
137    /// `OnceLock`. Using `get_or_init` would hold the lock across that
138    /// nested rayon work, and any sibling rayon workers that reach
139    /// `get_or_init` while the build is in flight would park on the
140    /// `OnceLock`'s OS mutex — starving the build's nested `par_iter`
141    /// and deadlocking the whole pool (every worker in
142    /// `pthread_mutex_wait`, init thread in `pthread_cond_wait`,
143    /// 0% CPU). See `feedback_oncelock_rayon_deadlock`. Computing
144    /// without the lock costs at most one redundant build per racing
145    /// caller — `OnceLock::set` discards losers; `get` after `set`
146    /// always observes the winning value regardless of who won.
147    fn materialized_combined(&self) -> Option<&Array2<f64>> {
148        if let Some(slot) = self.materialized.get() {
149            return slot.as_ref().map(|a| a.as_ref());
150        }
151        let bytes = self
152            .n
153            .checked_mul(self.total_cols)
154            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
155        let computed = match bytes {
156            Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
157                Some(Arc::new(self.build_row_chunk_combined(0..self.n)))
158            }
159            _ => None,
160        };
161        if self.materialized.set(computed).is_err() {
162            return self
163                .materialized
164                .get()
165                .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
166        }
167        self.materialized
168            .get()
169            .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
170    }
171
172    /// Evaluate kernel block for a range of rows, then restrict it through the
173    /// coefficient Gauge when present.
174    ///
175    /// This is not a matrix Kronecker product. The center rows are coordinate
176    /// arguments to `kernel.eval(data_row, center_row)`; each output entry is a
177    /// scalar kernel value before the optional column projection.
178    fn kernel_chunk(&self, rows: Range<usize>) -> Array2<f64> {
179        let chunk_n = rows.end - rows.start;
180        let k_raw = self.centers.nrows();
181        let dim = self.data.ncols();
182        let data = self
183            .data
184            .as_slice()
185            .expect("ChunkedKernelDesignOperator stores standard-layout data");
186        let centers = self
187            .centers
188            .as_slice()
189            .expect("ChunkedKernelDesignOperator stores standard-layout centers");
190        let kernel = &self.kernel;
191        let mut values = vec![0.0_f64; chunk_n * k_raw];
192        values
193            .par_chunks_mut(k_raw)
194            .enumerate()
195            .for_each(|(local, out_row)| {
196                let global = rows.start + local;
197                let x_start = global * dim;
198                let x = &data[x_start..x_start + dim];
199                for j in 0..k_raw {
200                    let c_start = j * dim;
201                    out_row[j] = kernel.eval(x, &centers[c_start..c_start + dim]);
202                }
203            });
204        let kernel_block = Array2::from_shape_vec((chunk_n, k_raw), values)
205            .expect("kernel chunk shape should match generated values");
206        if let Some(gauge) = self.kernel_gauge.as_ref() {
207            gauge.restrict_design(&kernel_block)
208        } else {
209            kernel_block
210        }
211    }
212}
213
214impl<K: SpatialKernelEvaluator> LinearOperator for ChunkedKernelDesignOperator<K> {
215    fn nrows(&self) -> usize {
216        self.n
217    }
218    fn ncols(&self) -> usize {
219        self.total_cols
220    }
221    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
222        if let Some(combined) = self.materialized_combined() {
223            return fast_av(combined, vector);
224        }
225        let k_eff = self
226            .kernel_gauge
227            .as_ref()
228            .map_or(self.centers.nrows(), |g| g.reduced_total());
229        let v_kernel = vector.slice(s![..k_eff]);
230        let mut result = Array1::<f64>::zeros(self.n);
231        // Process in chunks to limit memory.
232        for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
233            let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
234            let chunk = self.kernel_chunk(start..end);
235            let partial = fast_av(&chunk, &v_kernel);
236            result.slice_mut(s![start..end]).assign(&partial);
237        }
238        if let Some(poly) = self.poly_basis.as_ref() {
239            let v_poly = vector.slice(s![k_eff..]);
240            let poly_part = fast_av(poly, &v_poly);
241            result += &poly_part;
242        }
243        result
244    }
245    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
246        if let Some(combined) = self.materialized_combined() {
247            return fast_atv(combined, vector);
248        }
249        let k_eff = self
250            .kernel_gauge
251            .as_ref()
252            .map_or(self.centers.nrows(), |g| g.reduced_total());
253        let mut result = Array1::<f64>::zeros(self.total_cols);
254        // Kernel part: chunked accumulation of K^T v.
255        for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
256            let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
257            let chunk = self.kernel_chunk(start..end);
258            let v_slice = vector.slice(s![start..end]);
259            let partial = fast_atv(&chunk, &v_slice);
260            result.slice_mut(s![..k_eff]).scaled_add(1.0, &partial);
261        }
262        // Poly part.
263        if let Some(poly) = self.poly_basis.as_ref() {
264            let poly_part = fast_atv(poly, vector);
265            result.slice_mut(s![k_eff..]).assign(&poly_part);
266        }
267        result
268    }
269    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
270        let p = self.total_cols;
271        // [STAGE] kernel-xtwx: prefer the one-shot materialized [K_eff | poly]
272        // block + faer streaming GEMM.  This is the BLAS-3 path that beats
273        // the per-iteration kernel rebuild by an order of magnitude on dense
274        // Duchon / TPS designs.
275        if let Some(combined) = self.materialized_combined() {
276            let mut xtwx = Array2::<f64>::zeros((p, p));
277            stream_weighted_crossprod_into(
278                combined,
279                weights,
280                &mut xtwx,
281                CrossprodStructure::Full,
282                CrossprodAccum::Replace,
283                effective_global_parallelism(),
284            );
285            return Ok(xtwx);
286        }
287        // Fallback: design too large to materialize.  Run row chunks in
288        // parallel, each thread folding into its own p×p accumulator and
289        // performing one BLAS-3 GEMM (Xc^T·(W·Xc)) per chunk.
290        let n = self.n;
291        if n == 0 || p == 0 {
292            return Ok(Array2::<f64>::zeros((p, p)));
293        }
294        let chunk_starts: Vec<usize> = (0..n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE).collect();
295        let xtwx = chunk_starts
296            .into_par_iter()
297            .fold(
298                || Array2::<f64>::zeros((p, p)),
299                |mut acc, start| {
300                    let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(n);
301                    let chunk = self.row_chunk_combined(start..end);
302                    let mut wchunk = chunk.clone();
303                    for local in 0..(end - start) {
304                        let wi = weights[start + local];
305                        wchunk.row_mut(local).mapv_inplace(|v| v * wi);
306                    }
307                    let chunk_view = FaerArrayView::new(&chunk);
308                    let wchunk_view = FaerArrayView::new(&wchunk);
309                    let mut acc_view = array2_to_matmut(&mut acc);
310                    matmul(
311                        acc_view.as_mut(),
312                        Accum::Add,
313                        chunk_view.as_ref().transpose(),
314                        wchunk_view.as_ref(),
315                        1.0,
316                        Par::Seq,
317                    );
318                    acc
319                },
320            )
321            .reduce(
322                || Array2::<f64>::zeros((p, p)),
323                |mut a, b| {
324                    a += &b;
325                    a
326                },
327            );
328        Ok(xtwx)
329    }
330}
331
332impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
333    /// Combined row chunk: [kernel_chunk | poly_chunk]. Reuses the cached
334    /// materialization when available to avoid recomputing kernel evaluations.
335    pub(crate) fn row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
336        if let Some(combined) = self.materialized_combined() {
337            return combined.slice(s![rows, ..]).to_owned();
338        }
339        self.build_row_chunk_combined(rows)
340    }
341
342    /// Build a combined row chunk from scratch, bypassing the cache. Used by
343    /// `row_chunk_combined` on a cache miss and by `materialized_combined`'s
344    /// initializer (which must avoid re-entering the OnceLock).
345    fn build_row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
346        let chunk_n = rows.end - rows.start;
347        let k_eff = self
348            .kernel_gauge
349            .as_ref()
350            .map_or(self.centers.nrows(), |g| g.reduced_total());
351        let kernel = self.kernel_chunk(rows.clone());
352        let poly_cols = self.poly_basis.as_ref().map_or(0, |p| p.ncols());
353        let mut combined = Array2::<f64>::zeros((chunk_n, k_eff + poly_cols));
354        combined.slice_mut(s![.., ..k_eff]).assign(&kernel);
355        if let Some(poly) = self.poly_basis.as_ref() {
356            combined
357                .slice_mut(s![.., k_eff..])
358                .assign(&poly.slice(s![rows, ..]));
359        }
360        combined
361    }
362}
363
364impl<K: SpatialKernelEvaluator> DenseDesignOperator for ChunkedKernelDesignOperator<K> {
365    /// Expose the cached [K_eff | poly] materialization so cross-block paths
366    /// can use the Dense × Dense BLAS-3 fast path instead of falling back to
367    /// chunked scalar accumulation.
368    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
369        self.materialized_combined()
370    }
371
372    fn row_chunk_into(
373        &self,
374        rows: Range<usize>,
375        mut out: ArrayViewMut2<'_, f64>,
376    ) -> Result<(), MatrixMaterializationError> {
377        if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
378            return Err(MatrixMaterializationError::MissingRowChunk {
379                context: "ChunkedKernelDesignOperator::row_chunk_into shape mismatch",
380            });
381        }
382        if let Some(combined) = self.materialized_combined() {
383            out.assign(&combined.slice(s![rows, ..]));
384        } else {
385            out.assign(&self.row_chunk_combined(rows));
386        }
387        Ok(())
388    }
389
390    fn to_dense(&self) -> Array2<f64> {
391        if let Some(combined) = self.materialized_combined() {
392            return combined.clone();
393        }
394        self.row_chunk_combined(0..self.n)
395    }
396}
397
398#[cfg(test)]
399mod chunked_kernel_operator_tests {
400    use super::*;
401    use gam_linalg::matrix::DenseDesignMatrix;
402    use ndarray::{Array1, Array2, array};
403    use std::sync::Arc;
404    #[test]
405    fn chunked_kernel_operator_uses_center_rows_for_column_count() {
406        let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
407        let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
408        let kernel =
409            |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
410        let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
411            .expect("chunked kernel operator");
412
413        assert_eq!(operator.ncols(), 3);
414        let chunk = operator.row_chunk_combined(0..2);
415        assert_eq!(chunk.dim(), (2, 3));
416    }
417
418    #[test]
419    fn chunked_kernel_operator_rejects_incompatible_optional_shapes() {
420        let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
421        let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
422        let kernel = |_: &[f64], _: &[f64]| 0.0;
423        let bad_gauge = Arc::new(gam_problem::Gauge::from_block_transforms(&[
424            Array2::<f64>::zeros((2, 1)),
425        ]));
426        let bad_poly = Arc::new(Array2::<f64>::zeros((3, 1)));
427
428        let gauge_err = match ChunkedKernelDesignOperator::new(
429            data.clone(),
430            centers.clone(),
431            kernel,
432            Some(bad_gauge),
433            None,
434        ) {
435            // SAFETY: test asserting validation rejects mismatched gauge raw width; Ok means the validator regressed.
436            Ok(_) => panic!("gauge raw width should match centers rows"),
437            Err(err) => err,
438        };
439        assert!(gauge_err.contains("kernel gauge raw width 2 != centers rows 3"));
440
441        let poly_err =
442            match ChunkedKernelDesignOperator::new(data, centers, kernel, None, Some(bad_poly)) {
443                // SAFETY: test asserting validation rejects mismatched poly rows; Ok means the validator regressed.
444                Ok(_) => panic!("poly rows should match data rows"),
445                Err(err) => err,
446            };
447        assert!(poly_err.contains("poly_basis rows 3 != data rows 2"));
448    }
449
450    #[test]
451    fn chunked_kernel_operator_canonicalizes_non_contiguous_inputs() {
452        let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]].reversed_axes());
453        let centers = Arc::new(array![[0.0, 1.0, 2.0], [0.0, 1.0, -1.0]].reversed_axes());
454        assert!(!data.is_standard_layout());
455        assert!(!centers.is_standard_layout());
456
457        let kernel =
458            |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
459        let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
460            .expect("chunked kernel operator");
461        let chunk = operator.row_chunk_combined(0..2);
462
463        assert_eq!(chunk.dim(), (2, 3));
464        assert_eq!(chunk[[0, 0]], 0.0);
465        assert_eq!(chunk[[1, 1]], 1.5);
466    }
467    #[test]
468    fn chunked_kernel_operator_exposes_cached_dense_to_block_dispatch() {
469        let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5], [2.0, -1.0]]);
470        let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0]]);
471        let kernel =
472            |x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
473        let op = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
474            .expect("chunked kernel operator");
475        let expected = op.to_dense();
476
477        let dense_design = DenseDesignMatrix::from(Arc::new(op));
478
479        let probe = Array1::from_elem(3, 1.0);
480        let warmed = dense_design.apply_transpose(&probe);
481        assert_eq!(warmed.len(), expected.ncols());
482
483        let dense_ref = dense_design
484            .as_dense_ref()
485            .expect("DenseDesignMatrix::as_dense_ref must reach the cached kernel block");
486        assert_eq!(dense_ref.dim(), expected.dim());
487        for ((r, c), v) in expected.indexed_iter() {
488            assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
489        }
490    }
491}