Skip to main content

gam_problem/
custom_family_blockwise.rs

1//! Neutral blockwise custom-family contract primitives shared by the
2//! `CustomFamily` trait layer (`gam-model-api`) and the solver
3//! (`gam-solve`): the IRLS weight / ridge floors, the block-spec consistency
4//! validator, and the exact-Newton outer-curvature payload.
5//!
6//! These carry no dependency on the `CustomFamily` trait itself, so they live
7//! in the neutral `gam-problem` crate and are re-exported upward, keeping a
8//! single definition shared across crates.
9
10use crate::{CustomFamilyError, ParameterBlockSpec};
11use ndarray::Array2;
12use std::collections::BTreeMap;
13
14/// Floor applied to IRLS working weights so downstream divisions cannot hit
15/// exact zero. Used as the default `minweight` in `CustomFamilyOptions` and
16/// mirrored in tests that override it.
17///
18/// Sourced from the canonical positive-weight floor
19/// ([`crate::types::MIN_WEIGHT`] = `1e-12`) so every floored family shares one
20/// definition; this alias keeps the descriptive local name at the `minweight`
21/// defaults.
22pub const CUSTOM_FAMILY_WEIGHT_FLOOR: f64 = crate::types::MIN_WEIGHT;
23
24/// Default initial ridge δ for the explicit-stabilization Cholesky escalation
25/// schedule. Enters the quadratic term, the Laplace Hessian, and the penalty
26/// log-determinant via the active `RidgePolicy`.
27pub const CUSTOM_FAMILY_RIDGE_FLOOR: f64 = 1e-12;
28
29pub fn validate_blockspec_consistency(
30    specs: &[ParameterBlockSpec],
31) -> Result<Vec<usize>, String> {
32    let mut seen_names = BTreeMap::<String, usize>::new();
33    for (b, spec) in specs.iter().enumerate() {
34        if let Some(prev) = seen_names.insert(spec.name.clone(), b) {
35            return Err(CustomFamilyError::ConstraintViolation {
36                reason: format!(
37                    "duplicate parameter block name '{}' at indices {prev} and {b}: block names must be unique so coefficient labels resolved by name are unambiguous",
38                    spec.name
39                ),
40            }
41            .into());
42        }
43    }
44    let mut penalty_counts = Vec::with_capacity(specs.len());
45    for (b, spec) in specs.iter().enumerate() {
46        let n = spec.design.nrows();
47        if spec.offset.len() != n {
48            return Err(CustomFamilyError::DimensionMismatch {
49                reason: format!(
50                    "block {b} offset length mismatch: got {}, expected {}",
51                    spec.offset.len(),
52                    n
53                ),
54            }
55            .into());
56        }
57        // `stacked_design` and `stacked_offset` must be `Some` together
58        // and their row/length must agree.  This enforces the contract
59        // that `solver_design()` and `solver_offset()` always return a
60        // matched pair.
61        match (&spec.stacked_design, &spec.stacked_offset) {
62            (Some(sd), Some(so)) => {
63                if sd.nrows() != so.len() {
64                    return Err(CustomFamilyError::DimensionMismatch {
65                        reason: format!(
66                            "block {b} stacked_design/stacked_offset row mismatch: \
67                             stacked_design.nrows()={}, stacked_offset.len()={}",
68                            sd.nrows(),
69                            so.len(),
70                        ),
71                    }
72                    .into());
73                }
74                if sd.ncols() != spec.design.ncols() {
75                    return Err(CustomFamilyError::DimensionMismatch {
76                        reason: format!(
77                            "block {b} stacked_design column count {} disagrees with \
78                             design column count {}",
79                            sd.ncols(),
80                            spec.design.ncols(),
81                        ),
82                    }
83                    .into());
84                }
85            }
86            (None, None) => {}
87            (Some(_), None) | (None, Some(_)) => {
88                return Err(CustomFamilyError::ConstraintViolation {
89                    reason: format!(
90                        "block {b} stacked_design and stacked_offset must be Some together \
91                         or both None"
92                    ),
93                }
94                .into());
95            }
96        }
97        let p = spec.design.ncols();
98        if let Some(beta0) = &spec.initial_beta
99            && beta0.len() != p
100        {
101            return Err(CustomFamilyError::DimensionMismatch {
102                reason: format!(
103                    "block {b} initial_beta length mismatch: got {}, expected {p}",
104                    beta0.len()
105                ),
106            }
107            .into());
108        }
109        if spec.initial_log_lambdas.len() != spec.penalties.len() {
110            return Err(CustomFamilyError::DimensionMismatch {
111                reason: format!(
112                    "block {b} initial_log_lambdas length {} does not match penalties {}",
113                    spec.initial_log_lambdas.len(),
114                    spec.penalties.len()
115                ),
116            }
117            .into());
118        }
119        for (k, s) in spec.penalties.iter().enumerate() {
120            let (r, c) = s.shape();
121            if r != p || c != p {
122                return Err(CustomFamilyError::DimensionMismatch {
123                    reason: format!("block {b} penalty {k} must be {p}x{p}, got {r}x{c}"),
124                }
125                .into());
126            }
127        }
128        penalty_counts.push(spec.penalties.len());
129    }
130    Ok(penalty_counts)
131}
132
133/// Scale-aware exact joint curvature payload for the outer REML evaluator.
134pub struct ExactNewtonOuterCurvature {
135    pub hessian: Array2<f64>,
136    pub rho_curvature_scale: f64,
137    pub hessian_logdet_correction: f64,
138}