Skip to main content

gam_terms/analytic_penalties/
total_variation.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// Total variation penalty
5// ---------------------------------------------------------------------------
6
7/// Shape of the first-difference operator used by [`TotalVariationPenalty`].
8#[derive(Debug, Clone)]
9pub enum DifferenceOpKind {
10    /// Path graph with rows connected as `(0, 1), (1, 2), ...`.
11    ForwardDiff1D,
12    /// Explicit adjacency list; each edge row has `-1` at `from`, `+1` at `to`.
13    GraphEdges(Vec<(usize, usize)>),
14}
15
16/// Coordinatewise/anisotropic smoothed-L¹ total variation on a row-major
17/// `(n_eff, d)` latent block.
18///
19/// Uses the differentiable Huber-style kernel `φ(x)=sqrt(x²+ε²)-ε` separately
20/// for each edge and latent axis. This is not vector-norm/isotropic edge TV:
21/// the Hessian intentionally has no cross-axis terms. The difference operator
22/// defines the prior shape: forward 1-D differences for ordered context
23/// windows, or graph edges for adjacency-structured atoms. Pair TV with
24/// Orthogonality when piecewise-constant atoms need a gauge-fixed basis.
25#[derive(Debug, Clone)]
26pub struct TotalVariationPenalty {
27    /// Base strength. If `learnable_weight` is true, the resolved strength is
28    /// `weight * exp(rho[rho_index])`; otherwise it is fixed at `weight`.
29    pub weight: f64,
30    /// Number of rows in the row-major latent coefficient block.
31    pub n_eff: usize,
32    pub difference_op: DifferenceOpKind,
33    pub smoothing_eps: f64,
34    pub learnable_weight: bool,
35    pub rho_index: usize,
36    pub weight_schedule: Option<ScalarWeightSchedule>,
37}
38
39impl TotalVariationPenalty {
40    #[must_use = "build error must be handled"]
41    pub fn new(
42        weight: f64,
43        n_eff: usize,
44        difference_op: DifferenceOpKind,
45        smoothing_eps: f64,
46        learnable_weight: bool,
47    ) -> Result<Self, String> {
48        if !(weight.is_finite() && weight > 0.0) {
49            return Err(format!(
50                "TotalVariationPenalty::new requires finite weight > 0, got {weight}"
51            ));
52        }
53        if n_eff == 0 {
54            return Err("TotalVariationPenalty::new requires n_eff > 0".to_string());
55        }
56        if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
57            return Err(format!(
58                "TotalVariationPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
59            ));
60        }
61        if let DifferenceOpKind::GraphEdges(edges) = &difference_op {
62            if edges.is_empty() {
63                return Err(
64                    "TotalVariationPenalty::new GraphEdges requires at least one edge".to_string(),
65                );
66            }
67            for &(a, b) in edges {
68                if a >= n_eff || b >= n_eff {
69                    return Err(format!(
70                        "TotalVariationPenalty::new graph edge ({a}, {b}) exceeds n_eff {n_eff}"
71                    ));
72                }
73                if a == b {
74                    return Err(format!(
75                        "TotalVariationPenalty::new graph edge ({a}, {b}) is self-referential"
76                    ));
77                }
78            }
79        }
80        Ok(Self {
81            weight,
82            n_eff,
83            difference_op,
84            smoothing_eps,
85            learnable_weight,
86            rho_index: 0,
87            weight_schedule: None,
88        })
89    }
90
91    impl_with_weight_schedule!(weight);
92
93    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
94        if self.learnable_weight {
95            resolve_learnable_weight(self.weight, rho[self.rho_index])
96        } else {
97            self.weight
98        }
99    }
100
101    fn latent_dim(&self, target_len: usize) -> Option<usize> {
102        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
103            assert_eq!(
104                target_len % self.n_eff.max(1),
105                0,
106                "target length must be divisible by n_eff"
107            );
108            return None;
109        }
110        Some(target_len / self.n_eff)
111    }
112
113    fn edge_count(&self) -> usize {
114        match &self.difference_op {
115            DifferenceOpKind::ForwardDiff1D => self.n_eff.saturating_sub(1),
116            DifferenceOpKind::GraphEdges(edges) => edges.len(),
117        }
118    }
119
120    fn add_edge_hvp(
121        &self,
122        target: ArrayView1<'_, f64>,
123        v: ArrayView1<'_, f64>,
124        out: &mut Array1<f64>,
125        d: usize,
126        a: usize,
127        b: usize,
128        weight: f64,
129    ) {
130        let eps2 = self.smoothing_eps * self.smoothing_eps;
131        for j in 0..d {
132            let ia = a * d + j;
133            let ib = b * d + j;
134            let diff = target[ib] - target[ia];
135            let r = (diff * diff + eps2).sqrt();
136            let curvature = eps2 / (r * r * r);
137            let dv = v[ib] - v[ia];
138            let h = weight * curvature * dv;
139            out[ia] -= h;
140            out[ib] += h;
141        }
142    }
143
144    fn add_edge_grad(
145        &self,
146        target: ArrayView1<'_, f64>,
147        out: &mut Array1<f64>,
148        d: usize,
149        a: usize,
150        b: usize,
151        weight: f64,
152    ) {
153        let eps2 = self.smoothing_eps * self.smoothing_eps;
154        for j in 0..d {
155            let ia = a * d + j;
156            let ib = b * d + j;
157            let diff = target[ib] - target[ia];
158            let smooth_sign = diff / (diff * diff + eps2).sqrt();
159            let g = weight * smooth_sign;
160            out[ia] -= g;
161            out[ib] += g;
162        }
163    }
164
165    fn add_edge_diag(
166        &self,
167        target: ArrayView1<'_, f64>,
168        out: &mut Array1<f64>,
169        d: usize,
170        a: usize,
171        b: usize,
172        weight: f64,
173    ) {
174        let eps2 = self.smoothing_eps * self.smoothing_eps;
175        for j in 0..d {
176            let ia = a * d + j;
177            let ib = b * d + j;
178            let diff = target[ib] - target[ia];
179            let r = (diff * diff + eps2).sqrt();
180            let curvature = weight * eps2 / (r * r * r);
181            out[ia] += curvature;
182            out[ib] += curvature;
183        }
184    }
185
186    fn add_edge_dense(
187        &self,
188        target: ArrayView1<'_, f64>,
189        out: &mut Array2<f64>,
190        d: usize,
191        a: usize,
192        b: usize,
193        weight: f64,
194    ) {
195        let eps2 = self.smoothing_eps * self.smoothing_eps;
196        for j in 0..d {
197            let ia = a * d + j;
198            let ib = b * d + j;
199            let diff = target[ib] - target[ia];
200            let r = (diff * diff + eps2).sqrt();
201            let curvature = weight * eps2 / (r * r * r);
202            out[[ia, ia]] += curvature;
203            out[[ib, ib]] += curvature;
204            out[[ia, ib]] -= curvature;
205            out[[ib, ia]] -= curvature;
206        }
207    }
208
209    pub fn diag_target(
210        &self,
211        target: ArrayView1<'_, f64>,
212        rho: ArrayView1<'_, f64>,
213    ) -> Array1<f64> {
214        let Some(d) = self.latent_dim(target.len()) else {
215            return Array1::<f64>::zeros(target.len());
216        };
217        let weight = self.resolved_weight(rho);
218        let mut out = Array1::<f64>::zeros(target.len());
219        match &self.difference_op {
220            DifferenceOpKind::ForwardDiff1D => {
221                for a in 0..self.n_eff.saturating_sub(1) {
222                    self.add_edge_diag(target, &mut out, d, a, a + 1, weight);
223                }
224            }
225            DifferenceOpKind::GraphEdges(edges) => {
226                for &(a, b) in edges {
227                    self.add_edge_diag(target, &mut out, d, a, b, weight);
228                }
229            }
230        }
231        out
232    }
233
234    /// Materialize `Dᵀ diag(φ''(D T)) D` for diagnostics and small graph cases.
235    pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
236        let n = target.len();
237        let Some(d) = self.latent_dim(n) else {
238            return Array2::<f64>::zeros((n, n));
239        };
240        let weight = self.resolved_weight(rho);
241        let mut out = Array2::<f64>::zeros((n, n));
242        match &self.difference_op {
243            DifferenceOpKind::ForwardDiff1D => {
244                for a in 0..self.n_eff.saturating_sub(1) {
245                    self.add_edge_dense(target, &mut out, d, a, a + 1, weight);
246                }
247            }
248            DifferenceOpKind::GraphEdges(edges) => {
249                for &(a, b) in edges {
250                    self.add_edge_dense(target, &mut out, d, a, b, weight);
251                }
252            }
253        }
254        out
255    }
256
257    pub fn log_det_plus_lambda_i_forward_1d(
258        &self,
259        target: ArrayView1<'_, f64>,
260        rho: ArrayView1<'_, f64>,
261        lambda: f64,
262    ) -> Result<f64, String> {
263        if !matches!(&self.difference_op, DifferenceOpKind::ForwardDiff1D) {
264            return Err(
265                "TotalVariationPenalty::log_det_plus_lambda_i_forward_1d requires ForwardDiff1D"
266                    .to_string(),
267            );
268        }
269        let Some(d) = self.latent_dim(target.len()) else {
270            return Err(format!(
271                "TotalVariationPenalty target length {} is not divisible by n_eff {}",
272                target.len(),
273                self.n_eff
274            ));
275        };
276        if !(lambda.is_finite() && lambda > 0.0) {
277            return Err(format!(
278                "TotalVariationPenalty::log_det_plus_lambda_i_forward_1d requires finite λ > 0; got {lambda}"
279            ));
280        }
281        let n = self.n_eff;
282        if n == 1 {
283            return Ok((d as f64) * lambda.ln());
284        }
285        let weight = self.resolved_weight(rho);
286        let eps2 = self.smoothing_eps * self.smoothing_eps;
287        let mut total = 0.0;
288        for j in 0..d {
289            let mut edge_w = vec![0.0; n - 1];
290            for a in 0..n - 1 {
291                let diff = target[(a + 1) * d + j] - target[a * d + j];
292                let r = (diff * diff + eps2).sqrt();
293                edge_w[a] = weight * eps2 / (r * r * r);
294            }
295
296            let mut prev_pivot = lambda + edge_w[0];
297            if !prev_pivot.is_finite() || prev_pivot <= 0.0 {
298                return Err(format!(
299                    "TotalVariationPenalty log-det encountered non-positive pivot {prev_pivot:.3e}"
300                ));
301            }
302            total += prev_pivot.ln();
303            for row in 1..n {
304                let left = edge_w[row - 1];
305                let right = if row + 1 < n { edge_w[row] } else { 0.0 };
306                let diag = lambda + left + right;
307                let pivot = diag - left * left / prev_pivot;
308                if !pivot.is_finite() || pivot <= 0.0 {
309                    return Err(format!(
310                        "TotalVariationPenalty log-det encountered non-positive pivot {pivot:.3e}"
311                    ));
312                }
313                total += pivot.ln();
314                prev_pivot = pivot;
315            }
316        }
317        Ok(total)
318    }
319}
320
321impl AnalyticPenalty for TotalVariationPenalty {
322    fn tier(&self) -> PenaltyTier {
323        PenaltyTier::Psi
324    }
325
326    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
327        let Some(d) = self.latent_dim(target.len()) else {
328            return 0.0;
329        };
330        if self.edge_count() == 0 {
331            return 0.0;
332        }
333        let weight = self.resolved_weight(rho);
334        let eps = self.smoothing_eps;
335        let eps2 = eps * eps;
336        let mut acc = 0.0;
337        match &self.difference_op {
338            DifferenceOpKind::ForwardDiff1D => {
339                for a in 0..self.n_eff.saturating_sub(1) {
340                    let b = a + 1;
341                    for j in 0..d {
342                        let diff = target[b * d + j] - target[a * d + j];
343                        acc += (diff * diff + eps2).sqrt() - eps;
344                    }
345                }
346            }
347            DifferenceOpKind::GraphEdges(edges) => {
348                for &(a, b) in edges {
349                    for j in 0..d {
350                        let diff = target[b * d + j] - target[a * d + j];
351                        acc += (diff * diff + eps2).sqrt() - eps;
352                    }
353                }
354            }
355        }
356        weight * acc
357    }
358
359    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
360        let Some(d) = self.latent_dim(target.len()) else {
361            return Array1::<f64>::zeros(target.len());
362        };
363        let weight = self.resolved_weight(rho);
364        let mut out = Array1::<f64>::zeros(target.len());
365        match &self.difference_op {
366            DifferenceOpKind::ForwardDiff1D => {
367                for a in 0..self.n_eff.saturating_sub(1) {
368                    self.add_edge_grad(target, &mut out, d, a, a + 1, weight);
369                }
370            }
371            DifferenceOpKind::GraphEdges(edges) => {
372                for &(a, b) in edges {
373                    self.add_edge_grad(target, &mut out, d, a, b, weight);
374                }
375            }
376        }
377        out
378    }
379
380    fn hvp(
381        &self,
382        target: ArrayView1<'_, f64>,
383        rho: ArrayView1<'_, f64>,
384        v: ArrayView1<'_, f64>,
385    ) -> Array1<f64> {
386        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
387        if target.len() != v.len() {
388            return Array1::<f64>::zeros(target.len());
389        }
390        let Some(d) = self.latent_dim(target.len()) else {
391            return Array1::<f64>::zeros(target.len());
392        };
393        let weight = self.resolved_weight(rho);
394        let mut out = Array1::<f64>::zeros(target.len());
395        match &self.difference_op {
396            DifferenceOpKind::ForwardDiff1D => {
397                for a in 0..self.n_eff.saturating_sub(1) {
398                    self.add_edge_hvp(target, v, &mut out, d, a, a + 1, weight);
399                }
400            }
401            DifferenceOpKind::GraphEdges(edges) => {
402                for &(a, b) in edges {
403                    self.add_edge_hvp(target, v, &mut out, d, a, b, weight);
404                }
405            }
406        }
407        out
408    }
409
410    impl_learnable_weight_grad_rho!();
411
412    impl_learnable_weight_rho_count!();
413
414    fn name(&self) -> &str {
415        "total_variation"
416    }
417
418    impl_scalar_apply_schedule!(weight);
419}
420
421// ---------------------------------------------------------------------------
422// Monotonicity penalty (1D shape constraint)
423// ---------------------------------------------------------------------------
424
425/// Soft monotonicity penalty over a row-major `(n_eff, d)` latent block.
426///
427/// For each adjacent pair `(a, a+1)` along the leading axis and each output
428/// column `j`, the penalty contribution is
429///
430///     softplus(-direction * (target[a+1, j] - target[a, j]) / smoothing_eps)
431///     * smoothing_eps
432///
433/// which is the smoothed hinge that hits zero when the slope agrees with
434/// `direction` (+1 ⇒ non-decreasing, -1 ⇒ non-increasing) and grows
435/// approximately linearly when it disagrees. The Hessian is positive
436/// semidefinite (softplus is convex) so the penalty composes cleanly with
437/// PIRLS/REML.
438///
439/// `n_eff` is the number of latent rows along the constrained axis; the
440/// remaining `target.len() / n_eff` columns are penalized independently and
441/// summed.
442#[derive(Debug, Clone)]
443pub struct ShapeMonotonicityPenalty {
444    pub weight: f64,
445    pub n_eff: usize,
446    /// `+1.0` for non-decreasing, `-1.0` for non-increasing along the leading axis.
447    pub direction: f64,
448    pub smoothing_eps: f64,
449    pub learnable_weight: bool,
450    pub rho_index: usize,
451    pub weight_schedule: Option<ScalarWeightSchedule>,
452}
453
454impl ShapeMonotonicityPenalty {
455    #[must_use = "build error must be handled"]
456    pub fn new(
457        weight: f64,
458        n_eff: usize,
459        direction: f64,
460        smoothing_eps: f64,
461        learnable_weight: bool,
462    ) -> Result<Self, String> {
463        if !(weight.is_finite() && weight > 0.0) {
464            return Err(format!(
465                "ShapeMonotonicityPenalty::new requires finite weight > 0, got {weight}"
466            ));
467        }
468        if n_eff == 0 {
469            return Err("ShapeMonotonicityPenalty::new requires n_eff > 0".to_string());
470        }
471        if !(direction.is_finite() && direction.abs() > 0.0) {
472            return Err(format!(
473                "ShapeMonotonicityPenalty::new requires finite non-zero direction (+1 or -1), got {direction}"
474            ));
475        }
476        if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
477            return Err(format!(
478                "ShapeMonotonicityPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
479            ));
480        }
481        Ok(Self {
482            weight,
483            n_eff,
484            direction: direction.signum(),
485            smoothing_eps,
486            learnable_weight,
487            rho_index: 0,
488            weight_schedule: None,
489        })
490    }
491
492    impl_with_weight_schedule!(weight);
493
494    fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
495        if self.learnable_weight {
496            resolve_learnable_weight(self.weight, rho[self.rho_index])
497        } else {
498            self.weight
499        }
500    }
501
502    fn latent_dim(&self, target_len: usize) -> Option<usize> {
503        if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
504            return None;
505        }
506        Some(target_len / self.n_eff)
507    }
508
509    /// Smoothed-hinge contribution for a single edge `(a, b)` and column `j`.
510    fn edge_value(&self, target: ArrayView1<'_, f64>, d: usize, a: usize, b: usize) -> f64 {
511        let eps = self.smoothing_eps;
512        let mut acc = 0.0;
513        for j in 0..d {
514            let slope = target[b * d + j] - target[a * d + j];
515            let z = -self.direction * slope / eps;
516            // softplus(z) * eps, computed in a numerically stable form.
517            let sp = if z > 0.0 {
518                z + (-z).exp().ln_1p()
519            } else {
520                z.exp().ln_1p()
521            };
522            acc += sp * eps;
523        }
524        acc
525    }
526
527    /// d softplus(-dir * slope / eps) * eps / d target = -dir * sigma(-dir*slope/eps).
528    fn edge_grad(
529        &self,
530        target: ArrayView1<'_, f64>,
531        out: &mut Array1<f64>,
532        d: usize,
533        a: usize,
534        b: usize,
535        weight: f64,
536    ) {
537        let eps = self.smoothing_eps;
538        for j in 0..d {
539            let slope = target[b * d + j] - target[a * d + j];
540            let z = -self.direction * slope / eps;
541            // Stable sigmoid(z).
542            let sigma = if z > 0.0 {
543                1.0 / (1.0 + (-z).exp())
544            } else {
545                let ez = z.exp();
546                ez / (1.0 + ez)
547            };
548            let g = weight * (-self.direction) * sigma;
549            out[a * d + j] -= g;
550            out[b * d + j] += g;
551        }
552    }
553}
554
555impl AnalyticPenalty for ShapeMonotonicityPenalty {
556    fn tier(&self) -> PenaltyTier {
557        PenaltyTier::Psi
558    }
559
560    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
561        let Some(d) = self.latent_dim(target.len()) else {
562            return 0.0;
563        };
564        if self.n_eff < 2 {
565            return 0.0;
566        }
567        let weight = self.resolved_weight(rho);
568        let mut acc = 0.0;
569        for a in 0..self.n_eff.saturating_sub(1) {
570            acc += self.edge_value(target, d, a, a + 1);
571        }
572        weight * acc
573    }
574
575    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
576        let Some(d) = self.latent_dim(target.len()) else {
577            return Array1::<f64>::zeros(target.len());
578        };
579        let weight = self.resolved_weight(rho);
580        let mut out = Array1::<f64>::zeros(target.len());
581        for a in 0..self.n_eff.saturating_sub(1) {
582            self.edge_grad(target, &mut out, d, a, a + 1, weight);
583        }
584        out
585    }
586
587    fn hvp(
588        &self,
589        target: ArrayView1<'_, f64>,
590        rho: ArrayView1<'_, f64>,
591        v: ArrayView1<'_, f64>,
592    ) -> Array1<f64> {
593        assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
594        let Some(d) = self.latent_dim(target.len()) else {
595            return Array1::<f64>::zeros(target.len());
596        };
597        let weight = self.resolved_weight(rho);
598        let eps = self.smoothing_eps;
599        let mut out = Array1::<f64>::zeros(target.len());
600        for a in 0..self.n_eff.saturating_sub(1) {
601            let b = a + 1;
602            for j in 0..d {
603                let slope = target[b * d + j] - target[a * d + j];
604                let z = -self.direction * slope / eps;
605                let sigma = if z > 0.0 {
606                    1.0 / (1.0 + (-z).exp())
607                } else {
608                    let ez = z.exp();
609                    ez / (1.0 + ez)
610                };
611                // d²P/d(target_a)d(target_b) follows from the chain rule on
612                // z = -dir * (target_b - target_a) / eps. The penalty value is
613                // `softplus(z) * eps` (note the outer eps from `edge_value`).
614                // softplus''(z) = sigma(z)(1 - sigma(z)) and the (dz/dtarget)²
615                // factor is 1/eps², but the value's outer `* eps` cancels one of
616                // those, leaving `sigma(1 - sigma) / eps` — exactly the eps power
617                // that keeps `hvp` consistent with the finite difference of
618                // `grad_target` (whose own eps already cancelled). Off-diagonal
619                // entries carry an extra minus sign from the difference.
620                let h = weight * sigma * (1.0 - sigma) / eps;
621                let dv = v[b * d + j] - v[a * d + j];
622                out[a * d + j] -= h * dv;
623                out[b * d + j] += h * dv;
624            }
625        }
626        out
627    }
628
629    impl_learnable_weight_grad_rho!();
630
631    impl_learnable_weight_rho_count!();
632
633    fn name(&self) -> &str {
634        "monotonicity"
635    }
636
637    impl_scalar_apply_schedule!(weight);
638}