Skip to main content

gam_models/
vector_response.rs

1//! Vector-valued response support.
2//!
3//! Many smooths sharing one latent: the shape function in the latent-variable
4//! engine maps to a reduced activation vector (tens-to-hundreds of dimensions,
5//! after a random-matrix noise cut). This module defines the response-side
6//! types, the Gaussian vector likelihood, and the connector trait the inner
7//! solver consumes.
8//!
9//! Conventions:
10//! - `Y` is shape `(N, M)`: `N` rows, `M` output dimensions.
11//! - `eta` is shape `(N, M)`: the linear predictor with one column per output.
12//! - For Gaussian identity-link, mean(η) = η, so the likelihood depends only
13//!   on `eta` and `Y`.
14//!
15//! The Hessian is block-structured: per-row (N independent blocks for the
16//! Gaussian case), each of size `(M, M)`. For a Gaussian likelihood with
17//! Diagonal/Isotropic noise this per-row block is itself diagonal — exactly
18//! what the arrow Schur elimination in `solver/arrow_schur.rs` consumes.
19
20use crate::model_types::EstimationError;
21use ndarray::{Array1, Array2, Array3, ArrayView2};
22
23/// Per-output noise model for a vector response.
24///
25/// `LowRank` stores the symmetric structured precision
26/// `W = diag(diag) + U Uᵀ`, with `factor` holding `U`. The vector likelihood
27/// consumes the owned arrays directly; PIRLS low-rank Gram assembly is handled
28/// by `gam_linalg::low_rank_weight::LowRankWeight` and
29/// `gam_solve::pirls`.
30#[derive(Clone, Debug)]
31pub enum VectorNoise {
32    /// Shared σ across all M outputs: Σ = σ² I_M.
33    Isotropic(f64),
34    /// Per-output σ_m: Σ = diag(σ_m²).
35    Diagonal(Array1<f64>),
36    /// Symmetric structured form `W = diag(diag) + factor · factorᵀ`.
37    LowRank {
38        diag: Array1<f64>,
39        factor: Array2<f64>,
40    },
41}
42
43impl VectorNoise {
44    /// Per-output precision vector (1/σ_m²) for the Isotropic / Diagonal cases.
45    /// LowRank returns the diagonal piece only; the low-rank correction is
46    /// applied separately by the Piece 5 weight code.
47    pub fn diag_precision(&self, m: usize) -> Result<Array1<f64>, EstimationError> {
48        match self {
49            Self::Isotropic(sigma) => {
50                if !sigma.is_finite() || *sigma <= 0.0 {
51                    crate::bail_invalid_estim!(
52                        "VectorNoise::Isotropic: σ must be > 0 and finite (got {sigma})",
53                    );
54                }
55                let p = 1.0 / (sigma * sigma);
56                Ok(Array1::from_elem(m, p))
57            }
58            Self::Diagonal(sigma) => {
59                if sigma.len() != m {
60                    crate::bail_invalid_estim!(
61                        "VectorNoise::Diagonal: σ length {} ≠ M={m}",
62                        sigma.len()
63                    );
64                }
65                let mut out = Array1::<f64>::zeros(m);
66                for j in 0..m {
67                    let s = sigma[j];
68                    if !s.is_finite() || s <= 0.0 {
69                        crate::bail_invalid_estim!(
70                            "VectorNoise::Diagonal: σ[{j}] must be > 0 and finite (got {s})",
71                        );
72                    }
73                    out[j] = 1.0 / (s * s);
74                }
75                Ok(out)
76            }
77            Self::LowRank { diag, .. } => {
78                if diag.len() != m {
79                    crate::bail_invalid_estim!(
80                        "VectorNoise::LowRank: diag length {} ≠ M={m}",
81                        diag.len()
82                    );
83                }
84                let mut out = Array1::<f64>::zeros(m);
85                for j in 0..m {
86                    let d = diag[j];
87                    if !d.is_finite() || d <= 0.0 {
88                        crate::bail_invalid_estim!(
89                            "VectorNoise::LowRank: diag[{j}] must be > 0 (got {d})",
90                        );
91                    }
92                    // `diag` is the PRECISION diagonal (W = diag(d) + F·Fᵀ).
93                    // Pass it through unchanged.
94                    out[j] = d;
95                }
96                Ok(out)
97            }
98        }
99    }
100}
101
102/// Vector-valued response target.
103///
104/// `y` is `(N, M)`; `row_weights` (if present) is length `N` and scales the
105/// per-row contribution to the likelihood (e.g. observation weights from a
106/// re-sampling or inverse-probability scheme).
107#[derive(Clone, Debug)]
108pub struct VectorResponseTarget {
109    /// shape (N, M) — N rows × M output dimensions.
110    pub y: Array2<f64>,
111    /// per-output noise (or shared scalar).
112    pub noise: VectorNoise,
113    /// optional row weights (N,).
114    pub row_weights: Option<Array1<f64>>,
115}
116
117impl VectorResponseTarget {
118    pub fn new(y: Array2<f64>, noise: VectorNoise) -> Self {
119        Self {
120            y,
121            noise,
122            row_weights: None,
123        }
124    }
125
126    pub fn with_row_weights(mut self, w: Array1<f64>) -> Result<Self, EstimationError> {
127        validate_row_weights(&w, self.y.nrows())?;
128        self.row_weights = Some(w);
129        Ok(self)
130    }
131
132    pub fn n(&self) -> usize {
133        self.y.nrows()
134    }
135    pub fn m(&self) -> usize {
136        self.y.ncols()
137    }
138}
139
140/// Relative tolerance on the per-row simplex constraint `Σ_c y_{n,c} = 1`.
141///
142/// The multinomial-logit log-likelihood `ℓ = Σ_c y_c log p_c` has the
143/// canonical residual gradient `y_a − p_a` and Fisher block
144/// `p_a δ_{ab} − p_a p_b` **only** when each target row is a probability
145/// vector (`y_c ≥ 0`, `Σ_c y_c = 1`). For a general row mass `s = Σ_c y_c`
146/// the true derivatives are `y_a − s p_a` and `s (p_a δ_{ab} − p_a p_b)`, so
147/// any row whose mass deviates from 1 makes the implemented gradient/Hessian
148/// disagree with the implemented objective. We therefore require simplex rows
149/// at every construction boundary and reject anything else, rather than
150/// silently fitting with inconsistent curvature. The tolerance absorbs only
151/// floating-point round-off in an otherwise-exact one-hot / label-smoothed
152/// row (e.g. a sum of `K` rationals), not genuine count or proportional data.
153pub(crate) const MULTINOMIAL_SIMPLEX_TOL: f64 = 1.0e-9;
154
155/// Validate that every row of a multinomial target `y ∈ ℝ^{N×K}` is a point on
156/// the probability simplex: `y_{n,c} ≥ 0` for all entries and
157/// `Σ_c y_{n,c} = 1` for every row (up to [`MULTINOMIAL_SIMPLEX_TOL`]). This
158/// is the precondition under which [`MultinomialLogitLikelihood`]'s residual
159/// gradient and Fisher block are the exact derivatives of its log-likelihood;
160/// see the constant's docs. Finiteness is checked first so the message points
161/// at the offending entry rather than at a NaN-poisoned row sum.
162pub(crate) fn validate_multinomial_simplex(
163    y: ArrayView2<f64>,
164    context: &str,
165) -> Result<(), EstimationError> {
166    let (n, k) = y.dim();
167    for row in 0..n {
168        let mut row_sum = 0.0_f64;
169        for c in 0..k {
170            let v = y[[row, c]];
171            if !v.is_finite() {
172                crate::bail_invalid_estim!("{context}: y[{row},{c}] must be finite (got {v})");
173            }
174            if v < 0.0 {
175                crate::bail_invalid_estim!(
176                    "{context}: multinomial target must be a probability vector \
177                     (y_c ≥ 0); got y[{row},{c}] = {v}"
178                );
179            }
180            row_sum += v;
181        }
182        if (row_sum - 1.0).abs() > MULTINOMIAL_SIMPLEX_TOL {
183            crate::bail_invalid_estim!(
184                "{context}: multinomial target rows must sum to 1 (one-hot for \
185                 hard labels, or a label-smoothed probability vector); row {row} \
186                 sums to {row_sum}. The softmax residual gradient y_a − p_a and \
187                 Fisher block p_a δ_ab − p_a p_b are the derivatives of \
188                 Σ_c y_c log p_c only when the row mass is 1."
189            );
190        }
191    }
192    Ok(())
193}
194
195fn validate_row_weights(weights: &Array1<f64>, n: usize) -> Result<(), EstimationError> {
196    if weights.len() != n {
197        crate::bail_invalid_estim!("row_weights length {} ≠ N={n}", weights.len());
198    }
199    for (idx, weight) in weights.iter().copied().enumerate() {
200        if !(weight.is_finite() && weight >= 0.0) {
201            crate::bail_invalid_estim!(
202                "row_weights[{idx}] must be finite and non-negative (got {weight})"
203            );
204        }
205    }
206    Ok(())
207}
208
209/// Connector trait the inner solver (Piece 1) plugs into.
210///
211/// `eta` is the `(N, M)` linear predictor; `y` is the `(N, M)` target. The
212/// implementation is responsible for any link inversion. The `hess_diag`
213/// return is the per-element diagonal of the per-row Hessian block; for a
214/// Diagonal-noise Gaussian this is exactly `(N, M)` of per-output precisions.
215pub trait VectorLikelihood {
216    /// log p(Y | η).
217    fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64;
218
219    /// ∂ log p(Y | η) / ∂ η, shape (N, M).
220    fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
221
222    /// Diagonal of the per-row Hessian −∂² log p / ∂ η ∂ η, shape (N, M).
223    /// This is the per-row block consumed by `solver/arrow_schur.rs`.
224    fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
225
226    /// Per-row dense Hessian block −∂² log p / ∂η_a ∂η_b, shape (N, M, M).
227    ///
228    /// Default implementation lifts [`Self::hess_diag`] onto the per-row
229    /// diagonal, valid only when the per-row Hessian is genuinely diagonal
230    /// across outputs (e.g. Gaussian with Isotropic/Diagonal noise).
231    /// Likelihoods with off-diagonal output coupling must override this:
232    /// [`GaussianVectorLikelihood`] with a low-rank precision factor `F`
233    /// (block `w·(diag(precision) + F·Fᵀ)`, off-diagonals `w·Σ_k F[a,k]·F[b,k]`)
234    /// and multinomial-logit (per-row Fisher block `p_a (δ_ab − p_b)`).
235    ///
236    /// The returned array is consumed by
237    /// [`gam_solve::pirls::dense_block_xtwx`] /
238    /// [`gam_solve::pirls::dense_block_xtwy`] to build `XᵀWX` and `XᵀWy`
239    /// for vector-response IRLS in output-major coefficient ordering.
240    fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
241        let diag = self.hess_diag(eta, y);
242        let (n, m) = diag.dim();
243        let mut out = Array3::<f64>::zeros((n, m, m));
244        for row in 0..n {
245            for j in 0..m {
246                out[[row, j, j]] = diag[[row, j]];
247            }
248        }
249        out
250    }
251}
252
253/// Gaussian vector likelihood with identity link.
254///
255/// `log p(Y|η) = −½ Σ_n w_n · rᵀ W r` where `r = Y_n − η_n` and `W` is the
256/// per-output **precision** matrix. For Isotropic / Diagonal `W = diag(prec)`;
257/// for `LowRank` it is `W = diag(prec) + F · Fᵀ`, with `F` carried alongside
258/// the diagonal here.
259///
260/// (Up to the constant log-determinant of the noise covariance, dropped here
261/// because it does not depend on β or the latent t; the determinant is
262/// accounted for in the REML score, not the inner likelihood.)
263#[derive(Clone, Debug)]
264pub struct GaussianVectorLikelihood {
265    /// Per-output diagonal precision (length M). For Isotropic / Diagonal /
266    /// LowRank this is the diagonal piece of the precision matrix
267    /// (`1/σ_m²` for Diagonal/Isotropic; `diag` for LowRank).
268    pub precision: Array1<f64>,
269    /// Optional dense rank-r factor `F` of size `(M, r)` such that the full
270    /// per-row precision is `diag(precision) + F · Fᵀ`. `None` for the
271    /// Isotropic / Diagonal cases.
272    pub factor: Option<Array2<f64>>,
273    /// Optional row weights (length N), or None for uniform.
274    pub row_weights: Option<Array1<f64>>,
275}
276
277impl GaussianVectorLikelihood {
278    pub fn from_target(target: &VectorResponseTarget) -> Result<Self, EstimationError> {
279        if let Some(weights) = target.row_weights.as_ref() {
280            validate_row_weights(weights, target.n())?;
281        }
282        let precision = target.noise.diag_precision(target.m())?;
283        let factor = match &target.noise {
284            VectorNoise::LowRank { factor, .. } => {
285                if factor.nrows() != target.m() {
286                    crate::bail_invalid_estim!(
287                        "VectorNoise::LowRank: factor has {} rows but M={}",
288                        factor.nrows(),
289                        target.m()
290                    );
291                }
292                for ((row, col), value) in factor.indexed_iter() {
293                    if !value.is_finite() {
294                        crate::bail_invalid_estim!(
295                            "VectorNoise::LowRank: factor[{row},{col}] must be finite (got {value})"
296                        );
297                    }
298                }
299                Some(factor.clone())
300            }
301            _ => None,
302        };
303        Ok(Self {
304            precision,
305            factor,
306            row_weights: target.row_weights.clone(),
307        })
308    }
309
310    #[inline]
311    fn row_weight(&self, n: usize) -> f64 {
312        self.row_weights.as_ref().map_or(1.0, |w| w[n])
313    }
314}
315
316impl VectorLikelihood for GaussianVectorLikelihood {
317    fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64 {
318        assert_eq!(eta.dim(), y.dim());
319        assert_eq!(eta.ncols(), self.precision.len());
320        let m = eta.ncols();
321        let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
322        let mut acc = 0.0;
323        // Scratch buffer for Fᵀ r (length rank), reused across rows.
324        let mut ftr = vec![0.0f64; rank];
325        for n in 0..eta.nrows() {
326            let w = self.row_weight(n);
327            // Diagonal part: Σ_m d_m r_m²
328            let mut row_acc = 0.0;
329            for j in 0..m {
330                let r = y[[n, j]] - eta[[n, j]];
331                row_acc += self.precision[j] * r * r;
332            }
333            // Low-rank part: ||Fᵀ r||²
334            if let Some(f) = self.factor.as_ref() {
335                for k in 0..rank {
336                    ftr[k] = 0.0;
337                }
338                for j in 0..m {
339                    let r = y[[n, j]] - eta[[n, j]];
340                    for k in 0..rank {
341                        ftr[k] += f[[j, k]] * r;
342                    }
343                }
344                for k in 0..rank {
345                    row_acc += ftr[k] * ftr[k];
346                }
347            }
348            acc += w * row_acc;
349        }
350        -0.5 * acc
351    }
352
353    fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
354        assert_eq!(eta.dim(), y.dim());
355        let (n_rows, n_cols) = eta.dim();
356        let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
357        let mut out = Array2::<f64>::zeros((n_rows, n_cols));
358        let mut ftr = vec![0.0f64; rank];
359        for n in 0..n_rows {
360            let w = self.row_weight(n);
361            // Diagonal part: w · d_m · (y − η)_m
362            for j in 0..n_cols {
363                out[[n, j]] = w * self.precision[j] * (y[[n, j]] - eta[[n, j]]);
364            }
365            // Low-rank part: + w · F (Fᵀ r) for r = y − η
366            if let Some(f) = self.factor.as_ref() {
367                for k in 0..rank {
368                    ftr[k] = 0.0;
369                }
370                for j in 0..n_cols {
371                    let r = y[[n, j]] - eta[[n, j]];
372                    for k in 0..rank {
373                        ftr[k] += f[[j, k]] * r;
374                    }
375                }
376                for j in 0..n_cols {
377                    let mut s = 0.0;
378                    for k in 0..rank {
379                        s += f[[j, k]] * ftr[k];
380                    }
381                    out[[n, j]] += w * s;
382                }
383            }
384        }
385        out
386    }
387
388    fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
389        assert_eq!(eta.dim(), y.dim());
390        // Diagonal of −∂² log p / ∂η² = w · diag(diag(d) + F·Fᵀ); the diagonal
391        // of (F·Fᵀ) at output m is Σ_k F[m, k]². This is the diagonal
392        // *preconditioner* only — the off-diagonal cross terms F[a, k]·F[b, k]
393        // are carried by the full per-row block in [`Self::hess_block`] (which
394        // this type overrides whenever `factor` is present). Callers that need
395        // the true Hessian must use `hess_block`, not this diagonal.
396        let (n_rows, n_cols) = eta.dim();
397        let mut out = Array2::<f64>::zeros((n_rows, n_cols));
398        // Pre-compute Σ_k F[m, k]² per output m (independent of n).
399        let f_row_sqsum: Option<Array1<f64>> = self.factor.as_ref().map(|f| {
400            let m = f.nrows();
401            let r = f.ncols();
402            let mut s = Array1::<f64>::zeros(m);
403            for j in 0..m {
404                let mut acc = 0.0;
405                for k in 0..r {
406                    let v = f[[j, k]];
407                    acc += v * v;
408                }
409                s[j] = acc;
410            }
411            s
412        });
413        for n in 0..n_rows {
414            let w = self.row_weight(n);
415            for j in 0..n_cols {
416                let mut d = self.precision[j];
417                if let Some(s) = f_row_sqsum.as_ref() {
418                    d += s[j];
419                }
420                out[[n, j]] = w * d;
421            }
422        }
423        out
424    }
425
426    fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
427        // Per-row dense block −∂² log p / ∂η_a ∂η_b. With log-likelihood
428        //     ℓ = −½ Σ_n w_n · rₙᵀ W rₙ,   r = y − η,   W = diag(precision) + F·Fᵀ,
429        // the gradient is wₙ · W rₙ and the negative Hessian block is exactly
430        //     H_{n,a,b} = w_n · ( precision_a · δ_ab + Σ_k F[a,k] · F[b,k] ).
431        // This is the true second derivative of `log_lik` (it differentiates
432        // `grad_eta` exactly); the diagonal-only trait default would drop the
433        // F·Fᵀ cross terms F[a,k]·F[b,k] for a ≠ b, so it must be overridden
434        // whenever a low-rank factor is present.
435        assert_eq!(eta.dim(), y.dim());
436        assert_eq!(eta.ncols(), self.precision.len());
437        let (n_rows, m) = eta.dim();
438        let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
439
440        // Per-output Gram of the low-rank factor, G_{a,b} = Σ_k F[a,k]·F[b,k].
441        // Independent of the row n, so assemble once and scale by w_n.
442        let gram: Option<Array2<f64>> = self.factor.as_ref().map(|f| {
443            let mut g = Array2::<f64>::zeros((m, m));
444            for a in 0..m {
445                for b in a..m {
446                    let mut acc = 0.0;
447                    for k in 0..rank {
448                        acc += f[[a, k]] * f[[b, k]];
449                    }
450                    g[[a, b]] = acc;
451                    g[[b, a]] = acc;
452                }
453            }
454            g
455        });
456
457        let mut out = Array3::<f64>::zeros((n_rows, m, m));
458        for n in 0..n_rows {
459            let w = self.row_weight(n);
460            for a in 0..m {
461                for b in 0..m {
462                    let mut val = if a == b { self.precision[a] } else { 0.0 };
463                    if let Some(g) = gram.as_ref() {
464                        val += g[[a, b]];
465                    }
466                    out[[n, a, b]] = w * val;
467                }
468            }
469        }
470        out
471    }
472}
473
474// ─────────────────────────────────────────────────────────────────────────────
475// Piece 5 / Piece 1 row-block support
476// ─────────────────────────────────────────────────────────────────────────────
477
478/// Multinomial-logit (softmax) likelihood with explicit reference class.
479///
480/// Conventions:
481/// - `K` is the total number of classes; the linear predictor has `M = K - 1`
482///   columns corresponding to the *active* classes. Class `K - 1` is the
483///   reference class with η_{K-1} ≡ 0 (so the gauge is fixed by construction
484///   and no additional sum-to-zero projection is required at the η level).
485/// - `y` is the categorical response with shape `(N, K)`. Each row must be a
486///   point on the probability simplex (`y_c ≥ 0`, `Σ_c y_c = 1`): a one-hot
487///   indicator for hard-label classification, or a label-smoothed probability
488///   vector. The row *weight* `w_n` scales the whole row's likelihood
489///   contribution and is independent of the row mass — it is **not** the row
490///   sum. Callers enforce the simplex precondition via
491///   [`validate_multinomial_simplex`] at every construction boundary; under it
492///   the residual gradient `y_a − p_a` and Fisher block `p_a δ_ab − p_a p_b`
493///   below are the exact derivatives of the log-likelihood `Σ_c y_c log p_c`.
494/// - `eta` is the active linear predictor with shape `(N, M = K - 1)`.
495///
496/// Softmax with baseline:
497/// ```text
498///     p_a   = exp(η_a) / (1 + Σ_b exp(η_b))           for a ∈ [0, K-1)
499///     p_{K-1} = 1 / (1 + Σ_b exp(η_b))
500/// ```
501///
502/// Log-likelihood (rows with weight `w_n`, default 1.0):
503/// ```text
504///     log L = Σ_n w_n · ( Σ_{a < K-1} y_{n,a} · η_{n,a} − log(1 + Σ_b exp(η_{n,b})) )
505///           = Σ_n w_n · Σ_{c ∈ [0, K)} y_{n,c} · log p_{n,c}
506/// ```
507///
508/// Per-row gradient w.r.t. the active η is the canonical Bernoulli/softmax
509/// residual:
510/// ```text
511///     ∂ log L / ∂η_{n,a} = w_n · (y_{n,a} − p_{n,a})       for a ∈ [0, K-1)
512/// ```
513///
514/// Per-row Fisher (= observed, since logit is canonical for the multinomial)
515/// information block, shape `(M, M)`:
516/// ```text
517///     H_{n,a,b} = w_n · ( p_{n,a} · δ_{ab} − p_{n,a} · p_{n,b} )
518/// ```
519///
520/// This is the standard reference-coded multinomial-logit GLM. The dense
521/// per-row block flows through [`VectorLikelihood::hess_block`] into
522/// [`gam_solve::pirls::dense_block_xtwx`], which builds the stacked
523/// `XᵀWX` in output-major coefficient ordering `β = [β_0; β_1; …; β_{K-2}]`
524/// with each per-class block of size `(P, P)`.
525#[derive(Clone, Debug)]
526pub struct MultinomialLogitLikelihood {
527    /// Number of active classes `M = K − 1`. Cached for shape checks.
528    pub active_classes: usize,
529    /// Optional row weights (length N), or `None` for uniform 1.0.
530    pub row_weights: Option<Array1<f64>>,
531}
532
533impl MultinomialLogitLikelihood {
534    /// Construct from the total number of classes `K ≥ 2`.
535    pub fn with_classes(total_classes: usize) -> Result<Self, EstimationError> {
536        if total_classes < 2 {
537            crate::bail_invalid_estim!(
538                "MultinomialLogitLikelihood requires K ≥ 2 classes (got {total_classes})"
539            );
540        }
541        Ok(Self {
542            active_classes: total_classes - 1,
543            row_weights: None,
544        })
545    }
546
547    /// Attach per-row weights (length N, finite and non-negative).
548    pub fn with_row_weights(mut self, w: Array1<f64>) -> Result<Self, EstimationError> {
549        validate_row_weights(&w, w.len())?;
550        self.row_weights = Some(w);
551        Ok(self)
552    }
553
554    /// Total class count `K = M + 1`.
555    #[inline]
556    pub fn total_classes(&self) -> usize {
557        self.active_classes + 1
558    }
559
560    #[inline]
561    fn row_weight(&self, n: usize) -> f64 {
562        self.row_weights.as_ref().map_or(1.0, |w| w[n])
563    }
564
565    /// Numerically-stable softmax with implicit reference column (η_{K-1} = 0).
566    ///
567    /// Writes `K` probabilities into `out` (length `M + 1`). The shift uses
568    /// `max(0, max(eta_active))` so the reference class is included in the
569    /// max and the denominator stays bounded. This is the canonical
570    /// reference implementation; the FFI surface and any direct
571    /// matrix-free callers route through this method rather than carrying
572    /// their own softmax.
573    pub fn softmax_with_baseline(eta_active: &[f64], out: &mut [f64]) {
574        assert_eq!(out.len(), eta_active.len() + 1);
575        let mut max_eta = 0.0_f64;
576        for &v in eta_active {
577            if v > max_eta {
578                max_eta = v;
579            }
580        }
581        let baseline = (-max_eta).exp();
582        let mut denom = baseline;
583        for (idx, &v) in eta_active.iter().enumerate() {
584            let e = (v - max_eta).exp();
585            out[idx] = e;
586            denom += e;
587        }
588        for v in out.iter_mut().take(eta_active.len()) {
589            *v /= denom;
590        }
591        out[eta_active.len()] = baseline / denom;
592    }
593
594    /// Convenience: compute the full (N, K) probability matrix from
595    /// (N, K-1) active linear predictor. This is the multinomial inverse
596    /// link used by prediction.
597    pub fn probabilities(&self, eta: ArrayView2<f64>) -> Array2<f64> {
598        let n = eta.nrows();
599        let m = self.active_classes;
600        assert_eq!(eta.ncols(), m, "η must have K-1 columns");
601        let k = self.total_classes();
602        let mut probs = Array2::<f64>::zeros((n, k));
603        let mut eta_row = vec![0.0_f64; m];
604        let mut probs_row = vec![0.0_f64; k];
605        for row in 0..n {
606            for j in 0..m {
607                eta_row[j] = eta[[row, j]];
608            }
609            Self::softmax_with_baseline(&eta_row, &mut probs_row);
610            for j in 0..k {
611                probs[[row, j]] = probs_row[j];
612            }
613        }
614        probs
615    }
616}
617
618impl VectorLikelihood for MultinomialLogitLikelihood {
619    fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64 {
620        let n = eta.nrows();
621        let m = self.active_classes;
622        let k = self.total_classes();
623        assert_eq!(eta.ncols(), m, "η must have K-1 columns");
624        assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
625        let mut acc = 0.0_f64;
626        let mut eta_row = vec![0.0_f64; m];
627        let mut probs_row = vec![0.0_f64; k];
628        for row in 0..n {
629            let w = self.row_weight(row);
630            for j in 0..m {
631                eta_row[j] = eta[[row, j]];
632            }
633            Self::softmax_with_baseline(&eta_row, &mut probs_row);
634            let mut row_acc = 0.0_f64;
635            for c in 0..k {
636                let yc = y[[row, c]];
637                if yc != 0.0 {
638                    // Guard against log(0) when p underflows; clamp the
639                    // probability away from zero by 1e-300 — outside the
640                    // representable range, the residual still drives the
641                    // gradient correctly.
642                    let p = probs_row[c].max(1.0e-300);
643                    row_acc += yc * p.ln();
644                }
645            }
646            acc += w * row_acc;
647        }
648        acc
649    }
650
651    fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
652        let n = eta.nrows();
653        let m = self.active_classes;
654        let k = self.total_classes();
655        assert_eq!(eta.ncols(), m, "η must have K-1 columns");
656        assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
657        let mut out = Array2::<f64>::zeros((n, m));
658        let mut eta_row = vec![0.0_f64; m];
659        let mut probs_row = vec![0.0_f64; k];
660        for row in 0..n {
661            let w = self.row_weight(row);
662            for j in 0..m {
663                eta_row[j] = eta[[row, j]];
664            }
665            Self::softmax_with_baseline(&eta_row, &mut probs_row);
666            for j in 0..m {
667                out[[row, j]] = w * (y[[row, j]] - probs_row[j]);
668            }
669        }
670        out
671    }
672
673    fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
674        // Per-row diagonal of the (M, M) Fisher block:
675        //     H_{n,a,a} = w_n · p_{n,a} · (1 − p_{n,a})
676        // Provided for callers that explicitly want the diagonal-only
677        // preconditioner; the joint dense block ships through `hess_block`.
678        let n = eta.nrows();
679        let m = self.active_classes;
680        let k = self.total_classes();
681        assert_eq!(eta.ncols(), m, "η must have K-1 columns");
682        assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
683        let mut out = Array2::<f64>::zeros((n, m));
684        let mut eta_row = vec![0.0_f64; m];
685        let mut probs_row = vec![0.0_f64; k];
686        for row in 0..n {
687            let w = self.row_weight(row);
688            for j in 0..m {
689                eta_row[j] = eta[[row, j]];
690            }
691            Self::softmax_with_baseline(&eta_row, &mut probs_row);
692            for j in 0..m {
693                let p = probs_row[j];
694                out[[row, j]] = w * p * (1.0 - p);
695            }
696        }
697        out
698    }
699
700    fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
701        // Per-row dense (M, M) Fisher / observed-information block:
702        //     H_{n,a,b} = w_n · ( p_{n,a} · δ_{ab} − p_{n,a} · p_{n,b} )
703        let n = eta.nrows();
704        let m = self.active_classes;
705        let k = self.total_classes();
706        assert_eq!(eta.ncols(), m, "η must have K-1 columns");
707        assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
708        let mut out = Array3::<f64>::zeros((n, m, m));
709        let mut eta_row = vec![0.0_f64; m];
710        let mut probs_row = vec![0.0_f64; k];
711        for row in 0..n {
712            let w = self.row_weight(row);
713            for j in 0..m {
714                eta_row[j] = eta[[row, j]];
715            }
716            Self::softmax_with_baseline(&eta_row, &mut probs_row);
717            for a in 0..m {
718                let pa = probs_row[a];
719                out[[row, a, a]] = w * pa * (1.0 - pa);
720                for b in (a + 1)..m {
721                    let off = -w * pa * probs_row[b];
722                    out[[row, a, b]] = off;
723                    out[[row, b, a]] = off;
724                }
725            }
726        }
727        out
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use ndarray::{Array1, Array2};
735
736    // Macro (not fn) so the assertion / panic tokens are inlined into each
737    // caller's test body, satisfying the build.rs scanner that looks for
738    // `assert!(` / `panic!(` directly in the `#[test]` function.
739    macro_rules! expect_invalid_input {
740        ($result:expr, $needle:expr $(,)?) => {{
741            let needle: &str = $needle;
742            match $result {
743                Ok(_) => {
744                    panic!("expected EstimationError::InvalidInput containing `{needle}`, got Ok")
745                }
746                Err(EstimationError::InvalidInput(msg)) => {
747                    assert!(
748                        msg.contains(needle),
749                        "InvalidInput message `{msg}` does not contain `{needle}`"
750                    );
751                    msg
752                }
753                Err(other) => panic!(
754                    "expected EstimationError::InvalidInput containing `{needle}`, got {other:?}"
755                ),
756            }
757        }};
758    }
759
760    fn dummy_target(n: usize, m: usize) -> VectorResponseTarget {
761        VectorResponseTarget::new(Array2::<f64>::zeros((n, m)), VectorNoise::Isotropic(1.0))
762    }
763
764    #[test]
765    fn with_row_weights_rejects_wrong_length() {
766        let target = dummy_target(4, 2);
767        let weights = Array1::from(vec![1.0, 1.0, 1.0]);
768        expect_invalid_input!(target.with_row_weights(weights), "row_weights length");
769    }
770
771    #[test]
772    fn with_row_weights_rejects_negative_entry() {
773        let target = dummy_target(3, 2);
774        let weights = Array1::from(vec![1.0, -0.5, 2.0]);
775        expect_invalid_input!(
776            target.with_row_weights(weights),
777            "must be finite and non-negative",
778        );
779    }
780
781    #[test]
782    fn with_row_weights_rejects_nan_entry() {
783        let target = dummy_target(3, 2);
784        let weights = Array1::from(vec![1.0, f64::NAN, 2.0]);
785        expect_invalid_input!(
786            target.with_row_weights(weights),
787            "must be finite and non-negative",
788        );
789    }
790
791    #[test]
792    fn with_row_weights_rejects_infinite_entry() {
793        let target = dummy_target(3, 2);
794        let weights = Array1::from(vec![1.0, f64::INFINITY, 2.0]);
795        expect_invalid_input!(
796            target.with_row_weights(weights),
797            "must be finite and non-negative",
798        );
799    }
800
801    #[test]
802    fn with_row_weights_accepts_zero_and_positive() {
803        let target = dummy_target(3, 2);
804        let weights = Array1::from(vec![0.0, 1.5, 3.0]);
805        let weighted = target
806            .with_row_weights(weights)
807            .expect("zero / positive weights should be accepted");
808        assert!(weighted.row_weights.is_some());
809    }
810
811    #[test]
812    fn from_target_rejects_low_rank_factor_with_wrong_row_count() {
813        let n = 4;
814        let m = 3;
815        // factor has 2 rows instead of M = 3.
816        let factor = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
817        let target = VectorResponseTarget::new(
818            Array2::<f64>::zeros((n, m)),
819            VectorNoise::LowRank {
820                diag: Array1::from(vec![1.0; m]),
821                factor,
822            },
823        );
824        expect_invalid_input!(GaussianVectorLikelihood::from_target(&target), "factor has",);
825    }
826
827    #[test]
828    fn from_target_rejects_non_finite_low_rank_factor_entry() {
829        let n = 4;
830        let m = 3;
831        let mut factor = Array2::<f64>::zeros((m, 2));
832        factor[[1, 0]] = f64::NAN;
833        let target = VectorResponseTarget::new(
834            Array2::<f64>::zeros((n, m)),
835            VectorNoise::LowRank {
836                diag: Array1::from(vec![1.0; m]),
837                factor,
838            },
839        );
840        expect_invalid_input!(
841            GaussianVectorLikelihood::from_target(&target),
842            "must be finite",
843        );
844    }
845
846    #[test]
847    fn from_target_accepts_well_formed_low_rank_factor() {
848        let n = 2;
849        let m = 3;
850        let factor = Array2::from_shape_vec((m, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
851        let target = VectorResponseTarget::new(
852            Array2::<f64>::zeros((n, m)),
853            VectorNoise::LowRank {
854                diag: Array1::from(vec![1.0; m]),
855                factor: factor.clone(),
856            },
857        );
858        let lik = GaussianVectorLikelihood::from_target(&target)
859            .expect("well-formed low-rank factor should be accepted");
860        let stored = lik.factor.expect("low-rank factor should be carried");
861        assert_eq!(stored.dim(), (m, 2));
862        for ((i, j), v) in stored.indexed_iter() {
863            assert_eq!(*v, factor[[i, j]]);
864        }
865        // `GaussianVectorLikelihood::precision` is the per-output diagonal
866        // of length `M`, populated from `target.noise.diag_precision(M)`
867        // — not a per-row precision of length `N`. The historical
868        // `assert_eq!(n, lik.precision.len().max(n))` reduces to
869        // `precision.len() ≤ n`, which is the opposite of the contract
870        // (and false for any `M > N`, the typical multivariate-response
871        // shape).
872        assert_eq!(m, lik.precision.len());
873    }
874
875    #[test]
876    fn from_target_propagates_row_weight_length_mismatch() {
877        let n = 3;
878        let m = 2;
879        let target = VectorResponseTarget {
880            y: Array2::<f64>::zeros((n, m)),
881            noise: VectorNoise::Isotropic(1.0),
882            row_weights: Some(Array1::from(vec![1.0, 1.0])),
883        };
884        expect_invalid_input!(
885            GaussianVectorLikelihood::from_target(&target),
886            "row_weights length",
887        );
888    }
889}