Skip to main content

gam_terms/analytic_penalties/
penalty_trait.rs

1use super::*;
2
3pub(crate) const MIN_CONDITIONAL_PRECISION: f64 = 1.0e-12;
4
5/// Floor applied to an assignment probability before taking its logarithm in the
6/// entropic / softmax-assignment penalties, keeping `ln(a)` finite (and the
7/// `a·ln(a)` contribution → 0) as `a → 0` without changing the value anywhere a
8/// is not numerically zero.
9pub const ENTROPY_LOG_PROBABILITY_FLOOR: f64 = 1e-300;
10
11/// Half-width of the open-interval clamp `[ε, 1−ε]` applied to IBP-assignment
12/// probabilities before `ln`/`1/p` so the Bernoulli cross-entropy and its score
13/// stay finite at the simplex boundary.
14pub(crate) const IBP_PROBABILITY_CLAMP: f64 = 1.0e-12;
15
16/// Interior tolerance for the IBP straight-through Bernoulli mean: the
17/// pass-through Jacobian `∂π/∂(mass)` is taken only when the unclamped mean lies
18/// strictly inside `(δ, 1−δ)`; at the saturated boundary the gradient is zero.
19pub(crate) const IBP_INTERIOR_TOL: f64 = 1.0e-9;
20
21/// Floor on the IBP posterior-count denominator `n + a − 1`, guarding the
22/// per-component mean against a zero (or negative) effective count.
23pub(crate) const IBP_COUNT_DENOM_FLOOR: f64 = 1.0e-9;
24
25// ---------------------------------------------------------------------------
26// Common trait
27// ---------------------------------------------------------------------------
28
29/// Whether a penalty's target is a slice of `β` (decoder coefficients), a
30/// slice of extension coordinates (per-observation latent field, e.g.
31/// `LatentCoordValues`),
32/// or a slice of `ρ` (a hyperparameter sub-block — rare, used by hyperpriors
33/// that we don't yet ship analytically).
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum PenaltyTier {
36    Beta,
37    Psi,
38    Rho,
39}
40
41/// Reference for the column / coordinate range a penalty operates over.
42///
43/// Mirrors `BlockwisePenalty::col_range` for the β tier and is the natural
44/// per-observation flat index for the extension-coordinate tier (matching the
45/// `LatentCoordValues` row-major flat layout: `n * d + a`).
46#[derive(Debug, Clone)]
47pub struct PsiSlice {
48    /// Inclusive-start, exclusive-end flat range into the underlying ext-coordinate vector.
49    pub range: std::ops::Range<usize>,
50    /// For latent-coordinate slices: the latent dimensionality, used to
51    /// reshape the flat slice into per-row `(n_obs, d)` blocks.
52    pub latent_dim: Option<usize>,
53}
54
55impl PsiSlice {
56    #[must_use]
57    pub fn full(len: usize, latent_dim: Option<usize>) -> Self {
58        Self {
59            range: 0..len,
60            latent_dim,
61        }
62    }
63
64    pub fn len(&self) -> usize {
65        self.range.len()
66    }
67
68    pub fn is_empty(&self) -> bool {
69        self.range.is_empty()
70    }
71}
72
73/// Resolve a learnable penalty strength `base_weight · exp(rho)` without ever
74/// overflowing to `inf` or (for a nonzero base weight) underflowing to exact
75/// `0.0`.
76///
77/// For finite `rho ≳ 709` the naive `base_weight * rho.exp()` overflows to
78/// `inf`; the resulting `inf` then poisons the solve via `inf · 0.0 = NaN` or
79/// `inf / inf = NaN` in the value/grad/Hessian. Conversely for `rho ≲ -745`
80/// `rho.exp()` underflows to `0.0`, silently disabling a penalty whose base
81/// weight is strictly positive and reintroducing `0/0` in ratios that divide by
82/// the strength.
83///
84/// The fix is to evaluate the product in log-space and clamp the *log-strength*
85/// into the finite-normal band before exponentiating, so the returned strength
86/// is always finite (and strictly positive whenever `base_weight ≠ 0`). The
87/// clamp band is symmetric in log-strength about zero, matched to the largest /
88/// smallest positive normal `f64`, leaving a safety margin so subsequent
89/// multiplications by `O(1)` factors stay finite.
90pub fn resolve_learnable_weight(base_weight: f64, rho: f64) -> f64 {
91    // Largest / smallest log-magnitude that keeps the strength a finite normal
92    // `f64` with headroom for downstream `O(1)` arithmetic.
93    const MAX_LOG_STRENGTH: f64 = 700.0;
94    const MIN_LOG_STRENGTH: f64 = -700.0;
95    if base_weight == 0.0 {
96        return 0.0;
97    }
98    assert!(
99        base_weight.is_finite() && rho.is_finite(),
100        "resolve_learnable_weight requires finite inputs; got base_weight={base_weight}, rho={rho}"
101    );
102    let log_strength = base_weight.abs().ln() + rho;
103    let clamped = log_strength.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH);
104    clamped.exp().copysign(base_weight)
105}
106
107/// Exponentiate a learnable log-precision `exp(log_alpha)` with the exponent
108/// clamped into the finite-normal band, returning a finite, strictly-positive
109/// precision.
110///
111/// A raw `log_alpha.exp()` overflows to `inf` for `log_alpha ≳ 709` (an `inf`
112/// precision then poisons the ARD value/grad/Hessian via `inf · 0.0 = NaN`) and
113/// underflows to exact `0.0` for `log_alpha ≲ -745` (a zero precision drops a
114/// prior the term still expects to be positive). Clamping the exponent and
115/// flooring at the smallest positive normal keeps the precision a finite,
116/// strictly-positive `f64` while still spanning arbitrarily small / large
117/// values within range (#742, Issue 4).
118pub(crate) fn stable_exp_log_precision(log_alpha: f64) -> f64 {
119    const MAX_LOG_STRENGTH: f64 = 700.0;
120    const MIN_LOG_STRENGTH: f64 = -700.0;
121    log_alpha
122        .clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH)
123        .exp()
124        .max(f64::MIN_POSITIVE)
125}
126
127/// Scalar annealing schedule for analytic penalty weights.
128///
129/// This is the penalty-weight analogue of [`crate::terms::sae::manifold::GumbelTemperatureSchedule`]:
130/// it starts with a weak analytic regularizer and ramps toward the target
131/// weight during REML outer iterations. This follows the standard annealed
132/// regularization pattern in deep learning, where optimization first finds
133/// good fits before stronger structure constrains the solution. It also
134/// addresses the general observation that hand-picked analytic weights
135/// materially affect outcomes — fixed tight auxiliary scales can outperform
136/// learned weights on one dataset and underperform on another. A schedule
137/// side-steps that brittle initial choice by ramping the constraint.
138#[derive(Debug, Clone)]
139pub struct ScalarWeightSchedule {
140    pub w_start: f64,
141    pub w_end: f64,
142    pub kind: ScheduleKind,
143    pub iter_count: usize,
144}
145
146impl ScalarWeightSchedule {
147    #[must_use = "build error must be handled"]
148    pub fn new(w_start: f64, w_end: f64, kind: ScheduleKind) -> Result<Self, String> {
149        let schedule = Self {
150            w_start,
151            w_end,
152            kind,
153            iter_count: 0,
154        };
155        schedule.validate()?;
156        Ok(schedule)
157    }
158
159    pub fn validate(&self) -> Result<(), String> {
160        if !(self.w_start.is_finite() && self.w_start >= 0.0) {
161            return Err(format!(
162                "ScalarWeightSchedule: w_start must be finite and non-negative; got {}",
163                self.w_start
164            ));
165        }
166        if !(self.w_end.is_finite() && self.w_end >= 0.0) {
167            return Err(format!(
168                "ScalarWeightSchedule: w_end must be finite and non-negative; got {}",
169                self.w_end
170            ));
171        }
172        match &self.kind {
173            ScheduleKind::Geometric { rate } => {
174                if !(rate.is_finite() && *rate > 0.0 && *rate < 1.0) {
175                    return Err(format!(
176                        "ScalarWeightSchedule::Geometric: rate must be in (0, 1); got {rate}"
177                    ));
178                }
179            }
180            ScheduleKind::Linear { steps } => {
181                if *steps == 0 {
182                    return Err("ScalarWeightSchedule::Linear: steps must be positive".into());
183                }
184            }
185            ScheduleKind::ReciprocalIter => {}
186        }
187        Ok(())
188    }
189
190    pub fn current_weight(&self, iter: usize) -> f64 {
191        let delta = self.w_end - self.w_start;
192        let raw = match &self.kind {
193            ScheduleKind::Geometric { rate } => self.w_end - delta * rate.powf(iter as f64),
194            ScheduleKind::Linear { steps } => {
195                if iter >= *steps {
196                    self.w_end
197                } else {
198                    let frac = iter as f64 / *steps as f64;
199                    self.w_start + frac * delta
200                }
201            }
202            ScheduleKind::ReciprocalIter => self.w_end - delta / (1.0 + iter as f64),
203        };
204        raw.clamp(self.w_start.min(self.w_end), self.w_start.max(self.w_end))
205    }
206
207    pub fn step(&mut self) -> f64 {
208        let weight = self.current_weight(self.iter_count);
209        self.iter_count += 1;
210        weight
211    }
212}
213
214/// Uniform interface implemented by every analytic penalty in this module.
215///
216/// `target` is the relevant slice of the β or extension-coordinate vector, viewed as
217/// a flat `ArrayView1`. The owning REML driver is responsible for slicing the
218/// global parameter vector before calling, and for routing the returned
219/// gradient back into the correct global indices.
220pub trait AnalyticPenalty: Send + Sync {
221    /// Tier the target lives in (β or ext-coord).
222    fn tier(&self) -> PenaltyTier;
223
224    /// Scalar penalty contribution `P(target; ρ)`. The strength factor
225    /// `exp(ρ)` (or whatever parameterization the penalty uses) is folded in.
226    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64;
227
228    /// Gradient `∂P/∂target`, same length as `target`.
229    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
230
231    /// Diagonal of the Hessian `diag(∂²P/∂target²)` when the Hessian is
232    /// block-diagonal. Returns `None` for penalties whose Hessian is dense
233    /// (Isometry); those implement [`Self::hvp`] instead. The default
234    /// signals "no closed-form diagonal" by returning `None` for any
235    /// non-empty target — concrete penalties either override with their
236    /// own analytic diagonal or rely on the matrix-free `hvp` path.
237    fn hessian_diag(
238        &self,
239        target: ArrayView1<'_, f64>,
240        rho: ArrayView1<'_, f64>,
241    ) -> Option<Array1<f64>> {
242        assert!(
243            rho.iter().all(|value| value.is_finite()),
244            "analytic-penalty rho must be finite"
245        );
246        if target.is_empty() {
247            Some(Array1::zeros(0))
248        } else {
249            None
250        }
251    }
252
253    /// Hessian-vector product `H v = (∂²P/∂target²) v`, in closed form.
254    ///
255    /// The default covers every penalty whose Hessian is diagonal: it reads the
256    /// analytic [`Self::hessian_diag`] and forms `diag ⊙ v`. Penalties with a
257    /// dense (non-diagonal) Hessian — e.g. `IsometryPenalty`,
258    /// `SheafConsistencyPenalty`, the orthogonality / nuclear-norm family —
259    /// return `None` from `hessian_diag` and supply their own analytic `hvp`
260    /// override (Laplacian/Gram-vector products). There is no finite-difference
261    /// path: a penalty that reaches the default without a closed-form diagonal
262    /// is a programming error and panics rather than silently differencing its
263    /// own gradient (SPEC: finite differences are never used outside tests).
264    fn hvp(
265        &self,
266        target: ArrayView1<'_, f64>,
267        rho: ArrayView1<'_, f64>,
268        v: ArrayView1<'_, f64>,
269    ) -> Array1<f64> {
270        let diag = self.hessian_diag(target, rho).unwrap_or_else(|| {
271            // SAFETY: programming-error invariant, never a runtime/data condition.
272            // A penalty whose Hessian is non-diagonal MUST override `hvp` with its
273            // closed-form Hessian-vector product; reaching this default means the
274            // impl is missing that override. SPEC forbids a finite-difference
275            // fallback outside tests, so there is no recoverable path — failing
276            // loud here is the contract.
277            panic!(
278                "AnalyticPenalty::hvp default reached for `{}`, whose Hessian is \
279                 not diagonal (hessian_diag returned None). Such a penalty must \
280                 override `hvp` with its closed-form Hessian-vector product; the \
281                 default never finite-differences.",
282                self.name()
283            )
284        });
285        assert_eq!(diag.len(), v.len(), "hvp dimension mismatch");
286        let mut out = Array1::<f64>::zeros(v.len());
287        for i in 0..v.len() {
288            out[i] = diag[i] * v[i];
289        }
290        out
291    }
292
293    /// Diagonal of a **PSD majorizer** of the Hessian — the positive
294    /// re-weighted-ℓ₂ / MM surrogate `diag(B(target; ρ))` with
295    /// `B ⪰ ∂²P/∂target²` everywhere and `B ⪰ 0`. This is a *different*
296    /// operator from [`Self::hessian_diag`]: for nonconvex penalties (log
297    /// sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
298    /// Newton / PIRLS solve and the log-det / preconditioner pipeline require
299    /// a PSD curvature block. For convex penalties the majorizer coincides
300    /// with the exact Hessian, so the default simply delegates to
301    /// [`Self::hessian_diag`]; nonconvex penalties override.
302    fn psd_majorizer_diag(
303        &self,
304        target: ArrayView1<'_, f64>,
305        rho: ArrayView1<'_, f64>,
306    ) -> Option<Array1<f64>> {
307        self.hessian_diag(target, rho)
308    }
309
310    /// Matrix-vector product against the **PSD majorizer** `B(target; ρ) v`
311    /// (see [`Self::psd_majorizer_diag`]). For convex penalties this is the
312    /// exact Hessian-vector product, so the default delegates to
313    /// [`Self::hvp`]; nonconvex penalties override to return their PSD
314    /// surrogate instead of the indefinite true Hessian.
315    fn psd_majorizer_hvp(
316        &self,
317        target: ArrayView1<'_, f64>,
318        rho: ArrayView1<'_, f64>,
319        v: ArrayView1<'_, f64>,
320    ) -> Array1<f64> {
321        if let Some(diag) = self.psd_majorizer_diag(target, rho) {
322            assert_eq!(diag.len(), v.len(), "psd_majorizer_hvp dimension mismatch");
323            let mut out = Array1::<f64>::zeros(v.len());
324            for i in 0..v.len() {
325                out[i] = diag[i] * v[i];
326            }
327            return out;
328        }
329        self.hvp(target, rho, v)
330    }
331
332    /// Gradient of the penalty value w.r.t. each owned ρ-axis. Length equals
333    /// [`Self::rho_count`].
334    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
335
336    /// Number of REML-selectable hyperparameter axes this penalty contributes
337    /// to the outer ρ vector.
338    fn rho_count(&self) -> usize;
339
340    /// Human-readable identifier for diagnostics / logging.
341    fn name(&self) -> &str;
342
343    /// Update any attached scalar weight schedule at the given REML outer
344    /// iteration. Penalties without schedules keep their stored weight.
345    fn apply_schedule(&mut self, iter: usize) {
346        // REML outer loops are bounded well below 1,000,000; a value beyond
347        // that cap signals counter corruption rather than a legitimate
348        // iteration count, so refuse to silently accept it.
349        assert!(
350            iter < 1_000_000,
351            "apply_schedule received implausible outer iteration {iter}",
352        );
353    }
354}
355
356pub(crate) fn advance_scalar_weight(
357    weight: &mut f64,
358    schedule: &mut Option<ScalarWeightSchedule>,
359    iter: usize,
360) {
361    if let Some(schedule) = schedule.as_mut() {
362        *weight = schedule.current_weight(iter);
363        schedule.iter_count = iter + 1;
364    }
365}
366
367/// Emit the standard scalar-weight-schedule builder for a penalty struct whose
368/// scalar weight lives in `$field` and whose schedule lives in
369/// `weight_schedule: Option<ScalarWeightSchedule>`. The builder seeds the
370/// current weight from the schedule and stores the schedule. Invoke inside the
371/// struct's inherent `impl … {}` block.
372macro_rules! impl_with_weight_schedule {
373    ($field:ident) => {
374        /// Attach a scalar weight schedule, seeding the current weight from
375        /// the schedule's stored iteration counter.
376        #[must_use]
377        pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
378            self.$field = schedule.current_weight(schedule.iter_count);
379            self.weight_schedule = Some(schedule);
380            self
381        }
382    };
383}
384
385/// Emit the standard [`AnalyticPenalty::apply_schedule`] override for a penalty
386/// whose scalar weight lives in `$field`. Invoke inside the `impl
387/// AnalyticPenalty for …` block.
388macro_rules! impl_scalar_apply_schedule {
389    ($field:ident) => {
390        fn apply_schedule(&mut self, iter: usize) {
391            advance_scalar_weight(&mut self.$field, &mut self.weight_schedule, iter);
392        }
393    };
394}
395
396/// Emit the standard learnable-scalar-weight [`AnalyticPenalty::grad_rho`] for a
397/// penalty whose single owned ρ-axis is the (optionally learnable) log-weight at
398/// `self.rho_index`, gated by `self.learnable_weight`. Invoke inside the `impl
399/// AnalyticPenalty for …` block.
400macro_rules! impl_learnable_weight_grad_rho {
401    () => {
402        fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
403            if !self.learnable_weight {
404                return Array1::<f64>::zeros(0);
405            }
406            let mut out = Array1::<f64>::zeros(1);
407            out[self.rho_index] = self.value(target, rho);
408            out
409        }
410    };
411}
412
413/// Emit the standard learnable-scalar-weight [`AnalyticPenalty::rho_count`]:
414/// one ρ-axis when the weight is learnable, none otherwise. Invoke inside the
415/// `impl AnalyticPenalty for …` block.
416macro_rules! impl_learnable_weight_rho_count {
417    () => {
418        fn rho_count(&self) -> usize {
419            usize::from(self.learnable_weight)
420        }
421    };
422}