Skip to main content

gam_terms/analytic_penalties/
row_precision.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// Row-precision prior penalty
5// ---------------------------------------------------------------------------
6
7/// Fixed zero-mean Gaussian row-precision prior on the latent block.
8///
9/// Evaluates the row-wise precision energy `½ μ Σ_n t_nᵀ Λ_n t_n`, with the
10/// `ρ`-dependent Gaussian precision normalizer when `μ` is learnable. Callers
11/// pass one positive-definite precision matrix per row. This is not the iVAE
12/// conditional-mean gauge `½ μ ||t - h(u)||²`; use `LatentIdMode::AuxPrior`
13/// for the ridge/linear projection-residual gauge.
14#[derive(Debug, Clone)]
15pub struct RowPrecisionPriorPenalty {
16    pub lambda_per_row: Array3<f64>,
17    /// Base strength. If `learnable_weight` is true, the resolved strength is
18    /// `weight * exp(rho[rho_index])`; otherwise it is fixed at `weight`.
19    pub weight: f64,
20    /// Number of rows in the row-major matrix-valued latent block.
21    pub n_eff: usize,
22    pub learnable_weight: bool,
23    pub rho_index: usize,
24    pub target: PsiSlice,
25    pub weight_schedule: Option<ScalarWeightSchedule>,
26}
27
28impl RowPrecisionPriorPenalty {
29    #[must_use = "build error must be handled"]
30    pub fn new(
31        target: PsiSlice,
32        lambda_per_row: Array3<f64>,
33        weight: f64,
34        n_eff: usize,
35        learnable_weight: bool,
36    ) -> Result<Self, String> {
37        if target.is_empty() {
38            return Err("RowPrecisionPriorPenalty::new requires a non-empty target".to_string());
39        }
40        if !(weight.is_finite() && weight > 0.0) {
41            return Err(format!(
42                "RowPrecisionPriorPenalty::new requires finite weight > 0, got {weight}"
43            ));
44        }
45        if n_eff == 0 {
46            return Err("RowPrecisionPriorPenalty::new requires n_eff > 0".to_string());
47        }
48        if !target.len().is_multiple_of(n_eff) {
49            return Err(format!(
50                "RowPrecisionPriorPenalty::new target length {} is not divisible by n_eff {}",
51                target.len(),
52                n_eff
53            ));
54        }
55        let latent_dim = target.len() / n_eff;
56        if let Some(expected_dim) = target.latent_dim {
57            let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
58                "RowPrecisionPriorPenalty::new target shape overflows usize".to_string()
59            })?;
60            if expected != target.len() {
61                return Err(format!(
62                    "RowPrecisionPriorPenalty::new target length {} does not match n_eff {} × latent_dim {}",
63                    target.len(),
64                    n_eff,
65                    expected_dim
66                ));
67            }
68            if expected_dim != latent_dim {
69                return Err(format!(
70                    "RowPrecisionPriorPenalty::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
71                ));
72            }
73        }
74        let (lambda_n, lambda_rows, lambda_cols) = lambda_per_row.dim();
75        if lambda_n != n_eff || lambda_rows != latent_dim || lambda_cols != latent_dim {
76            return Err(format!(
77                "RowPrecisionPriorPenalty::new lambda_per_row shape must be ({n_eff}, {latent_dim}, {latent_dim}), got ({lambda_n}, {lambda_rows}, {lambda_cols})"
78            ));
79        }
80        for n in 0..n_eff {
81            let mut matrix = Array2::<f64>::zeros((latent_dim, latent_dim));
82            for i in 0..latent_dim {
83                for j in 0..latent_dim {
84                    let value = lambda_per_row[[n, i, j]];
85                    if !value.is_finite() {
86                        return Err(format!(
87                            "RowPrecisionPriorPenalty::new lambda_per_row[{n},{i},{j}] must be finite"
88                        ));
89                    }
90                    let transpose = lambda_per_row[[n, j, i]];
91                    if (value - transpose).abs() >= 1.0e-10 {
92                        return Err(format!(
93                            "RowPrecisionPriorPenalty::new lambda_per_row[{n}] must be symmetric; |Λ[{i},{j}] - Λ[{j},{i}]| = {:.3e}",
94                            (value - transpose).abs()
95                        ));
96                    }
97                    matrix[[i, j]] = value;
98                }
99            }
100            let (evals, _) = matrix.eigh(Side::Lower).map_err(|err| {
101                format!("RowPrecisionPriorPenalty::new lambda_per_row[{n}] eigendecomposition failed: {err}")
102            })?;
103            let min_eval = evals.iter().fold(f64::INFINITY, |acc, &v| acc.min(v));
104            if !(min_eval.is_finite() && min_eval > 0.0) {
105                return Err(format!(
106                    "RowPrecisionPriorPenalty::new lambda_per_row[{n}] must be positive definite; minimum eigenvalue {min_eval:.3e}"
107                ));
108            }
109        }
110        Ok(Self {
111            lambda_per_row,
112            weight,
113            n_eff,
114            learnable_weight,
115            rho_index: 0,
116            target,
117            weight_schedule: None,
118        })
119    }
120
121    impl_with_weight_schedule!(weight);
122
123    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
124        if self.learnable_weight {
125            resolve_learnable_weight(self.weight, rho[self.rho_index])
126        } else {
127            self.weight
128        }
129    }
130
131    fn latent_dim(&self, target_len: usize) -> Option<usize> {
132        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
133            assert_eq!(
134                target_len % self.n_eff.max(1),
135                0,
136                "target length must be divisible by n_eff"
137            );
138            return None;
139        }
140        Some(target_len / self.n_eff)
141    }
142
143    fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
144        let d = self.latent_dim(target.len())?;
145        target.into_shape_with_order((self.n_eff, d)).ok()
146    }
147
148    fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
149        let n_obs = m.nrows();
150        let d = m.ncols();
151        let mut out = Array1::<f64>::zeros(n_obs * d);
152        for n in 0..n_obs {
153            for a in 0..d {
154                out[n * d + a] = m[[n, a]];
155            }
156        }
157        out
158    }
159
160    pub fn diag_target(
161        &self,
162        target: ArrayView1<'_, f64>,
163        rho: ArrayView1<'_, f64>,
164    ) -> Array1<f64> {
165        let Some(t) = self.target_matrix(target) else {
166            return Array1::<f64>::zeros(target.len());
167        };
168        let weight = self.resolved_weight(rho);
169        let mut out = Array1::<f64>::zeros(target.len());
170        for n in 0..t.nrows() {
171            for i in 0..t.ncols() {
172                out[n * t.ncols() + i] = weight * self.lambda_per_row[[n, i, i]];
173            }
174        }
175        out
176    }
177
178    /// Materialize the row-block-diagonal Hessian for exact spectral paths.
179    pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
180        let n_total = target.len();
181        let Some(t) = self.target_matrix(target) else {
182            return Array2::<f64>::zeros((n_total, n_total));
183        };
184        let d = t.ncols();
185        let weight = self.resolved_weight(rho);
186        let mut dense = Array2::<f64>::zeros((n_total, n_total));
187        for n in 0..t.nrows() {
188            for i in 0..d {
189                let row = n * d + i;
190                for j in 0..d {
191                    dense[[row, n * d + j]] = weight * self.lambda_per_row[[n, i, j]];
192                }
193            }
194        }
195        dense
196    }
197
198    pub fn log_det_plus_lambda_i(
199        &self,
200        rho: ArrayView1<'_, f64>,
201        lambda: f64,
202    ) -> Result<f64, String> {
203        if !(lambda.is_finite() && lambda > 0.0) {
204            return Err(format!(
205                "RowPrecisionPriorPenalty::log_det_plus_lambda_i requires finite λ > 0; got {lambda}"
206            ));
207        }
208        let (n_obs, d, _) = self.lambda_per_row.dim();
209        let weight = self.resolved_weight(rho);
210        let mut sum = 0.0;
211        for n in 0..n_obs {
212            let mut matrix = Array2::<f64>::zeros((d, d));
213            for i in 0..d {
214                for j in 0..d {
215                    matrix[[i, j]] = self.lambda_per_row[[n, i, j]];
216                }
217            }
218            let (evals, _) = matrix.eigh(Side::Lower).map_err(|err| {
219                format!("RowPrecisionPriorPenalty::log_det_plus_lambda_i lambda_per_row[{n}] eigendecomposition failed: {err}")
220            })?;
221            for &eval in evals.iter() {
222                let shifted = weight * eval + lambda;
223                if !(shifted.is_finite() && shifted > 0.0) {
224                    return Err(format!(
225                        "RowPrecisionPriorPenalty::log_det_plus_lambda_i non-positive shifted eigenvalue {shifted:.3e}"
226                    ));
227                }
228                sum += shifted.ln();
229            }
230        }
231        Ok(sum)
232    }
233}
234
235impl AnalyticPenalty for RowPrecisionPriorPenalty {
236    fn tier(&self) -> PenaltyTier {
237        PenaltyTier::Psi
238    }
239
240    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
241        let Some(t) = self.target_matrix(target) else {
242            return 0.0;
243        };
244        let mut acc = 0.0;
245        for n in 0..t.nrows() {
246            for i in 0..t.ncols() {
247                let mut row_dot = 0.0;
248                for j in 0..t.ncols() {
249                    row_dot += self.lambda_per_row[[n, i, j]] * t[[n, j]];
250                }
251                acc += t[[n, i]] * row_dot;
252            }
253        }
254        let weight = self.resolved_weight(rho);
255        let log_weight_normalizer = -0.5 * target.len() as f64 * weight.ln();
256        0.5 * weight * acc + log_weight_normalizer
257    }
258
259    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
260        let Some(t) = self.target_matrix(target) else {
261            return Array1::<f64>::zeros(target.len());
262        };
263        let weight = self.resolved_weight(rho);
264        let mut grad = Array2::<f64>::zeros(t.dim());
265        for n in 0..t.nrows() {
266            for i in 0..t.ncols() {
267                let mut acc = 0.0;
268                for j in 0..t.ncols() {
269                    acc += self.lambda_per_row[[n, i, j]] * t[[n, j]];
270                }
271                grad[[n, i]] = weight * acc;
272            }
273        }
274        Self::flatten_matrix(&grad)
275    }
276
277    fn hessian_diag(
278        &self,
279        target: ArrayView1<'_, f64>,
280        rho: ArrayView1<'_, f64>,
281    ) -> Option<Array1<f64>> {
282        let Some(t) = self.target_matrix(target) else {
283            return Some(Array1::<f64>::zeros(target.len()));
284        };
285        for n in 0..t.nrows() {
286            for i in 0..t.ncols() {
287                for j in 0..t.ncols() {
288                    if i != j && self.lambda_per_row[[n, i, j]] != 0.0 {
289                        return None;
290                    }
291                }
292            }
293        }
294        Some(self.diag_target(target, rho))
295    }
296
297    fn hvp(
298        &self,
299        target: ArrayView1<'_, f64>,
300        rho: ArrayView1<'_, f64>,
301        v: ArrayView1<'_, f64>,
302    ) -> Array1<f64> {
303        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
304        if target.len() != v.len() {
305            return Array1::<f64>::zeros(target.len());
306        }
307        let Some(t) = self.target_matrix(target) else {
308            return Array1::<f64>::zeros(target.len());
309        };
310        let Some(v_mat) = self.target_matrix(v) else {
311            return Array1::<f64>::zeros(target.len());
312        };
313        let weight = self.resolved_weight(rho);
314        let mut out = Array2::<f64>::zeros(t.dim());
315        for n in 0..v_mat.nrows() {
316            for i in 0..v_mat.ncols() {
317                let mut acc = 0.0;
318                for j in 0..v_mat.ncols() {
319                    acc += self.lambda_per_row[[n, i, j]] * v_mat[[n, j]];
320                }
321                out[[n, i]] = weight * acc;
322            }
323        }
324        Self::flatten_matrix(&out)
325    }
326
327    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
328        if !self.learnable_weight {
329            return Array1::<f64>::zeros(0);
330        }
331        let Some(t) = self.target_matrix(target) else {
332            return Array1::<f64>::zeros(1);
333        };
334        let mut quad = 0.0;
335        for n in 0..t.nrows() {
336            for i in 0..t.ncols() {
337                let mut row_dot = 0.0;
338                for j in 0..t.ncols() {
339                    row_dot += self.lambda_per_row[[n, i, j]] * t[[n, j]];
340                }
341                quad += t[[n, i]] * row_dot;
342            }
343        }
344        let weight = self.resolved_weight(rho);
345        let mut out = Array1::<f64>::zeros(1);
346        out[self.rho_index] = 0.5 * weight * quad - 0.5 * target.len() as f64;
347        out
348    }
349
350    impl_learnable_weight_rho_count!();
351
352    fn name(&self) -> &str {
353        "row_precision_prior"
354    }
355
356    impl_scalar_apply_schedule!(weight);
357}
358
359// ---------------------------------------------------------------------------
360// iVAE ridge conditional-mean gauge penalty
361// ---------------------------------------------------------------------------
362
363/// iVAE conditional-mean gauge penalty on the latent block.
364///
365/// Khemakhem et al. (2020) identify nonlinear ICA/iVAE latent factors from
366/// auxiliary-variable variation up to an affine transform under sufficient
367/// variation in `u`. This penalty implements the conditional-mean side of that
368/// signal as `0.5 * μ * ||t - U(UᵀU + εI)⁻¹Uᵀt||²`, penalizing only the
369/// component of each latent axis not explained by a ridge linear fit to `u`.
370#[derive(Debug, Clone)]
371pub struct IvaeRidgeMeanGauge {
372    pub aux: Array2<f64>,
373    pub ridge_inv: Array2<f64>,
374    pub ridge_eps: f64,
375    /// Base strength. If `learnable_weight` is true, the resolved strength is
376    /// `weight * exp(rho[rho_index])`; otherwise it is fixed at `weight`.
377    pub weight: f64,
378    /// Number of rows in the row-major matrix-valued latent block.
379    pub n_eff: usize,
380    pub learnable_weight: bool,
381    pub rho_index: usize,
382    pub target: PsiSlice,
383    pub weight_schedule: Option<ScalarWeightSchedule>,
384}
385
386impl IvaeRidgeMeanGauge {
387    #[must_use = "build error must be handled"]
388    pub fn new(
389        target: PsiSlice,
390        aux: Array2<f64>,
391        ridge_eps: f64,
392        weight: f64,
393        n_eff: usize,
394        learnable_weight: bool,
395    ) -> Result<Self, String> {
396        if target.is_empty() {
397            return Err("IvaeRidgeMeanGauge::new requires a non-empty target".to_string());
398        }
399        if !(weight.is_finite() && weight > 0.0) {
400            return Err(format!(
401                "IvaeRidgeMeanGauge::new requires finite weight > 0, got {weight}"
402            ));
403        }
404        if !(ridge_eps.is_finite() && ridge_eps > 0.0) {
405            return Err(format!(
406                "IvaeRidgeMeanGauge::new requires finite ridge_eps > 0, got {ridge_eps}"
407            ));
408        }
409        if n_eff == 0 {
410            return Err("IvaeRidgeMeanGauge::new requires n_eff > 0".to_string());
411        }
412        if !target.len().is_multiple_of(n_eff) {
413            return Err(format!(
414                "IvaeRidgeMeanGauge::new target length {} is not divisible by n_eff {}",
415                target.len(),
416                n_eff
417            ));
418        }
419        let latent_dim = target.len() / n_eff;
420        if let Some(expected_dim) = target.latent_dim {
421            let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
422                "IvaeRidgeMeanGauge::new target shape overflows usize".to_string()
423            })?;
424            if expected != target.len() {
425                return Err(format!(
426                    "IvaeRidgeMeanGauge::new target length {} does not match n_eff {} × latent_dim {}",
427                    target.len(),
428                    n_eff,
429                    expected_dim
430                ));
431            }
432            if expected_dim != latent_dim {
433                return Err(format!(
434                    "IvaeRidgeMeanGauge::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
435                ));
436            }
437        }
438        let (aux_n, aux_dim) = aux.dim();
439        if aux_n != n_eff {
440            return Err(format!(
441                "IvaeRidgeMeanGauge::new aux rows must equal n_eff {n_eff}, got {aux_n}"
442            ));
443        }
444        if aux_dim == 0 {
445            return Err("IvaeRidgeMeanGauge::new requires aux dimension > 0".to_string());
446        }
447        for (idx, &value) in aux.iter().enumerate() {
448            if !value.is_finite() {
449                return Err(format!("IvaeRidgeMeanGauge::new aux[{idx}] must be finite"));
450            }
451        }
452        let mut gram = Array2::<f64>::zeros((aux_dim, aux_dim));
453        for n in 0..n_eff {
454            for i in 0..aux_dim {
455                for j in 0..aux_dim {
456                    gram[[i, j]] += aux[[n, i]] * aux[[n, j]];
457                }
458            }
459        }
460        for i in 0..aux_dim {
461            gram[[i, i]] += ridge_eps;
462        }
463        let ridge_inv = Self::invert_spd_gram(gram)?;
464        Ok(Self {
465            aux,
466            ridge_inv,
467            ridge_eps,
468            weight,
469            n_eff,
470            learnable_weight,
471            rho_index: 0,
472            target,
473            weight_schedule: None,
474        })
475    }
476
477    impl_with_weight_schedule!(weight);
478
479    fn invert_spd_gram(gram: Array2<f64>) -> Result<Array2<f64>, String> {
480        let q = gram.nrows();
481        let (evals, evecs) = gram.eigh(Side::Lower).map_err(|err| {
482            format!("IvaeRidgeMeanGauge::new ridge Gram eigendecomposition failed: {err}")
483        })?;
484        let mut inv = Array2::<f64>::zeros((q, q));
485        for k in 0..q {
486            let eval = evals[k];
487            if !(eval.is_finite() && eval > 0.0) {
488                return Err(format!(
489                    "IvaeRidgeMeanGauge::new ridge Gram must be positive definite; eigenvalue {k} is {eval:.3e}"
490                ));
491            }
492            let inv_eval = 1.0 / eval;
493            for i in 0..q {
494                for j in 0..q {
495                    inv[[i, j]] += evecs[[i, k]] * evecs[[j, k]] * inv_eval;
496                }
497            }
498        }
499        Ok(inv)
500    }
501
502    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
503        if self.learnable_weight {
504            resolve_learnable_weight(self.weight, rho[self.rho_index])
505        } else {
506            self.weight
507        }
508    }
509
510    fn latent_dim(&self, target_len: usize) -> Option<usize> {
511        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
512            assert_eq!(
513                target_len % self.n_eff.max(1),
514                0,
515                "target length must be divisible by n_eff"
516            );
517            return None;
518        }
519        Some(target_len / self.n_eff)
520    }
521
522    fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
523        let d = self.latent_dim(target.len())?;
524        target.into_shape_with_order((self.n_eff, d)).ok()
525    }
526
527    fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
528        let n_obs = m.nrows();
529        let d = m.ncols();
530        let mut out = Array1::<f64>::zeros(n_obs * d);
531        for n in 0..n_obs {
532            for a in 0..d {
533                out[n * d + a] = m[[n, a]];
534            }
535        }
536        out
537    }
538
539    fn projected_matrix(&self, x: ArrayView2<'_, f64>) -> Array2<f64> {
540        let q = self.aux.ncols();
541        let d = x.ncols();
542        let mut u_t_x = Array2::<f64>::zeros((q, d));
543        for n in 0..x.nrows() {
544            for i in 0..q {
545                let u_ni = self.aux[[n, i]];
546                for a in 0..d {
547                    u_t_x[[i, a]] += u_ni * x[[n, a]];
548                }
549            }
550        }
551        let mut coeff = Array2::<f64>::zeros((q, d));
552        for i in 0..q {
553            for j in 0..q {
554                let inv_ij = self.ridge_inv[[i, j]];
555                for a in 0..d {
556                    coeff[[i, a]] += inv_ij * u_t_x[[j, a]];
557                }
558            }
559        }
560        let mut projected = Array2::<f64>::zeros(x.dim());
561        for n in 0..x.nrows() {
562            for i in 0..q {
563                let u_ni = self.aux[[n, i]];
564                for a in 0..d {
565                    projected[[n, a]] += u_ni * coeff[[i, a]];
566                }
567            }
568        }
569        projected
570    }
571
572    fn residual_matrix(&self, x: ArrayView2<'_, f64>) -> Array2<f64> {
573        let projected = self.projected_matrix(x);
574        let mut residual = Array2::<f64>::zeros(x.dim());
575        for n in 0..x.nrows() {
576            for a in 0..x.ncols() {
577                residual[[n, a]] = x[[n, a]] - projected[[n, a]];
578            }
579        }
580        residual
581    }
582
583    pub fn diag_target(
584        &self,
585        target: ArrayView1<'_, f64>,
586        rho: ArrayView1<'_, f64>,
587    ) -> Array1<f64> {
588        let Some(t) = self.target_matrix(target) else {
589            return Array1::<f64>::zeros(target.len());
590        };
591        let weight = self.resolved_weight(rho);
592        let mut out = Array1::<f64>::zeros(target.len());
593        for n in 0..t.nrows() {
594            let mut p_nn = 0.0;
595            for i in 0..self.aux.ncols() {
596                for j in 0..self.aux.ncols() {
597                    p_nn += self.aux[[n, i]] * self.ridge_inv[[i, j]] * self.aux[[n, j]];
598                }
599            }
600            let diag = weight * (1.0 - p_nn);
601            for a in 0..t.ncols() {
602                out[n * t.ncols() + a] = diag;
603            }
604        }
605        out
606    }
607
608    /// Materialize `μ(I - U(UᵀU + εI)⁻¹Uᵀ)` repeated per latent axis.
609    pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
610        let n_total = target.len();
611        let Some(t) = self.target_matrix(target) else {
612            return Array2::<f64>::zeros((n_total, n_total));
613        };
614        let d = t.ncols();
615        let weight = self.resolved_weight(rho);
616        let mut dense = Array2::<f64>::zeros((n_total, n_total));
617        for n in 0..t.nrows() {
618            for m in 0..t.nrows() {
619                let mut p_nm = 0.0;
620                for i in 0..self.aux.ncols() {
621                    for j in 0..self.aux.ncols() {
622                        p_nm += self.aux[[n, i]] * self.ridge_inv[[i, j]] * self.aux[[m, j]];
623                    }
624                }
625                let entry = weight * (if n == m { 1.0 } else { 0.0 } - p_nm);
626                for a in 0..d {
627                    dense[[n * d + a, m * d + a]] = entry;
628                }
629            }
630        }
631        dense
632    }
633}
634
635impl AnalyticPenalty for IvaeRidgeMeanGauge {
636    fn tier(&self) -> PenaltyTier {
637        PenaltyTier::Psi
638    }
639
640    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
641        let Some(t) = self.target_matrix(target) else {
642            return 0.0;
643        };
644        let residual = self.residual_matrix(t.view());
645        let mut acc = 0.0;
646        for n in 0..t.nrows() {
647            for a in 0..t.ncols() {
648                acc += t[[n, a]] * residual[[n, a]];
649            }
650        }
651        let weight = self.resolved_weight(rho);
652        let mut value = 0.5 * weight * acc;
653        if self.learnable_weight {
654            value -= 0.5 * target.len() as f64 * weight.ln();
655        }
656        value
657    }
658
659    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
660        let Some(t) = self.target_matrix(target) else {
661            return Array1::<f64>::zeros(target.len());
662        };
663        let weight = self.resolved_weight(rho);
664        let mut grad = self.residual_matrix(t.view());
665        for value in grad.iter_mut() {
666            *value *= weight;
667        }
668        Self::flatten_matrix(&grad)
669    }
670
671    fn hvp(
672        &self,
673        target: ArrayView1<'_, f64>,
674        rho: ArrayView1<'_, f64>,
675        v: ArrayView1<'_, f64>,
676    ) -> Array1<f64> {
677        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
678        if target.len() != v.len() {
679            return Array1::<f64>::zeros(target.len());
680        }
681        let Some(v_mat) = self.target_matrix(v) else {
682            return Array1::<f64>::zeros(target.len());
683        };
684        let weight = self.resolved_weight(rho);
685        let mut hv = self.residual_matrix(v_mat.view());
686        for value in hv.iter_mut() {
687            *value *= weight;
688        }
689        Self::flatten_matrix(&hv)
690    }
691
692    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
693        if !self.learnable_weight {
694            return Array1::<f64>::zeros(0);
695        }
696        if self.target_matrix(target).is_none() {
697            return Array1::<f64>::zeros(1);
698        }
699        let mut out = Array1::<f64>::zeros(1);
700        let weight = self.resolved_weight(rho);
701        out[self.rho_index] =
702            self.value(target, rho) + 0.5 * target.len() as f64 * (weight.ln() - 1.0);
703        out
704    }
705
706    impl_learnable_weight_rho_count!();
707
708    fn name(&self) -> &str {
709        "ivae_ridge_mean_gauge"
710    }
711
712    impl_scalar_apply_schedule!(weight);
713}
714
715// ---------------------------------------------------------------------------
716// Parametric row-precision prior penalty
717// ---------------------------------------------------------------------------
718
719/// Parametric zero-mean Gaussian row-precision prior on the latent block.
720///
721/// Uses a diagonal precision
722/// `λ_k(u_n) = exp(log_alpha_k) + softplus(raw_beta_k) ||u_n - μ_k||²`.
723/// REML may learn that conditional precision map, including the Gaussian
724/// precision normalizer derivatives. This is not a learnable conditional
725/// mean map and does not implement the iVAE projection-residual gauge.
726#[derive(Debug, Clone)]
727pub struct ParametricRowPrecisionPriorPenalty {
728    pub aux: Array2<f64>,
729    pub log_alpha: Array1<f64>,
730    pub raw_beta: Array1<f64>,
731    pub mu: Array2<f64>,
732    /// Base strength. If `learnable_weight` is true, the resolved strength is
733    /// `weight * exp(rho[weight_rho_index])`; otherwise it is fixed at `weight`.
734    pub weight: f64,
735    /// Number of rows in the row-major matrix-valued latent block.
736    pub n_eff: usize,
737    pub learnable_weight: bool,
738    pub target: PsiSlice,
739    pub weight_schedule: Option<ScalarWeightSchedule>,
740}
741
742impl ParametricRowPrecisionPriorPenalty {
743    #[must_use = "build error must be handled"]
744    pub fn new(
745        target: PsiSlice,
746        aux: Array2<f64>,
747        log_alpha: Array1<f64>,
748        raw_beta: Array1<f64>,
749        mu: Array2<f64>,
750        weight: f64,
751        n_eff: usize,
752        learnable_weight: bool,
753    ) -> Result<Self, String> {
754        if target.is_empty() {
755            return Err(
756                "ParametricRowPrecisionPriorPenalty::new requires a non-empty target".to_string(),
757            );
758        }
759        if !(weight.is_finite() && weight > 0.0) {
760            return Err(format!(
761                "ParametricRowPrecisionPriorPenalty::new requires finite weight > 0, got {weight}"
762            ));
763        }
764        if n_eff == 0 {
765            return Err("ParametricRowPrecisionPriorPenalty::new requires n_eff > 0".to_string());
766        }
767        if !target.len().is_multiple_of(n_eff) {
768            return Err(format!(
769                "ParametricRowPrecisionPriorPenalty::new target length {} is not divisible by n_eff {}",
770                target.len(),
771                n_eff
772            ));
773        }
774        let latent_dim = target.len() / n_eff;
775        if latent_dim == 0 {
776            return Err(
777                "ParametricRowPrecisionPriorPenalty::new requires latent_dim > 0".to_string(),
778            );
779        }
780        if let Some(expected_dim) = target.latent_dim {
781            let expected = n_eff.checked_mul(expected_dim).ok_or_else(|| {
782                "ParametricRowPrecisionPriorPenalty::new target shape overflows usize".to_string()
783            })?;
784            if expected != target.len() {
785                return Err(format!(
786                    "ParametricRowPrecisionPriorPenalty::new target length {} does not match n_eff {} × latent_dim {}",
787                    target.len(),
788                    n_eff,
789                    expected_dim
790                ));
791            }
792            if expected_dim != latent_dim {
793                return Err(format!(
794                    "ParametricRowPrecisionPriorPenalty::new inferred latent_dim {latent_dim} does not match target latent_dim {expected_dim}"
795                ));
796            }
797        }
798        let (aux_n, aux_dim) = aux.dim();
799        if aux_n != n_eff {
800            return Err(format!(
801                "ParametricRowPrecisionPriorPenalty::new aux rows must equal n_eff {n_eff}, got {aux_n}"
802            ));
803        }
804        if aux_dim == 0 {
805            return Err(
806                "ParametricRowPrecisionPriorPenalty::new requires aux dimension > 0".to_string(),
807            );
808        }
809        if log_alpha.len() != latent_dim {
810            return Err(format!(
811                "ParametricRowPrecisionPriorPenalty::new log_alpha length must equal latent_dim {latent_dim}, got {}",
812                log_alpha.len()
813            ));
814        }
815        if raw_beta.len() != latent_dim {
816            return Err(format!(
817                "ParametricRowPrecisionPriorPenalty::new raw_beta length must equal latent_dim {latent_dim}, got {}",
818                raw_beta.len()
819            ));
820        }
821        let (mu_rows, mu_cols) = mu.dim();
822        if mu_rows != latent_dim || mu_cols != aux_dim {
823            return Err(format!(
824                "ParametricRowPrecisionPriorPenalty::new mu shape must be ({latent_dim}, {aux_dim}), got ({mu_rows}, {mu_cols})"
825            ));
826        }
827        for (idx, &value) in aux.iter().enumerate() {
828            if !value.is_finite() {
829                return Err(format!(
830                    "ParametricRowPrecisionPriorPenalty::new aux[{idx}] must be finite"
831                ));
832            }
833        }
834        for k in 0..latent_dim {
835            let log_alpha_k = log_alpha[k];
836            if !log_alpha_k.is_finite() {
837                return Err(format!(
838                    "ParametricRowPrecisionPriorPenalty::new log_alpha[{k}] must be finite"
839                ));
840            }
841            let alpha_k = log_alpha_k.exp();
842            if !(alpha_k.is_finite() && alpha_k > 0.0) {
843                return Err(format!(
844                    "ParametricRowPrecisionPriorPenalty::new exp(log_alpha[{k}]) must be finite and > 0"
845                ));
846            }
847            let raw_beta_k = raw_beta[k];
848            if !raw_beta_k.is_finite() {
849                return Err(format!(
850                    "ParametricRowPrecisionPriorPenalty::new raw_beta[{k}] must be finite"
851                ));
852            }
853            let beta_k = gam_linalg::utils::stable_softplus(raw_beta_k);
854            if !(beta_k.is_finite() && beta_k >= 0.0) {
855                return Err(format!(
856                    "ParametricRowPrecisionPriorPenalty::new softplus(raw_beta[{k}]) must be finite and >= 0"
857                ));
858            }
859        }
860        for (idx, &value) in mu.iter().enumerate() {
861            if !value.is_finite() {
862                return Err(format!(
863                    "ParametricRowPrecisionPriorPenalty::new mu[{idx}] must be finite"
864                ));
865            }
866        }
867        Ok(Self {
868            aux,
869            log_alpha,
870            raw_beta,
871            mu,
872            weight,
873            n_eff,
874            learnable_weight,
875            target,
876            weight_schedule: None,
877        })
878    }
879
880    impl_with_weight_schedule!(weight);
881
882    fn latent_dim(&self, target_len: usize) -> Option<usize> {
883        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
884            assert_eq!(
885                target_len % self.n_eff.max(1),
886                0,
887                "target length must be divisible by n_eff"
888            );
889            return None;
890        }
891        Some(target_len / self.n_eff)
892    }
893
894    fn target_matrix<'a>(&self, target: ArrayView1<'a, f64>) -> Option<ArrayView2<'a, f64>> {
895        let d = self.latent_dim(target.len())?;
896        target.into_shape_with_order((self.n_eff, d)).ok()
897    }
898
899    fn flatten_matrix(m: &Array2<f64>) -> Array1<f64> {
900        let n_obs = m.nrows();
901        let d = m.ncols();
902        let mut out = Array1::<f64>::zeros(n_obs * d);
903        for n in 0..n_obs {
904            for a in 0..d {
905                out[n * d + a] = m[[n, a]];
906            }
907        }
908        out
909    }
910
911    fn log_alpha_offset(&self) -> usize {
912        0
913    }
914
915    fn raw_beta_offset(&self) -> usize {
916        self.log_alpha.len()
917    }
918
919    fn mu_offset(&self) -> usize {
920        self.log_alpha.len() + self.raw_beta.len()
921    }
922
923    fn weight_offset(&self) -> usize {
924        self.mu_offset() + self.mu.len()
925    }
926
927    fn active_log_alpha(&self, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
928        self.log_alpha[k] + rho[self.log_alpha_offset() + k]
929    }
930
931    fn active_raw_beta(&self, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
932        self.raw_beta[k] + rho[self.raw_beta_offset() + k]
933    }
934
935    fn active_mu(&self, k: usize, a: usize, rho: ArrayView1<'_, f64>) -> f64 {
936        self.mu[[k, a]] + rho[self.mu_offset() + k * self.aux.ncols() + a]
937    }
938
939    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
940        if self.learnable_weight {
941            resolve_learnable_weight(self.weight, rho[self.weight_offset()])
942        } else {
943            self.weight
944        }
945    }
946
947    fn lambda_at(&self, n: usize, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
948        let alpha = stable_exp_log_precision(self.active_log_alpha(k, rho));
949        let beta = gam_linalg::utils::stable_softplus(self.active_raw_beta(k, rho));
950        MIN_CONDITIONAL_PRECISION + alpha + beta * self.dist2(n, k, rho)
951    }
952
953    fn dist2(&self, n: usize, k: usize, rho: ArrayView1<'_, f64>) -> f64 {
954        let mut r2 = 0.0;
955        for a in 0..self.aux.ncols() {
956            let delta = self.aux[[n, a]] - self.active_mu(k, a, rho);
957            r2 += delta * delta;
958        }
959        r2
960    }
961
962    pub fn diag_target(
963        &self,
964        target: ArrayView1<'_, f64>,
965        rho: ArrayView1<'_, f64>,
966    ) -> Array1<f64> {
967        let Some(t) = self.target_matrix(target) else {
968            return Array1::<f64>::zeros(target.len());
969        };
970        let weight = self.resolved_weight(rho);
971        let mut out = Array1::<f64>::zeros(target.len());
972        for n in 0..t.nrows() {
973            for k in 0..t.ncols() {
974                out[n * t.ncols() + k] = weight * self.lambda_at(n, k, rho);
975            }
976        }
977        out
978    }
979
980    /// Materialize the row-block-diagonal Hessian for exact spectral paths.
981    pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
982        let n_total = target.len();
983        let diag = self.diag_target(target, rho);
984        let mut dense = Array2::<f64>::zeros((n_total, n_total));
985        for i in 0..n_total {
986            dense[[i, i]] = diag[i];
987        }
988        dense
989    }
990
991    pub fn log_det_plus_lambda_i(
992        &self,
993        rho: ArrayView1<'_, f64>,
994        lambda: f64,
995    ) -> Result<f64, String> {
996        if !(lambda.is_finite() && lambda > 0.0) {
997            return Err(format!(
998                "ParametricRowPrecisionPriorPenalty::log_det_plus_lambda_i requires finite λ > 0; got {lambda}"
999            ));
1000        }
1001        let weight = self.resolved_weight(rho);
1002        let mut sum = 0.0;
1003        for n in 0..self.n_eff {
1004            for k in 0..self.log_alpha.len() {
1005                let shifted = lambda + weight * self.lambda_at(n, k, rho);
1006                if !(shifted.is_finite() && shifted > 0.0) {
1007                    return Err(format!(
1008                        "ParametricRowPrecisionPriorPenalty::log_det_plus_lambda_i non-positive shifted diagonal {shifted:.3e}"
1009                    ));
1010                }
1011                sum += shifted.ln();
1012            }
1013        }
1014        Ok(sum)
1015    }
1016}
1017
1018impl AnalyticPenalty for ParametricRowPrecisionPriorPenalty {
1019    fn tier(&self) -> PenaltyTier {
1020        PenaltyTier::Psi
1021    }
1022
1023    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1024        let Some(t) = self.target_matrix(target) else {
1025            return 0.0;
1026        };
1027        let weight = self.resolved_weight(rho);
1028        let mut quadratic = 0.0;
1029        let mut log_det = 0.0;
1030        for n in 0..t.nrows() {
1031            for k in 0..t.ncols() {
1032                let lambda = self.lambda_at(n, k, rho);
1033                quadratic += lambda * t[[n, k]] * t[[n, k]];
1034                log_det += (weight * lambda).ln();
1035            }
1036        }
1037        0.5 * weight * quadratic - 0.5 * log_det
1038    }
1039
1040    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1041        let Some(t) = self.target_matrix(target) else {
1042            return Array1::<f64>::zeros(target.len());
1043        };
1044        let weight = self.resolved_weight(rho);
1045        let mut grad = Array2::<f64>::zeros(t.dim());
1046        for n in 0..t.nrows() {
1047            for k in 0..t.ncols() {
1048                grad[[n, k]] = weight * self.lambda_at(n, k, rho) * t[[n, k]];
1049            }
1050        }
1051        Self::flatten_matrix(&grad)
1052    }
1053
1054    fn hessian_diag(
1055        &self,
1056        target: ArrayView1<'_, f64>,
1057        rho: ArrayView1<'_, f64>,
1058    ) -> Option<Array1<f64>> {
1059        Some(self.diag_target(target, rho))
1060    }
1061
1062    fn hvp(
1063        &self,
1064        target: ArrayView1<'_, f64>,
1065        rho: ArrayView1<'_, f64>,
1066        v: ArrayView1<'_, f64>,
1067    ) -> Array1<f64> {
1068        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
1069        if target.len() != v.len() {
1070            return Array1::<f64>::zeros(target.len());
1071        }
1072        let diag = self.diag_target(target, rho);
1073        let mut out = Array1::<f64>::zeros(v.len());
1074        for i in 0..v.len() {
1075            out[i] = diag[i] * v[i];
1076        }
1077        out
1078    }
1079
1080    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1081        let Some(t) = self.target_matrix(target) else {
1082            return Array1::<f64>::zeros(self.rho_count());
1083        };
1084        let weight = self.resolved_weight(rho);
1085        let mut out = Array1::<f64>::zeros(self.rho_count());
1086        let d = t.ncols();
1087        let du = self.aux.ncols();
1088        let mut grad_weight_direct = 0.0;
1089        for k in 0..d {
1090            let log_alpha = self.active_log_alpha(k, rho);
1091            let alpha = stable_exp_log_precision(log_alpha);
1092            let raw_beta = self.active_raw_beta(k, rho);
1093            let beta = gam_linalg::utils::stable_softplus(raw_beta);
1094            let beta_jac = gam_linalg::utils::stable_logistic(raw_beta);
1095            let mut grad_alpha_direct = 0.0;
1096            let mut grad_beta_direct = 0.0;
1097            let mut grad_mu_direct = vec![0.0_f64; du];
1098            for n in 0..t.nrows() {
1099                let tk = t[[n, k]];
1100                let sq = tk * tk;
1101                let r2 = self.dist2(n, k, rho);
1102                let lambda = alpha + beta * r2;
1103                let precision_score = 0.5 * weight * sq - 0.5 / lambda;
1104                grad_weight_direct += 0.5 * weight * lambda * sq;
1105                grad_alpha_direct += precision_score;
1106                grad_beta_direct += precision_score * r2;
1107                for a in 0..du {
1108                    let delta = self.aux[[n, a]] - self.active_mu(k, a, rho);
1109                    grad_mu_direct[a] += -2.0 * precision_score * beta * delta;
1110                }
1111            }
1112            out[self.log_alpha_offset() + k] = grad_alpha_direct * alpha;
1113            out[self.raw_beta_offset() + k] = grad_beta_direct * beta_jac;
1114            for a in 0..du {
1115                out[self.mu_offset() + k * du + a] = grad_mu_direct[a];
1116            }
1117        }
1118        if self.learnable_weight {
1119            out[self.weight_offset()] = grad_weight_direct - 0.5 * target.len() as f64;
1120        }
1121        out
1122    }
1123
1124    fn rho_count(&self) -> usize {
1125        self.log_alpha.len()
1126            + self.raw_beta.len()
1127            + self.mu.len()
1128            + usize::from(self.learnable_weight)
1129    }
1130
1131    fn name(&self) -> &str {
1132        "parametric_row_precision_prior"
1133    }
1134
1135    impl_scalar_apply_schedule!(weight);
1136}