Skip to main content

gam_problem/
penalty_matrix.rs

1//! The `PenaltyMatrix` carrier (dense / Kronecker / scaled) used by every
2//! custom-family block, plus its constructors and the `Array2` conversion.
3
4use ndarray::{Array1, Array2, Axis};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7/// A penalty matrix that may be stored in Kronecker-factored form.
8///
9/// For tensor-product terms (e.g. time-varying survival covariates), the penalty
10/// has the structure `S = left ⊗ right` (Kronecker product). Keeping this
11/// factored avoids materializing (p_left × p_right)² dense entries and enables
12/// exact log-determinant computation via `log|A ⊗ B| = n_B log|A| + n_A log|B|`.
13///
14/// Dense penalties are stored as-is.  Callers that need a raw `Array2<f64>` can
15/// call `as_dense()` (zero-cost for Dense, lazy-materialized for KroneckerFactored).
16#[derive(Clone, Debug)]
17pub enum PenaltyMatrix {
18    Dense(Array2<f64>),
19    KroneckerFactored {
20        left: Array2<f64>,
21        right: Array2<f64>,
22    },
23    /// Block-local penalty: `local` is `block_dim × block_dim`, embedded at
24    /// `col_range` in the full parameter space of dimension `total_dim`.
25    /// Avoids materializing the full `total_dim × total_dim` matrix.
26    Blockwise {
27        local: Array2<f64>,
28        col_range: std::ops::Range<usize>,
29        total_dim: usize,
30    },
31    /// Wrapper assigning this penalty component to a user-visible precision
32    /// label. Components with the same label share one smoothing parameter.
33    Labeled {
34        label: String,
35        inner: Box<PenaltyMatrix>,
36    },
37    /// Wrapper fixing this penalty component at a physical log-precision.
38    /// Fixed components remain in the block-local physical penalty layout but
39    /// are removed from the REML outer coordinate vector.
40    Fixed {
41        log_lambda: f64,
42        inner: Box<PenaltyMatrix>,
43    },
44}
45
46impl PenaltyMatrix {
47    /// Number of rows (= number of columns, since penalties are square).
48    pub fn dim(&self) -> usize {
49        match self {
50            Self::Dense(m) => m.nrows(),
51            Self::KroneckerFactored { left, right } => left.nrows() * right.nrows(),
52            Self::Blockwise { total_dim, .. } => *total_dim,
53            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dim(),
54        }
55    }
56
57    /// Returns (nrows, ncols) like Array2::dim().
58    pub fn shape(&self) -> (usize, usize) {
59        let d = self.dim();
60        (d, d)
61    }
62
63    /// Materialize the full dense matrix.
64    pub fn to_dense(&self) -> Array2<f64> {
65        match self {
66            Self::Dense(m) => m.clone(),
67            Self::KroneckerFactored { left, right } => kronecker_product(left, right),
68            Self::Blockwise {
69                local,
70                col_range,
71                total_dim,
72            } => {
73                let mut g = Array2::zeros((*total_dim, *total_dim));
74                g.slice_mut(ndarray::s![
75                    col_range.start..col_range.end,
76                    col_range.start..col_range.end
77                ])
78                .assign(local);
79                g
80            }
81            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.to_dense(),
82        }
83    }
84
85    /// Borrow the inner dense matrix if Dense, otherwise materialize.
86    pub fn as_dense_cow(&self) -> std::borrow::Cow<'_, Array2<f64>> {
87        match self {
88            Self::Dense(m) => std::borrow::Cow::Borrowed(m),
89            Self::KroneckerFactored { .. }
90            | Self::Blockwise { .. }
91            | Self::Labeled { .. }
92            | Self::Fixed { .. } => std::borrow::Cow::Owned(self.to_dense()),
93        }
94    }
95
96    /// Returns a reference to the inner matrix if this is a Dense variant.
97    pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
98        match self {
99            Self::Dense(m) => Some(m),
100            Self::Fixed { inner, .. } => inner.as_dense_ref(),
101            Self::KroneckerFactored { .. } | Self::Blockwise { .. } | Self::Labeled { .. } => None,
102        }
103    }
104
105    pub fn with_precision_label(self, label: impl Into<String>) -> Self {
106        Self::Labeled {
107            label: label.into(),
108            inner: Box::new(self),
109        }
110    }
111
112    pub fn precision_label(&self) -> Option<&str> {
113        match self {
114            Self::Labeled { label, .. } => Some(label.as_str()),
115            Self::Fixed { .. } => None,
116            _ => None,
117        }
118    }
119
120    pub fn with_fixed_log_lambda(self, log_lambda: f64) -> Self {
121        Self::Fixed {
122            log_lambda,
123            inner: Box::new(self),
124        }
125    }
126
127    pub fn fixed_log_lambda(&self) -> Option<f64> {
128        match self {
129            Self::Fixed { log_lambda, .. } => Some(*log_lambda),
130            Self::Labeled { inner, .. } => inner.fixed_log_lambda(),
131            _ => None,
132        }
133    }
134
135    /// Compute S * v using the row-major Kronecker vec trick when factored:
136    ///   (A ⊗ B) vec_rm(V) = vec_rm(A V Bᵀ)
137    /// where V = reshape(v, (p_left, p_right)).
138    pub fn dot(&self, v: &Array1<f64>) -> Array1<f64> {
139        match self {
140            Self::Dense(m) => m.dot(v),
141            Self::KroneckerFactored { left, right } => {
142                let p_left = left.nrows();
143                let p_right = right.nrows();
144                // v is ordered by i_left * p_right + i_right.
145                let v_mat =
146                    ndarray::ArrayView2::from_shape((p_left, p_right), v.as_slice().unwrap())
147                        .unwrap();
148                let avbt = left.dot(&v_mat).dot(&right.t());
149                let standard = avbt.as_standard_layout();
150                Array1::from_iter(standard.iter().copied())
151            }
152            Self::Blockwise {
153                local,
154                col_range,
155                total_dim,
156            } => {
157                let mut out = Array1::zeros(*total_dim);
158                let v_block = v.slice(ndarray::s![col_range.clone()]);
159                let result_block = local.dot(&v_block);
160                out.slice_mut(ndarray::s![col_range.clone()])
161                    .assign(&result_block);
162                out
163            }
164            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dot(v),
165        }
166    }
167
168    /// Add λ * self to a mutable dense accumulator.
169    pub fn add_scaled_to(&self, lambda: f64, target: &mut Array2<f64>) {
170        match self {
171            Self::Dense(m) => {
172                target.scaled_add(lambda, m);
173            }
174            Self::KroneckerFactored { left, right } => {
175                let p_left = left.nrows();
176                let p_right = right.nrows();
177                for i1 in 0..p_left {
178                    for j1 in 0..p_left {
179                        let a_ij = left[[i1, j1]];
180                        if a_ij == 0.0 {
181                            continue;
182                        }
183                        let scaled_a = lambda * a_ij;
184                        for i2 in 0..p_right {
185                            let row = i1 * p_right + i2;
186                            for j2 in 0..p_right {
187                                let col = j1 * p_right + j2;
188                                target[[row, col]] += scaled_a * right[[i2, j2]];
189                            }
190                        }
191                    }
192                }
193            }
194            Self::Blockwise {
195                local, col_range, ..
196            } => {
197                target
198                    .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
199                    .scaled_add(lambda, local);
200            }
201            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
202                inner.add_scaled_to(lambda, target)
203            }
204        }
205    }
206
207    /// Add λ * diag(self) to a mutable diagonal accumulator.
208    pub fn add_scaled_diag_to(&self, lambda: f64, target: &mut Array1<f64>) {
209        match self {
210            Self::Dense(m) => {
211                let p = m.nrows().min(target.len());
212                for j in 0..p {
213                    target[j] += lambda * m[[j, j]];
214                }
215            }
216            Self::KroneckerFactored { left, right } => {
217                let p_left = left.nrows();
218                let p_right = right.nrows();
219                assert_eq!(target.len(), p_left * p_right);
220                for i_left in 0..p_left {
221                    let left_diag = left[[i_left, i_left]];
222                    if left_diag == 0.0 {
223                        continue;
224                    }
225                    let scaled_left = lambda * left_diag;
226                    for i_right in 0..p_right {
227                        target[i_left * p_right + i_right] +=
228                            scaled_left * right[[i_right, i_right]];
229                    }
230                }
231            }
232            Self::Blockwise {
233                local, col_range, ..
234            } => {
235                let width = local.nrows().min(col_range.len());
236                for local_idx in 0..width {
237                    target[col_range.start + local_idx] += lambda * local[[local_idx, local_idx]];
238                }
239            }
240            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
241                inner.add_scaled_diag_to(lambda, target)
242            }
243        }
244    }
245
246    /// Compute the quadratic form β' S β.
247    pub fn quadratic_form(&self, beta: &Array1<f64>) -> f64 {
248        match self {
249            Self::Dense(m) => beta.dot(&m.dot(beta)),
250            Self::KroneckerFactored { .. } => {
251                let sv = self.dot(beta);
252                beta.dot(&sv)
253            }
254            Self::Blockwise {
255                local, col_range, ..
256            } => {
257                let beta_block = beta.slice(ndarray::s![col_range.clone()]);
258                let sv = local.dot(&beta_block);
259                beta_block.dot(&sv)
260            }
261            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.quadratic_form(beta),
262        }
263    }
264
265    /// Access dimensions like an Array2.
266    pub fn nrows(&self) -> usize {
267        self.dim()
268    }
269
270    pub fn ncols(&self) -> usize {
271        self.dim()
272    }
273}
274
275impl From<Array2<f64>> for PenaltyMatrix {
276    fn from(m: Array2<f64>) -> Self {
277        Self::Dense(m)
278    }
279}
280
281/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
282/// This is used to create tensor product penalties that enforce smoothness
283/// in multiple dimensions for interaction terms.
284fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
285    let (arows, a_cols) = a.dim();
286    let (brows, b_cols) = b.dim();
287    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
288        return Array2::zeros((arows * brows, a_cols * b_cols));
289    }
290    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
291
292    result
293        .axis_chunks_iter_mut(Axis(0), brows)
294        .into_par_iter()
295        .enumerate()
296        .for_each(|(i, mut row_block)| {
297            let arow = a.row(i);
298            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
299            for (j, mut block) in col_chunks.into_iter().enumerate() {
300                let aval = arow[j];
301                if aval == 0.0 {
302                    continue;
303                }
304                for (dest, &src) in block.iter_mut().zip(b.iter()) {
305                    *dest = aval * src;
306                }
307            }
308        });
309
310    result
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use ndarray::array;
317
318    // ── Dense variant ─────────────────────────────────────────────────────────
319
320    #[test]
321    fn dense_dim_and_shape() {
322        let m = array![[1.0, 0.0], [0.0, 2.0]];
323        let p = PenaltyMatrix::Dense(m);
324        assert_eq!(p.dim(), 2);
325        assert_eq!(p.shape(), (2, 2));
326        assert_eq!(p.nrows(), 2);
327        assert_eq!(p.ncols(), 2);
328    }
329
330    #[test]
331    fn dense_to_dense_is_clone() {
332        let m = array![[3.0, 1.0], [1.0, 4.0]];
333        let p = PenaltyMatrix::Dense(m.clone());
334        assert_eq!(p.to_dense(), m);
335    }
336
337    #[test]
338    fn dense_dot_product() {
339        // [[1, 0], [0, 2]] · [3, 5] = [3, 10]
340        let m = array![[1.0, 0.0], [0.0, 2.0]];
341        let p = PenaltyMatrix::Dense(m);
342        let v = ndarray::array![3.0, 5.0];
343        let result = p.dot(&v);
344        assert_eq!(result.as_slice().unwrap(), &[3.0, 10.0]);
345    }
346
347    #[test]
348    fn dense_quadratic_form() {
349        // beta' S beta with S=diag(1,2), beta=[3,2] → 9 + 8 = 17
350        let m = array![[1.0, 0.0], [0.0, 2.0]];
351        let p = PenaltyMatrix::Dense(m);
352        let beta = ndarray::array![3.0, 2.0];
353        assert!((p.quadratic_form(&beta) - 17.0).abs() < 1e-14);
354    }
355
356    #[test]
357    fn dense_add_scaled_to() {
358        let s = array![[1.0, 0.0], [0.0, 1.0]];
359        let p = PenaltyMatrix::Dense(s);
360        let mut acc = ndarray::Array2::<f64>::zeros((2, 2));
361        p.add_scaled_to(3.0, &mut acc);
362        assert_eq!(acc, array![[3.0, 0.0], [0.0, 3.0]]);
363    }
364
365    #[test]
366    fn dense_add_scaled_diag_to() {
367        let s = array![[2.0, 5.0], [5.0, 7.0]];
368        let p = PenaltyMatrix::Dense(s);
369        let mut diag = ndarray::array![0.0, 0.0];
370        p.add_scaled_diag_to(1.0, &mut diag);
371        // diagonal entries are 2.0 and 7.0
372        assert_eq!(diag.as_slice().unwrap(), &[2.0, 7.0]);
373    }
374
375    // ── KroneckerFactored variant ─────────────────────────────────────────────
376
377    #[test]
378    fn kronecker_dim_is_product() {
379        let left = array![[1.0, 0.0], [0.0, 1.0]]; // 2×2
380        let right = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; // 3×3
381        let p = PenaltyMatrix::KroneckerFactored { left, right };
382        assert_eq!(p.dim(), 6);
383    }
384
385    #[test]
386    fn kronecker_to_dense_identity_x_identity() {
387        // I_2 ⊗ I_2 = I_4
388        let eye2 = ndarray::Array2::<f64>::eye(2);
389        let p = PenaltyMatrix::KroneckerFactored {
390            left: eye2.clone(),
391            right: eye2,
392        };
393        let dense = p.to_dense();
394        assert_eq!(dense, ndarray::Array2::<f64>::eye(4));
395    }
396
397    #[test]
398    fn kronecker_dot_matches_dense_dot() {
399        let left = array![[2.0, 0.0], [0.0, 3.0]];
400        let right = array![[1.0, 1.0], [0.0, 1.0]];
401        let p = PenaltyMatrix::KroneckerFactored {
402            left: left.clone(),
403            right: right.clone(),
404        };
405        // Compare to materialised version
406        let dense = p.to_dense();
407        let v = ndarray::array![1.0, 2.0, 3.0, 4.0];
408        let got = p.dot(&v);
409        let expected = dense.dot(&v);
410        for (a, b) in got.iter().zip(expected.iter()) {
411            assert!((a - b).abs() < 1e-14, "got={a} expected={b}");
412        }
413    }
414
415    // ── Blockwise variant ─────────────────────────────────────────────────────
416
417    #[test]
418    fn blockwise_dim_is_total() {
419        let local = array![[1.0, 0.0], [0.0, 1.0]];
420        let p = PenaltyMatrix::Blockwise {
421            local,
422            col_range: 1..3,
423            total_dim: 5,
424        };
425        assert_eq!(p.dim(), 5);
426    }
427
428    #[test]
429    fn blockwise_to_dense_embeds_local_block() {
430        // 3×3 total with local 2×2 at cols 1..3
431        let local = array![[2.0, 1.0], [1.0, 3.0]];
432        let p = PenaltyMatrix::Blockwise {
433            local,
434            col_range: 1..3,
435            total_dim: 3,
436        };
437        let dense = p.to_dense();
438        assert_eq!(dense[[0, 0]], 0.0);
439        assert_eq!(dense[[1, 1]], 2.0);
440        assert_eq!(dense[[1, 2]], 1.0);
441        assert_eq!(dense[[2, 1]], 1.0);
442        assert_eq!(dense[[2, 2]], 3.0);
443    }
444
445    #[test]
446    fn blockwise_dot_only_touches_block() {
447        let local = array![[2.0, 0.0], [0.0, 3.0]];
448        let p = PenaltyMatrix::Blockwise {
449            local,
450            col_range: 1..3,
451            total_dim: 4,
452        };
453        let v = ndarray::array![7.0, 1.0, 2.0, 9.0];
454        let out = p.dot(&v);
455        // v[1..3] = [1,2]; local * [1,2] = [2,6]; embedded at positions 1..3
456        assert_eq!(out.as_slice().unwrap(), &[0.0, 2.0, 6.0, 0.0]);
457    }
458
459    // ── Labeled / Fixed wrappers ──────────────────────────────────────────────
460
461    #[test]
462    fn labeled_inherits_dim_and_delegates_dot() {
463        let m = array![[1.0, 0.0], [0.0, 2.0]];
464        let p = PenaltyMatrix::Dense(m).with_precision_label("smooth");
465        assert_eq!(p.dim(), 2);
466        assert_eq!(p.precision_label(), Some("smooth"));
467        let v = ndarray::array![3.0, 4.0];
468        let out = p.dot(&v);
469        assert_eq!(out.as_slice().unwrap(), &[3.0, 8.0]);
470    }
471
472    #[test]
473    fn fixed_inherits_dim_and_exposes_log_lambda() {
474        let m = array![[5.0, 0.0], [0.0, 5.0]];
475        let p = PenaltyMatrix::Dense(m).with_fixed_log_lambda(2.5);
476        assert_eq!(p.dim(), 2);
477        assert_eq!(p.fixed_log_lambda(), Some(2.5));
478    }
479}