1use crate::linalg_helpers::{dense_matvec_into, dense_transpose_matvec_scaled_add_into};
6use crate::reml_contract_panic;
7use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
8
9#[derive(Clone, Debug)]
19pub enum PenaltyCoordinate {
20 DenseRoot(Array2<f64>),
21 DenseRootCentered {
22 root: Array2<f64>,
23 prior_mean: Array1<f64>,
24 },
25 BlockRoot {
26 root: Array2<f64>,
27 start: usize,
28 end: usize,
29 total_dim: usize,
30 },
31 BlockRootCentered {
32 root: Array2<f64>,
33 start: usize,
34 end: usize,
35 total_dim: usize,
36 prior_mean: Array1<f64>,
37 },
38 KroneckerMarginal {
45 eigenvalues: Vec<Array1<f64>>,
47 dim_index: usize,
49 marginal_dims: Vec<usize>,
51 total_dim: usize,
53 },
54}
55
56impl PenaltyCoordinate {
57 pub fn from_dense_root(root: Array2<f64>) -> Self {
58 Self::DenseRoot(root)
59 }
60
61 pub fn from_dense_root_with_mean(root: Array2<f64>, prior_mean: Array1<f64>) -> Self {
62 assert_eq!(root.ncols(), prior_mean.len());
63 if prior_mean.iter().all(|&value| value == 0.0) {
64 Self::DenseRoot(root)
65 } else {
66 Self::DenseRootCentered { root, prior_mean }
67 }
68 }
69
70 pub fn from_block_root(root: Array2<f64>, start: usize, end: usize, total_dim: usize) -> Self {
71 assert_eq!(
72 root.ncols(),
73 end.saturating_sub(start),
74 "block prior root column count must match block width"
75 );
76 assert!(
77 end <= total_dim,
78 "block prior root end exceeds total dimension: start={start}, end={end}, total_dim={total_dim}, root_dim={:?}",
79 root.dim()
80 );
81 Self::BlockRoot {
82 root,
83 start,
84 end,
85 total_dim,
86 }
87 }
88
89 pub fn from_block_root_with_mean(
90 root: Array2<f64>,
91 start: usize,
92 end: usize,
93 total_dim: usize,
94 prior_mean: Array1<f64>,
95 ) -> Self {
96 assert_eq!(
97 root.ncols(),
98 end.saturating_sub(start),
99 "centered block prior root column count must match block width"
100 );
101 assert_eq!(
102 prior_mean.len(),
103 end.saturating_sub(start),
104 "centered block prior mean length must match block width"
105 );
106 assert!(
107 end <= total_dim,
108 "centered block prior root end exceeds total dimension: start={start}, end={end}, total_dim={total_dim}, root_dim={:?}, prior_mean_len={}",
109 root.dim(),
110 prior_mean.len()
111 );
112 if prior_mean.iter().all(|&value| value == 0.0) {
113 Self::from_block_root(root, start, end, total_dim)
114 } else {
115 Self::BlockRootCentered {
116 root,
117 start,
118 end,
119 total_dim,
120 prior_mean,
121 }
122 }
123 }
124
125 pub fn rank(&self) -> usize {
126 match self {
127 Self::DenseRoot(root)
128 | Self::DenseRootCentered { root, .. }
129 | Self::BlockRoot { root, .. }
130 | Self::BlockRootCentered { root, .. } => root.nrows(),
131 Self::KroneckerMarginal {
132 eigenvalues,
133 dim_index,
134 ..
135 } => {
136 let nz = eigenvalues[*dim_index]
139 .iter()
140 .filter(|&&v| v.abs() > 1e-12)
141 .count();
142 let other: usize = eigenvalues
143 .iter()
144 .enumerate()
145 .filter(|&(j, _)| j != *dim_index)
146 .map(|(_, e)| e.len())
147 .product::<usize>()
148 .max(1);
149 nz * other
150 }
151 }
152 }
153
154 pub fn dim(&self) -> usize {
155 match self {
156 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => root.ncols(),
157 Self::BlockRoot { total_dim, .. }
158 | Self::BlockRootCentered { total_dim, .. }
159 | Self::KroneckerMarginal { total_dim, .. } => *total_dim,
160 }
161 }
162
163 pub fn uses_operator_fast_path(&self) -> bool {
164 matches!(
165 self,
166 Self::BlockRoot { .. }
167 | Self::BlockRootCentered { .. }
168 | Self::KroneckerMarginal { .. }
169 )
170 }
171
172 pub fn project_into_subspace(&self, z: &Array2<f64>) -> Self {
189 assert_eq!(
190 z.nrows(),
191 self.dim(),
192 "PenaltyCoordinate::project_into_subspace: free-basis row count {} does not match coordinate dimension {}",
193 z.nrows(),
194 self.dim()
195 );
196 match self {
197 Self::DenseRoot(root) => Self::DenseRoot(root.dot(z)),
198 Self::DenseRootCentered { root, prior_mean } => {
199 Self::from_dense_root_with_mean(root.dot(z), z.t().dot(prior_mean))
200 }
201 Self::BlockRoot {
202 root, start, end, ..
203 } => {
204 let z_block = z.slice(ndarray::s![*start..*end, ..]);
205 Self::DenseRoot(root.dot(&z_block))
206 }
207 Self::BlockRootCentered {
208 root,
209 start,
210 end,
211 prior_mean,
212 ..
213 } => {
214 let z_block = z.slice(ndarray::s![*start..*end, ..]);
215 let z_block_owned = z_block.to_owned();
219 Self::from_dense_root_with_mean(
220 root.dot(&z_block_owned),
221 z_block_owned.t().dot(prior_mean),
222 )
223 }
224 Self::KroneckerMarginal { .. } => reml_contract_panic(
225 "PenaltyCoordinate::project_into_subspace: Kronecker-factored \
226 coordinates do not co-occur with linear-inequality active sets \
227 (box/monotone constraints lower to dense/block roots)",
228 ),
229 }
230 }
231
232 pub(crate) fn apply_root(&self, beta: &Array1<f64>) -> Array1<f64> {
233 assert_eq!(beta.len(), self.dim());
234 match self {
235 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => root.dot(beta),
236 Self::BlockRoot {
237 root, start, end, ..
238 }
239 | Self::BlockRootCentered {
240 root, start, end, ..
241 } => root.dot(&beta.slice(ndarray::s![*start..*end])),
242 Self::KroneckerMarginal { .. } => {
243 reml_contract_panic(
251 "apply_root not supported for KroneckerMarginal; use apply_penalty directly",
252 );
253 }
254 }
255 }
256
257 pub fn apply_penalty(&self, beta: &Array1<f64>, scale: f64) -> Array1<f64> {
258 assert_eq!(beta.len(), self.dim());
259 let mut out = Array1::<f64>::zeros(self.dim());
260 self.apply_penalty_view_into(beta.view(), scale, out.view_mut());
261 out
262 }
263
264 pub fn apply_penalty_view_into(
265 &self,
266 beta: ArrayView1<'_, f64>,
267 scale: f64,
268 mut out: ArrayViewMut1<'_, f64>,
269 ) {
270 assert_eq!(beta.len(), self.dim());
271 assert_eq!(out.len(), self.dim());
272 out.fill(0.0);
273 self.scaled_add_penalty_view(beta, scale, out);
274 }
275
276 pub fn scaled_add_penalty_view(
277 &self,
278 beta: ArrayView1<'_, f64>,
279 scale: f64,
280 mut out: ArrayViewMut1<'_, f64>,
281 ) {
282 assert_eq!(beta.len(), self.dim());
283 assert_eq!(out.len(), self.dim());
284 if scale == 0.0 {
285 return;
286 }
287 match self {
288 Self::DenseRoot(_)
289 | Self::DenseRootCentered { .. }
290 | Self::BlockRoot { .. }
291 | Self::BlockRootCentered { .. } => match self {
292 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
293 let mut root_beta = Array1::<f64>::zeros(root.nrows());
294 dense_matvec_into(root, beta, root_beta.view_mut());
295 dense_transpose_matvec_scaled_add_into(
296 root,
297 root_beta.view(),
298 scale,
299 out.view_mut(),
300 );
301 }
302 Self::BlockRoot {
303 root,
304 start,
305 end,
306 total_dim: _,
307 }
308 | Self::BlockRootCentered {
309 root,
310 start,
311 end,
312 total_dim: _,
313 ..
314 } => {
315 let beta_block = beta.slice(ndarray::s![*start..*end]);
316 let mut root_beta = Array1::<f64>::zeros(root.nrows());
317 dense_matvec_into(root, beta_block, root_beta.view_mut());
318 let out_block = out.slice_mut(ndarray::s![*start..*end]);
319 dense_transpose_matvec_scaled_add_into(
320 root,
321 root_beta.view(),
322 scale,
323 out_block,
324 );
325 }
326 Self::KroneckerMarginal { .. } => {}
328 },
329 Self::KroneckerMarginal {
330 eigenvalues,
331 dim_index,
332 marginal_dims,
333 total_dim,
334 } => {
335 let k = *dim_index;
338 let q_k = marginal_dims[k];
339 let stride_k: usize = marginal_dims[k + 1..]
340 .iter()
341 .copied()
342 .product::<usize>()
343 .max(1);
344 let outer_size: usize =
345 marginal_dims[..k].iter().copied().product::<usize>().max(1);
346 let inner_size = stride_k;
347 let eigs = &eigenvalues[k];
348 assert_eq!(
349 outer_size * q_k * stride_k,
350 *total_dim,
351 "KroneckerMarginal dimension mismatch in apply"
352 );
353
354 for outer in 0..outer_size {
355 for j in 0..q_k {
356 let mu = eigs[j] * scale;
357 if mu == 0.0 {
358 continue;
359 }
360 let base = outer * q_k * stride_k + j * stride_k;
361 for inner in 0..inner_size {
362 let idx = base + inner;
363 out[idx] += mu * beta[idx];
364 }
365 }
366 }
367 }
368 }
369 }
370
371 pub fn quadratic(&self, beta: &Array1<f64>, scale: f64) -> f64 {
372 match self {
373 Self::DenseRoot(_)
374 | Self::DenseRootCentered { .. }
375 | Self::BlockRoot { .. }
376 | Self::BlockRootCentered { .. } => {
377 let root_beta = self.apply_root(beta);
378 scale * root_beta.dot(&root_beta)
379 }
380 Self::KroneckerMarginal {
381 eigenvalues,
382 dim_index,
383 marginal_dims,
384 ..
385 } => {
386 let k = *dim_index;
388 let q_k = marginal_dims[k];
389 let stride_k: usize = marginal_dims[k + 1..]
390 .iter()
391 .copied()
392 .product::<usize>()
393 .max(1);
394 let outer_size: usize =
395 marginal_dims[..k].iter().copied().product::<usize>().max(1);
396 let inner_size = stride_k;
397 let eigs = &eigenvalues[k];
398
399 let mut sum = 0.0;
400 for outer in 0..outer_size {
401 for j in 0..q_k {
402 let mu = eigs[j];
403 if mu == 0.0 {
404 continue;
405 }
406 let base = outer * q_k * stride_k + j * stride_k;
407 for inner in 0..inner_size {
408 let v = beta[base + inner];
409 sum += mu * v * v;
410 }
411 }
412 }
413 sum * scale
414 }
415 }
416 }
417
418 pub fn apply_shifted_penalty(&self, beta: &Array1<f64>, scale: f64) -> Array1<f64> {
419 match self {
420 Self::DenseRootCentered { root, prior_mean } => {
421 let centered = beta - prior_mean;
422 let root_beta = root.dot(¢ered);
423 let mut out = root.t().dot(&root_beta);
424 out *= scale;
425 out
426 }
427 Self::BlockRootCentered {
428 root,
429 start,
430 end,
431 total_dim,
432 prior_mean,
433 } => {
434 let mut out = Array1::<f64>::zeros(*total_dim);
435 let beta_block = beta.slice(ndarray::s![*start..*end]);
436 let centered = beta_block.to_owned() - prior_mean;
437 let root_beta = root.dot(¢ered);
438 let mut block = root.t().dot(&root_beta);
439 block *= scale;
440 out.slice_mut(ndarray::s![*start..*end]).assign(&block);
441 out
442 }
443 _ => self.apply_penalty(beta, scale),
444 }
445 }
446
447 pub fn shifted_quadratic(&self, beta: &Array1<f64>, scale: f64) -> f64 {
448 match self {
449 Self::DenseRootCentered { root, prior_mean } => {
450 let centered = beta - prior_mean;
451 let root_beta = root.dot(¢ered);
452 scale * root_beta.dot(&root_beta)
453 }
454 Self::BlockRootCentered {
455 root,
456 start,
457 end,
458 prior_mean,
459 ..
460 } => {
461 let beta_block = beta.slice(ndarray::s![*start..*end]);
462 let centered = beta_block.to_owned() - prior_mean;
463 let root_beta = root.dot(¢ered);
464 scale * root_beta.dot(&root_beta)
465 }
466 _ => self.quadratic(beta, scale),
467 }
468 }
469
470 pub fn scaled_dense_matrix(&self, scale: f64) -> Array2<f64> {
471 match self {
472 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
473 let mut out = root.t().dot(root);
474 out *= scale;
475 out
476 }
477 Self::BlockRoot {
478 root,
479 start,
480 end,
481 total_dim,
482 }
483 | Self::BlockRootCentered {
484 root,
485 start,
486 end,
487 total_dim,
488 ..
489 } => {
490 let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
491 let mut block = root.t().dot(root);
492 block *= scale;
493 out.slice_mut(ndarray::s![*start..*end, *start..*end])
494 .assign(&block);
495 out
496 }
497 Self::KroneckerMarginal {
498 eigenvalues,
499 dim_index,
500 marginal_dims,
501 total_dim,
502 } => {
503 let k = *dim_index;
505 let q_k = marginal_dims[k];
506 let stride_k: usize = marginal_dims[k + 1..]
507 .iter()
508 .copied()
509 .product::<usize>()
510 .max(1);
511 let outer_size: usize =
512 marginal_dims[..k].iter().copied().product::<usize>().max(1);
513 let eigs = &eigenvalues[k];
514 assert_eq!(
515 outer_size * q_k * stride_k,
516 *total_dim,
517 "KroneckerMarginal dimension mismatch in to_dense"
518 );
519
520 let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
521 for outer in 0..outer_size {
522 for j in 0..q_k {
523 let mu = eigs[j] * scale;
524 let base = outer * q_k * stride_k + j * stride_k;
525 for inner in 0..stride_k {
526 let idx = base + inner;
527 out[[idx, idx]] = mu;
528 }
529 }
530 }
531 out
532 }
533 }
534 }
535
536 pub fn scaled_block_local(&self, scale: f64) -> (Array2<f64>, usize, usize) {
540 match self {
541 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
542 let mut out = root.t().dot(root);
543 out *= scale;
544 let p = out.nrows();
545 (out, 0, p)
546 }
547 Self::BlockRoot {
548 root, start, end, ..
549 }
550 | Self::BlockRootCentered {
551 root, start, end, ..
552 } => {
553 let mut block = root.t().dot(root);
554 block *= scale;
555 (block, *start, *end)
556 }
557 Self::KroneckerMarginal { total_dim, .. } => {
558 let mat = self.scaled_dense_matrix(scale);
560 (mat, 0, *total_dim)
561 }
562 }
563 }
564
565 pub fn is_block_local(&self) -> bool {
567 matches!(
568 self,
569 Self::BlockRoot { .. }
570 | Self::BlockRootCentered { .. }
571 | Self::KroneckerMarginal { .. }
572 )
573 }
574
575 pub fn scaled_matvec(&self, v: &Array1<f64>, scale: f64) -> Array1<f64> {
578 match self {
579 Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
580 let root_v = root.dot(v);
581 let mut out = root.t().dot(&root_v);
582 out *= scale;
583 out
584 }
585 Self::BlockRoot {
586 root, start, end, ..
587 }
588 | Self::BlockRootCentered {
589 root, start, end, ..
590 } => {
591 let mut out = Array1::zeros(v.len());
592 let v_block = v.slice(ndarray::s![*start..*end]);
593 let root_v = root.dot(&v_block);
594 let mut block_result = root.t().dot(&root_v);
595 block_result *= scale;
596 out.slice_mut(ndarray::s![*start..*end])
597 .assign(&block_result);
598 out
599 }
600 Self::KroneckerMarginal { .. } => {
601 self.apply_penalty(v, scale)
603 }
604 }
605 }
606
607 pub fn canonical_structural_key(&self) -> u64 {
629 use std::hash::{Hash, Hasher};
630 let mut hasher = std::collections::hash_map::DefaultHasher::new();
631
632 let quant = |v: f64| -> i64 {
636 if !v.is_finite() || v.abs() <= 1e-300 {
637 return 0;
638 }
639 let q = (v.abs().ln() * 1.0e6).round() as i64;
641 if v < 0.0 { -q } else { q }
642 };
643
644 match self {
645 Self::DenseRoot(root)
646 | Self::DenseRootCentered { root, .. }
647 | Self::BlockRoot { root, .. }
648 | Self::BlockRootCentered { root, .. } => {
649 0u8.hash(&mut hasher);
654 root.nrows().hash(&mut hasher); root.ncols().hash(&mut hasher); let sk = root.t().dot(root);
657 let n = sk.nrows().min(sk.ncols());
666 let trace1 = (0..n).map(|i| sk[[i, i]]).sum::<f64>();
667 let frob_sq = sk.iter().map(|&x| x * x).sum::<f64>(); let sk2 = sk.dot(&sk);
669 let trace3 = {
670 let sk3diag = sk2.dot(&sk);
671 (0..n).map(|i| sk3diag[[i, i]]).sum::<f64>()
672 };
673 let mut invariants = [quant(trace1), quant(frob_sq), quant(trace3)];
674 invariants.sort_unstable();
677 invariants.hash(&mut hasher);
678 }
679 Self::KroneckerMarginal {
680 eigenvalues,
681 dim_index,
682 marginal_dims,
683 ..
684 } => {
685 1u8.hash(&mut hasher);
690 let mut margin_spectrum: Vec<i64> =
691 eigenvalues[*dim_index].iter().map(|&e| quant(e)).collect();
692 margin_spectrum.sort_unstable();
693 margin_spectrum.hash(&mut hasher);
694 let mut dims_sorted = marginal_dims.clone();
695 dims_sorted.sort_unstable();
696 dims_sorted.hash(&mut hasher);
697 }
698 }
699
700 hasher.finish()
701 }
702}