1use ndarray::{Array1, ArrayView1};
2use serde::{Deserialize, Serialize};
3use std::ops::{Deref, DerefMut};
4
5pub use gam_linalg::{RidgeDeterminantMode, RidgePolicy};
6
7pub const MIN_WEIGHT: f64 = 1e-12;
10
11pub use gam_spec::*;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum RidgeMatrixForm {
16 ScaledIdentity,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub struct RidgePassport {
23 pub delta: f64,
25 pub matrix_form: RidgeMatrixForm,
26 pub policy: RidgePolicy,
27}
28
29impl RidgePassport {
30 pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
31 Self {
32 delta,
33 matrix_form: RidgeMatrixForm::ScaledIdentity,
34 policy,
35 }
36 }
37
38 #[inline]
39 pub const fn penalty_logdet_ridge(self) -> f64 {
40 if self.policy.include_penalty_logdet {
41 self.delta
42 } else {
43 0.0
44 }
45 }
46
47 #[inline]
48 pub const fn laplacehessianridge(self) -> f64 {
49 if self.policy.include_laplacehessian {
50 self.delta
51 } else {
52 0.0
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89pub struct Inertia {
90 pub positive: usize,
91 pub zero: usize,
92 pub negative: usize,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
97pub enum StabilizationRule {
98 FixedConstant,
100 InertiaTarget { spd_floor: f64 },
102 Heuristic,
104 UserSpecified,
106 BackoffEscalation { attempts: usize },
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
112pub enum StabilizationKind {
113 None,
114 SolverDampingOnly,
117 NumericalPerturbation {
122 backward_error_bound: Option<f64>,
123 },
124 ExplicitPrior,
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
136pub struct StabilizationLedger {
137 pub kind: StabilizationKind,
138 pub delta: f64,
139 pub matrix_form: RidgeMatrixForm,
140 pub chosen_by: StabilizationRule,
141 pub inertia_before: Option<Inertia>,
142 pub inertia_after: Option<Inertia>,
143}
144
145impl StabilizationLedger {
146 pub const fn none() -> Self {
148 Self {
149 kind: StabilizationKind::None,
150 delta: 0.0,
151 matrix_form: RidgeMatrixForm::ScaledIdentity,
152 chosen_by: StabilizationRule::FixedConstant,
153 inertia_before: None,
154 inertia_after: None,
155 }
156 }
157
158 pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
162 Self {
163 kind: StabilizationKind::SolverDampingOnly,
164 delta,
165 matrix_form: RidgeMatrixForm::ScaledIdentity,
166 chosen_by,
167 inertia_before: None,
168 inertia_after: None,
169 }
170 }
171
172 pub const fn numerical_perturbation(
176 delta: f64,
177 chosen_by: StabilizationRule,
178 backward_error_bound: Option<f64>,
179 ) -> Self {
180 Self {
181 kind: StabilizationKind::NumericalPerturbation {
182 backward_error_bound,
183 },
184 delta,
185 matrix_form: RidgeMatrixForm::ScaledIdentity,
186 chosen_by,
187 inertia_before: None,
188 inertia_after: None,
189 }
190 }
191
192 pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
196 Self {
197 kind: StabilizationKind::ExplicitPrior,
198 delta,
199 matrix_form,
200 chosen_by: StabilizationRule::UserSpecified,
201 inertia_before: None,
202 inertia_after: None,
203 }
204 }
205
206 pub const fn from_passport(passport: RidgePassport) -> Self {
218 let any_included = passport.policy.include_quadratic_penalty
219 || passport.policy.include_laplacehessian
220 || passport.policy.include_penalty_logdet;
221 let kind = if any_included {
222 StabilizationKind::ExplicitPrior
223 } else {
224 StabilizationKind::NumericalPerturbation {
225 backward_error_bound: None,
226 }
227 };
228 Self {
229 kind,
230 delta: passport.delta,
231 matrix_form: passport.matrix_form,
232 chosen_by: StabilizationRule::FixedConstant,
233 inertia_before: None,
234 inertia_after: None,
235 }
236 }
237
238 #[inline]
242 pub const fn quadratic_delta(&self) -> f64 {
243 match self.kind {
244 StabilizationKind::ExplicitPrior => self.delta,
245 StabilizationKind::None
246 | StabilizationKind::SolverDampingOnly
247 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
248 }
249 }
250
251 #[inline]
255 pub const fn laplace_hessian_delta(&self) -> f64 {
256 match self.kind {
257 StabilizationKind::ExplicitPrior => self.delta,
258 StabilizationKind::None
259 | StabilizationKind::SolverDampingOnly
260 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
261 }
262 }
263
264 #[inline]
268 pub const fn penalty_logdet_delta(&self) -> f64 {
269 match self.kind {
270 StabilizationKind::ExplicitPrior => self.delta,
271 StabilizationKind::None
272 | StabilizationKind::SolverDampingOnly
273 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
274 }
275 }
276}
277macro_rules! array1_f64_newtype {
282 ($name:ident $(, $extra:ident)*) => {
283 #[repr(transparent)]
284 #[derive(Clone, Debug, PartialEq)]
285 pub struct $name(pub Array1<f64>);
286
287 impl $name {
288 #[inline]
289 pub fn new(values: Array1<f64>) -> Self {
290 Self(values)
291 }
292
293 #[inline]
294 pub fn zeros(len: usize) -> Self {
295 Self(Array1::zeros(len))
296 }
297 }
298
299 impl Deref for $name {
300 type Target = Array1<f64>;
301 #[inline]
302 fn deref(&self) -> &Self::Target { &self.0 }
303 }
304
305 impl DerefMut for $name {
306 #[inline]
307 fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
308 }
309
310 impl AsRef<Array1<f64>> for $name {
311 #[inline]
312 fn as_ref(&self) -> &Array1<f64> { &self.0 }
313 }
314
315 impl From<Array1<f64>> for $name {
316 #[inline]
317 fn from(values: Array1<f64>) -> Self { Self(values) }
318 }
319
320 impl From<$name> for Array1<f64> {
321 #[inline]
322 fn from(values: $name) -> Self { values.0 }
323 }
324
325 $( array1_f64_newtype!(@extra $name $extra); )*
326 };
327 (@extra $name:ident exp) => {
328 impl $name {
329 #[inline]
330 pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
331 }
332 };
333}
334
335array1_f64_newtype!(Coefficients);
336array1_f64_newtype!(LinearPredictor);
337array1_f64_newtype!(LogSmoothingParams, exp);
338
339#[repr(transparent)]
347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
348pub struct SmoothTermIdx(usize);
349
350impl SmoothTermIdx {
351 #[inline]
352 pub const fn new(idx: usize) -> Self {
353 Self(idx)
354 }
355
356 #[inline]
361 pub const fn placeholder() -> Self {
362 Self(usize::MAX)
363 }
364
365 #[inline]
366 pub const fn get(self) -> usize {
367 self.0
368 }
369
370 #[inline]
371 pub const fn is_placeholder(self) -> bool {
372 self.0 == usize::MAX
373 }
374}
375
376impl std::fmt::Display for SmoothTermIdx {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 write!(f, "{}", self.0)
379 }
380}
381
382#[repr(transparent)]
391#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
392pub struct PenaltyIdx(usize);
393
394impl PenaltyIdx {
395 #[inline]
396 pub const fn new(idx: usize) -> Self {
397 Self(idx)
398 }
399
400 #[inline]
401 pub const fn get(self) -> usize {
402 self.0
403 }
404}
405
406impl std::fmt::Display for PenaltyIdx {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "{}", self.0)
409 }
410}
411
412#[repr(transparent)]
426#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
427pub struct BasisIdx(usize);
428
429impl BasisIdx {
430 #[inline]
431 pub const fn new(idx: usize) -> Self {
432 Self(idx)
433 }
434
435 #[inline]
436 pub const fn get(self) -> usize {
437 self.0
438 }
439}
440
441impl std::fmt::Display for BasisIdx {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 write!(f, "{}", self.0)
444 }
445}
446
447#[repr(transparent)]
460#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
461pub struct ColIdx(usize);
462
463impl ColIdx {
464 #[inline]
465 pub const fn new(idx: usize) -> Self {
466 Self(idx)
467 }
468
469 #[inline]
470 pub const fn get(self) -> usize {
471 self.0
472 }
473}
474
475impl std::fmt::Display for ColIdx {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 write!(f, "{}", self.0)
478 }
479}
480
481#[repr(transparent)]
489#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
490pub struct RowIdx(usize);
491
492impl RowIdx {
493 #[inline]
494 pub const fn new(idx: usize) -> Self {
495 Self(idx)
496 }
497
498 #[inline]
499 pub const fn get(self) -> usize {
500 self.0
501 }
502}
503
504impl std::fmt::Display for RowIdx {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 write!(f, "{}", self.0)
507 }
508}
509
510#[repr(transparent)]
511#[derive(Clone, Copy, Debug)]
512pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
513
514impl<'a> LogSmoothingParamsView<'a> {
515 pub fn new(values: ArrayView1<'a, f64>) -> Self {
516 Self(values)
517 }
518
519 pub fn exp(&self) -> Array1<f64> {
520 self.0.mapv(f64::exp)
521 }
522}
523
524impl<'a> Deref for LogSmoothingParamsView<'a> {
525 type Target = ArrayView1<'a, f64>;
526
527 fn deref(&self) -> &Self::Target {
528 &self.0
529 }
530}
531
532#[cfg(test)]
533mod newtype_tests {
534 use super::*;
535 use ndarray::array;
536
537 #[test]
538 fn smooth_term_idx_new_get_roundtrip() {
539 let idx = SmoothTermIdx::new(7);
540 assert_eq!(idx.get(), 7);
541 assert!(!idx.is_placeholder());
542 assert_eq!(format!("{idx}"), "7");
543 }
544
545 #[test]
546 fn smooth_term_idx_placeholder_is_detected() {
547 let p = SmoothTermIdx::placeholder();
548 assert!(p.is_placeholder());
549 assert_eq!(p.get(), usize::MAX);
550 }
551
552 #[test]
553 fn smooth_term_idx_ordering() {
554 let a = SmoothTermIdx::new(1);
555 let b = SmoothTermIdx::new(2);
556 assert!(a < b);
557 assert_eq!(a, SmoothTermIdx::new(1));
558 }
559
560 #[test]
561 fn coefficients_zeros_and_deref() {
562 let c = Coefficients::zeros(3);
563 assert_eq!(c.len(), 3);
564 assert!(c.iter().all(|&v| v == 0.0));
565 }
566
567 #[test]
568 fn coefficients_from_array1() {
569 let arr = array![1.0, 2.0, 3.0];
570 let c = Coefficients::from(arr.clone());
571 assert_eq!(*c, arr);
572 }
573
574 #[test]
575 fn log_smoothing_params_exp_matches_elementwise() {
576 let arr = array![0.0_f64, 1.0, -1.0];
577 let rho = LogSmoothingParams::new(arr.clone());
578 let expected = arr.mapv(f64::exp);
579 assert_eq!(rho.exp(), expected);
580 }
581
582 #[test]
583 fn linear_predictor_zeros_and_deref() {
584 let lp = LinearPredictor::zeros(4);
585 assert_eq!(lp.len(), 4);
586 assert!(lp.iter().all(|&v| v == 0.0));
587 }
588}
589
590#[cfg(test)]
591mod ridge_policy_tests {
592 use super::{RidgePassport, RidgePolicy, StabilizationKind, StabilizationLedger};
593
594 #[test]
595 fn solver_only_ridge_policy_stays_off_objective_accounting() {
596 let passport = RidgePassport::scaled_identity(1.0e-4, RidgePolicy::solver_only());
597
598 assert!(
599 !passport.policy.include_quadratic_penalty,
600 "solver-only ridge must not add a quadratic prior"
601 );
602 assert_eq!(
603 passport.penalty_logdet_ridge(),
604 0.0,
605 "solver-only ridge must not shift the penalty logdet"
606 );
607 assert_eq!(
608 passport.laplacehessianridge(),
609 0.0,
610 "solver-only ridge must not shift the Laplace Hessian"
611 );
612
613 let ledger = StabilizationLedger::from_passport(passport);
614 assert!(
615 matches!(
616 ledger.kind,
617 StabilizationKind::NumericalPerturbation {
618 backward_error_bound: None
619 }
620 ),
621 "solver-only ridge is a numerical perturbation, not an explicit prior"
622 );
623 assert_eq!(
624 ledger.quadratic_delta(),
625 0.0,
626 "solver-only ridge must not contribute to the optimized objective"
627 );
628 assert_eq!(
629 ledger.laplace_hessian_delta(),
630 0.0,
631 "solver-only ridge must not contribute to REML curvature accounting"
632 );
633 assert_eq!(
634 ledger.penalty_logdet_delta(),
635 0.0,
636 "solver-only ridge must not contribute to determinant accounting"
637 );
638 }
639}