1use crate::basis::{
2 BasisOptions, Dense, KnotSource, create_basis, create_difference_penalty_matrix,
3};
4use crate::custom_family::{
5 BlockWorkingSet, BlockwiseFitOptions, BlockwiseFitResult, CustomFamily, FamilyEvaluation,
6 KnownLinkWiggle, ParameterBlockSpec, ParameterBlockState, fit_custom_family,
7};
8use crate::faer_ndarray::{fast_ata, fast_atv};
9use crate::generative::{CustomFamilyGenerative, GenerativeSpec, NoiseModel};
10use crate::matrix::DesignMatrix;
11use crate::pirls::WorkingLikelihood as EngineWorkingLikelihood;
12use crate::probability::{normal_cdf_approx, normal_pdf};
13use crate::types::{LikelihoodFamily, LinkFunction};
14use faer::Mat as FaerMat;
15use faer::Side;
16use faer::linalg::solvers::{
17 Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
18};
19use ndarray::{Array1, Array2, ArrayView1, s};
20
21const MIN_PROB: f64 = 1e-10;
22const MIN_DERIV: f64 = 1e-8;
23const MIN_WEIGHT: f64 = 1e-12;
24const BETA_RANGE_WARN_THRESHOLD: f64 = 1.10;
25const BINOMIAL_EFFECTIVE_N_WARN_THRESHOLD: f64 = 25.0;
26
27#[derive(Clone)]
29pub struct ParameterBlockInput {
30 pub design: DesignMatrix,
31 pub offset: Array1<f64>,
32 pub penalties: Vec<Array2<f64>>,
33 pub initial_log_lambdas: Option<Array1<f64>>,
34 pub initial_beta: Option<Array1<f64>>,
35}
36
37#[derive(Clone, Debug)]
38pub struct FamilyMetadata {
39 pub name: &'static str,
40 pub parameter_names: &'static [&'static str],
41 pub parameter_links: &'static [ParameterLink],
42}
43
44#[derive(Clone, Debug)]
45pub struct WiggleBlockConfig {
46 pub degree: usize,
47 pub num_internal_knots: usize,
48 pub penalty_order: usize,
49 pub double_penalty: bool,
50}
51
52impl ParameterBlockInput {
53 pub fn into_spec(self, name: &str) -> Result<ParameterBlockSpec, String> {
54 let p = self.design.ncols();
55 let n = self.design.nrows();
56 if self.offset.len() != n {
57 return Err(format!(
58 "block '{name}' offset length mismatch: got {}, expected {n}",
59 self.offset.len()
60 ));
61 }
62 if let Some(beta0) = &self.initial_beta {
63 if beta0.len() != p {
64 return Err(format!(
65 "block '{name}' initial_beta length mismatch: got {}, expected {p}",
66 beta0.len()
67 ));
68 }
69 }
70 for (k, s) in self.penalties.iter().enumerate() {
71 let (r, c) = s.dim();
72 if r != p || c != p {
73 return Err(format!(
74 "block '{name}' penalty {k} must be {p}x{p}, got {r}x{c}"
75 ));
76 }
77 }
78 let k = self.penalties.len();
79 let initial_log_lambdas = self
80 .initial_log_lambdas
81 .unwrap_or_else(|| Array1::<f64>::zeros(k));
82 if initial_log_lambdas.len() != k {
83 return Err(format!(
84 "block '{name}' initial_log_lambdas length mismatch: got {}, expected {k}",
85 initial_log_lambdas.len()
86 ));
87 }
88 Ok(ParameterBlockSpec {
89 name: name.to_string(),
90 design: self.design,
91 offset: self.offset,
92 penalties: self.penalties,
93 initial_log_lambdas,
94 initial_beta: self.initial_beta,
95 })
96 }
97}
98
99fn validate_sigma_bounds(sigma_min: f64, sigma_max: f64, context: &str) -> Result<(), String> {
100 if !sigma_min.is_finite() || !sigma_max.is_finite() {
101 return Err(format!("{context}: sigma bounds must be finite"));
102 }
103 if sigma_min <= 0.0 || sigma_max <= 0.0 {
104 return Err(format!(
105 "{context}: sigma bounds must be strictly positive (got min={sigma_min}, max={sigma_max})"
106 ));
107 }
108 if sigma_min > sigma_max {
109 return Err(format!(
110 "{context}: sigma_min ({sigma_min}) must be <= sigma_max ({sigma_max})"
111 ));
112 }
113 Ok(())
114}
115
116fn validate_len_match(name: &str, expected: usize, found: usize) -> Result<(), String> {
117 if expected != found {
118 return Err(format!(
119 "{name} length mismatch: expected {expected}, found {found}"
120 ));
121 }
122 Ok(())
123}
124
125fn validate_weights(weights: &Array1<f64>, context: &str) -> Result<(), String> {
126 for (i, &w) in weights.iter().enumerate() {
127 if !w.is_finite() || w < 0.0 {
128 return Err(format!(
129 "{context}: weights must be finite and non-negative; found weights[{i}]={w}"
130 ));
131 }
132 }
133 Ok(())
134}
135
136fn validate_binomial_response(y: &Array1<f64>, context: &str) -> Result<(), String> {
137 for (i, &yi) in y.iter().enumerate() {
138 if !yi.is_finite() || !(0.0..=1.0).contains(&yi) {
139 return Err(format!(
140 "{context}: binomial response must be finite in [0,1]; found y[{i}]={yi}"
141 ));
142 }
143 }
144 Ok(())
145}
146
147pub fn initialize_wiggle_knots_from_seed(
148 seed: ArrayView1<'_, f64>,
149 degree: usize,
150 num_internal_knots: usize,
151) -> Result<Array1<f64>, String> {
152 let seed_min = seed.iter().copied().fold(f64::INFINITY, f64::min);
153 let mut seed_max = seed.iter().copied().fold(f64::NEG_INFINITY, f64::max);
154 if !seed_min.is_finite() || !seed_max.is_finite() {
155 return Err("non-finite seed for wiggle knot initialization".to_string());
156 }
157 if (seed_max - seed_min).abs() < 1e-12 {
158 seed_max = seed_min + 1e-6;
159 }
160 let (_, knots) = create_basis::<Dense>(
161 seed,
162 KnotSource::Generate {
163 data_range: (seed_min, seed_max),
164 num_internal_knots,
165 },
166 degree,
167 BasisOptions::value(),
168 )
169 .map_err(|e| e.to_string())?;
170 Ok(knots)
171}
172
173pub fn build_wiggle_block_input_from_knots(
174 seed: ArrayView1<'_, f64>,
175 knots: &Array1<f64>,
176 degree: usize,
177 penalty_order: usize,
178 double_penalty: bool,
179) -> Result<ParameterBlockInput, String> {
180 let (basis, _) = create_basis::<Dense>(
181 seed,
182 KnotSource::Provided(knots.view()),
183 degree,
184 BasisOptions::value(),
185 )
186 .map_err(|e| e.to_string())?;
187 let full = (*basis).clone();
188 if full.ncols() < 2 {
189 return Err("wiggle basis has fewer than two columns".to_string());
190 }
191 let design = full.slice(s![.., 1..]).to_owned();
192 let p = design.ncols();
193 let mut penalties =
194 vec![create_difference_penalty_matrix(p, penalty_order, None).map_err(|e| e.to_string())?];
195 if double_penalty {
196 penalties.push(Array2::<f64>::eye(p));
197 }
198 Ok(ParameterBlockInput {
199 design: DesignMatrix::Dense(design),
200 offset: Array1::zeros(seed.len()),
201 penalties,
202 initial_log_lambdas: None,
203 initial_beta: None,
204 })
205}
206
207pub fn build_wiggle_block_input_from_seed(
208 seed: ArrayView1<'_, f64>,
209 cfg: &WiggleBlockConfig,
210) -> Result<(ParameterBlockInput, Array1<f64>), String> {
211 let knots = initialize_wiggle_knots_from_seed(seed, cfg.degree, cfg.num_internal_knots)?;
212 let block = build_wiggle_block_input_from_knots(
213 seed,
214 &knots,
215 cfg.degree,
216 cfg.penalty_order,
217 cfg.double_penalty,
218 )?;
219 Ok((block, knots))
220}
221
222fn validate_block_rows(name: &str, n: usize, block: &ParameterBlockInput) -> Result<(), String> {
223 validate_len_match(
224 &format!("block '{name}' offset vs response"),
225 n,
226 block.offset.len(),
227 )?;
228 validate_len_match(
229 &format!("block '{name}' design rows vs response"),
230 n,
231 block.design.nrows(),
232 )
233}
234
235fn evaluate_single_block_glm(
238 family: LikelihoodFamily,
239 y: &Array1<f64>,
240 weights: &Array1<f64>,
241 eta: &Array1<f64>,
242) -> Result<FamilyEvaluation, String> {
243 let n = y.len();
244 if eta.len() != n || weights.len() != n {
245 return Err("single-block GLM input size mismatch".to_string());
246 }
247 let mut mu = Array1::<f64>::zeros(n);
248 let mut z = Array1::<f64>::zeros(n);
249 let mut w = Array1::<f64>::zeros(n);
250 family
251 .irls_update(y.view(), eta, weights.view(), &mut mu, &mut w, &mut z, None)
252 .map_err(|e| e.to_string())?;
253 let ll = family
254 .log_likelihood(y.view(), eta, &mu, weights.view())
255 .map_err(|e| e.to_string())?;
256 Ok(FamilyEvaluation {
257 log_likelihood: ll,
258 block_working_sets: vec![BlockWorkingSet {
259 working_response: z,
260 working_weights: w,
261 gradient_eta: None,
262 }],
263 })
264}
265
266fn initial_log_lambdas_or_zeros(block: &ParameterBlockInput) -> Result<Array1<f64>, String> {
267 let k = block.penalties.len();
268 let lambdas = block
269 .initial_log_lambdas
270 .clone()
271 .unwrap_or_else(|| Array1::<f64>::zeros(k));
272 if lambdas.len() != k {
273 return Err(format!(
274 "initial_log_lambdas length mismatch: got {}, expected {}",
275 lambdas.len(),
276 k
277 ));
278 }
279 Ok(lambdas)
280}
281
282fn solve_weighted_projection(
283 design: &DesignMatrix,
284 offset: &Array1<f64>,
285 target_eta: &Array1<f64>,
286 weights: &Array1<f64>,
287 ridge_floor: f64,
288) -> Result<Array1<f64>, String> {
289 let n = design.nrows();
290 let p = design.ncols();
291 if offset.len() != n || target_eta.len() != n || weights.len() != n {
292 return Err("solve_weighted_projection dimension mismatch".to_string());
293 }
294
295 let (mut xtwx, xtwy) = match design {
296 DesignMatrix::Dense(x) => {
297 let mut xw = x.clone();
298 for i in 0..n {
299 let sw = weights[i].max(0.0).sqrt();
300 if sw != 1.0 {
301 let mut row = xw.row_mut(i);
302 row *= sw;
303 }
304 }
305 let xtwx = fast_ata(&xw);
306 let mut y_w = target_eta - offset;
307 for i in 0..n {
308 y_w[i] *= weights[i].max(0.0).sqrt();
309 }
310 let xtwy = fast_atv(&xw, &y_w);
311 (xtwx, xtwy)
312 }
313 DesignMatrix::Sparse(xs) => {
314 let csr = xs
315 .as_ref()
316 .to_row_major()
317 .map_err(|_| "failed to obtain CSR view for weighted projection".to_string())?;
318 let sym = csr.symbolic();
319 let row_ptr = sym.row_ptr();
320 let col_idx = sym.col_idx();
321 let vals = csr.val();
322 let mut xtwx = Array2::<f64>::zeros((p, p));
323 let mut xtwy = Array1::<f64>::zeros(p);
324
325 for i in 0..n {
326 let wi = weights[i].max(0.0);
327 if wi == 0.0 {
328 continue;
329 }
330 let y_star = target_eta[i] - offset[i];
331 let start = row_ptr[i];
332 let end = row_ptr[i + 1];
333 for a_ptr in start..end {
334 let a = col_idx[a_ptr];
335 let xa = vals[a_ptr];
336 xtwy[a] += wi * xa * y_star;
337 for b_ptr in a_ptr..end {
338 let b = col_idx[b_ptr];
339 let xb = vals[b_ptr];
340 xtwx[[a, b]] += wi * xa * xb;
341 }
342 }
343 }
344 for a in 0..p {
345 for b in 0..a {
346 xtwx[[a, b]] = xtwx[[b, a]];
347 }
348 }
349 (xtwx, xtwy)
350 }
351 };
352 for a in 0..p {
353 xtwx[[a, a]] += ridge_floor.max(1e-12);
354 }
355
356 let h = crate::faer_ndarray::FaerArrayView::new(&xtwx);
357 let mut rhs_mat = FaerMat::zeros(p, 1);
358 for i in 0..p {
359 rhs_mat[(i, 0)] = xtwy[i];
360 }
361
362 if let Ok(ch) = FaerLlt::new(h.as_ref(), Side::Lower) {
363 ch.solve_in_place(rhs_mat.as_mut());
364 } else if let Ok(ld) = FaerLdlt::new(h.as_ref(), Side::Lower) {
365 ld.solve_in_place(rhs_mat.as_mut());
366 } else {
367 let lb = FaerLblt::new(h.as_ref(), Side::Lower);
368 lb.solve_in_place(rhs_mat.as_mut());
369 }
370
371 let mut beta = Array1::<f64>::zeros(p);
372 for i in 0..p {
373 beta[i] = rhs_mat[(i, 0)];
374 }
375 if beta.iter().any(|v| !v.is_finite()) {
376 return Err("solve_weighted_projection produced non-finite coefficients".to_string());
377 }
378 Ok(beta)
379}
380
381fn weighted_prevalence(y: &Array1<f64>, weights: &Array1<f64>) -> f64 {
382 let w_sum: f64 = weights.iter().copied().sum();
383 if w_sum <= 0.0 {
384 return 0.5;
385 }
386 let y_w_sum: f64 = y.iter().zip(weights.iter()).map(|(&yi, &wi)| yi * wi).sum();
387 (y_w_sum / w_sum).clamp(0.0, 1.0)
388}
389
390fn emit_binomial_alpha_beta_warnings(
391 context: &str,
392 beta_values: &Array1<f64>,
393 y: &Array1<f64>,
394 weights: &Array1<f64>,
395) {
396 if beta_values.is_empty() {
397 return;
398 }
399 let beta_min = beta_values.iter().copied().fold(f64::INFINITY, f64::min);
400 let beta_max = beta_values
401 .iter()
402 .copied()
403 .fold(f64::NEG_INFINITY, f64::max);
404
405 if !beta_min.is_finite() || !beta_max.is_finite() || beta_min <= 0.0 {
406 log::warn!(
407 "[GAMLSS][{}] non-positive or non-finite beta encountered (min={}, max={})",
408 context,
409 beta_min,
410 beta_max
411 );
412 } else {
413 let ratio = beta_max / beta_min;
414 if ratio > BETA_RANGE_WARN_THRESHOLD {
415 log::warn!(
416 "[GAMLSS][{}] beta range ratio {:.3} exceeds {:.3}; transformed-penalty distortion risk is elevated",
417 context,
418 ratio,
419 BETA_RANGE_WARN_THRESHOLD
420 );
421 }
422 }
423
424 let pi = weighted_prevalence(y, weights);
425 let w_sum: f64 = weights.iter().copied().sum();
426 let n_eff = w_sum * pi * (1.0 - pi);
427 if n_eff < BINOMIAL_EFFECTIVE_N_WARN_THRESHOLD {
428 log::warn!(
429 "[GAMLSS][{}] low effective sample size N_eff={:.3} (sum_w={:.3}, prevalence={:.3}); location-scale separation artifacts are more likely",
430 context,
431 n_eff,
432 w_sum,
433 pi
434 );
435 }
436}
437
438#[derive(Clone)]
439struct BinomialAlphaBetaWarmStartFamily {
440 y: Array1<f64>,
441 score: Array1<f64>,
442 weights: Array1<f64>,
443 beta_min: f64,
444 beta_max: f64,
445}
446
447impl BinomialAlphaBetaWarmStartFamily {
448 const BLOCK_ALPHA: usize = 0;
449 const BLOCK_BETA: usize = 1;
450}
451
452impl CustomFamily for BinomialAlphaBetaWarmStartFamily {
453 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
454 if block_states.len() != 2 {
455 return Err(format!(
456 "BinomialAlphaBetaWarmStartFamily expects 2 blocks, got {}",
457 block_states.len()
458 ));
459 }
460 let n = self.y.len();
461 let eta_alpha = &block_states[Self::BLOCK_ALPHA].eta;
462 let eta_beta = &block_states[Self::BLOCK_BETA].eta;
463 if eta_alpha.len() != n
464 || eta_beta.len() != n
465 || self.score.len() != n
466 || self.weights.len() != n
467 {
468 return Err("BinomialAlphaBetaWarmStartFamily input size mismatch".to_string());
469 }
470
471 let mut z_alpha = Array1::<f64>::zeros(n);
472 let mut w_alpha = Array1::<f64>::zeros(n);
473 let mut z_beta = Array1::<f64>::zeros(n);
474 let mut w_beta = Array1::<f64>::zeros(n);
475 let mut ll = 0.0_f64;
476
477 for i in 0..n {
478 let raw_beta = eta_beta[i];
479 let beta = raw_beta.clamp(self.beta_min, self.beta_max);
480 let dbeta_deta = if raw_beta >= self.beta_min && raw_beta <= self.beta_max {
481 1.0
482 } else {
483 0.0
484 };
485 let q = eta_alpha[i] + beta * self.score[i];
486 let mu = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
487 let dmu_dq = normal_pdf(q).max(MIN_DERIV);
488 let var = (mu * (1.0 - mu)).max(MIN_PROB);
489
490 ll += self.weights[i] * (self.y[i] * mu.ln() + (1.0 - self.y[i]) * (1.0 - mu).ln());
491
492 let dmu_alpha = dmu_dq;
493 w_alpha[i] = (self.weights[i] * (dmu_alpha * dmu_alpha / var)).max(MIN_WEIGHT);
494 z_alpha[i] = eta_alpha[i] + (self.y[i] - mu) / signed_with_floor(dmu_alpha, MIN_DERIV);
495
496 let chain_beta = self.score[i] * dbeta_deta;
497 let dmu_beta = dmu_dq * chain_beta;
498 w_beta[i] = (self.weights[i] * (dmu_beta * dmu_beta / var)).max(MIN_WEIGHT);
499 z_beta[i] = eta_beta[i] + (self.y[i] - mu) / signed_with_floor(dmu_beta, MIN_DERIV);
500 }
501
502 Ok(FamilyEvaluation {
503 log_likelihood: ll,
504 block_working_sets: vec![
505 BlockWorkingSet {
506 working_response: z_alpha,
507 working_weights: w_alpha,
508 gradient_eta: None,
509 },
510 BlockWorkingSet {
511 working_response: z_beta,
512 working_weights: w_beta,
513 gradient_eta: None,
514 },
515 ],
516 })
517 }
518
519 fn post_update_beta(
520 &self,
521 block_index: usize,
522 beta: Array1<f64>,
523 ) -> Result<Array1<f64>, String> {
524 if block_index != Self::BLOCK_BETA {
525 return Ok(beta);
526 }
527 Ok(beta.mapv(|v| v.clamp(self.beta_min, self.beta_max)))
528 }
529}
530
531fn try_binomial_alpha_beta_warm_start(
532 y: &Array1<f64>,
533 score: &Array1<f64>,
534 weights: &Array1<f64>,
535 sigma_min: f64,
536 sigma_max: f64,
537 threshold_block: &ParameterBlockInput,
538 log_sigma_block: &ParameterBlockInput,
539 options: &BlockwiseFitOptions,
540) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
541 let beta_min = (1.0 / sigma_max.max(1e-12)).max(1e-12);
542 let beta_max = (1.0 / sigma_min.max(1e-12)).max(beta_min + 1e-12);
543 let warm_family = BinomialAlphaBetaWarmStartFamily {
544 y: y.clone(),
545 score: score.clone(),
546 weights: weights.clone(),
547 beta_min,
548 beta_max,
549 };
550
551 let alpha_spec = ParameterBlockSpec {
552 name: "alpha_warm".to_string(),
553 design: threshold_block.design.clone(),
554 offset: threshold_block.offset.clone(),
555 penalties: threshold_block.penalties.clone(),
556 initial_log_lambdas: initial_log_lambdas_or_zeros(threshold_block)?,
557 initial_beta: None,
558 };
559 let beta_spec = ParameterBlockSpec {
560 name: "beta_warm".to_string(),
561 design: log_sigma_block.design.clone(),
562 offset: log_sigma_block.offset.clone(),
563 penalties: log_sigma_block.penalties.clone(),
564 initial_log_lambdas: initial_log_lambdas_or_zeros(log_sigma_block)?,
565 initial_beta: None,
566 };
567
568 let warm_options = BlockwiseFitOptions {
569 inner_max_cycles: options.inner_max_cycles.min(40).max(5),
570 inner_tol: options.inner_tol,
571 outer_max_iter: options.outer_max_iter.min(20).max(3),
572 outer_tol: options.outer_tol.max(1e-6),
573 min_weight: options.min_weight,
574 ridge_floor: options.ridge_floor.max(1e-10),
575 ridge_policy: options.ridge_policy,
576 use_reml_objective: false,
578 };
579 let warm_fit = fit_custom_family(&warm_family, &[alpha_spec, beta_spec], &warm_options)?;
580 let eta_alpha = &warm_fit.block_states[BinomialAlphaBetaWarmStartFamily::BLOCK_ALPHA].eta;
581 let eta_beta = &warm_fit.block_states[BinomialAlphaBetaWarmStartFamily::BLOCK_BETA].eta;
582 if eta_alpha.len() != y.len() || eta_beta.len() != y.len() {
583 return Err("warm start eta length mismatch".to_string());
584 }
585
586 let beta_obs = eta_beta.mapv(|v| v.clamp(beta_min, beta_max));
587 let t_target = Array1::from_iter(
588 eta_alpha
589 .iter()
590 .zip(beta_obs.iter())
591 .map(|(&a, &b)| -a / b.max(1e-12)),
592 );
593 let log_sigma_target = beta_obs.mapv(|b| -b.max(1e-12).ln());
594
595 let beta_t = solve_weighted_projection(
596 &threshold_block.design,
597 &threshold_block.offset,
598 &t_target,
599 weights,
600 options.ridge_floor.max(1e-10),
601 )?;
602 let beta_log_sigma = solve_weighted_projection(
603 &log_sigma_block.design,
604 &log_sigma_block.offset,
605 &log_sigma_target,
606 weights,
607 options.ridge_floor.max(1e-10),
608 )?;
609
610 Ok((beta_t, beta_log_sigma, beta_obs))
611}
612
613#[derive(Clone)]
614pub struct GaussianLocationScaleSpec {
615 pub y: Array1<f64>,
616 pub weights: Array1<f64>,
617 pub sigma_min: f64,
618 pub sigma_max: f64,
619 pub mu_block: ParameterBlockInput,
620 pub log_sigma_block: ParameterBlockInput,
621}
622
623#[derive(Clone)]
624pub struct BinomialLogitSpec {
625 pub y: Array1<f64>,
626 pub weights: Array1<f64>,
627 pub eta_block: ParameterBlockInput,
628}
629
630#[derive(Clone)]
631pub struct PoissonLogSpec {
632 pub y: Array1<f64>,
633 pub weights: Array1<f64>,
634 pub eta_block: ParameterBlockInput,
635}
636
637#[derive(Clone)]
638pub struct GammaLogSpec {
639 pub y: Array1<f64>,
640 pub weights: Array1<f64>,
641 pub shape: f64,
643 pub eta_block: ParameterBlockInput,
644}
645
646#[derive(Clone)]
647pub struct BinomialLocationScaleProbitSpec {
648 pub y: Array1<f64>,
649 pub score: Array1<f64>,
650 pub weights: Array1<f64>,
651 pub sigma_min: f64,
652 pub sigma_max: f64,
653 pub threshold_block: ParameterBlockInput,
654 pub log_sigma_block: ParameterBlockInput,
655}
656
657#[derive(Clone)]
658pub struct BinomialLocationScaleProbitWiggleSpec {
659 pub y: Array1<f64>,
660 pub score: Array1<f64>,
661 pub weights: Array1<f64>,
662 pub sigma_min: f64,
663 pub sigma_max: f64,
664 pub wiggle_knots: Array1<f64>,
665 pub wiggle_degree: usize,
666 pub threshold_block: ParameterBlockInput,
667 pub log_sigma_block: ParameterBlockInput,
668 pub wiggle_block: ParameterBlockInput,
669}
670
671pub fn fit_gaussian_location_scale(
672 spec: GaussianLocationScaleSpec,
673 options: &BlockwiseFitOptions,
674) -> Result<BlockwiseFitResult, String> {
675 let n = spec.y.len();
676 validate_len_match("weights vs y", n, spec.weights.len())?;
677 validate_weights(&spec.weights, "fit_gaussian_location_scale")?;
678 validate_sigma_bounds(
679 spec.sigma_min,
680 spec.sigma_max,
681 "fit_gaussian_location_scale",
682 )?;
683 validate_block_rows("mu", n, &spec.mu_block)?;
684 validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
685
686 let family = GaussianLocationScaleFamily {
687 y: spec.y,
688 weights: spec.weights,
689 sigma_min: spec.sigma_min,
690 sigma_max: spec.sigma_max,
691 };
692 let blocks = vec![
693 spec.mu_block.into_spec("mu")?,
694 spec.log_sigma_block.into_spec("log_sigma")?,
695 ];
696 fit_custom_family(&family, &blocks, options)
697}
698
699pub fn fit_binomial_logit(
700 spec: BinomialLogitSpec,
701 options: &BlockwiseFitOptions,
702) -> Result<BlockwiseFitResult, String> {
703 let n = spec.y.len();
704 validate_len_match("weights vs y", n, spec.weights.len())?;
705 validate_weights(&spec.weights, "fit_binomial_logit")?;
706 validate_binomial_response(&spec.y, "fit_binomial_logit")?;
707 validate_block_rows("eta", n, &spec.eta_block)?;
708
709 let family = BinomialLogitFamily {
710 y: spec.y,
711 weights: spec.weights,
712 };
713 let blocks = vec![spec.eta_block.into_spec("eta")?];
714 fit_custom_family(&family, &blocks, options)
715}
716
717pub fn fit_poisson_log(
718 spec: PoissonLogSpec,
719 options: &BlockwiseFitOptions,
720) -> Result<BlockwiseFitResult, String> {
721 let n = spec.y.len();
722 validate_len_match("weights vs y", n, spec.weights.len())?;
723 validate_weights(&spec.weights, "fit_poisson_log")?;
724 validate_block_rows("eta", n, &spec.eta_block)?;
725
726 let family = PoissonLogFamily {
727 y: spec.y,
728 weights: spec.weights,
729 };
730 let blocks = vec![spec.eta_block.into_spec("eta")?];
731 fit_custom_family(&family, &blocks, options)
732}
733
734pub fn fit_gamma_log(
735 spec: GammaLogSpec,
736 options: &BlockwiseFitOptions,
737) -> Result<BlockwiseFitResult, String> {
738 let n = spec.y.len();
739 validate_len_match("weights vs y", n, spec.weights.len())?;
740 validate_weights(&spec.weights, "fit_gamma_log")?;
741 validate_block_rows("eta", n, &spec.eta_block)?;
742 if !spec.shape.is_finite() || spec.shape <= 0.0 {
743 return Err(format!(
744 "fit_gamma_log: shape must be finite and > 0, got {}",
745 spec.shape
746 ));
747 }
748
749 let family = GammaLogFamily {
750 y: spec.y,
751 weights: spec.weights,
752 shape: spec.shape,
753 };
754 let blocks = vec![spec.eta_block.into_spec("eta")?];
755 fit_custom_family(&family, &blocks, options)
756}
757
758pub fn fit_binomial_location_scale_probit(
759 spec: BinomialLocationScaleProbitSpec,
760 options: &BlockwiseFitOptions,
761) -> Result<BlockwiseFitResult, String> {
762 let n = spec.y.len();
763 validate_len_match("score vs y", n, spec.score.len())?;
764 validate_len_match("weights vs y", n, spec.weights.len())?;
765 validate_weights(&spec.weights, "fit_binomial_location_scale_probit")?;
766 validate_binomial_response(&spec.y, "fit_binomial_location_scale_probit")?;
767 validate_sigma_bounds(
768 spec.sigma_min,
769 spec.sigma_max,
770 "fit_binomial_location_scale_probit",
771 )?;
772 validate_block_rows("threshold", n, &spec.threshold_block)?;
773 validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
774
775 let BinomialLocationScaleProbitSpec {
776 y,
777 score,
778 weights,
779 sigma_min,
780 sigma_max,
781 mut threshold_block,
782 mut log_sigma_block,
783 } = spec;
784
785 match try_binomial_alpha_beta_warm_start(
786 &y,
787 &score,
788 &weights,
789 sigma_min,
790 sigma_max,
791 &threshold_block,
792 &log_sigma_block,
793 options,
794 ) {
795 Ok((beta_t0, beta_ls0, beta_warm)) => {
796 threshold_block.initial_beta = Some(beta_t0);
797 log_sigma_block.initial_beta = Some(beta_ls0);
798 emit_binomial_alpha_beta_warnings("warm-start", &beta_warm, &y, &weights);
799 }
800 Err(err) => {
801 log::warn!(
802 "[GAMLSS][fit_binomial_location_scale_probit] alpha/beta warm start failed, falling back to direct initialization: {}",
803 err
804 );
805 }
806 }
807
808 let family = BinomialLocationScaleProbitFamily {
809 y: y.clone(),
810 score: score.clone(),
811 weights: weights.clone(),
812 sigma_min,
813 sigma_max,
814 };
815 let blocks = vec![
816 threshold_block.into_spec("threshold")?,
817 log_sigma_block.into_spec("log_sigma")?,
818 ];
819 let fit = fit_custom_family(&family, &blocks, options)?;
820 let beta_final = fit.block_states[BinomialLocationScaleProbitFamily::BLOCK_LOG_SIGMA]
821 .eta
822 .mapv(f64::exp)
823 .mapv(|s| 1.0 / s.clamp(sigma_min, sigma_max).max(1e-12));
824 emit_binomial_alpha_beta_warnings("final-fit", &beta_final, &y, &weights);
825 Ok(fit)
826}
827
828pub fn fit_binomial_location_scale_probit_wiggle(
829 spec: BinomialLocationScaleProbitWiggleSpec,
830 options: &BlockwiseFitOptions,
831) -> Result<BlockwiseFitResult, String> {
832 let n = spec.y.len();
833 validate_len_match("score vs y", n, spec.score.len())?;
834 validate_len_match("weights vs y", n, spec.weights.len())?;
835 validate_weights(&spec.weights, "fit_binomial_location_scale_probit_wiggle")?;
836 validate_binomial_response(&spec.y, "fit_binomial_location_scale_probit_wiggle")?;
837 validate_sigma_bounds(
838 spec.sigma_min,
839 spec.sigma_max,
840 "fit_binomial_location_scale_probit_wiggle",
841 )?;
842 validate_block_rows("threshold", n, &spec.threshold_block)?;
843 validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
844 validate_block_rows("wiggle", n, &spec.wiggle_block)?;
845 if spec.wiggle_degree < 1 {
846 return Err(format!(
847 "fit_binomial_location_scale_probit_wiggle: wiggle_degree must be >= 1, got {}",
848 spec.wiggle_degree
849 ));
850 }
851 if spec.wiggle_knots.len() < spec.wiggle_degree + 2 {
852 return Err(format!(
853 "fit_binomial_location_scale_probit_wiggle: wiggle_knots length {} is too short for degree {}",
854 spec.wiggle_knots.len(),
855 spec.wiggle_degree
856 ));
857 }
858
859 let BinomialLocationScaleProbitWiggleSpec {
860 y,
861 score,
862 weights,
863 sigma_min,
864 sigma_max,
865 wiggle_knots,
866 wiggle_degree,
867 mut threshold_block,
868 mut log_sigma_block,
869 wiggle_block,
870 } = spec;
871
872 match try_binomial_alpha_beta_warm_start(
873 &y,
874 &score,
875 &weights,
876 sigma_min,
877 sigma_max,
878 &threshold_block,
879 &log_sigma_block,
880 options,
881 ) {
882 Ok((beta_t0, beta_ls0, beta_warm)) => {
883 threshold_block.initial_beta = Some(beta_t0);
884 log_sigma_block.initial_beta = Some(beta_ls0);
885 emit_binomial_alpha_beta_warnings("warm-start-wiggle", &beta_warm, &y, &weights);
886 }
887 Err(err) => {
888 log::warn!(
889 "[GAMLSS][fit_binomial_location_scale_probit_wiggle] alpha/beta warm start failed, falling back to direct initialization: {}",
890 err
891 );
892 }
893 }
894
895 let family = BinomialLocationScaleProbitWiggleFamily {
896 y: y.clone(),
897 score: score.clone(),
898 weights: weights.clone(),
899 sigma_min,
900 sigma_max,
901 wiggle_knots,
902 wiggle_degree,
903 };
904 let blocks = vec![
905 threshold_block.into_spec("threshold")?,
906 log_sigma_block.into_spec("log_sigma")?,
907 wiggle_block.into_spec("wiggle")?,
908 ];
909 let fit = fit_custom_family(&family, &blocks, options)?;
910 let beta_final = fit.block_states[BinomialLocationScaleProbitWiggleFamily::BLOCK_LOG_SIGMA]
911 .eta
912 .mapv(f64::exp)
913 .mapv(|s| 1.0 / s.clamp(sigma_min, sigma_max).max(1e-12));
914 emit_binomial_alpha_beta_warnings("final-fit-wiggle", &beta_final, &y, &weights);
915 Ok(fit)
916}
917
918#[derive(Clone, Copy, Debug, PartialEq, Eq)]
920pub enum ParameterLink {
921 Identity,
922 Log,
923 Logit,
924 Probit,
925 Wiggle,
927}
928
929fn signed_with_floor(v: f64, floor: f64) -> f64 {
930 let a = v.abs().max(floor);
931 if v >= 0.0 { a } else { -a }
932}
933
934struct BinomialLocationScaleCore {
935 sigma: Array1<f64>,
936 dsigma_deta: Array1<f64>,
937 q0: Array1<f64>,
938 mu: Array1<f64>,
939 dmu_dq: Array1<f64>,
940 log_likelihood: f64,
941}
942
943fn binomial_location_scale_core(
944 y: &Array1<f64>,
945 score: &Array1<f64>,
946 weights: &Array1<f64>,
947 eta_t: &Array1<f64>,
948 eta_ls: &Array1<f64>,
949 eta_wiggle: Option<&Array1<f64>>,
950 sigma_min: f64,
951 sigma_max: f64,
952) -> Result<BinomialLocationScaleCore, String> {
953 let n = y.len();
954 if score.len() != n || weights.len() != n || eta_t.len() != n || eta_ls.len() != n {
955 return Err("binomial location-scale core size mismatch".to_string());
956 }
957 if let Some(w) = eta_wiggle {
958 if w.len() != n {
959 return Err("binomial location-scale core wiggle size mismatch".to_string());
960 }
961 }
962
963 let mut sigma = Array1::<f64>::zeros(n);
964 let mut dsigma_deta = Array1::<f64>::zeros(n);
965 let mut q0 = Array1::<f64>::zeros(n);
966 let mut mu = Array1::<f64>::zeros(n);
967 let mut dmu_dq = Array1::<f64>::zeros(n);
968 let mut ll = 0.0;
969
970 for i in 0..n {
971 let raw = eta_ls[i].exp();
972 sigma[i] = raw.clamp(sigma_min, sigma_max);
973 dsigma_deta[i] = if raw >= sigma_min && raw <= sigma_max {
974 raw
975 } else {
976 0.0
977 };
978 q0[i] = (score[i] - eta_t[i]) / sigma[i].max(1e-12);
979 let q = q0[i] + eta_wiggle.map_or(0.0, |w| w[i]);
980 mu[i] = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
981 dmu_dq[i] = normal_pdf(q).max(MIN_DERIV);
982 ll += weights[i] * (y[i] * mu[i].ln() + (1.0 - y[i]) * (1.0 - mu[i]).ln());
983 }
984
985 Ok(BinomialLocationScaleCore {
986 sigma,
987 dsigma_deta,
988 q0,
989 mu,
990 dmu_dq,
991 log_likelihood: ll,
992 })
993}
994
995fn binomial_location_scale_working_sets(
996 y: &Array1<f64>,
997 weights: &Array1<f64>,
998 eta_t: &Array1<f64>,
999 eta_ls: &Array1<f64>,
1000 eta_wiggle: Option<&Array1<f64>>,
1001 core: &BinomialLocationScaleCore,
1002) -> (BlockWorkingSet, BlockWorkingSet, Option<BlockWorkingSet>) {
1003 let n = y.len();
1004 let mut z_t = Array1::<f64>::zeros(n);
1005 let mut w_t = Array1::<f64>::zeros(n);
1006 let mut z_ls = Array1::<f64>::zeros(n);
1007 let mut w_ls = Array1::<f64>::zeros(n);
1008 let mut z_w = eta_wiggle.map(|_| Array1::<f64>::zeros(n));
1009 let mut w_w = eta_wiggle.map(|_| Array1::<f64>::zeros(n));
1010
1011 for i in 0..n {
1012 let var = (core.mu[i] * (1.0 - core.mu[i])).max(MIN_PROB);
1013
1014 let chain_t = -1.0 / core.sigma[i].max(1e-12);
1016 let dmu_t = core.dmu_dq[i] * chain_t;
1017 w_t[i] = (weights[i] * (dmu_t * dmu_t / var)).max(MIN_WEIGHT);
1018 z_t[i] = eta_t[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_t, MIN_DERIV);
1019
1020 let chain_ls = {
1023 let s = core.sigma[i].max(1e-12);
1024 -core.q0[i] * core.dsigma_deta[i] / s
1025 };
1026 let dmu_ls = core.dmu_dq[i] * chain_ls;
1027 w_ls[i] = (weights[i] * (dmu_ls * dmu_ls / var)).max(MIN_WEIGHT);
1028 z_ls[i] = eta_ls[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_ls, MIN_DERIV);
1029
1030 if let (Some(eta_w), Some(z_wv), Some(w_wv)) = (eta_wiggle, z_w.as_mut(), w_w.as_mut()) {
1031 let dmu_w = core.dmu_dq[i];
1033 w_wv[i] = (weights[i] * (dmu_w * dmu_w / var)).max(MIN_WEIGHT);
1034 z_wv[i] = eta_w[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_w, MIN_DERIV);
1035 }
1036 }
1037
1038 let t_ws = BlockWorkingSet {
1039 working_response: z_t,
1040 working_weights: w_t,
1041 gradient_eta: None,
1042 };
1043 let ls_ws = BlockWorkingSet {
1044 working_response: z_ls,
1045 working_weights: w_ls,
1046 gradient_eta: None,
1047 };
1048 let w_ws = match (z_w, w_w) {
1049 (Some(z), Some(w)) => Some(BlockWorkingSet {
1050 working_response: z,
1051 working_weights: w,
1052 gradient_eta: None,
1053 }),
1054 _ => None,
1055 };
1056 (t_ws, ls_ws, w_ws)
1057}
1058
1059#[derive(Clone)]
1063pub struct GaussianLocationScaleFamily {
1064 pub y: Array1<f64>,
1065 pub weights: Array1<f64>,
1066 pub sigma_min: f64,
1067 pub sigma_max: f64,
1068}
1069
1070impl GaussianLocationScaleFamily {
1071 pub const BLOCK_MU: usize = 0;
1072 pub const BLOCK_LOG_SIGMA: usize = 1;
1073
1074 pub fn parameter_names() -> &'static [&'static str] {
1075 &["mu", "log_sigma"]
1076 }
1077
1078 pub fn parameter_links() -> &'static [ParameterLink] {
1079 &[ParameterLink::Identity, ParameterLink::Log]
1080 }
1081
1082 pub fn metadata() -> FamilyMetadata {
1083 FamilyMetadata {
1084 name: "gaussian_location_scale",
1085 parameter_names: Self::parameter_names(),
1086 parameter_links: Self::parameter_links(),
1087 }
1088 }
1089}
1090
1091impl CustomFamily for GaussianLocationScaleFamily {
1092 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1093 if block_states.len() != 2 {
1094 return Err(format!(
1095 "GaussianLocationScaleFamily expects 2 blocks, got {}",
1096 block_states.len()
1097 ));
1098 }
1099 let n = self.y.len();
1100 let eta_mu = &block_states[Self::BLOCK_MU].eta;
1101 let eta_log_sigma = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1102 if eta_mu.len() != n || eta_log_sigma.len() != n || self.weights.len() != n {
1103 return Err("GaussianLocationScaleFamily input size mismatch".to_string());
1104 }
1105
1106 let mut sigma = Array1::<f64>::zeros(n);
1107 let mut dsigma_deta = Array1::<f64>::zeros(n);
1108 let mut ll = 0.0;
1109
1110 for i in 0..n {
1111 let raw = eta_log_sigma[i].exp();
1112 sigma[i] = raw.clamp(self.sigma_min, self.sigma_max);
1113 dsigma_deta[i] = if raw >= self.sigma_min && raw <= self.sigma_max {
1114 raw
1115 } else {
1116 0.0
1117 };
1118 let r = self.y[i] - eta_mu[i];
1119 let s2 = (sigma[i] * sigma[i]).max(1e-20);
1120 ll += self.weights[i] * (-0.5 * (r * r / s2 + (2.0 * std::f64::consts::PI * s2).ln()));
1121 }
1122
1123 let mut z_mu = Array1::<f64>::zeros(n);
1124 let mut w_mu = Array1::<f64>::zeros(n);
1125 let mut z_ls = Array1::<f64>::zeros(n);
1126 let mut w_ls = Array1::<f64>::zeros(n);
1127
1128 for i in 0..n {
1129 let r = self.y[i] - eta_mu[i];
1130 let s = sigma[i].max(1e-10);
1131 let s2 = (s * s).max(1e-20);
1132
1133 w_mu[i] = (self.weights[i] / s2).max(MIN_WEIGHT);
1135 z_mu[i] = eta_mu[i] + r;
1136
1137 let dlogsigma_du = if dsigma_deta[i] == 0.0 {
1141 0.0
1142 } else {
1143 (dsigma_deta[i] / s).clamp(-1.0, 1.0)
1144 };
1145 let score_u = self.weights[i] * ((r * r / s2) - 1.0) * dlogsigma_du;
1146 let info_u = (2.0 * self.weights[i] * dlogsigma_du * dlogsigma_du).max(MIN_WEIGHT);
1147 z_ls[i] = eta_log_sigma[i] + score_u / info_u;
1148 w_ls[i] = info_u;
1149 }
1150
1151 Ok(FamilyEvaluation {
1152 log_likelihood: ll,
1153 block_working_sets: vec![
1154 BlockWorkingSet {
1155 working_response: z_mu,
1156 working_weights: w_mu,
1157 gradient_eta: None,
1158 },
1159 BlockWorkingSet {
1160 working_response: z_ls,
1161 working_weights: w_ls,
1162 gradient_eta: None,
1163 },
1164 ],
1165 })
1166 }
1167}
1168
1169impl CustomFamilyGenerative for GaussianLocationScaleFamily {
1170 fn generative_spec(
1171 &self,
1172 block_states: &[ParameterBlockState],
1173 ) -> Result<GenerativeSpec, String> {
1174 if block_states.len() != 2 {
1175 return Err(format!(
1176 "GaussianLocationScaleFamily expects 2 blocks, got {}",
1177 block_states.len()
1178 ));
1179 }
1180 let mu = block_states[Self::BLOCK_MU].eta.clone();
1181 let sigma = block_states[Self::BLOCK_LOG_SIGMA]
1182 .eta
1183 .mapv(f64::exp)
1184 .mapv(|s| s.clamp(self.sigma_min, self.sigma_max));
1185 Ok(GenerativeSpec {
1186 mean: mu,
1187 noise: NoiseModel::Gaussian { sigma },
1188 })
1189 }
1190}
1191
1192#[derive(Clone)]
1194pub struct BinomialLogitFamily {
1195 pub y: Array1<f64>,
1196 pub weights: Array1<f64>,
1197}
1198
1199impl BinomialLogitFamily {
1200 pub const BLOCK_ETA: usize = 0;
1201
1202 pub fn parameter_names() -> &'static [&'static str] {
1203 &["eta"]
1204 }
1205
1206 pub fn parameter_links() -> &'static [ParameterLink] {
1207 &[ParameterLink::Logit]
1208 }
1209
1210 pub fn metadata() -> FamilyMetadata {
1211 FamilyMetadata {
1212 name: "binomial_logit",
1213 parameter_names: Self::parameter_names(),
1214 parameter_links: Self::parameter_links(),
1215 }
1216 }
1217}
1218
1219impl CustomFamily for BinomialLogitFamily {
1220 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1221 if block_states.len() != 1 {
1222 return Err(format!(
1223 "BinomialLogitFamily expects 1 block, got {}",
1224 block_states.len()
1225 ));
1226 }
1227 let eta = &block_states[Self::BLOCK_ETA].eta;
1228 let n = self.y.len();
1229 if eta.len() != n || self.weights.len() != n {
1230 return Err("BinomialLogitFamily input size mismatch".to_string());
1231 }
1232 evaluate_single_block_glm(LikelihoodFamily::BinomialLogit, &self.y, &self.weights, eta)
1233 }
1234}
1235
1236impl CustomFamilyGenerative for BinomialLogitFamily {
1237 fn generative_spec(
1238 &self,
1239 block_states: &[ParameterBlockState],
1240 ) -> Result<GenerativeSpec, String> {
1241 if block_states.len() != 1 {
1242 return Err(format!(
1243 "BinomialLogitFamily expects 1 block, got {}",
1244 block_states.len()
1245 ));
1246 }
1247 let mean = block_states[Self::BLOCK_ETA].eta.mapv(|e| {
1248 (1.0 / (1.0 + (-e.clamp(-30.0, 30.0)).exp())).clamp(MIN_PROB, 1.0 - MIN_PROB)
1249 });
1250 Ok(GenerativeSpec {
1251 mean,
1252 noise: NoiseModel::Bernoulli,
1253 })
1254 }
1255}
1256
1257#[derive(Clone)]
1259pub struct PoissonLogFamily {
1260 pub y: Array1<f64>,
1261 pub weights: Array1<f64>,
1262}
1263
1264impl PoissonLogFamily {
1265 pub const BLOCK_ETA: usize = 0;
1266
1267 pub fn parameter_names() -> &'static [&'static str] {
1268 &["eta"]
1269 }
1270
1271 pub fn parameter_links() -> &'static [ParameterLink] {
1272 &[ParameterLink::Log]
1273 }
1274
1275 pub fn metadata() -> FamilyMetadata {
1276 FamilyMetadata {
1277 name: "poisson_log",
1278 parameter_names: Self::parameter_names(),
1279 parameter_links: Self::parameter_links(),
1280 }
1281 }
1282}
1283
1284impl CustomFamily for PoissonLogFamily {
1285 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1286 if block_states.len() != 1 {
1287 return Err(format!(
1288 "PoissonLogFamily expects 1 block, got {}",
1289 block_states.len()
1290 ));
1291 }
1292 let eta = &block_states[Self::BLOCK_ETA].eta;
1293 let n = self.y.len();
1294 if eta.len() != n || self.weights.len() != n {
1295 return Err("PoissonLogFamily input size mismatch".to_string());
1296 }
1297
1298 let mut mu = Array1::<f64>::zeros(n);
1299 let mut ll = 0.0;
1300 let mut z = Array1::<f64>::zeros(n);
1301 let mut w = Array1::<f64>::zeros(n);
1302
1303 for i in 0..n {
1304 let yi = self.y[i];
1305 if !yi.is_finite() || yi < 0.0 {
1306 return Err(format!(
1307 "PoissonLogFamily requires non-negative finite y; found y[{i}]={yi}"
1308 ));
1309 }
1310 let e = eta[i].clamp(-30.0, 30.0);
1311 let m = e.exp().max(1e-12);
1312 mu[i] = m;
1313 ll += self.weights[i] * (yi * e - m);
1315 let dmu = m.max(MIN_DERIV);
1316 let var = m.max(MIN_PROB);
1317 w[i] = (self.weights[i] * (dmu * dmu / var)).max(MIN_WEIGHT);
1318 z[i] = e + (yi - m) / signed_with_floor(dmu, MIN_DERIV);
1319 }
1320
1321 Ok(FamilyEvaluation {
1322 log_likelihood: ll,
1323 block_working_sets: vec![BlockWorkingSet {
1324 working_response: z,
1325 working_weights: w,
1326 gradient_eta: None,
1327 }],
1328 })
1329 }
1330}
1331
1332impl CustomFamilyGenerative for PoissonLogFamily {
1333 fn generative_spec(
1334 &self,
1335 block_states: &[ParameterBlockState],
1336 ) -> Result<GenerativeSpec, String> {
1337 if block_states.len() != 1 {
1338 return Err(format!(
1339 "PoissonLogFamily expects 1 block, got {}",
1340 block_states.len()
1341 ));
1342 }
1343 let mean = block_states[Self::BLOCK_ETA]
1344 .eta
1345 .mapv(|e| e.clamp(-30.0, 30.0).exp().max(1e-12));
1346 Ok(GenerativeSpec {
1347 mean,
1348 noise: NoiseModel::Poisson,
1349 })
1350 }
1351}
1352
1353#[derive(Clone)]
1355pub struct GammaLogFamily {
1356 pub y: Array1<f64>,
1357 pub weights: Array1<f64>,
1358 pub shape: f64,
1359}
1360
1361impl GammaLogFamily {
1362 pub const BLOCK_ETA: usize = 0;
1363
1364 pub fn parameter_names() -> &'static [&'static str] {
1365 &["eta"]
1366 }
1367
1368 pub fn parameter_links() -> &'static [ParameterLink] {
1369 &[ParameterLink::Log]
1370 }
1371
1372 pub fn metadata() -> FamilyMetadata {
1373 FamilyMetadata {
1374 name: "gamma_log",
1375 parameter_names: Self::parameter_names(),
1376 parameter_links: Self::parameter_links(),
1377 }
1378 }
1379}
1380
1381impl CustomFamily for GammaLogFamily {
1382 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1383 if block_states.len() != 1 {
1384 return Err(format!(
1385 "GammaLogFamily expects 1 block, got {}",
1386 block_states.len()
1387 ));
1388 }
1389 let eta = &block_states[Self::BLOCK_ETA].eta;
1390 let n = self.y.len();
1391 if eta.len() != n || self.weights.len() != n {
1392 return Err("GammaLogFamily input size mismatch".to_string());
1393 }
1394 if !self.shape.is_finite() || self.shape <= 0.0 {
1395 return Err("GammaLogFamily shape must be finite and > 0".to_string());
1396 }
1397
1398 let mut mu = Array1::<f64>::zeros(n);
1399 let mut ll = 0.0;
1400 let mut z = Array1::<f64>::zeros(n);
1401 let mut w = Array1::<f64>::zeros(n);
1402
1403 for i in 0..n {
1404 let yi = self.y[i];
1405 if !yi.is_finite() || yi <= 0.0 {
1406 return Err(format!(
1407 "GammaLogFamily requires positive finite y; found y[{i}]={yi}"
1408 ));
1409 }
1410 let e = eta[i].clamp(-30.0, 30.0);
1411 let m = e.exp().max(1e-12);
1412 mu[i] = m;
1413 ll += self.weights[i] * (-self.shape * (yi / m + m.ln()));
1415 let dmu = m.max(MIN_DERIV);
1416 let var = (m * m / self.shape).max(MIN_PROB);
1417 w[i] = (self.weights[i] * (dmu * dmu / var)).max(MIN_WEIGHT);
1418 z[i] = e + (yi - m) / signed_with_floor(dmu, MIN_DERIV);
1419 }
1420
1421 Ok(FamilyEvaluation {
1422 log_likelihood: ll,
1423 block_working_sets: vec![BlockWorkingSet {
1424 working_response: z,
1425 working_weights: w,
1426 gradient_eta: None,
1427 }],
1428 })
1429 }
1430}
1431
1432impl CustomFamilyGenerative for GammaLogFamily {
1433 fn generative_spec(
1434 &self,
1435 block_states: &[ParameterBlockState],
1436 ) -> Result<GenerativeSpec, String> {
1437 if block_states.len() != 1 {
1438 return Err(format!(
1439 "GammaLogFamily expects 1 block, got {}",
1440 block_states.len()
1441 ));
1442 }
1443 let mean = block_states[Self::BLOCK_ETA]
1444 .eta
1445 .mapv(|e| e.clamp(-30.0, 30.0).exp().max(1e-12));
1446 Ok(GenerativeSpec {
1447 mean,
1448 noise: NoiseModel::Gamma { shape: self.shape },
1449 })
1450 }
1451}
1452
1453#[derive(Clone)]
1460pub struct BinomialLocationScaleProbitFamily {
1461 pub y: Array1<f64>,
1462 pub score: Array1<f64>,
1463 pub weights: Array1<f64>,
1464 pub sigma_min: f64,
1465 pub sigma_max: f64,
1466}
1467
1468impl BinomialLocationScaleProbitFamily {
1469 pub const BLOCK_T: usize = 0;
1470 pub const BLOCK_LOG_SIGMA: usize = 1;
1471
1472 pub fn parameter_names() -> &'static [&'static str] {
1473 &["threshold", "log_sigma"]
1474 }
1475
1476 pub fn parameter_links() -> &'static [ParameterLink] {
1477 &[ParameterLink::Probit, ParameterLink::Log]
1478 }
1479
1480 pub fn metadata() -> FamilyMetadata {
1481 FamilyMetadata {
1482 name: "binomial_location_scale_probit",
1483 parameter_names: Self::parameter_names(),
1484 parameter_links: Self::parameter_links(),
1485 }
1486 }
1487}
1488
1489impl CustomFamily for BinomialLocationScaleProbitFamily {
1490 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1491 if block_states.len() != 2 {
1492 return Err(format!(
1493 "BinomialLocationScaleProbitFamily expects 2 blocks, got {}",
1494 block_states.len()
1495 ));
1496 }
1497 let n = self.y.len();
1498 let eta_t = &block_states[Self::BLOCK_T].eta;
1499 let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1500 if eta_t.len() != n || eta_ls.len() != n || self.weights.len() != n || self.score.len() != n
1501 {
1502 return Err("BinomialLocationScaleProbitFamily input size mismatch".to_string());
1503 }
1504
1505 let core = binomial_location_scale_core(
1506 &self.y,
1507 &self.score,
1508 &self.weights,
1509 eta_t,
1510 eta_ls,
1511 None,
1512 self.sigma_min,
1513 self.sigma_max,
1514 )?;
1515 let (t_ws, ls_ws, _none) = binomial_location_scale_working_sets(
1516 &self.y,
1517 &self.weights,
1518 eta_t,
1519 eta_ls,
1520 None,
1521 &core,
1522 );
1523
1524 Ok(FamilyEvaluation {
1525 log_likelihood: core.log_likelihood,
1526 block_working_sets: vec![t_ws, ls_ws],
1527 })
1528 }
1529}
1530
1531impl CustomFamilyGenerative for BinomialLocationScaleProbitFamily {
1532 fn generative_spec(
1533 &self,
1534 block_states: &[ParameterBlockState],
1535 ) -> Result<GenerativeSpec, String> {
1536 if block_states.len() != 2 {
1537 return Err(format!(
1538 "BinomialLocationScaleProbitFamily expects 2 blocks, got {}",
1539 block_states.len()
1540 ));
1541 }
1542 let eta_t = &block_states[Self::BLOCK_T].eta;
1543 let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1544 if eta_t.len() != self.score.len() || eta_ls.len() != self.score.len() {
1545 return Err("BinomialLocationScaleProbitFamily generative size mismatch".to_string());
1546 }
1547 let mut mean = Array1::<f64>::zeros(self.score.len());
1548 for i in 0..mean.len() {
1549 let sigma = eta_ls[i]
1550 .exp()
1551 .clamp(self.sigma_min, self.sigma_max)
1552 .max(1e-12);
1553 let q = (self.score[i] - eta_t[i]) / sigma;
1554 mean[i] = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
1555 }
1556 Ok(GenerativeSpec {
1557 mean,
1558 noise: NoiseModel::Bernoulli,
1559 })
1560 }
1561}
1562
1563#[derive(Clone)]
1570pub struct BinomialLocationScaleProbitWiggleFamily {
1571 pub y: Array1<f64>,
1572 pub score: Array1<f64>,
1573 pub weights: Array1<f64>,
1574 pub sigma_min: f64,
1575 pub sigma_max: f64,
1576 pub wiggle_knots: Array1<f64>,
1577 pub wiggle_degree: usize,
1578}
1579
1580impl BinomialLocationScaleProbitWiggleFamily {
1581 pub const BLOCK_T: usize = 0;
1582 pub const BLOCK_LOG_SIGMA: usize = 1;
1583 pub const BLOCK_WIGGLE: usize = 2;
1584
1585 pub fn parameter_names() -> &'static [&'static str] {
1586 &["threshold", "log_sigma", "wiggle"]
1587 }
1588
1589 pub fn parameter_links() -> &'static [ParameterLink] {
1590 &[
1591 ParameterLink::Probit,
1592 ParameterLink::Log,
1593 ParameterLink::Wiggle,
1594 ]
1595 }
1596
1597 pub fn metadata() -> FamilyMetadata {
1598 FamilyMetadata {
1599 name: "binomial_location_scale_probit_wiggle",
1600 parameter_names: Self::parameter_names(),
1601 parameter_links: Self::parameter_links(),
1602 }
1603 }
1604
1605 pub fn initialize_wiggle_knots_from_q(
1606 q_seed: ArrayView1<'_, f64>,
1607 degree: usize,
1608 num_internal_knots: usize,
1609 ) -> Result<Array1<f64>, String> {
1610 initialize_wiggle_knots_from_seed(q_seed, degree, num_internal_knots)
1611 }
1612
1613 fn wiggle_design(&self, q0: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
1614 let (basis, _) = create_basis::<Dense>(
1615 q0,
1616 KnotSource::Provided(self.wiggle_knots.view()),
1617 self.wiggle_degree,
1618 BasisOptions::value(),
1619 )
1620 .map_err(|e| e.to_string())?;
1621 let full = (*basis).clone();
1622 if full.ncols() < 2 {
1623 return Err("wiggle basis has fewer than two columns".to_string());
1624 }
1625 Ok(full.slice(s![.., 1..]).to_owned())
1626 }
1627
1628 pub fn build_wiggle_block_input(
1631 q_seed: ArrayView1<'_, f64>,
1632 degree: usize,
1633 num_internal_knots: usize,
1634 penalty_order: usize,
1635 double_penalty: bool,
1636 ) -> Result<(ParameterBlockInput, Array1<f64>), String> {
1637 let knots = Self::initialize_wiggle_knots_from_q(q_seed, degree, num_internal_knots)?;
1638 let block = build_wiggle_block_input_from_knots(
1639 q_seed,
1640 &knots,
1641 degree,
1642 penalty_order,
1643 double_penalty,
1644 )?;
1645 Ok((block, knots))
1646 }
1647}
1648
1649impl CustomFamily for BinomialLocationScaleProbitWiggleFamily {
1650 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1651 if block_states.len() != 3 {
1652 return Err(format!(
1653 "BinomialLocationScaleProbitWiggleFamily expects 3 blocks, got {}",
1654 block_states.len()
1655 ));
1656 }
1657 let n = self.y.len();
1658 let eta_t = &block_states[Self::BLOCK_T].eta;
1659 let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1660 let eta_w = &block_states[Self::BLOCK_WIGGLE].eta;
1661 if eta_t.len() != n || eta_ls.len() != n || eta_w.len() != n || self.score.len() != n {
1662 return Err("BinomialLocationScaleProbitWiggleFamily input size mismatch".to_string());
1663 }
1664
1665 let core = binomial_location_scale_core(
1666 &self.y,
1667 &self.score,
1668 &self.weights,
1669 eta_t,
1670 eta_ls,
1671 Some(eta_w),
1672 self.sigma_min,
1673 self.sigma_max,
1674 )?;
1675 let (t_ws, ls_ws, w_ws) = binomial_location_scale_working_sets(
1676 &self.y,
1677 &self.weights,
1678 eta_t,
1679 eta_ls,
1680 Some(eta_w),
1681 &core,
1682 );
1683 let w_ws = w_ws.ok_or_else(|| "wiggle working set missing".to_string())?;
1684
1685 Ok(FamilyEvaluation {
1686 log_likelihood: core.log_likelihood,
1687 block_working_sets: vec![t_ws, ls_ws, w_ws],
1688 })
1689 }
1690
1691 fn known_link_wiggle(&self) -> Option<KnownLinkWiggle> {
1692 Some(KnownLinkWiggle {
1693 base_link: LinkFunction::Probit,
1694 wiggle_block: Some(Self::BLOCK_WIGGLE),
1695 })
1696 }
1697
1698 fn block_geometry(
1699 &self,
1700 block_index: usize,
1701 block_states: &[ParameterBlockState],
1702 spec: &crate::custom_family::ParameterBlockSpec,
1703 ) -> Result<(DesignMatrix, Array1<f64>), String> {
1704 if block_index != Self::BLOCK_WIGGLE {
1705 return Ok((spec.design.clone(), spec.offset.clone()));
1706 }
1707 if block_states.len() < 2 {
1708 return Err("wiggle geometry requires threshold and log-sigma blocks".to_string());
1709 }
1710 let eta_t = &block_states[Self::BLOCK_T].eta;
1711 let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1712 if eta_t.len() != self.score.len() || eta_ls.len() != self.score.len() {
1713 return Err("wiggle geometry input size mismatch".to_string());
1714 }
1715 let mut q0 = Array1::<f64>::zeros(self.score.len());
1716 for i in 0..q0.len() {
1717 let sigma = eta_ls[i]
1718 .exp()
1719 .clamp(self.sigma_min, self.sigma_max)
1720 .max(1e-12);
1721 q0[i] = (self.score[i] - eta_t[i]) / sigma;
1722 }
1723 let x = self.wiggle_design(q0.view())?;
1724 if x.ncols() != spec.design.ncols() {
1725 return Err(format!(
1726 "dynamic wiggle design col mismatch: got {}, expected {}",
1727 x.ncols(),
1728 spec.design.ncols()
1729 ));
1730 }
1731 let nrows = x.nrows();
1732 Ok((DesignMatrix::Dense(x), Array1::zeros(nrows)))
1733 }
1734}
1735
1736impl CustomFamilyGenerative for BinomialLocationScaleProbitWiggleFamily {
1737 fn generative_spec(
1738 &self,
1739 block_states: &[ParameterBlockState],
1740 ) -> Result<GenerativeSpec, String> {
1741 if block_states.len() != 3 {
1742 return Err(format!(
1743 "BinomialLocationScaleProbitWiggleFamily expects 3 blocks, got {}",
1744 block_states.len()
1745 ));
1746 }
1747 let eta_t = &block_states[Self::BLOCK_T].eta;
1748 let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1749 let eta_w = &block_states[Self::BLOCK_WIGGLE].eta;
1750 if eta_t.len() != self.score.len()
1751 || eta_ls.len() != self.score.len()
1752 || eta_w.len() != self.score.len()
1753 {
1754 return Err(
1755 "BinomialLocationScaleProbitWiggleFamily generative size mismatch".to_string(),
1756 );
1757 }
1758 let mut mean = Array1::<f64>::zeros(self.score.len());
1759 for i in 0..mean.len() {
1760 let sigma = eta_ls[i]
1761 .exp()
1762 .clamp(self.sigma_min, self.sigma_max)
1763 .max(1e-12);
1764 let q0 = (self.score[i] - eta_t[i]) / sigma;
1765 mean[i] = normal_cdf_approx(q0 + eta_w[i]).clamp(MIN_PROB, 1.0 - MIN_PROB);
1766 }
1767 Ok(GenerativeSpec {
1768 mean,
1769 noise: NoiseModel::Bernoulli,
1770 })
1771 }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776 use super::*;
1777
1778 fn intercept_block(n: usize) -> ParameterBlockInput {
1779 ParameterBlockInput {
1780 design: DesignMatrix::Dense(Array2::from_elem((n, 1), 1.0)),
1781 offset: Array1::zeros(n),
1782 penalties: Vec::new(),
1783 initial_log_lambdas: None,
1784 initial_beta: None,
1785 }
1786 }
1787
1788 #[test]
1789 fn weighted_projection_returns_finite_coefficients() {
1790 let n = 8usize;
1791 let design = DesignMatrix::Dense(Array2::from_elem((n, 1), 1.0));
1792 let offset = Array1::zeros(n);
1793 let target_eta = Array1::from_vec(vec![0.2; n]);
1794 let weights = Array1::from_vec(vec![1.0; n]);
1795 let beta =
1796 solve_weighted_projection(&design, &offset, &target_eta, &weights, 1e-10).unwrap();
1797 assert_eq!(beta.len(), 1);
1798 assert!(beta[0].is_finite());
1799 assert!((beta[0] - 0.2).abs() < 1e-6);
1800 }
1801
1802 #[test]
1803 fn alpha_beta_warm_start_produces_finite_targets() {
1804 let n = 16usize;
1805 let y = Array1::from_vec((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }).collect());
1806 let score = Array1::from_vec((0..n).map(|i| i as f64 / n as f64 - 0.5).collect());
1807 let weights = Array1::from_vec(vec![1.0; n]);
1808 let threshold = intercept_block(n);
1809 let log_sigma = intercept_block(n);
1810
1811 let (beta_t, beta_ls, beta_obs) = try_binomial_alpha_beta_warm_start(
1812 &y,
1813 &score,
1814 &weights,
1815 0.25,
1816 4.0,
1817 &threshold,
1818 &log_sigma,
1819 &BlockwiseFitOptions::default(),
1820 )
1821 .unwrap();
1822
1823 assert_eq!(beta_t.len(), 1);
1824 assert_eq!(beta_ls.len(), 1);
1825 assert!(beta_t[0].is_finite());
1826 assert!(beta_ls[0].is_finite());
1827 assert!(beta_obs.iter().all(|v| v.is_finite() && *v > 0.0));
1828 }
1829
1830 #[test]
1831 fn fit_binomial_location_scale_probit_runs_with_warm_start_path() {
1832 let n = 32usize;
1833 let y = Array1::from_vec((0..n).map(|i| if i % 4 == 0 { 1.0 } else { 0.0 }).collect());
1834 let score = Array1::from_vec((0..n).map(|i| (i as f64 - 16.0) / 10.0).collect());
1835 let weights = Array1::from_vec(vec![1.0; n]);
1836 let spec = BinomialLocationScaleProbitSpec {
1837 y,
1838 score,
1839 weights,
1840 sigma_min: 0.3,
1841 sigma_max: 3.0,
1842 threshold_block: intercept_block(n),
1843 log_sigma_block: intercept_block(n),
1844 };
1845
1846 let fit = fit_binomial_location_scale_probit(spec, &BlockwiseFitOptions::default())
1847 .expect("binomial location-scale probit should fit");
1848 assert_eq!(fit.block_states.len(), 2);
1849 assert!(fit.log_likelihood.is_finite());
1850 }
1851}