Skip to main content

gam_problem/
penalty_matrix.rs

1//! The `PenaltyMatrix` carrier (dense / Kronecker / scaled) used by every
2//! custom-family block, plus its constructors and the `Array2` conversion.
3
4use ndarray::{Array1, Array2, Axis};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7/// A penalty matrix that may be stored in Kronecker-factored form.
8///
9/// For tensor-product terms (e.g. time-varying survival covariates), the penalty
10/// has the structure `S = left ⊗ right` (Kronecker product). Keeping this
11/// factored avoids materializing (p_left × p_right)² dense entries and enables
12/// exact log-determinant computation via `log|A ⊗ B| = n_B log|A| + n_A log|B|`.
13///
14/// Dense penalties are stored as-is.  Callers that need a raw `Array2<f64>` can
15/// call `as_dense()` (zero-cost for Dense, lazy-materialized for KroneckerFactored).
16#[derive(Clone, Debug)]
17pub enum PenaltyMatrix {
18    Dense(Array2<f64>),
19    KroneckerFactored {
20        left: Array2<f64>,
21        right: Array2<f64>,
22    },
23    /// Block-local penalty: `local` is `block_dim × block_dim`, embedded at
24    /// `col_range` in the full parameter space of dimension `total_dim`.
25    /// Avoids materializing the full `total_dim × total_dim` matrix.
26    Blockwise {
27        local: Array2<f64>,
28        col_range: std::ops::Range<usize>,
29        total_dim: usize,
30    },
31    /// Wrapper assigning this penalty component to a user-visible precision
32    /// label. Components with the same label share one smoothing parameter.
33    Labeled {
34        label: String,
35        inner: Box<PenaltyMatrix>,
36    },
37    /// Wrapper fixing this penalty component at a physical log-precision.
38    /// Fixed components remain in the block-local physical penalty layout but
39    /// are removed from the REML outer coordinate vector.
40    Fixed {
41        log_lambda: f64,
42        inner: Box<PenaltyMatrix>,
43    },
44}
45
46impl PenaltyMatrix {
47    /// Number of rows (= number of columns, since penalties are square).
48    pub fn dim(&self) -> usize {
49        match self {
50            Self::Dense(m) => m.nrows(),
51            Self::KroneckerFactored { left, right } => left.nrows() * right.nrows(),
52            Self::Blockwise { total_dim, .. } => *total_dim,
53            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dim(),
54        }
55    }
56
57    /// Returns (nrows, ncols) like Array2::dim().
58    pub fn shape(&self) -> (usize, usize) {
59        let d = self.dim();
60        (d, d)
61    }
62
63    /// Materialize the full dense matrix.
64    pub fn to_dense(&self) -> Array2<f64> {
65        match self {
66            Self::Dense(m) => m.clone(),
67            Self::KroneckerFactored { left, right } => kronecker_product(left, right),
68            Self::Blockwise {
69                local,
70                col_range,
71                total_dim,
72            } => {
73                let mut g = Array2::zeros((*total_dim, *total_dim));
74                g.slice_mut(ndarray::s![
75                    col_range.start..col_range.end,
76                    col_range.start..col_range.end
77                ])
78                .assign(local);
79                g
80            }
81            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.to_dense(),
82        }
83    }
84
85    /// Borrow the inner dense matrix if Dense, otherwise materialize.
86    pub fn as_dense_cow(&self) -> std::borrow::Cow<'_, Array2<f64>> {
87        match self {
88            Self::Dense(m) => std::borrow::Cow::Borrowed(m),
89            Self::KroneckerFactored { .. }
90            | Self::Blockwise { .. }
91            | Self::Labeled { .. }
92            | Self::Fixed { .. } => std::borrow::Cow::Owned(self.to_dense()),
93        }
94    }
95
96    /// Returns a reference to the inner matrix if this is a Dense variant.
97    pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
98        match self {
99            Self::Dense(m) => Some(m),
100            Self::Fixed { inner, .. } => inner.as_dense_ref(),
101            Self::KroneckerFactored { .. } | Self::Blockwise { .. } | Self::Labeled { .. } => None,
102        }
103    }
104
105    pub fn with_precision_label(self, label: impl Into<String>) -> Self {
106        Self::Labeled {
107            label: label.into(),
108            inner: Box::new(self),
109        }
110    }
111
112    pub fn precision_label(&self) -> Option<&str> {
113        match self {
114            Self::Labeled { label, .. } => Some(label.as_str()),
115            Self::Fixed { .. } => None,
116            _ => None,
117        }
118    }
119
120    pub fn with_fixed_log_lambda(self, log_lambda: f64) -> Self {
121        Self::Fixed {
122            log_lambda,
123            inner: Box::new(self),
124        }
125    }
126
127    pub fn fixed_log_lambda(&self) -> Option<f64> {
128        match self {
129            Self::Fixed { log_lambda, .. } => Some(*log_lambda),
130            Self::Labeled { inner, .. } => inner.fixed_log_lambda(),
131            _ => None,
132        }
133    }
134
135    /// Compute S * v using the row-major Kronecker vec trick when factored:
136    ///   (A ⊗ B) vec_rm(V) = vec_rm(A V Bᵀ)
137    /// where V = reshape(v, (p_left, p_right)).
138    pub fn dot(&self, v: &Array1<f64>) -> Array1<f64> {
139        match self {
140            Self::Dense(m) => m.dot(v),
141            Self::KroneckerFactored { left, right } => {
142                let p_left = left.nrows();
143                let p_right = right.nrows();
144                // v is ordered by i_left * p_right + i_right.
145                let v_mat =
146                    ndarray::ArrayView2::from_shape((p_left, p_right), v.as_slice().unwrap())
147                        .unwrap();
148                let avbt = left.dot(&v_mat).dot(&right.t());
149                let standard = avbt.as_standard_layout();
150                Array1::from_iter(standard.iter().copied())
151            }
152            Self::Blockwise {
153                local,
154                col_range,
155                total_dim,
156            } => {
157                let mut out = Array1::zeros(*total_dim);
158                let v_block = v.slice(ndarray::s![col_range.clone()]);
159                let result_block = local.dot(&v_block);
160                out.slice_mut(ndarray::s![col_range.clone()])
161                    .assign(&result_block);
162                out
163            }
164            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dot(v),
165        }
166    }
167
168    /// Add λ * self to a mutable dense accumulator.
169    pub fn add_scaled_to(&self, lambda: f64, target: &mut Array2<f64>) {
170        match self {
171            Self::Dense(m) => {
172                target.scaled_add(lambda, m);
173            }
174            Self::KroneckerFactored { left, right } => {
175                let p_left = left.nrows();
176                let p_right = right.nrows();
177                for i1 in 0..p_left {
178                    for j1 in 0..p_left {
179                        let a_ij = left[[i1, j1]];
180                        if a_ij == 0.0 {
181                            continue;
182                        }
183                        let scaled_a = lambda * a_ij;
184                        for i2 in 0..p_right {
185                            let row = i1 * p_right + i2;
186                            for j2 in 0..p_right {
187                                let col = j1 * p_right + j2;
188                                target[[row, col]] += scaled_a * right[[i2, j2]];
189                            }
190                        }
191                    }
192                }
193            }
194            Self::Blockwise {
195                local, col_range, ..
196            } => {
197                target
198                    .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
199                    .scaled_add(lambda, local);
200            }
201            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
202                inner.add_scaled_to(lambda, target)
203            }
204        }
205    }
206
207    /// Add λ * diag(self) to a mutable diagonal accumulator.
208    pub fn add_scaled_diag_to(&self, lambda: f64, target: &mut Array1<f64>) {
209        match self {
210            Self::Dense(m) => {
211                let p = m.nrows().min(target.len());
212                for j in 0..p {
213                    target[j] += lambda * m[[j, j]];
214                }
215            }
216            Self::KroneckerFactored { left, right } => {
217                let p_left = left.nrows();
218                let p_right = right.nrows();
219                assert_eq!(target.len(), p_left * p_right);
220                for i_left in 0..p_left {
221                    let left_diag = left[[i_left, i_left]];
222                    if left_diag == 0.0 {
223                        continue;
224                    }
225                    let scaled_left = lambda * left_diag;
226                    for i_right in 0..p_right {
227                        target[i_left * p_right + i_right] +=
228                            scaled_left * right[[i_right, i_right]];
229                    }
230                }
231            }
232            Self::Blockwise {
233                local, col_range, ..
234            } => {
235                let width = local.nrows().min(col_range.len());
236                for local_idx in 0..width {
237                    target[col_range.start + local_idx] += lambda * local[[local_idx, local_idx]];
238                }
239            }
240            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
241                inner.add_scaled_diag_to(lambda, target)
242            }
243        }
244    }
245
246    /// Compute the quadratic form β' S β.
247    pub fn quadratic_form(&self, beta: &Array1<f64>) -> f64 {
248        match self {
249            Self::Dense(m) => beta.dot(&m.dot(beta)),
250            Self::KroneckerFactored { .. } => {
251                let sv = self.dot(beta);
252                beta.dot(&sv)
253            }
254            Self::Blockwise {
255                local, col_range, ..
256            } => {
257                let beta_block = beta.slice(ndarray::s![col_range.clone()]);
258                let sv = local.dot(&beta_block);
259                beta_block.dot(&sv)
260            }
261            Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.quadratic_form(beta),
262        }
263    }
264
265    /// Access dimensions like an Array2.
266    pub fn nrows(&self) -> usize {
267        self.dim()
268    }
269
270    pub fn ncols(&self) -> usize {
271        self.dim()
272    }
273}
274
275impl From<Array2<f64>> for PenaltyMatrix {
276    fn from(m: Array2<f64>) -> Self {
277        Self::Dense(m)
278    }
279}
280
281/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
282/// This is used to create tensor product penalties that enforce smoothness
283/// in multiple dimensions for interaction terms.
284fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
285    let (arows, a_cols) = a.dim();
286    let (brows, b_cols) = b.dim();
287    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
288        return Array2::zeros((arows * brows, a_cols * b_cols));
289    }
290    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
291
292    result
293        .axis_chunks_iter_mut(Axis(0), brows)
294        .into_par_iter()
295        .enumerate()
296        .for_each(|(i, mut row_block)| {
297            let arow = a.row(i);
298            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
299            for (j, mut block) in col_chunks.into_iter().enumerate() {
300                let aval = arow[j];
301                if aval == 0.0 {
302                    continue;
303                }
304                for (dest, &src) in block.iter_mut().zip(b.iter()) {
305                    *dest = aval * src;
306                }
307            }
308        });
309
310    result
311}