gam_models/
parameter_block.rs1use crate::custom_family::{ParameterBlockSpec, PenaltyMatrix};
2use gam_linalg::matrix::DesignMatrix;
3use crate::model_types::PenaltySpec;
4use ndarray::Array1;
5
6const DEFAULT_GAUGE_PRIORITY: u8 = 100;
7
8#[derive(Clone)]
10pub struct ParameterBlockInput {
11 pub design: DesignMatrix,
12 pub offset: Array1<f64>,
13 pub penalties: Vec<PenaltySpec>,
14 pub nullspace_dims: Vec<usize>,
17 pub initial_log_lambdas: Option<Array1<f64>>,
18 pub initial_beta: Option<Array1<f64>>,
19}
20
21impl ParameterBlockInput {
22 pub fn intospec(self, name: &str) -> Result<ParameterBlockSpec, String> {
23 self.intospec_with_gauge_priority(name, DEFAULT_GAUGE_PRIORITY)
24 }
25
26 pub fn intospec_with_gauge_priority(
27 self,
28 name: &str,
29 gauge_priority: u8,
30 ) -> Result<ParameterBlockSpec, String> {
31 let p = self.design.ncols();
32 let n = self.design.nrows();
33 if self.offset.len() != n {
34 return Err(format!(
35 "block '{name}' offset length mismatch: got {}, expected {n}",
36 self.offset.len()
37 ));
38 }
39 if let Some(beta0) = &self.initial_beta
40 && beta0.len() != p
41 {
42 return Err(format!(
43 "block '{name}' initial_beta length mismatch: got {}, expected {p}",
44 beta0.len()
45 ));
46 }
47 for (k, s) in self.penalties.iter().enumerate() {
48 match s {
49 PenaltySpec::Block {
50 local, col_range, ..
51 } => {
52 if col_range.end > p
53 || local.nrows() != col_range.len()
54 || local.ncols() != col_range.len()
55 {
56 return Err(format!(
57 "block '{name}' penalty {k} block shape mismatch: col_range={}..{}, local={}x{}, total_dim={p}",
58 col_range.start,
59 col_range.end,
60 local.nrows(),
61 local.ncols()
62 ));
63 }
64 }
65 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
66 let (r, c) = m.dim();
67 if r != p || c != p {
68 return Err(format!(
69 "block '{name}' penalty {k} must be {p}x{p}, got {r}x{c}"
70 ));
71 }
72 }
73 }
74 }
75 let k = self.penalties.len();
76 let initial_log_lambdas = self
77 .initial_log_lambdas
78 .unwrap_or_else(|| Array1::<f64>::zeros(k));
79 if initial_log_lambdas.len() != k {
80 return Err(format!(
81 "block '{name}' initial_log_lambdas length mismatch: got {}, expected {k}",
82 initial_log_lambdas.len()
83 ));
84 }
85 Ok(ParameterBlockSpec {
86 name: name.to_string(),
87 design: self.design,
88 offset: self.offset,
89 penalties: {
90 self.penalties
91 .into_iter()
92 .map(|spec| match spec {
93 PenaltySpec::Block {
94 local, col_range, ..
95 } => PenaltyMatrix::Blockwise {
96 local,
97 col_range,
98 total_dim: p,
99 },
100 PenaltySpec::Dense(m) | PenaltySpec::DenseWithMean { matrix: m, .. } => {
101 PenaltyMatrix::Dense(m)
102 }
103 })
104 .collect()
105 },
106 nullspace_dims: self.nullspace_dims,
107 initial_log_lambdas,
108 initial_beta: self.initial_beta,
109 gauge_priority,
110 jacobian_callback: None,
111 stacked_design: None,
112 stacked_offset: None,
113 })
114 }
115}