1use ndarray::{Array1, Array2, Axis};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7#[derive(Clone, Debug)]
17pub enum PenaltyMatrix {
18 Dense(Array2<f64>),
19 KroneckerFactored {
20 left: Array2<f64>,
21 right: Array2<f64>,
22 },
23 Blockwise {
27 local: Array2<f64>,
28 col_range: std::ops::Range<usize>,
29 total_dim: usize,
30 },
31 Labeled {
34 label: String,
35 inner: Box<PenaltyMatrix>,
36 },
37 Fixed {
41 log_lambda: f64,
42 inner: Box<PenaltyMatrix>,
43 },
44}
45
46impl PenaltyMatrix {
47 pub fn dim(&self) -> usize {
49 match self {
50 Self::Dense(m) => m.nrows(),
51 Self::KroneckerFactored { left, right } => left.nrows() * right.nrows(),
52 Self::Blockwise { total_dim, .. } => *total_dim,
53 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dim(),
54 }
55 }
56
57 pub fn shape(&self) -> (usize, usize) {
59 let d = self.dim();
60 (d, d)
61 }
62
63 pub fn to_dense(&self) -> Array2<f64> {
65 match self {
66 Self::Dense(m) => m.clone(),
67 Self::KroneckerFactored { left, right } => kronecker_product(left, right),
68 Self::Blockwise {
69 local,
70 col_range,
71 total_dim,
72 } => {
73 let mut g = Array2::zeros((*total_dim, *total_dim));
74 g.slice_mut(ndarray::s![
75 col_range.start..col_range.end,
76 col_range.start..col_range.end
77 ])
78 .assign(local);
79 g
80 }
81 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.to_dense(),
82 }
83 }
84
85 pub fn as_dense_cow(&self) -> std::borrow::Cow<'_, Array2<f64>> {
87 match self {
88 Self::Dense(m) => std::borrow::Cow::Borrowed(m),
89 Self::KroneckerFactored { .. }
90 | Self::Blockwise { .. }
91 | Self::Labeled { .. }
92 | Self::Fixed { .. } => std::borrow::Cow::Owned(self.to_dense()),
93 }
94 }
95
96 pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
98 match self {
99 Self::Dense(m) => Some(m),
100 Self::Fixed { inner, .. } => inner.as_dense_ref(),
101 Self::KroneckerFactored { .. } | Self::Blockwise { .. } | Self::Labeled { .. } => None,
102 }
103 }
104
105 pub fn with_precision_label(self, label: impl Into<String>) -> Self {
106 Self::Labeled {
107 label: label.into(),
108 inner: Box::new(self),
109 }
110 }
111
112 pub fn precision_label(&self) -> Option<&str> {
113 match self {
114 Self::Labeled { label, .. } => Some(label.as_str()),
115 Self::Fixed { .. } => None,
116 _ => None,
117 }
118 }
119
120 pub fn with_fixed_log_lambda(self, log_lambda: f64) -> Self {
121 Self::Fixed {
122 log_lambda,
123 inner: Box::new(self),
124 }
125 }
126
127 pub fn fixed_log_lambda(&self) -> Option<f64> {
128 match self {
129 Self::Fixed { log_lambda, .. } => Some(*log_lambda),
130 Self::Labeled { inner, .. } => inner.fixed_log_lambda(),
131 _ => None,
132 }
133 }
134
135 pub fn dot(&self, v: &Array1<f64>) -> Array1<f64> {
139 match self {
140 Self::Dense(m) => m.dot(v),
141 Self::KroneckerFactored { left, right } => {
142 let p_left = left.nrows();
143 let p_right = right.nrows();
144 let v_mat =
146 ndarray::ArrayView2::from_shape((p_left, p_right), v.as_slice().unwrap())
147 .unwrap();
148 let avbt = left.dot(&v_mat).dot(&right.t());
149 let standard = avbt.as_standard_layout();
150 Array1::from_iter(standard.iter().copied())
151 }
152 Self::Blockwise {
153 local,
154 col_range,
155 total_dim,
156 } => {
157 let mut out = Array1::zeros(*total_dim);
158 let v_block = v.slice(ndarray::s![col_range.clone()]);
159 let result_block = local.dot(&v_block);
160 out.slice_mut(ndarray::s![col_range.clone()])
161 .assign(&result_block);
162 out
163 }
164 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dot(v),
165 }
166 }
167
168 pub fn add_scaled_to(&self, lambda: f64, target: &mut Array2<f64>) {
170 match self {
171 Self::Dense(m) => {
172 target.scaled_add(lambda, m);
173 }
174 Self::KroneckerFactored { left, right } => {
175 let p_left = left.nrows();
176 let p_right = right.nrows();
177 for i1 in 0..p_left {
178 for j1 in 0..p_left {
179 let a_ij = left[[i1, j1]];
180 if a_ij == 0.0 {
181 continue;
182 }
183 let scaled_a = lambda * a_ij;
184 for i2 in 0..p_right {
185 let row = i1 * p_right + i2;
186 for j2 in 0..p_right {
187 let col = j1 * p_right + j2;
188 target[[row, col]] += scaled_a * right[[i2, j2]];
189 }
190 }
191 }
192 }
193 }
194 Self::Blockwise {
195 local, col_range, ..
196 } => {
197 target
198 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
199 .scaled_add(lambda, local);
200 }
201 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
202 inner.add_scaled_to(lambda, target)
203 }
204 }
205 }
206
207 pub fn add_scaled_diag_to(&self, lambda: f64, target: &mut Array1<f64>) {
209 match self {
210 Self::Dense(m) => {
211 let p = m.nrows().min(target.len());
212 for j in 0..p {
213 target[j] += lambda * m[[j, j]];
214 }
215 }
216 Self::KroneckerFactored { left, right } => {
217 let p_left = left.nrows();
218 let p_right = right.nrows();
219 assert_eq!(target.len(), p_left * p_right);
220 for i_left in 0..p_left {
221 let left_diag = left[[i_left, i_left]];
222 if left_diag == 0.0 {
223 continue;
224 }
225 let scaled_left = lambda * left_diag;
226 for i_right in 0..p_right {
227 target[i_left * p_right + i_right] +=
228 scaled_left * right[[i_right, i_right]];
229 }
230 }
231 }
232 Self::Blockwise {
233 local, col_range, ..
234 } => {
235 let width = local.nrows().min(col_range.len());
236 for local_idx in 0..width {
237 target[col_range.start + local_idx] += lambda * local[[local_idx, local_idx]];
238 }
239 }
240 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
241 inner.add_scaled_diag_to(lambda, target)
242 }
243 }
244 }
245
246 pub fn quadratic_form(&self, beta: &Array1<f64>) -> f64 {
248 match self {
249 Self::Dense(m) => beta.dot(&m.dot(beta)),
250 Self::KroneckerFactored { .. } => {
251 let sv = self.dot(beta);
252 beta.dot(&sv)
253 }
254 Self::Blockwise {
255 local, col_range, ..
256 } => {
257 let beta_block = beta.slice(ndarray::s![col_range.clone()]);
258 let sv = local.dot(&beta_block);
259 beta_block.dot(&sv)
260 }
261 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.quadratic_form(beta),
262 }
263 }
264
265 pub fn nrows(&self) -> usize {
267 self.dim()
268 }
269
270 pub fn ncols(&self) -> usize {
271 self.dim()
272 }
273}
274
275impl From<Array2<f64>> for PenaltyMatrix {
276 fn from(m: Array2<f64>) -> Self {
277 Self::Dense(m)
278 }
279}
280
281fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
285 let (arows, a_cols) = a.dim();
286 let (brows, b_cols) = b.dim();
287 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
288 return Array2::zeros((arows * brows, a_cols * b_cols));
289 }
290 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
291
292 result
293 .axis_chunks_iter_mut(Axis(0), brows)
294 .into_par_iter()
295 .enumerate()
296 .for_each(|(i, mut row_block)| {
297 let arow = a.row(i);
298 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
299 for (j, mut block) in col_chunks.into_iter().enumerate() {
300 let aval = arow[j];
301 if aval == 0.0 {
302 continue;
303 }
304 for (dest, &src) in block.iter_mut().zip(b.iter()) {
305 *dest = aval * src;
306 }
307 }
308 });
309
310 result
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use ndarray::array;
317
318 #[test]
321 fn dense_dim_and_shape() {
322 let m = array![[1.0, 0.0], [0.0, 2.0]];
323 let p = PenaltyMatrix::Dense(m);
324 assert_eq!(p.dim(), 2);
325 assert_eq!(p.shape(), (2, 2));
326 assert_eq!(p.nrows(), 2);
327 assert_eq!(p.ncols(), 2);
328 }
329
330 #[test]
331 fn dense_to_dense_is_clone() {
332 let m = array![[3.0, 1.0], [1.0, 4.0]];
333 let p = PenaltyMatrix::Dense(m.clone());
334 assert_eq!(p.to_dense(), m);
335 }
336
337 #[test]
338 fn dense_dot_product() {
339 let m = array![[1.0, 0.0], [0.0, 2.0]];
341 let p = PenaltyMatrix::Dense(m);
342 let v = ndarray::array![3.0, 5.0];
343 let result = p.dot(&v);
344 assert_eq!(result.as_slice().unwrap(), &[3.0, 10.0]);
345 }
346
347 #[test]
348 fn dense_quadratic_form() {
349 let m = array![[1.0, 0.0], [0.0, 2.0]];
351 let p = PenaltyMatrix::Dense(m);
352 let beta = ndarray::array![3.0, 2.0];
353 assert!((p.quadratic_form(&beta) - 17.0).abs() < 1e-14);
354 }
355
356 #[test]
357 fn dense_add_scaled_to() {
358 let s = array![[1.0, 0.0], [0.0, 1.0]];
359 let p = PenaltyMatrix::Dense(s);
360 let mut acc = ndarray::Array2::<f64>::zeros((2, 2));
361 p.add_scaled_to(3.0, &mut acc);
362 assert_eq!(acc, array![[3.0, 0.0], [0.0, 3.0]]);
363 }
364
365 #[test]
366 fn dense_add_scaled_diag_to() {
367 let s = array![[2.0, 5.0], [5.0, 7.0]];
368 let p = PenaltyMatrix::Dense(s);
369 let mut diag = ndarray::array![0.0, 0.0];
370 p.add_scaled_diag_to(1.0, &mut diag);
371 assert_eq!(diag.as_slice().unwrap(), &[2.0, 7.0]);
373 }
374
375 #[test]
378 fn kronecker_dim_is_product() {
379 let left = array![[1.0, 0.0], [0.0, 1.0]]; let right = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; let p = PenaltyMatrix::KroneckerFactored { left, right };
382 assert_eq!(p.dim(), 6);
383 }
384
385 #[test]
386 fn kronecker_to_dense_identity_x_identity() {
387 let eye2 = ndarray::Array2::<f64>::eye(2);
389 let p = PenaltyMatrix::KroneckerFactored {
390 left: eye2.clone(),
391 right: eye2,
392 };
393 let dense = p.to_dense();
394 assert_eq!(dense, ndarray::Array2::<f64>::eye(4));
395 }
396
397 #[test]
398 fn kronecker_dot_matches_dense_dot() {
399 let left = array![[2.0, 0.0], [0.0, 3.0]];
400 let right = array![[1.0, 1.0], [0.0, 1.0]];
401 let p = PenaltyMatrix::KroneckerFactored {
402 left: left.clone(),
403 right: right.clone(),
404 };
405 let dense = p.to_dense();
407 let v = ndarray::array![1.0, 2.0, 3.0, 4.0];
408 let got = p.dot(&v);
409 let expected = dense.dot(&v);
410 for (a, b) in got.iter().zip(expected.iter()) {
411 assert!((a - b).abs() < 1e-14, "got={a} expected={b}");
412 }
413 }
414
415 #[test]
418 fn blockwise_dim_is_total() {
419 let local = array![[1.0, 0.0], [0.0, 1.0]];
420 let p = PenaltyMatrix::Blockwise {
421 local,
422 col_range: 1..3,
423 total_dim: 5,
424 };
425 assert_eq!(p.dim(), 5);
426 }
427
428 #[test]
429 fn blockwise_to_dense_embeds_local_block() {
430 let local = array![[2.0, 1.0], [1.0, 3.0]];
432 let p = PenaltyMatrix::Blockwise {
433 local,
434 col_range: 1..3,
435 total_dim: 3,
436 };
437 let dense = p.to_dense();
438 assert_eq!(dense[[0, 0]], 0.0);
439 assert_eq!(dense[[1, 1]], 2.0);
440 assert_eq!(dense[[1, 2]], 1.0);
441 assert_eq!(dense[[2, 1]], 1.0);
442 assert_eq!(dense[[2, 2]], 3.0);
443 }
444
445 #[test]
446 fn blockwise_dot_only_touches_block() {
447 let local = array![[2.0, 0.0], [0.0, 3.0]];
448 let p = PenaltyMatrix::Blockwise {
449 local,
450 col_range: 1..3,
451 total_dim: 4,
452 };
453 let v = ndarray::array![7.0, 1.0, 2.0, 9.0];
454 let out = p.dot(&v);
455 assert_eq!(out.as_slice().unwrap(), &[0.0, 2.0, 6.0, 0.0]);
457 }
458
459 #[test]
462 fn labeled_inherits_dim_and_delegates_dot() {
463 let m = array![[1.0, 0.0], [0.0, 2.0]];
464 let p = PenaltyMatrix::Dense(m).with_precision_label("smooth");
465 assert_eq!(p.dim(), 2);
466 assert_eq!(p.precision_label(), Some("smooth"));
467 let v = ndarray::array![3.0, 4.0];
468 let out = p.dot(&v);
469 assert_eq!(out.as_slice().unwrap(), &[3.0, 8.0]);
470 }
471
472 #[test]
473 fn fixed_inherits_dim_and_exposes_log_lambda() {
474 let m = array![[5.0, 0.0], [0.0, 5.0]];
475 let p = PenaltyMatrix::Dense(m).with_fixed_log_lambda(2.5);
476 assert_eq!(p.dim(), 2);
477 assert_eq!(p.fixed_log_lambda(), Some(2.5));
478 }
479}