gam_terms/analytic_penalties/penalty_trait.rs
1use super::*;
2
3pub(crate) const MIN_CONDITIONAL_PRECISION: f64 = 1.0e-12;
4
5/// Floor applied to an assignment probability before taking its logarithm in the
6/// entropic / softmax-assignment penalties, keeping `ln(a)` finite (and the
7/// `a·ln(a)` contribution → 0) as `a → 0` without changing the value anywhere a
8/// is not numerically zero.
9pub const ENTROPY_LOG_PROBABILITY_FLOOR: f64 = 1e-300;
10
11/// Half-width of the open-interval clamp `[ε, 1−ε]` applied to IBP-assignment
12/// probabilities before `ln`/`1/p` so the Bernoulli cross-entropy and its score
13/// stay finite at the simplex boundary.
14pub(crate) const IBP_PROBABILITY_CLAMP: f64 = 1.0e-12;
15
16/// Interior tolerance for the IBP straight-through Bernoulli mean: the
17/// pass-through Jacobian `∂π/∂(mass)` is taken only when the unclamped mean lies
18/// strictly inside `(δ, 1−δ)`; at the saturated boundary the gradient is zero.
19pub(crate) const IBP_INTERIOR_TOL: f64 = 1.0e-9;
20
21/// Floor on the IBP posterior-count denominator `n + a − 1`, guarding the
22/// per-component mean against a zero (or negative) effective count.
23pub(crate) const IBP_COUNT_DENOM_FLOOR: f64 = 1.0e-9;
24
25// ---------------------------------------------------------------------------
26// Common trait
27// ---------------------------------------------------------------------------
28
29/// Whether a penalty's target is a slice of `β` (decoder coefficients), a
30/// slice of extension coordinates (per-observation latent field, e.g.
31/// `LatentCoordValues`),
32/// or a slice of `ρ` (a hyperparameter sub-block — rare, used by hyperpriors
33/// that we don't yet ship analytically).
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum PenaltyTier {
36 Beta,
37 Psi,
38 Rho,
39}
40
41/// Reference for the column / coordinate range a penalty operates over.
42///
43/// Mirrors `BlockwisePenalty::col_range` for the β tier and is the natural
44/// per-observation flat index for the extension-coordinate tier (matching the
45/// `LatentCoordValues` row-major flat layout: `n * d + a`).
46#[derive(Debug, Clone)]
47pub struct PsiSlice {
48 /// Inclusive-start, exclusive-end flat range into the underlying ext-coordinate vector.
49 pub range: std::ops::Range<usize>,
50 /// For latent-coordinate slices: the latent dimensionality, used to
51 /// reshape the flat slice into per-row `(n_obs, d)` blocks.
52 pub latent_dim: Option<usize>,
53}
54
55impl PsiSlice {
56 #[must_use]
57 pub fn full(len: usize, latent_dim: Option<usize>) -> Self {
58 Self {
59 range: 0..len,
60 latent_dim,
61 }
62 }
63
64 pub fn len(&self) -> usize {
65 self.range.len()
66 }
67
68 pub fn is_empty(&self) -> bool {
69 self.range.is_empty()
70 }
71}
72
73/// Resolve a learnable penalty strength `base_weight · exp(rho)` without ever
74/// overflowing to `inf` or (for a nonzero base weight) underflowing to exact
75/// `0.0`.
76///
77/// For finite `rho ≳ 709` the naive `base_weight * rho.exp()` overflows to
78/// `inf`; the resulting `inf` then poisons the solve via `inf · 0.0 = NaN` or
79/// `inf / inf = NaN` in the value/grad/Hessian. Conversely for `rho ≲ -745`
80/// `rho.exp()` underflows to `0.0`, silently disabling a penalty whose base
81/// weight is strictly positive and reintroducing `0/0` in ratios that divide by
82/// the strength.
83///
84/// The fix is to evaluate the product in log-space and clamp the *log-strength*
85/// into the finite-normal band before exponentiating, so the returned strength
86/// is always finite (and strictly positive whenever `base_weight ≠ 0`). The
87/// clamp band is symmetric in log-strength about zero, matched to the largest /
88/// smallest positive normal `f64`, leaving a safety margin so subsequent
89/// multiplications by `O(1)` factors stay finite.
90pub fn resolve_learnable_weight(base_weight: f64, rho: f64) -> f64 {
91 // Largest / smallest log-magnitude that keeps the strength a finite normal
92 // `f64` with headroom for downstream `O(1)` arithmetic.
93 const MAX_LOG_STRENGTH: f64 = 700.0;
94 const MIN_LOG_STRENGTH: f64 = -700.0;
95 if base_weight == 0.0 {
96 return 0.0;
97 }
98 assert!(
99 base_weight.is_finite() && rho.is_finite(),
100 "resolve_learnable_weight requires finite inputs; got base_weight={base_weight}, rho={rho}"
101 );
102 let log_strength = base_weight.abs().ln() + rho;
103 let clamped = log_strength.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH);
104 clamped.exp().copysign(base_weight)
105}
106
107/// Exponentiate a learnable log-precision `exp(log_alpha)` with the exponent
108/// clamped into the finite-normal band, returning a finite, strictly-positive
109/// precision.
110///
111/// A raw `log_alpha.exp()` overflows to `inf` for `log_alpha ≳ 709` (an `inf`
112/// precision then poisons the ARD value/grad/Hessian via `inf · 0.0 = NaN`) and
113/// underflows to exact `0.0` for `log_alpha ≲ -745` (a zero precision drops a
114/// prior the term still expects to be positive). Clamping the exponent and
115/// flooring at the smallest positive normal keeps the precision a finite,
116/// strictly-positive `f64` while still spanning arbitrarily small / large
117/// values within range (#742, Issue 4).
118pub(crate) fn stable_exp_log_precision(log_alpha: f64) -> f64 {
119 const MAX_LOG_STRENGTH: f64 = 700.0;
120 const MIN_LOG_STRENGTH: f64 = -700.0;
121 log_alpha
122 .clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH)
123 .exp()
124 .max(f64::MIN_POSITIVE)
125}
126
127/// Scalar annealing schedule for analytic penalty weights.
128///
129/// This is the penalty-weight analogue of [`crate::terms::sae::manifold::GumbelTemperatureSchedule`]:
130/// it starts with a weak analytic regularizer and ramps toward the target
131/// weight during REML outer iterations. This follows the standard annealed
132/// regularization pattern in deep learning, where optimization first finds
133/// good fits before stronger structure constrains the solution. It also
134/// addresses the general observation that hand-picked analytic weights
135/// materially affect outcomes — fixed tight auxiliary scales can outperform
136/// learned weights on one dataset and underperform on another. A schedule
137/// side-steps that brittle initial choice by ramping the constraint.
138#[derive(Debug, Clone)]
139pub struct ScalarWeightSchedule {
140 pub w_start: f64,
141 pub w_end: f64,
142 pub kind: ScheduleKind,
143 pub iter_count: usize,
144}
145
146impl ScalarWeightSchedule {
147 #[must_use = "build error must be handled"]
148 pub fn new(w_start: f64, w_end: f64, kind: ScheduleKind) -> Result<Self, String> {
149 let schedule = Self {
150 w_start,
151 w_end,
152 kind,
153 iter_count: 0,
154 };
155 schedule.validate()?;
156 Ok(schedule)
157 }
158
159 pub fn validate(&self) -> Result<(), String> {
160 if !(self.w_start.is_finite() && self.w_start >= 0.0) {
161 return Err(format!(
162 "ScalarWeightSchedule: w_start must be finite and non-negative; got {}",
163 self.w_start
164 ));
165 }
166 if !(self.w_end.is_finite() && self.w_end >= 0.0) {
167 return Err(format!(
168 "ScalarWeightSchedule: w_end must be finite and non-negative; got {}",
169 self.w_end
170 ));
171 }
172 match &self.kind {
173 ScheduleKind::Geometric { rate } => {
174 if !(rate.is_finite() && *rate > 0.0 && *rate < 1.0) {
175 return Err(format!(
176 "ScalarWeightSchedule::Geometric: rate must be in (0, 1); got {rate}"
177 ));
178 }
179 }
180 ScheduleKind::Linear { steps } => {
181 if *steps == 0 {
182 return Err("ScalarWeightSchedule::Linear: steps must be positive".into());
183 }
184 }
185 ScheduleKind::ReciprocalIter => {}
186 }
187 Ok(())
188 }
189
190 pub fn current_weight(&self, iter: usize) -> f64 {
191 let delta = self.w_end - self.w_start;
192 let raw = match &self.kind {
193 ScheduleKind::Geometric { rate } => self.w_end - delta * rate.powf(iter as f64),
194 ScheduleKind::Linear { steps } => {
195 if iter >= *steps {
196 self.w_end
197 } else {
198 let frac = iter as f64 / *steps as f64;
199 self.w_start + frac * delta
200 }
201 }
202 ScheduleKind::ReciprocalIter => self.w_end - delta / (1.0 + iter as f64),
203 };
204 raw.clamp(self.w_start.min(self.w_end), self.w_start.max(self.w_end))
205 }
206
207 pub fn step(&mut self) -> f64 {
208 let weight = self.current_weight(self.iter_count);
209 self.iter_count += 1;
210 weight
211 }
212}
213
214/// Uniform interface implemented by every analytic penalty in this module.
215///
216/// `target` is the relevant slice of the β or extension-coordinate vector, viewed as
217/// a flat `ArrayView1`. The owning REML driver is responsible for slicing the
218/// global parameter vector before calling, and for routing the returned
219/// gradient back into the correct global indices.
220pub trait AnalyticPenalty: Send + Sync {
221 /// Tier the target lives in (β or ext-coord).
222 fn tier(&self) -> PenaltyTier;
223
224 /// Scalar penalty contribution `P(target; ρ)`. The strength factor
225 /// `exp(ρ)` (or whatever parameterization the penalty uses) is folded in.
226 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64;
227
228 /// Gradient `∂P/∂target`, same length as `target`.
229 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
230
231 /// Diagonal of the Hessian `diag(∂²P/∂target²)` when the Hessian is
232 /// block-diagonal. Returns `None` for penalties whose Hessian is dense
233 /// (Isometry); those implement [`Self::hvp`] instead. The default
234 /// signals "no closed-form diagonal" by returning `None` for any
235 /// non-empty target — concrete penalties either override with their
236 /// own analytic diagonal or rely on the matrix-free `hvp` path.
237 fn hessian_diag(
238 &self,
239 target: ArrayView1<'_, f64>,
240 rho: ArrayView1<'_, f64>,
241 ) -> Option<Array1<f64>> {
242 assert!(
243 rho.iter().all(|value| value.is_finite()),
244 "analytic-penalty rho must be finite"
245 );
246 if target.is_empty() {
247 Some(Array1::zeros(0))
248 } else {
249 None
250 }
251 }
252
253 /// Hessian-vector product `H v = (∂²P/∂target²) v`, in closed form.
254 ///
255 /// The default covers every penalty whose Hessian is diagonal: it reads the
256 /// analytic [`Self::hessian_diag`] and forms `diag ⊙ v`. Penalties with a
257 /// dense (non-diagonal) Hessian — e.g. `IsometryPenalty`,
258 /// `SheafConsistencyPenalty`, the orthogonality / nuclear-norm family —
259 /// return `None` from `hessian_diag` and supply their own analytic `hvp`
260 /// override (Laplacian/Gram-vector products). There is no finite-difference
261 /// path: a penalty that reaches the default without a closed-form diagonal
262 /// is a programming error and panics rather than silently differencing its
263 /// own gradient (SPEC: finite differences are never used outside tests).
264 fn hvp(
265 &self,
266 target: ArrayView1<'_, f64>,
267 rho: ArrayView1<'_, f64>,
268 v: ArrayView1<'_, f64>,
269 ) -> Array1<f64> {
270 let diag = self.hessian_diag(target, rho).unwrap_or_else(|| {
271 // SAFETY: programming-error invariant, never a runtime/data condition.
272 // A penalty whose Hessian is non-diagonal MUST override `hvp` with its
273 // closed-form Hessian-vector product; reaching this default means the
274 // impl is missing that override. SPEC forbids a finite-difference
275 // fallback outside tests, so there is no recoverable path — failing
276 // loud here is the contract.
277 panic!(
278 "AnalyticPenalty::hvp default reached for `{}`, whose Hessian is \
279 not diagonal (hessian_diag returned None). Such a penalty must \
280 override `hvp` with its closed-form Hessian-vector product; the \
281 default never finite-differences.",
282 self.name()
283 )
284 });
285 assert_eq!(diag.len(), v.len(), "hvp dimension mismatch");
286 let mut out = Array1::<f64>::zeros(v.len());
287 for i in 0..v.len() {
288 out[i] = diag[i] * v[i];
289 }
290 out
291 }
292
293 /// Diagonal of a **PSD majorizer** of the Hessian — the positive
294 /// re-weighted-ℓ₂ / MM surrogate `diag(B(target; ρ))` with
295 /// `B ⪰ ∂²P/∂target²` everywhere and `B ⪰ 0`. This is a *different*
296 /// operator from [`Self::hessian_diag`]: for nonconvex penalties (log
297 /// sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
298 /// Newton / PIRLS solve and the log-det / preconditioner pipeline require
299 /// a PSD curvature block. For convex penalties the majorizer coincides
300 /// with the exact Hessian, so the default simply delegates to
301 /// [`Self::hessian_diag`]; nonconvex penalties override.
302 fn psd_majorizer_diag(
303 &self,
304 target: ArrayView1<'_, f64>,
305 rho: ArrayView1<'_, f64>,
306 ) -> Option<Array1<f64>> {
307 self.hessian_diag(target, rho)
308 }
309
310 /// Matrix-vector product against the **PSD majorizer** `B(target; ρ) v`
311 /// (see [`Self::psd_majorizer_diag`]). For convex penalties this is the
312 /// exact Hessian-vector product, so the default delegates to
313 /// [`Self::hvp`]; nonconvex penalties override to return their PSD
314 /// surrogate instead of the indefinite true Hessian.
315 fn psd_majorizer_hvp(
316 &self,
317 target: ArrayView1<'_, f64>,
318 rho: ArrayView1<'_, f64>,
319 v: ArrayView1<'_, f64>,
320 ) -> Array1<f64> {
321 if let Some(diag) = self.psd_majorizer_diag(target, rho) {
322 assert_eq!(diag.len(), v.len(), "psd_majorizer_hvp dimension mismatch");
323 let mut out = Array1::<f64>::zeros(v.len());
324 for i in 0..v.len() {
325 out[i] = diag[i] * v[i];
326 }
327 return out;
328 }
329 self.hvp(target, rho, v)
330 }
331
332 /// Gradient of the penalty value w.r.t. each owned ρ-axis. Length equals
333 /// [`Self::rho_count`].
334 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
335
336 /// Number of REML-selectable hyperparameter axes this penalty contributes
337 /// to the outer ρ vector.
338 fn rho_count(&self) -> usize;
339
340 /// Human-readable identifier for diagnostics / logging.
341 fn name(&self) -> &str;
342
343 /// Update any attached scalar weight schedule at the given REML outer
344 /// iteration. Penalties without schedules keep their stored weight.
345 fn apply_schedule(&mut self, iter: usize) {
346 // REML outer loops are bounded well below 1,000,000; a value beyond
347 // that cap signals counter corruption rather than a legitimate
348 // iteration count, so refuse to silently accept it.
349 assert!(
350 iter < 1_000_000,
351 "apply_schedule received implausible outer iteration {iter}",
352 );
353 }
354}
355
356pub(crate) fn advance_scalar_weight(
357 weight: &mut f64,
358 schedule: &mut Option<ScalarWeightSchedule>,
359 iter: usize,
360) {
361 if let Some(schedule) = schedule.as_mut() {
362 *weight = schedule.current_weight(iter);
363 schedule.iter_count = iter + 1;
364 }
365}
366
367/// Emit the standard scalar-weight-schedule builder for a penalty struct whose
368/// scalar weight lives in `$field` and whose schedule lives in
369/// `weight_schedule: Option<ScalarWeightSchedule>`. The builder seeds the
370/// current weight from the schedule and stores the schedule. Invoke inside the
371/// struct's inherent `impl … {}` block.
372macro_rules! impl_with_weight_schedule {
373 ($field:ident) => {
374 /// Attach a scalar weight schedule, seeding the current weight from
375 /// the schedule's stored iteration counter.
376 #[must_use]
377 pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
378 self.$field = schedule.current_weight(schedule.iter_count);
379 self.weight_schedule = Some(schedule);
380 self
381 }
382 };
383}
384
385/// Emit the standard [`AnalyticPenalty::apply_schedule`] override for a penalty
386/// whose scalar weight lives in `$field`. Invoke inside the `impl
387/// AnalyticPenalty for …` block.
388macro_rules! impl_scalar_apply_schedule {
389 ($field:ident) => {
390 fn apply_schedule(&mut self, iter: usize) {
391 advance_scalar_weight(&mut self.$field, &mut self.weight_schedule, iter);
392 }
393 };
394}
395
396/// Emit the standard learnable-scalar-weight [`AnalyticPenalty::grad_rho`] for a
397/// penalty whose single owned ρ-axis is the (optionally learnable) log-weight at
398/// `self.rho_index`, gated by `self.learnable_weight`. Invoke inside the `impl
399/// AnalyticPenalty for …` block.
400macro_rules! impl_learnable_weight_grad_rho {
401 () => {
402 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
403 if !self.learnable_weight {
404 return Array1::<f64>::zeros(0);
405 }
406 let mut out = Array1::<f64>::zeros(1);
407 out[self.rho_index] = self.value(target, rho);
408 out
409 }
410 };
411}
412
413/// Emit the standard learnable-scalar-weight [`AnalyticPenalty::rho_count`]:
414/// one ρ-axis when the weight is learnable, none otherwise. Invoke inside the
415/// `impl AnalyticPenalty for …` block.
416macro_rules! impl_learnable_weight_rho_count {
417 () => {
418 fn rho_count(&self) -> usize {
419 usize::from(self.learnable_weight)
420 }
421 };
422}