Skip to main content

gam_solve/
measure_jet_gram_cache.rs

1//! Sufficient-statistic caches for #1033 mechanism (a), the measure-jet
2//! fixed-design case.
3//!
4//! This module is for single-scale-mode measure jets where `dX/dpsi == 0`: the
5//! design matrix `X` is theta-invariant across the lambda/rho outer loop, while
6//! the penalty and, for GLM PIRLS, the scalar working-weight diagonal `W` may
7//! change. It is distinct from `GaussianFixedCache`, which covers only the
8//! Gaussian+identity lane with constant `W`, and from `PsiGramTensor`, which
9//! covers design-moving psi via Chebyshev expansions, #1033 mechanism (b).
10//!
11//! Invariant: n-row work from the measure-jet basis builder happens once per fit
12//! at construction. Gaussian constant-`W` accessors are O(p^3) or cheaper and do
13//! not re-touch the n design rows. The GLM changing-`W` lane keeps the fixed
14//! rows cached and performs only the irreducible weighted contractions needed
15//! when PIRLS weights move.
16
17use gam_linalg::faer_ndarray::{fast_xt_diag_x, fast_xt_diag_y};
18use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19
20/// Gaussian / constant-`W` sufficient statistics for a fixed design.
21///
22/// This stores `X'WX`, `X'W(y - offset)`, and `(y - offset)'W(y - offset)` so
23/// per-lambda assembly and RSS/evidence terms are n-free. It generalizes the
24/// constant-design idea beyond the existing Gaussian+identity-only cache while
25/// keeping the same fixed-`W` requirement for this lane.
26pub struct FixedDesignGramCache {
27    xtwx: Array2<f64>,
28    xtwy: Array1<f64>,
29    ywy: f64,
30    n: usize,
31    p: usize,
32}
33
34impl FixedDesignGramCache {
35    /// Build fixed-design Gaussian sufficient statistics.
36    ///
37    /// The right-hand side is routed through `fast_xt_diag_y`, the same weighted
38    /// contraction primitive used by the runtime recompute path.
39    pub fn build(
40        x: ArrayView2<'_, f64>,
41        y: ArrayView1<'_, f64>,
42        offset: Option<ArrayView1<'_, f64>>,
43        weights: Option<ArrayView1<'_, f64>>,
44    ) -> Result<Self, String> {
45        let n = x.nrows();
46        let p = x.ncols();
47        if y.len() != n {
48            return Err(format!(
49                "y length {} must match design row count {}",
50                y.len(),
51                n
52            ));
53        }
54        if let Some(offset_values) = offset {
55            if offset_values.len() != n {
56                return Err(format!(
57                    "offset length {} must match design row count {}",
58                    offset_values.len(),
59                    n
60                ));
61            }
62        }
63        if let Some(weight_values) = weights {
64            if weight_values.len() != n {
65                return Err(format!(
66                    "weights length {} must match design row count {}",
67                    weight_values.len(),
68                    n
69                ));
70            }
71            validate_nonnegative_finite_weights(weight_values)?;
72        }
73        validate_finite_vector("y", y)?;
74        if let Some(offset_values) = offset {
75            validate_finite_vector("offset", offset_values)?;
76        }
77        validate_finite_matrix("x", x)?;
78
79        let r = match offset {
80            Some(offset_values) => &y.to_owned() - &offset_values.to_owned(),
81            None => y.to_owned(),
82        };
83        let w = match weights {
84            Some(weight_values) => weight_values.to_owned(),
85            None => Array1::ones(n),
86        };
87        let x_owned = x.to_owned();
88        let xtwx = fast_xt_diag_x(&x_owned, &w);
89        let r2 = r.view().insert_axis(ndarray::Axis(1));
90        let xtwy_mat = fast_xt_diag_y(&x_owned, &w, &r2);
91        let xtwy = xtwy_mat.column(0).to_owned();
92        let ywy = weighted_sum_squares(w.view(), r.view());
93
94        Ok(Self {
95            xtwx,
96            xtwy,
97            ywy,
98            n,
99            p,
100        })
101    }
102
103    pub fn n(&self) -> usize {
104        self.n
105    }
106
107    pub fn p(&self) -> usize {
108        self.p
109    }
110
111    pub fn xtwx(&self) -> ArrayView2<'_, f64> {
112        self.xtwx.view()
113    }
114
115    pub fn xtwy(&self) -> ArrayView1<'_, f64> {
116        self.xtwy.view()
117    }
118
119    pub fn ywy(&self) -> f64 {
120        self.ywy
121    }
122
123    /// Assemble `X'WX + S` for the inner solver without revisiting design rows.
124    pub fn penalized_normal_matrix(
125        &self,
126        penalty: ArrayView2<'_, f64>,
127    ) -> Result<Array2<f64>, String> {
128        if penalty.nrows() != self.p || penalty.ncols() != self.p {
129            return Err(format!(
130                "penalty shape {}x{} must match {}x{}",
131                penalty.nrows(),
132                penalty.ncols(),
133                self.p,
134                self.p
135            ));
136        }
137        let mut normal = self.xtwx.clone();
138        normal += &penalty;
139        Ok(normal)
140    }
141
142    /// Compute penalty-free weighted RSS from sufficient statistics.
143    pub fn penalized_rss(&self, beta: ArrayView1<'_, f64>) -> Result<f64, String> {
144        if beta.len() != self.p {
145            return Err(format!(
146                "beta length {} must match design column count {}",
147                beta.len(),
148                self.p
149            ));
150        }
151        // Expanding (r - Xb)'W(r - Xb) gives ywy - 2 b'X'Wr + b'X'WXb.
152        let gram_beta = self.xtwx.dot(&beta);
153        let linear = beta.dot(&self.xtwy);
154        let quadratic = beta.dot(&gram_beta);
155        Ok(self.ywy - 2.0 * linear + quadratic)
156    }
157}
158
159/// Cached fixed design rows for GLM / changing-`W` PIRLS trials.
160///
161/// This cache owns the theta-invariant `X` rows once. Each trial recomputes
162/// `X'WX` and `X'Wz` because the scalar working weights and working response
163/// genuinely move during PIRLS. The Gaussian constant-Gram trick does not apply
164/// when `W` changes; the saved work is the expensive measure-jet basis/design
165/// construction, not the unavoidable weighted contraction over fixed rows.
166pub struct FixedDesignRowCache {
167    x: Array2<f64>,
168    n: usize,
169    p: usize,
170}
171
172impl FixedDesignRowCache {
173    /// Cache a finite, non-empty fixed design.
174    pub fn build(x: ArrayView2<'_, f64>) -> Result<Self, String> {
175        if x.nrows() == 0 || x.ncols() == 0 {
176            return Err(format!(
177                "design must be non-empty, got shape {}x{}",
178                x.nrows(),
179                x.ncols()
180            ));
181        }
182        validate_finite_matrix("x", x)?;
183        let n = x.nrows();
184        let p = x.ncols();
185        Ok(Self {
186            x: x.to_owned(),
187            n,
188            p,
189        })
190    }
191
192    pub fn n(&self) -> usize {
193        self.n
194    }
195
196    pub fn p(&self) -> usize {
197        self.p
198    }
199
200    pub fn design(&self) -> ArrayView2<'_, f64> {
201        self.x.view()
202    }
203
204    /// Recompute `X' diag(weights) X` over cached rows.
205    ///
206    /// This remains O(n p^2), the irreducible weighted contraction when `W`
207    /// changes. It avoids rebuilding the n-row measure-jet design.
208    pub fn xtwx(&self, weights: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
209        self.validate_changing_weights(weights)?;
210        Ok(fast_xt_diag_x(&self.x, &weights))
211    }
212
213    /// Recompute `X' diag(weights) z` over cached rows for a PIRLS response.
214    pub fn xtwz(
215        &self,
216        weights: ArrayView1<'_, f64>,
217        z: ArrayView1<'_, f64>,
218    ) -> Result<Array1<f64>, String> {
219        self.validate_changing_weights(weights)?;
220        if z.len() != self.n {
221            return Err(format!(
222                "z length {} must match design row count {}",
223                z.len(),
224                self.n
225            ));
226        }
227        validate_finite_vector("z", z)?;
228        let z2 = z.insert_axis(ndarray::Axis(1));
229        let xtwz_mat = fast_xt_diag_y(&self.x, &weights, &z2);
230        Ok(xtwz_mat.column(0).to_owned())
231    }
232
233    fn validate_changing_weights(&self, weights: ArrayView1<'_, f64>) -> Result<(), String> {
234        if weights.len() != self.n {
235            return Err(format!(
236                "weights length {} must match design row count {}",
237                weights.len(),
238                self.n
239            ));
240        }
241        validate_finite_vector("weights", weights)
242    }
243}
244
245fn validate_finite_matrix(name: &str, matrix: ArrayView2<'_, f64>) -> Result<(), String> {
246    for ((row, col), value) in matrix.indexed_iter() {
247        if !(*value).is_finite() {
248            return Err(format!("{name}[{row},{col}] must be finite"));
249        }
250    }
251    Ok(())
252}
253
254fn validate_finite_vector(name: &str, vector: ArrayView1<'_, f64>) -> Result<(), String> {
255    for (index, value) in vector.iter().enumerate() {
256        if !(*value).is_finite() {
257            return Err(format!("{name}[{index}] must be finite"));
258        }
259    }
260    Ok(())
261}
262
263fn validate_nonnegative_finite_weights(weights: ArrayView1<'_, f64>) -> Result<(), String> {
264    for (index, weight) in weights.iter().enumerate() {
265        if !(*weight).is_finite() {
266            return Err(format!("weights[{index}] must be finite"));
267        }
268        if *weight < 0.0 {
269            return Err(format!("weights[{index}] must be non-negative"));
270        }
271    }
272    Ok(())
273}
274
275fn weighted_sum_squares(weights: ArrayView1<'_, f64>, values: ArrayView1<'_, f64>) -> f64 {
276    weights
277        .iter()
278        .zip(values.iter())
279        .map(|(weight, value)| *weight * *value * *value)
280        .sum()
281}
282
283#[cfg(test)]
284mod tests {
285    use super::{FixedDesignGramCache, FixedDesignRowCache};
286    use gam_linalg::faer_ndarray::fast_xt_diag_x;
287    use approx::assert_abs_diff_eq;
288    use ndarray::{Array1, Array2};
289
290    fn deterministic_design(n: usize, p: usize) -> Array2<f64> {
291        Array2::from_shape_fn((n, p), |(i, j)| {
292            let row = i as f64 + 1.0;
293            let col = j as f64 + 1.0;
294            ((row * 0.17 + col * 0.31).sin()) + row * col * 0.002
295        })
296    }
297
298    fn deterministic_response(n: usize) -> Array1<f64> {
299        Array1::from_shape_fn(n, |i| {
300            let row = i as f64 + 1.0;
301            (row * 0.23).cos() + row * 0.015
302        })
303    }
304
305    fn deterministic_offset(n: usize) -> Array1<f64> {
306        Array1::from_shape_fn(n, |i| {
307            let row = i as f64 + 1.0;
308            0.2 * (row * 0.11).sin() - 0.01 * row
309        })
310    }
311
312    fn deterministic_weights(n: usize, scale: f64) -> Array1<f64> {
313        Array1::from_shape_fn(n, |i| {
314            let row = i as f64 + 1.0;
315            0.4 + scale * (1.0 + (row * 0.19).sin())
316        })
317    }
318
319    fn naive_xtx(x: &Array2<f64>) -> Array2<f64> {
320        let n = x.nrows();
321        let p = x.ncols();
322        let mut out = Array2::zeros((p, p));
323        for row in 0..n {
324            for a in 0..p {
325                for b in 0..p {
326                    out[[a, b]] += x[[row, a]] * x[[row, b]];
327                }
328            }
329        }
330        out
331    }
332
333    fn naive_xtwy(x: &Array2<f64>, weights: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
334        let n = x.nrows();
335        let p = x.ncols();
336        let mut out = Array1::zeros(p);
337        for row in 0..n {
338            for col in 0..p {
339                out[col] += x[[row, col]] * weights[row] * r[row];
340            }
341        }
342        out
343    }
344
345    fn naive_xtwz(x: &Array2<f64>, weights: &Array1<f64>, z: &Array1<f64>) -> Array1<f64> {
346        naive_xtwy(x, weights, z)
347    }
348
349    fn naive_ywy(weights: &Array1<f64>, r: &Array1<f64>) -> f64 {
350        let mut sum = 0.0;
351        for row in 0..weights.len() {
352            sum += weights[row] * r[row] * r[row];
353        }
354        sum
355    }
356
357    fn assert_matrix_close(actual: ndarray::ArrayView2<'_, f64>, expected: &Array2<f64>, eps: f64) {
358        assert_eq!(actual.nrows(), expected.nrows());
359        assert_eq!(actual.ncols(), expected.ncols());
360        for row in 0..expected.nrows() {
361            for col in 0..expected.ncols() {
362                assert_abs_diff_eq!(actual[[row, col]], expected[[row, col]], epsilon = eps);
363            }
364        }
365    }
366
367    fn assert_vector_close(actual: ndarray::ArrayView1<'_, f64>, expected: &Array1<f64>, eps: f64) {
368        assert_eq!(actual.len(), expected.len());
369        for index in 0..expected.len() {
370            assert_abs_diff_eq!(actual[index], expected[index], epsilon = eps);
371        }
372    }
373
374    #[test]
375    fn gaussian_xtwx_matches_naive() {
376        let n = 40;
377        let p = 4;
378        let x = deterministic_design(n, p);
379        let y = deterministic_response(n);
380        let cache = FixedDesignGramCache::build(x.view(), y.view(), None, None).unwrap();
381        let naive = naive_xtx(&x);
382        assert_matrix_close(cache.xtwx(), &naive, 1.0e-9);
383    }
384
385    #[test]
386    fn gaussian_xtwy_and_ywy_match_naive() {
387        let n = 40;
388        let p = 4;
389        let x = deterministic_design(n, p);
390        let y = deterministic_response(n);
391        let offset = deterministic_offset(n);
392        let weights = deterministic_weights(n, 0.35);
393        let r = &y - &offset;
394        let cache = FixedDesignGramCache::build(
395            x.view(),
396            y.view(),
397            Some(offset.view()),
398            Some(weights.view()),
399        )
400        .unwrap();
401        let expected_xtwy = naive_xtwy(&x, &weights, &r);
402        let expected_ywy = naive_ywy(&weights, &r);
403        assert_vector_close(cache.xtwy(), &expected_xtwy, 1.0e-9);
404        assert_abs_diff_eq!(cache.ywy(), expected_ywy, epsilon = 1.0e-9);
405    }
406
407    #[test]
408    fn penalized_rss_matches_direct_residual() {
409        let n = 40;
410        let p = 4;
411        let x = deterministic_design(n, p);
412        let y = deterministic_response(n);
413        let offset = deterministic_offset(n);
414        let weights = deterministic_weights(n, 0.21);
415        let beta = Array1::from_vec(vec![0.4, -0.2, 0.15, 0.05]);
416        let r = &y - &offset;
417        let cache = FixedDesignGramCache::build(
418            x.view(),
419            y.view(),
420            Some(offset.view()),
421            Some(weights.view()),
422        )
423        .unwrap();
424        let mut direct = 0.0;
425        for row in 0..n {
426            let mut fit = 0.0;
427            for col in 0..p {
428                fit += x[[row, col]] * beta[col];
429            }
430            let residual = r[row] - fit;
431            direct += weights[row] * residual * residual;
432        }
433        let cached = cache.penalized_rss(beta.view()).unwrap();
434        assert_abs_diff_eq!(cached, direct, epsilon = 1.0e-8);
435    }
436
437    #[test]
438    fn penalized_normal_matrix_adds_penalty() {
439        let n = 40;
440        let p = 4;
441        let x = deterministic_design(n, p);
442        let y = deterministic_response(n);
443        let cache = FixedDesignGramCache::build(x.view(), y.view(), None, None).unwrap();
444        let penalty = Array2::from_shape_fn((p, p), |(row, col)| {
445            if row == col {
446                0.5 + row as f64 * 0.1
447            } else {
448                0.02 * (row + col) as f64
449            }
450        });
451        let normal = cache.penalized_normal_matrix(penalty.view()).unwrap();
452        for row in 0..p {
453            for col in 0..p {
454                let expected = cache.xtwx()[[row, col]] + penalty[[row, col]];
455                assert_abs_diff_eq!(normal[[row, col]], expected, epsilon = 1.0e-12);
456            }
457        }
458    }
459
460    #[test]
461    fn row_cache_xtwx_matches_fresh_build_across_weights() {
462        let n = 40;
463        let p = 4;
464        let x = deterministic_design(n, p);
465        let cache = FixedDesignRowCache::build(x.view()).unwrap();
466        let weight_sets = [
467            deterministic_weights(n, 0.12),
468            deterministic_weights(n, 0.27),
469            deterministic_weights(n, 0.41),
470        ];
471        for weights in weight_sets.iter() {
472            let cached = cache.xtwx(weights.view()).unwrap();
473            let fresh = fast_xt_diag_x(&x, weights);
474            assert_matrix_close(cached.view(), &fresh, 1.0e-12);
475        }
476    }
477
478    #[test]
479    fn row_cache_xtwz_matches_naive() {
480        let n = 40;
481        let p = 4;
482        let x = deterministic_design(n, p);
483        let weights = deterministic_weights(n, 0.33);
484        let z = Array1::from_shape_fn(n, |i| {
485            let row = i as f64 + 1.0;
486            (row * 0.07).sin() + 0.03 * row
487        });
488        let cache = FixedDesignRowCache::build(x.view()).unwrap();
489        let cached = cache.xtwz(weights.view(), z.view()).unwrap();
490        let expected = naive_xtwz(&x, &weights, &z);
491        assert_vector_close(cached.view(), &expected, 1.0e-9);
492    }
493
494    #[test]
495    fn build_rejects_shape_mismatch() {
496        let n = 40;
497        let p = 4;
498        let x = deterministic_design(n, p);
499        let mismatched_y = deterministic_response(n - 1);
500        assert!(FixedDesignGramCache::build(x.view(), mismatched_y.view(), None, None).is_err());
501
502        let y = deterministic_response(n);
503        let mut weights = deterministic_weights(n, 0.2);
504        weights[3] = f64::NAN;
505        assert!(
506            FixedDesignGramCache::build(x.view(), y.view(), None, Some(weights.view())).is_err()
507        );
508    }
509}