Skip to main content

gam_config/
lib.rs

1use gam_models::survival::location_scale::residual_distribution_inverse_link;
2use gam_models::survival::lognormal_kernel::{FrailtySpec, HazardLoading};
3use gam_models::survival::parse_survival_distribution;
4use gam_models::survival::{SurvivalLikelihoodMode, parse_survival_likelihood_mode};
5use gam_inference::formula_dsl::parse_link_choice;
6use gam_inference::model::GroupMetadata;
7use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
8use gam_models::fit_orchestration::descriptors::build_analytic_penalty_registry_from_descriptors;
9use gam_models::fit_orchestration::{CtnStage1Recipe, FitConfig};
10use gam_models::transformation_normal::TransformationNormalConfig;
11use gam_problem::types::{InverseLink, LinkFunction, MixtureLinkSpec, SasLinkSpec, StandardLink};
12use ndarray::Array1;
13use serde::Deserialize;
14use serde_json::Value as JsonValue;
15use std::collections::BTreeMap;
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18pub enum CliFrailtyKind {
19    GaussianShift,
20    HazardMultiplier,
21}
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum CliHazardLoading {
25    Full,
26    LoadedVsUnloaded,
27}
28
29#[derive(Default, Deserialize)]
30#[serde(deny_unknown_fields)]
31struct JsonFitConfig {
32    family: Option<String>,
33    offset: Option<String>,
34    weights: Option<String>,
35    ridge_lambda: Option<f64>,
36    transformation_normal: Option<bool>,
37    survival_likelihood: Option<String>,
38    baseline_target: Option<String>,
39    baseline_scale: Option<f64>,
40    baseline_shape: Option<f64>,
41    baseline_rate: Option<f64>,
42    baseline_makeham: Option<f64>,
43    z_column: Option<String>,
44    logslope_formula: Option<String>,
45    ctn_stage1: Option<JsonCtnStage1>,
46    link: Option<String>,
47    flexible_link: Option<bool>,
48    scale_dimensions: Option<bool>,
49    adaptive_regularization: Option<bool>,
50    noise_formula: Option<String>,
51    noise_offset: Option<String>,
52    firth: Option<bool>,
53    outer_max_iter: Option<usize>,
54    gpu: Option<String>,
55    group_metadata: Option<GroupMetadata>,
56    groups: Option<JsonValue>,
57    precision_hyperpriors: Option<JsonValue>,
58    penalty_block_gamma_priors: Option<JsonValue>,
59    latents: Option<JsonValue>,
60    penalties: Option<JsonValue>,
61    smooths: Option<JsonValue>,
62    topology_auto_selector: Option<JsonValue>,
63    frailty_kind: Option<String>,
64    frailty_sd: Option<f64>,
65    hazard_loading: Option<String>,
66    training_table_kind: Option<String>,
67}
68
69#[derive(Deserialize)]
70#[serde(deny_unknown_fields)]
71struct JsonCtnStage1 {
72    response_column: String,
73    covariate_formula_rhs: String,
74    #[serde(default)]
75    config: Option<JsonCtnStage1Config>,
76    #[serde(default)]
77    weight_column: Option<String>,
78    #[serde(default)]
79    offset_column: Option<String>,
80}
81
82#[derive(Deserialize)]
83#[serde(deny_unknown_fields)]
84struct JsonCtnStage1Config {
85    #[serde(default)]
86    response_degree: Option<usize>,
87    #[serde(default)]
88    response_num_internal_knots: Option<usize>,
89    #[serde(default)]
90    response_penalty_order: Option<usize>,
91    #[serde(default)]
92    response_extra_penalty_orders: Option<Vec<usize>>,
93    #[serde(default)]
94    double_penalty: Option<bool>,
95}
96
97impl JsonCtnStage1 {
98    fn into_recipe(self) -> Result<CtnStage1Recipe, String> {
99        let mut config = TransformationNormalConfig::default();
100        if let Some(overrides) = self.config {
101            if let Some(value) = overrides.response_degree {
102                config.response_degree = value;
103            }
104            if let Some(value) = overrides.response_num_internal_knots {
105                config.response_num_internal_knots = value;
106            }
107            if let Some(value) = overrides.response_penalty_order {
108                config.response_penalty_order = value;
109            }
110            if let Some(value) = overrides.response_extra_penalty_orders {
111                config.response_extra_penalty_orders = value;
112            }
113            if let Some(value) = overrides.double_penalty {
114                config.double_penalty = value;
115            }
116        }
117        CtnStage1Recipe::new(
118            &self.response_column,
119            &self.covariate_formula_rhs,
120            config,
121            self.weight_column.as_deref(),
122            self.offset_column.as_deref(),
123        )
124    }
125}
126
127pub struct ResolvedFitConfig {
128    pub fit_config: FitConfig,
129    pub training_table_kind: Option<String>,
130}
131
132pub struct CliFitConfigInput {
133    pub family: Option<String>,
134    pub negative_binomial_theta: Option<f64>,
135    pub link: Option<String>,
136    pub flexible_link: bool,
137    pub offset_column: Option<String>,
138    pub weight_column: Option<String>,
139    pub noise_offset_column: Option<String>,
140    pub baseline_target: String,
141    pub baseline_scale: Option<f64>,
142    pub baseline_shape: Option<f64>,
143    pub baseline_rate: Option<f64>,
144    pub baseline_makeham: Option<f64>,
145    pub time_basis: String,
146    pub time_degree: usize,
147    pub time_num_internal_knots: usize,
148    pub time_smooth_lambda: f64,
149    pub survival_likelihood: String,
150    pub survival_distribution: String,
151    pub threshold_time_k: Option<usize>,
152    pub threshold_time_degree: usize,
153    pub sigma_time_k: Option<usize>,
154    pub sigma_time_degree: usize,
155    pub noise_formula: Option<String>,
156    pub logslope_formula: Option<String>,
157    pub z_column: Option<String>,
158    pub scale_dimensions: bool,
159    pub adaptive_regularization: Option<bool>,
160    pub ridge_lambda: f64,
161    pub transformation_normal: bool,
162    pub firth: bool,
163    pub outer_max_iter: Option<usize>,
164    pub gpu: Option<String>,
165    pub frailty_kind: Option<CliFrailtyKind>,
166    pub frailty_sd: Option<f64>,
167    pub hazard_loading: Option<CliHazardLoading>,
168}
169
170pub struct SurvivalInverseLinkInput<'a> {
171    pub link: Option<&'a str>,
172    pub mixture_rho: Option<&'a str>,
173    pub sas_init: Option<&'a str>,
174    pub beta_logistic_init: Option<&'a str>,
175    pub survival_distribution: &'a str,
176}
177
178pub fn parse_fit_config_json(config_json: Option<&str>) -> Result<ResolvedFitConfig, String> {
179    let json_config = match config_json {
180        Some(raw) if !raw.trim().is_empty() => serde_json::from_str::<JsonFitConfig>(raw)
181            .map_err(|err| format!("invalid fit config json: {err}"))?,
182        _ => JsonFitConfig::default(),
183    };
184    resolve_json_fit_config(json_config)
185}
186
187fn resolve_json_fit_config(json_config: JsonFitConfig) -> Result<ResolvedFitConfig, String> {
188    let training_table_kind = json_config.training_table_kind;
189    let mut fit_config = FitConfig::default();
190    fit_config.group_metadata =
191        parse_group_metadata(json_config.group_metadata, json_config.groups)?;
192    fit_config.penalty_block_gamma_priors = parse_precision_hyperpriors(
193        json_config.precision_hyperpriors,
194        json_config.penalty_block_gamma_priors,
195    )?;
196    let analytic_penalties = json_config.penalties;
197    build_analytic_penalty_registry_from_descriptors(
198        json_config.latents.as_ref(),
199        analytic_penalties.as_ref(),
200    )?;
201    fit_config.latents = json_config.latents;
202    fit_config.analytic_penalties = analytic_penalties;
203    fit_config.smooth_overrides = json_config.smooths;
204    fit_config.topology_auto_selector = json_config
205        .topology_auto_selector
206        .as_ref()
207        .map(gam_solve::topology_selector::TopologyAutoSelector::from_json)
208        .transpose()?;
209    fit_config.family = normalize_optional_family(json_config.family);
210    fit_config.offset_column = json_config.offset;
211    fit_config.weight_column = json_config.weights;
212    if let Some(ridge_lambda) = json_config.ridge_lambda {
213        fit_config.ridge_lambda = ridge_lambda;
214    }
215    if let Some(flag) = json_config.transformation_normal {
216        fit_config.transformation_normal = flag;
217    }
218    if let Some(mode) = json_config.survival_likelihood {
219        fit_config.survival_likelihood = parse_survival_likelihood_cli(&mode)?;
220    }
221    if let Some(target) = json_config.baseline_target {
222        fit_config.baseline_target =
223            resolve_nonempty_string(target, "baseline_target must be a non-empty string")?;
224    }
225    if let Some(value) = json_config.baseline_scale {
226        fit_config.baseline_scale = Some(value);
227    }
228    if let Some(value) = json_config.baseline_shape {
229        fit_config.baseline_shape = Some(value);
230    }
231    if let Some(value) = json_config.baseline_rate {
232        fit_config.baseline_rate = Some(value);
233    }
234    if let Some(value) = json_config.baseline_makeham {
235        fit_config.baseline_makeham = Some(value);
236    }
237    if let Some(z) = json_config.z_column {
238        fit_config.z_column = Some(resolve_nonempty_string(
239            z,
240            "z_column must be a non-empty string",
241        )?);
242    }
243    if let Some(formula) = json_config.logslope_formula {
244        fit_config.logslope_formula = Some(formula);
245    }
246    if let Some(stage1) = json_config.ctn_stage1 {
247        fit_config.ctn_stage1 = Some(stage1.into_recipe()?);
248    }
249    if let Some(link) = json_config.link {
250        let trimmed = link.trim();
251        fit_config.link = if trimmed.is_empty() {
252            None
253        } else {
254            Some(trimmed.to_string())
255        };
256    }
257    if let Some(flag) = json_config.flexible_link {
258        fit_config.flexible_link = flag;
259    }
260    if let Some(flag) = json_config.scale_dimensions {
261        fit_config.scale_dimensions = flag;
262    }
263    if let Some(flag) = json_config.adaptive_regularization {
264        fit_config.adaptive_regularization = Some(flag);
265    }
266    if let Some(formula) = json_config.noise_formula {
267        fit_config.noise_formula = Some(formula);
268    }
269    if let Some(column) = json_config.noise_offset {
270        fit_config.noise_offset_column = Some(column);
271    }
272    if let Some(flag) = json_config.firth {
273        fit_config.firth = flag;
274    }
275    if let Some(value) = json_config.outer_max_iter {
276        if value == 0 {
277            return Err("outer_max_iter must be >= 1".to_string());
278        }
279        fit_config.outer_max_iter = Some(value);
280    }
281    if let Some(raw_gpu) = json_config.gpu {
282        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
283    }
284    fit_config.frailty = parse_json_frailty_spec(
285        json_config.frailty_kind,
286        json_config.frailty_sd,
287        json_config.hazard_loading,
288    )?;
289    validate_resolved_fit_config(&fit_config)?;
290    Ok(ResolvedFitConfig {
291        fit_config,
292        training_table_kind,
293    })
294}
295
296pub fn resolve_cli_fit_config(input: CliFitConfigInput) -> Result<FitConfig, String> {
297    let mut fit_config = FitConfig::default();
298    fit_config.family = input.family;
299    fit_config.negative_binomial_theta = input.negative_binomial_theta;
300    fit_config.link = input.link;
301    fit_config.flexible_link = input.flexible_link;
302    fit_config.offset_column = input.offset_column;
303    fit_config.weight_column = input.weight_column;
304    fit_config.noise_offset_column = input.noise_offset_column;
305    fit_config.baseline_target = input.baseline_target;
306    fit_config.baseline_scale = input.baseline_scale;
307    fit_config.baseline_shape = input.baseline_shape;
308    fit_config.baseline_rate = input.baseline_rate;
309    fit_config.baseline_makeham = input.baseline_makeham;
310    fit_config.time_basis = input.time_basis;
311    fit_config.time_degree = input.time_degree;
312    fit_config.time_num_internal_knots = input.time_num_internal_knots;
313    fit_config.time_smooth_lambda = input.time_smooth_lambda;
314    fit_config.survival_likelihood = parse_survival_likelihood_cli(&input.survival_likelihood)?;
315    fit_config.survival_distribution = input.survival_distribution;
316    fit_config.threshold_time_k = input.threshold_time_k;
317    fit_config.threshold_time_degree = input.threshold_time_degree;
318    fit_config.sigma_time_k = input.sigma_time_k;
319    fit_config.sigma_time_degree = input.sigma_time_degree;
320    fit_config.noise_formula = input.noise_formula;
321    fit_config.logslope_formula = input.logslope_formula;
322    fit_config.z_column = input.z_column;
323    fit_config.scale_dimensions = input.scale_dimensions;
324    fit_config.adaptive_regularization = input.adaptive_regularization;
325    fit_config.ridge_lambda = input.ridge_lambda;
326    fit_config.transformation_normal = input.transformation_normal;
327    fit_config.firth = input.firth;
328    fit_config.outer_max_iter = input.outer_max_iter;
329    if let Some(raw_gpu) = input.gpu {
330        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
331    }
332    fit_config.frailty = Some(resolve_cli_frailty_spec(
333        input.frailty_kind,
334        input.frailty_sd,
335        input.hazard_loading,
336        "fit",
337    )?);
338    validate_resolved_fit_config(&fit_config)?;
339    Ok(fit_config)
340}
341
342pub fn resolve_cli_frailty_spec(
343    frailty_kind: Option<CliFrailtyKind>,
344    frailty_sd: Option<f64>,
345    hazard_loading: Option<CliHazardLoading>,
346    context: &str,
347) -> Result<FrailtySpec, String> {
348    let validate_sigma = || -> Result<Option<f64>, String> {
349        match frailty_sd {
350            None => Ok(None),
351            Some(sigma) => {
352                if !sigma.is_finite() || sigma < 0.0 {
353                    return Err(format!(
354                        "{context} requires a finite --frailty-sd >= 0, got {sigma}"
355                    ));
356                }
357                Ok(Some(sigma))
358            }
359        }
360    };
361
362    match frailty_kind {
363        None => {
364            if frailty_sd.is_some() || hazard_loading.is_some() {
365                return Err(format!(
366                    "{context} requires --frailty-kind when --frailty-sd or --hazard-loading is provided"
367                ));
368            }
369            Ok(FrailtySpec::None)
370        }
371        Some(CliFrailtyKind::GaussianShift) => {
372            if hazard_loading.is_some() {
373                return Err(format!(
374                    "{context} does not accept --hazard-loading with --frailty-kind gaussian-shift"
375                ));
376            }
377            Ok(FrailtySpec::GaussianShift {
378                sigma_fixed: validate_sigma()?,
379            })
380        }
381        Some(CliFrailtyKind::HazardMultiplier) => Ok(FrailtySpec::HazardMultiplier {
382            sigma_fixed: validate_sigma()?,
383            loading: hazard_loading.map(cli_hazard_loading).ok_or_else(|| {
384                format!("{context} requires --hazard-loading with --frailty-kind hazard-multiplier")
385            })?,
386        }),
387    }
388}
389
390pub fn parse_survival_likelihood_cli(raw: &str) -> Result<String, String> {
391    let normalized = raw.trim().to_ascii_lowercase();
392    parse_survival_likelihood_mode(&normalized)?;
393    Ok(normalized)
394}
395
396pub fn parse_baseline_target_cli(raw: &str) -> Result<String, String> {
397    let normalized = raw.trim().to_ascii_lowercase();
398    match normalized.as_str() {
399        "linear" | "weibull" | "gompertz" | "gompertz-makeham" => Ok(normalized),
400        other => Err(format!(
401            "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
402        )),
403    }
404}
405
406pub fn validate_survival_baseline_args(
407    likelihood_mode: SurvivalLikelihoodMode,
408    baseline_target: &str,
409    baseline_scale: Option<f64>,
410    baseline_shape: Option<f64>,
411    baseline_rate: Option<f64>,
412    baseline_makeham: Option<f64>,
413) -> Result<(), String> {
414    if likelihood_mode == SurvivalLikelihoodMode::Weibull {
415        if baseline_rate.is_some() || baseline_makeham.is_some() {
416            return Err(
417                "--survival-likelihood weibull does not use --baseline-rate or --baseline-makeham"
418                    .to_string(),
419            );
420        }
421        if !matches!(baseline_target, "linear" | "weibull") {
422            return Err(
423                "--survival-likelihood weibull supports only --baseline-target linear|weibull"
424                    .to_string(),
425            );
426        }
427        return Ok(());
428    }
429
430    match baseline_target {
431        "linear" => {
432            if baseline_scale.is_some()
433                || baseline_shape.is_some()
434                || baseline_rate.is_some()
435                || baseline_makeham.is_some()
436            {
437                return Err(
438                    "--baseline-target linear does not use baseline parameter flags".to_string(),
439                );
440            }
441        }
442        "weibull" => {
443            if baseline_rate.is_some() || baseline_makeham.is_some() {
444                return Err(
445                    "--baseline-target weibull does not use --baseline-rate or --baseline-makeham"
446                        .to_string(),
447                );
448            }
449        }
450        "gompertz" => {
451            if baseline_scale.is_some() || baseline_makeham.is_some() {
452                return Err(
453                    "--baseline-target gompertz does not use --baseline-scale or --baseline-makeham"
454                        .to_string(),
455                );
456            }
457        }
458        "gompertz-makeham" => {
459            if baseline_scale.is_some() {
460                return Err(
461                    "--baseline-target gompertz-makeham does not use --baseline-scale".to_string(),
462                );
463            }
464        }
465        other => {
466            return Err(format!(
467                "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
468            ));
469        }
470    }
471    Ok(())
472}
473
474pub fn parse_comma_f64(v: &str, label: &str) -> Result<Vec<f64>, String> {
475    let mut out = Vec::new();
476    for part in v.split(',') {
477        let t = part.trim();
478        if t.is_empty() {
479            continue;
480        }
481        let parsed = t
482            .parse::<f64>()
483            .map_err(|err| format!("{label} contains non-numeric value '{t}': {err}"))?;
484        if !parsed.is_finite() {
485            return Err(format!("{label} contains non-finite value '{t}'"));
486        }
487        out.push(parsed);
488    }
489    Ok(out)
490}
491
492pub fn effective_link_to_standard(
493    link: LinkFunction,
494    context: &str,
495) -> Result<StandardLink, String> {
496    StandardLink::try_from(link).map_err(|_| {
497        format!(
498            "{context}: state-bearing link `{}` must be routed through `InverseLink::Sas` / `InverseLink::BetaLogistic`, not `Standard(_)`",
499            link.name()
500        )
501    })
502}
503
504pub fn parse_survival_inverse_link(
505    input: SurvivalInverseLinkInput<'_>,
506) -> Result<InverseLink, String> {
507    if let Some(raw) = input.link {
508        let name = raw.trim().to_ascii_lowercase();
509        if name == "loglog" || name == "cauchit" {
510            return Err(format!(
511                "survival --link {name} is not supported: cauchit and loglog have no \
512                 LinkFunction representative and cannot be wrapped in a MixtureLinkSpec; \
513                 {}",
514                survival_link_usage()
515            ));
516        }
517    }
518    let choice = parse_link_choice(input.link, false).map_err(|err| {
519        let err = err.to_string();
520        if let Some(raw) = input.link {
521            let name = raw.trim().to_ascii_lowercase();
522            if err.starts_with("unsupported --link ") || err.starts_with("unsupported link type ") {
523                return format!(
524                    "unsupported survival --link '{name}'; {}",
525                    survival_link_usage()
526                );
527            }
528        }
529        err
530    })?;
531    if let Some(choice) = choice {
532        if let Some(components) = choice.mixture_components {
533            if input.sas_init.is_some() || input.beta_logistic_init.is_some() {
534                return Err(
535                    "survival blended(...) link does not accept --sas-init/--beta-logistic-init"
536                        .to_string(),
537                );
538            }
539            let expected = components.len().saturating_sub(1);
540            let initial_rho = if let Some(raw) = input.mixture_rho {
541                let vals = parse_comma_f64(raw, "--mixture-rho")?;
542                if vals.len() != expected {
543                    return Err(format!(
544                        "--mixture-rho expects {expected} values for blended({})",
545                        components
546                            .iter()
547                            .map(|component| component.name())
548                            .collect::<Vec<_>>()
549                            .join(",")
550                    ));
551                }
552                Array1::from_vec(vals)
553            } else {
554                Array1::zeros(expected)
555            };
556            return state_fromspec(&MixtureLinkSpec {
557                components,
558                initial_rho,
559            })
560            .map(InverseLink::Mixture)
561            .map_err(|e| format!("invalid survival blended link state: {e}"));
562        }
563
564        if input.mixture_rho.is_some() {
565            return Err(
566                "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
567            );
568        }
569        match choice.link {
570            LinkFunction::Sas => {
571                if input.beta_logistic_init.is_some() {
572                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
573                }
574                let (epsilon, log_delta) = if let Some(raw) = input.sas_init {
575                    let vals = parse_comma_f64(raw, "--sas-init")?;
576                    if vals.len() != 2 {
577                        return Err(format!(
578                            "--sas-init expects two values: epsilon,log_delta (got {})",
579                            vals.len()
580                        ));
581                    }
582                    (vals[0], vals[1])
583                } else {
584                    (0.0, 0.0)
585                };
586                state_from_sasspec(SasLinkSpec {
587                    initial_epsilon: epsilon,
588                    initial_log_delta: log_delta,
589                })
590                .map(InverseLink::Sas)
591                .map_err(|e| format!("invalid survival SAS link state: {e}"))
592            }
593            LinkFunction::BetaLogistic => {
594                if input.sas_init.is_some() {
595                    return Err("--sas-init requires --link sas".to_string());
596                }
597                let (epsilon, delta) = if let Some(raw) = input.beta_logistic_init {
598                    let vals = parse_comma_f64(raw, "--beta-logistic-init")?;
599                    if vals.len() != 2 {
600                        return Err(format!(
601                            "--beta-logistic-init expects two values: epsilon,delta (got {})",
602                            vals.len()
603                        ));
604                    }
605                    (vals[0], vals[1])
606                } else {
607                    (0.0, 0.0)
608                };
609                state_from_beta_logisticspec(SasLinkSpec {
610                    initial_epsilon: epsilon,
611                    initial_log_delta: delta,
612                })
613                .map(InverseLink::BetaLogistic)
614                .map_err(|e| format!("invalid survival Beta-Logistic link state: {e}"))
615            }
616            LinkFunction::Log => Err(format!(
617                "unsupported survival --link 'log'; {}",
618                survival_link_usage()
619            )),
620            other => {
621                if input.sas_init.is_some() {
622                    return Err("--sas-init requires --link sas".to_string());
623                }
624                if input.beta_logistic_init.is_some() {
625                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
626                }
627                Ok(InverseLink::Standard(effective_link_to_standard(
628                    other,
629                    "survival inverse link",
630                )?))
631            }
632        }
633    } else {
634        if input.mixture_rho.is_some() {
635            return Err("--mixture-rho requires --link blended(...)/mixture(...)".to_string());
636        }
637        if input.sas_init.is_some() {
638            return Err("--sas-init requires --link sas".to_string());
639        }
640        if input.beta_logistic_init.is_some() {
641            return Err("--beta-logistic-init requires --link beta-logistic".to_string());
642        }
643        let dist = parse_survival_distribution(input.survival_distribution)?;
644        Ok(residual_distribution_inverse_link(dist))
645    }
646}
647
648pub fn normalize_optional_family(family: Option<String>) -> Option<String> {
649    match family {
650        Some(value) if value.eq_ignore_ascii_case("auto") => None,
651        other => other,
652    }
653}
654
655fn resolve_nonempty_string(raw: String, message: &str) -> Result<String, String> {
656    let trimmed = raw.trim();
657    if trimmed.is_empty() {
658        return Err(message.to_string());
659    }
660    Ok(trimmed.to_string())
661}
662
663fn parse_json_frailty_spec(
664    frailty_kind: Option<String>,
665    frailty_sd: Option<f64>,
666    hazard_loading: Option<String>,
667) -> Result<Option<FrailtySpec>, String> {
668    if let Some(kind) = frailty_kind {
669        let trimmed = kind.trim().to_ascii_lowercase();
670        let sigma = frailty_sd;
671        if let Some(value) = sigma
672            && (!value.is_finite() || value < 0.0)
673        {
674            return Err(format!("frailty_sd must be finite and >= 0, got {value}"));
675        }
676        let hazard_loading = hazard_loading
677            .as_ref()
678            .map(|raw| raw.trim().to_ascii_lowercase());
679        let frailty = match trimmed.as_str() {
680            "none" | "" => {
681                if sigma.is_some() || hazard_loading.is_some() {
682                    return Err(
683                        "frailty_kind='none' does not accept frailty_sd or hazard_loading"
684                            .to_string(),
685                    );
686                }
687                FrailtySpec::None
688            }
689            "hazard-multiplier" => {
690                let loading = match hazard_loading.as_deref() {
691                    Some("full") | None => HazardLoading::Full,
692                    Some("loaded-vs-unloaded") => HazardLoading::LoadedVsUnloaded,
693                    Some(other) => {
694                        return Err(format!(
695                            "unknown hazard_loading '{other}'; supported: 'full', 'loaded-vs-unloaded'"
696                        ));
697                    }
698                };
699                FrailtySpec::HazardMultiplier {
700                    sigma_fixed: sigma,
701                    loading,
702                }
703            }
704            "gaussian-shift" => {
705                if hazard_loading.is_some() {
706                    return Err(
707                        "hazard_loading is valid only with frailty_kind='hazard-multiplier'"
708                            .to_string(),
709                    );
710                }
711                FrailtySpec::GaussianShift { sigma_fixed: sigma }
712            }
713            other => {
714                return Err(format!(
715                    "unknown frailty_kind '{other}'; supported: 'none', 'hazard-multiplier', 'gaussian-shift'"
716                ));
717            }
718        };
719        Ok(Some(frailty))
720    } else if frailty_sd.is_some() || hazard_loading.is_some() {
721        Err("frailty_kind is required when frailty_sd or hazard_loading is provided".to_string())
722    } else {
723        Ok(None)
724    }
725}
726
727fn cli_hazard_loading(loading: CliHazardLoading) -> HazardLoading {
728    match loading {
729        CliHazardLoading::Full => HazardLoading::Full,
730        CliHazardLoading::LoadedVsUnloaded => HazardLoading::LoadedVsUnloaded,
731    }
732}
733
734fn parse_group_metadata(
735    direct: Option<GroupMetadata>,
736    groups: Option<JsonValue>,
737) -> Result<Option<GroupMetadata>, String> {
738    match (direct, groups) {
739        (Some(metadata), None) => Ok(nonempty_group_metadata(metadata)),
740        (None, Some(groups)) => group_metadata_from_groups(groups),
741        (None, None) => Ok(None),
742        (Some(_), Some(_)) => {
743            Err("fit config accepts either group_metadata or groups metadata, not both".to_string())
744        }
745    }
746}
747
748fn parse_gamma_pair_value(label: &str, value: JsonValue) -> Result<(String, f64, f64), String> {
749    match value {
750        JsonValue::Array(values) => {
751            if values.len() != 2 {
752                return Err(format!(
753                    "precision_hyperpriors['{label}'] must be [shape, rate]"
754                ));
755            }
756            let shape = values[0]
757                .as_f64()
758                .ok_or_else(|| format!("precision_hyperpriors['{label}'][0] must be numeric"))?;
759            let rate = values[1]
760                .as_f64()
761                .ok_or_else(|| format!("precision_hyperpriors['{label}'][1] must be numeric"))?;
762            Ok((label.to_string(), shape, rate))
763        }
764        JsonValue::Object(mut map) => {
765            let shape = map
766                .remove("shape")
767                .or_else(|| map.remove("a"))
768                .or_else(|| map.remove("a_p"))
769                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing shape/a"))?
770                .as_f64()
771                .ok_or_else(|| {
772                    format!("precision_hyperpriors['{label}'] shape/a must be numeric")
773                })?;
774            let rate = map
775                .remove("rate")
776                .or_else(|| map.remove("b"))
777                .or_else(|| map.remove("b_p"))
778                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing rate/b"))?
779                .as_f64()
780                .ok_or_else(|| {
781                    format!("precision_hyperpriors['{label}'] rate/b must be numeric")
782                })?;
783            Ok((label.to_string(), shape, rate))
784        }
785        _ => Err(format!(
786            "precision_hyperpriors['{label}'] must be [shape, rate] or an object"
787        )),
788    }
789}
790
791fn parse_precision_hyperpriors(
792    precision_hyperpriors: Option<JsonValue>,
793    penalty_block_gamma_priors: Option<JsonValue>,
794) -> Result<Vec<(String, f64, f64)>, String> {
795    let raw = match (precision_hyperpriors, penalty_block_gamma_priors) {
796        (Some(_), Some(_)) => {
797            return Err(
798                "fit config accepts either precision_hyperpriors or penalty_block_gamma_priors, not both"
799                    .to_string(),
800            );
801        }
802        (Some(raw), None) | (None, Some(raw)) => raw,
803        (None, None) => {
804            return Ok(Vec::new());
805        }
806    };
807    let raw_name = "precision_hyperpriors";
808    let Some(raw) = (match raw {
809        JsonValue::Null => None,
810        other => Some(other),
811    }) else {
812        return Ok(Vec::new());
813    };
814    match raw {
815        JsonValue::Object(map) => map
816            .into_iter()
817            .map(|(label, value)| parse_gamma_pair_value(&label, value))
818            .collect(),
819        JsonValue::Array(items) => items
820            .into_iter()
821            .enumerate()
822            .map(|(idx, item)| match item {
823                JsonValue::Object(mut obj) => {
824                    let label = obj
825                        .remove("label")
826                        .or_else(|| obj.remove("name"))
827                        .or_else(|| obj.remove("group"))
828                        .ok_or_else(|| format!("{raw_name}[{idx}] needs label/name/group"))?;
829                    let JsonValue::String(label) = label else {
830                        return Err(format!("{raw_name}[{idx}] label must be a string"));
831                    };
832                    parse_gamma_pair_value(&label, JsonValue::Object(obj))
833                }
834                JsonValue::Array(mut values) => {
835                    if values.len() != 2 && values.len() != 3 {
836                        return Err(format!(
837                            "{raw_name}[{idx}] must be [label, shape, rate] or [label, [shape, rate]]"
838                        ));
839                    }
840                    let label = values.remove(0);
841                    let JsonValue::String(label) = label else {
842                        return Err(format!("{raw_name}[{idx}][0] must be a string label"));
843                    };
844                    let pair = if values.len() == 1 {
845                        values.remove(0)
846                    } else {
847                        JsonValue::Array(values)
848                    };
849                    parse_gamma_pair_value(&label, pair)
850                }
851                _ => Err(format!("{raw_name}[{idx}] must be an object or array")),
852            })
853            .collect(),
854        _ => Err(format!("{raw_name} must be a map or array")),
855    }
856}
857
858fn nonempty_group_metadata(metadata: GroupMetadata) -> Option<GroupMetadata> {
859    if metadata.is_empty() {
860        None
861    } else {
862        Some(metadata)
863    }
864}
865
866fn group_metadata_from_groups(groups: JsonValue) -> Result<Option<GroupMetadata>, String> {
867    match groups {
868        JsonValue::Null => Ok(None),
869        JsonValue::Object(map) => {
870            let out = map.into_iter().collect::<BTreeMap<_, _>>();
871            Ok(nonempty_group_metadata(out))
872        }
873        JsonValue::Array(items) => {
874            let mut out = BTreeMap::new();
875            for (idx, item) in items.into_iter().enumerate() {
876                let JsonValue::Object(mut group) = item else {
877                    return Err(format!("groups[{idx}] must be an object"));
878                };
879                let Some(metadata) = group.remove("metadata") else {
880                    continue;
881                };
882                let name = group
883                    .remove("name")
884                    .or_else(|| group.remove("id"))
885                    .or_else(|| group.remove("key"))
886                    .ok_or_else(|| {
887                        format!(
888                            "groups[{idx}] with metadata must include a string name, id, or key"
889                        )
890                    })?;
891                let JsonValue::String(name) = name else {
892                    return Err(format!("groups[{idx}] name/id/key must be a string"));
893                };
894                if name.is_empty() {
895                    return Err(format!("groups[{idx}] name/id/key must be non-empty"));
896                }
897                if out.insert(name.clone(), metadata).is_some() {
898                    return Err(format!("duplicate group metadata key '{name}'"));
899                }
900            }
901            Ok(nonempty_group_metadata(out))
902        }
903        _ => Err("groups must be an object map or an array of group objects".to_string()),
904    }
905}
906
907fn parse_gpu_policy(raw_gpu: &str) -> Result<gam_gpu::GpuPolicy, String> {
908    gam_gpu::GpuPolicy::parse(raw_gpu).ok_or_else(|| {
909        format!(
910            "invalid gpu policy '{}'; supported values are auto, off, force",
911            raw_gpu
912        )
913    })
914}
915
916fn validate_resolved_fit_config(config: &FitConfig) -> Result<(), String> {
917    if !config.ridge_lambda.is_finite() || config.ridge_lambda < 0.0 {
918        return Err("--ridge-lambda must be finite and >= 0".to_string());
919    }
920    let likelihood_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
921    validate_survival_baseline_args(
922        likelihood_mode,
923        &config.baseline_target,
924        config.baseline_scale,
925        config.baseline_shape,
926        config.baseline_rate,
927        config.baseline_makeham,
928    )
929}
930
931fn survival_link_usage() -> &'static str {
932    "use identity|logit|probit|cloglog|sas|beta-logistic|blended(...)/mixture(...) or flexible(...)"
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938    use gam_models::survival::lognormal_kernel::FrailtySpec;
939    use serde_json::{Value, json};
940
941    struct ParityCase {
942        name: &'static str,
943        cli: CliFitConfigInput,
944        json: Value,
945    }
946
947    fn base_cli() -> CliFitConfigInput {
948        CliFitConfigInput {
949            family: None,
950            negative_binomial_theta: None,
951            link: None,
952            flexible_link: false,
953            offset_column: None,
954            weight_column: None,
955            noise_offset_column: None,
956            baseline_target: "linear".to_string(),
957            baseline_scale: None,
958            baseline_shape: None,
959            baseline_rate: None,
960            baseline_makeham: None,
961            time_basis: "ispline".to_string(),
962            time_degree: 3,
963            time_num_internal_knots: 8,
964            time_smooth_lambda: 1e-2,
965            survival_likelihood: "transformation".to_string(),
966            survival_distribution: "gaussian".to_string(),
967            threshold_time_k: None,
968            threshold_time_degree: 3,
969            sigma_time_k: None,
970            sigma_time_degree: 3,
971            noise_formula: None,
972            logslope_formula: None,
973            z_column: None,
974            scale_dimensions: false,
975            adaptive_regularization: None,
976            ridge_lambda: 1e-6,
977            transformation_normal: false,
978            firth: false,
979            outer_max_iter: None,
980            gpu: None,
981            frailty_kind: None,
982            frailty_sd: None,
983            hazard_loading: None,
984        }
985    }
986
987    fn resolved_cli(input: CliFitConfigInput) -> Result<FitConfig, String> {
988        resolve_cli_fit_config(input)
989    }
990
991    fn resolved_json(config: Value) -> Result<FitConfig, String> {
992        parse_fit_config_json(Some(&config.to_string())).map(|resolved| {
993            assert_eq!(resolved.training_table_kind, None);
994            resolved.fit_config
995        })
996    }
997
998    fn canonical_fit_config(mut config: FitConfig) -> String {
999        if config.frailty.is_none() {
1000            config.frailty = Some(FrailtySpec::None);
1001        }
1002        format!("{config:#?}")
1003    }
1004
1005    #[test]
1006    fn cli_shaped_and_json_wire_config_resolution_match() {
1007        let cases = vec![
1008            ParityCase {
1009                name: "family and link selection",
1010                cli: {
1011                    let mut input = base_cli();
1012                    input.family = Some("binomial".to_string());
1013                    input.link = Some("probit".to_string());
1014                    input.flexible_link = true;
1015                    input
1016                },
1017                json: json!({
1018                    "family": "binomial",
1019                    "link": "probit",
1020                    "flexible_link": true
1021                }),
1022            },
1023            ParityCase {
1024                name: "offset weights ridge and noise offset columns",
1025                cli: {
1026                    let mut input = base_cli();
1027                    input.offset_column = Some("eta_offset".to_string());
1028                    input.weight_column = Some("case_weight".to_string());
1029                    input.noise_offset_column = Some("sigma_offset".to_string());
1030                    input.ridge_lambda = 0.125;
1031                    input
1032                },
1033                json: json!({
1034                    "offset": "eta_offset",
1035                    "weights": "case_weight",
1036                    "noise_offset": "sigma_offset",
1037                    "ridge_lambda": 0.125
1038                }),
1039            },
1040            ParityCase {
1041                name: "weibull survival likelihood and baseline scale shape",
1042                cli: {
1043                    let mut input = base_cli();
1044                    input.survival_likelihood = "weibull".to_string();
1045                    input.baseline_target = "weibull".to_string();
1046                    input.baseline_scale = Some(2.5);
1047                    input.baseline_shape = Some(1.75);
1048                    input
1049                },
1050                json: json!({
1051                    "survival_likelihood": "weibull",
1052                    "baseline_target": "weibull",
1053                    "baseline_scale": 2.5,
1054                    "baseline_shape": 1.75
1055                }),
1056            },
1057            ParityCase {
1058                name: "transformation survival gompertz makeham baseline",
1059                cli: {
1060                    let mut input = base_cli();
1061                    input.survival_likelihood = "transformation".to_string();
1062                    input.baseline_target = "gompertz-makeham".to_string();
1063                    input.baseline_shape = Some(1.2);
1064                    input.baseline_rate = Some(0.04);
1065                    input.baseline_makeham = Some(0.01);
1066                    input
1067                },
1068                json: json!({
1069                    "survival_likelihood": "transformation",
1070                    "baseline_target": "gompertz-makeham",
1071                    "baseline_shape": 1.2,
1072                    "baseline_rate": 0.04,
1073                    "baseline_makeham": 0.01
1074                }),
1075            },
1076            ParityCase {
1077                name: "survival likelihood values are canonicalized",
1078                cli: {
1079                    let mut input = base_cli();
1080                    input.survival_likelihood = "TRANSFORMATION".to_string();
1081                    input
1082                },
1083                json: json!({
1084                    "survival_likelihood": "Transformation"
1085                }),
1086            },
1087            ParityCase {
1088                name: "noise formula logslope z column and scale dimensions",
1089                cli: {
1090                    let mut input = base_cli();
1091                    input.noise_formula = Some("~ s(age) + treatment".to_string());
1092                    input.logslope_formula = Some("~ s(dose)".to_string());
1093                    input.z_column = Some("dose".to_string());
1094                    input.scale_dimensions = true;
1095                    input
1096                },
1097                json: json!({
1098                    "noise_formula": "~ s(age) + treatment",
1099                    "logslope_formula": "~ s(dose)",
1100                    "z_column": "dose",
1101                    "scale_dimensions": true
1102                }),
1103            },
1104            ParityCase {
1105                name: "firth transformation normal outer iterations and adaptive regularization",
1106                cli: {
1107                    let mut input = base_cli();
1108                    input.firth = true;
1109                    input.transformation_normal = true;
1110                    input.outer_max_iter = Some(7);
1111                    input.adaptive_regularization = Some(true);
1112                    input
1113                },
1114                json: json!({
1115                    "firth": true,
1116                    "transformation_normal": true,
1117                    "outer_max_iter": 7,
1118                    "adaptive_regularization": true
1119                }),
1120            },
1121            ParityCase {
1122                name: "gpu policy toggle",
1123                cli: {
1124                    let mut input = base_cli();
1125                    input.gpu = Some("off".to_string());
1126                    input
1127                },
1128                json: json!({
1129                    "gpu": "off"
1130                }),
1131            },
1132            ParityCase {
1133                name: "hazard multiplier frailty fields",
1134                cli: {
1135                    let mut input = base_cli();
1136                    input.frailty_kind = Some(CliFrailtyKind::HazardMultiplier);
1137                    input.frailty_sd = Some(0.35);
1138                    input.hazard_loading = Some(CliHazardLoading::LoadedVsUnloaded);
1139                    input
1140                },
1141                json: json!({
1142                    "frailty_kind": "hazard-multiplier",
1143                    "frailty_sd": 0.35,
1144                    "hazard_loading": "loaded-vs-unloaded"
1145                }),
1146            },
1147            ParityCase {
1148                name: "gaussian shift frailty fields",
1149                cli: {
1150                    let mut input = base_cli();
1151                    input.frailty_kind = Some(CliFrailtyKind::GaussianShift);
1152                    input.frailty_sd = Some(0.2);
1153                    input
1154                },
1155                json: json!({
1156                    "frailty_kind": "gaussian-shift",
1157                    "frailty_sd": 0.2
1158                }),
1159            },
1160        ];
1161
1162        for case in cases {
1163            let cli = resolved_cli(case.cli)
1164                .unwrap_or_else(|err| panic!("{}: CLI-shaped config failed: {err}", case.name));
1165            let json = resolved_json(case.json)
1166                .unwrap_or_else(|err| panic!("{}: JSON wire config failed: {err}", case.name));
1167            assert_eq!(
1168                canonical_fit_config(cli),
1169                canonical_fit_config(json),
1170                "{}",
1171                case.name
1172            );
1173        }
1174    }
1175
1176    #[test]
1177    fn cli_shaped_and_json_wire_config_resolution_rejections_match() {
1178        let cases = vec![
1179            ParityCase {
1180                name: "negative ridge lambda",
1181                cli: {
1182                    let mut input = base_cli();
1183                    input.ridge_lambda = -1.0;
1184                    input
1185                },
1186                json: json!({
1187                    "ridge_lambda": -1.0
1188                }),
1189            },
1190            ParityCase {
1191                name: "unknown gpu policy",
1192                cli: {
1193                    let mut input = base_cli();
1194                    input.gpu = Some("cuda".to_string());
1195                    input
1196                },
1197                json: json!({
1198                    "gpu": "cuda"
1199                }),
1200            },
1201            ParityCase {
1202                name: "linear baseline rejects shape",
1203                cli: {
1204                    let mut input = base_cli();
1205                    input.baseline_shape = Some(1.1);
1206                    input
1207                },
1208                json: json!({
1209                    "baseline_shape": 1.1
1210                }),
1211            },
1212            ParityCase {
1213                name: "weibull likelihood rejects gompertz target",
1214                cli: {
1215                    let mut input = base_cli();
1216                    input.survival_likelihood = "weibull".to_string();
1217                    input.baseline_target = "gompertz".to_string();
1218                    input
1219                },
1220                json: json!({
1221                    "survival_likelihood": "weibull",
1222                    "baseline_target": "gompertz"
1223                }),
1224            },
1225        ];
1226
1227        for case in cases {
1228            let cli = resolved_cli(case.cli).expect_err(case.name);
1229            let json = resolved_json(case.json).expect_err(case.name);
1230            assert_eq!(cli, json, "{}", case.name);
1231        }
1232    }
1233
1234    // ── parse_comma_f64 ───────────────────────────────────────────────────
1235
1236    #[test]
1237    fn parse_comma_f64_empty_string_returns_empty_vec() {
1238        assert_eq!(parse_comma_f64("", "x").unwrap(), Vec::<f64>::new());
1239        assert_eq!(parse_comma_f64("   ", "x").unwrap(), Vec::<f64>::new());
1240    }
1241
1242    #[test]
1243    fn parse_comma_f64_single_value() {
1244        assert_eq!(parse_comma_f64("3.14", "x").unwrap(), vec![3.14]);
1245    }
1246
1247    #[test]
1248    fn parse_comma_f64_multiple_values_with_spaces() {
1249        let result = parse_comma_f64("1.0, 2.5, -3.0", "x").unwrap();
1250        assert_eq!(result, vec![1.0, 2.5, -3.0]);
1251    }
1252
1253    #[test]
1254    fn parse_comma_f64_non_numeric_returns_error() {
1255        let err = parse_comma_f64("1.0, bad, 3.0", "--vals").unwrap_err();
1256        assert!(err.contains("--vals"), "error should name the label: {err}");
1257        assert!(err.contains("bad"), "error should name the bad token: {err}");
1258    }
1259
1260    #[test]
1261    fn parse_comma_f64_infinity_returns_error() {
1262        let err = parse_comma_f64("inf", "--vals").unwrap_err();
1263        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1264    }
1265
1266    #[test]
1267    fn parse_comma_f64_nan_returns_error() {
1268        let err = parse_comma_f64("nan", "--vals").unwrap_err();
1269        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1270    }
1271
1272    // ── normalize_optional_family ─────────────────────────────────────────
1273
1274    #[test]
1275    fn normalize_optional_family_none_passthrough() {
1276        assert_eq!(normalize_optional_family(None), None);
1277    }
1278
1279    #[test]
1280    fn normalize_optional_family_auto_becomes_none() {
1281        assert_eq!(normalize_optional_family(Some("auto".to_string())), None);
1282        assert_eq!(normalize_optional_family(Some("Auto".to_string())), None);
1283        assert_eq!(normalize_optional_family(Some("AUTO".to_string())), None);
1284    }
1285
1286    #[test]
1287    fn normalize_optional_family_non_auto_passthrough() {
1288        assert_eq!(
1289            normalize_optional_family(Some("binomial".to_string())),
1290            Some("binomial".to_string())
1291        );
1292        assert_eq!(
1293            normalize_optional_family(Some("gaussian".to_string())),
1294            Some("gaussian".to_string())
1295        );
1296    }
1297
1298    // ── parse_survival_likelihood_cli ─────────────────────────────────────
1299
1300    #[test]
1301    fn parse_survival_likelihood_cli_valid_values() {
1302        assert_eq!(
1303            parse_survival_likelihood_cli("transformation").unwrap(),
1304            "transformation"
1305        );
1306        assert_eq!(
1307            parse_survival_likelihood_cli("weibull").unwrap(),
1308            "weibull"
1309        );
1310        // case-insensitive
1311        assert_eq!(
1312            parse_survival_likelihood_cli("WEIBULL").unwrap(),
1313            "weibull"
1314        );
1315        assert_eq!(
1316            parse_survival_likelihood_cli("Transformation").unwrap(),
1317            "transformation"
1318        );
1319    }
1320
1321    #[test]
1322    fn parse_survival_likelihood_cli_invalid_returns_error() {
1323        assert!(parse_survival_likelihood_cli("lognormal").is_err());
1324        assert!(parse_survival_likelihood_cli("").is_err());
1325    }
1326
1327    // ── parse_baseline_target_cli ─────────────────────────────────────────
1328
1329    #[test]
1330    fn parse_baseline_target_cli_valid_values() {
1331        for target in &["linear", "weibull", "gompertz", "gompertz-makeham"] {
1332            assert_eq!(
1333                parse_baseline_target_cli(target).unwrap(),
1334                *target,
1335                "should accept '{target}'"
1336            );
1337        }
1338        // trimmed and lowercased
1339        assert_eq!(
1340            parse_baseline_target_cli("  Weibull  ").unwrap(),
1341            "weibull"
1342        );
1343    }
1344
1345    #[test]
1346    fn parse_baseline_target_cli_invalid_returns_error() {
1347        let err = parse_baseline_target_cli("cox").unwrap_err();
1348        assert!(err.contains("cox"), "error should name the bad value: {err}");
1349    }
1350
1351    // ── validate_survival_baseline_args ───────────────────────────────────
1352
1353    #[test]
1354    fn validate_survival_baseline_args_linear_rejects_params() {
1355        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1356        assert!(validate_survival_baseline_args(
1357            mode,
1358            "linear",
1359            Some(1.0),
1360            None,
1361            None,
1362            None
1363        )
1364        .is_err());
1365    }
1366
1367    #[test]
1368    fn validate_survival_baseline_args_linear_accepts_no_params() {
1369        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1370        assert!(validate_survival_baseline_args(
1371            mode, "linear", None, None, None, None
1372        )
1373        .is_ok());
1374    }
1375
1376    #[test]
1377    fn validate_survival_baseline_args_weibull_likelihood_rejects_gompertz_target() {
1378        let mode = parse_survival_likelihood_mode("weibull").unwrap();
1379        assert!(validate_survival_baseline_args(
1380            mode,
1381            "gompertz",
1382            None,
1383            None,
1384            None,
1385            None
1386        )
1387        .is_err());
1388    }
1389
1390    #[test]
1391    fn validate_survival_baseline_args_gompertz_rejects_scale() {
1392        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1393        assert!(validate_survival_baseline_args(
1394            mode,
1395            "gompertz",
1396            Some(2.0),
1397            None,
1398            None,
1399            None
1400        )
1401        .is_err());
1402    }
1403
1404    #[test]
1405    fn validate_survival_baseline_args_gompertz_makeham_rejects_scale() {
1406        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1407        assert!(validate_survival_baseline_args(
1408            mode,
1409            "gompertz-makeham",
1410            Some(1.0),
1411            None,
1412            None,
1413            None
1414        )
1415        .is_err());
1416    }
1417}