1use super::*;
2
3#[derive(Debug, Clone)]
9pub enum DifferenceOpKind {
10 ForwardDiff1D,
12 GraphEdges(Vec<(usize, usize)>),
14}
15
16#[derive(Debug, Clone)]
26pub struct TotalVariationPenalty {
27 pub weight: f64,
30 pub n_eff: usize,
32 pub difference_op: DifferenceOpKind,
33 pub smoothing_eps: f64,
34 pub learnable_weight: bool,
35 pub rho_index: usize,
36 pub weight_schedule: Option<ScalarWeightSchedule>,
37}
38
39impl TotalVariationPenalty {
40 #[must_use = "build error must be handled"]
41 pub fn new(
42 weight: f64,
43 n_eff: usize,
44 difference_op: DifferenceOpKind,
45 smoothing_eps: f64,
46 learnable_weight: bool,
47 ) -> Result<Self, String> {
48 if !(weight.is_finite() && weight > 0.0) {
49 return Err(format!(
50 "TotalVariationPenalty::new requires finite weight > 0, got {weight}"
51 ));
52 }
53 if n_eff == 0 {
54 return Err("TotalVariationPenalty::new requires n_eff > 0".to_string());
55 }
56 if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
57 return Err(format!(
58 "TotalVariationPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
59 ));
60 }
61 if let DifferenceOpKind::GraphEdges(edges) = &difference_op {
62 if edges.is_empty() {
63 return Err(
64 "TotalVariationPenalty::new GraphEdges requires at least one edge".to_string(),
65 );
66 }
67 for &(a, b) in edges {
68 if a >= n_eff || b >= n_eff {
69 return Err(format!(
70 "TotalVariationPenalty::new graph edge ({a}, {b}) exceeds n_eff {n_eff}"
71 ));
72 }
73 if a == b {
74 return Err(format!(
75 "TotalVariationPenalty::new graph edge ({a}, {b}) is self-referential"
76 ));
77 }
78 }
79 }
80 Ok(Self {
81 weight,
82 n_eff,
83 difference_op,
84 smoothing_eps,
85 learnable_weight,
86 rho_index: 0,
87 weight_schedule: None,
88 })
89 }
90
91 impl_with_weight_schedule!(weight);
92
93 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
94 if self.learnable_weight {
95 resolve_learnable_weight(self.weight, rho[self.rho_index])
96 } else {
97 self.weight
98 }
99 }
100
101 fn latent_dim(&self, target_len: usize) -> Option<usize> {
102 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
103 assert_eq!(
104 target_len % self.n_eff.max(1),
105 0,
106 "target length must be divisible by n_eff"
107 );
108 return None;
109 }
110 Some(target_len / self.n_eff)
111 }
112
113 fn edge_count(&self) -> usize {
114 match &self.difference_op {
115 DifferenceOpKind::ForwardDiff1D => self.n_eff.saturating_sub(1),
116 DifferenceOpKind::GraphEdges(edges) => edges.len(),
117 }
118 }
119
120 fn add_edge_hvp(
121 &self,
122 target: ArrayView1<'_, f64>,
123 v: ArrayView1<'_, f64>,
124 out: &mut Array1<f64>,
125 d: usize,
126 a: usize,
127 b: usize,
128 weight: f64,
129 ) {
130 let eps2 = self.smoothing_eps * self.smoothing_eps;
131 for j in 0..d {
132 let ia = a * d + j;
133 let ib = b * d + j;
134 let diff = target[ib] - target[ia];
135 let r = (diff * diff + eps2).sqrt();
136 let curvature = eps2 / (r * r * r);
137 let dv = v[ib] - v[ia];
138 let h = weight * curvature * dv;
139 out[ia] -= h;
140 out[ib] += h;
141 }
142 }
143
144 fn add_edge_grad(
145 &self,
146 target: ArrayView1<'_, f64>,
147 out: &mut Array1<f64>,
148 d: usize,
149 a: usize,
150 b: usize,
151 weight: f64,
152 ) {
153 let eps2 = self.smoothing_eps * self.smoothing_eps;
154 for j in 0..d {
155 let ia = a * d + j;
156 let ib = b * d + j;
157 let diff = target[ib] - target[ia];
158 let smooth_sign = diff / (diff * diff + eps2).sqrt();
159 let g = weight * smooth_sign;
160 out[ia] -= g;
161 out[ib] += g;
162 }
163 }
164
165 fn add_edge_diag(
166 &self,
167 target: ArrayView1<'_, f64>,
168 out: &mut Array1<f64>,
169 d: usize,
170 a: usize,
171 b: usize,
172 weight: f64,
173 ) {
174 let eps2 = self.smoothing_eps * self.smoothing_eps;
175 for j in 0..d {
176 let ia = a * d + j;
177 let ib = b * d + j;
178 let diff = target[ib] - target[ia];
179 let r = (diff * diff + eps2).sqrt();
180 let curvature = weight * eps2 / (r * r * r);
181 out[ia] += curvature;
182 out[ib] += curvature;
183 }
184 }
185
186 fn add_edge_dense(
187 &self,
188 target: ArrayView1<'_, f64>,
189 out: &mut Array2<f64>,
190 d: usize,
191 a: usize,
192 b: usize,
193 weight: f64,
194 ) {
195 let eps2 = self.smoothing_eps * self.smoothing_eps;
196 for j in 0..d {
197 let ia = a * d + j;
198 let ib = b * d + j;
199 let diff = target[ib] - target[ia];
200 let r = (diff * diff + eps2).sqrt();
201 let curvature = weight * eps2 / (r * r * r);
202 out[[ia, ia]] += curvature;
203 out[[ib, ib]] += curvature;
204 out[[ia, ib]] -= curvature;
205 out[[ib, ia]] -= curvature;
206 }
207 }
208
209 pub fn diag_target(
210 &self,
211 target: ArrayView1<'_, f64>,
212 rho: ArrayView1<'_, f64>,
213 ) -> Array1<f64> {
214 let Some(d) = self.latent_dim(target.len()) else {
215 return Array1::<f64>::zeros(target.len());
216 };
217 let weight = self.resolved_weight(rho);
218 let mut out = Array1::<f64>::zeros(target.len());
219 match &self.difference_op {
220 DifferenceOpKind::ForwardDiff1D => {
221 for a in 0..self.n_eff.saturating_sub(1) {
222 self.add_edge_diag(target, &mut out, d, a, a + 1, weight);
223 }
224 }
225 DifferenceOpKind::GraphEdges(edges) => {
226 for &(a, b) in edges {
227 self.add_edge_diag(target, &mut out, d, a, b, weight);
228 }
229 }
230 }
231 out
232 }
233
234 pub fn as_dense(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array2<f64> {
236 let n = target.len();
237 let Some(d) = self.latent_dim(n) else {
238 return Array2::<f64>::zeros((n, n));
239 };
240 let weight = self.resolved_weight(rho);
241 let mut out = Array2::<f64>::zeros((n, n));
242 match &self.difference_op {
243 DifferenceOpKind::ForwardDiff1D => {
244 for a in 0..self.n_eff.saturating_sub(1) {
245 self.add_edge_dense(target, &mut out, d, a, a + 1, weight);
246 }
247 }
248 DifferenceOpKind::GraphEdges(edges) => {
249 for &(a, b) in edges {
250 self.add_edge_dense(target, &mut out, d, a, b, weight);
251 }
252 }
253 }
254 out
255 }
256
257 pub fn log_det_plus_lambda_i_forward_1d(
258 &self,
259 target: ArrayView1<'_, f64>,
260 rho: ArrayView1<'_, f64>,
261 lambda: f64,
262 ) -> Result<f64, String> {
263 if !matches!(&self.difference_op, DifferenceOpKind::ForwardDiff1D) {
264 return Err(
265 "TotalVariationPenalty::log_det_plus_lambda_i_forward_1d requires ForwardDiff1D"
266 .to_string(),
267 );
268 }
269 let Some(d) = self.latent_dim(target.len()) else {
270 return Err(format!(
271 "TotalVariationPenalty target length {} is not divisible by n_eff {}",
272 target.len(),
273 self.n_eff
274 ));
275 };
276 if !(lambda.is_finite() && lambda > 0.0) {
277 return Err(format!(
278 "TotalVariationPenalty::log_det_plus_lambda_i_forward_1d requires finite λ > 0; got {lambda}"
279 ));
280 }
281 let n = self.n_eff;
282 if n == 1 {
283 return Ok((d as f64) * lambda.ln());
284 }
285 let weight = self.resolved_weight(rho);
286 let eps2 = self.smoothing_eps * self.smoothing_eps;
287 let mut total = 0.0;
288 for j in 0..d {
289 let mut edge_w = vec![0.0; n - 1];
290 for a in 0..n - 1 {
291 let diff = target[(a + 1) * d + j] - target[a * d + j];
292 let r = (diff * diff + eps2).sqrt();
293 edge_w[a] = weight * eps2 / (r * r * r);
294 }
295
296 let mut prev_pivot = lambda + edge_w[0];
297 if !prev_pivot.is_finite() || prev_pivot <= 0.0 {
298 return Err(format!(
299 "TotalVariationPenalty log-det encountered non-positive pivot {prev_pivot:.3e}"
300 ));
301 }
302 total += prev_pivot.ln();
303 for row in 1..n {
304 let left = edge_w[row - 1];
305 let right = if row + 1 < n { edge_w[row] } else { 0.0 };
306 let diag = lambda + left + right;
307 let pivot = diag - left * left / prev_pivot;
308 if !pivot.is_finite() || pivot <= 0.0 {
309 return Err(format!(
310 "TotalVariationPenalty log-det encountered non-positive pivot {pivot:.3e}"
311 ));
312 }
313 total += pivot.ln();
314 prev_pivot = pivot;
315 }
316 }
317 Ok(total)
318 }
319}
320
321impl AnalyticPenalty for TotalVariationPenalty {
322 fn tier(&self) -> PenaltyTier {
323 PenaltyTier::Psi
324 }
325
326 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
327 let Some(d) = self.latent_dim(target.len()) else {
328 return 0.0;
329 };
330 if self.edge_count() == 0 {
331 return 0.0;
332 }
333 let weight = self.resolved_weight(rho);
334 let eps = self.smoothing_eps;
335 let eps2 = eps * eps;
336 let mut acc = 0.0;
337 match &self.difference_op {
338 DifferenceOpKind::ForwardDiff1D => {
339 for a in 0..self.n_eff.saturating_sub(1) {
340 let b = a + 1;
341 for j in 0..d {
342 let diff = target[b * d + j] - target[a * d + j];
343 acc += (diff * diff + eps2).sqrt() - eps;
344 }
345 }
346 }
347 DifferenceOpKind::GraphEdges(edges) => {
348 for &(a, b) in edges {
349 for j in 0..d {
350 let diff = target[b * d + j] - target[a * d + j];
351 acc += (diff * diff + eps2).sqrt() - eps;
352 }
353 }
354 }
355 }
356 weight * acc
357 }
358
359 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
360 let Some(d) = self.latent_dim(target.len()) else {
361 return Array1::<f64>::zeros(target.len());
362 };
363 let weight = self.resolved_weight(rho);
364 let mut out = Array1::<f64>::zeros(target.len());
365 match &self.difference_op {
366 DifferenceOpKind::ForwardDiff1D => {
367 for a in 0..self.n_eff.saturating_sub(1) {
368 self.add_edge_grad(target, &mut out, d, a, a + 1, weight);
369 }
370 }
371 DifferenceOpKind::GraphEdges(edges) => {
372 for &(a, b) in edges {
373 self.add_edge_grad(target, &mut out, d, a, b, weight);
374 }
375 }
376 }
377 out
378 }
379
380 fn hvp(
381 &self,
382 target: ArrayView1<'_, f64>,
383 rho: ArrayView1<'_, f64>,
384 v: ArrayView1<'_, f64>,
385 ) -> Array1<f64> {
386 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
387 if target.len() != v.len() {
388 return Array1::<f64>::zeros(target.len());
389 }
390 let Some(d) = self.latent_dim(target.len()) else {
391 return Array1::<f64>::zeros(target.len());
392 };
393 let weight = self.resolved_weight(rho);
394 let mut out = Array1::<f64>::zeros(target.len());
395 match &self.difference_op {
396 DifferenceOpKind::ForwardDiff1D => {
397 for a in 0..self.n_eff.saturating_sub(1) {
398 self.add_edge_hvp(target, v, &mut out, d, a, a + 1, weight);
399 }
400 }
401 DifferenceOpKind::GraphEdges(edges) => {
402 for &(a, b) in edges {
403 self.add_edge_hvp(target, v, &mut out, d, a, b, weight);
404 }
405 }
406 }
407 out
408 }
409
410 impl_learnable_weight_grad_rho!();
411
412 impl_learnable_weight_rho_count!();
413
414 fn name(&self) -> &str {
415 "total_variation"
416 }
417
418 impl_scalar_apply_schedule!(weight);
419}
420
421#[derive(Debug, Clone)]
443pub struct ShapeMonotonicityPenalty {
444 pub weight: f64,
445 pub n_eff: usize,
446 pub direction: f64,
448 pub smoothing_eps: f64,
449 pub learnable_weight: bool,
450 pub rho_index: usize,
451 pub weight_schedule: Option<ScalarWeightSchedule>,
452}
453
454impl ShapeMonotonicityPenalty {
455 #[must_use = "build error must be handled"]
456 pub fn new(
457 weight: f64,
458 n_eff: usize,
459 direction: f64,
460 smoothing_eps: f64,
461 learnable_weight: bool,
462 ) -> Result<Self, String> {
463 if !(weight.is_finite() && weight > 0.0) {
464 return Err(format!(
465 "ShapeMonotonicityPenalty::new requires finite weight > 0, got {weight}"
466 ));
467 }
468 if n_eff == 0 {
469 return Err("ShapeMonotonicityPenalty::new requires n_eff > 0".to_string());
470 }
471 if !(direction.is_finite() && direction.abs() > 0.0) {
472 return Err(format!(
473 "ShapeMonotonicityPenalty::new requires finite non-zero direction (+1 or -1), got {direction}"
474 ));
475 }
476 if !(smoothing_eps.is_finite() && smoothing_eps > 0.0) {
477 return Err(format!(
478 "ShapeMonotonicityPenalty::new requires finite smoothing_eps > 0, got {smoothing_eps}"
479 ));
480 }
481 Ok(Self {
482 weight,
483 n_eff,
484 direction: direction.signum(),
485 smoothing_eps,
486 learnable_weight,
487 rho_index: 0,
488 weight_schedule: None,
489 })
490 }
491
492 impl_with_weight_schedule!(weight);
493
494 fn resolved_weight(&self, rho: ArrayView1<'_, f64>) -> f64 {
495 if self.learnable_weight {
496 resolve_learnable_weight(self.weight, rho[self.rho_index])
497 } else {
498 self.weight
499 }
500 }
501
502 fn latent_dim(&self, target_len: usize) -> Option<usize> {
503 if self.n_eff == 0 || !target_len.is_multiple_of(self.n_eff) {
504 return None;
505 }
506 Some(target_len / self.n_eff)
507 }
508
509 fn edge_value(&self, target: ArrayView1<'_, f64>, d: usize, a: usize, b: usize) -> f64 {
511 let eps = self.smoothing_eps;
512 let mut acc = 0.0;
513 for j in 0..d {
514 let slope = target[b * d + j] - target[a * d + j];
515 let z = -self.direction * slope / eps;
516 let sp = if z > 0.0 {
518 z + (-z).exp().ln_1p()
519 } else {
520 z.exp().ln_1p()
521 };
522 acc += sp * eps;
523 }
524 acc
525 }
526
527 fn edge_grad(
529 &self,
530 target: ArrayView1<'_, f64>,
531 out: &mut Array1<f64>,
532 d: usize,
533 a: usize,
534 b: usize,
535 weight: f64,
536 ) {
537 let eps = self.smoothing_eps;
538 for j in 0..d {
539 let slope = target[b * d + j] - target[a * d + j];
540 let z = -self.direction * slope / eps;
541 let sigma = if z > 0.0 {
543 1.0 / (1.0 + (-z).exp())
544 } else {
545 let ez = z.exp();
546 ez / (1.0 + ez)
547 };
548 let g = weight * (-self.direction) * sigma;
549 out[a * d + j] -= g;
550 out[b * d + j] += g;
551 }
552 }
553}
554
555impl AnalyticPenalty for ShapeMonotonicityPenalty {
556 fn tier(&self) -> PenaltyTier {
557 PenaltyTier::Psi
558 }
559
560 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
561 let Some(d) = self.latent_dim(target.len()) else {
562 return 0.0;
563 };
564 if self.n_eff < 2 {
565 return 0.0;
566 }
567 let weight = self.resolved_weight(rho);
568 let mut acc = 0.0;
569 for a in 0..self.n_eff.saturating_sub(1) {
570 acc += self.edge_value(target, d, a, a + 1);
571 }
572 weight * acc
573 }
574
575 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
576 let Some(d) = self.latent_dim(target.len()) else {
577 return Array1::<f64>::zeros(target.len());
578 };
579 let weight = self.resolved_weight(rho);
580 let mut out = Array1::<f64>::zeros(target.len());
581 for a in 0..self.n_eff.saturating_sub(1) {
582 self.edge_grad(target, &mut out, d, a, a + 1, weight);
583 }
584 out
585 }
586
587 fn hvp(
588 &self,
589 target: ArrayView1<'_, f64>,
590 rho: ArrayView1<'_, f64>,
591 v: ArrayView1<'_, f64>,
592 ) -> Array1<f64> {
593 assert_eq!(target.len(), v.len(), "hvp dimension mismatch");
594 let Some(d) = self.latent_dim(target.len()) else {
595 return Array1::<f64>::zeros(target.len());
596 };
597 let weight = self.resolved_weight(rho);
598 let eps = self.smoothing_eps;
599 let mut out = Array1::<f64>::zeros(target.len());
600 for a in 0..self.n_eff.saturating_sub(1) {
601 let b = a + 1;
602 for j in 0..d {
603 let slope = target[b * d + j] - target[a * d + j];
604 let z = -self.direction * slope / eps;
605 let sigma = if z > 0.0 {
606 1.0 / (1.0 + (-z).exp())
607 } else {
608 let ez = z.exp();
609 ez / (1.0 + ez)
610 };
611 let h = weight * sigma * (1.0 - sigma) / eps;
621 let dv = v[b * d + j] - v[a * d + j];
622 out[a * d + j] -= h * dv;
623 out[b * d + j] += h * dv;
624 }
625 }
626 out
627 }
628
629 impl_learnable_weight_grad_rho!();
630
631 impl_learnable_weight_rho_count!();
632
633 fn name(&self) -> &str {
634 "monotonicity"
635 }
636
637 impl_scalar_apply_schedule!(weight);
638}