Skip to main content

gam/
gamlss.rs

1use crate::basis::{
2    BasisOptions, Dense, KnotSource, create_basis, create_difference_penalty_matrix,
3};
4use crate::custom_family::{
5    BlockWorkingSet, BlockwiseFitOptions, BlockwiseFitResult, CustomFamily, FamilyEvaluation,
6    KnownLinkWiggle, ParameterBlockSpec, ParameterBlockState, fit_custom_family,
7};
8use crate::faer_ndarray::{fast_ata, fast_atv};
9use crate::generative::{CustomFamilyGenerative, GenerativeSpec, NoiseModel};
10use crate::matrix::DesignMatrix;
11use crate::pirls::WorkingLikelihood as EngineWorkingLikelihood;
12use crate::probability::{normal_cdf_approx, normal_pdf};
13use crate::types::{LikelihoodFamily, LinkFunction};
14use faer::Mat as FaerMat;
15use faer::Side;
16use faer::linalg::solvers::{
17    Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
18};
19use ndarray::{Array1, Array2, ArrayView1, s};
20
21const MIN_PROB: f64 = 1e-10;
22const MIN_DERIV: f64 = 1e-8;
23const MIN_WEIGHT: f64 = 1e-12;
24const BETA_RANGE_WARN_THRESHOLD: f64 = 1.10;
25const BINOMIAL_EFFECTIVE_N_WARN_THRESHOLD: f64 = 25.0;
26
27/// Generic block input for high-level built-in family APIs.
28#[derive(Clone)]
29pub struct ParameterBlockInput {
30    pub design: DesignMatrix,
31    pub offset: Array1<f64>,
32    pub penalties: Vec<Array2<f64>>,
33    pub initial_log_lambdas: Option<Array1<f64>>,
34    pub initial_beta: Option<Array1<f64>>,
35}
36
37#[derive(Clone, Debug)]
38pub struct FamilyMetadata {
39    pub name: &'static str,
40    pub parameter_names: &'static [&'static str],
41    pub parameter_links: &'static [ParameterLink],
42}
43
44#[derive(Clone, Debug)]
45pub struct WiggleBlockConfig {
46    pub degree: usize,
47    pub num_internal_knots: usize,
48    pub penalty_order: usize,
49    pub double_penalty: bool,
50}
51
52impl ParameterBlockInput {
53    pub fn into_spec(self, name: &str) -> Result<ParameterBlockSpec, String> {
54        let p = self.design.ncols();
55        let n = self.design.nrows();
56        if self.offset.len() != n {
57            return Err(format!(
58                "block '{name}' offset length mismatch: got {}, expected {n}",
59                self.offset.len()
60            ));
61        }
62        if let Some(beta0) = &self.initial_beta {
63            if beta0.len() != p {
64                return Err(format!(
65                    "block '{name}' initial_beta length mismatch: got {}, expected {p}",
66                    beta0.len()
67                ));
68            }
69        }
70        for (k, s) in self.penalties.iter().enumerate() {
71            let (r, c) = s.dim();
72            if r != p || c != p {
73                return Err(format!(
74                    "block '{name}' penalty {k} must be {p}x{p}, got {r}x{c}"
75                ));
76            }
77        }
78        let k = self.penalties.len();
79        let initial_log_lambdas = self
80            .initial_log_lambdas
81            .unwrap_or_else(|| Array1::<f64>::zeros(k));
82        if initial_log_lambdas.len() != k {
83            return Err(format!(
84                "block '{name}' initial_log_lambdas length mismatch: got {}, expected {k}",
85                initial_log_lambdas.len()
86            ));
87        }
88        Ok(ParameterBlockSpec {
89            name: name.to_string(),
90            design: self.design,
91            offset: self.offset,
92            penalties: self.penalties,
93            initial_log_lambdas,
94            initial_beta: self.initial_beta,
95        })
96    }
97}
98
99fn validate_sigma_bounds(sigma_min: f64, sigma_max: f64, context: &str) -> Result<(), String> {
100    if !sigma_min.is_finite() || !sigma_max.is_finite() {
101        return Err(format!("{context}: sigma bounds must be finite"));
102    }
103    if sigma_min <= 0.0 || sigma_max <= 0.0 {
104        return Err(format!(
105            "{context}: sigma bounds must be strictly positive (got min={sigma_min}, max={sigma_max})"
106        ));
107    }
108    if sigma_min > sigma_max {
109        return Err(format!(
110            "{context}: sigma_min ({sigma_min}) must be <= sigma_max ({sigma_max})"
111        ));
112    }
113    Ok(())
114}
115
116fn validate_len_match(name: &str, expected: usize, found: usize) -> Result<(), String> {
117    if expected != found {
118        return Err(format!(
119            "{name} length mismatch: expected {expected}, found {found}"
120        ));
121    }
122    Ok(())
123}
124
125fn validate_weights(weights: &Array1<f64>, context: &str) -> Result<(), String> {
126    for (i, &w) in weights.iter().enumerate() {
127        if !w.is_finite() || w < 0.0 {
128            return Err(format!(
129                "{context}: weights must be finite and non-negative; found weights[{i}]={w}"
130            ));
131        }
132    }
133    Ok(())
134}
135
136fn validate_binomial_response(y: &Array1<f64>, context: &str) -> Result<(), String> {
137    for (i, &yi) in y.iter().enumerate() {
138        if !yi.is_finite() || !(0.0..=1.0).contains(&yi) {
139            return Err(format!(
140                "{context}: binomial response must be finite in [0,1]; found y[{i}]={yi}"
141            ));
142        }
143    }
144    Ok(())
145}
146
147pub fn initialize_wiggle_knots_from_seed(
148    seed: ArrayView1<'_, f64>,
149    degree: usize,
150    num_internal_knots: usize,
151) -> Result<Array1<f64>, String> {
152    let seed_min = seed.iter().copied().fold(f64::INFINITY, f64::min);
153    let mut seed_max = seed.iter().copied().fold(f64::NEG_INFINITY, f64::max);
154    if !seed_min.is_finite() || !seed_max.is_finite() {
155        return Err("non-finite seed for wiggle knot initialization".to_string());
156    }
157    if (seed_max - seed_min).abs() < 1e-12 {
158        seed_max = seed_min + 1e-6;
159    }
160    let (_, knots) = create_basis::<Dense>(
161        seed,
162        KnotSource::Generate {
163            data_range: (seed_min, seed_max),
164            num_internal_knots,
165        },
166        degree,
167        BasisOptions::value(),
168    )
169    .map_err(|e| e.to_string())?;
170    Ok(knots)
171}
172
173pub fn build_wiggle_block_input_from_knots(
174    seed: ArrayView1<'_, f64>,
175    knots: &Array1<f64>,
176    degree: usize,
177    penalty_order: usize,
178    double_penalty: bool,
179) -> Result<ParameterBlockInput, String> {
180    let (basis, _) = create_basis::<Dense>(
181        seed,
182        KnotSource::Provided(knots.view()),
183        degree,
184        BasisOptions::value(),
185    )
186    .map_err(|e| e.to_string())?;
187    let full = (*basis).clone();
188    if full.ncols() < 2 {
189        return Err("wiggle basis has fewer than two columns".to_string());
190    }
191    let design = full.slice(s![.., 1..]).to_owned();
192    let p = design.ncols();
193    let mut penalties =
194        vec![create_difference_penalty_matrix(p, penalty_order, None).map_err(|e| e.to_string())?];
195    if double_penalty {
196        penalties.push(Array2::<f64>::eye(p));
197    }
198    Ok(ParameterBlockInput {
199        design: DesignMatrix::Dense(design),
200        offset: Array1::zeros(seed.len()),
201        penalties,
202        initial_log_lambdas: None,
203        initial_beta: None,
204    })
205}
206
207pub fn build_wiggle_block_input_from_seed(
208    seed: ArrayView1<'_, f64>,
209    cfg: &WiggleBlockConfig,
210) -> Result<(ParameterBlockInput, Array1<f64>), String> {
211    let knots = initialize_wiggle_knots_from_seed(seed, cfg.degree, cfg.num_internal_knots)?;
212    let block = build_wiggle_block_input_from_knots(
213        seed,
214        &knots,
215        cfg.degree,
216        cfg.penalty_order,
217        cfg.double_penalty,
218    )?;
219    Ok((block, knots))
220}
221
222fn validate_block_rows(name: &str, n: usize, block: &ParameterBlockInput) -> Result<(), String> {
223    validate_len_match(
224        &format!("block '{name}' offset vs response"),
225        n,
226        block.offset.len(),
227    )?;
228    validate_len_match(
229        &format!("block '{name}' design rows vs response"),
230        n,
231        block.design.nrows(),
232    )
233}
234
235/// Shared single-block GLM evaluation adapter backed by the engine-level
236/// `WorkingLikelihood` implementation used by PIRLS.
237fn evaluate_single_block_glm(
238    family: LikelihoodFamily,
239    y: &Array1<f64>,
240    weights: &Array1<f64>,
241    eta: &Array1<f64>,
242) -> Result<FamilyEvaluation, String> {
243    let n = y.len();
244    if eta.len() != n || weights.len() != n {
245        return Err("single-block GLM input size mismatch".to_string());
246    }
247    let mut mu = Array1::<f64>::zeros(n);
248    let mut z = Array1::<f64>::zeros(n);
249    let mut w = Array1::<f64>::zeros(n);
250    family
251        .irls_update(y.view(), eta, weights.view(), &mut mu, &mut w, &mut z, None)
252        .map_err(|e| e.to_string())?;
253    let ll = family
254        .log_likelihood(y.view(), eta, &mu, weights.view())
255        .map_err(|e| e.to_string())?;
256    Ok(FamilyEvaluation {
257        log_likelihood: ll,
258        block_working_sets: vec![BlockWorkingSet {
259            working_response: z,
260            working_weights: w,
261            gradient_eta: None,
262        }],
263    })
264}
265
266fn initial_log_lambdas_or_zeros(block: &ParameterBlockInput) -> Result<Array1<f64>, String> {
267    let k = block.penalties.len();
268    let lambdas = block
269        .initial_log_lambdas
270        .clone()
271        .unwrap_or_else(|| Array1::<f64>::zeros(k));
272    if lambdas.len() != k {
273        return Err(format!(
274            "initial_log_lambdas length mismatch: got {}, expected {}",
275            lambdas.len(),
276            k
277        ));
278    }
279    Ok(lambdas)
280}
281
282fn solve_weighted_projection(
283    design: &DesignMatrix,
284    offset: &Array1<f64>,
285    target_eta: &Array1<f64>,
286    weights: &Array1<f64>,
287    ridge_floor: f64,
288) -> Result<Array1<f64>, String> {
289    let n = design.nrows();
290    let p = design.ncols();
291    if offset.len() != n || target_eta.len() != n || weights.len() != n {
292        return Err("solve_weighted_projection dimension mismatch".to_string());
293    }
294
295    let (mut xtwx, xtwy) = match design {
296        DesignMatrix::Dense(x) => {
297            let mut xw = x.clone();
298            for i in 0..n {
299                let sw = weights[i].max(0.0).sqrt();
300                if sw != 1.0 {
301                    let mut row = xw.row_mut(i);
302                    row *= sw;
303                }
304            }
305            let xtwx = fast_ata(&xw);
306            let mut y_w = target_eta - offset;
307            for i in 0..n {
308                y_w[i] *= weights[i].max(0.0).sqrt();
309            }
310            let xtwy = fast_atv(&xw, &y_w);
311            (xtwx, xtwy)
312        }
313        DesignMatrix::Sparse(xs) => {
314            let csr = xs
315                .as_ref()
316                .to_row_major()
317                .map_err(|_| "failed to obtain CSR view for weighted projection".to_string())?;
318            let sym = csr.symbolic();
319            let row_ptr = sym.row_ptr();
320            let col_idx = sym.col_idx();
321            let vals = csr.val();
322            let mut xtwx = Array2::<f64>::zeros((p, p));
323            let mut xtwy = Array1::<f64>::zeros(p);
324
325            for i in 0..n {
326                let wi = weights[i].max(0.0);
327                if wi == 0.0 {
328                    continue;
329                }
330                let y_star = target_eta[i] - offset[i];
331                let start = row_ptr[i];
332                let end = row_ptr[i + 1];
333                for a_ptr in start..end {
334                    let a = col_idx[a_ptr];
335                    let xa = vals[a_ptr];
336                    xtwy[a] += wi * xa * y_star;
337                    for b_ptr in a_ptr..end {
338                        let b = col_idx[b_ptr];
339                        let xb = vals[b_ptr];
340                        xtwx[[a, b]] += wi * xa * xb;
341                    }
342                }
343            }
344            for a in 0..p {
345                for b in 0..a {
346                    xtwx[[a, b]] = xtwx[[b, a]];
347                }
348            }
349            (xtwx, xtwy)
350        }
351    };
352    for a in 0..p {
353        xtwx[[a, a]] += ridge_floor.max(1e-12);
354    }
355
356    let h = crate::faer_ndarray::FaerArrayView::new(&xtwx);
357    let mut rhs_mat = FaerMat::zeros(p, 1);
358    for i in 0..p {
359        rhs_mat[(i, 0)] = xtwy[i];
360    }
361
362    if let Ok(ch) = FaerLlt::new(h.as_ref(), Side::Lower) {
363        ch.solve_in_place(rhs_mat.as_mut());
364    } else if let Ok(ld) = FaerLdlt::new(h.as_ref(), Side::Lower) {
365        ld.solve_in_place(rhs_mat.as_mut());
366    } else {
367        let lb = FaerLblt::new(h.as_ref(), Side::Lower);
368        lb.solve_in_place(rhs_mat.as_mut());
369    }
370
371    let mut beta = Array1::<f64>::zeros(p);
372    for i in 0..p {
373        beta[i] = rhs_mat[(i, 0)];
374    }
375    if beta.iter().any(|v| !v.is_finite()) {
376        return Err("solve_weighted_projection produced non-finite coefficients".to_string());
377    }
378    Ok(beta)
379}
380
381fn weighted_prevalence(y: &Array1<f64>, weights: &Array1<f64>) -> f64 {
382    let w_sum: f64 = weights.iter().copied().sum();
383    if w_sum <= 0.0 {
384        return 0.5;
385    }
386    let y_w_sum: f64 = y.iter().zip(weights.iter()).map(|(&yi, &wi)| yi * wi).sum();
387    (y_w_sum / w_sum).clamp(0.0, 1.0)
388}
389
390fn emit_binomial_alpha_beta_warnings(
391    context: &str,
392    beta_values: &Array1<f64>,
393    y: &Array1<f64>,
394    weights: &Array1<f64>,
395) {
396    if beta_values.is_empty() {
397        return;
398    }
399    let beta_min = beta_values.iter().copied().fold(f64::INFINITY, f64::min);
400    let beta_max = beta_values
401        .iter()
402        .copied()
403        .fold(f64::NEG_INFINITY, f64::max);
404
405    if !beta_min.is_finite() || !beta_max.is_finite() || beta_min <= 0.0 {
406        log::warn!(
407            "[GAMLSS][{}] non-positive or non-finite beta encountered (min={}, max={})",
408            context,
409            beta_min,
410            beta_max
411        );
412    } else {
413        let ratio = beta_max / beta_min;
414        if ratio > BETA_RANGE_WARN_THRESHOLD {
415            log::warn!(
416                "[GAMLSS][{}] beta range ratio {:.3} exceeds {:.3}; transformed-penalty distortion risk is elevated",
417                context,
418                ratio,
419                BETA_RANGE_WARN_THRESHOLD
420            );
421        }
422    }
423
424    let pi = weighted_prevalence(y, weights);
425    let w_sum: f64 = weights.iter().copied().sum();
426    let n_eff = w_sum * pi * (1.0 - pi);
427    if n_eff < BINOMIAL_EFFECTIVE_N_WARN_THRESHOLD {
428        log::warn!(
429            "[GAMLSS][{}] low effective sample size N_eff={:.3} (sum_w={:.3}, prevalence={:.3}); location-scale separation artifacts are more likely",
430            context,
431            n_eff,
432            w_sum,
433            pi
434        );
435    }
436}
437
438#[derive(Clone)]
439struct BinomialAlphaBetaWarmStartFamily {
440    y: Array1<f64>,
441    score: Array1<f64>,
442    weights: Array1<f64>,
443    beta_min: f64,
444    beta_max: f64,
445}
446
447impl BinomialAlphaBetaWarmStartFamily {
448    const BLOCK_ALPHA: usize = 0;
449    const BLOCK_BETA: usize = 1;
450}
451
452impl CustomFamily for BinomialAlphaBetaWarmStartFamily {
453    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
454        if block_states.len() != 2 {
455            return Err(format!(
456                "BinomialAlphaBetaWarmStartFamily expects 2 blocks, got {}",
457                block_states.len()
458            ));
459        }
460        let n = self.y.len();
461        let eta_alpha = &block_states[Self::BLOCK_ALPHA].eta;
462        let eta_beta = &block_states[Self::BLOCK_BETA].eta;
463        if eta_alpha.len() != n
464            || eta_beta.len() != n
465            || self.score.len() != n
466            || self.weights.len() != n
467        {
468            return Err("BinomialAlphaBetaWarmStartFamily input size mismatch".to_string());
469        }
470
471        let mut z_alpha = Array1::<f64>::zeros(n);
472        let mut w_alpha = Array1::<f64>::zeros(n);
473        let mut z_beta = Array1::<f64>::zeros(n);
474        let mut w_beta = Array1::<f64>::zeros(n);
475        let mut ll = 0.0_f64;
476
477        for i in 0..n {
478            let raw_beta = eta_beta[i];
479            let beta = raw_beta.clamp(self.beta_min, self.beta_max);
480            let dbeta_deta = if raw_beta >= self.beta_min && raw_beta <= self.beta_max {
481                1.0
482            } else {
483                0.0
484            };
485            let q = eta_alpha[i] + beta * self.score[i];
486            let mu = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
487            let dmu_dq = normal_pdf(q).max(MIN_DERIV);
488            let var = (mu * (1.0 - mu)).max(MIN_PROB);
489
490            ll += self.weights[i] * (self.y[i] * mu.ln() + (1.0 - self.y[i]) * (1.0 - mu).ln());
491
492            let dmu_alpha = dmu_dq;
493            w_alpha[i] = (self.weights[i] * (dmu_alpha * dmu_alpha / var)).max(MIN_WEIGHT);
494            z_alpha[i] = eta_alpha[i] + (self.y[i] - mu) / signed_with_floor(dmu_alpha, MIN_DERIV);
495
496            let chain_beta = self.score[i] * dbeta_deta;
497            let dmu_beta = dmu_dq * chain_beta;
498            w_beta[i] = (self.weights[i] * (dmu_beta * dmu_beta / var)).max(MIN_WEIGHT);
499            z_beta[i] = eta_beta[i] + (self.y[i] - mu) / signed_with_floor(dmu_beta, MIN_DERIV);
500        }
501
502        Ok(FamilyEvaluation {
503            log_likelihood: ll,
504            block_working_sets: vec![
505                BlockWorkingSet {
506                    working_response: z_alpha,
507                    working_weights: w_alpha,
508                    gradient_eta: None,
509                },
510                BlockWorkingSet {
511                    working_response: z_beta,
512                    working_weights: w_beta,
513                    gradient_eta: None,
514                },
515            ],
516        })
517    }
518
519    fn post_update_beta(
520        &self,
521        block_index: usize,
522        beta: Array1<f64>,
523    ) -> Result<Array1<f64>, String> {
524        if block_index != Self::BLOCK_BETA {
525            return Ok(beta);
526        }
527        Ok(beta.mapv(|v| v.clamp(self.beta_min, self.beta_max)))
528    }
529}
530
531fn try_binomial_alpha_beta_warm_start(
532    y: &Array1<f64>,
533    score: &Array1<f64>,
534    weights: &Array1<f64>,
535    sigma_min: f64,
536    sigma_max: f64,
537    threshold_block: &ParameterBlockInput,
538    log_sigma_block: &ParameterBlockInput,
539    options: &BlockwiseFitOptions,
540) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
541    let beta_min = (1.0 / sigma_max.max(1e-12)).max(1e-12);
542    let beta_max = (1.0 / sigma_min.max(1e-12)).max(beta_min + 1e-12);
543    let warm_family = BinomialAlphaBetaWarmStartFamily {
544        y: y.clone(),
545        score: score.clone(),
546        weights: weights.clone(),
547        beta_min,
548        beta_max,
549    };
550
551    let alpha_spec = ParameterBlockSpec {
552        name: "alpha_warm".to_string(),
553        design: threshold_block.design.clone(),
554        offset: threshold_block.offset.clone(),
555        penalties: threshold_block.penalties.clone(),
556        initial_log_lambdas: initial_log_lambdas_or_zeros(threshold_block)?,
557        initial_beta: None,
558    };
559    let beta_spec = ParameterBlockSpec {
560        name: "beta_warm".to_string(),
561        design: log_sigma_block.design.clone(),
562        offset: log_sigma_block.offset.clone(),
563        penalties: log_sigma_block.penalties.clone(),
564        initial_log_lambdas: initial_log_lambdas_or_zeros(log_sigma_block)?,
565        initial_beta: None,
566    };
567
568    let warm_options = BlockwiseFitOptions {
569        inner_max_cycles: options.inner_max_cycles.min(40).max(5),
570        inner_tol: options.inner_tol,
571        outer_max_iter: options.outer_max_iter.min(20).max(3),
572        outer_tol: options.outer_tol.max(1e-6),
573        min_weight: options.min_weight,
574        ridge_floor: options.ridge_floor.max(1e-10),
575        ridge_policy: options.ridge_policy,
576        // Warm start optimization focuses on robust initialization, not REML correction.
577        use_reml_objective: false,
578    };
579    let warm_fit = fit_custom_family(&warm_family, &[alpha_spec, beta_spec], &warm_options)?;
580    let eta_alpha = &warm_fit.block_states[BinomialAlphaBetaWarmStartFamily::BLOCK_ALPHA].eta;
581    let eta_beta = &warm_fit.block_states[BinomialAlphaBetaWarmStartFamily::BLOCK_BETA].eta;
582    if eta_alpha.len() != y.len() || eta_beta.len() != y.len() {
583        return Err("warm start eta length mismatch".to_string());
584    }
585
586    let beta_obs = eta_beta.mapv(|v| v.clamp(beta_min, beta_max));
587    let t_target = Array1::from_iter(
588        eta_alpha
589            .iter()
590            .zip(beta_obs.iter())
591            .map(|(&a, &b)| -a / b.max(1e-12)),
592    );
593    let log_sigma_target = beta_obs.mapv(|b| -b.max(1e-12).ln());
594
595    let beta_t = solve_weighted_projection(
596        &threshold_block.design,
597        &threshold_block.offset,
598        &t_target,
599        weights,
600        options.ridge_floor.max(1e-10),
601    )?;
602    let beta_log_sigma = solve_weighted_projection(
603        &log_sigma_block.design,
604        &log_sigma_block.offset,
605        &log_sigma_target,
606        weights,
607        options.ridge_floor.max(1e-10),
608    )?;
609
610    Ok((beta_t, beta_log_sigma, beta_obs))
611}
612
613#[derive(Clone)]
614pub struct GaussianLocationScaleSpec {
615    pub y: Array1<f64>,
616    pub weights: Array1<f64>,
617    pub sigma_min: f64,
618    pub sigma_max: f64,
619    pub mu_block: ParameterBlockInput,
620    pub log_sigma_block: ParameterBlockInput,
621}
622
623#[derive(Clone)]
624pub struct BinomialLogitSpec {
625    pub y: Array1<f64>,
626    pub weights: Array1<f64>,
627    pub eta_block: ParameterBlockInput,
628}
629
630#[derive(Clone)]
631pub struct PoissonLogSpec {
632    pub y: Array1<f64>,
633    pub weights: Array1<f64>,
634    pub eta_block: ParameterBlockInput,
635}
636
637#[derive(Clone)]
638pub struct GammaLogSpec {
639    pub y: Array1<f64>,
640    pub weights: Array1<f64>,
641    /// Gamma shape parameter (k > 0).
642    pub shape: f64,
643    pub eta_block: ParameterBlockInput,
644}
645
646#[derive(Clone)]
647pub struct BinomialLocationScaleProbitSpec {
648    pub y: Array1<f64>,
649    pub score: Array1<f64>,
650    pub weights: Array1<f64>,
651    pub sigma_min: f64,
652    pub sigma_max: f64,
653    pub threshold_block: ParameterBlockInput,
654    pub log_sigma_block: ParameterBlockInput,
655}
656
657#[derive(Clone)]
658pub struct BinomialLocationScaleProbitWiggleSpec {
659    pub y: Array1<f64>,
660    pub score: Array1<f64>,
661    pub weights: Array1<f64>,
662    pub sigma_min: f64,
663    pub sigma_max: f64,
664    pub wiggle_knots: Array1<f64>,
665    pub wiggle_degree: usize,
666    pub threshold_block: ParameterBlockInput,
667    pub log_sigma_block: ParameterBlockInput,
668    pub wiggle_block: ParameterBlockInput,
669}
670
671pub fn fit_gaussian_location_scale(
672    spec: GaussianLocationScaleSpec,
673    options: &BlockwiseFitOptions,
674) -> Result<BlockwiseFitResult, String> {
675    let n = spec.y.len();
676    validate_len_match("weights vs y", n, spec.weights.len())?;
677    validate_weights(&spec.weights, "fit_gaussian_location_scale")?;
678    validate_sigma_bounds(
679        spec.sigma_min,
680        spec.sigma_max,
681        "fit_gaussian_location_scale",
682    )?;
683    validate_block_rows("mu", n, &spec.mu_block)?;
684    validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
685
686    let family = GaussianLocationScaleFamily {
687        y: spec.y,
688        weights: spec.weights,
689        sigma_min: spec.sigma_min,
690        sigma_max: spec.sigma_max,
691    };
692    let blocks = vec![
693        spec.mu_block.into_spec("mu")?,
694        spec.log_sigma_block.into_spec("log_sigma")?,
695    ];
696    fit_custom_family(&family, &blocks, options)
697}
698
699pub fn fit_binomial_logit(
700    spec: BinomialLogitSpec,
701    options: &BlockwiseFitOptions,
702) -> Result<BlockwiseFitResult, String> {
703    let n = spec.y.len();
704    validate_len_match("weights vs y", n, spec.weights.len())?;
705    validate_weights(&spec.weights, "fit_binomial_logit")?;
706    validate_binomial_response(&spec.y, "fit_binomial_logit")?;
707    validate_block_rows("eta", n, &spec.eta_block)?;
708
709    let family = BinomialLogitFamily {
710        y: spec.y,
711        weights: spec.weights,
712    };
713    let blocks = vec![spec.eta_block.into_spec("eta")?];
714    fit_custom_family(&family, &blocks, options)
715}
716
717pub fn fit_poisson_log(
718    spec: PoissonLogSpec,
719    options: &BlockwiseFitOptions,
720) -> Result<BlockwiseFitResult, String> {
721    let n = spec.y.len();
722    validate_len_match("weights vs y", n, spec.weights.len())?;
723    validate_weights(&spec.weights, "fit_poisson_log")?;
724    validate_block_rows("eta", n, &spec.eta_block)?;
725
726    let family = PoissonLogFamily {
727        y: spec.y,
728        weights: spec.weights,
729    };
730    let blocks = vec![spec.eta_block.into_spec("eta")?];
731    fit_custom_family(&family, &blocks, options)
732}
733
734pub fn fit_gamma_log(
735    spec: GammaLogSpec,
736    options: &BlockwiseFitOptions,
737) -> Result<BlockwiseFitResult, String> {
738    let n = spec.y.len();
739    validate_len_match("weights vs y", n, spec.weights.len())?;
740    validate_weights(&spec.weights, "fit_gamma_log")?;
741    validate_block_rows("eta", n, &spec.eta_block)?;
742    if !spec.shape.is_finite() || spec.shape <= 0.0 {
743        return Err(format!(
744            "fit_gamma_log: shape must be finite and > 0, got {}",
745            spec.shape
746        ));
747    }
748
749    let family = GammaLogFamily {
750        y: spec.y,
751        weights: spec.weights,
752        shape: spec.shape,
753    };
754    let blocks = vec![spec.eta_block.into_spec("eta")?];
755    fit_custom_family(&family, &blocks, options)
756}
757
758pub fn fit_binomial_location_scale_probit(
759    spec: BinomialLocationScaleProbitSpec,
760    options: &BlockwiseFitOptions,
761) -> Result<BlockwiseFitResult, String> {
762    let n = spec.y.len();
763    validate_len_match("score vs y", n, spec.score.len())?;
764    validate_len_match("weights vs y", n, spec.weights.len())?;
765    validate_weights(&spec.weights, "fit_binomial_location_scale_probit")?;
766    validate_binomial_response(&spec.y, "fit_binomial_location_scale_probit")?;
767    validate_sigma_bounds(
768        spec.sigma_min,
769        spec.sigma_max,
770        "fit_binomial_location_scale_probit",
771    )?;
772    validate_block_rows("threshold", n, &spec.threshold_block)?;
773    validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
774
775    let BinomialLocationScaleProbitSpec {
776        y,
777        score,
778        weights,
779        sigma_min,
780        sigma_max,
781        mut threshold_block,
782        mut log_sigma_block,
783    } = spec;
784
785    match try_binomial_alpha_beta_warm_start(
786        &y,
787        &score,
788        &weights,
789        sigma_min,
790        sigma_max,
791        &threshold_block,
792        &log_sigma_block,
793        options,
794    ) {
795        Ok((beta_t0, beta_ls0, beta_warm)) => {
796            threshold_block.initial_beta = Some(beta_t0);
797            log_sigma_block.initial_beta = Some(beta_ls0);
798            emit_binomial_alpha_beta_warnings("warm-start", &beta_warm, &y, &weights);
799        }
800        Err(err) => {
801            log::warn!(
802                "[GAMLSS][fit_binomial_location_scale_probit] alpha/beta warm start failed, falling back to direct initialization: {}",
803                err
804            );
805        }
806    }
807
808    let family = BinomialLocationScaleProbitFamily {
809        y: y.clone(),
810        score: score.clone(),
811        weights: weights.clone(),
812        sigma_min,
813        sigma_max,
814    };
815    let blocks = vec![
816        threshold_block.into_spec("threshold")?,
817        log_sigma_block.into_spec("log_sigma")?,
818    ];
819    let fit = fit_custom_family(&family, &blocks, options)?;
820    let beta_final = fit.block_states[BinomialLocationScaleProbitFamily::BLOCK_LOG_SIGMA]
821        .eta
822        .mapv(f64::exp)
823        .mapv(|s| 1.0 / s.clamp(sigma_min, sigma_max).max(1e-12));
824    emit_binomial_alpha_beta_warnings("final-fit", &beta_final, &y, &weights);
825    Ok(fit)
826}
827
828pub fn fit_binomial_location_scale_probit_wiggle(
829    spec: BinomialLocationScaleProbitWiggleSpec,
830    options: &BlockwiseFitOptions,
831) -> Result<BlockwiseFitResult, String> {
832    let n = spec.y.len();
833    validate_len_match("score vs y", n, spec.score.len())?;
834    validate_len_match("weights vs y", n, spec.weights.len())?;
835    validate_weights(&spec.weights, "fit_binomial_location_scale_probit_wiggle")?;
836    validate_binomial_response(&spec.y, "fit_binomial_location_scale_probit_wiggle")?;
837    validate_sigma_bounds(
838        spec.sigma_min,
839        spec.sigma_max,
840        "fit_binomial_location_scale_probit_wiggle",
841    )?;
842    validate_block_rows("threshold", n, &spec.threshold_block)?;
843    validate_block_rows("log_sigma", n, &spec.log_sigma_block)?;
844    validate_block_rows("wiggle", n, &spec.wiggle_block)?;
845    if spec.wiggle_degree < 1 {
846        return Err(format!(
847            "fit_binomial_location_scale_probit_wiggle: wiggle_degree must be >= 1, got {}",
848            spec.wiggle_degree
849        ));
850    }
851    if spec.wiggle_knots.len() < spec.wiggle_degree + 2 {
852        return Err(format!(
853            "fit_binomial_location_scale_probit_wiggle: wiggle_knots length {} is too short for degree {}",
854            spec.wiggle_knots.len(),
855            spec.wiggle_degree
856        ));
857    }
858
859    let BinomialLocationScaleProbitWiggleSpec {
860        y,
861        score,
862        weights,
863        sigma_min,
864        sigma_max,
865        wiggle_knots,
866        wiggle_degree,
867        mut threshold_block,
868        mut log_sigma_block,
869        wiggle_block,
870    } = spec;
871
872    match try_binomial_alpha_beta_warm_start(
873        &y,
874        &score,
875        &weights,
876        sigma_min,
877        sigma_max,
878        &threshold_block,
879        &log_sigma_block,
880        options,
881    ) {
882        Ok((beta_t0, beta_ls0, beta_warm)) => {
883            threshold_block.initial_beta = Some(beta_t0);
884            log_sigma_block.initial_beta = Some(beta_ls0);
885            emit_binomial_alpha_beta_warnings("warm-start-wiggle", &beta_warm, &y, &weights);
886        }
887        Err(err) => {
888            log::warn!(
889                "[GAMLSS][fit_binomial_location_scale_probit_wiggle] alpha/beta warm start failed, falling back to direct initialization: {}",
890                err
891            );
892        }
893    }
894
895    let family = BinomialLocationScaleProbitWiggleFamily {
896        y: y.clone(),
897        score: score.clone(),
898        weights: weights.clone(),
899        sigma_min,
900        sigma_max,
901        wiggle_knots,
902        wiggle_degree,
903    };
904    let blocks = vec![
905        threshold_block.into_spec("threshold")?,
906        log_sigma_block.into_spec("log_sigma")?,
907        wiggle_block.into_spec("wiggle")?,
908    ];
909    let fit = fit_custom_family(&family, &blocks, options)?;
910    let beta_final = fit.block_states[BinomialLocationScaleProbitWiggleFamily::BLOCK_LOG_SIGMA]
911        .eta
912        .mapv(f64::exp)
913        .mapv(|s| 1.0 / s.clamp(sigma_min, sigma_max).max(1e-12));
914    emit_binomial_alpha_beta_warnings("final-fit-wiggle", &beta_final, &y, &weights);
915    Ok(fit)
916}
917
918/// Link identifiers for distribution parameters in multi-parameter GAMLSS families.
919#[derive(Clone, Copy, Debug, PartialEq, Eq)]
920pub enum ParameterLink {
921    Identity,
922    Log,
923    Logit,
924    Probit,
925    /// Learnable smooth departure from a known base link.
926    Wiggle,
927}
928
929fn signed_with_floor(v: f64, floor: f64) -> f64 {
930    let a = v.abs().max(floor);
931    if v >= 0.0 { a } else { -a }
932}
933
934struct BinomialLocationScaleCore {
935    sigma: Array1<f64>,
936    dsigma_deta: Array1<f64>,
937    q0: Array1<f64>,
938    mu: Array1<f64>,
939    dmu_dq: Array1<f64>,
940    log_likelihood: f64,
941}
942
943fn binomial_location_scale_core(
944    y: &Array1<f64>,
945    score: &Array1<f64>,
946    weights: &Array1<f64>,
947    eta_t: &Array1<f64>,
948    eta_ls: &Array1<f64>,
949    eta_wiggle: Option<&Array1<f64>>,
950    sigma_min: f64,
951    sigma_max: f64,
952) -> Result<BinomialLocationScaleCore, String> {
953    let n = y.len();
954    if score.len() != n || weights.len() != n || eta_t.len() != n || eta_ls.len() != n {
955        return Err("binomial location-scale core size mismatch".to_string());
956    }
957    if let Some(w) = eta_wiggle {
958        if w.len() != n {
959            return Err("binomial location-scale core wiggle size mismatch".to_string());
960        }
961    }
962
963    let mut sigma = Array1::<f64>::zeros(n);
964    let mut dsigma_deta = Array1::<f64>::zeros(n);
965    let mut q0 = Array1::<f64>::zeros(n);
966    let mut mu = Array1::<f64>::zeros(n);
967    let mut dmu_dq = Array1::<f64>::zeros(n);
968    let mut ll = 0.0;
969
970    for i in 0..n {
971        let raw = eta_ls[i].exp();
972        sigma[i] = raw.clamp(sigma_min, sigma_max);
973        dsigma_deta[i] = if raw >= sigma_min && raw <= sigma_max {
974            raw
975        } else {
976            0.0
977        };
978        q0[i] = (score[i] - eta_t[i]) / sigma[i].max(1e-12);
979        let q = q0[i] + eta_wiggle.map_or(0.0, |w| w[i]);
980        mu[i] = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
981        dmu_dq[i] = normal_pdf(q).max(MIN_DERIV);
982        ll += weights[i] * (y[i] * mu[i].ln() + (1.0 - y[i]) * (1.0 - mu[i]).ln());
983    }
984
985    Ok(BinomialLocationScaleCore {
986        sigma,
987        dsigma_deta,
988        q0,
989        mu,
990        dmu_dq,
991        log_likelihood: ll,
992    })
993}
994
995fn binomial_location_scale_working_sets(
996    y: &Array1<f64>,
997    weights: &Array1<f64>,
998    eta_t: &Array1<f64>,
999    eta_ls: &Array1<f64>,
1000    eta_wiggle: Option<&Array1<f64>>,
1001    core: &BinomialLocationScaleCore,
1002) -> (BlockWorkingSet, BlockWorkingSet, Option<BlockWorkingSet>) {
1003    let n = y.len();
1004    let mut z_t = Array1::<f64>::zeros(n);
1005    let mut w_t = Array1::<f64>::zeros(n);
1006    let mut z_ls = Array1::<f64>::zeros(n);
1007    let mut w_ls = Array1::<f64>::zeros(n);
1008    let mut z_w = eta_wiggle.map(|_| Array1::<f64>::zeros(n));
1009    let mut w_w = eta_wiggle.map(|_| Array1::<f64>::zeros(n));
1010
1011    for i in 0..n {
1012        let var = (core.mu[i] * (1.0 - core.mu[i])).max(MIN_PROB);
1013
1014        // Location/threshold chain: dq/deta_t = -1/sigma
1015        let chain_t = -1.0 / core.sigma[i].max(1e-12);
1016        let dmu_t = core.dmu_dq[i] * chain_t;
1017        w_t[i] = (weights[i] * (dmu_t * dmu_t / var)).max(MIN_WEIGHT);
1018        z_t[i] = eta_t[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_t, MIN_DERIV);
1019
1020        // Scale chain: dq/deta_log_sigma = -q0 * dsigma/deta / sigma
1021        // This is the generic location-scale structure; the -Z multiplier appears here.
1022        let chain_ls = {
1023            let s = core.sigma[i].max(1e-12);
1024            -core.q0[i] * core.dsigma_deta[i] / s
1025        };
1026        let dmu_ls = core.dmu_dq[i] * chain_ls;
1027        w_ls[i] = (weights[i] * (dmu_ls * dmu_ls / var)).max(MIN_WEIGHT);
1028        z_ls[i] = eta_ls[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_ls, MIN_DERIV);
1029
1030        if let (Some(eta_w), Some(z_wv), Some(w_wv)) = (eta_wiggle, z_w.as_mut(), w_w.as_mut()) {
1031            // Wiggle enters additively in q, so chain is 1.
1032            let dmu_w = core.dmu_dq[i];
1033            w_wv[i] = (weights[i] * (dmu_w * dmu_w / var)).max(MIN_WEIGHT);
1034            z_wv[i] = eta_w[i] + (y[i] - core.mu[i]) / signed_with_floor(dmu_w, MIN_DERIV);
1035        }
1036    }
1037
1038    let t_ws = BlockWorkingSet {
1039        working_response: z_t,
1040        working_weights: w_t,
1041        gradient_eta: None,
1042    };
1043    let ls_ws = BlockWorkingSet {
1044        working_response: z_ls,
1045        working_weights: w_ls,
1046        gradient_eta: None,
1047    };
1048    let w_ws = match (z_w, w_w) {
1049        (Some(z), Some(w)) => Some(BlockWorkingSet {
1050            working_response: z,
1051            working_weights: w,
1052            gradient_eta: None,
1053        }),
1054        _ => None,
1055    };
1056    (t_ws, ls_ws, w_ws)
1057}
1058
1059/// Built-in Gaussian location-scale family:
1060/// - Block 0: location μ(·) with identity link
1061/// - Block 1: log-scale log σ(·) with log link
1062#[derive(Clone)]
1063pub struct GaussianLocationScaleFamily {
1064    pub y: Array1<f64>,
1065    pub weights: Array1<f64>,
1066    pub sigma_min: f64,
1067    pub sigma_max: f64,
1068}
1069
1070impl GaussianLocationScaleFamily {
1071    pub const BLOCK_MU: usize = 0;
1072    pub const BLOCK_LOG_SIGMA: usize = 1;
1073
1074    pub fn parameter_names() -> &'static [&'static str] {
1075        &["mu", "log_sigma"]
1076    }
1077
1078    pub fn parameter_links() -> &'static [ParameterLink] {
1079        &[ParameterLink::Identity, ParameterLink::Log]
1080    }
1081
1082    pub fn metadata() -> FamilyMetadata {
1083        FamilyMetadata {
1084            name: "gaussian_location_scale",
1085            parameter_names: Self::parameter_names(),
1086            parameter_links: Self::parameter_links(),
1087        }
1088    }
1089}
1090
1091impl CustomFamily for GaussianLocationScaleFamily {
1092    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1093        if block_states.len() != 2 {
1094            return Err(format!(
1095                "GaussianLocationScaleFamily expects 2 blocks, got {}",
1096                block_states.len()
1097            ));
1098        }
1099        let n = self.y.len();
1100        let eta_mu = &block_states[Self::BLOCK_MU].eta;
1101        let eta_log_sigma = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1102        if eta_mu.len() != n || eta_log_sigma.len() != n || self.weights.len() != n {
1103            return Err("GaussianLocationScaleFamily input size mismatch".to_string());
1104        }
1105
1106        let mut sigma = Array1::<f64>::zeros(n);
1107        let mut dsigma_deta = Array1::<f64>::zeros(n);
1108        let mut ll = 0.0;
1109
1110        for i in 0..n {
1111            let raw = eta_log_sigma[i].exp();
1112            sigma[i] = raw.clamp(self.sigma_min, self.sigma_max);
1113            dsigma_deta[i] = if raw >= self.sigma_min && raw <= self.sigma_max {
1114                raw
1115            } else {
1116                0.0
1117            };
1118            let r = self.y[i] - eta_mu[i];
1119            let s2 = (sigma[i] * sigma[i]).max(1e-20);
1120            ll += self.weights[i] * (-0.5 * (r * r / s2 + (2.0 * std::f64::consts::PI * s2).ln()));
1121        }
1122
1123        let mut z_mu = Array1::<f64>::zeros(n);
1124        let mut w_mu = Array1::<f64>::zeros(n);
1125        let mut z_ls = Array1::<f64>::zeros(n);
1126        let mut w_ls = Array1::<f64>::zeros(n);
1127
1128        for i in 0..n {
1129            let r = self.y[i] - eta_mu[i];
1130            let s = sigma[i].max(1e-10);
1131            let s2 = (s * s).max(1e-20);
1132
1133            // mu block (identity): canonical WLS
1134            w_mu[i] = (self.weights[i] / s2).max(MIN_WEIGHT);
1135            z_mu[i] = eta_mu[i] + r;
1136
1137            // log-sigma block using score + expected information (Fisher)
1138            // score_u = ((y-mu)^2 / sigma^2 - 1) * d(log sigma)/du, with u = eta_log_sigma.
1139            // Here d(log sigma)/du = dsigma/(sigma du).
1140            let dlogsigma_du = if dsigma_deta[i] == 0.0 {
1141                0.0
1142            } else {
1143                (dsigma_deta[i] / s).clamp(-1.0, 1.0)
1144            };
1145            let score_u = self.weights[i] * ((r * r / s2) - 1.0) * dlogsigma_du;
1146            let info_u = (2.0 * self.weights[i] * dlogsigma_du * dlogsigma_du).max(MIN_WEIGHT);
1147            z_ls[i] = eta_log_sigma[i] + score_u / info_u;
1148            w_ls[i] = info_u;
1149        }
1150
1151        Ok(FamilyEvaluation {
1152            log_likelihood: ll,
1153            block_working_sets: vec![
1154                BlockWorkingSet {
1155                    working_response: z_mu,
1156                    working_weights: w_mu,
1157                    gradient_eta: None,
1158                },
1159                BlockWorkingSet {
1160                    working_response: z_ls,
1161                    working_weights: w_ls,
1162                    gradient_eta: None,
1163                },
1164            ],
1165        })
1166    }
1167}
1168
1169impl CustomFamilyGenerative for GaussianLocationScaleFamily {
1170    fn generative_spec(
1171        &self,
1172        block_states: &[ParameterBlockState],
1173    ) -> Result<GenerativeSpec, String> {
1174        if block_states.len() != 2 {
1175            return Err(format!(
1176                "GaussianLocationScaleFamily expects 2 blocks, got {}",
1177                block_states.len()
1178            ));
1179        }
1180        let mu = block_states[Self::BLOCK_MU].eta.clone();
1181        let sigma = block_states[Self::BLOCK_LOG_SIGMA]
1182            .eta
1183            .mapv(f64::exp)
1184            .mapv(|s| s.clamp(self.sigma_min, self.sigma_max));
1185        Ok(GenerativeSpec {
1186            mean: mu,
1187            noise: NoiseModel::Gaussian { sigma },
1188        })
1189    }
1190}
1191
1192/// Built-in binomial logit family (single parameter block).
1193#[derive(Clone)]
1194pub struct BinomialLogitFamily {
1195    pub y: Array1<f64>,
1196    pub weights: Array1<f64>,
1197}
1198
1199impl BinomialLogitFamily {
1200    pub const BLOCK_ETA: usize = 0;
1201
1202    pub fn parameter_names() -> &'static [&'static str] {
1203        &["eta"]
1204    }
1205
1206    pub fn parameter_links() -> &'static [ParameterLink] {
1207        &[ParameterLink::Logit]
1208    }
1209
1210    pub fn metadata() -> FamilyMetadata {
1211        FamilyMetadata {
1212            name: "binomial_logit",
1213            parameter_names: Self::parameter_names(),
1214            parameter_links: Self::parameter_links(),
1215        }
1216    }
1217}
1218
1219impl CustomFamily for BinomialLogitFamily {
1220    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1221        if block_states.len() != 1 {
1222            return Err(format!(
1223                "BinomialLogitFamily expects 1 block, got {}",
1224                block_states.len()
1225            ));
1226        }
1227        let eta = &block_states[Self::BLOCK_ETA].eta;
1228        let n = self.y.len();
1229        if eta.len() != n || self.weights.len() != n {
1230            return Err("BinomialLogitFamily input size mismatch".to_string());
1231        }
1232        evaluate_single_block_glm(LikelihoodFamily::BinomialLogit, &self.y, &self.weights, eta)
1233    }
1234}
1235
1236impl CustomFamilyGenerative for BinomialLogitFamily {
1237    fn generative_spec(
1238        &self,
1239        block_states: &[ParameterBlockState],
1240    ) -> Result<GenerativeSpec, String> {
1241        if block_states.len() != 1 {
1242            return Err(format!(
1243                "BinomialLogitFamily expects 1 block, got {}",
1244                block_states.len()
1245            ));
1246        }
1247        let mean = block_states[Self::BLOCK_ETA].eta.mapv(|e| {
1248            (1.0 / (1.0 + (-e.clamp(-30.0, 30.0)).exp())).clamp(MIN_PROB, 1.0 - MIN_PROB)
1249        });
1250        Ok(GenerativeSpec {
1251            mean,
1252            noise: NoiseModel::Bernoulli,
1253        })
1254    }
1255}
1256
1257/// Built-in Poisson log-link family (single parameter block).
1258#[derive(Clone)]
1259pub struct PoissonLogFamily {
1260    pub y: Array1<f64>,
1261    pub weights: Array1<f64>,
1262}
1263
1264impl PoissonLogFamily {
1265    pub const BLOCK_ETA: usize = 0;
1266
1267    pub fn parameter_names() -> &'static [&'static str] {
1268        &["eta"]
1269    }
1270
1271    pub fn parameter_links() -> &'static [ParameterLink] {
1272        &[ParameterLink::Log]
1273    }
1274
1275    pub fn metadata() -> FamilyMetadata {
1276        FamilyMetadata {
1277            name: "poisson_log",
1278            parameter_names: Self::parameter_names(),
1279            parameter_links: Self::parameter_links(),
1280        }
1281    }
1282}
1283
1284impl CustomFamily for PoissonLogFamily {
1285    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1286        if block_states.len() != 1 {
1287            return Err(format!(
1288                "PoissonLogFamily expects 1 block, got {}",
1289                block_states.len()
1290            ));
1291        }
1292        let eta = &block_states[Self::BLOCK_ETA].eta;
1293        let n = self.y.len();
1294        if eta.len() != n || self.weights.len() != n {
1295            return Err("PoissonLogFamily input size mismatch".to_string());
1296        }
1297
1298        let mut mu = Array1::<f64>::zeros(n);
1299        let mut ll = 0.0;
1300        let mut z = Array1::<f64>::zeros(n);
1301        let mut w = Array1::<f64>::zeros(n);
1302
1303        for i in 0..n {
1304            let yi = self.y[i];
1305            if !yi.is_finite() || yi < 0.0 {
1306                return Err(format!(
1307                    "PoissonLogFamily requires non-negative finite y; found y[{i}]={yi}"
1308                ));
1309            }
1310            let e = eta[i].clamp(-30.0, 30.0);
1311            let m = e.exp().max(1e-12);
1312            mu[i] = m;
1313            // Drop log(y!) constant in objective.
1314            ll += self.weights[i] * (yi * e - m);
1315            let dmu = m.max(MIN_DERIV);
1316            let var = m.max(MIN_PROB);
1317            w[i] = (self.weights[i] * (dmu * dmu / var)).max(MIN_WEIGHT);
1318            z[i] = e + (yi - m) / signed_with_floor(dmu, MIN_DERIV);
1319        }
1320
1321        Ok(FamilyEvaluation {
1322            log_likelihood: ll,
1323            block_working_sets: vec![BlockWorkingSet {
1324                working_response: z,
1325                working_weights: w,
1326                gradient_eta: None,
1327            }],
1328        })
1329    }
1330}
1331
1332impl CustomFamilyGenerative for PoissonLogFamily {
1333    fn generative_spec(
1334        &self,
1335        block_states: &[ParameterBlockState],
1336    ) -> Result<GenerativeSpec, String> {
1337        if block_states.len() != 1 {
1338            return Err(format!(
1339                "PoissonLogFamily expects 1 block, got {}",
1340                block_states.len()
1341            ));
1342        }
1343        let mean = block_states[Self::BLOCK_ETA]
1344            .eta
1345            .mapv(|e| e.clamp(-30.0, 30.0).exp().max(1e-12));
1346        Ok(GenerativeSpec {
1347            mean,
1348            noise: NoiseModel::Poisson,
1349        })
1350    }
1351}
1352
1353/// Built-in Gamma log-link family (single parameter block, fixed shape).
1354#[derive(Clone)]
1355pub struct GammaLogFamily {
1356    pub y: Array1<f64>,
1357    pub weights: Array1<f64>,
1358    pub shape: f64,
1359}
1360
1361impl GammaLogFamily {
1362    pub const BLOCK_ETA: usize = 0;
1363
1364    pub fn parameter_names() -> &'static [&'static str] {
1365        &["eta"]
1366    }
1367
1368    pub fn parameter_links() -> &'static [ParameterLink] {
1369        &[ParameterLink::Log]
1370    }
1371
1372    pub fn metadata() -> FamilyMetadata {
1373        FamilyMetadata {
1374            name: "gamma_log",
1375            parameter_names: Self::parameter_names(),
1376            parameter_links: Self::parameter_links(),
1377        }
1378    }
1379}
1380
1381impl CustomFamily for GammaLogFamily {
1382    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1383        if block_states.len() != 1 {
1384            return Err(format!(
1385                "GammaLogFamily expects 1 block, got {}",
1386                block_states.len()
1387            ));
1388        }
1389        let eta = &block_states[Self::BLOCK_ETA].eta;
1390        let n = self.y.len();
1391        if eta.len() != n || self.weights.len() != n {
1392            return Err("GammaLogFamily input size mismatch".to_string());
1393        }
1394        if !self.shape.is_finite() || self.shape <= 0.0 {
1395            return Err("GammaLogFamily shape must be finite and > 0".to_string());
1396        }
1397
1398        let mut mu = Array1::<f64>::zeros(n);
1399        let mut ll = 0.0;
1400        let mut z = Array1::<f64>::zeros(n);
1401        let mut w = Array1::<f64>::zeros(n);
1402
1403        for i in 0..n {
1404            let yi = self.y[i];
1405            if !yi.is_finite() || yi <= 0.0 {
1406                return Err(format!(
1407                    "GammaLogFamily requires positive finite y; found y[{i}]={yi}"
1408                ));
1409            }
1410            let e = eta[i].clamp(-30.0, 30.0);
1411            let m = e.exp().max(1e-12);
1412            mu[i] = m;
1413            // Gamma(shape=k, scale=mu/k), dropping constants independent of eta.
1414            ll += self.weights[i] * (-self.shape * (yi / m + m.ln()));
1415            let dmu = m.max(MIN_DERIV);
1416            let var = (m * m / self.shape).max(MIN_PROB);
1417            w[i] = (self.weights[i] * (dmu * dmu / var)).max(MIN_WEIGHT);
1418            z[i] = e + (yi - m) / signed_with_floor(dmu, MIN_DERIV);
1419        }
1420
1421        Ok(FamilyEvaluation {
1422            log_likelihood: ll,
1423            block_working_sets: vec![BlockWorkingSet {
1424                working_response: z,
1425                working_weights: w,
1426                gradient_eta: None,
1427            }],
1428        })
1429    }
1430}
1431
1432impl CustomFamilyGenerative for GammaLogFamily {
1433    fn generative_spec(
1434        &self,
1435        block_states: &[ParameterBlockState],
1436    ) -> Result<GenerativeSpec, String> {
1437        if block_states.len() != 1 {
1438            return Err(format!(
1439                "GammaLogFamily expects 1 block, got {}",
1440                block_states.len()
1441            ));
1442        }
1443        let mean = block_states[Self::BLOCK_ETA]
1444            .eta
1445            .mapv(|e| e.clamp(-30.0, 30.0).exp().max(1e-12));
1446        Ok(GenerativeSpec {
1447            mean,
1448            noise: NoiseModel::Gamma { shape: self.shape },
1449        })
1450    }
1451}
1452
1453/// Built-in binomial location-scale probit family.
1454///
1455/// Parameters:
1456/// - Block 0: threshold/location T(covariates)
1457/// - Block 1: log-scale log σ(covariates)
1458/// - fixed score input enters as q = (score - T) / σ
1459#[derive(Clone)]
1460pub struct BinomialLocationScaleProbitFamily {
1461    pub y: Array1<f64>,
1462    pub score: Array1<f64>,
1463    pub weights: Array1<f64>,
1464    pub sigma_min: f64,
1465    pub sigma_max: f64,
1466}
1467
1468impl BinomialLocationScaleProbitFamily {
1469    pub const BLOCK_T: usize = 0;
1470    pub const BLOCK_LOG_SIGMA: usize = 1;
1471
1472    pub fn parameter_names() -> &'static [&'static str] {
1473        &["threshold", "log_sigma"]
1474    }
1475
1476    pub fn parameter_links() -> &'static [ParameterLink] {
1477        &[ParameterLink::Probit, ParameterLink::Log]
1478    }
1479
1480    pub fn metadata() -> FamilyMetadata {
1481        FamilyMetadata {
1482            name: "binomial_location_scale_probit",
1483            parameter_names: Self::parameter_names(),
1484            parameter_links: Self::parameter_links(),
1485        }
1486    }
1487}
1488
1489impl CustomFamily for BinomialLocationScaleProbitFamily {
1490    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1491        if block_states.len() != 2 {
1492            return Err(format!(
1493                "BinomialLocationScaleProbitFamily expects 2 blocks, got {}",
1494                block_states.len()
1495            ));
1496        }
1497        let n = self.y.len();
1498        let eta_t = &block_states[Self::BLOCK_T].eta;
1499        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1500        if eta_t.len() != n || eta_ls.len() != n || self.weights.len() != n || self.score.len() != n
1501        {
1502            return Err("BinomialLocationScaleProbitFamily input size mismatch".to_string());
1503        }
1504
1505        let core = binomial_location_scale_core(
1506            &self.y,
1507            &self.score,
1508            &self.weights,
1509            eta_t,
1510            eta_ls,
1511            None,
1512            self.sigma_min,
1513            self.sigma_max,
1514        )?;
1515        let (t_ws, ls_ws, _none) = binomial_location_scale_working_sets(
1516            &self.y,
1517            &self.weights,
1518            eta_t,
1519            eta_ls,
1520            None,
1521            &core,
1522        );
1523
1524        Ok(FamilyEvaluation {
1525            log_likelihood: core.log_likelihood,
1526            block_working_sets: vec![t_ws, ls_ws],
1527        })
1528    }
1529}
1530
1531impl CustomFamilyGenerative for BinomialLocationScaleProbitFamily {
1532    fn generative_spec(
1533        &self,
1534        block_states: &[ParameterBlockState],
1535    ) -> Result<GenerativeSpec, String> {
1536        if block_states.len() != 2 {
1537            return Err(format!(
1538                "BinomialLocationScaleProbitFamily expects 2 blocks, got {}",
1539                block_states.len()
1540            ));
1541        }
1542        let eta_t = &block_states[Self::BLOCK_T].eta;
1543        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1544        if eta_t.len() != self.score.len() || eta_ls.len() != self.score.len() {
1545            return Err("BinomialLocationScaleProbitFamily generative size mismatch".to_string());
1546        }
1547        let mut mean = Array1::<f64>::zeros(self.score.len());
1548        for i in 0..mean.len() {
1549            let sigma = eta_ls[i]
1550                .exp()
1551                .clamp(self.sigma_min, self.sigma_max)
1552                .max(1e-12);
1553            let q = (self.score[i] - eta_t[i]) / sigma;
1554            mean[i] = normal_cdf_approx(q).clamp(MIN_PROB, 1.0 - MIN_PROB);
1555        }
1556        Ok(GenerativeSpec {
1557            mean,
1558            noise: NoiseModel::Bernoulli,
1559        })
1560    }
1561}
1562
1563/// Built-in binomial location-scale probit with learnable wiggle on q.
1564///
1565/// Block structure:
1566/// - Block 0: threshold T(covariates)
1567/// - Block 1: log sigma(covariates)
1568/// - Block 2: wiggle(q) represented by B-spline coefficients on q
1569#[derive(Clone)]
1570pub struct BinomialLocationScaleProbitWiggleFamily {
1571    pub y: Array1<f64>,
1572    pub score: Array1<f64>,
1573    pub weights: Array1<f64>,
1574    pub sigma_min: f64,
1575    pub sigma_max: f64,
1576    pub wiggle_knots: Array1<f64>,
1577    pub wiggle_degree: usize,
1578}
1579
1580impl BinomialLocationScaleProbitWiggleFamily {
1581    pub const BLOCK_T: usize = 0;
1582    pub const BLOCK_LOG_SIGMA: usize = 1;
1583    pub const BLOCK_WIGGLE: usize = 2;
1584
1585    pub fn parameter_names() -> &'static [&'static str] {
1586        &["threshold", "log_sigma", "wiggle"]
1587    }
1588
1589    pub fn parameter_links() -> &'static [ParameterLink] {
1590        &[
1591            ParameterLink::Probit,
1592            ParameterLink::Log,
1593            ParameterLink::Wiggle,
1594        ]
1595    }
1596
1597    pub fn metadata() -> FamilyMetadata {
1598        FamilyMetadata {
1599            name: "binomial_location_scale_probit_wiggle",
1600            parameter_names: Self::parameter_names(),
1601            parameter_links: Self::parameter_links(),
1602        }
1603    }
1604
1605    pub fn initialize_wiggle_knots_from_q(
1606        q_seed: ArrayView1<'_, f64>,
1607        degree: usize,
1608        num_internal_knots: usize,
1609    ) -> Result<Array1<f64>, String> {
1610        initialize_wiggle_knots_from_seed(q_seed, degree, num_internal_knots)
1611    }
1612
1613    fn wiggle_design(&self, q0: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
1614        let (basis, _) = create_basis::<Dense>(
1615            q0,
1616            KnotSource::Provided(self.wiggle_knots.view()),
1617            self.wiggle_degree,
1618            BasisOptions::value(),
1619        )
1620        .map_err(|e| e.to_string())?;
1621        let full = (*basis).clone();
1622        if full.ncols() < 2 {
1623            return Err("wiggle basis has fewer than two columns".to_string());
1624        }
1625        Ok(full.slice(s![.., 1..]).to_owned())
1626    }
1627
1628    /// Build a turnkey wiggle block from a q-seed vector and knot settings.
1629    /// Returns both the block input and the generated knot vector.
1630    pub fn build_wiggle_block_input(
1631        q_seed: ArrayView1<'_, f64>,
1632        degree: usize,
1633        num_internal_knots: usize,
1634        penalty_order: usize,
1635        double_penalty: bool,
1636    ) -> Result<(ParameterBlockInput, Array1<f64>), String> {
1637        let knots = Self::initialize_wiggle_knots_from_q(q_seed, degree, num_internal_knots)?;
1638        let block = build_wiggle_block_input_from_knots(
1639            q_seed,
1640            &knots,
1641            degree,
1642            penalty_order,
1643            double_penalty,
1644        )?;
1645        Ok((block, knots))
1646    }
1647}
1648
1649impl CustomFamily for BinomialLocationScaleProbitWiggleFamily {
1650    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
1651        if block_states.len() != 3 {
1652            return Err(format!(
1653                "BinomialLocationScaleProbitWiggleFamily expects 3 blocks, got {}",
1654                block_states.len()
1655            ));
1656        }
1657        let n = self.y.len();
1658        let eta_t = &block_states[Self::BLOCK_T].eta;
1659        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1660        let eta_w = &block_states[Self::BLOCK_WIGGLE].eta;
1661        if eta_t.len() != n || eta_ls.len() != n || eta_w.len() != n || self.score.len() != n {
1662            return Err("BinomialLocationScaleProbitWiggleFamily input size mismatch".to_string());
1663        }
1664
1665        let core = binomial_location_scale_core(
1666            &self.y,
1667            &self.score,
1668            &self.weights,
1669            eta_t,
1670            eta_ls,
1671            Some(eta_w),
1672            self.sigma_min,
1673            self.sigma_max,
1674        )?;
1675        let (t_ws, ls_ws, w_ws) = binomial_location_scale_working_sets(
1676            &self.y,
1677            &self.weights,
1678            eta_t,
1679            eta_ls,
1680            Some(eta_w),
1681            &core,
1682        );
1683        let w_ws = w_ws.ok_or_else(|| "wiggle working set missing".to_string())?;
1684
1685        Ok(FamilyEvaluation {
1686            log_likelihood: core.log_likelihood,
1687            block_working_sets: vec![t_ws, ls_ws, w_ws],
1688        })
1689    }
1690
1691    fn known_link_wiggle(&self) -> Option<KnownLinkWiggle> {
1692        Some(KnownLinkWiggle {
1693            base_link: LinkFunction::Probit,
1694            wiggle_block: Some(Self::BLOCK_WIGGLE),
1695        })
1696    }
1697
1698    fn block_geometry(
1699        &self,
1700        block_index: usize,
1701        block_states: &[ParameterBlockState],
1702        spec: &crate::custom_family::ParameterBlockSpec,
1703    ) -> Result<(DesignMatrix, Array1<f64>), String> {
1704        if block_index != Self::BLOCK_WIGGLE {
1705            return Ok((spec.design.clone(), spec.offset.clone()));
1706        }
1707        if block_states.len() < 2 {
1708            return Err("wiggle geometry requires threshold and log-sigma blocks".to_string());
1709        }
1710        let eta_t = &block_states[Self::BLOCK_T].eta;
1711        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1712        if eta_t.len() != self.score.len() || eta_ls.len() != self.score.len() {
1713            return Err("wiggle geometry input size mismatch".to_string());
1714        }
1715        let mut q0 = Array1::<f64>::zeros(self.score.len());
1716        for i in 0..q0.len() {
1717            let sigma = eta_ls[i]
1718                .exp()
1719                .clamp(self.sigma_min, self.sigma_max)
1720                .max(1e-12);
1721            q0[i] = (self.score[i] - eta_t[i]) / sigma;
1722        }
1723        let x = self.wiggle_design(q0.view())?;
1724        if x.ncols() != spec.design.ncols() {
1725            return Err(format!(
1726                "dynamic wiggle design col mismatch: got {}, expected {}",
1727                x.ncols(),
1728                spec.design.ncols()
1729            ));
1730        }
1731        let nrows = x.nrows();
1732        Ok((DesignMatrix::Dense(x), Array1::zeros(nrows)))
1733    }
1734}
1735
1736impl CustomFamilyGenerative for BinomialLocationScaleProbitWiggleFamily {
1737    fn generative_spec(
1738        &self,
1739        block_states: &[ParameterBlockState],
1740    ) -> Result<GenerativeSpec, String> {
1741        if block_states.len() != 3 {
1742            return Err(format!(
1743                "BinomialLocationScaleProbitWiggleFamily expects 3 blocks, got {}",
1744                block_states.len()
1745            ));
1746        }
1747        let eta_t = &block_states[Self::BLOCK_T].eta;
1748        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
1749        let eta_w = &block_states[Self::BLOCK_WIGGLE].eta;
1750        if eta_t.len() != self.score.len()
1751            || eta_ls.len() != self.score.len()
1752            || eta_w.len() != self.score.len()
1753        {
1754            return Err(
1755                "BinomialLocationScaleProbitWiggleFamily generative size mismatch".to_string(),
1756            );
1757        }
1758        let mut mean = Array1::<f64>::zeros(self.score.len());
1759        for i in 0..mean.len() {
1760            let sigma = eta_ls[i]
1761                .exp()
1762                .clamp(self.sigma_min, self.sigma_max)
1763                .max(1e-12);
1764            let q0 = (self.score[i] - eta_t[i]) / sigma;
1765            mean[i] = normal_cdf_approx(q0 + eta_w[i]).clamp(MIN_PROB, 1.0 - MIN_PROB);
1766        }
1767        Ok(GenerativeSpec {
1768            mean,
1769            noise: NoiseModel::Bernoulli,
1770        })
1771    }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776    use super::*;
1777
1778    fn intercept_block(n: usize) -> ParameterBlockInput {
1779        ParameterBlockInput {
1780            design: DesignMatrix::Dense(Array2::from_elem((n, 1), 1.0)),
1781            offset: Array1::zeros(n),
1782            penalties: Vec::new(),
1783            initial_log_lambdas: None,
1784            initial_beta: None,
1785        }
1786    }
1787
1788    #[test]
1789    fn weighted_projection_returns_finite_coefficients() {
1790        let n = 8usize;
1791        let design = DesignMatrix::Dense(Array2::from_elem((n, 1), 1.0));
1792        let offset = Array1::zeros(n);
1793        let target_eta = Array1::from_vec(vec![0.2; n]);
1794        let weights = Array1::from_vec(vec![1.0; n]);
1795        let beta =
1796            solve_weighted_projection(&design, &offset, &target_eta, &weights, 1e-10).unwrap();
1797        assert_eq!(beta.len(), 1);
1798        assert!(beta[0].is_finite());
1799        assert!((beta[0] - 0.2).abs() < 1e-6);
1800    }
1801
1802    #[test]
1803    fn alpha_beta_warm_start_produces_finite_targets() {
1804        let n = 16usize;
1805        let y = Array1::from_vec((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }).collect());
1806        let score = Array1::from_vec((0..n).map(|i| i as f64 / n as f64 - 0.5).collect());
1807        let weights = Array1::from_vec(vec![1.0; n]);
1808        let threshold = intercept_block(n);
1809        let log_sigma = intercept_block(n);
1810
1811        let (beta_t, beta_ls, beta_obs) = try_binomial_alpha_beta_warm_start(
1812            &y,
1813            &score,
1814            &weights,
1815            0.25,
1816            4.0,
1817            &threshold,
1818            &log_sigma,
1819            &BlockwiseFitOptions::default(),
1820        )
1821        .unwrap();
1822
1823        assert_eq!(beta_t.len(), 1);
1824        assert_eq!(beta_ls.len(), 1);
1825        assert!(beta_t[0].is_finite());
1826        assert!(beta_ls[0].is_finite());
1827        assert!(beta_obs.iter().all(|v| v.is_finite() && *v > 0.0));
1828    }
1829
1830    #[test]
1831    fn fit_binomial_location_scale_probit_runs_with_warm_start_path() {
1832        let n = 32usize;
1833        let y = Array1::from_vec((0..n).map(|i| if i % 4 == 0 { 1.0 } else { 0.0 }).collect());
1834        let score = Array1::from_vec((0..n).map(|i| (i as f64 - 16.0) / 10.0).collect());
1835        let weights = Array1::from_vec(vec![1.0; n]);
1836        let spec = BinomialLocationScaleProbitSpec {
1837            y,
1838            score,
1839            weights,
1840            sigma_min: 0.3,
1841            sigma_max: 3.0,
1842            threshold_block: intercept_block(n),
1843            log_sigma_block: intercept_block(n),
1844        };
1845
1846        let fit = fit_binomial_location_scale_probit(spec, &BlockwiseFitOptions::default())
1847            .expect("binomial location-scale probit should fit");
1848        assert_eq!(fit.block_states.len(), 2);
1849        assert!(fit.log_likelihood.is_finite());
1850    }
1851}