gam_models/vector_response.rs
1//! Vector-valued response support.
2//!
3//! Many smooths sharing one latent: the shape function in the latent-variable
4//! engine maps to a reduced activation vector (tens-to-hundreds of dimensions,
5//! after a random-matrix noise cut). This module defines the response-side
6//! types, the Gaussian vector likelihood, and the connector trait the inner
7//! solver consumes.
8//!
9//! Conventions:
10//! - `Y` is shape `(N, M)`: `N` rows, `M` output dimensions.
11//! - `eta` is shape `(N, M)`: the linear predictor with one column per output.
12//! - For Gaussian identity-link, mean(η) = η, so the likelihood depends only
13//! on `eta` and `Y`.
14//!
15//! The Hessian is block-structured: per-row (N independent blocks for the
16//! Gaussian case), each of size `(M, M)`. For a Gaussian likelihood with
17//! Diagonal/Isotropic noise this per-row block is itself diagonal — exactly
18//! what the arrow Schur elimination in `solver/arrow_schur.rs` consumes.
19
20use crate::model_types::EstimationError;
21use ndarray::{Array1, Array2, Array3, ArrayView2};
22
23/// Per-output noise model for a vector response.
24///
25/// `LowRank` stores the symmetric structured precision
26/// `W = diag(diag) + U Uᵀ`, with `factor` holding `U`. The vector likelihood
27/// consumes the owned arrays directly; PIRLS low-rank Gram assembly is handled
28/// by `gam_linalg::low_rank_weight::LowRankWeight` and
29/// `gam_solve::pirls`.
30#[derive(Clone, Debug)]
31pub enum VectorNoise {
32 /// Shared σ across all M outputs: Σ = σ² I_M.
33 Isotropic(f64),
34 /// Per-output σ_m: Σ = diag(σ_m²).
35 Diagonal(Array1<f64>),
36 /// Symmetric structured form `W = diag(diag) + factor · factorᵀ`.
37 LowRank {
38 diag: Array1<f64>,
39 factor: Array2<f64>,
40 },
41}
42
43impl VectorNoise {
44 /// Per-output precision vector (1/σ_m²) for the Isotropic / Diagonal cases.
45 /// LowRank returns the diagonal piece only; the low-rank correction is
46 /// applied separately by the Piece 5 weight code.
47 pub fn diag_precision(&self, m: usize) -> Result<Array1<f64>, EstimationError> {
48 match self {
49 Self::Isotropic(sigma) => {
50 if !sigma.is_finite() || *sigma <= 0.0 {
51 crate::bail_invalid_estim!(
52 "VectorNoise::Isotropic: σ must be > 0 and finite (got {sigma})",
53 );
54 }
55 let p = 1.0 / (sigma * sigma);
56 Ok(Array1::from_elem(m, p))
57 }
58 Self::Diagonal(sigma) => {
59 if sigma.len() != m {
60 crate::bail_invalid_estim!(
61 "VectorNoise::Diagonal: σ length {} ≠ M={m}",
62 sigma.len()
63 );
64 }
65 let mut out = Array1::<f64>::zeros(m);
66 for j in 0..m {
67 let s = sigma[j];
68 if !s.is_finite() || s <= 0.0 {
69 crate::bail_invalid_estim!(
70 "VectorNoise::Diagonal: σ[{j}] must be > 0 and finite (got {s})",
71 );
72 }
73 out[j] = 1.0 / (s * s);
74 }
75 Ok(out)
76 }
77 Self::LowRank { diag, .. } => {
78 if diag.len() != m {
79 crate::bail_invalid_estim!(
80 "VectorNoise::LowRank: diag length {} ≠ M={m}",
81 diag.len()
82 );
83 }
84 let mut out = Array1::<f64>::zeros(m);
85 for j in 0..m {
86 let d = diag[j];
87 if !d.is_finite() || d <= 0.0 {
88 crate::bail_invalid_estim!(
89 "VectorNoise::LowRank: diag[{j}] must be > 0 (got {d})",
90 );
91 }
92 // `diag` is the PRECISION diagonal (W = diag(d) + F·Fᵀ).
93 // Pass it through unchanged.
94 out[j] = d;
95 }
96 Ok(out)
97 }
98 }
99 }
100}
101
102/// Vector-valued response target.
103///
104/// `y` is `(N, M)`; `row_weights` (if present) is length `N` and scales the
105/// per-row contribution to the likelihood (e.g. observation weights from a
106/// re-sampling or inverse-probability scheme).
107#[derive(Clone, Debug)]
108pub struct VectorResponseTarget {
109 /// shape (N, M) — N rows × M output dimensions.
110 pub y: Array2<f64>,
111 /// per-output noise (or shared scalar).
112 pub noise: VectorNoise,
113 /// optional row weights (N,).
114 pub row_weights: Option<Array1<f64>>,
115}
116
117impl VectorResponseTarget {
118 pub fn new(y: Array2<f64>, noise: VectorNoise) -> Self {
119 Self {
120 y,
121 noise,
122 row_weights: None,
123 }
124 }
125
126 pub fn with_row_weights(mut self, w: Array1<f64>) -> Result<Self, EstimationError> {
127 validate_row_weights(&w, self.y.nrows())?;
128 self.row_weights = Some(w);
129 Ok(self)
130 }
131
132 pub fn n(&self) -> usize {
133 self.y.nrows()
134 }
135 pub fn m(&self) -> usize {
136 self.y.ncols()
137 }
138}
139
140/// Relative tolerance on the per-row simplex constraint `Σ_c y_{n,c} = 1`.
141///
142/// The multinomial-logit log-likelihood `ℓ = Σ_c y_c log p_c` has the
143/// canonical residual gradient `y_a − p_a` and Fisher block
144/// `p_a δ_{ab} − p_a p_b` **only** when each target row is a probability
145/// vector (`y_c ≥ 0`, `Σ_c y_c = 1`). For a general row mass `s = Σ_c y_c`
146/// the true derivatives are `y_a − s p_a` and `s (p_a δ_{ab} − p_a p_b)`, so
147/// any row whose mass deviates from 1 makes the implemented gradient/Hessian
148/// disagree with the implemented objective. We therefore require simplex rows
149/// at every construction boundary and reject anything else, rather than
150/// silently fitting with inconsistent curvature. The tolerance absorbs only
151/// floating-point round-off in an otherwise-exact one-hot / label-smoothed
152/// row (e.g. a sum of `K` rationals), not genuine count or proportional data.
153pub(crate) const MULTINOMIAL_SIMPLEX_TOL: f64 = 1.0e-9;
154
155/// Validate that every row of a multinomial target `y ∈ ℝ^{N×K}` is a point on
156/// the probability simplex: `y_{n,c} ≥ 0` for all entries and
157/// `Σ_c y_{n,c} = 1` for every row (up to [`MULTINOMIAL_SIMPLEX_TOL`]). This
158/// is the precondition under which [`MultinomialLogitLikelihood`]'s residual
159/// gradient and Fisher block are the exact derivatives of its log-likelihood;
160/// see the constant's docs. Finiteness is checked first so the message points
161/// at the offending entry rather than at a NaN-poisoned row sum.
162pub(crate) fn validate_multinomial_simplex(
163 y: ArrayView2<f64>,
164 context: &str,
165) -> Result<(), EstimationError> {
166 let (n, k) = y.dim();
167 for row in 0..n {
168 let mut row_sum = 0.0_f64;
169 for c in 0..k {
170 let v = y[[row, c]];
171 if !v.is_finite() {
172 crate::bail_invalid_estim!("{context}: y[{row},{c}] must be finite (got {v})");
173 }
174 if v < 0.0 {
175 crate::bail_invalid_estim!(
176 "{context}: multinomial target must be a probability vector \
177 (y_c ≥ 0); got y[{row},{c}] = {v}"
178 );
179 }
180 row_sum += v;
181 }
182 if (row_sum - 1.0).abs() > MULTINOMIAL_SIMPLEX_TOL {
183 crate::bail_invalid_estim!(
184 "{context}: multinomial target rows must sum to 1 (one-hot for \
185 hard labels, or a label-smoothed probability vector); row {row} \
186 sums to {row_sum}. The softmax residual gradient y_a − p_a and \
187 Fisher block p_a δ_ab − p_a p_b are the derivatives of \
188 Σ_c y_c log p_c only when the row mass is 1."
189 );
190 }
191 }
192 Ok(())
193}
194
195fn validate_row_weights(weights: &Array1<f64>, n: usize) -> Result<(), EstimationError> {
196 if weights.len() != n {
197 crate::bail_invalid_estim!("row_weights length {} ≠ N={n}", weights.len());
198 }
199 for (idx, weight) in weights.iter().copied().enumerate() {
200 if !(weight.is_finite() && weight >= 0.0) {
201 crate::bail_invalid_estim!(
202 "row_weights[{idx}] must be finite and non-negative (got {weight})"
203 );
204 }
205 }
206 Ok(())
207}
208
209/// Connector trait the inner solver (Piece 1) plugs into.
210///
211/// `eta` is the `(N, M)` linear predictor; `y` is the `(N, M)` target. The
212/// implementation is responsible for any link inversion. The `hess_diag`
213/// return is the per-element diagonal of the per-row Hessian block; for a
214/// Diagonal-noise Gaussian this is exactly `(N, M)` of per-output precisions.
215pub trait VectorLikelihood {
216 /// log p(Y | η).
217 fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64;
218
219 /// ∂ log p(Y | η) / ∂ η, shape (N, M).
220 fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
221
222 /// Diagonal of the per-row Hessian −∂² log p / ∂ η ∂ η, shape (N, M).
223 /// This is the per-row block consumed by `solver/arrow_schur.rs`.
224 fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
225
226 /// Per-row dense Hessian block −∂² log p / ∂η_a ∂η_b, shape (N, M, M).
227 ///
228 /// Default implementation lifts [`Self::hess_diag`] onto the per-row
229 /// diagonal, valid only when the per-row Hessian is genuinely diagonal
230 /// across outputs (e.g. Gaussian with Isotropic/Diagonal noise).
231 /// Likelihoods with off-diagonal output coupling must override this:
232 /// [`GaussianVectorLikelihood`] with a low-rank precision factor `F`
233 /// (block `w·(diag(precision) + F·Fᵀ)`, off-diagonals `w·Σ_k F[a,k]·F[b,k]`)
234 /// and multinomial-logit (per-row Fisher block `p_a (δ_ab − p_b)`).
235 ///
236 /// The returned array is consumed by
237 /// [`gam_solve::pirls::dense_block_xtwx`] /
238 /// [`gam_solve::pirls::dense_block_xtwy`] to build `XᵀWX` and `XᵀWy`
239 /// for vector-response IRLS in output-major coefficient ordering.
240 fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
241 let diag = self.hess_diag(eta, y);
242 let (n, m) = diag.dim();
243 let mut out = Array3::<f64>::zeros((n, m, m));
244 for row in 0..n {
245 for j in 0..m {
246 out[[row, j, j]] = diag[[row, j]];
247 }
248 }
249 out
250 }
251}
252
253/// Gaussian vector likelihood with identity link.
254///
255/// `log p(Y|η) = −½ Σ_n w_n · rᵀ W r` where `r = Y_n − η_n` and `W` is the
256/// per-output **precision** matrix. For Isotropic / Diagonal `W = diag(prec)`;
257/// for `LowRank` it is `W = diag(prec) + F · Fᵀ`, with `F` carried alongside
258/// the diagonal here.
259///
260/// (Up to the constant log-determinant of the noise covariance, dropped here
261/// because it does not depend on β or the latent t; the determinant is
262/// accounted for in the REML score, not the inner likelihood.)
263#[derive(Clone, Debug)]
264pub struct GaussianVectorLikelihood {
265 /// Per-output diagonal precision (length M). For Isotropic / Diagonal /
266 /// LowRank this is the diagonal piece of the precision matrix
267 /// (`1/σ_m²` for Diagonal/Isotropic; `diag` for LowRank).
268 pub precision: Array1<f64>,
269 /// Optional dense rank-r factor `F` of size `(M, r)` such that the full
270 /// per-row precision is `diag(precision) + F · Fᵀ`. `None` for the
271 /// Isotropic / Diagonal cases.
272 pub factor: Option<Array2<f64>>,
273 /// Optional row weights (length N), or None for uniform.
274 pub row_weights: Option<Array1<f64>>,
275}
276
277impl GaussianVectorLikelihood {
278 pub fn from_target(target: &VectorResponseTarget) -> Result<Self, EstimationError> {
279 if let Some(weights) = target.row_weights.as_ref() {
280 validate_row_weights(weights, target.n())?;
281 }
282 let precision = target.noise.diag_precision(target.m())?;
283 let factor = match &target.noise {
284 VectorNoise::LowRank { factor, .. } => {
285 if factor.nrows() != target.m() {
286 crate::bail_invalid_estim!(
287 "VectorNoise::LowRank: factor has {} rows but M={}",
288 factor.nrows(),
289 target.m()
290 );
291 }
292 for ((row, col), value) in factor.indexed_iter() {
293 if !value.is_finite() {
294 crate::bail_invalid_estim!(
295 "VectorNoise::LowRank: factor[{row},{col}] must be finite (got {value})"
296 );
297 }
298 }
299 Some(factor.clone())
300 }
301 _ => None,
302 };
303 Ok(Self {
304 precision,
305 factor,
306 row_weights: target.row_weights.clone(),
307 })
308 }
309
310 #[inline]
311 fn row_weight(&self, n: usize) -> f64 {
312 self.row_weights.as_ref().map_or(1.0, |w| w[n])
313 }
314}
315
316impl VectorLikelihood for GaussianVectorLikelihood {
317 fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64 {
318 assert_eq!(eta.dim(), y.dim());
319 assert_eq!(eta.ncols(), self.precision.len());
320 let m = eta.ncols();
321 let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
322 let mut acc = 0.0;
323 // Scratch buffer for Fᵀ r (length rank), reused across rows.
324 let mut ftr = vec![0.0f64; rank];
325 for n in 0..eta.nrows() {
326 let w = self.row_weight(n);
327 // Diagonal part: Σ_m d_m r_m²
328 let mut row_acc = 0.0;
329 for j in 0..m {
330 let r = y[[n, j]] - eta[[n, j]];
331 row_acc += self.precision[j] * r * r;
332 }
333 // Low-rank part: ||Fᵀ r||²
334 if let Some(f) = self.factor.as_ref() {
335 for k in 0..rank {
336 ftr[k] = 0.0;
337 }
338 for j in 0..m {
339 let r = y[[n, j]] - eta[[n, j]];
340 for k in 0..rank {
341 ftr[k] += f[[j, k]] * r;
342 }
343 }
344 for k in 0..rank {
345 row_acc += ftr[k] * ftr[k];
346 }
347 }
348 acc += w * row_acc;
349 }
350 -0.5 * acc
351 }
352
353 fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
354 assert_eq!(eta.dim(), y.dim());
355 let (n_rows, n_cols) = eta.dim();
356 let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
357 let mut out = Array2::<f64>::zeros((n_rows, n_cols));
358 let mut ftr = vec![0.0f64; rank];
359 for n in 0..n_rows {
360 let w = self.row_weight(n);
361 // Diagonal part: w · d_m · (y − η)_m
362 for j in 0..n_cols {
363 out[[n, j]] = w * self.precision[j] * (y[[n, j]] - eta[[n, j]]);
364 }
365 // Low-rank part: + w · F (Fᵀ r) for r = y − η
366 if let Some(f) = self.factor.as_ref() {
367 for k in 0..rank {
368 ftr[k] = 0.0;
369 }
370 for j in 0..n_cols {
371 let r = y[[n, j]] - eta[[n, j]];
372 for k in 0..rank {
373 ftr[k] += f[[j, k]] * r;
374 }
375 }
376 for j in 0..n_cols {
377 let mut s = 0.0;
378 for k in 0..rank {
379 s += f[[j, k]] * ftr[k];
380 }
381 out[[n, j]] += w * s;
382 }
383 }
384 }
385 out
386 }
387
388 fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
389 assert_eq!(eta.dim(), y.dim());
390 // Diagonal of −∂² log p / ∂η² = w · diag(diag(d) + F·Fᵀ); the diagonal
391 // of (F·Fᵀ) at output m is Σ_k F[m, k]². This is the diagonal
392 // *preconditioner* only — the off-diagonal cross terms F[a, k]·F[b, k]
393 // are carried by the full per-row block in [`Self::hess_block`] (which
394 // this type overrides whenever `factor` is present). Callers that need
395 // the true Hessian must use `hess_block`, not this diagonal.
396 let (n_rows, n_cols) = eta.dim();
397 let mut out = Array2::<f64>::zeros((n_rows, n_cols));
398 // Pre-compute Σ_k F[m, k]² per output m (independent of n).
399 let f_row_sqsum: Option<Array1<f64>> = self.factor.as_ref().map(|f| {
400 let m = f.nrows();
401 let r = f.ncols();
402 let mut s = Array1::<f64>::zeros(m);
403 for j in 0..m {
404 let mut acc = 0.0;
405 for k in 0..r {
406 let v = f[[j, k]];
407 acc += v * v;
408 }
409 s[j] = acc;
410 }
411 s
412 });
413 for n in 0..n_rows {
414 let w = self.row_weight(n);
415 for j in 0..n_cols {
416 let mut d = self.precision[j];
417 if let Some(s) = f_row_sqsum.as_ref() {
418 d += s[j];
419 }
420 out[[n, j]] = w * d;
421 }
422 }
423 out
424 }
425
426 fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
427 // Per-row dense block −∂² log p / ∂η_a ∂η_b. With log-likelihood
428 // ℓ = −½ Σ_n w_n · rₙᵀ W rₙ, r = y − η, W = diag(precision) + F·Fᵀ,
429 // the gradient is wₙ · W rₙ and the negative Hessian block is exactly
430 // H_{n,a,b} = w_n · ( precision_a · δ_ab + Σ_k F[a,k] · F[b,k] ).
431 // This is the true second derivative of `log_lik` (it differentiates
432 // `grad_eta` exactly); the diagonal-only trait default would drop the
433 // F·Fᵀ cross terms F[a,k]·F[b,k] for a ≠ b, so it must be overridden
434 // whenever a low-rank factor is present.
435 assert_eq!(eta.dim(), y.dim());
436 assert_eq!(eta.ncols(), self.precision.len());
437 let (n_rows, m) = eta.dim();
438 let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
439
440 // Per-output Gram of the low-rank factor, G_{a,b} = Σ_k F[a,k]·F[b,k].
441 // Independent of the row n, so assemble once and scale by w_n.
442 let gram: Option<Array2<f64>> = self.factor.as_ref().map(|f| {
443 let mut g = Array2::<f64>::zeros((m, m));
444 for a in 0..m {
445 for b in a..m {
446 let mut acc = 0.0;
447 for k in 0..rank {
448 acc += f[[a, k]] * f[[b, k]];
449 }
450 g[[a, b]] = acc;
451 g[[b, a]] = acc;
452 }
453 }
454 g
455 });
456
457 let mut out = Array3::<f64>::zeros((n_rows, m, m));
458 for n in 0..n_rows {
459 let w = self.row_weight(n);
460 for a in 0..m {
461 for b in 0..m {
462 let mut val = if a == b { self.precision[a] } else { 0.0 };
463 if let Some(g) = gram.as_ref() {
464 val += g[[a, b]];
465 }
466 out[[n, a, b]] = w * val;
467 }
468 }
469 }
470 out
471 }
472}
473
474// ─────────────────────────────────────────────────────────────────────────────
475// Piece 5 / Piece 1 row-block support
476// ─────────────────────────────────────────────────────────────────────────────
477
478/// Multinomial-logit (softmax) likelihood with explicit reference class.
479///
480/// Conventions:
481/// - `K` is the total number of classes; the linear predictor has `M = K - 1`
482/// columns corresponding to the *active* classes. Class `K - 1` is the
483/// reference class with η_{K-1} ≡ 0 (so the gauge is fixed by construction
484/// and no additional sum-to-zero projection is required at the η level).
485/// - `y` is the categorical response with shape `(N, K)`. Each row must be a
486/// point on the probability simplex (`y_c ≥ 0`, `Σ_c y_c = 1`): a one-hot
487/// indicator for hard-label classification, or a label-smoothed probability
488/// vector. The row *weight* `w_n` scales the whole row's likelihood
489/// contribution and is independent of the row mass — it is **not** the row
490/// sum. Callers enforce the simplex precondition via
491/// [`validate_multinomial_simplex`] at every construction boundary; under it
492/// the residual gradient `y_a − p_a` and Fisher block `p_a δ_ab − p_a p_b`
493/// below are the exact derivatives of the log-likelihood `Σ_c y_c log p_c`.
494/// - `eta` is the active linear predictor with shape `(N, M = K - 1)`.
495///
496/// Softmax with baseline:
497/// ```text
498/// p_a = exp(η_a) / (1 + Σ_b exp(η_b)) for a ∈ [0, K-1)
499/// p_{K-1} = 1 / (1 + Σ_b exp(η_b))
500/// ```
501///
502/// Log-likelihood (rows with weight `w_n`, default 1.0):
503/// ```text
504/// log L = Σ_n w_n · ( Σ_{a < K-1} y_{n,a} · η_{n,a} − log(1 + Σ_b exp(η_{n,b})) )
505/// = Σ_n w_n · Σ_{c ∈ [0, K)} y_{n,c} · log p_{n,c}
506/// ```
507///
508/// Per-row gradient w.r.t. the active η is the canonical Bernoulli/softmax
509/// residual:
510/// ```text
511/// ∂ log L / ∂η_{n,a} = w_n · (y_{n,a} − p_{n,a}) for a ∈ [0, K-1)
512/// ```
513///
514/// Per-row Fisher (= observed, since logit is canonical for the multinomial)
515/// information block, shape `(M, M)`:
516/// ```text
517/// H_{n,a,b} = w_n · ( p_{n,a} · δ_{ab} − p_{n,a} · p_{n,b} )
518/// ```
519///
520/// This is the standard reference-coded multinomial-logit GLM. The dense
521/// per-row block flows through [`VectorLikelihood::hess_block`] into
522/// [`gam_solve::pirls::dense_block_xtwx`], which builds the stacked
523/// `XᵀWX` in output-major coefficient ordering `β = [β_0; β_1; …; β_{K-2}]`
524/// with each per-class block of size `(P, P)`.
525#[derive(Clone, Debug)]
526pub struct MultinomialLogitLikelihood {
527 /// Number of active classes `M = K − 1`. Cached for shape checks.
528 pub active_classes: usize,
529 /// Optional row weights (length N), or `None` for uniform 1.0.
530 pub row_weights: Option<Array1<f64>>,
531}
532
533impl MultinomialLogitLikelihood {
534 /// Construct from the total number of classes `K ≥ 2`.
535 pub fn with_classes(total_classes: usize) -> Result<Self, EstimationError> {
536 if total_classes < 2 {
537 crate::bail_invalid_estim!(
538 "MultinomialLogitLikelihood requires K ≥ 2 classes (got {total_classes})"
539 );
540 }
541 Ok(Self {
542 active_classes: total_classes - 1,
543 row_weights: None,
544 })
545 }
546
547 /// Attach per-row weights (length N, finite and non-negative).
548 pub fn with_row_weights(mut self, w: Array1<f64>) -> Result<Self, EstimationError> {
549 validate_row_weights(&w, w.len())?;
550 self.row_weights = Some(w);
551 Ok(self)
552 }
553
554 /// Total class count `K = M + 1`.
555 #[inline]
556 pub fn total_classes(&self) -> usize {
557 self.active_classes + 1
558 }
559
560 #[inline]
561 fn row_weight(&self, n: usize) -> f64 {
562 self.row_weights.as_ref().map_or(1.0, |w| w[n])
563 }
564
565 /// Numerically-stable softmax with implicit reference column (η_{K-1} = 0).
566 ///
567 /// Writes `K` probabilities into `out` (length `M + 1`). The shift uses
568 /// `max(0, max(eta_active))` so the reference class is included in the
569 /// max and the denominator stays bounded. This is the canonical
570 /// reference implementation; the FFI surface and any direct
571 /// matrix-free callers route through this method rather than carrying
572 /// their own softmax.
573 pub fn softmax_with_baseline(eta_active: &[f64], out: &mut [f64]) {
574 assert_eq!(out.len(), eta_active.len() + 1);
575 let mut max_eta = 0.0_f64;
576 for &v in eta_active {
577 if v > max_eta {
578 max_eta = v;
579 }
580 }
581 let baseline = (-max_eta).exp();
582 let mut denom = baseline;
583 for (idx, &v) in eta_active.iter().enumerate() {
584 let e = (v - max_eta).exp();
585 out[idx] = e;
586 denom += e;
587 }
588 for v in out.iter_mut().take(eta_active.len()) {
589 *v /= denom;
590 }
591 out[eta_active.len()] = baseline / denom;
592 }
593
594 /// Convenience: compute the full (N, K) probability matrix from
595 /// (N, K-1) active linear predictor. This is the multinomial inverse
596 /// link used by prediction.
597 pub fn probabilities(&self, eta: ArrayView2<f64>) -> Array2<f64> {
598 let n = eta.nrows();
599 let m = self.active_classes;
600 assert_eq!(eta.ncols(), m, "η must have K-1 columns");
601 let k = self.total_classes();
602 let mut probs = Array2::<f64>::zeros((n, k));
603 let mut eta_row = vec![0.0_f64; m];
604 let mut probs_row = vec![0.0_f64; k];
605 for row in 0..n {
606 for j in 0..m {
607 eta_row[j] = eta[[row, j]];
608 }
609 Self::softmax_with_baseline(&eta_row, &mut probs_row);
610 for j in 0..k {
611 probs[[row, j]] = probs_row[j];
612 }
613 }
614 probs
615 }
616}
617
618impl VectorLikelihood for MultinomialLogitLikelihood {
619 fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64 {
620 let n = eta.nrows();
621 let m = self.active_classes;
622 let k = self.total_classes();
623 assert_eq!(eta.ncols(), m, "η must have K-1 columns");
624 assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
625 let mut acc = 0.0_f64;
626 let mut eta_row = vec![0.0_f64; m];
627 let mut probs_row = vec![0.0_f64; k];
628 for row in 0..n {
629 let w = self.row_weight(row);
630 for j in 0..m {
631 eta_row[j] = eta[[row, j]];
632 }
633 Self::softmax_with_baseline(&eta_row, &mut probs_row);
634 let mut row_acc = 0.0_f64;
635 for c in 0..k {
636 let yc = y[[row, c]];
637 if yc != 0.0 {
638 // Guard against log(0) when p underflows; clamp the
639 // probability away from zero by 1e-300 — outside the
640 // representable range, the residual still drives the
641 // gradient correctly.
642 let p = probs_row[c].max(1.0e-300);
643 row_acc += yc * p.ln();
644 }
645 }
646 acc += w * row_acc;
647 }
648 acc
649 }
650
651 fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
652 let n = eta.nrows();
653 let m = self.active_classes;
654 let k = self.total_classes();
655 assert_eq!(eta.ncols(), m, "η must have K-1 columns");
656 assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
657 let mut out = Array2::<f64>::zeros((n, m));
658 let mut eta_row = vec![0.0_f64; m];
659 let mut probs_row = vec![0.0_f64; k];
660 for row in 0..n {
661 let w = self.row_weight(row);
662 for j in 0..m {
663 eta_row[j] = eta[[row, j]];
664 }
665 Self::softmax_with_baseline(&eta_row, &mut probs_row);
666 for j in 0..m {
667 out[[row, j]] = w * (y[[row, j]] - probs_row[j]);
668 }
669 }
670 out
671 }
672
673 fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
674 // Per-row diagonal of the (M, M) Fisher block:
675 // H_{n,a,a} = w_n · p_{n,a} · (1 − p_{n,a})
676 // Provided for callers that explicitly want the diagonal-only
677 // preconditioner; the joint dense block ships through `hess_block`.
678 let n = eta.nrows();
679 let m = self.active_classes;
680 let k = self.total_classes();
681 assert_eq!(eta.ncols(), m, "η must have K-1 columns");
682 assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
683 let mut out = Array2::<f64>::zeros((n, m));
684 let mut eta_row = vec![0.0_f64; m];
685 let mut probs_row = vec![0.0_f64; k];
686 for row in 0..n {
687 let w = self.row_weight(row);
688 for j in 0..m {
689 eta_row[j] = eta[[row, j]];
690 }
691 Self::softmax_with_baseline(&eta_row, &mut probs_row);
692 for j in 0..m {
693 let p = probs_row[j];
694 out[[row, j]] = w * p * (1.0 - p);
695 }
696 }
697 out
698 }
699
700 fn hess_block(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array3<f64> {
701 // Per-row dense (M, M) Fisher / observed-information block:
702 // H_{n,a,b} = w_n · ( p_{n,a} · δ_{ab} − p_{n,a} · p_{n,b} )
703 let n = eta.nrows();
704 let m = self.active_classes;
705 let k = self.total_classes();
706 assert_eq!(eta.ncols(), m, "η must have K-1 columns");
707 assert_eq!(y.dim(), (n, k), "y must be (N, K) one-hot encoded");
708 let mut out = Array3::<f64>::zeros((n, m, m));
709 let mut eta_row = vec![0.0_f64; m];
710 let mut probs_row = vec![0.0_f64; k];
711 for row in 0..n {
712 let w = self.row_weight(row);
713 for j in 0..m {
714 eta_row[j] = eta[[row, j]];
715 }
716 Self::softmax_with_baseline(&eta_row, &mut probs_row);
717 for a in 0..m {
718 let pa = probs_row[a];
719 out[[row, a, a]] = w * pa * (1.0 - pa);
720 for b in (a + 1)..m {
721 let off = -w * pa * probs_row[b];
722 out[[row, a, b]] = off;
723 out[[row, b, a]] = off;
724 }
725 }
726 }
727 out
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use ndarray::{Array1, Array2};
735
736 // Macro (not fn) so the assertion / panic tokens are inlined into each
737 // caller's test body, satisfying the build.rs scanner that looks for
738 // `assert!(` / `panic!(` directly in the `#[test]` function.
739 macro_rules! expect_invalid_input {
740 ($result:expr, $needle:expr $(,)?) => {{
741 let needle: &str = $needle;
742 match $result {
743 Ok(_) => {
744 panic!("expected EstimationError::InvalidInput containing `{needle}`, got Ok")
745 }
746 Err(EstimationError::InvalidInput(msg)) => {
747 assert!(
748 msg.contains(needle),
749 "InvalidInput message `{msg}` does not contain `{needle}`"
750 );
751 msg
752 }
753 Err(other) => panic!(
754 "expected EstimationError::InvalidInput containing `{needle}`, got {other:?}"
755 ),
756 }
757 }};
758 }
759
760 fn dummy_target(n: usize, m: usize) -> VectorResponseTarget {
761 VectorResponseTarget::new(Array2::<f64>::zeros((n, m)), VectorNoise::Isotropic(1.0))
762 }
763
764 #[test]
765 fn with_row_weights_rejects_wrong_length() {
766 let target = dummy_target(4, 2);
767 let weights = Array1::from(vec![1.0, 1.0, 1.0]);
768 expect_invalid_input!(target.with_row_weights(weights), "row_weights length");
769 }
770
771 #[test]
772 fn with_row_weights_rejects_negative_entry() {
773 let target = dummy_target(3, 2);
774 let weights = Array1::from(vec![1.0, -0.5, 2.0]);
775 expect_invalid_input!(
776 target.with_row_weights(weights),
777 "must be finite and non-negative",
778 );
779 }
780
781 #[test]
782 fn with_row_weights_rejects_nan_entry() {
783 let target = dummy_target(3, 2);
784 let weights = Array1::from(vec![1.0, f64::NAN, 2.0]);
785 expect_invalid_input!(
786 target.with_row_weights(weights),
787 "must be finite and non-negative",
788 );
789 }
790
791 #[test]
792 fn with_row_weights_rejects_infinite_entry() {
793 let target = dummy_target(3, 2);
794 let weights = Array1::from(vec![1.0, f64::INFINITY, 2.0]);
795 expect_invalid_input!(
796 target.with_row_weights(weights),
797 "must be finite and non-negative",
798 );
799 }
800
801 #[test]
802 fn with_row_weights_accepts_zero_and_positive() {
803 let target = dummy_target(3, 2);
804 let weights = Array1::from(vec![0.0, 1.5, 3.0]);
805 let weighted = target
806 .with_row_weights(weights)
807 .expect("zero / positive weights should be accepted");
808 assert!(weighted.row_weights.is_some());
809 }
810
811 #[test]
812 fn from_target_rejects_low_rank_factor_with_wrong_row_count() {
813 let n = 4;
814 let m = 3;
815 // factor has 2 rows instead of M = 3.
816 let factor = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
817 let target = VectorResponseTarget::new(
818 Array2::<f64>::zeros((n, m)),
819 VectorNoise::LowRank {
820 diag: Array1::from(vec![1.0; m]),
821 factor,
822 },
823 );
824 expect_invalid_input!(GaussianVectorLikelihood::from_target(&target), "factor has",);
825 }
826
827 #[test]
828 fn from_target_rejects_non_finite_low_rank_factor_entry() {
829 let n = 4;
830 let m = 3;
831 let mut factor = Array2::<f64>::zeros((m, 2));
832 factor[[1, 0]] = f64::NAN;
833 let target = VectorResponseTarget::new(
834 Array2::<f64>::zeros((n, m)),
835 VectorNoise::LowRank {
836 diag: Array1::from(vec![1.0; m]),
837 factor,
838 },
839 );
840 expect_invalid_input!(
841 GaussianVectorLikelihood::from_target(&target),
842 "must be finite",
843 );
844 }
845
846 #[test]
847 fn from_target_accepts_well_formed_low_rank_factor() {
848 let n = 2;
849 let m = 3;
850 let factor = Array2::from_shape_vec((m, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
851 let target = VectorResponseTarget::new(
852 Array2::<f64>::zeros((n, m)),
853 VectorNoise::LowRank {
854 diag: Array1::from(vec![1.0; m]),
855 factor: factor.clone(),
856 },
857 );
858 let lik = GaussianVectorLikelihood::from_target(&target)
859 .expect("well-formed low-rank factor should be accepted");
860 let stored = lik.factor.expect("low-rank factor should be carried");
861 assert_eq!(stored.dim(), (m, 2));
862 for ((i, j), v) in stored.indexed_iter() {
863 assert_eq!(*v, factor[[i, j]]);
864 }
865 // `GaussianVectorLikelihood::precision` is the per-output diagonal
866 // of length `M`, populated from `target.noise.diag_precision(M)`
867 // — not a per-row precision of length `N`. The historical
868 // `assert_eq!(n, lik.precision.len().max(n))` reduces to
869 // `precision.len() ≤ n`, which is the opposite of the contract
870 // (and false for any `M > N`, the typical multivariate-response
871 // shape).
872 assert_eq!(m, lik.precision.len());
873 }
874
875 #[test]
876 fn from_target_propagates_row_weight_length_mismatch() {
877 let n = 3;
878 let m = 2;
879 let target = VectorResponseTarget {
880 y: Array2::<f64>::zeros((n, m)),
881 noise: VectorNoise::Isotropic(1.0),
882 row_weights: Some(Array1::from(vec![1.0, 1.0])),
883 };
884 expect_invalid_input!(
885 GaussianVectorLikelihood::from_target(&target),
886 "row_weights length",
887 );
888 }
889}