Skip to main content

gam_terms/basis/
matern_gradient.rs

1//! Streaming closed-form gradients for Matérn radial basis values.
2//!
3//! This is the lightweight public primitive for composition-engine callers
4//! that need `dK/dtheta` without finite differences or a full smooth-term
5//! build. It streams row chunks over `(data, centers)` and supports the global
6//! log-kappa coordinate plus per-axis anisotropic log-scale coordinates.
7
8use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
9use rayon::prelude::*;
10
11use crate::basis::duchon_kernel_math::centered_aniso_metric_weights;
12use crate::basis::{BasisError, MaternNu};
13
14/// Default row-chunk size for streaming the `(data × centers)` distance scan.
15/// Chosen so a chunk's working set (`chunk × k_centers` f64) stays in L2 for
16/// typical center counts while keeping rayon task granularity coarse.
17const DEFAULT_ROW_CHUNK: usize = 2048;
18
19/// Argument above which `exp(-a)` underflows `f64` to zero. `f64::MIN_POSITIVE`
20/// occurs near `exp(-708)`; full underflow to `0.0` happens by `exp(-745)`, so
21/// the polynomial-times-`exp(-a)` product is exactly zero beyond this point.
22const EXP_NEG_UNDERFLOW: f64 = 745.0;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum MaternBasisGradientTarget {
26    LogKappa,
27    AnisoLogScale(usize),
28}
29
30#[derive(Debug, Clone)]
31pub struct StreamingMaternBasisGradientEvaluator {
32    centers: Array2<f64>,
33    length_scale: f64,
34    nu: MaternNu,
35    metric_weights: Vec<f64>,
36    chunk_size: usize,
37}
38
39impl StreamingMaternBasisGradientEvaluator {
40    pub fn new(
41        centers: ArrayView2<'_, f64>,
42        length_scale: f64,
43        nu: MaternNu,
44        aniso_log_scales: Option<&[f64]>,
45        chunk_size: Option<usize>,
46    ) -> Result<Self, BasisError> {
47        if centers.ncols() == 0 {
48            crate::bail_invalid_basis!(
49                "StreamingMaternBasisGradientEvaluator requires centers with at least one column"
50                    .to_string(),
51            );
52        }
53        if centers.iter().any(|v| !v.is_finite()) {
54            crate::bail_invalid_basis!(
55                "StreamingMaternBasisGradientEvaluator centers must be finite"
56            );
57        }
58        if !(length_scale.is_finite() && length_scale > 0.0) {
59            crate::bail_invalid_basis!(
60                "StreamingMaternBasisGradientEvaluator length_scale must be finite and positive; got {length_scale}"
61            );
62        }
63        let metric_weights = match aniso_log_scales {
64            Some(eta) => {
65                if eta.len() != centers.ncols() {
66                    crate::bail_dim_basis!(
67                        "aniso_log_scales length {} != center dimension {}",
68                        eta.len(),
69                        centers.ncols()
70                    );
71                }
72                for (axis, value) in eta.iter().enumerate() {
73                    if !value.is_finite() {
74                        return Err(BasisError::InvalidInput(format!(
75                            "aniso_log_scales[{axis}] must be finite"
76                        )));
77                    }
78                }
79                centered_aniso_metric_weights(eta)
80            }
81            None => vec![1.0; centers.ncols()],
82        };
83        Ok(Self {
84            centers: centers.as_standard_layout().to_owned(),
85            length_scale,
86            nu,
87            metric_weights,
88            chunk_size: chunk_size.unwrap_or(DEFAULT_ROW_CHUNK).max(1),
89        })
90    }
91
92    pub fn n_centers(&self) -> usize {
93        self.centers.nrows()
94    }
95
96    pub fn dimension(&self) -> usize {
97        self.centers.ncols()
98    }
99
100    pub fn row_chunk_gradient(
101        &self,
102        data: ArrayView2<'_, f64>,
103        start: usize,
104        end: usize,
105        target: MaternBasisGradientTarget,
106    ) -> Result<Array2<f64>, BasisError> {
107        self.validate_data(data)?;
108        if start > end || end > data.nrows() {
109            crate::bail_invalid_basis!(
110                "Matérn gradient row chunk {start}..{end} is outside data with {} rows",
111                data.nrows()
112            );
113        }
114        if let MaternBasisGradientTarget::AnisoLogScale(axis) = target
115            && axis >= self.dimension()
116        {
117            crate::bail_invalid_basis!(
118                "Matérn anisotropic gradient axis {axis} out of bounds for dimension {}",
119                self.dimension()
120            );
121        }
122
123        let chunk_n = end - start;
124        let k = self.n_centers();
125        let dim = self.dimension();
126        let centers = self
127            .centers
128            .as_slice()
129            .expect("standard-layout Matérn gradient centers");
130        let mut values = vec![0.0_f64; chunk_n * k];
131        values
132            .par_chunks_mut(k)
133            .enumerate()
134            .for_each(|(local, row)| {
135                let global = start + local;
136                for center_idx in 0..k {
137                    let c = &centers[center_idx * dim..(center_idx + 1) * dim];
138                    let mut r2 = 0.0;
139                    let mut axis_component = 0.0;
140                    for axis in 0..dim {
141                        let h = data[[global, axis]] - c[axis];
142                        let component = self.metric_weights[axis] * h * h;
143                        r2 += component;
144                        if target == MaternBasisGradientTarget::AnisoLogScale(axis) {
145                            axis_component = component;
146                        }
147                    }
148                    let d_log_kappa =
149                        matern_log_kappa_derivative(r2.sqrt(), self.length_scale, self.nu);
150                    row[center_idx] = match target {
151                        MaternBasisGradientTarget::LogKappa => d_log_kappa,
152                        MaternBasisGradientTarget::AnisoLogScale(_) => {
153                            if r2 == 0.0 {
154                                0.0
155                            } else {
156                                let centered_component = axis_component - r2 / dim as f64;
157                                d_log_kappa * centered_component / r2
158                            }
159                        }
160                    };
161                }
162            });
163        Array2::from_shape_vec((chunk_n, k), values).map_err(|err| {
164            BasisError::InvalidInput(format!("Matérn gradient chunk shape failed: {err}"))
165        })
166    }
167
168    pub fn evaluate(
169        &self,
170        data: ArrayView2<'_, f64>,
171        target: MaternBasisGradientTarget,
172    ) -> Result<Array2<f64>, BasisError> {
173        self.validate_data(data)?;
174        let mut out = Array2::<f64>::zeros((data.nrows(), self.n_centers()));
175        for start in (0..data.nrows()).step_by(self.chunk_size) {
176            let end = (start + self.chunk_size).min(data.nrows());
177            let chunk = self.row_chunk_gradient(data, start, end, target)?;
178            out.slice_mut(s![start..end, ..]).assign(&chunk);
179        }
180        Ok(out)
181    }
182
183    pub fn forward_mul(
184        &self,
185        data: ArrayView2<'_, f64>,
186        target: MaternBasisGradientTarget,
187        coeffs: ArrayView1<'_, f64>,
188    ) -> Result<Array1<f64>, BasisError> {
189        self.validate_data(data)?;
190        if coeffs.len() != self.n_centers() {
191            crate::bail_dim_basis!(
192                "Matérn gradient coeff length {} != centers {}",
193                coeffs.len(),
194                self.n_centers()
195            );
196        }
197        let mut out = Array1::<f64>::zeros(data.nrows());
198        for start in (0..data.nrows()).step_by(self.chunk_size) {
199            let end = (start + self.chunk_size).min(data.nrows());
200            let chunk = self.row_chunk_gradient(data, start, end, target)?;
201            out.slice_mut(s![start..end]).assign(&chunk.dot(&coeffs));
202        }
203        Ok(out)
204    }
205
206    fn validate_data(&self, data: ArrayView2<'_, f64>) -> Result<(), BasisError> {
207        if data.ncols() != self.dimension() {
208            crate::bail_dim_basis!(
209                "Matérn gradient data dimension {} != center dimension {}",
210                data.ncols(),
211                self.dimension()
212            );
213        }
214        if data.iter().any(|v| !v.is_finite()) {
215            crate::bail_invalid_basis!("Matérn gradient data must be finite");
216        }
217        Ok::<(), _>(())
218    }
219}
220
221fn stable_poly_exp(a: f64, coeffs: &[f64]) -> f64 {
222    if a > EXP_NEG_UNDERFLOW {
223        return 0.0;
224    }
225    let mut poly = 0.0;
226    for &coeff in coeffs.iter().rev() {
227        poly = poly * a + coeff;
228    }
229    poly * (-a).exp()
230}
231
232fn matern_log_kappa_derivative(r: f64, length_scale: f64, nu: MaternNu) -> f64 {
233    let x = r / length_scale;
234    match nu {
235        MaternNu::Half => stable_poly_exp(x, &[0.0, -1.0]),
236        MaternNu::ThreeHalves => stable_poly_exp(3.0_f64.sqrt() * x, &[0.0, 0.0, -1.0]),
237        MaternNu::FiveHalves => {
238            stable_poly_exp(5.0_f64.sqrt() * x, &[0.0, 0.0, -1.0 / 3.0, -1.0 / 3.0])
239        }
240        MaternNu::SevenHalves => stable_poly_exp(
241            7.0_f64.sqrt() * x,
242            &[0.0, 0.0, -1.0 / 5.0, -1.0 / 5.0, -1.0 / 15.0],
243        ),
244        MaternNu::NineHalves => stable_poly_exp(
245            9.0_f64.sqrt() * x,
246            &[0.0, 0.0, -1.0 / 7.0, -1.0 / 7.0, -2.0 / 35.0, -1.0 / 105.0],
247        ),
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use ndarray::array;
255
256    fn matern_value_from_distance(r: f64, length_scale: f64, nu: MaternNu) -> f64 {
257        let x = r / length_scale;
258        match nu {
259            MaternNu::Half => stable_poly_exp(x, &[1.0]),
260            MaternNu::ThreeHalves => stable_poly_exp(3.0_f64.sqrt() * x, &[1.0, 1.0]),
261            MaternNu::FiveHalves => stable_poly_exp(5.0_f64.sqrt() * x, &[1.0, 1.0, 1.0 / 3.0]),
262            MaternNu::SevenHalves => {
263                stable_poly_exp(7.0_f64.sqrt() * x, &[1.0, 1.0, 2.0 / 5.0, 1.0 / 15.0])
264            }
265            MaternNu::NineHalves => stable_poly_exp(
266                9.0_f64.sqrt() * x,
267                &[1.0, 1.0, 3.0 / 7.0, 2.0 / 21.0, 1.0 / 105.0],
268            ),
269        }
270    }
271
272    #[test]
273    fn log_kappa_gradient_matches_finite_difference() {
274        let data = array![[0.1, 0.2], [1.0, -0.3], [0.4, 0.8]];
275        let centers = array![[0.0, 0.0], [0.8, 0.5]];
276        let length_scale = 1.3;
277        let eval = StreamingMaternBasisGradientEvaluator::new(
278            centers.view(),
279            length_scale,
280            MaternNu::FiveHalves,
281            None,
282            Some(2),
283        )
284        .unwrap();
285        let analytic = eval
286            .evaluate(data.view(), MaternBasisGradientTarget::LogKappa)
287            .unwrap();
288        let h: f64 = 1.0e-5;
289        for i in 0..data.nrows() {
290            for j in 0..centers.nrows() {
291                let r = ((0..data.ncols())
292                    .map(|axis| {
293                        let d = data[[i, axis]] - centers[[j, axis]];
294                        d * d
295                    })
296                    .sum::<f64>())
297                .sqrt();
298                let plus =
299                    matern_value_from_distance(r, length_scale * (-h).exp(), MaternNu::FiveHalves);
300                let minus =
301                    matern_value_from_distance(r, length_scale * h.exp(), MaternNu::FiveHalves);
302                let fd = (plus - minus) / (2.0 * h);
303                assert!((analytic[[i, j]] - fd).abs() < 1.0e-8);
304            }
305        }
306    }
307
308    #[test]
309    fn anisotropic_axis_gradient_matches_finite_difference() {
310        let data = array![[0.2, -0.1], [1.1, 0.7]];
311        let centers = array![[0.0, 0.0], [0.6, 0.4], [1.0, -0.2]];
312        let eta = [0.2_f64, -0.2];
313        let eval = StreamingMaternBasisGradientEvaluator::new(
314            centers.view(),
315            0.9,
316            MaternNu::ThreeHalves,
317            Some(&eta),
318            Some(1),
319        )
320        .unwrap();
321        let analytic = eval
322            .evaluate(data.view(), MaternBasisGradientTarget::AnisoLogScale(1))
323            .unwrap();
324        let h = 1.0e-5;
325        for i in 0..data.nrows() {
326            for j in 0..centers.nrows() {
327                let value_at = |axis_eta: f64| {
328                    let eta_trial = [eta[0], axis_eta];
329                    let weights = centered_aniso_metric_weights(&eta_trial);
330                    let r = ((0..2)
331                        .map(|axis| {
332                            let d = data[[i, axis]] - centers[[j, axis]];
333                            weights[axis] * d * d
334                        })
335                        .sum::<f64>())
336                    .sqrt();
337                    matern_value_from_distance(r, 0.9, MaternNu::ThreeHalves)
338                };
339                let fd = (value_at(eta[1] + h) - value_at(eta[1] - h)) / (2.0 * h);
340                assert!((analytic[[i, j]] - fd).abs() < 1.0e-8);
341            }
342        }
343    }
344
345    #[test]
346    fn forward_mul_matches_materialized_dot() {
347        let data = array![[0.1], [0.3], [0.8]];
348        let centers = array![[0.0], [0.5]];
349        let coeffs = array![2.0, -0.25];
350        let eval = StreamingMaternBasisGradientEvaluator::new(
351            centers.view(),
352            1.1,
353            MaternNu::NineHalves,
354            None,
355            Some(2),
356        )
357        .unwrap();
358        let dense = eval
359            .evaluate(data.view(), MaternBasisGradientTarget::LogKappa)
360            .unwrap();
361        let streaming = eval
362            .forward_mul(
363                data.view(),
364                MaternBasisGradientTarget::LogKappa,
365                coeffs.view(),
366            )
367            .unwrap();
368        assert_eq!(streaming, dense.dot(&coeffs));
369    }
370}