Skip to main content

gam_config/
lib.rs

1use gam_models::survival::location_scale::residual_distribution_inverse_link;
2use gam_models::survival::lognormal_kernel::{FrailtySpec, HazardLoading};
3use gam_models::survival::parse_survival_distribution;
4use gam_models::survival::{SurvivalLikelihoodMode, parse_survival_likelihood_mode};
5use gam_inference::formula_dsl::parse_link_choice;
6use gam_inference::model::GroupMetadata;
7use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
8use gam_models::fit_orchestration::descriptors::build_analytic_penalty_registry_from_descriptors;
9use gam_models::fit_orchestration::{CtnStage1Recipe, FitConfig};
10use gam_models::transformation_normal::TransformationNormalConfig;
11use gam_problem::types::{InverseLink, LinkFunction, MixtureLinkSpec, SasLinkSpec, StandardLink};
12use ndarray::Array1;
13use serde::Deserialize;
14use serde_json::Value as JsonValue;
15use std::collections::BTreeMap;
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18pub enum CliFrailtyKind {
19    GaussianShift,
20    HazardMultiplier,
21}
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum CliHazardLoading {
25    Full,
26    LoadedVsUnloaded,
27}
28
29#[derive(Default, Deserialize)]
30#[serde(deny_unknown_fields)]
31struct JsonFitConfig {
32    family: Option<String>,
33    offset: Option<String>,
34    weights: Option<String>,
35    ridge_lambda: Option<f64>,
36    transformation_normal: Option<bool>,
37    survival_likelihood: Option<String>,
38    baseline_target: Option<String>,
39    baseline_scale: Option<f64>,
40    baseline_shape: Option<f64>,
41    baseline_rate: Option<f64>,
42    baseline_makeham: Option<f64>,
43    z_column: Option<String>,
44    logslope_formula: Option<String>,
45    ctn_stage1: Option<JsonCtnStage1>,
46    link: Option<String>,
47    flexible_link: Option<bool>,
48    scale_dimensions: Option<bool>,
49    adaptive_regularization: Option<bool>,
50    noise_formula: Option<String>,
51    noise_offset: Option<String>,
52    firth: Option<bool>,
53    outer_max_iter: Option<usize>,
54    gpu: Option<String>,
55    group_metadata: Option<GroupMetadata>,
56    groups: Option<JsonValue>,
57    precision_hyperpriors: Option<JsonValue>,
58    penalty_block_gamma_priors: Option<JsonValue>,
59    latents: Option<JsonValue>,
60    penalties: Option<JsonValue>,
61    smooths: Option<JsonValue>,
62    topology_auto_selector: Option<JsonValue>,
63    frailty_kind: Option<String>,
64    frailty_sd: Option<f64>,
65    hazard_loading: Option<String>,
66    training_table_kind: Option<String>,
67}
68
69#[derive(Deserialize)]
70#[serde(deny_unknown_fields)]
71struct JsonCtnStage1 {
72    response_column: String,
73    covariate_formula_rhs: String,
74    #[serde(default)]
75    config: Option<JsonCtnStage1Config>,
76    #[serde(default)]
77    weight_column: Option<String>,
78    #[serde(default)]
79    offset_column: Option<String>,
80}
81
82#[derive(Deserialize)]
83#[serde(deny_unknown_fields)]
84struct JsonCtnStage1Config {
85    #[serde(default)]
86    response_degree: Option<usize>,
87    #[serde(default)]
88    response_num_internal_knots: Option<usize>,
89    #[serde(default)]
90    response_penalty_order: Option<usize>,
91    #[serde(default)]
92    response_extra_penalty_orders: Option<Vec<usize>>,
93    #[serde(default)]
94    double_penalty: Option<bool>,
95}
96
97impl JsonCtnStage1 {
98    fn into_recipe(self) -> Result<CtnStage1Recipe, String> {
99        let mut config = TransformationNormalConfig::default();
100        if let Some(overrides) = self.config {
101            if let Some(value) = overrides.response_degree {
102                config.response_degree = value;
103            }
104            if let Some(value) = overrides.response_num_internal_knots {
105                config.response_num_internal_knots = value;
106            }
107            if let Some(value) = overrides.response_penalty_order {
108                config.response_penalty_order = value;
109            }
110            if let Some(value) = overrides.response_extra_penalty_orders {
111                config.response_extra_penalty_orders = value;
112            }
113            if let Some(value) = overrides.double_penalty {
114                config.double_penalty = value;
115            }
116        }
117        CtnStage1Recipe::new(
118            &self.response_column,
119            &self.covariate_formula_rhs,
120            config,
121            self.weight_column.as_deref(),
122            self.offset_column.as_deref(),
123        )
124    }
125}
126
127pub struct ResolvedFitConfig {
128    pub fit_config: FitConfig,
129    pub training_table_kind: Option<String>,
130}
131
132pub struct CliFitConfigInput {
133    pub family: Option<String>,
134    /// Expectile asymmetry `τ ∈ (0, 1)` for `--family expectile`. The CLI family
135    /// flag is a fixed value-enum that cannot carry the inline `expectile(τ)`
136    /// asymmetry the string form accepts, so the CLI pins `τ` through this
137    /// separate `--expectile-tau` field; it rides into `FitConfig::expectile_tau`
138    /// and is read by the shared expectile dispatch seam (#1777).
139    pub expectile_tau: Option<f64>,
140    pub negative_binomial_theta: Option<f64>,
141    pub link: Option<String>,
142    pub flexible_link: bool,
143    pub offset_column: Option<String>,
144    pub weight_column: Option<String>,
145    pub noise_offset_column: Option<String>,
146    pub baseline_target: String,
147    pub baseline_scale: Option<f64>,
148    pub baseline_shape: Option<f64>,
149    pub baseline_rate: Option<f64>,
150    pub baseline_makeham: Option<f64>,
151    pub time_basis: String,
152    pub time_degree: usize,
153    pub time_num_internal_knots: usize,
154    pub time_smooth_lambda: f64,
155    pub survival_likelihood: String,
156    pub survival_distribution: String,
157    pub threshold_time_k: Option<usize>,
158    pub threshold_time_degree: usize,
159    pub sigma_time_k: Option<usize>,
160    pub sigma_time_degree: usize,
161    pub noise_formula: Option<String>,
162    pub logslope_formula: Option<String>,
163    pub z_column: Option<String>,
164    pub scale_dimensions: bool,
165    pub adaptive_regularization: Option<bool>,
166    pub ridge_lambda: f64,
167    pub transformation_normal: bool,
168    pub firth: bool,
169    pub outer_max_iter: Option<usize>,
170    pub gpu: Option<String>,
171    pub frailty_kind: Option<CliFrailtyKind>,
172    pub frailty_sd: Option<f64>,
173    pub hazard_loading: Option<CliHazardLoading>,
174}
175
176pub struct SurvivalInverseLinkInput<'a> {
177    pub link: Option<&'a str>,
178    pub mixture_rho: Option<&'a str>,
179    pub sas_init: Option<&'a str>,
180    pub beta_logistic_init: Option<&'a str>,
181    pub survival_distribution: &'a str,
182}
183
184pub fn parse_fit_config_json(config_json: Option<&str>) -> Result<ResolvedFitConfig, String> {
185    let json_config = match config_json {
186        Some(raw) if !raw.trim().is_empty() => serde_json::from_str::<JsonFitConfig>(raw)
187            .map_err(|err| format!("invalid fit config json: {err}"))?,
188        _ => JsonFitConfig::default(),
189    };
190    resolve_json_fit_config(json_config)
191}
192
193fn resolve_json_fit_config(json_config: JsonFitConfig) -> Result<ResolvedFitConfig, String> {
194    let training_table_kind = json_config.training_table_kind;
195    let mut fit_config = FitConfig::default();
196    fit_config.group_metadata =
197        parse_group_metadata(json_config.group_metadata, json_config.groups)?;
198    fit_config.penalty_block_gamma_priors = parse_precision_hyperpriors(
199        json_config.precision_hyperpriors,
200        json_config.penalty_block_gamma_priors,
201    )?;
202    let analytic_penalties = json_config.penalties;
203    build_analytic_penalty_registry_from_descriptors(
204        json_config.latents.as_ref(),
205        analytic_penalties.as_ref(),
206    )?;
207    fit_config.latents = json_config.latents;
208    fit_config.analytic_penalties = analytic_penalties;
209    fit_config.smooth_overrides = json_config.smooths;
210    fit_config.topology_auto_selector = json_config
211        .topology_auto_selector
212        .as_ref()
213        .map(gam_solve::topology_selector::TopologyAutoSelector::from_json)
214        .transpose()?;
215    fit_config.family = normalize_optional_family(json_config.family);
216    fit_config.offset_column = json_config.offset;
217    fit_config.weight_column = json_config.weights;
218    if let Some(ridge_lambda) = json_config.ridge_lambda {
219        fit_config.ridge_lambda = ridge_lambda;
220    }
221    if let Some(flag) = json_config.transformation_normal {
222        fit_config.transformation_normal = flag;
223    }
224    if let Some(mode) = json_config.survival_likelihood {
225        fit_config.survival_likelihood = parse_survival_likelihood_cli(&mode)?;
226    }
227    if let Some(target) = json_config.baseline_target {
228        fit_config.baseline_target =
229            resolve_nonempty_string(target, "baseline_target must be a non-empty string")?;
230    }
231    if let Some(value) = json_config.baseline_scale {
232        fit_config.baseline_scale = Some(value);
233    }
234    if let Some(value) = json_config.baseline_shape {
235        fit_config.baseline_shape = Some(value);
236    }
237    if let Some(value) = json_config.baseline_rate {
238        fit_config.baseline_rate = Some(value);
239    }
240    if let Some(value) = json_config.baseline_makeham {
241        fit_config.baseline_makeham = Some(value);
242    }
243    if let Some(z) = json_config.z_column {
244        fit_config.z_column = Some(resolve_nonempty_string(
245            z,
246            "z_column must be a non-empty string",
247        )?);
248    }
249    if let Some(formula) = json_config.logslope_formula {
250        fit_config.logslope_formula = Some(formula);
251    }
252    if let Some(stage1) = json_config.ctn_stage1 {
253        fit_config.ctn_stage1 = Some(stage1.into_recipe()?);
254    }
255    if let Some(link) = json_config.link {
256        let trimmed = link.trim();
257        fit_config.link = if trimmed.is_empty() {
258            None
259        } else {
260            Some(trimmed.to_string())
261        };
262    }
263    if let Some(flag) = json_config.flexible_link {
264        fit_config.flexible_link = flag;
265    }
266    if let Some(flag) = json_config.scale_dimensions {
267        fit_config.scale_dimensions = flag;
268    }
269    if let Some(flag) = json_config.adaptive_regularization {
270        fit_config.adaptive_regularization = Some(flag);
271    }
272    if let Some(formula) = json_config.noise_formula {
273        fit_config.noise_formula = Some(formula);
274    }
275    if let Some(column) = json_config.noise_offset {
276        fit_config.noise_offset_column = Some(column);
277    }
278    if let Some(flag) = json_config.firth {
279        fit_config.firth = flag;
280    }
281    if let Some(value) = json_config.outer_max_iter {
282        if value == 0 {
283            return Err("outer_max_iter must be >= 1".to_string());
284        }
285        fit_config.outer_max_iter = Some(value);
286    }
287    if let Some(raw_gpu) = json_config.gpu {
288        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
289    }
290    fit_config.frailty = parse_json_frailty_spec(
291        json_config.frailty_kind,
292        json_config.frailty_sd,
293        json_config.hazard_loading,
294    )?;
295    validate_resolved_fit_config(&fit_config)?;
296    Ok(ResolvedFitConfig {
297        fit_config,
298        training_table_kind,
299    })
300}
301
302pub fn resolve_cli_fit_config(input: CliFitConfigInput) -> Result<FitConfig, String> {
303    let mut fit_config = FitConfig::default();
304    fit_config.family = input.family;
305    fit_config.expectile_tau = input.expectile_tau;
306    fit_config.negative_binomial_theta = input.negative_binomial_theta;
307    fit_config.link = input.link;
308    fit_config.flexible_link = input.flexible_link;
309    fit_config.offset_column = input.offset_column;
310    fit_config.weight_column = input.weight_column;
311    fit_config.noise_offset_column = input.noise_offset_column;
312    fit_config.baseline_target = input.baseline_target;
313    fit_config.baseline_scale = input.baseline_scale;
314    fit_config.baseline_shape = input.baseline_shape;
315    fit_config.baseline_rate = input.baseline_rate;
316    fit_config.baseline_makeham = input.baseline_makeham;
317    fit_config.time_basis = input.time_basis;
318    fit_config.time_degree = input.time_degree;
319    fit_config.time_num_internal_knots = input.time_num_internal_knots;
320    fit_config.time_smooth_lambda = input.time_smooth_lambda;
321    fit_config.survival_likelihood = parse_survival_likelihood_cli(&input.survival_likelihood)?;
322    fit_config.survival_distribution = input.survival_distribution;
323    fit_config.threshold_time_k = input.threshold_time_k;
324    fit_config.threshold_time_degree = input.threshold_time_degree;
325    fit_config.sigma_time_k = input.sigma_time_k;
326    fit_config.sigma_time_degree = input.sigma_time_degree;
327    fit_config.noise_formula = input.noise_formula;
328    fit_config.logslope_formula = input.logslope_formula;
329    fit_config.z_column = input.z_column;
330    fit_config.scale_dimensions = input.scale_dimensions;
331    fit_config.adaptive_regularization = input.adaptive_regularization;
332    fit_config.ridge_lambda = input.ridge_lambda;
333    fit_config.transformation_normal = input.transformation_normal;
334    fit_config.firth = input.firth;
335    fit_config.outer_max_iter = input.outer_max_iter;
336    if let Some(raw_gpu) = input.gpu {
337        fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
338    }
339    fit_config.frailty = Some(resolve_cli_frailty_spec(
340        input.frailty_kind,
341        input.frailty_sd,
342        input.hazard_loading,
343        "fit",
344    )?);
345    validate_resolved_fit_config(&fit_config)?;
346    Ok(fit_config)
347}
348
349pub fn resolve_cli_frailty_spec(
350    frailty_kind: Option<CliFrailtyKind>,
351    frailty_sd: Option<f64>,
352    hazard_loading: Option<CliHazardLoading>,
353    context: &str,
354) -> Result<FrailtySpec, String> {
355    let validate_sigma = || -> Result<Option<f64>, String> {
356        match frailty_sd {
357            None => Ok(None),
358            Some(sigma) => {
359                if !sigma.is_finite() || sigma < 0.0 {
360                    return Err(format!(
361                        "{context} requires a finite --frailty-sd >= 0, got {sigma}"
362                    ));
363                }
364                Ok(Some(sigma))
365            }
366        }
367    };
368
369    match frailty_kind {
370        None => {
371            if frailty_sd.is_some() || hazard_loading.is_some() {
372                return Err(format!(
373                    "{context} requires --frailty-kind when --frailty-sd or --hazard-loading is provided"
374                ));
375            }
376            Ok(FrailtySpec::None)
377        }
378        Some(CliFrailtyKind::GaussianShift) => {
379            if hazard_loading.is_some() {
380                return Err(format!(
381                    "{context} does not accept --hazard-loading with --frailty-kind gaussian-shift"
382                ));
383            }
384            Ok(FrailtySpec::GaussianShift {
385                sigma_fixed: validate_sigma()?,
386            })
387        }
388        Some(CliFrailtyKind::HazardMultiplier) => Ok(FrailtySpec::HazardMultiplier {
389            sigma_fixed: validate_sigma()?,
390            loading: hazard_loading.map(cli_hazard_loading).ok_or_else(|| {
391                format!("{context} requires --hazard-loading with --frailty-kind hazard-multiplier")
392            })?,
393        }),
394    }
395}
396
397pub fn parse_survival_likelihood_cli(raw: &str) -> Result<String, String> {
398    let normalized = raw.trim().to_ascii_lowercase();
399    parse_survival_likelihood_mode(&normalized)?;
400    Ok(normalized)
401}
402
403pub fn parse_baseline_target_cli(raw: &str) -> Result<String, String> {
404    let normalized = raw.trim().to_ascii_lowercase();
405    match normalized.as_str() {
406        "linear" | "weibull" | "gompertz" | "gompertz-makeham" => Ok(normalized),
407        other => Err(format!(
408            "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
409        )),
410    }
411}
412
413pub fn validate_survival_baseline_args(
414    likelihood_mode: SurvivalLikelihoodMode,
415    baseline_target: &str,
416    baseline_scale: Option<f64>,
417    baseline_shape: Option<f64>,
418    baseline_rate: Option<f64>,
419    baseline_makeham: Option<f64>,
420) -> Result<(), String> {
421    if likelihood_mode == SurvivalLikelihoodMode::Weibull {
422        if baseline_rate.is_some() || baseline_makeham.is_some() {
423            return Err(
424                "--survival-likelihood weibull does not use --baseline-rate or --baseline-makeham"
425                    .to_string(),
426            );
427        }
428        if !matches!(baseline_target, "linear" | "weibull") {
429            return Err(
430                "--survival-likelihood weibull supports only --baseline-target linear|weibull"
431                    .to_string(),
432            );
433        }
434        return Ok(());
435    }
436
437    match baseline_target {
438        "linear" => {
439            if baseline_scale.is_some()
440                || baseline_shape.is_some()
441                || baseline_rate.is_some()
442                || baseline_makeham.is_some()
443            {
444                return Err(
445                    "--baseline-target linear does not use baseline parameter flags".to_string(),
446                );
447            }
448        }
449        "weibull" => {
450            if baseline_rate.is_some() || baseline_makeham.is_some() {
451                return Err(
452                    "--baseline-target weibull does not use --baseline-rate or --baseline-makeham"
453                        .to_string(),
454                );
455            }
456        }
457        "gompertz" => {
458            if baseline_scale.is_some() || baseline_makeham.is_some() {
459                return Err(
460                    "--baseline-target gompertz does not use --baseline-scale or --baseline-makeham"
461                        .to_string(),
462                );
463            }
464        }
465        "gompertz-makeham" => {
466            if baseline_scale.is_some() {
467                return Err(
468                    "--baseline-target gompertz-makeham does not use --baseline-scale".to_string(),
469                );
470            }
471        }
472        other => {
473            return Err(format!(
474                "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
475            ));
476        }
477    }
478    Ok(())
479}
480
481pub fn parse_comma_f64(v: &str, label: &str) -> Result<Vec<f64>, String> {
482    let mut out = Vec::new();
483    for part in v.split(',') {
484        let t = part.trim();
485        if t.is_empty() {
486            continue;
487        }
488        let parsed = t
489            .parse::<f64>()
490            .map_err(|err| format!("{label} contains non-numeric value '{t}': {err}"))?;
491        if !parsed.is_finite() {
492            return Err(format!("{label} contains non-finite value '{t}'"));
493        }
494        out.push(parsed);
495    }
496    Ok(out)
497}
498
499pub fn effective_link_to_standard(
500    link: LinkFunction,
501    context: &str,
502) -> Result<StandardLink, String> {
503    StandardLink::try_from(link).map_err(|_| {
504        format!(
505            "{context}: state-bearing link `{}` must be routed through `InverseLink::Sas` / `InverseLink::BetaLogistic`, not `Standard(_)`",
506            link.name()
507        )
508    })
509}
510
511pub fn parse_survival_inverse_link(
512    input: SurvivalInverseLinkInput<'_>,
513) -> Result<InverseLink, String> {
514    if let Some(raw) = input.link {
515        let name = raw.trim().to_ascii_lowercase();
516        if name == "loglog" || name == "cauchit" {
517            return Err(format!(
518                "survival --link {name} is not supported: cauchit and loglog have no \
519                 LinkFunction representative and cannot be wrapped in a MixtureLinkSpec; \
520                 {}",
521                survival_link_usage()
522            ));
523        }
524    }
525    let choice = parse_link_choice(input.link, false).map_err(|err| {
526        let err = err.to_string();
527        if let Some(raw) = input.link {
528            let name = raw.trim().to_ascii_lowercase();
529            if err.starts_with("unsupported --link ") || err.starts_with("unsupported link type ") {
530                return format!(
531                    "unsupported survival --link '{name}'; {}",
532                    survival_link_usage()
533                );
534            }
535        }
536        err
537    })?;
538    if let Some(choice) = choice {
539        if let Some(components) = choice.mixture_components {
540            if input.sas_init.is_some() || input.beta_logistic_init.is_some() {
541                return Err(
542                    "survival blended(...) link does not accept --sas-init/--beta-logistic-init"
543                        .to_string(),
544                );
545            }
546            let expected = components.len().saturating_sub(1);
547            let initial_rho = if let Some(raw) = input.mixture_rho {
548                let vals = parse_comma_f64(raw, "--mixture-rho")?;
549                if vals.len() != expected {
550                    return Err(format!(
551                        "--mixture-rho expects {expected} values for blended({})",
552                        components
553                            .iter()
554                            .map(|component| component.name())
555                            .collect::<Vec<_>>()
556                            .join(",")
557                    ));
558                }
559                Array1::from_vec(vals)
560            } else {
561                Array1::zeros(expected)
562            };
563            return state_fromspec(&MixtureLinkSpec {
564                components,
565                initial_rho,
566            })
567            .map(InverseLink::Mixture)
568            .map_err(|e| format!("invalid survival blended link state: {e}"));
569        }
570
571        if input.mixture_rho.is_some() {
572            return Err(
573                "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
574            );
575        }
576        match choice.link {
577            LinkFunction::Sas => {
578                if input.beta_logistic_init.is_some() {
579                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
580                }
581                let (epsilon, log_delta) = if let Some(raw) = input.sas_init {
582                    let vals = parse_comma_f64(raw, "--sas-init")?;
583                    if vals.len() != 2 {
584                        return Err(format!(
585                            "--sas-init expects two values: epsilon,log_delta (got {})",
586                            vals.len()
587                        ));
588                    }
589                    (vals[0], vals[1])
590                } else {
591                    (0.0, 0.0)
592                };
593                state_from_sasspec(SasLinkSpec {
594                    initial_epsilon: epsilon,
595                    initial_log_delta: log_delta,
596                })
597                .map(InverseLink::Sas)
598                .map_err(|e| format!("invalid survival SAS link state: {e}"))
599            }
600            LinkFunction::BetaLogistic => {
601                if input.sas_init.is_some() {
602                    return Err("--sas-init requires --link sas".to_string());
603                }
604                let (epsilon, delta) = if let Some(raw) = input.beta_logistic_init {
605                    let vals = parse_comma_f64(raw, "--beta-logistic-init")?;
606                    if vals.len() != 2 {
607                        return Err(format!(
608                            "--beta-logistic-init expects two values: epsilon,delta (got {})",
609                            vals.len()
610                        ));
611                    }
612                    (vals[0], vals[1])
613                } else {
614                    (0.0, 0.0)
615                };
616                state_from_beta_logisticspec(SasLinkSpec {
617                    initial_epsilon: epsilon,
618                    initial_log_delta: delta,
619                })
620                .map(InverseLink::BetaLogistic)
621                .map_err(|e| format!("invalid survival Beta-Logistic link state: {e}"))
622            }
623            LinkFunction::Log => Err(format!(
624                "unsupported survival --link 'log'; {}",
625                survival_link_usage()
626            )),
627            other => {
628                if input.sas_init.is_some() {
629                    return Err("--sas-init requires --link sas".to_string());
630                }
631                if input.beta_logistic_init.is_some() {
632                    return Err("--beta-logistic-init requires --link beta-logistic".to_string());
633                }
634                Ok(InverseLink::Standard(effective_link_to_standard(
635                    other,
636                    "survival inverse link",
637                )?))
638            }
639        }
640    } else {
641        if input.mixture_rho.is_some() {
642            return Err("--mixture-rho requires --link blended(...)/mixture(...)".to_string());
643        }
644        if input.sas_init.is_some() {
645            return Err("--sas-init requires --link sas".to_string());
646        }
647        if input.beta_logistic_init.is_some() {
648            return Err("--beta-logistic-init requires --link beta-logistic".to_string());
649        }
650        let dist = parse_survival_distribution(input.survival_distribution)?;
651        Ok(residual_distribution_inverse_link(dist))
652    }
653}
654
655pub fn normalize_optional_family(family: Option<String>) -> Option<String> {
656    match family {
657        Some(value) if value.eq_ignore_ascii_case("auto") => None,
658        other => other,
659    }
660}
661
662fn resolve_nonempty_string(raw: String, message: &str) -> Result<String, String> {
663    let trimmed = raw.trim();
664    if trimmed.is_empty() {
665        return Err(message.to_string());
666    }
667    Ok(trimmed.to_string())
668}
669
670fn parse_json_frailty_spec(
671    frailty_kind: Option<String>,
672    frailty_sd: Option<f64>,
673    hazard_loading: Option<String>,
674) -> Result<Option<FrailtySpec>, String> {
675    if let Some(kind) = frailty_kind {
676        let trimmed = kind.trim().to_ascii_lowercase();
677        let sigma = frailty_sd;
678        if let Some(value) = sigma
679            && (!value.is_finite() || value < 0.0)
680        {
681            return Err(format!("frailty_sd must be finite and >= 0, got {value}"));
682        }
683        let hazard_loading = hazard_loading
684            .as_ref()
685            .map(|raw| raw.trim().to_ascii_lowercase());
686        let frailty = match trimmed.as_str() {
687            "none" | "" => {
688                if sigma.is_some() || hazard_loading.is_some() {
689                    return Err(
690                        "frailty_kind='none' does not accept frailty_sd or hazard_loading"
691                            .to_string(),
692                    );
693                }
694                FrailtySpec::None
695            }
696            "hazard-multiplier" => {
697                let loading = match hazard_loading.as_deref() {
698                    Some("full") | None => HazardLoading::Full,
699                    Some("loaded-vs-unloaded") => HazardLoading::LoadedVsUnloaded,
700                    Some(other) => {
701                        return Err(format!(
702                            "unknown hazard_loading '{other}'; supported: 'full', 'loaded-vs-unloaded'"
703                        ));
704                    }
705                };
706                FrailtySpec::HazardMultiplier {
707                    sigma_fixed: sigma,
708                    loading,
709                }
710            }
711            "gaussian-shift" => {
712                if hazard_loading.is_some() {
713                    return Err(
714                        "hazard_loading is valid only with frailty_kind='hazard-multiplier'"
715                            .to_string(),
716                    );
717                }
718                FrailtySpec::GaussianShift { sigma_fixed: sigma }
719            }
720            other => {
721                return Err(format!(
722                    "unknown frailty_kind '{other}'; supported: 'none', 'hazard-multiplier', 'gaussian-shift'"
723                ));
724            }
725        };
726        Ok(Some(frailty))
727    } else if frailty_sd.is_some() || hazard_loading.is_some() {
728        Err("frailty_kind is required when frailty_sd or hazard_loading is provided".to_string())
729    } else {
730        Ok(None)
731    }
732}
733
734fn cli_hazard_loading(loading: CliHazardLoading) -> HazardLoading {
735    match loading {
736        CliHazardLoading::Full => HazardLoading::Full,
737        CliHazardLoading::LoadedVsUnloaded => HazardLoading::LoadedVsUnloaded,
738    }
739}
740
741fn parse_group_metadata(
742    direct: Option<GroupMetadata>,
743    groups: Option<JsonValue>,
744) -> Result<Option<GroupMetadata>, String> {
745    match (direct, groups) {
746        (Some(metadata), None) => Ok(nonempty_group_metadata(metadata)),
747        (None, Some(groups)) => group_metadata_from_groups(groups),
748        (None, None) => Ok(None),
749        (Some(_), Some(_)) => {
750            Err("fit config accepts either group_metadata or groups metadata, not both".to_string())
751        }
752    }
753}
754
755fn parse_gamma_pair_value(label: &str, value: JsonValue) -> Result<(String, f64, f64), String> {
756    match value {
757        JsonValue::Array(values) => {
758            if values.len() != 2 {
759                return Err(format!(
760                    "precision_hyperpriors['{label}'] must be [shape, rate]"
761                ));
762            }
763            let shape = values[0]
764                .as_f64()
765                .ok_or_else(|| format!("precision_hyperpriors['{label}'][0] must be numeric"))?;
766            let rate = values[1]
767                .as_f64()
768                .ok_or_else(|| format!("precision_hyperpriors['{label}'][1] must be numeric"))?;
769            Ok((label.to_string(), shape, rate))
770        }
771        JsonValue::Object(mut map) => {
772            let shape = map
773                .remove("shape")
774                .or_else(|| map.remove("a"))
775                .or_else(|| map.remove("a_p"))
776                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing shape/a"))?
777                .as_f64()
778                .ok_or_else(|| {
779                    format!("precision_hyperpriors['{label}'] shape/a must be numeric")
780                })?;
781            let rate = map
782                .remove("rate")
783                .or_else(|| map.remove("b"))
784                .or_else(|| map.remove("b_p"))
785                .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing rate/b"))?
786                .as_f64()
787                .ok_or_else(|| {
788                    format!("precision_hyperpriors['{label}'] rate/b must be numeric")
789                })?;
790            Ok((label.to_string(), shape, rate))
791        }
792        _ => Err(format!(
793            "precision_hyperpriors['{label}'] must be [shape, rate] or an object"
794        )),
795    }
796}
797
798fn parse_precision_hyperpriors(
799    precision_hyperpriors: Option<JsonValue>,
800    penalty_block_gamma_priors: Option<JsonValue>,
801) -> Result<Vec<(String, f64, f64)>, String> {
802    let raw = match (precision_hyperpriors, penalty_block_gamma_priors) {
803        (Some(_), Some(_)) => {
804            return Err(
805                "fit config accepts either precision_hyperpriors or penalty_block_gamma_priors, not both"
806                    .to_string(),
807            );
808        }
809        (Some(raw), None) | (None, Some(raw)) => raw,
810        (None, None) => {
811            return Ok(Vec::new());
812        }
813    };
814    let raw_name = "precision_hyperpriors";
815    let Some(raw) = (match raw {
816        JsonValue::Null => None,
817        other => Some(other),
818    }) else {
819        return Ok(Vec::new());
820    };
821    match raw {
822        JsonValue::Object(map) => map
823            .into_iter()
824            .map(|(label, value)| parse_gamma_pair_value(&label, value))
825            .collect(),
826        JsonValue::Array(items) => items
827            .into_iter()
828            .enumerate()
829            .map(|(idx, item)| match item {
830                JsonValue::Object(mut obj) => {
831                    let label = obj
832                        .remove("label")
833                        .or_else(|| obj.remove("name"))
834                        .or_else(|| obj.remove("group"))
835                        .ok_or_else(|| format!("{raw_name}[{idx}] needs label/name/group"))?;
836                    let JsonValue::String(label) = label else {
837                        return Err(format!("{raw_name}[{idx}] label must be a string"));
838                    };
839                    parse_gamma_pair_value(&label, JsonValue::Object(obj))
840                }
841                JsonValue::Array(mut values) => {
842                    if values.len() != 2 && values.len() != 3 {
843                        return Err(format!(
844                            "{raw_name}[{idx}] must be [label, shape, rate] or [label, [shape, rate]]"
845                        ));
846                    }
847                    let label = values.remove(0);
848                    let JsonValue::String(label) = label else {
849                        return Err(format!("{raw_name}[{idx}][0] must be a string label"));
850                    };
851                    let pair = if values.len() == 1 {
852                        values.remove(0)
853                    } else {
854                        JsonValue::Array(values)
855                    };
856                    parse_gamma_pair_value(&label, pair)
857                }
858                _ => Err(format!("{raw_name}[{idx}] must be an object or array")),
859            })
860            .collect(),
861        _ => Err(format!("{raw_name} must be a map or array")),
862    }
863}
864
865fn nonempty_group_metadata(metadata: GroupMetadata) -> Option<GroupMetadata> {
866    if metadata.is_empty() {
867        None
868    } else {
869        Some(metadata)
870    }
871}
872
873fn group_metadata_from_groups(groups: JsonValue) -> Result<Option<GroupMetadata>, String> {
874    match groups {
875        JsonValue::Null => Ok(None),
876        JsonValue::Object(map) => {
877            let out = map.into_iter().collect::<BTreeMap<_, _>>();
878            Ok(nonempty_group_metadata(out))
879        }
880        JsonValue::Array(items) => {
881            let mut out = BTreeMap::new();
882            for (idx, item) in items.into_iter().enumerate() {
883                let JsonValue::Object(mut group) = item else {
884                    return Err(format!("groups[{idx}] must be an object"));
885                };
886                let Some(metadata) = group.remove("metadata") else {
887                    continue;
888                };
889                let name = group
890                    .remove("name")
891                    .or_else(|| group.remove("id"))
892                    .or_else(|| group.remove("key"))
893                    .ok_or_else(|| {
894                        format!(
895                            "groups[{idx}] with metadata must include a string name, id, or key"
896                        )
897                    })?;
898                let JsonValue::String(name) = name else {
899                    return Err(format!("groups[{idx}] name/id/key must be a string"));
900                };
901                if name.is_empty() {
902                    return Err(format!("groups[{idx}] name/id/key must be non-empty"));
903                }
904                if out.insert(name.clone(), metadata).is_some() {
905                    return Err(format!("duplicate group metadata key '{name}'"));
906                }
907            }
908            Ok(nonempty_group_metadata(out))
909        }
910        _ => Err("groups must be an object map or an array of group objects".to_string()),
911    }
912}
913
914fn parse_gpu_policy(raw_gpu: &str) -> Result<gam_gpu::GpuPolicy, String> {
915    gam_gpu::GpuPolicy::parse(raw_gpu).ok_or_else(|| {
916        format!(
917            "invalid gpu policy '{}'; supported values are auto, off, force",
918            raw_gpu
919        )
920    })
921}
922
923fn validate_resolved_fit_config(config: &FitConfig) -> Result<(), String> {
924    if !config.ridge_lambda.is_finite() || config.ridge_lambda < 0.0 {
925        return Err("--ridge-lambda must be finite and >= 0".to_string());
926    }
927    let likelihood_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
928    validate_survival_baseline_args(
929        likelihood_mode,
930        &config.baseline_target,
931        config.baseline_scale,
932        config.baseline_shape,
933        config.baseline_rate,
934        config.baseline_makeham,
935    )
936}
937
938fn survival_link_usage() -> &'static str {
939    "use identity|logit|probit|cloglog|sas|beta-logistic|blended(...)/mixture(...) or flexible(...)"
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use gam_models::survival::lognormal_kernel::FrailtySpec;
946    use serde_json::{Value, json};
947
948    struct ParityCase {
949        name: &'static str,
950        cli: CliFitConfigInput,
951        json: Value,
952    }
953
954    fn base_cli() -> CliFitConfigInput {
955        CliFitConfigInput {
956            family: None,
957            expectile_tau: None,
958            negative_binomial_theta: None,
959            link: None,
960            flexible_link: false,
961            offset_column: None,
962            weight_column: None,
963            noise_offset_column: None,
964            baseline_target: "linear".to_string(),
965            baseline_scale: None,
966            baseline_shape: None,
967            baseline_rate: None,
968            baseline_makeham: None,
969            time_basis: "ispline".to_string(),
970            time_degree: 3,
971            time_num_internal_knots: 8,
972            time_smooth_lambda: 1e-2,
973            survival_likelihood: "transformation".to_string(),
974            survival_distribution: "gaussian".to_string(),
975            threshold_time_k: None,
976            threshold_time_degree: 3,
977            sigma_time_k: None,
978            sigma_time_degree: 3,
979            noise_formula: None,
980            logslope_formula: None,
981            z_column: None,
982            scale_dimensions: false,
983            adaptive_regularization: None,
984            ridge_lambda: 1e-6,
985            transformation_normal: false,
986            firth: false,
987            outer_max_iter: None,
988            gpu: None,
989            frailty_kind: None,
990            frailty_sd: None,
991            hazard_loading: None,
992        }
993    }
994
995    fn resolved_cli(input: CliFitConfigInput) -> Result<FitConfig, String> {
996        resolve_cli_fit_config(input)
997    }
998
999    fn resolved_json(config: Value) -> Result<FitConfig, String> {
1000        parse_fit_config_json(Some(&config.to_string())).map(|resolved| {
1001            assert_eq!(resolved.training_table_kind, None);
1002            resolved.fit_config
1003        })
1004    }
1005
1006    fn canonical_fit_config(mut config: FitConfig) -> String {
1007        if config.frailty.is_none() {
1008            config.frailty = Some(FrailtySpec::None);
1009        }
1010        format!("{config:#?}")
1011    }
1012
1013    #[test]
1014    fn cli_shaped_and_json_wire_config_resolution_match() {
1015        let cases = vec![
1016            ParityCase {
1017                name: "family and link selection",
1018                cli: {
1019                    let mut input = base_cli();
1020                    input.family = Some("binomial".to_string());
1021                    input.link = Some("probit".to_string());
1022                    input.flexible_link = true;
1023                    input
1024                },
1025                json: json!({
1026                    "family": "binomial",
1027                    "link": "probit",
1028                    "flexible_link": true
1029                }),
1030            },
1031            ParityCase {
1032                name: "offset weights ridge and noise offset columns",
1033                cli: {
1034                    let mut input = base_cli();
1035                    input.offset_column = Some("eta_offset".to_string());
1036                    input.weight_column = Some("case_weight".to_string());
1037                    input.noise_offset_column = Some("sigma_offset".to_string());
1038                    input.ridge_lambda = 0.125;
1039                    input
1040                },
1041                json: json!({
1042                    "offset": "eta_offset",
1043                    "weights": "case_weight",
1044                    "noise_offset": "sigma_offset",
1045                    "ridge_lambda": 0.125
1046                }),
1047            },
1048            ParityCase {
1049                name: "weibull survival likelihood and baseline scale shape",
1050                cli: {
1051                    let mut input = base_cli();
1052                    input.survival_likelihood = "weibull".to_string();
1053                    input.baseline_target = "weibull".to_string();
1054                    input.baseline_scale = Some(2.5);
1055                    input.baseline_shape = Some(1.75);
1056                    input
1057                },
1058                json: json!({
1059                    "survival_likelihood": "weibull",
1060                    "baseline_target": "weibull",
1061                    "baseline_scale": 2.5,
1062                    "baseline_shape": 1.75
1063                }),
1064            },
1065            ParityCase {
1066                name: "transformation survival gompertz makeham baseline",
1067                cli: {
1068                    let mut input = base_cli();
1069                    input.survival_likelihood = "transformation".to_string();
1070                    input.baseline_target = "gompertz-makeham".to_string();
1071                    input.baseline_shape = Some(1.2);
1072                    input.baseline_rate = Some(0.04);
1073                    input.baseline_makeham = Some(0.01);
1074                    input
1075                },
1076                json: json!({
1077                    "survival_likelihood": "transformation",
1078                    "baseline_target": "gompertz-makeham",
1079                    "baseline_shape": 1.2,
1080                    "baseline_rate": 0.04,
1081                    "baseline_makeham": 0.01
1082                }),
1083            },
1084            ParityCase {
1085                name: "survival likelihood values are canonicalized",
1086                cli: {
1087                    let mut input = base_cli();
1088                    input.survival_likelihood = "TRANSFORMATION".to_string();
1089                    input
1090                },
1091                json: json!({
1092                    "survival_likelihood": "Transformation"
1093                }),
1094            },
1095            ParityCase {
1096                name: "noise formula logslope z column and scale dimensions",
1097                cli: {
1098                    let mut input = base_cli();
1099                    input.noise_formula = Some("~ s(age) + treatment".to_string());
1100                    input.logslope_formula = Some("~ s(dose)".to_string());
1101                    input.z_column = Some("dose".to_string());
1102                    input.scale_dimensions = true;
1103                    input
1104                },
1105                json: json!({
1106                    "noise_formula": "~ s(age) + treatment",
1107                    "logslope_formula": "~ s(dose)",
1108                    "z_column": "dose",
1109                    "scale_dimensions": true
1110                }),
1111            },
1112            ParityCase {
1113                name: "firth transformation normal outer iterations and adaptive regularization",
1114                cli: {
1115                    let mut input = base_cli();
1116                    input.firth = true;
1117                    input.transformation_normal = true;
1118                    input.outer_max_iter = Some(7);
1119                    input.adaptive_regularization = Some(true);
1120                    input
1121                },
1122                json: json!({
1123                    "firth": true,
1124                    "transformation_normal": true,
1125                    "outer_max_iter": 7,
1126                    "adaptive_regularization": true
1127                }),
1128            },
1129            ParityCase {
1130                name: "gpu policy toggle",
1131                cli: {
1132                    let mut input = base_cli();
1133                    input.gpu = Some("off".to_string());
1134                    input
1135                },
1136                json: json!({
1137                    "gpu": "off"
1138                }),
1139            },
1140            ParityCase {
1141                name: "hazard multiplier frailty fields",
1142                cli: {
1143                    let mut input = base_cli();
1144                    input.frailty_kind = Some(CliFrailtyKind::HazardMultiplier);
1145                    input.frailty_sd = Some(0.35);
1146                    input.hazard_loading = Some(CliHazardLoading::LoadedVsUnloaded);
1147                    input
1148                },
1149                json: json!({
1150                    "frailty_kind": "hazard-multiplier",
1151                    "frailty_sd": 0.35,
1152                    "hazard_loading": "loaded-vs-unloaded"
1153                }),
1154            },
1155            ParityCase {
1156                name: "gaussian shift frailty fields",
1157                cli: {
1158                    let mut input = base_cli();
1159                    input.frailty_kind = Some(CliFrailtyKind::GaussianShift);
1160                    input.frailty_sd = Some(0.2);
1161                    input
1162                },
1163                json: json!({
1164                    "frailty_kind": "gaussian-shift",
1165                    "frailty_sd": 0.2
1166                }),
1167            },
1168        ];
1169
1170        for case in cases {
1171            let cli = resolved_cli(case.cli)
1172                .unwrap_or_else(|err| panic!("{}: CLI-shaped config failed: {err}", case.name));
1173            let json = resolved_json(case.json)
1174                .unwrap_or_else(|err| panic!("{}: JSON wire config failed: {err}", case.name));
1175            assert_eq!(
1176                canonical_fit_config(cli),
1177                canonical_fit_config(json),
1178                "{}",
1179                case.name
1180            );
1181        }
1182    }
1183
1184    #[test]
1185    fn cli_shaped_and_json_wire_config_resolution_rejections_match() {
1186        let cases = vec![
1187            ParityCase {
1188                name: "negative ridge lambda",
1189                cli: {
1190                    let mut input = base_cli();
1191                    input.ridge_lambda = -1.0;
1192                    input
1193                },
1194                json: json!({
1195                    "ridge_lambda": -1.0
1196                }),
1197            },
1198            ParityCase {
1199                name: "unknown gpu policy",
1200                cli: {
1201                    let mut input = base_cli();
1202                    input.gpu = Some("cuda".to_string());
1203                    input
1204                },
1205                json: json!({
1206                    "gpu": "cuda"
1207                }),
1208            },
1209            ParityCase {
1210                name: "linear baseline rejects shape",
1211                cli: {
1212                    let mut input = base_cli();
1213                    input.baseline_shape = Some(1.1);
1214                    input
1215                },
1216                json: json!({
1217                    "baseline_shape": 1.1
1218                }),
1219            },
1220            ParityCase {
1221                name: "weibull likelihood rejects gompertz target",
1222                cli: {
1223                    let mut input = base_cli();
1224                    input.survival_likelihood = "weibull".to_string();
1225                    input.baseline_target = "gompertz".to_string();
1226                    input
1227                },
1228                json: json!({
1229                    "survival_likelihood": "weibull",
1230                    "baseline_target": "gompertz"
1231                }),
1232            },
1233        ];
1234
1235        for case in cases {
1236            let cli = resolved_cli(case.cli).expect_err(case.name);
1237            let json = resolved_json(case.json).expect_err(case.name);
1238            assert_eq!(cli, json, "{}", case.name);
1239        }
1240    }
1241
1242    // ── parse_comma_f64 ───────────────────────────────────────────────────
1243
1244    #[test]
1245    fn parse_comma_f64_empty_string_returns_empty_vec() {
1246        assert_eq!(parse_comma_f64("", "x").unwrap(), Vec::<f64>::new());
1247        assert_eq!(parse_comma_f64("   ", "x").unwrap(), Vec::<f64>::new());
1248    }
1249
1250    #[test]
1251    fn parse_comma_f64_single_value() {
1252        assert_eq!(parse_comma_f64("3.14", "x").unwrap(), vec![3.14]);
1253    }
1254
1255    #[test]
1256    fn parse_comma_f64_multiple_values_with_spaces() {
1257        let result = parse_comma_f64("1.0, 2.5, -3.0", "x").unwrap();
1258        assert_eq!(result, vec![1.0, 2.5, -3.0]);
1259    }
1260
1261    #[test]
1262    fn parse_comma_f64_non_numeric_returns_error() {
1263        let err = parse_comma_f64("1.0, bad, 3.0", "--vals").unwrap_err();
1264        assert!(err.contains("--vals"), "error should name the label: {err}");
1265        assert!(err.contains("bad"), "error should name the bad token: {err}");
1266    }
1267
1268    #[test]
1269    fn parse_comma_f64_infinity_returns_error() {
1270        let err = parse_comma_f64("inf", "--vals").unwrap_err();
1271        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1272    }
1273
1274    #[test]
1275    fn parse_comma_f64_nan_returns_error() {
1276        let err = parse_comma_f64("nan", "--vals").unwrap_err();
1277        assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1278    }
1279
1280    // ── normalize_optional_family ─────────────────────────────────────────
1281
1282    #[test]
1283    fn normalize_optional_family_none_passthrough() {
1284        assert_eq!(normalize_optional_family(None), None);
1285    }
1286
1287    #[test]
1288    fn normalize_optional_family_auto_becomes_none() {
1289        assert_eq!(normalize_optional_family(Some("auto".to_string())), None);
1290        assert_eq!(normalize_optional_family(Some("Auto".to_string())), None);
1291        assert_eq!(normalize_optional_family(Some("AUTO".to_string())), None);
1292    }
1293
1294    #[test]
1295    fn normalize_optional_family_non_auto_passthrough() {
1296        assert_eq!(
1297            normalize_optional_family(Some("binomial".to_string())),
1298            Some("binomial".to_string())
1299        );
1300        assert_eq!(
1301            normalize_optional_family(Some("gaussian".to_string())),
1302            Some("gaussian".to_string())
1303        );
1304    }
1305
1306    // ── parse_survival_likelihood_cli ─────────────────────────────────────
1307
1308    #[test]
1309    fn parse_survival_likelihood_cli_valid_values() {
1310        assert_eq!(
1311            parse_survival_likelihood_cli("transformation").unwrap(),
1312            "transformation"
1313        );
1314        assert_eq!(
1315            parse_survival_likelihood_cli("weibull").unwrap(),
1316            "weibull"
1317        );
1318        // case-insensitive
1319        assert_eq!(
1320            parse_survival_likelihood_cli("WEIBULL").unwrap(),
1321            "weibull"
1322        );
1323        assert_eq!(
1324            parse_survival_likelihood_cli("Transformation").unwrap(),
1325            "transformation"
1326        );
1327    }
1328
1329    #[test]
1330    fn parse_survival_likelihood_cli_invalid_returns_error() {
1331        assert!(parse_survival_likelihood_cli("lognormal").is_err());
1332        assert!(parse_survival_likelihood_cli("").is_err());
1333    }
1334
1335    // ── parse_baseline_target_cli ─────────────────────────────────────────
1336
1337    #[test]
1338    fn parse_baseline_target_cli_valid_values() {
1339        for target in &["linear", "weibull", "gompertz", "gompertz-makeham"] {
1340            assert_eq!(
1341                parse_baseline_target_cli(target).unwrap(),
1342                *target,
1343                "should accept '{target}'"
1344            );
1345        }
1346        // trimmed and lowercased
1347        assert_eq!(
1348            parse_baseline_target_cli("  Weibull  ").unwrap(),
1349            "weibull"
1350        );
1351    }
1352
1353    #[test]
1354    fn parse_baseline_target_cli_invalid_returns_error() {
1355        let err = parse_baseline_target_cli("cox").unwrap_err();
1356        assert!(err.contains("cox"), "error should name the bad value: {err}");
1357    }
1358
1359    // ── validate_survival_baseline_args ───────────────────────────────────
1360
1361    #[test]
1362    fn validate_survival_baseline_args_linear_rejects_params() {
1363        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1364        assert!(validate_survival_baseline_args(
1365            mode,
1366            "linear",
1367            Some(1.0),
1368            None,
1369            None,
1370            None
1371        )
1372        .is_err());
1373    }
1374
1375    #[test]
1376    fn validate_survival_baseline_args_linear_accepts_no_params() {
1377        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1378        assert!(validate_survival_baseline_args(
1379            mode, "linear", None, None, None, None
1380        )
1381        .is_ok());
1382    }
1383
1384    #[test]
1385    fn validate_survival_baseline_args_weibull_likelihood_rejects_gompertz_target() {
1386        let mode = parse_survival_likelihood_mode("weibull").unwrap();
1387        assert!(validate_survival_baseline_args(
1388            mode,
1389            "gompertz",
1390            None,
1391            None,
1392            None,
1393            None
1394        )
1395        .is_err());
1396    }
1397
1398    #[test]
1399    fn validate_survival_baseline_args_gompertz_rejects_scale() {
1400        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1401        assert!(validate_survival_baseline_args(
1402            mode,
1403            "gompertz",
1404            Some(2.0),
1405            None,
1406            None,
1407            None
1408        )
1409        .is_err());
1410    }
1411
1412    #[test]
1413    fn validate_survival_baseline_args_gompertz_makeham_rejects_scale() {
1414        let mode = parse_survival_likelihood_mode("transformation").unwrap();
1415        assert!(validate_survival_baseline_args(
1416            mode,
1417            "gompertz-makeham",
1418            Some(1.0),
1419            None,
1420            None,
1421            None
1422        )
1423        .is_err());
1424    }
1425}