Skip to main content

gam_config/
lib.rs

1use gam_models::survival::location_scale::residual_distribution_inverse_link;
2use gam_models::survival::lognormal_kernel::{FrailtySpec, HazardLoading};
3use gam_models::survival::parse_survival_distribution;
4use gam_models::survival::{SurvivalLikelihoodMode, parse_survival_likelihood_mode};
5use gam_inference::formula_dsl::parse_link_choice;
6use gam_inference::model::GroupMetadata;
7use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
8use gam_models::fit_orchestration::descriptors::build_analytic_penalty_registry_from_descriptors;
9use gam_models::fit_orchestration::{CtnStage1Recipe, FitConfig};
10use gam_models::transformation_normal::TransformationNormalConfig;
11use gam_problem::types::{
12    InverseLink, LinkComponent, LinkFunction, MixtureLinkSpec, SasLinkSpec, StandardLink,
13};
14use ndarray::Array1;
15use serde::Deserialize;
16use serde_json::Value as JsonValue;
17use std::collections::BTreeMap;
18
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum CliFrailtyKind {
21    GaussianShift,
22    HazardMultiplier,
23}
24
25#[derive(Clone, Copy, Debug, Eq, PartialEq)]
26pub enum CliHazardLoading {
27    Full,
28    LoadedVsUnloaded,
29}
30
31#[derive(Default, Deserialize)]
32#[serde(deny_unknown_fields)]
33struct JsonFitConfig {
34    family: Option<String>,
35    offset: Option<String>,
36    weights: Option<String>,
37    ridge_lambda: Option<f64>,
38    transformation_normal: Option<bool>,
39    survival_likelihood: Option<String>,
40    baseline_target: Option<String>,
41    baseline_scale: Option<f64>,
42    baseline_shape: Option<f64>,
43    baseline_rate: Option<f64>,
44    baseline_makeham: Option<f64>,
45    z_column: Option<String>,
46    logslope_formula: Option<String>,
47    ctn_stage1: Option<JsonCtnStage1>,
48    link: Option<String>,
49    flexible_link: Option<bool>,
50    scale_dimensions: Option<bool>,
51    adaptive_regularization: Option<bool>,
52    noise_formula: Option<String>,
53    noise_offset: Option<String>,
54    firth: Option<bool>,
55    outer_max_iter: Option<usize>,
56    gpu: Option<String>,
57    group_metadata: Option<GroupMetadata>,
58    groups: Option<JsonValue>,
59    precision_hyperpriors: Option<JsonValue>,
60    penalty_block_gamma_priors: Option<JsonValue>,
61    latents: Option<JsonValue>,
62    penalties: Option<JsonValue>,
63    smooths: Option<JsonValue>,
64    topology_auto_selector: Option<JsonValue>,
65    frailty_kind: Option<String>,
66    frailty_sd: Option<f64>,
67    hazard_loading: Option<String>,
68    training_table_kind: Option<String>,
69}
70
71#[derive(Deserialize)]
72#[serde(deny_unknown_fields)]
73struct JsonCtnStage1 {
74    response_column: String,
75    covariate_formula_rhs: String,
76    #[serde(default)]
77    config: Option<JsonCtnStage1Config>,
78    #[serde(default)]
79    weight_column: Option<String>,
80    #[serde(default)]
81    offset_column: Option<String>,
82}
83
84#[derive(Deserialize)]
85#[serde(deny_unknown_fields)]
86struct JsonCtnStage1Config {
87    #[serde(default)]
88    response_degree: Option<usize>,
89    #[serde(default)]
90    response_num_internal_knots: Option<usize>,
91    #[serde(default)]
92    response_penalty_order: Option<usize>,
93    #[serde(default)]
94    response_extra_penalty_orders: Option<Vec<usize>>,
95    #[serde(default)]
96    double_penalty: Option<bool>,
97}
98
99impl JsonCtnStage1 {
100    fn into_recipe(self) -> Result<CtnStage1Recipe, String> {
101        let mut config = TransformationNormalConfig::default();
102        if let Some(overrides) = self.config {
103            if let Some(value) = overrides.response_degree {
104                config.response_degree = value;
105            }
106            if let Some(value) = overrides.response_num_internal_knots {
107                config.response_num_internal_knots = value;
108            }
109            if let Some(value) = overrides.response_penalty_order {
110                config.response_penalty_order = value;
111            }
112            if let Some(value) = overrides.response_extra_penalty_orders {
113                config.response_extra_penalty_orders = value;
114            }
115            if let Some(value) = overrides.double_penalty {
116                config.double_penalty = value;
117            }
118        }
119        CtnStage1Recipe::new(
120            &self.response_column,
121            &self.covariate_formula_rhs,
122            config,
123            self.weight_column.as_deref(),
124            self.offset_column.as_deref(),
125        )
126    }
127}
128
129pub struct ResolvedFitConfig {
130    pub fit_config: FitConfig,
131    pub training_table_kind: Option<String>,
132}
133
134pub struct CliFitConfigInput {
135    pub family: Option<String>,
136    /// Expectile asymmetry `τ ∈ (0, 1)` for `--family expectile`. The CLI family
137    /// flag is a fixed value-enum that cannot carry the inline `expectile(τ)`
138    /// asymmetry the string form accepts, so the CLI pins `τ` through this
139    /// separate `--expectile-tau` field; it rides into `FitConfig::expectile_tau`
140    /// and is read by the shared expectile dispatch seam (#1777).
141    pub expectile_tau: Option<f64>,
142    pub negative_binomial_theta: Option<f64>,
143    pub link: Option<String>,
144    pub flexible_link: bool,
145    pub offset_column: Option<String>,
146    pub weight_column: Option<String>,
147    pub noise_offset_column: Option<String>,
148    pub baseline_target: String,
149    pub baseline_scale: Option<f64>,
150    pub baseline_shape: Option<f64>,
151    pub baseline_rate: Option<f64>,
152    pub baseline_makeham: Option<f64>,
153    pub time_basis: String,
154    pub time_degree: usize,
155    pub time_num_internal_knots: usize,
156    pub time_smooth_lambda: f64,
157    pub survival_likelihood: String,
158    pub survival_distribution: String,
159    pub threshold_time_k: Option<usize>,
160    pub threshold_time_degree: usize,
161    pub sigma_time_k: Option<usize>,
162    pub sigma_time_degree: usize,
163    pub noise_formula: Option<String>,
164    pub logslope_formula: Option<String>,
165    pub z_column: Option<String>,
166    pub scale_dimensions: bool,
167    pub adaptive_regularization: Option<bool>,
168    pub ridge_lambda: f64,
169    pub transformation_normal: bool,
170    pub firth: bool,
171    pub outer_max_iter: Option<usize>,
172    pub gpu: Option<String>,
173    pub frailty_kind: Option<CliFrailtyKind>,
174    pub frailty_sd: Option<f64>,
175    pub hazard_loading: Option<CliHazardLoading>,
176}
177
178pub struct SurvivalInverseLinkInput<'a> {
179    pub link: Option<&'a str>,
180    pub mixture_rho: Option<&'a str>,
181    pub sas_init: Option<&'a str>,
182    pub beta_logistic_init: Option<&'a str>,
183    pub survival_distribution: &'a str,
184}
185
186pub fn parse_fit_config_json(config_json: Option<&str>) -> Result<ResolvedFitConfig, String> {
187    let json_config = match config_json {
188        Some(raw) if !raw.trim().is_empty() => serde_json::from_str::<JsonFitConfig>(raw)
189            .map_err(|err| format!("invalid fit config json: {err}"))?,
190        _ => JsonFitConfig::default(),
191    };
192    resolve_json_fit_config(json_config)
193}
194
195fn resolve_json_fit_config(json_config: JsonFitConfig) -> Result<ResolvedFitConfig, String> {
196    let training_table_kind = json_config.training_table_kind;
197    let mut fit_config = FitConfig::default();
198    fit_config.group_metadata =
199        parse_group_metadata(json_config.group_metadata, json_config.groups)?;
200    fit_config.penalty_block_gamma_priors = parse_precision_hyperpriors(
201        json_config.precision_hyperpriors,
202        json_config.penalty_block_gamma_priors,
203    )?;
204    let analytic_penalties = json_config.penalties;
205    build_analytic_penalty_registry_from_descriptors(
206        json_config.latents.as_ref(),
207        analytic_penalties.as_ref(),
208    )?;
209    fit_config.latents = json_config.latents;
210    fit_config.analytic_penalties = analytic_penalties;
211    fit_config.smooth_overrides = json_config.smooths;
212    fit_config.topology_auto_selector = json_config
213        .topology_auto_selector
214        .as_ref()
215        .map(gam_solve::topology_selector::TopologyAutoSelector::from_json)
216        .transpose()?;
217    fit_config.family = normalize_optional_family(json_config.family);
218    fit_config.offset_column = json_config.offset;
219    fit_config.weight_column = json_config.weights;
220    if let Some(ridge_lambda) = json_config.ridge_lambda {
221        fit_config.ridge_lambda = ridge_lambda;
222    }
223    if let Some(flag) = json_config.transformation_normal {
224        fit_config.transformation_normal = flag;
225    }
226    if let Some(mode) = json_config.survival_likelihood {
227        fit_config.survival_likelihood = parse_survival_likelihood_cli(&mode)?;
228    }
229    if let Some(target) = json_config.baseline_target {
230        fit_config.baseline_target =
231            resolve_nonempty_string(target, "baseline_target must be a non-empty string")?;
232    }
233    if let Some(value) = json_config.baseline_scale {
234        fit_config.baseline_scale = Some(value);
235    }
236    if let Some(value) = json_config.baseline_shape {
237        fit_config.baseline_shape = Some(value);
238    }
239    if let Some(value) = json_config.baseline_rate {
240        fit_config.baseline_rate = Some(value);
241    }
242    if let Some(value) = json_config.baseline_makeham {
243        fit_config.baseline_makeham = Some(value);
244    }
245    if let Some(z) = json_config.z_column {
246        fit_config.z_column = Some(resolve_nonempty_string(
247            z,
248            "z_column must be a non-empty string",
249        )?);
250    }
251    if let Some(formula) = json_config.logslope_formula {
252        fit_config.logslope_formula = Some(formula);
253    }
254    if let Some(stage1) = json_config.ctn_stage1 {
255        fit_config.ctn_stage1 = Some(stage1.into_recipe()?);
256    }
257    if let Some(link) = json_config.link {
258        let trimmed = link.trim();
259        fit_config.link = if trimmed.is_empty() {
260            None
261        } else {
262            Some(trimmed.to_string())
263        };
264    }
265    if let Some(flag) = json_config.flexible_link {
266        fit_config.flexible_link = flag;
267    }
268    if let Some(flag) = json_config.scale_dimensions {
269        fit_config.scale_dimensions = flag;
270    }
271    if let Some(flag) = json_config.adaptive_regularization {
272        fit_config.adaptive_regularization = Some(flag);
273    }
274    if let Some(formula) = json_config.noise_formula {
275        fit_config.noise_formula = Some(formula);
276    }
277    if let Some(column) = json_config.noise_offset {
278        fit_config.noise_offset_column = Some(column);
279    }
280    if let Some(flag) = json_config.firth {
281        fit_config.firth = flag;
282    }
283    if let Some(value) = json_config.outer_max_iter {
284        if value == 0 {
285            return Err("outer_max_iter must be >= 1".to_string());
286        }
287        fit_config.outer_max_iter = Some(value);
288    }
289    if let Some(raw_gpu) = json_config.gpu {
290        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
291    }
292    fit_config.frailty = parse_json_frailty_spec(
293        json_config.frailty_kind,
294        json_config.frailty_sd,
295        json_config.hazard_loading,
296    )?;
297    validate_resolved_fit_config(&fit_config)?;
298    Ok(ResolvedFitConfig {
299        fit_config,
300        training_table_kind,
301    })
302}
303
304pub fn resolve_cli_fit_config(input: CliFitConfigInput) -> Result<FitConfig, String> {
305    let mut fit_config = FitConfig::default();
306    fit_config.family = input.family;
307    fit_config.expectile_tau = input.expectile_tau;
308    fit_config.negative_binomial_theta = input.negative_binomial_theta;
309    fit_config.link = input.link;
310    fit_config.flexible_link = input.flexible_link;
311    fit_config.offset_column = input.offset_column;
312    fit_config.weight_column = input.weight_column;
313    fit_config.noise_offset_column = input.noise_offset_column;
314    fit_config.baseline_target = input.baseline_target;
315    fit_config.baseline_scale = input.baseline_scale;
316    fit_config.baseline_shape = input.baseline_shape;
317    fit_config.baseline_rate = input.baseline_rate;
318    fit_config.baseline_makeham = input.baseline_makeham;
319    fit_config.time_basis = input.time_basis;
320    fit_config.time_degree = input.time_degree;
321    fit_config.time_num_internal_knots = input.time_num_internal_knots;
322    fit_config.time_smooth_lambda = input.time_smooth_lambda;
323    fit_config.survival_likelihood = parse_survival_likelihood_cli(&input.survival_likelihood)?;
324    fit_config.survival_distribution = input.survival_distribution;
325    fit_config.threshold_time_k = input.threshold_time_k;
326    fit_config.threshold_time_degree = input.threshold_time_degree;
327    fit_config.sigma_time_k = input.sigma_time_k;
328    fit_config.sigma_time_degree = input.sigma_time_degree;
329    fit_config.noise_formula = input.noise_formula;
330    fit_config.logslope_formula = input.logslope_formula;
331    fit_config.z_column = input.z_column;
332    fit_config.scale_dimensions = input.scale_dimensions;
333    fit_config.adaptive_regularization = input.adaptive_regularization;
334    fit_config.ridge_lambda = input.ridge_lambda;
335    fit_config.transformation_normal = input.transformation_normal;
336    fit_config.firth = input.firth;
337    fit_config.outer_max_iter = input.outer_max_iter;
338    if let Some(raw_gpu) = input.gpu {
339        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
340    }
341    fit_config.frailty = Some(resolve_cli_frailty_spec(
342        input.frailty_kind,
343        input.frailty_sd,
344        input.hazard_loading,
345        "fit",
346    )?);
347    validate_resolved_fit_config(&fit_config)?;
348    Ok(fit_config)
349}
350
351pub fn resolve_cli_frailty_spec(
352    frailty_kind: Option<CliFrailtyKind>,
353    frailty_sd: Option<f64>,
354    hazard_loading: Option<CliHazardLoading>,
355    context: &str,
356) -> Result<FrailtySpec, String> {
357    let validate_sigma = || -> Result<Option<f64>, String> {
358        match frailty_sd {
359            None => Ok(None),
360            Some(sigma) => {
361                if !sigma.is_finite() || sigma < 0.0 {
362                    return Err(format!(
363                        "{context} requires a finite --frailty-sd >= 0, got {sigma}"
364                    ));
365                }
366                Ok(Some(sigma))
367            }
368        }
369    };
370
371    match frailty_kind {
372        None => {
373            if frailty_sd.is_some() || hazard_loading.is_some() {
374                return Err(format!(
375                    "{context} requires --frailty-kind when --frailty-sd or --hazard-loading is provided"
376                ));
377            }
378            Ok(FrailtySpec::None)
379        }
380        Some(CliFrailtyKind::GaussianShift) => {
381            if hazard_loading.is_some() {
382                return Err(format!(
383                    "{context} does not accept --hazard-loading with --frailty-kind gaussian-shift"
384                ));
385            }
386            Ok(FrailtySpec::GaussianShift {
387                sigma_fixed: validate_sigma()?,
388            })
389        }
390        Some(CliFrailtyKind::HazardMultiplier) => Ok(FrailtySpec::HazardMultiplier {
391            sigma_fixed: validate_sigma()?,
392            loading: hazard_loading.map(cli_hazard_loading).ok_or_else(|| {
393                format!("{context} requires --hazard-loading with --frailty-kind hazard-multiplier")
394            })?,
395        }),
396    }
397}
398
399pub fn parse_survival_likelihood_cli(raw: &str) -> Result<String, String> {
400    let normalized = raw.trim().to_ascii_lowercase();
401    parse_survival_likelihood_mode(&normalized)?;
402    Ok(normalized)
403}
404
405pub fn parse_baseline_target_cli(raw: &str) -> Result<String, String> {
406    let normalized = raw.trim().to_ascii_lowercase();
407    match normalized.as_str() {
408        "linear" | "weibull" | "gompertz" | "gompertz-makeham" => Ok(normalized),
409        other => Err(format!(
410            "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
411        )),
412    }
413}
414
415pub fn validate_survival_baseline_args(
416    likelihood_mode: SurvivalLikelihoodMode,
417    baseline_target: &str,
418    baseline_scale: Option<f64>,
419    baseline_shape: Option<f64>,
420    baseline_rate: Option<f64>,
421    baseline_makeham: Option<f64>,
422) -> Result<(), String> {
423    if likelihood_mode == SurvivalLikelihoodMode::Weibull {
424        if baseline_rate.is_some() || baseline_makeham.is_some() {
425            return Err(
426                "--survival-likelihood weibull does not use --baseline-rate or --baseline-makeham"
427                    .to_string(),
428            );
429        }
430        if !matches!(baseline_target, "linear" | "weibull") {
431            return Err(
432                "--survival-likelihood weibull supports only --baseline-target linear|weibull"
433                    .to_string(),
434            );
435        }
436        return Ok(());
437    }
438
439    match baseline_target {
440        "linear" => {
441            if baseline_scale.is_some()
442                || baseline_shape.is_some()
443                || baseline_rate.is_some()
444                || baseline_makeham.is_some()
445            {
446                return Err(
447                    "--baseline-target linear does not use baseline parameter flags".to_string(),
448                );
449            }
450        }
451        "weibull" => {
452            if baseline_rate.is_some() || baseline_makeham.is_some() {
453                return Err(
454                    "--baseline-target weibull does not use --baseline-rate or --baseline-makeham"
455                        .to_string(),
456                );
457            }
458        }
459        "gompertz" => {
460            if baseline_scale.is_some() || baseline_makeham.is_some() {
461                return Err(
462                    "--baseline-target gompertz does not use --baseline-scale or --baseline-makeham"
463                        .to_string(),
464                );
465            }
466        }
467        "gompertz-makeham" => {
468            if baseline_scale.is_some() {
469                return Err(
470                    "--baseline-target gompertz-makeham does not use --baseline-scale".to_string(),
471                );
472            }
473        }
474        other => {
475            return Err(format!(
476                "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
477            ));
478        }
479    }
480    Ok(())
481}
482
483pub fn parse_comma_f64(v: &str, label: &str) -> Result<Vec<f64>, String> {
484    let mut out = Vec::new();
485    for part in v.split(',') {
486        let t = part.trim();
487        if t.is_empty() {
488            continue;
489        }
490        let parsed = t
491            .parse::<f64>()
492            .map_err(|err| format!("{label} contains non-numeric value '{t}': {err}"))?;
493        if !parsed.is_finite() {
494            return Err(format!("{label} contains non-finite value '{t}'"));
495        }
496        out.push(parsed);
497    }
498    Ok(out)
499}
500
501pub fn effective_link_to_standard(
502    link: LinkFunction,
503    context: &str,
504) -> Result<StandardLink, String> {
505    StandardLink::try_from(link).map_err(|_| {
506        format!(
507            "{context}: state-bearing link `{}` must be routed through `InverseLink::Sas` / `InverseLink::BetaLogistic`, not `Standard(_)`",
508            link.name()
509        )
510    })
511}
512
513pub fn parse_survival_inverse_link(
514    input: SurvivalInverseLinkInput<'_>,
515) -> Result<InverseLink, String> {
516    if let Some(raw) = input.link {
517        let name = raw.trim().to_ascii_lowercase();
518        if name == "loglog" || name == "cauchit" {
519            // `loglog` and `cauchit` have no scalar `LinkFunction`/`StandardLink`
520            // representative, but the blended-link kernels implement their inverse
521            // link and derivative jets exactly (`LinkComponent::LogLog` /
522            // `LinkComponent::Cauchit`). Represent a survival `--link loglog` /
523            // `--link cauchit` as a single-component mixture: it carries weight 1.0
524            // with no free mixing logits, so it evaluates as exactly that link and
525            // flows end-to-end through the fully-wired `InverseLink::Mixture` survival
526            // path (prepare/construct/row-kernel/predict).
527            if input.sas_init.is_some() {
528                return Err("--sas-init requires --link sas".to_string());
529            }
530            if input.beta_logistic_init.is_some() {
531                return Err("--beta-logistic-init requires --link beta-logistic".to_string());
532            }
533            if input.mixture_rho.is_some() {
534                return Err(
535                    "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
536                );
537            }
538            let component = if name == "loglog" {
539                LinkComponent::LogLog
540            } else {
541                LinkComponent::Cauchit
542            };
543            return state_fromspec(&MixtureLinkSpec {
544                components: vec![component],
545                initial_rho: Array1::zeros(0),
546            })
547            .map(InverseLink::Mixture)
548            .map_err(|e| format!("invalid survival {name} link state: {e}"));
549        }
550    }
551    let choice = parse_link_choice(input.link, false).map_err(|err| {
552        let err = err.to_string();
553        if let Some(raw) = input.link {
554            let name = raw.trim().to_ascii_lowercase();
555            if err.starts_with("unsupported --link ") || err.starts_with("unsupported link type ") {
556                return format!(
557                    "unsupported survival --link '{name}'; {}",
558                    survival_link_usage()
559                );
560            }
561        }
562        err
563    })?;
564    if let Some(choice) = choice {
565        if let Some(components) = choice.mixture_components {
566            if input.sas_init.is_some() || input.beta_logistic_init.is_some() {
567                return Err(
568                    "survival blended(...) link does not accept --sas-init/--beta-logistic-init"
569                        .to_string(),
570                );
571            }
572            let expected = components.len().saturating_sub(1);
573            let initial_rho = if let Some(raw) = input.mixture_rho {
574                let vals = parse_comma_f64(raw, "--mixture-rho")?;
575                if vals.len() != expected {
576                    return Err(format!(
577                        "--mixture-rho expects {expected} values for blended({})",
578                        components
579                            .iter()
580                            .map(|component| component.name())
581                            .collect::<Vec<_>>()
582                            .join(",")
583                    ));
584                }
585                Array1::from_vec(vals)
586            } else {
587                Array1::zeros(expected)
588            };
589            return state_fromspec(&MixtureLinkSpec {
590                components,
591                initial_rho,
592            })
593            .map(InverseLink::Mixture)
594            .map_err(|e| format!("invalid survival blended link state: {e}"));
595        }
596
597        if input.mixture_rho.is_some() {
598            return Err(
599                "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
600            );
601        }
602        match choice.link {
603            LinkFunction::Sas => {
604                if input.beta_logistic_init.is_some() {
605                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
606                }
607                let (epsilon, log_delta) = if let Some(raw) = input.sas_init {
608                    let vals = parse_comma_f64(raw, "--sas-init")?;
609                    if vals.len() != 2 {
610                        return Err(format!(
611                            "--sas-init expects two values: epsilon,log_delta (got {})",
612                            vals.len()
613                        ));
614                    }
615                    (vals[0], vals[1])
616                } else {
617                    (0.0, 0.0)
618                };
619                state_from_sasspec(SasLinkSpec {
620                    initial_epsilon: epsilon,
621                    initial_log_delta: log_delta,
622                })
623                .map(InverseLink::Sas)
624                .map_err(|e| format!("invalid survival SAS link state: {e}"))
625            }
626            LinkFunction::BetaLogistic => {
627                if input.sas_init.is_some() {
628                    return Err("--sas-init requires --link sas".to_string());
629                }
630                let (epsilon, delta) = if let Some(raw) = input.beta_logistic_init {
631                    let vals = parse_comma_f64(raw, "--beta-logistic-init")?;
632                    if vals.len() != 2 {
633                        return Err(format!(
634                            "--beta-logistic-init expects two values: epsilon,delta (got {})",
635                            vals.len()
636                        ));
637                    }
638                    (vals[0], vals[1])
639                } else {
640                    (0.0, 0.0)
641                };
642                state_from_beta_logisticspec(SasLinkSpec {
643                    initial_epsilon: epsilon,
644                    initial_log_delta: delta,
645                })
646                .map(InverseLink::BetaLogistic)
647                .map_err(|e| format!("invalid survival Beta-Logistic link state: {e}"))
648            }
649            LinkFunction::Log => Err(format!(
650                "unsupported survival --link 'log'; {}",
651                survival_link_usage()
652            )),
653            other => {
654                if input.sas_init.is_some() {
655                    return Err("--sas-init requires --link sas".to_string());
656                }
657                if input.beta_logistic_init.is_some() {
658                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
659                }
660                Ok(InverseLink::Standard(effective_link_to_standard(
661                    other,
662                    "survival inverse link",
663                )?))
664            }
665        }
666    } else {
667        if input.mixture_rho.is_some() {
668            return Err("--mixture-rho requires --link blended(...)/mixture(...)".to_string());
669        }
670        if input.sas_init.is_some() {
671            return Err("--sas-init requires --link sas".to_string());
672        }
673        if input.beta_logistic_init.is_some() {
674            return Err("--beta-logistic-init requires --link beta-logistic".to_string());
675        }
676        let dist = parse_survival_distribution(input.survival_distribution)?;
677        Ok(residual_distribution_inverse_link(dist))
678    }
679}
680
681pub fn normalize_optional_family(family: Option<String>) -> Option<String> {
682    match family {
683        Some(value) if value.eq_ignore_ascii_case("auto") => None,
684        other => other,
685    }
686}
687
688fn resolve_nonempty_string(raw: String, message: &str) -> Result<String, String> {
689    let trimmed = raw.trim();
690    if trimmed.is_empty() {
691        return Err(message.to_string());
692    }
693    Ok(trimmed.to_string())
694}
695
696fn parse_json_frailty_spec(
697    frailty_kind: Option<String>,
698    frailty_sd: Option<f64>,
699    hazard_loading: Option<String>,
700) -> Result<Option<FrailtySpec>, String> {
701    if let Some(kind) = frailty_kind {
702        let trimmed = kind.trim().to_ascii_lowercase();
703        let sigma = frailty_sd;
704        if let Some(value) = sigma
705            && (!value.is_finite() || value < 0.0)
706        {
707            return Err(format!("frailty_sd must be finite and >= 0, got {value}"));
708        }
709        let hazard_loading = hazard_loading
710            .as_ref()
711            .map(|raw| raw.trim().to_ascii_lowercase());
712        let frailty = match trimmed.as_str() {
713            "none" | "" => {
714                if sigma.is_some() || hazard_loading.is_some() {
715                    return Err(
716                        "frailty_kind='none' does not accept frailty_sd or hazard_loading"
717                            .to_string(),
718                    );
719                }
720                FrailtySpec::None
721            }
722            "hazard-multiplier" => {
723                let loading = match hazard_loading.as_deref() {
724                    Some("full") | None => HazardLoading::Full,
725                    Some("loaded-vs-unloaded") => HazardLoading::LoadedVsUnloaded,
726                    Some(other) => {
727                        return Err(format!(
728                            "unknown hazard_loading '{other}'; supported: 'full', 'loaded-vs-unloaded'"
729                        ));
730                    }
731                };
732                FrailtySpec::HazardMultiplier {
733                    sigma_fixed: sigma,
734                    loading,
735                }
736            }
737            "gaussian-shift" => {
738                if hazard_loading.is_some() {
739                    return Err(
740                        "hazard_loading is valid only with frailty_kind='hazard-multiplier'"
741                            .to_string(),
742                    );
743                }
744                FrailtySpec::GaussianShift { sigma_fixed: sigma }
745            }
746            other => {
747                return Err(format!(
748                    "unknown frailty_kind '{other}'; supported: 'none', 'hazard-multiplier', 'gaussian-shift'"
749                ));
750            }
751        };
752        Ok(Some(frailty))
753    } else if frailty_sd.is_some() || hazard_loading.is_some() {
754        Err("frailty_kind is required when frailty_sd or hazard_loading is provided".to_string())
755    } else {
756        Ok(None)
757    }
758}
759
760fn cli_hazard_loading(loading: CliHazardLoading) -> HazardLoading {
761    match loading {
762        CliHazardLoading::Full => HazardLoading::Full,
763        CliHazardLoading::LoadedVsUnloaded => HazardLoading::LoadedVsUnloaded,
764    }
765}
766
767fn parse_group_metadata(
768    direct: Option<GroupMetadata>,
769    groups: Option<JsonValue>,
770) -> Result<Option<GroupMetadata>, String> {
771    match (direct, groups) {
772        (Some(metadata), None) => Ok(nonempty_group_metadata(metadata)),
773        (None, Some(groups)) => group_metadata_from_groups(groups),
774        (None, None) => Ok(None),
775        (Some(_), Some(_)) => {
776            Err("fit config accepts either group_metadata or groups metadata, not both".to_string())
777        }
778    }
779}
780
781fn parse_gamma_pair_value(label: &str, value: JsonValue) -> Result<(String, f64, f64), String> {
782    match value {
783        JsonValue::Array(values) => {
784            if values.len() != 2 {
785                return Err(format!(
786                    "precision_hyperpriors['{label}'] must be [shape, rate]"
787                ));
788            }
789            let shape = values[0]
790                .as_f64()
791                .ok_or_else(|| format!("precision_hyperpriors['{label}'][0] must be numeric"))?;
792            let rate = values[1]
793                .as_f64()
794                .ok_or_else(|| format!("precision_hyperpriors['{label}'][1] must be numeric"))?;
795            Ok((label.to_string(), shape, rate))
796        }
797        JsonValue::Object(mut map) => {
798            let shape = map
799                .remove("shape")
800                .or_else(|| map.remove("a"))
801                .or_else(|| map.remove("a_p"))
802                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing shape/a"))?
803                .as_f64()
804                .ok_or_else(|| {
805                    format!("precision_hyperpriors['{label}'] shape/a must be numeric")
806                })?;
807            let rate = map
808                .remove("rate")
809                .or_else(|| map.remove("b"))
810                .or_else(|| map.remove("b_p"))
811                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing rate/b"))?
812                .as_f64()
813                .ok_or_else(|| {
814                    format!("precision_hyperpriors['{label}'] rate/b must be numeric")
815                })?;
816            Ok((label.to_string(), shape, rate))
817        }
818        _ => Err(format!(
819            "precision_hyperpriors['{label}'] must be [shape, rate] or an object"
820        )),
821    }
822}
823
824fn parse_precision_hyperpriors(
825    precision_hyperpriors: Option<JsonValue>,
826    penalty_block_gamma_priors: Option<JsonValue>,
827) -> Result<Vec<(String, f64, f64)>, String> {
828    let raw = match (precision_hyperpriors, penalty_block_gamma_priors) {
829        (Some(_), Some(_)) => {
830            return Err(
831                "fit config accepts either precision_hyperpriors or penalty_block_gamma_priors, not both"
832                    .to_string(),
833            );
834        }
835        (Some(raw), None) | (None, Some(raw)) => raw,
836        (None, None) => {
837            return Ok(Vec::new());
838        }
839    };
840    let raw_name = "precision_hyperpriors";
841    let Some(raw) = (match raw {
842        JsonValue::Null => None,
843        other => Some(other),
844    }) else {
845        return Ok(Vec::new());
846    };
847    match raw {
848        JsonValue::Object(map) => map
849            .into_iter()
850            .map(|(label, value)| parse_gamma_pair_value(&label, value))
851            .collect(),
852        JsonValue::Array(items) => items
853            .into_iter()
854            .enumerate()
855            .map(|(idx, item)| match item {
856                JsonValue::Object(mut obj) => {
857                    let label = obj
858                        .remove("label")
859                        .or_else(|| obj.remove("name"))
860                        .or_else(|| obj.remove("group"))
861                        .ok_or_else(|| format!("{raw_name}[{idx}] needs label/name/group"))?;
862                    let JsonValue::String(label) = label else {
863                        return Err(format!("{raw_name}[{idx}] label must be a string"));
864                    };
865                    parse_gamma_pair_value(&label, JsonValue::Object(obj))
866                }
867                JsonValue::Array(mut values) => {
868                    if values.len() != 2 && values.len() != 3 {
869                        return Err(format!(
870                            "{raw_name}[{idx}] must be [label, shape, rate] or [label, [shape, rate]]"
871                        ));
872                    }
873                    let label = values.remove(0);
874                    let JsonValue::String(label) = label else {
875                        return Err(format!("{raw_name}[{idx}][0] must be a string label"));
876                    };
877                    let pair = if values.len() == 1 {
878                        values.remove(0)
879                    } else {
880                        JsonValue::Array(values)
881                    };
882                    parse_gamma_pair_value(&label, pair)
883                }
884                _ => Err(format!("{raw_name}[{idx}] must be an object or array")),
885            })
886            .collect(),
887        _ => Err(format!("{raw_name} must be a map or array")),
888    }
889}
890
891fn nonempty_group_metadata(metadata: GroupMetadata) -> Option<GroupMetadata> {
892    if metadata.is_empty() {
893        None
894    } else {
895        Some(metadata)
896    }
897}
898
899fn group_metadata_from_groups(groups: JsonValue) -> Result<Option<GroupMetadata>, String> {
900    match groups {
901        JsonValue::Null => Ok(None),
902        JsonValue::Object(map) => {
903            let out = map.into_iter().collect::<BTreeMap<_, _>>();
904            Ok(nonempty_group_metadata(out))
905        }
906        JsonValue::Array(items) => {
907            let mut out = BTreeMap::new();
908            for (idx, item) in items.into_iter().enumerate() {
909                let JsonValue::Object(mut group) = item else {
910                    return Err(format!("groups[{idx}] must be an object"));
911                };
912                let Some(metadata) = group.remove("metadata") else {
913                    continue;
914                };
915                let name = group
916                    .remove("name")
917                    .or_else(|| group.remove("id"))
918                    .or_else(|| group.remove("key"))
919                    .ok_or_else(|| {
920                        format!(
921                            "groups[{idx}] with metadata must include a string name, id, or key"
922                        )
923                    })?;
924                let JsonValue::String(name) = name else {
925                    return Err(format!("groups[{idx}] name/id/key must be a string"));
926                };
927                if name.is_empty() {
928                    return Err(format!("groups[{idx}] name/id/key must be non-empty"));
929                }
930                if out.insert(name.clone(), metadata).is_some() {
931                    return Err(format!("duplicate group metadata key '{name}'"));
932                }
933            }
934            Ok(nonempty_group_metadata(out))
935        }
936        _ => Err("groups must be an object map or an array of group objects".to_string()),
937    }
938}
939
940fn parse_gpu_policy(raw_gpu: &str) -> Result<gam_gpu::GpuPolicy, String> {
941    gam_gpu::GpuPolicy::parse(raw_gpu).ok_or_else(|| {
942        format!(
943            "invalid gpu policy '{}'; supported values are auto, off, force",
944            raw_gpu
945        )
946    })
947}
948
949fn validate_resolved_fit_config(config: &FitConfig) -> Result<(), String> {
950    if !config.ridge_lambda.is_finite() || config.ridge_lambda < 0.0 {
951        return Err("--ridge-lambda must be finite and >= 0".to_string());
952    }
953    let likelihood_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
954    validate_survival_baseline_args(
955        likelihood_mode,
956        &config.baseline_target,
957        config.baseline_scale,
958        config.baseline_shape,
959        config.baseline_rate,
960        config.baseline_makeham,
961    )
962}
963
964fn survival_link_usage() -> &'static str {
965    "use identity|logit|probit|cloglog|loglog|cauchit|sas|beta-logistic|blended(...)/mixture(...) or flexible(...)"
966}
967
968#[cfg(test)]
969mod tests {
970    use super::*;
971    use gam_models::survival::lognormal_kernel::FrailtySpec;
972    use serde_json::{Value, json};
973
974    struct ParityCase {
975        name: &'static str,
976        cli: CliFitConfigInput,
977        json: Value,
978    }
979
980    fn base_cli() -> CliFitConfigInput {
981        CliFitConfigInput {
982            family: None,
983            expectile_tau: None,
984            negative_binomial_theta: None,
985            link: None,
986            flexible_link: false,
987            offset_column: None,
988            weight_column: None,
989            noise_offset_column: None,
990            baseline_target: "linear".to_string(),
991            baseline_scale: None,
992            baseline_shape: None,
993            baseline_rate: None,
994            baseline_makeham: None,
995            time_basis: "ispline".to_string(),
996            time_degree: 3,
997            time_num_internal_knots: 8,
998            time_smooth_lambda: 1e-2,
999            survival_likelihood: "transformation".to_string(),
1000            survival_distribution: "gaussian".to_string(),
1001            threshold_time_k: None,
1002            threshold_time_degree: 3,
1003            sigma_time_k: None,
1004            sigma_time_degree: 3,
1005            noise_formula: None,
1006            logslope_formula: None,
1007            z_column: None,
1008            scale_dimensions: false,
1009            adaptive_regularization: None,
1010            ridge_lambda: 1e-6,
1011            transformation_normal: false,
1012            firth: false,
1013            outer_max_iter: None,
1014            gpu: None,
1015            frailty_kind: None,
1016            frailty_sd: None,
1017            hazard_loading: None,
1018        }
1019    }
1020
1021    fn resolved_cli(input: CliFitConfigInput) -> Result<FitConfig, String> {
1022        resolve_cli_fit_config(input)
1023    }
1024
1025    fn resolved_json(config: Value) -> Result<FitConfig, String> {
1026        parse_fit_config_json(Some(&config.to_string())).map(|resolved| {
1027            assert_eq!(resolved.training_table_kind, None);
1028            resolved.fit_config
1029        })
1030    }
1031
1032    fn canonical_fit_config(mut config: FitConfig) -> String {
1033        if config.frailty.is_none() {
1034            config.frailty = Some(FrailtySpec::None);
1035        }
1036        format!("{config:#?}")
1037    }
1038
1039    #[test]
1040    fn cli_shaped_and_json_wire_config_resolution_match() {
1041        let cases = vec![
1042            ParityCase {
1043                name: "family and link selection",
1044                cli: {
1045                    let mut input = base_cli();
1046                    input.family = Some("binomial".to_string());
1047                    input.link = Some("probit".to_string());
1048                    input.flexible_link = true;
1049                    input
1050                },
1051                json: json!({
1052                    "family": "binomial",
1053                    "link": "probit",
1054                    "flexible_link": true
1055                }),
1056            },
1057            ParityCase {
1058                name: "offset weights ridge and noise offset columns",
1059                cli: {
1060                    let mut input = base_cli();
1061                    input.offset_column = Some("eta_offset".to_string());
1062                    input.weight_column = Some("case_weight".to_string());
1063                    input.noise_offset_column = Some("sigma_offset".to_string());
1064                    input.ridge_lambda = 0.125;
1065                    input
1066                },
1067                json: json!({
1068                    "offset": "eta_offset",
1069                    "weights": "case_weight",
1070                    "noise_offset": "sigma_offset",
1071                    "ridge_lambda": 0.125
1072                }),
1073            },
1074            ParityCase {
1075                name: "weibull survival likelihood and baseline scale shape",
1076                cli: {
1077                    let mut input = base_cli();
1078                    input.survival_likelihood = "weibull".to_string();
1079                    input.baseline_target = "weibull".to_string();
1080                    input.baseline_scale = Some(2.5);
1081                    input.baseline_shape = Some(1.75);
1082                    input
1083                },
1084                json: json!({
1085                    "survival_likelihood": "weibull",
1086                    "baseline_target": "weibull",
1087                    "baseline_scale": 2.5,
1088                    "baseline_shape": 1.75
1089                }),
1090            },
1091            ParityCase {
1092                name: "transformation survival gompertz makeham baseline",
1093                cli: {
1094                    let mut input = base_cli();
1095                    input.survival_likelihood = "transformation".to_string();
1096                    input.baseline_target = "gompertz-makeham".to_string();
1097                    input.baseline_shape = Some(1.2);
1098                    input.baseline_rate = Some(0.04);
1099                    input.baseline_makeham = Some(0.01);
1100                    input
1101                },
1102                json: json!({
1103                    "survival_likelihood": "transformation",
1104                    "baseline_target": "gompertz-makeham",
1105                    "baseline_shape": 1.2,
1106                    "baseline_rate": 0.04,
1107                    "baseline_makeham": 0.01
1108                }),
1109            },
1110            ParityCase {
1111                name: "survival likelihood values are canonicalized",
1112                cli: {
1113                    let mut input = base_cli();
1114                    input.survival_likelihood = "TRANSFORMATION".to_string();
1115                    input
1116                },
1117                json: json!({
1118                    "survival_likelihood": "Transformation"
1119                }),
1120            },
1121            ParityCase {
1122                name: "noise formula logslope z column and scale dimensions",
1123                cli: {
1124                    let mut input = base_cli();
1125                    input.noise_formula = Some("~ s(age) + treatment".to_string());
1126                    input.logslope_formula = Some("~ s(dose)".to_string());
1127                    input.z_column = Some("dose".to_string());
1128                    input.scale_dimensions = true;
1129                    input
1130                },
1131                json: json!({
1132                    "noise_formula": "~ s(age) + treatment",
1133                    "logslope_formula": "~ s(dose)",
1134                    "z_column": "dose",
1135                    "scale_dimensions": true
1136                }),
1137            },
1138            ParityCase {
1139                name: "firth transformation normal outer iterations and adaptive regularization",
1140                cli: {
1141                    let mut input = base_cli();
1142                    input.firth = true;
1143                    input.transformation_normal = true;
1144                    input.outer_max_iter = Some(7);
1145                    input.adaptive_regularization = Some(true);
1146                    input
1147                },
1148                json: json!({
1149                    "firth": true,
1150                    "transformation_normal": true,
1151                    "outer_max_iter": 7,
1152                    "adaptive_regularization": true
1153                }),
1154            },
1155            ParityCase {
1156                name: "gpu policy toggle",
1157                cli: {
1158                    let mut input = base_cli();
1159                    input.gpu = Some("off".to_string());
1160                    input
1161                },
1162                json: json!({
1163                    "gpu": "off"
1164                }),
1165            },
1166            ParityCase {
1167                name: "hazard multiplier frailty fields",
1168                cli: {
1169                    let mut input = base_cli();
1170                    input.frailty_kind = Some(CliFrailtyKind::HazardMultiplier);
1171                    input.frailty_sd = Some(0.35);
1172                    input.hazard_loading = Some(CliHazardLoading::LoadedVsUnloaded);
1173                    input
1174                },
1175                json: json!({
1176                    "frailty_kind": "hazard-multiplier",
1177                    "frailty_sd": 0.35,
1178                    "hazard_loading": "loaded-vs-unloaded"
1179                }),
1180            },
1181            ParityCase {
1182                name: "gaussian shift frailty fields",
1183                cli: {
1184                    let mut input = base_cli();
1185                    input.frailty_kind = Some(CliFrailtyKind::GaussianShift);
1186                    input.frailty_sd = Some(0.2);
1187                    input
1188                },
1189                json: json!({
1190                    "frailty_kind": "gaussian-shift",
1191                    "frailty_sd": 0.2
1192                }),
1193            },
1194        ];
1195
1196        for case in cases {
1197            let cli = resolved_cli(case.cli)
1198                .unwrap_or_else(|err| panic!("{}: CLI-shaped config failed: {err}", case.name));
1199            let json = resolved_json(case.json)
1200                .unwrap_or_else(|err| panic!("{}: JSON wire config failed: {err}", case.name));
1201            assert_eq!(
1202                canonical_fit_config(cli),
1203                canonical_fit_config(json),
1204                "{}",
1205                case.name
1206            );
1207        }
1208    }
1209
1210    #[test]
1211    fn cli_shaped_and_json_wire_config_resolution_rejections_match() {
1212        let cases = vec![
1213            ParityCase {
1214                name: "negative ridge lambda",
1215                cli: {
1216                    let mut input = base_cli();
1217                    input.ridge_lambda = -1.0;
1218                    input
1219                },
1220                json: json!({
1221                    "ridge_lambda": -1.0
1222                }),
1223            },
1224            ParityCase {
1225                name: "unknown gpu policy",
1226                cli: {
1227                    let mut input = base_cli();
1228                    input.gpu = Some("cuda".to_string());
1229                    input
1230                },
1231                json: json!({
1232                    "gpu": "cuda"
1233                }),
1234            },
1235            ParityCase {
1236                name: "linear baseline rejects shape",
1237                cli: {
1238                    let mut input = base_cli();
1239                    input.baseline_shape = Some(1.1);
1240                    input
1241                },
1242                json: json!({
1243                    "baseline_shape": 1.1
1244                }),
1245            },
1246            ParityCase {
1247                name: "weibull likelihood rejects gompertz target",
1248                cli: {
1249                    let mut input = base_cli();
1250                    input.survival_likelihood = "weibull".to_string();
1251                    input.baseline_target = "gompertz".to_string();
1252                    input
1253                },
1254                json: json!({
1255                    "survival_likelihood": "weibull",
1256                    "baseline_target": "gompertz"
1257                }),
1258            },
1259        ];
1260
1261        for case in cases {
1262            let cli = resolved_cli(case.cli).expect_err(case.name);
1263            let json = resolved_json(case.json).expect_err(case.name);
1264            assert_eq!(cli, json, "{}", case.name);
1265        }
1266    }
1267
1268    // ── parse_comma_f64 ───────────────────────────────────────────────────
1269
1270    #[test]
1271    fn parse_comma_f64_empty_string_returns_empty_vec() {
1272        assert_eq!(parse_comma_f64("", "x").unwrap(), Vec::<f64>::new());
1273        assert_eq!(parse_comma_f64("   ", "x").unwrap(), Vec::<f64>::new());
1274    }
1275
1276    #[test]
1277    fn parse_comma_f64_single_value() {
1278        assert_eq!(parse_comma_f64("3.14", "x").unwrap(), vec![3.14]);
1279    }
1280
1281    #[test]
1282    fn parse_comma_f64_multiple_values_with_spaces() {
1283        let result = parse_comma_f64("1.0, 2.5, -3.0", "x").unwrap();
1284        assert_eq!(result, vec![1.0, 2.5, -3.0]);
1285    }
1286
1287    #[test]
1288    fn parse_comma_f64_non_numeric_returns_error() {
1289        let err = parse_comma_f64("1.0, bad, 3.0", "--vals").unwrap_err();
1290        assert!(err.contains("--vals"), "error should name the label: {err}");
1291        assert!(err.contains("bad"), "error should name the bad token: {err}");
1292    }
1293
1294    #[test]
1295    fn parse_comma_f64_infinity_returns_error() {
1296        let err = parse_comma_f64("inf", "--vals").unwrap_err();
1297        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1298    }
1299
1300    #[test]
1301    fn parse_comma_f64_nan_returns_error() {
1302        let err = parse_comma_f64("nan", "--vals").unwrap_err();
1303        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1304    }
1305
1306    // ── normalize_optional_family ─────────────────────────────────────────
1307
1308    #[test]
1309    fn normalize_optional_family_none_passthrough() {
1310        assert_eq!(normalize_optional_family(None), None);
1311    }
1312
1313    #[test]
1314    fn normalize_optional_family_auto_becomes_none() {
1315        assert_eq!(normalize_optional_family(Some("auto".to_string())), None);
1316        assert_eq!(normalize_optional_family(Some("Auto".to_string())), None);
1317        assert_eq!(normalize_optional_family(Some("AUTO".to_string())), None);
1318    }
1319
1320    #[test]
1321    fn normalize_optional_family_non_auto_passthrough() {
1322        assert_eq!(
1323            normalize_optional_family(Some("binomial".to_string())),
1324            Some("binomial".to_string())
1325        );
1326        assert_eq!(
1327            normalize_optional_family(Some("gaussian".to_string())),
1328            Some("gaussian".to_string())
1329        );
1330    }
1331
1332    // ── parse_survival_likelihood_cli ─────────────────────────────────────
1333
1334    #[test]
1335    fn parse_survival_likelihood_cli_valid_values() {
1336        assert_eq!(
1337            parse_survival_likelihood_cli("transformation").unwrap(),
1338            "transformation"
1339        );
1340        assert_eq!(
1341            parse_survival_likelihood_cli("weibull").unwrap(),
1342            "weibull"
1343        );
1344        // case-insensitive
1345        assert_eq!(
1346            parse_survival_likelihood_cli("WEIBULL").unwrap(),
1347            "weibull"
1348        );
1349        assert_eq!(
1350            parse_survival_likelihood_cli("Transformation").unwrap(),
1351            "transformation"
1352        );
1353    }
1354
1355    #[test]
1356    fn parse_survival_likelihood_cli_invalid_returns_error() {
1357        assert!(parse_survival_likelihood_cli("lognormal").is_err());
1358        assert!(parse_survival_likelihood_cli("").is_err());
1359    }
1360
1361    // ── parse_baseline_target_cli ─────────────────────────────────────────
1362
1363    #[test]
1364    fn parse_baseline_target_cli_valid_values() {
1365        for target in &["linear", "weibull", "gompertz", "gompertz-makeham"] {
1366            assert_eq!(
1367                parse_baseline_target_cli(target).unwrap(),
1368                *target,
1369                "should accept '{target}'"
1370            );
1371        }
1372        // trimmed and lowercased
1373        assert_eq!(
1374            parse_baseline_target_cli("  Weibull  ").unwrap(),
1375            "weibull"
1376        );
1377    }
1378
1379    #[test]
1380    fn parse_baseline_target_cli_invalid_returns_error() {
1381        let err = parse_baseline_target_cli("cox").unwrap_err();
1382        assert!(err.contains("cox"), "error should name the bad value: {err}");
1383    }
1384
1385    // ── validate_survival_baseline_args ───────────────────────────────────
1386
1387    #[test]
1388    fn validate_survival_baseline_args_linear_rejects_params() {
1389        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1390        assert!(validate_survival_baseline_args(
1391            mode,
1392            "linear",
1393            Some(1.0),
1394            None,
1395            None,
1396            None
1397        )
1398        .is_err());
1399    }
1400
1401    #[test]
1402    fn validate_survival_baseline_args_linear_accepts_no_params() {
1403        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1404        assert!(validate_survival_baseline_args(
1405            mode, "linear", None, None, None, None
1406        )
1407        .is_ok());
1408    }
1409
1410    #[test]
1411    fn validate_survival_baseline_args_weibull_likelihood_rejects_gompertz_target() {
1412        let mode = parse_survival_likelihood_mode("weibull").unwrap();
1413        assert!(validate_survival_baseline_args(
1414            mode,
1415            "gompertz",
1416            None,
1417            None,
1418            None,
1419            None
1420        )
1421        .is_err());
1422    }
1423
1424    #[test]
1425    fn validate_survival_baseline_args_gompertz_rejects_scale() {
1426        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1427        assert!(validate_survival_baseline_args(
1428            mode,
1429            "gompertz",
1430            Some(2.0),
1431            None,
1432            None,
1433            None
1434        )
1435        .is_err());
1436    }
1437
1438    #[test]
1439    fn validate_survival_baseline_args_gompertz_makeham_rejects_scale() {
1440        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1441        assert!(validate_survival_baseline_args(
1442            mode,
1443            "gompertz-makeham",
1444            Some(1.0),
1445            None,
1446            None,
1447            None
1448        )
1449        .is_err());
1450    }
1451}