1use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
9use rayon::prelude::*;
10
11use crate::basis::duchon_kernel_math::centered_aniso_metric_weights;
12use crate::basis::{BasisError, MaternNu};
13
14const DEFAULT_ROW_CHUNK: usize = 2048;
18
19const EXP_NEG_UNDERFLOW: f64 = 745.0;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum MaternBasisGradientTarget {
26 LogKappa,
27 AnisoLogScale(usize),
28}
29
30#[derive(Debug, Clone)]
31pub struct StreamingMaternBasisGradientEvaluator {
32 centers: Array2<f64>,
33 length_scale: f64,
34 nu: MaternNu,
35 metric_weights: Vec<f64>,
36 chunk_size: usize,
37}
38
39impl StreamingMaternBasisGradientEvaluator {
40 pub fn new(
41 centers: ArrayView2<'_, f64>,
42 length_scale: f64,
43 nu: MaternNu,
44 aniso_log_scales: Option<&[f64]>,
45 chunk_size: Option<usize>,
46 ) -> Result<Self, BasisError> {
47 if centers.ncols() == 0 {
48 crate::bail_invalid_basis!(
49 "StreamingMaternBasisGradientEvaluator requires centers with at least one column"
50 .to_string(),
51 );
52 }
53 if centers.iter().any(|v| !v.is_finite()) {
54 crate::bail_invalid_basis!(
55 "StreamingMaternBasisGradientEvaluator centers must be finite"
56 );
57 }
58 if !(length_scale.is_finite() && length_scale > 0.0) {
59 crate::bail_invalid_basis!(
60 "StreamingMaternBasisGradientEvaluator length_scale must be finite and positive; got {length_scale}"
61 );
62 }
63 let metric_weights = match aniso_log_scales {
64 Some(eta) => {
65 if eta.len() != centers.ncols() {
66 crate::bail_dim_basis!(
67 "aniso_log_scales length {} != center dimension {}",
68 eta.len(),
69 centers.ncols()
70 );
71 }
72 for (axis, value) in eta.iter().enumerate() {
73 if !value.is_finite() {
74 return Err(BasisError::InvalidInput(format!(
75 "aniso_log_scales[{axis}] must be finite"
76 )));
77 }
78 }
79 centered_aniso_metric_weights(eta)
80 }
81 None => vec![1.0; centers.ncols()],
82 };
83 Ok(Self {
84 centers: centers.as_standard_layout().to_owned(),
85 length_scale,
86 nu,
87 metric_weights,
88 chunk_size: chunk_size.unwrap_or(DEFAULT_ROW_CHUNK).max(1),
89 })
90 }
91
92 pub fn n_centers(&self) -> usize {
93 self.centers.nrows()
94 }
95
96 pub fn dimension(&self) -> usize {
97 self.centers.ncols()
98 }
99
100 pub fn row_chunk_gradient(
101 &self,
102 data: ArrayView2<'_, f64>,
103 start: usize,
104 end: usize,
105 target: MaternBasisGradientTarget,
106 ) -> Result<Array2<f64>, BasisError> {
107 self.validate_data(data)?;
108 if start > end || end > data.nrows() {
109 crate::bail_invalid_basis!(
110 "Matérn gradient row chunk {start}..{end} is outside data with {} rows",
111 data.nrows()
112 );
113 }
114 if let MaternBasisGradientTarget::AnisoLogScale(axis) = target
115 && axis >= self.dimension()
116 {
117 crate::bail_invalid_basis!(
118 "Matérn anisotropic gradient axis {axis} out of bounds for dimension {}",
119 self.dimension()
120 );
121 }
122
123 let chunk_n = end - start;
124 let k = self.n_centers();
125 let dim = self.dimension();
126 let centers = self
127 .centers
128 .as_slice()
129 .expect("standard-layout Matérn gradient centers");
130 let mut values = vec![0.0_f64; chunk_n * k];
131 values
132 .par_chunks_mut(k)
133 .enumerate()
134 .for_each(|(local, row)| {
135 let global = start + local;
136 for center_idx in 0..k {
137 let c = ¢ers[center_idx * dim..(center_idx + 1) * dim];
138 let mut r2 = 0.0;
139 let mut axis_component = 0.0;
140 for axis in 0..dim {
141 let h = data[[global, axis]] - c[axis];
142 let component = self.metric_weights[axis] * h * h;
143 r2 += component;
144 if target == MaternBasisGradientTarget::AnisoLogScale(axis) {
145 axis_component = component;
146 }
147 }
148 let d_log_kappa =
149 matern_log_kappa_derivative(r2.sqrt(), self.length_scale, self.nu);
150 row[center_idx] = match target {
151 MaternBasisGradientTarget::LogKappa => d_log_kappa,
152 MaternBasisGradientTarget::AnisoLogScale(_) => {
153 if r2 == 0.0 {
154 0.0
155 } else {
156 let centered_component = axis_component - r2 / dim as f64;
157 d_log_kappa * centered_component / r2
158 }
159 }
160 };
161 }
162 });
163 Array2::from_shape_vec((chunk_n, k), values).map_err(|err| {
164 BasisError::InvalidInput(format!("Matérn gradient chunk shape failed: {err}"))
165 })
166 }
167
168 pub fn evaluate(
169 &self,
170 data: ArrayView2<'_, f64>,
171 target: MaternBasisGradientTarget,
172 ) -> Result<Array2<f64>, BasisError> {
173 self.validate_data(data)?;
174 let mut out = Array2::<f64>::zeros((data.nrows(), self.n_centers()));
175 for start in (0..data.nrows()).step_by(self.chunk_size) {
176 let end = (start + self.chunk_size).min(data.nrows());
177 let chunk = self.row_chunk_gradient(data, start, end, target)?;
178 out.slice_mut(s![start..end, ..]).assign(&chunk);
179 }
180 Ok(out)
181 }
182
183 pub fn forward_mul(
184 &self,
185 data: ArrayView2<'_, f64>,
186 target: MaternBasisGradientTarget,
187 coeffs: ArrayView1<'_, f64>,
188 ) -> Result<Array1<f64>, BasisError> {
189 self.validate_data(data)?;
190 if coeffs.len() != self.n_centers() {
191 crate::bail_dim_basis!(
192 "Matérn gradient coeff length {} != centers {}",
193 coeffs.len(),
194 self.n_centers()
195 );
196 }
197 let mut out = Array1::<f64>::zeros(data.nrows());
198 for start in (0..data.nrows()).step_by(self.chunk_size) {
199 let end = (start + self.chunk_size).min(data.nrows());
200 let chunk = self.row_chunk_gradient(data, start, end, target)?;
201 out.slice_mut(s![start..end]).assign(&chunk.dot(&coeffs));
202 }
203 Ok(out)
204 }
205
206 fn validate_data(&self, data: ArrayView2<'_, f64>) -> Result<(), BasisError> {
207 if data.ncols() != self.dimension() {
208 crate::bail_dim_basis!(
209 "Matérn gradient data dimension {} != center dimension {}",
210 data.ncols(),
211 self.dimension()
212 );
213 }
214 if data.iter().any(|v| !v.is_finite()) {
215 crate::bail_invalid_basis!("Matérn gradient data must be finite");
216 }
217 Ok::<(), _>(())
218 }
219}
220
221fn stable_poly_exp(a: f64, coeffs: &[f64]) -> f64 {
222 if a > EXP_NEG_UNDERFLOW {
223 return 0.0;
224 }
225 let mut poly = 0.0;
226 for &coeff in coeffs.iter().rev() {
227 poly = poly * a + coeff;
228 }
229 poly * (-a).exp()
230}
231
232fn matern_log_kappa_derivative(r: f64, length_scale: f64, nu: MaternNu) -> f64 {
233 let x = r / length_scale;
234 match nu {
235 MaternNu::Half => stable_poly_exp(x, &[0.0, -1.0]),
236 MaternNu::ThreeHalves => stable_poly_exp(3.0_f64.sqrt() * x, &[0.0, 0.0, -1.0]),
237 MaternNu::FiveHalves => {
238 stable_poly_exp(5.0_f64.sqrt() * x, &[0.0, 0.0, -1.0 / 3.0, -1.0 / 3.0])
239 }
240 MaternNu::SevenHalves => stable_poly_exp(
241 7.0_f64.sqrt() * x,
242 &[0.0, 0.0, -1.0 / 5.0, -1.0 / 5.0, -1.0 / 15.0],
243 ),
244 MaternNu::NineHalves => stable_poly_exp(
245 9.0_f64.sqrt() * x,
246 &[0.0, 0.0, -1.0 / 7.0, -1.0 / 7.0, -2.0 / 35.0, -1.0 / 105.0],
247 ),
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use ndarray::array;
255
256 fn matern_value_from_distance(r: f64, length_scale: f64, nu: MaternNu) -> f64 {
257 let x = r / length_scale;
258 match nu {
259 MaternNu::Half => stable_poly_exp(x, &[1.0]),
260 MaternNu::ThreeHalves => stable_poly_exp(3.0_f64.sqrt() * x, &[1.0, 1.0]),
261 MaternNu::FiveHalves => stable_poly_exp(5.0_f64.sqrt() * x, &[1.0, 1.0, 1.0 / 3.0]),
262 MaternNu::SevenHalves => {
263 stable_poly_exp(7.0_f64.sqrt() * x, &[1.0, 1.0, 2.0 / 5.0, 1.0 / 15.0])
264 }
265 MaternNu::NineHalves => stable_poly_exp(
266 9.0_f64.sqrt() * x,
267 &[1.0, 1.0, 3.0 / 7.0, 2.0 / 21.0, 1.0 / 105.0],
268 ),
269 }
270 }
271
272 #[test]
273 fn log_kappa_gradient_matches_finite_difference() {
274 let data = array![[0.1, 0.2], [1.0, -0.3], [0.4, 0.8]];
275 let centers = array![[0.0, 0.0], [0.8, 0.5]];
276 let length_scale = 1.3;
277 let eval = StreamingMaternBasisGradientEvaluator::new(
278 centers.view(),
279 length_scale,
280 MaternNu::FiveHalves,
281 None,
282 Some(2),
283 )
284 .unwrap();
285 let analytic = eval
286 .evaluate(data.view(), MaternBasisGradientTarget::LogKappa)
287 .unwrap();
288 let h: f64 = 1.0e-5;
289 for i in 0..data.nrows() {
290 for j in 0..centers.nrows() {
291 let r = ((0..data.ncols())
292 .map(|axis| {
293 let d = data[[i, axis]] - centers[[j, axis]];
294 d * d
295 })
296 .sum::<f64>())
297 .sqrt();
298 let plus =
299 matern_value_from_distance(r, length_scale * (-h).exp(), MaternNu::FiveHalves);
300 let minus =
301 matern_value_from_distance(r, length_scale * h.exp(), MaternNu::FiveHalves);
302 let fd = (plus - minus) / (2.0 * h);
303 assert!((analytic[[i, j]] - fd).abs() < 1.0e-8);
304 }
305 }
306 }
307
308 #[test]
309 fn anisotropic_axis_gradient_matches_finite_difference() {
310 let data = array![[0.2, -0.1], [1.1, 0.7]];
311 let centers = array![[0.0, 0.0], [0.6, 0.4], [1.0, -0.2]];
312 let eta = [0.2_f64, -0.2];
313 let eval = StreamingMaternBasisGradientEvaluator::new(
314 centers.view(),
315 0.9,
316 MaternNu::ThreeHalves,
317 Some(&eta),
318 Some(1),
319 )
320 .unwrap();
321 let analytic = eval
322 .evaluate(data.view(), MaternBasisGradientTarget::AnisoLogScale(1))
323 .unwrap();
324 let h = 1.0e-5;
325 for i in 0..data.nrows() {
326 for j in 0..centers.nrows() {
327 let value_at = |axis_eta: f64| {
328 let eta_trial = [eta[0], axis_eta];
329 let weights = centered_aniso_metric_weights(&eta_trial);
330 let r = ((0..2)
331 .map(|axis| {
332 let d = data[[i, axis]] - centers[[j, axis]];
333 weights[axis] * d * d
334 })
335 .sum::<f64>())
336 .sqrt();
337 matern_value_from_distance(r, 0.9, MaternNu::ThreeHalves)
338 };
339 let fd = (value_at(eta[1] + h) - value_at(eta[1] - h)) / (2.0 * h);
340 assert!((analytic[[i, j]] - fd).abs() < 1.0e-8);
341 }
342 }
343 }
344
345 #[test]
346 fn forward_mul_matches_materialized_dot() {
347 let data = array![[0.1], [0.3], [0.8]];
348 let centers = array![[0.0], [0.5]];
349 let coeffs = array![2.0, -0.25];
350 let eval = StreamingMaternBasisGradientEvaluator::new(
351 centers.view(),
352 1.1,
353 MaternNu::NineHalves,
354 None,
355 Some(2),
356 )
357 .unwrap();
358 let dense = eval
359 .evaluate(data.view(), MaternBasisGradientTarget::LogKappa)
360 .unwrap();
361 let streaming = eval
362 .forward_mul(
363 data.view(),
364 MaternBasisGradientTarget::LogKappa,
365 coeffs.view(),
366 )
367 .unwrap();
368 assert_eq!(streaming, dense.dot(&coeffs));
369 }
370}