Skip to main content

gam_models/
parameter_block.rs

1use crate::custom_family::{ParameterBlockSpec, PenaltyMatrix};
2use gam_linalg::matrix::DesignMatrix;
3use crate::model_types::PenaltySpec;
4use ndarray::Array1;
5
6const DEFAULT_GAUGE_PRIORITY: u8 = 100;
7
8/// Generic block input for high-level built-in family APIs.
9#[derive(Clone)]
10pub struct ParameterBlockInput {
11    pub design: DesignMatrix,
12    pub offset: Array1<f64>,
13    pub penalties: Vec<PenaltySpec>,
14    /// Structural nullspace dimension per penalty (same length as `penalties`).
15    /// Empty means "use eigenvalue-based rank detection."
16    pub nullspace_dims: Vec<usize>,
17    pub initial_log_lambdas: Option<Array1<f64>>,
18    pub initial_beta: Option<Array1<f64>>,
19}
20
21impl ParameterBlockInput {
22    pub fn intospec(self, name: &str) -> Result<ParameterBlockSpec, String> {
23        self.intospec_with_gauge_priority(name, DEFAULT_GAUGE_PRIORITY)
24    }
25
26    pub fn intospec_with_gauge_priority(
27        self,
28        name: &str,
29        gauge_priority: u8,
30    ) -> Result<ParameterBlockSpec, String> {
31        let p = self.design.ncols();
32        let n = self.design.nrows();
33        if self.offset.len() != n {
34            return Err(format!(
35                "block '{name}' offset length mismatch: got {}, expected {n}",
36                self.offset.len()
37            ));
38        }
39        if let Some(beta0) = &self.initial_beta
40            && beta0.len() != p
41        {
42            return Err(format!(
43                "block '{name}' initial_beta length mismatch: got {}, expected {p}",
44                beta0.len()
45            ));
46        }
47        for (k, s) in self.penalties.iter().enumerate() {
48            match s {
49                PenaltySpec::Block {
50                    local, col_range, ..
51                } => {
52                    if col_range.end > p
53                        || local.nrows() != col_range.len()
54                        || local.ncols() != col_range.len()
55                    {
56                        return Err(format!(
57                            "block '{name}' penalty {k} block shape mismatch: col_range={}..{}, local={}x{}, total_dim={p}",
58                            col_range.start,
59                            col_range.end,
60                            local.nrows(),
61                            local.ncols()
62                        ));
63                    }
64                }
65                PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
66                    let (r, c) = m.dim();
67                    if r != p || c != p {
68                        return Err(format!(
69                            "block '{name}' penalty {k} must be {p}x{p}, got {r}x{c}"
70                        ));
71                    }
72                }
73            }
74        }
75        let k = self.penalties.len();
76        let initial_log_lambdas = self
77            .initial_log_lambdas
78            .unwrap_or_else(|| Array1::<f64>::zeros(k));
79        if initial_log_lambdas.len() != k {
80            return Err(format!(
81                "block '{name}' initial_log_lambdas length mismatch: got {}, expected {k}",
82                initial_log_lambdas.len()
83            ));
84        }
85        Ok(ParameterBlockSpec {
86            name: name.to_string(),
87            design: self.design,
88            offset: self.offset,
89            penalties: {
90                self.penalties
91                    .into_iter()
92                    .map(|spec| match spec {
93                        PenaltySpec::Block {
94                            local, col_range, ..
95                        } => PenaltyMatrix::Blockwise {
96                            local,
97                            col_range,
98                            total_dim: p,
99                        },
100                        PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
101                            PenaltyMatrix::Dense(m)
102                        }
103                    })
104                    .collect()
105            },
106            nullspace_dims: self.nullspace_dims,
107            initial_log_lambdas,
108            initial_beta: self.initial_beta,
109            gauge_priority,
110            jacobian_callback: None,
111            stacked_design: None,
112            stacked_offset: None,
113        })
114    }
115}