1use super::*;
2
3#[derive(Debug, Clone)]
15pub struct RowPrecisionPriorPenalty {
16 pub lambda_per_row: Array3<f64>,
17 pub weight: f64,
20 pub n_eff: usize,
22 pub learnable_weight: bool,
23 pub rho_index: usize,
24 pub target: PsiSlice,
25 pub weight_schedule: Option<ScalarWeightSchedule>,
26}
27
28impl RowPrecisionPriorPenalty {
29 #[must_use = "build error must be handled"]
30 pub fn new(
31 target: PsiSlice,
32 lambda_per_row: Array3<f64>,
33 weight: f64,
34 n_eff: usize,
35 learnable_weight: bool,
36 ) -> Result<Self, String> {
37 if target.is_empty() {
38 return Err("RowPrecisionPriorPenalty::new requires a non-empty target".to_string());
39 }
40 if !(weight.is_finite() && weight > 0.0) {
41 return Err(format!(
42 "RowPrecisionPriorPenalty::new requires finite weight > 0, got {weight}"
43 ));
44 }
45 if n_eff == 0 {
46 return Err("RowPrecisionPriorPenalty::new requires n_eff > 0".to_string());
47 }
48 if !target.len().is_multiple_of(n_eff) {
49 return Err(format!(
50 "RowPrecisionPriorPenalty::new target length {} is not divisible by n_eff {}",
51 target.len(),
52 n_eff
53 ));
54 }
55 let latent_dim = target.len() / n_eff;
56 if let Some(expected_dim) = target.latent_dim {
57 let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
58 "RowPrecisionPriorPenalty::new target shape overflows usize".to_string()
59 })?;
60 if expected != target.len() {
61 return Err(format!(
62 "RowPrecisionPriorPenalty::new target length {} does not match n_eff {} × latent_dim {}",
63 target.len(),
64 n_eff,
65 expected_dim
66 ));
67 }
68 if expected_dim != latent_dim {
69 return Err(format!(
70 "RowPrecisionPriorPenalty::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
71 ));
72 }
73 }
74 let (lambda_n, lambda_rows, lambda_cols) = lambda_per_row.dim();
75 if lambda_n != n_eff || lambda_rows != latent_dim || lambda_cols != latent_dim {
76 return Err(format!(
77 "RowPrecisionPriorPenalty::new lambda_per_row shape must be ({n_eff}, {latent_dim}, {latent_dim}), got ({lambda_n}, {lambda_rows}, {lambda_cols})"
78 ));
79 }
80 for n in 0..n_eff {
81 let mut matrix = Array2::<f64>::zeros((latent_dim, latent_dim));
82 for i in 0..latent_dim {
83 for j in 0..latent_dim {
84 let value = lambda_per_row[[n, i, j]];
85 if !value.is_finite() {
86 return Err(format!(
87 "RowPrecisionPriorPenalty::new lambda_per_row[{n},{i},{j}] must be finite"
88 ));
89 }
90 let transpose = lambda_per_row[[n, j, i]];
91 if (value - transpose).abs() >= 1.0e-10 {
92 return Err(format!(
93 "RowPrecisionPriorPenalty::new lambda_per_row[{n}] must be symmetric; |Λ[{i},{j}] - Λ[{j},{i}]| = {:.3e}",
94 (value - transpose).abs()
95 ));
96 }
97 matrix[[i, j]] = value;
98 }
99 }
100 let (evals, _) = matrix.eigh(Side::Lower).map_err(|err| {
101 format!("RowPrecisionPriorPenalty::new lambda_per_row[{n}] eigendecomposition failed: {err}")
102 })?;
103 let min_eval = evals.iter().fold(f64::INFINITY, |acc, &v| acc.min(v));
104 if !(min_eval.is_finite() && min_eval > 0.0) {
105 return Err(format!(
106 "RowPrecisionPriorPenalty::new lambda_per_row[{n}] must be positive definite; minimum eigenvalue {min_eval:.3e}"
107 ));
108 }
109 }
110 Ok(Self {
111 lambda_per_row,
112 weight,
113 n_eff,
114 learnable_weight,
115 rho_index: 0,
116 target,
117 weight_schedule: None,
118 })
119 }
120
121 impl_with_weight_schedule!(weight);
122
123 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
124 if self.learnable_weight {
125 resolve_learnable_weight(self.weight, rho[self.rho_index])
126 } else {
127 self.weight
128 }
129 }
130
131 fn latent_dim(&self, target_len: usize) -> Option<usize> {
132 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
133 assert_eq!(
134 target_len % self.n_eff.max(1),
135 0,
136 "target length must be divisible by n_eff"
137 );
138 return None;
139 }
140 Some(target_len / self.n_eff)
141 }
142
143 fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
144 let d = self.latent_dim(target.len())?;
145 target.into_shape_with_order((self.n_eff, d)).ok()
146 }
147
148 fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
149 let n_obs = m.nrows();
150 let d = m.ncols();
151 let mut out = Array1::<f64>::zeros(n_obs * d);
152 for n in 0..n_obs {
153 for a in 0..d {
154 out[n * d + a] = m[[n, a]];
155 }
156 }
157 out
158 }
159
160 pub fn diag_target(
161 &self,
162 target: ArrayView1<'_, f64>,
163 rho: ArrayView1<'_, f64>,
164 ) -> Array1<f64> {
165 let Some(t) = self.target_matrix(target) else {
166 return Array1::<f64>::zeros(target.len());
167 };
168 let weight = self.resolved_weight(rho);
169 let mut out = Array1::<f64>::zeros(target.len());
170 for n in 0..t.nrows() {
171 for i in 0..t.ncols() {
172 out[n * t.ncols() + i] = weight * self.lambda_per_row[[n, i, i]];
173 }
174 }
175 out
176 }
177
178 pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
180 let n_total = target.len();
181 let Some(t) = self.target_matrix(target) else {
182 return Array2::<f64>::zeros((n_total, n_total));
183 };
184 let d = t.ncols();
185 let weight = self.resolved_weight(rho);
186 let mut dense = Array2::<f64>::zeros((n_total, n_total));
187 for n in 0..t.nrows() {
188 for i in 0..d {
189 let row = n * d + i;
190 for j in 0..d {
191 dense[[row, n * d + j]] = weight * self.lambda_per_row[[n, i, j]];
192 }
193 }
194 }
195 dense
196 }
197
198 pub fn log_det_plus_lambda_i(
199 &self,
200 rho: ArrayView1<'_, f64>,
201 lambda: f64,
202 ) -> Result<f64, String> {
203 if !(lambda.is_finite() && lambda > 0.0) {
204 return Err(format!(
205 "RowPrecisionPriorPenalty::log_det_plus_lambda_i requires finite λ > 0; got {lambda}"
206 ));
207 }
208 let (n_obs, d, _) = self.lambda_per_row.dim();
209 let weight = self.resolved_weight(rho);
210 let mut sum = 0.0;
211 for n in 0..n_obs {
212 let mut matrix = Array2::<f64>::zeros((d, d));
213 for i in 0..d {
214 for j in 0..d {
215 matrix[[i, j]] = self.lambda_per_row[[n, i, j]];
216 }
217 }
218 let (evals, _) = matrix.eigh(Side::Lower).map_err(|err| {
219 format!("RowPrecisionPriorPenalty::log_det_plus_lambda_i lambda_per_row[{n}] eigendecomposition failed: {err}")
220 })?;
221 for &eval in evals.iter() {
222 let shifted = weight * eval + lambda;
223 if !(shifted.is_finite() && shifted > 0.0) {
224 return Err(format!(
225 "RowPrecisionPriorPenalty::log_det_plus_lambda_i non-positive shifted eigenvalue {shifted:.3e}"
226 ));
227 }
228 sum += shifted.ln();
229 }
230 }
231 Ok(sum)
232 }
233}
234
235impl AnalyticPenalty for RowPrecisionPriorPenalty {
236 fn tier(&self) -> PenaltyTier {
237 PenaltyTier::Psi
238 }
239
240 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
241 let Some(t) = self.target_matrix(target) else {
242 return 0.0;
243 };
244 let mut acc = 0.0;
245 for n in 0..t.nrows() {
246 for i in 0..t.ncols() {
247 let mut row_dot = 0.0;
248 for j in 0..t.ncols() {
249 row_dot += self.lambda_per_row[[n, i, j]] * t[[n, j]];
250 }
251 acc += t[[n, i]] * row_dot;
252 }
253 }
254 let weight = self.resolved_weight(rho);
255 let log_weight_normalizer = -0.5 * target.len() as f64 * weight.ln();
256 0.5 * weight * acc + log_weight_normalizer
257 }
258
259 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
260 let Some(t) = self.target_matrix(target) else {
261 return Array1::<f64>::zeros(target.len());
262 };
263 let weight = self.resolved_weight(rho);
264 let mut grad = Array2::<f64>::zeros(t.dim());
265 for n in 0..t.nrows() {
266 for i in 0..t.ncols() {
267 let mut acc = 0.0;
268 for j in 0..t.ncols() {
269 acc += self.lambda_per_row[[n, i, j]] * t[[n, j]];
270 }
271 grad[[n, i]] = weight * acc;
272 }
273 }
274 Self::flatten_matrix(&grad)
275 }
276
277 fn hessian_diag(
278 &self,
279 target: ArrayView1<'_, f64>,
280 rho: ArrayView1<'_, f64>,
281 ) -> Option<Array1<f64>> {
282 let Some(t) = self.target_matrix(target) else {
283 return Some(Array1::<f64>::zeros(target.len()));
284 };
285 for n in 0..t.nrows() {
286 for i in 0..t.ncols() {
287 for j in 0..t.ncols() {
288 if i != j && self.lambda_per_row[[n, i, j]] != 0.0 {
289 return None;
290 }
291 }
292 }
293 }
294 Some(self.diag_target(target, rho))
295 }
296
297 fn hvp(
298 &self,
299 target: ArrayView1<'_, f64>,
300 rho: ArrayView1<'_, f64>,
301 v: ArrayView1<'_, f64>,
302 ) -> Array1<f64> {
303 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
304 if target.len() != v.len() {
305 return Array1::<f64>::zeros(target.len());
306 }
307 let Some(t) = self.target_matrix(target) else {
308 return Array1::<f64>::zeros(target.len());
309 };
310 let Some(v_mat) = self.target_matrix(v) else {
311 return Array1::<f64>::zeros(target.len());
312 };
313 let weight = self.resolved_weight(rho);
314 let mut out = Array2::<f64>::zeros(t.dim());
315 for n in 0..v_mat.nrows() {
316 for i in 0..v_mat.ncols() {
317 let mut acc = 0.0;
318 for j in 0..v_mat.ncols() {
319 acc += self.lambda_per_row[[n, i, j]] * v_mat[[n, j]];
320 }
321 out[[n, i]] = weight * acc;
322 }
323 }
324 Self::flatten_matrix(&out)
325 }
326
327 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
328 if !self.learnable_weight {
329 return Array1::<f64>::zeros(0);
330 }
331 let Some(t) = self.target_matrix(target) else {
332 return Array1::<f64>::zeros(1);
333 };
334 let mut quad = 0.0;
335 for n in 0..t.nrows() {
336 for i in 0..t.ncols() {
337 let mut row_dot = 0.0;
338 for j in 0..t.ncols() {
339 row_dot += self.lambda_per_row[[n, i, j]] * t[[n, j]];
340 }
341 quad += t[[n, i]] * row_dot;
342 }
343 }
344 let weight = self.resolved_weight(rho);
345 let mut out = Array1::<f64>::zeros(1);
346 out[self.rho_index] = 0.5 * weight * quad - 0.5 * target.len() as f64;
347 out
348 }
349
350 impl_learnable_weight_rho_count!();
351
352 fn name(&self) -> &str {
353 "row_precision_prior"
354 }
355
356 impl_scalar_apply_schedule!(weight);
357}
358
359#[derive(Debug, Clone)]
371pub struct IvaeRidgeMeanGauge {
372 pub aux: Array2<f64>,
373 pub ridge_inv: Array2<f64>,
374 pub ridge_eps: f64,
375 pub weight: f64,
378 pub n_eff: usize,
380 pub learnable_weight: bool,
381 pub rho_index: usize,
382 pub target: PsiSlice,
383 pub weight_schedule: Option<ScalarWeightSchedule>,
384}
385
386impl IvaeRidgeMeanGauge {
387 #[must_use = "build error must be handled"]
388 pub fn new(
389 target: PsiSlice,
390 aux: Array2<f64>,
391 ridge_eps: f64,
392 weight: f64,
393 n_eff: usize,
394 learnable_weight: bool,
395 ) -> Result<Self, String> {
396 if target.is_empty() {
397 return Err("IvaeRidgeMeanGauge::new requires a non-empty target".to_string());
398 }
399 if !(weight.is_finite() && weight > 0.0) {
400 return Err(format!(
401 "IvaeRidgeMeanGauge::new requires finite weight > 0, got {weight}"
402 ));
403 }
404 if !(ridge_eps.is_finite() && ridge_eps > 0.0) {
405 return Err(format!(
406 "IvaeRidgeMeanGauge::new requires finite ridge_eps > 0, got {ridge_eps}"
407 ));
408 }
409 if n_eff == 0 {
410 return Err("IvaeRidgeMeanGauge::new requires n_eff > 0".to_string());
411 }
412 if !target.len().is_multiple_of(n_eff) {
413 return Err(format!(
414 "IvaeRidgeMeanGauge::new target length {} is not divisible by n_eff {}",
415 target.len(),
416 n_eff
417 ));
418 }
419 let latent_dim = target.len() / n_eff;
420 if let Some(expected_dim) = target.latent_dim {
421 let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
422 "IvaeRidgeMeanGauge::new target shape overflows usize".to_string()
423 })?;
424 if expected != target.len() {
425 return Err(format!(
426 "IvaeRidgeMeanGauge::new target length {} does not match n_eff {} × latent_dim {}",
427 target.len(),
428 n_eff,
429 expected_dim
430 ));
431 }
432 if expected_dim != latent_dim {
433 return Err(format!(
434 "IvaeRidgeMeanGauge::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
435 ));
436 }
437 }
438 let (aux_n, aux_dim) = aux.dim();
439 if aux_n != n_eff {
440 return Err(format!(
441 "IvaeRidgeMeanGauge::new aux rows must equal n_eff {n_eff}, got {aux_n}"
442 ));
443 }
444 if aux_dim == 0 {
445 return Err("IvaeRidgeMeanGauge::new requires aux dimension > 0".to_string());
446 }
447 for (idx, &value) in aux.iter().enumerate() {
448 if !value.is_finite() {
449 return Err(format!("IvaeRidgeMeanGauge::new aux[{idx}] must be finite"));
450 }
451 }
452 let mut gram = Array2::<f64>::zeros((aux_dim, aux_dim));
453 for n in 0..n_eff {
454 for i in 0..aux_dim {
455 for j in 0..aux_dim {
456 gram[[i, j]] += aux[[n, i]] * aux[[n, j]];
457 }
458 }
459 }
460 for i in 0..aux_dim {
461 gram[[i, i]] += ridge_eps;
462 }
463 let ridge_inv = Self::invert_spd_gram(gram)?;
464 Ok(Self {
465 aux,
466 ridge_inv,
467 ridge_eps,
468 weight,
469 n_eff,
470 learnable_weight,
471 rho_index: 0,
472 target,
473 weight_schedule: None,
474 })
475 }
476
477 impl_with_weight_schedule!(weight);
478
479 fn invert_spd_gram(gram: Array2<f64>) -> Result<Array2<f64>, String> {
480 let q = gram.nrows();
481 let (evals, evecs) = gram.eigh(Side::Lower).map_err(|err| {
482 format!("IvaeRidgeMeanGauge::new ridge Gram eigendecomposition failed: {err}")
483 })?;
484 let mut inv = Array2::<f64>::zeros((q, q));
485 for k in 0..q {
486 let eval = evals[k];
487 if !(eval.is_finite() && eval > 0.0) {
488 return Err(format!(
489 "IvaeRidgeMeanGauge::new ridge Gram must be positive definite; eigenvalue {k} is {eval:.3e}"
490 ));
491 }
492 let inv_eval = 1.0 / eval;
493 for i in 0..q {
494 for j in 0..q {
495 inv[[i, j]] += evecs[[i, k]] * evecs[[j, k]] * inv_eval;
496 }
497 }
498 }
499 Ok(inv)
500 }
501
502 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
503 if self.learnable_weight {
504 resolve_learnable_weight(self.weight, rho[self.rho_index])
505 } else {
506 self.weight
507 }
508 }
509
510 fn latent_dim(&self, target_len: usize) -> Option<usize> {
511 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
512 assert_eq!(
513 target_len % self.n_eff.max(1),
514 0,
515 "target length must be divisible by n_eff"
516 );
517 return None;
518 }
519 Some(target_len / self.n_eff)
520 }
521
522 fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
523 let d = self.latent_dim(target.len())?;
524 target.into_shape_with_order((self.n_eff, d)).ok()
525 }
526
527 fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
528 let n_obs = m.nrows();
529 let d = m.ncols();
530 let mut out = Array1::<f64>::zeros(n_obs * d);
531 for n in 0..n_obs {
532 for a in 0..d {
533 out[n * d + a] = m[[n, a]];
534 }
535 }
536 out
537 }
538
539 fn projected_matrix(&self, x: ArrayView2<'_, f64>) -> Array2<f64> {
540 let q = self.aux.ncols();
541 let d = x.ncols();
542 let mut u_t_x = Array2::<f64>::zeros((q, d));
543 for n in 0..x.nrows() {
544 for i in 0..q {
545 let u_ni = self.aux[[n, i]];
546 for a in 0..d {
547 u_t_x[[i, a]] += u_ni * x[[n, a]];
548 }
549 }
550 }
551 let mut coeff = Array2::<f64>::zeros((q, d));
552 for i in 0..q {
553 for j in 0..q {
554 let inv_ij = self.ridge_inv[[i, j]];
555 for a in 0..d {
556 coeff[[i, a]] += inv_ij * u_t_x[[j, a]];
557 }
558 }
559 }
560 let mut projected = Array2::<f64>::zeros(x.dim());
561 for n in 0..x.nrows() {
562 for i in 0..q {
563 let u_ni = self.aux[[n, i]];
564 for a in 0..d {
565 projected[[n, a]] += u_ni * coeff[[i, a]];
566 }
567 }
568 }
569 projected
570 }
571
572 fn residual_matrix(&self, x: ArrayView2<'_, f64>) -> Array2<f64> {
573 let projected = self.projected_matrix(x);
574 let mut residual = Array2::<f64>::zeros(x.dim());
575 for n in 0..x.nrows() {
576 for a in 0..x.ncols() {
577 residual[[n, a]] = x[[n, a]] - projected[[n, a]];
578 }
579 }
580 residual
581 }
582
583 pub fn diag_target(
584 &self,
585 target: ArrayView1<'_, f64>,
586 rho: ArrayView1<'_, f64>,
587 ) -> Array1<f64> {
588 let Some(t) = self.target_matrix(target) else {
589 return Array1::<f64>::zeros(target.len());
590 };
591 let weight = self.resolved_weight(rho);
592 let mut out = Array1::<f64>::zeros(target.len());
593 for n in 0..t.nrows() {
594 let mut p_nn = 0.0;
595 for i in 0..self.aux.ncols() {
596 for j in 0..self.aux.ncols() {
597 p_nn += self.aux[[n, i]] * self.ridge_inv[[i, j]] * self.aux[[n, j]];
598 }
599 }
600 let diag = weight * (1.0 - p_nn);
601 for a in 0..t.ncols() {
602 out[n * t.ncols() + a] = diag;
603 }
604 }
605 out
606 }
607
608 pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
610 let n_total = target.len();
611 let Some(t) = self.target_matrix(target) else {
612 return Array2::<f64>::zeros((n_total, n_total));
613 };
614 let d = t.ncols();
615 let weight = self.resolved_weight(rho);
616 let mut dense = Array2::<f64>::zeros((n_total, n_total));
617 for n in 0..t.nrows() {
618 for m in 0..t.nrows() {
619 let mut p_nm = 0.0;
620 for i in 0..self.aux.ncols() {
621 for j in 0..self.aux.ncols() {
622 p_nm += self.aux[[n, i]] * self.ridge_inv[[i, j]] * self.aux[[m, j]];
623 }
624 }
625 let entry = weight * (if n == m { 1.0 } else { 0.0 } - p_nm);
626 for a in 0..d {
627 dense[[n * d + a, m * d + a]] = entry;
628 }
629 }
630 }
631 dense
632 }
633}
634
635impl AnalyticPenalty for IvaeRidgeMeanGauge {
636 fn tier(&self) -> PenaltyTier {
637 PenaltyTier::Psi
638 }
639
640 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
641 let Some(t) = self.target_matrix(target) else {
642 return 0.0;
643 };
644 let residual = self.residual_matrix(t.view());
645 let mut acc = 0.0;
646 for n in 0..t.nrows() {
647 for a in 0..t.ncols() {
648 acc += t[[n, a]] * residual[[n, a]];
649 }
650 }
651 let weight = self.resolved_weight(rho);
652 let mut value = 0.5 * weight * acc;
653 if self.learnable_weight {
654 value -= 0.5 * target.len() as f64 * weight.ln();
655 }
656 value
657 }
658
659 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
660 let Some(t) = self.target_matrix(target) else {
661 return Array1::<f64>::zeros(target.len());
662 };
663 let weight = self.resolved_weight(rho);
664 let mut grad = self.residual_matrix(t.view());
665 for value in grad.iter_mut() {
666 *value *= weight;
667 }
668 Self::flatten_matrix(&grad)
669 }
670
671 fn hvp(
672 &self,
673 target: ArrayView1<'_, f64>,
674 rho: ArrayView1<'_, f64>,
675 v: ArrayView1<'_, f64>,
676 ) -> Array1<f64> {
677 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
678 if target.len() != v.len() {
679 return Array1::<f64>::zeros(target.len());
680 }
681 let Some(v_mat) = self.target_matrix(v) else {
682 return Array1::<f64>::zeros(target.len());
683 };
684 let weight = self.resolved_weight(rho);
685 let mut hv = self.residual_matrix(v_mat.view());
686 for value in hv.iter_mut() {
687 *value *= weight;
688 }
689 Self::flatten_matrix(&hv)
690 }
691
692 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
693 if !self.learnable_weight {
694 return Array1::<f64>::zeros(0);
695 }
696 if self.target_matrix(target).is_none() {
697 return Array1::<f64>::zeros(1);
698 }
699 let mut out = Array1::<f64>::zeros(1);
700 let weight = self.resolved_weight(rho);
701 out[self.rho_index] =
702 self.value(target, rho) + 0.5 * target.len() as f64 * (weight.ln() - 1.0);
703 out
704 }
705
706 impl_learnable_weight_rho_count!();
707
708 fn name(&self) -> &str {
709 "ivae_ridge_mean_gauge"
710 }
711
712 impl_scalar_apply_schedule!(weight);
713}
714
715#[derive(Debug, Clone)]
727pub struct ParametricRowPrecisionPriorPenalty {
728 pub aux: Array2<f64>,
729 pub log_alpha: Array1<f64>,
730 pub raw_beta: Array1<f64>,
731 pub mu: Array2<f64>,
732 pub weight: f64,
735 pub n_eff: usize,
737 pub learnable_weight: bool,
738 pub target: PsiSlice,
739 pub weight_schedule: Option<ScalarWeightSchedule>,
740}
741
742impl ParametricRowPrecisionPriorPenalty {
743 #[must_use = "build error must be handled"]
744 pub fn new(
745 target: PsiSlice,
746 aux: Array2<f64>,
747 log_alpha: Array1<f64>,
748 raw_beta: Array1<f64>,
749 mu: Array2<f64>,
750 weight: f64,
751 n_eff: usize,
752 learnable_weight: bool,
753 ) -> Result<Self, String> {
754 if target.is_empty() {
755 return Err(
756 "ParametricRowPrecisionPriorPenalty::new requires a non-empty target".to_string(),
757 );
758 }
759 if !(weight.is_finite() && weight > 0.0) {
760 return Err(format!(
761 "ParametricRowPrecisionPriorPenalty::new requires finite weight > 0, got {weight}"
762 ));
763 }
764 if n_eff == 0 {
765 return Err("ParametricRowPrecisionPriorPenalty::new requires n_eff > 0".to_string());
766 }
767 if !target.len().is_multiple_of(n_eff) {
768 return Err(format!(
769 "ParametricRowPrecisionPriorPenalty::new target length {} is not divisible by n_eff {}",
770 target.len(),
771 n_eff
772 ));
773 }
774 let latent_dim = target.len() / n_eff;
775 if latent_dim == 0 {
776 return Err(
777 "ParametricRowPrecisionPriorPenalty::new requires latent_dim > 0".to_string(),
778 );
779 }
780 if let Some(expected_dim) = target.latent_dim {
781 let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
782 "ParametricRowPrecisionPriorPenalty::new target shape overflows usize".to_string()
783 })?;
784 if expected != target.len() {
785 return Err(format!(
786 "ParametricRowPrecisionPriorPenalty::new target length {} does not match n_eff {} × latent_dim {}",
787 target.len(),
788 n_eff,
789 expected_dim
790 ));
791 }
792 if expected_dim != latent_dim {
793 return Err(format!(
794 "ParametricRowPrecisionPriorPenalty::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
795 ));
796 }
797 }
798 let (aux_n, aux_dim) = aux.dim();
799 if aux_n != n_eff {
800 return Err(format!(
801 "ParametricRowPrecisionPriorPenalty::new aux rows must equal n_eff {n_eff}, got {aux_n}"
802 ));
803 }
804 if aux_dim == 0 {
805 return Err(
806 "ParametricRowPrecisionPriorPenalty::new requires aux dimension > 0".to_string(),
807 );
808 }
809 if log_alpha.len() != latent_dim {
810 return Err(format!(
811 "ParametricRowPrecisionPriorPenalty::new log_alpha length must equal latent_dim {latent_dim}, got {}",
812 log_alpha.len()
813 ));
814 }
815 if raw_beta.len() != latent_dim {
816 return Err(format!(
817 "ParametricRowPrecisionPriorPenalty::new raw_beta length must equal latent_dim {latent_dim}, got {}",
818 raw_beta.len()
819 ));
820 }
821 let (mu_rows, mu_cols) = mu.dim();
822 if mu_rows != latent_dim || mu_cols != aux_dim {
823 return Err(format!(
824 "ParametricRowPrecisionPriorPenalty::new mu shape must be ({latent_dim}, {aux_dim}), got ({mu_rows}, {mu_cols})"
825 ));
826 }
827 for (idx, &value) in aux.iter().enumerate() {
828 if !value.is_finite() {
829 return Err(format!(
830 "ParametricRowPrecisionPriorPenalty::new aux[{idx}] must be finite"
831 ));
832 }
833 }
834 for k in 0..latent_dim {
835 let log_alpha_k = log_alpha[k];
836 if !log_alpha_k.is_finite() {
837 return Err(format!(
838 "ParametricRowPrecisionPriorPenalty::new log_alpha[{k}] must be finite"
839 ));
840 }
841 let alpha_k = log_alpha_k.exp();
842 if !(alpha_k.is_finite() && alpha_k > 0.0) {
843 return Err(format!(
844 "ParametricRowPrecisionPriorPenalty::new exp(log_alpha[{k}]) must be finite and > 0"
845 ));
846 }
847 let raw_beta_k = raw_beta[k];
848 if !raw_beta_k.is_finite() {
849 return Err(format!(
850 "ParametricRowPrecisionPriorPenalty::new raw_beta[{k}] must be finite"
851 ));
852 }
853 let beta_k = gam_linalg::utils::stable_softplus(raw_beta_k);
854 if !(beta_k.is_finite() && beta_k >= 0.0) {
855 return Err(format!(
856 "ParametricRowPrecisionPriorPenalty::new softplus(raw_beta[{k}]) must be finite and >= 0"
857 ));
858 }
859 }
860 for (idx, &value) in mu.iter().enumerate() {
861 if !value.is_finite() {
862 return Err(format!(
863 "ParametricRowPrecisionPriorPenalty::new mu[{idx}] must be finite"
864 ));
865 }
866 }
867 Ok(Self {
868 aux,
869 log_alpha,
870 raw_beta,
871 mu,
872 weight,
873 n_eff,
874 learnable_weight,
875 target,
876 weight_schedule: None,
877 })
878 }
879
880 impl_with_weight_schedule!(weight);
881
882 fn latent_dim(&self, target_len: usize) -> Option<usize> {
883 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
884 assert_eq!(
885 target_len % self.n_eff.max(1),
886 0,
887 "target length must be divisible by n_eff"
888 );
889 return None;
890 }
891 Some(target_len / self.n_eff)
892 }
893
894 fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
895 let d = self.latent_dim(target.len())?;
896 target.into_shape_with_order((self.n_eff, d)).ok()
897 }
898
899 fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
900 let n_obs = m.nrows();
901 let d = m.ncols();
902 let mut out = Array1::<f64>::zeros(n_obs * d);
903 for n in 0..n_obs {
904 for a in 0..d {
905 out[n * d + a] = m[[n, a]];
906 }
907 }
908 out
909 }
910
911 fn log_alpha_offset(&self) -> usize {
912 0
913 }
914
915 fn raw_beta_offset(&self) -> usize {
916 self.log_alpha.len()
917 }
918
919 fn mu_offset(&self) -> usize {
920 self.log_alpha.len() + self.raw_beta.len()
921 }
922
923 fn weight_offset(&self) -> usize {
924 self.mu_offset() + self.mu.len()
925 }
926
927 fn active_log_alpha(&self, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
928 self.log_alpha[k] + rho[self.log_alpha_offset() + k]
929 }
930
931 fn active_raw_beta(&self, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
932 self.raw_beta[k] + rho[self.raw_beta_offset() + k]
933 }
934
935 fn active_mu(&self, k: usize, a: usize, rho: ArrayView1<'_, f64>) -> f64 {
936 self.mu[[k, a]] + rho[self.mu_offset() + k * self.aux.ncols() + a]
937 }
938
939 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
940 if self.learnable_weight {
941 resolve_learnable_weight(self.weight, rho[self.weight_offset()])
942 } else {
943 self.weight
944 }
945 }
946
947 fn lambda_at(&self, n: usize, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
948 let alpha = stable_exp_log_precision(self.active_log_alpha(k, rho));
949 let beta = gam_linalg::utils::stable_softplus(self.active_raw_beta(k, rho));
950 MIN_CONDITIONAL_PRECISION + alpha + beta * self.dist2(n, k, rho)
951 }
952
953 fn dist2(&self, n: usize, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
954 let mut r2 = 0.0;
955 for a in 0..self.aux.ncols() {
956 let delta = self.aux[[n, a]] - self.active_mu(k, a, rho);
957 r2 += delta * delta;
958 }
959 r2
960 }
961
962 pub fn diag_target(
963 &self,
964 target: ArrayView1<'_, f64>,
965 rho: ArrayView1<'_, f64>,
966 ) -> Array1<f64> {
967 let Some(t) = self.target_matrix(target) else {
968 return Array1::<f64>::zeros(target.len());
969 };
970 let weight = self.resolved_weight(rho);
971 let mut out = Array1::<f64>::zeros(target.len());
972 for n in 0..t.nrows() {
973 for k in 0..t.ncols() {
974 out[n * t.ncols() + k] = weight * self.lambda_at(n, k, rho);
975 }
976 }
977 out
978 }
979
980 pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
982 let n_total = target.len();
983 let diag = self.diag_target(target, rho);
984 let mut dense = Array2::<f64>::zeros((n_total, n_total));
985 for i in 0..n_total {
986 dense[[i, i]] = diag[i];
987 }
988 dense
989 }
990
991 pub fn log_det_plus_lambda_i(
992 &self,
993 rho: ArrayView1<'_, f64>,
994 lambda: f64,
995 ) -> Result<f64, String> {
996 if !(lambda.is_finite() && lambda > 0.0) {
997 return Err(format!(
998 "ParametricRowPrecisionPriorPenalty::log_det_plus_lambda_i requires finite λ > 0; got {lambda}"
999 ));
1000 }
1001 let weight = self.resolved_weight(rho);
1002 let mut sum = 0.0;
1003 for n in 0..self.n_eff {
1004 for k in 0..self.log_alpha.len() {
1005 let shifted = lambda + weight * self.lambda_at(n, k, rho);
1006 if !(shifted.is_finite() && shifted > 0.0) {
1007 return Err(format!(
1008 "ParametricRowPrecisionPriorPenalty::log_det_plus_lambda_i non-positive shifted diagonal {shifted:.3e}"
1009 ));
1010 }
1011 sum += shifted.ln();
1012 }
1013 }
1014 Ok(sum)
1015 }
1016}
1017
1018impl AnalyticPenalty for ParametricRowPrecisionPriorPenalty {
1019 fn tier(&self) -> PenaltyTier {
1020 PenaltyTier::Psi
1021 }
1022
1023 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1024 let Some(t) = self.target_matrix(target) else {
1025 return 0.0;
1026 };
1027 let weight = self.resolved_weight(rho);
1028 let mut quadratic = 0.0;
1029 let mut log_det = 0.0;
1030 for n in 0..t.nrows() {
1031 for k in 0..t.ncols() {
1032 let lambda = self.lambda_at(n, k, rho);
1033 quadratic += lambda * t[[n, k]] * t[[n, k]];
1034 log_det += (weight * lambda).ln();
1035 }
1036 }
1037 0.5 * weight * quadratic - 0.5 * log_det
1038 }
1039
1040 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1041 let Some(t) = self.target_matrix(target) else {
1042 return Array1::<f64>::zeros(target.len());
1043 };
1044 let weight = self.resolved_weight(rho);
1045 let mut grad = Array2::<f64>::zeros(t.dim());
1046 for n in 0..t.nrows() {
1047 for k in 0..t.ncols() {
1048 grad[[n, k]] = weight * self.lambda_at(n, k, rho) * t[[n, k]];
1049 }
1050 }
1051 Self::flatten_matrix(&grad)
1052 }
1053
1054 fn hessian_diag(
1055 &self,
1056 target: ArrayView1<'_, f64>,
1057 rho: ArrayView1<'_, f64>,
1058 ) -> Option<Array1<f64>> {
1059 Some(self.diag_target(target, rho))
1060 }
1061
1062 fn hvp(
1063 &self,
1064 target: ArrayView1<'_, f64>,
1065 rho: ArrayView1<'_, f64>,
1066 v: ArrayView1<'_, f64>,
1067 ) -> Array1<f64> {
1068 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1069 if target.len() != v.len() {
1070 return Array1::<f64>::zeros(target.len());
1071 }
1072 let diag = self.diag_target(target, rho);
1073 let mut out = Array1::<f64>::zeros(v.len());
1074 for i in 0..v.len() {
1075 out[i] = diag[i] * v[i];
1076 }
1077 out
1078 }
1079
1080 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1081 let Some(t) = self.target_matrix(target) else {
1082 return Array1::<f64>::zeros(self.rho_count());
1083 };
1084 let weight = self.resolved_weight(rho);
1085 let mut out = Array1::<f64>::zeros(self.rho_count());
1086 let d = t.ncols();
1087 let du = self.aux.ncols();
1088 let mut grad_weight_direct = 0.0;
1089 for k in 0..d {
1090 let log_alpha = self.active_log_alpha(k, rho);
1091 let alpha = stable_exp_log_precision(log_alpha);
1092 let raw_beta = self.active_raw_beta(k, rho);
1093 let beta = gam_linalg::utils::stable_softplus(raw_beta);
1094 let beta_jac = gam_linalg::utils::stable_logistic(raw_beta);
1095 let mut grad_alpha_direct = 0.0;
1096 let mut grad_beta_direct = 0.0;
1097 let mut grad_mu_direct = vec![0.0_f64; du];
1098 for n in 0..t.nrows() {
1099 let tk = t[[n, k]];
1100 let sq = tk * tk;
1101 let r2 = self.dist2(n, k, rho);
1102 let lambda = alpha + beta * r2;
1103 let precision_score = 0.5 * weight * sq - 0.5 / lambda;
1104 grad_weight_direct += 0.5 * weight * lambda * sq;
1105 grad_alpha_direct += precision_score;
1106 grad_beta_direct += precision_score * r2;
1107 for a in 0..du {
1108 let delta = self.aux[[n, a]] - self.active_mu(k, a, rho);
1109 grad_mu_direct[a] += -2.0 * precision_score * beta * delta;
1110 }
1111 }
1112 out[self.log_alpha_offset() + k] = grad_alpha_direct * alpha;
1113 out[self.raw_beta_offset() + k] = grad_beta_direct * beta_jac;
1114 for a in 0..du {
1115 out[self.mu_offset() + k * du + a] = grad_mu_direct[a];
1116 }
1117 }
1118 if self.learnable_weight {
1119 out[self.weight_offset()] = grad_weight_direct - 0.5 * target.len() as f64;
1120 }
1121 out
1122 }
1123
1124 fn rho_count(&self) -> usize {
1125 self.log_alpha.len()
1126 + self.raw_beta.len()
1127 + self.mu.len()
1128 + usize::from(self.learnable_weight)
1129 }
1130
1131 fn name(&self) -> &str {
1132 "parametric_row_precision_prior"
1133 }
1134
1135 impl_scalar_apply_schedule!(weight);
1136}