1use gam_models::survival::location_scale::residual_distribution_inverse_link;
2use gam_models::survival::lognormal_kernel::{FrailtySpec, HazardLoading};
3use gam_models::survival::parse_survival_distribution;
4use gam_models::survival::{SurvivalLikelihoodMode, parse_survival_likelihood_mode};
5use gam_inference::formula_dsl::parse_link_choice;
6use gam_inference::model::GroupMetadata;
7use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
8use gam_models::fit_orchestration::descriptors::build_analytic_penalty_registry_from_descriptors;
9use gam_models::fit_orchestration::{CtnStage1Recipe, FitConfig};
10use gam_models::transformation_normal::TransformationNormalConfig;
11use gam_problem::types::{
12 InverseLink, LinkComponent, LinkFunction, MixtureLinkSpec, SasLinkSpec, StandardLink,
13};
14use ndarray::Array1;
15use serde::Deserialize;
16use serde_json::Value as JsonValue;
17use std::collections::BTreeMap;
18
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub enum CliFrailtyKind {
21 GaussianShift,
22 HazardMultiplier,
23}
24
25#[derive(Clone, Copy, Debug, Eq, PartialEq)]
26pub enum CliHazardLoading {
27 Full,
28 LoadedVsUnloaded,
29}
30
31#[derive(Default, Deserialize)]
32#[serde(deny_unknown_fields)]
33struct JsonFitConfig {
34 family: Option<String>,
35 offset: Option<String>,
36 weights: Option<String>,
37 ridge_lambda: Option<f64>,
38 transformation_normal: Option<bool>,
39 survival_likelihood: Option<String>,
40 baseline_target: Option<String>,
41 baseline_scale: Option<f64>,
42 baseline_shape: Option<f64>,
43 baseline_rate: Option<f64>,
44 baseline_makeham: Option<f64>,
45 z_column: Option<String>,
46 logslope_formula: Option<String>,
47 ctn_stage1: Option<JsonCtnStage1>,
48 link: Option<String>,
49 flexible_link: Option<bool>,
50 scale_dimensions: Option<bool>,
51 adaptive_regularization: Option<bool>,
52 noise_formula: Option<String>,
53 noise_offset: Option<String>,
54 firth: Option<bool>,
55 outer_max_iter: Option<usize>,
56 gpu: Option<String>,
57 group_metadata: Option<GroupMetadata>,
58 groups: Option<JsonValue>,
59 precision_hyperpriors: Option<JsonValue>,
60 penalty_block_gamma_priors: Option<JsonValue>,
61 latents: Option<JsonValue>,
62 penalties: Option<JsonValue>,
63 smooths: Option<JsonValue>,
64 topology_auto_selector: Option<JsonValue>,
65 frailty_kind: Option<String>,
66 frailty_sd: Option<f64>,
67 hazard_loading: Option<String>,
68 training_table_kind: Option<String>,
69}
70
71#[derive(Deserialize)]
72#[serde(deny_unknown_fields)]
73struct JsonCtnStage1 {
74 response_column: String,
75 covariate_formula_rhs: String,
76 #[serde(default)]
77 config: Option<JsonCtnStage1Config>,
78 #[serde(default)]
79 weight_column: Option<String>,
80 #[serde(default)]
81 offset_column: Option<String>,
82}
83
84#[derive(Deserialize)]
85#[serde(deny_unknown_fields)]
86struct JsonCtnStage1Config {
87 #[serde(default)]
88 response_degree: Option<usize>,
89 #[serde(default)]
90 response_num_internal_knots: Option<usize>,
91 #[serde(default)]
92 response_penalty_order: Option<usize>,
93 #[serde(default)]
94 response_extra_penalty_orders: Option<Vec<usize>>,
95 #[serde(default)]
96 double_penalty: Option<bool>,
97}
98
99impl JsonCtnStage1 {
100 fn into_recipe(self) -> Result<CtnStage1Recipe, String> {
101 let mut config = TransformationNormalConfig::default();
102 if let Some(overrides) = self.config {
103 if let Some(value) = overrides.response_degree {
104 config.response_degree = value;
105 }
106 if let Some(value) = overrides.response_num_internal_knots {
107 config.response_num_internal_knots = value;
108 }
109 if let Some(value) = overrides.response_penalty_order {
110 config.response_penalty_order = value;
111 }
112 if let Some(value) = overrides.response_extra_penalty_orders {
113 config.response_extra_penalty_orders = value;
114 }
115 if let Some(value) = overrides.double_penalty {
116 config.double_penalty = value;
117 }
118 }
119 CtnStage1Recipe::new(
120 &self.response_column,
121 &self.covariate_formula_rhs,
122 config,
123 self.weight_column.as_deref(),
124 self.offset_column.as_deref(),
125 )
126 }
127}
128
129pub struct ResolvedFitConfig {
130 pub fit_config: FitConfig,
131 pub training_table_kind: Option<String>,
132}
133
134pub struct CliFitConfigInput {
135 pub family: Option<String>,
136 pub expectile_tau: Option<f64>,
142 pub negative_binomial_theta: Option<f64>,
143 pub link: Option<String>,
144 pub flexible_link: bool,
145 pub offset_column: Option<String>,
146 pub weight_column: Option<String>,
147 pub noise_offset_column: Option<String>,
148 pub baseline_target: String,
149 pub baseline_scale: Option<f64>,
150 pub baseline_shape: Option<f64>,
151 pub baseline_rate: Option<f64>,
152 pub baseline_makeham: Option<f64>,
153 pub time_basis: String,
154 pub time_degree: usize,
155 pub time_num_internal_knots: usize,
156 pub time_smooth_lambda: f64,
157 pub survival_likelihood: String,
158 pub survival_distribution: String,
159 pub threshold_time_k: Option<usize>,
160 pub threshold_time_degree: usize,
161 pub sigma_time_k: Option<usize>,
162 pub sigma_time_degree: usize,
163 pub noise_formula: Option<String>,
164 pub logslope_formula: Option<String>,
165 pub z_column: Option<String>,
166 pub scale_dimensions: bool,
167 pub adaptive_regularization: Option<bool>,
168 pub ridge_lambda: f64,
169 pub transformation_normal: bool,
170 pub firth: bool,
171 pub outer_max_iter: Option<usize>,
172 pub gpu: Option<String>,
173 pub frailty_kind: Option<CliFrailtyKind>,
174 pub frailty_sd: Option<f64>,
175 pub hazard_loading: Option<CliHazardLoading>,
176}
177
178pub struct SurvivalInverseLinkInput<'a> {
179 pub link: Option<&'a str>,
180 pub mixture_rho: Option<&'a str>,
181 pub sas_init: Option<&'a str>,
182 pub beta_logistic_init: Option<&'a str>,
183 pub survival_distribution: &'a str,
184}
185
186pub fn parse_fit_config_json(config_json: Option<&str>) -> Result<ResolvedFitConfig, String> {
187 let json_config = match config_json {
188 Some(raw) if !raw.trim().is_empty() => serde_json::from_str::<JsonFitConfig>(raw)
189 .map_err(|err| format!("invalid fit config json: {err}"))?,
190 _ => JsonFitConfig::default(),
191 };
192 resolve_json_fit_config(json_config)
193}
194
195fn resolve_json_fit_config(json_config: JsonFitConfig) -> Result<ResolvedFitConfig, String> {
196 let training_table_kind = json_config.training_table_kind;
197 let mut fit_config = FitConfig::default();
198 fit_config.group_metadata =
199 parse_group_metadata(json_config.group_metadata, json_config.groups)?;
200 fit_config.penalty_block_gamma_priors = parse_precision_hyperpriors(
201 json_config.precision_hyperpriors,
202 json_config.penalty_block_gamma_priors,
203 )?;
204 let analytic_penalties = json_config.penalties;
205 build_analytic_penalty_registry_from_descriptors(
206 json_config.latents.as_ref(),
207 analytic_penalties.as_ref(),
208 )?;
209 fit_config.latents = json_config.latents;
210 fit_config.analytic_penalties = analytic_penalties;
211 fit_config.smooth_overrides = json_config.smooths;
212 fit_config.topology_auto_selector = json_config
213 .topology_auto_selector
214 .as_ref()
215 .map(gam_solve::topology_selector::TopologyAutoSelector::from_json)
216 .transpose()?;
217 fit_config.family = normalize_optional_family(json_config.family);
218 fit_config.offset_column = json_config.offset;
219 fit_config.weight_column = json_config.weights;
220 if let Some(ridge_lambda) = json_config.ridge_lambda {
221 fit_config.ridge_lambda = ridge_lambda;
222 }
223 if let Some(flag) = json_config.transformation_normal {
224 fit_config.transformation_normal = flag;
225 }
226 if let Some(mode) = json_config.survival_likelihood {
227 fit_config.survival_likelihood = parse_survival_likelihood_cli(&mode)?;
228 }
229 if let Some(target) = json_config.baseline_target {
230 fit_config.baseline_target =
231 resolve_nonempty_string(target, "baseline_target must be a non-empty string")?;
232 }
233 if let Some(value) = json_config.baseline_scale {
234 fit_config.baseline_scale = Some(value);
235 }
236 if let Some(value) = json_config.baseline_shape {
237 fit_config.baseline_shape = Some(value);
238 }
239 if let Some(value) = json_config.baseline_rate {
240 fit_config.baseline_rate = Some(value);
241 }
242 if let Some(value) = json_config.baseline_makeham {
243 fit_config.baseline_makeham = Some(value);
244 }
245 if let Some(z) = json_config.z_column {
246 fit_config.z_column = Some(resolve_nonempty_string(
247 z,
248 "z_column must be a non-empty string",
249 )?);
250 }
251 if let Some(formula) = json_config.logslope_formula {
252 fit_config.logslope_formula = Some(formula);
253 }
254 if let Some(stage1) = json_config.ctn_stage1 {
255 fit_config.ctn_stage1 = Some(stage1.into_recipe()?);
256 }
257 if let Some(link) = json_config.link {
258 let trimmed = link.trim();
259 fit_config.link = if trimmed.is_empty() {
260 None
261 } else {
262 Some(trimmed.to_string())
263 };
264 }
265 if let Some(flag) = json_config.flexible_link {
266 fit_config.flexible_link = flag;
267 }
268 if let Some(flag) = json_config.scale_dimensions {
269 fit_config.scale_dimensions = flag;
270 }
271 if let Some(flag) = json_config.adaptive_regularization {
272 fit_config.adaptive_regularization = Some(flag);
273 }
274 if let Some(formula) = json_config.noise_formula {
275 fit_config.noise_formula = Some(formula);
276 }
277 if let Some(column) = json_config.noise_offset {
278 fit_config.noise_offset_column = Some(column);
279 }
280 if let Some(flag) = json_config.firth {
281 fit_config.firth = flag;
282 }
283 if let Some(value) = json_config.outer_max_iter {
284 if value == 0 {
285 return Err("outer_max_iter must be >= 1".to_string());
286 }
287 fit_config.outer_max_iter = Some(value);
288 }
289 if let Some(raw_gpu) = json_config.gpu {
290 fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
291 }
292 fit_config.frailty = parse_json_frailty_spec(
293 json_config.frailty_kind,
294 json_config.frailty_sd,
295 json_config.hazard_loading,
296 )?;
297 validate_resolved_fit_config(&fit_config)?;
298 Ok(ResolvedFitConfig {
299 fit_config,
300 training_table_kind,
301 })
302}
303
304pub fn resolve_cli_fit_config(input: CliFitConfigInput) -> Result<FitConfig, String> {
305 let mut fit_config = FitConfig::default();
306 fit_config.family = input.family;
307 fit_config.expectile_tau = input.expectile_tau;
308 fit_config.negative_binomial_theta = input.negative_binomial_theta;
309 fit_config.link = input.link;
310 fit_config.flexible_link = input.flexible_link;
311 fit_config.offset_column = input.offset_column;
312 fit_config.weight_column = input.weight_column;
313 fit_config.noise_offset_column = input.noise_offset_column;
314 fit_config.baseline_target = input.baseline_target;
315 fit_config.baseline_scale = input.baseline_scale;
316 fit_config.baseline_shape = input.baseline_shape;
317 fit_config.baseline_rate = input.baseline_rate;
318 fit_config.baseline_makeham = input.baseline_makeham;
319 fit_config.time_basis = input.time_basis;
320 fit_config.time_degree = input.time_degree;
321 fit_config.time_num_internal_knots = input.time_num_internal_knots;
322 fit_config.time_smooth_lambda = input.time_smooth_lambda;
323 fit_config.survival_likelihood = parse_survival_likelihood_cli(&input.survival_likelihood)?;
324 fit_config.survival_distribution = input.survival_distribution;
325 fit_config.threshold_time_k = input.threshold_time_k;
326 fit_config.threshold_time_degree = input.threshold_time_degree;
327 fit_config.sigma_time_k = input.sigma_time_k;
328 fit_config.sigma_time_degree = input.sigma_time_degree;
329 fit_config.noise_formula = input.noise_formula;
330 fit_config.logslope_formula = input.logslope_formula;
331 fit_config.z_column = input.z_column;
332 fit_config.scale_dimensions = input.scale_dimensions;
333 fit_config.adaptive_regularization = input.adaptive_regularization;
334 fit_config.ridge_lambda = input.ridge_lambda;
335 fit_config.transformation_normal = input.transformation_normal;
336 fit_config.firth = input.firth;
337 fit_config.outer_max_iter = input.outer_max_iter;
338 if let Some(raw_gpu) = input.gpu {
339 fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
340 }
341 fit_config.frailty = Some(resolve_cli_frailty_spec(
342 input.frailty_kind,
343 input.frailty_sd,
344 input.hazard_loading,
345 "fit",
346 )?);
347 validate_resolved_fit_config(&fit_config)?;
348 Ok(fit_config)
349}
350
351pub fn resolve_cli_frailty_spec(
352 frailty_kind: Option<CliFrailtyKind>,
353 frailty_sd: Option<f64>,
354 hazard_loading: Option<CliHazardLoading>,
355 context: &str,
356) -> Result<FrailtySpec, String> {
357 let validate_sigma = || -> Result<Option<f64>, String> {
358 match frailty_sd {
359 None => Ok(None),
360 Some(sigma) => {
361 if !sigma.is_finite() || sigma < 0.0 {
362 return Err(format!(
363 "{context} requires a finite --frailty-sd >= 0, got {sigma}"
364 ));
365 }
366 Ok(Some(sigma))
367 }
368 }
369 };
370
371 match frailty_kind {
372 None => {
373 if frailty_sd.is_some() || hazard_loading.is_some() {
374 return Err(format!(
375 "{context} requires --frailty-kind when --frailty-sd or --hazard-loading is provided"
376 ));
377 }
378 Ok(FrailtySpec::None)
379 }
380 Some(CliFrailtyKind::GaussianShift) => {
381 if hazard_loading.is_some() {
382 return Err(format!(
383 "{context} does not accept --hazard-loading with --frailty-kind gaussian-shift"
384 ));
385 }
386 Ok(FrailtySpec::GaussianShift {
387 sigma_fixed: validate_sigma()?,
388 })
389 }
390 Some(CliFrailtyKind::HazardMultiplier) => Ok(FrailtySpec::HazardMultiplier {
391 sigma_fixed: validate_sigma()?,
392 loading: hazard_loading.map(cli_hazard_loading).ok_or_else(|| {
393 format!("{context} requires --hazard-loading with --frailty-kind hazard-multiplier")
394 })?,
395 }),
396 }
397}
398
399pub fn parse_survival_likelihood_cli(raw: &str) -> Result<String, String> {
400 let normalized = raw.trim().to_ascii_lowercase();
401 parse_survival_likelihood_mode(&normalized)?;
402 Ok(normalized)
403}
404
405pub fn parse_baseline_target_cli(raw: &str) -> Result<String, String> {
406 let normalized = raw.trim().to_ascii_lowercase();
407 match normalized.as_str() {
408 "linear" | "weibull" | "gompertz" | "gompertz-makeham" => Ok(normalized),
409 other => Err(format!(
410 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
411 )),
412 }
413}
414
415pub fn validate_survival_baseline_args(
416 likelihood_mode: SurvivalLikelihoodMode,
417 baseline_target: &str,
418 baseline_scale: Option<f64>,
419 baseline_shape: Option<f64>,
420 baseline_rate: Option<f64>,
421 baseline_makeham: Option<f64>,
422) -> Result<(), String> {
423 if likelihood_mode == SurvivalLikelihoodMode::Weibull {
424 if baseline_rate.is_some() || baseline_makeham.is_some() {
425 return Err(
426 "--survival-likelihood weibull does not use --baseline-rate or --baseline-makeham"
427 .to_string(),
428 );
429 }
430 if !matches!(baseline_target, "linear" | "weibull") {
431 return Err(
432 "--survival-likelihood weibull supports only --baseline-target linear|weibull"
433 .to_string(),
434 );
435 }
436 return Ok(());
437 }
438
439 match baseline_target {
440 "linear" => {
441 if baseline_scale.is_some()
442 || baseline_shape.is_some()
443 || baseline_rate.is_some()
444 || baseline_makeham.is_some()
445 {
446 return Err(
447 "--baseline-target linear does not use baseline parameter flags".to_string(),
448 );
449 }
450 }
451 "weibull" => {
452 if baseline_rate.is_some() || baseline_makeham.is_some() {
453 return Err(
454 "--baseline-target weibull does not use --baseline-rate or --baseline-makeham"
455 .to_string(),
456 );
457 }
458 }
459 "gompertz" => {
460 if baseline_scale.is_some() || baseline_makeham.is_some() {
461 return Err(
462 "--baseline-target gompertz does not use --baseline-scale or --baseline-makeham"
463 .to_string(),
464 );
465 }
466 }
467 "gompertz-makeham" => {
468 if baseline_scale.is_some() {
469 return Err(
470 "--baseline-target gompertz-makeham does not use --baseline-scale".to_string(),
471 );
472 }
473 }
474 other => {
475 return Err(format!(
476 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
477 ));
478 }
479 }
480 Ok(())
481}
482
483pub fn parse_comma_f64(v: &str, label: &str) -> Result<Vec<f64>, String> {
484 let mut out = Vec::new();
485 for part in v.split(',') {
486 let t = part.trim();
487 if t.is_empty() {
488 continue;
489 }
490 let parsed = t
491 .parse::<f64>()
492 .map_err(|err| format!("{label} contains non-numeric value '{t}': {err}"))?;
493 if !parsed.is_finite() {
494 return Err(format!("{label} contains non-finite value '{t}'"));
495 }
496 out.push(parsed);
497 }
498 Ok(out)
499}
500
501pub fn effective_link_to_standard(
502 link: LinkFunction,
503 context: &str,
504) -> Result<StandardLink, String> {
505 StandardLink::try_from(link).map_err(|_| {
506 format!(
507 "{context}: state-bearing link `{}` must be routed through `InverseLink::Sas` / `InverseLink::BetaLogistic`, not `Standard(_)`",
508 link.name()
509 )
510 })
511}
512
513pub fn parse_survival_inverse_link(
514 input: SurvivalInverseLinkInput<'_>,
515) -> Result<InverseLink, String> {
516 if let Some(raw) = input.link {
517 let name = raw.trim().to_ascii_lowercase();
518 if name == "loglog" || name == "cauchit" {
519 if input.sas_init.is_some() {
528 return Err("--sas-init requires --link sas".to_string());
529 }
530 if input.beta_logistic_init.is_some() {
531 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
532 }
533 if input.mixture_rho.is_some() {
534 return Err(
535 "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
536 );
537 }
538 let component = if name == "loglog" {
539 LinkComponent::LogLog
540 } else {
541 LinkComponent::Cauchit
542 };
543 return state_fromspec(&MixtureLinkSpec {
544 components: vec![component],
545 initial_rho: Array1::zeros(0),
546 })
547 .map(InverseLink::Mixture)
548 .map_err(|e| format!("invalid survival {name} link state: {e}"));
549 }
550 }
551 let choice = parse_link_choice(input.link, false).map_err(|err| {
552 let err = err.to_string();
553 if let Some(raw) = input.link {
554 let name = raw.trim().to_ascii_lowercase();
555 if err.starts_with("unsupported --link ") || err.starts_with("unsupported link type ") {
556 return format!(
557 "unsupported survival --link '{name}'; {}",
558 survival_link_usage()
559 );
560 }
561 }
562 err
563 })?;
564 if let Some(choice) = choice {
565 if let Some(components) = choice.mixture_components {
566 if input.sas_init.is_some() || input.beta_logistic_init.is_some() {
567 return Err(
568 "survival blended(...) link does not accept --sas-init/--beta-logistic-init"
569 .to_string(),
570 );
571 }
572 let expected = components.len().saturating_sub(1);
573 let initial_rho = if let Some(raw) = input.mixture_rho {
574 let vals = parse_comma_f64(raw, "--mixture-rho")?;
575 if vals.len() != expected {
576 return Err(format!(
577 "--mixture-rho expects {expected} values for blended({})",
578 components
579 .iter()
580 .map(|component| component.name())
581 .collect::<Vec<_>>()
582 .join(",")
583 ));
584 }
585 Array1::from_vec(vals)
586 } else {
587 Array1::zeros(expected)
588 };
589 return state_fromspec(&MixtureLinkSpec {
590 components,
591 initial_rho,
592 })
593 .map(InverseLink::Mixture)
594 .map_err(|e| format!("invalid survival blended link state: {e}"));
595 }
596
597 if input.mixture_rho.is_some() {
598 return Err(
599 "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
600 );
601 }
602 match choice.link {
603 LinkFunction::Sas => {
604 if input.beta_logistic_init.is_some() {
605 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
606 }
607 let (epsilon, log_delta) = if let Some(raw) = input.sas_init {
608 let vals = parse_comma_f64(raw, "--sas-init")?;
609 if vals.len() != 2 {
610 return Err(format!(
611 "--sas-init expects two values: epsilon,log_delta (got {})",
612 vals.len()
613 ));
614 }
615 (vals[0], vals[1])
616 } else {
617 (0.0, 0.0)
618 };
619 state_from_sasspec(SasLinkSpec {
620 initial_epsilon: epsilon,
621 initial_log_delta: log_delta,
622 })
623 .map(InverseLink::Sas)
624 .map_err(|e| format!("invalid survival SAS link state: {e}"))
625 }
626 LinkFunction::BetaLogistic => {
627 if input.sas_init.is_some() {
628 return Err("--sas-init requires --link sas".to_string());
629 }
630 let (epsilon, delta) = if let Some(raw) = input.beta_logistic_init {
631 let vals = parse_comma_f64(raw, "--beta-logistic-init")?;
632 if vals.len() != 2 {
633 return Err(format!(
634 "--beta-logistic-init expects two values: epsilon,delta (got {})",
635 vals.len()
636 ));
637 }
638 (vals[0], vals[1])
639 } else {
640 (0.0, 0.0)
641 };
642 state_from_beta_logisticspec(SasLinkSpec {
643 initial_epsilon: epsilon,
644 initial_log_delta: delta,
645 })
646 .map(InverseLink::BetaLogistic)
647 .map_err(|e| format!("invalid survival Beta-Logistic link state: {e}"))
648 }
649 LinkFunction::Log => Err(format!(
650 "unsupported survival --link 'log'; {}",
651 survival_link_usage()
652 )),
653 other => {
654 if input.sas_init.is_some() {
655 return Err("--sas-init requires --link sas".to_string());
656 }
657 if input.beta_logistic_init.is_some() {
658 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
659 }
660 Ok(InverseLink::Standard(effective_link_to_standard(
661 other,
662 "survival inverse link",
663 )?))
664 }
665 }
666 } else {
667 if input.mixture_rho.is_some() {
668 return Err("--mixture-rho requires --link blended(...)/mixture(...)".to_string());
669 }
670 if input.sas_init.is_some() {
671 return Err("--sas-init requires --link sas".to_string());
672 }
673 if input.beta_logistic_init.is_some() {
674 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
675 }
676 let dist = parse_survival_distribution(input.survival_distribution)?;
677 Ok(residual_distribution_inverse_link(dist))
678 }
679}
680
681pub fn normalize_optional_family(family: Option<String>) -> Option<String> {
682 match family {
683 Some(value) if value.eq_ignore_ascii_case("auto") => None,
684 other => other,
685 }
686}
687
688fn resolve_nonempty_string(raw: String, message: &str) -> Result<String, String> {
689 let trimmed = raw.trim();
690 if trimmed.is_empty() {
691 return Err(message.to_string());
692 }
693 Ok(trimmed.to_string())
694}
695
696fn parse_json_frailty_spec(
697 frailty_kind: Option<String>,
698 frailty_sd: Option<f64>,
699 hazard_loading: Option<String>,
700) -> Result<Option<FrailtySpec>, String> {
701 if let Some(kind) = frailty_kind {
702 let trimmed = kind.trim().to_ascii_lowercase();
703 let sigma = frailty_sd;
704 if let Some(value) = sigma
705 && (!value.is_finite() || value < 0.0)
706 {
707 return Err(format!("frailty_sd must be finite and >= 0, got {value}"));
708 }
709 let hazard_loading = hazard_loading
710 .as_ref()
711 .map(|raw| raw.trim().to_ascii_lowercase());
712 let frailty = match trimmed.as_str() {
713 "none" | "" => {
714 if sigma.is_some() || hazard_loading.is_some() {
715 return Err(
716 "frailty_kind='none' does not accept frailty_sd or hazard_loading"
717 .to_string(),
718 );
719 }
720 FrailtySpec::None
721 }
722 "hazard-multiplier" => {
723 let loading = match hazard_loading.as_deref() {
724 Some("full") | None => HazardLoading::Full,
725 Some("loaded-vs-unloaded") => HazardLoading::LoadedVsUnloaded,
726 Some(other) => {
727 return Err(format!(
728 "unknown hazard_loading '{other}'; supported: 'full', 'loaded-vs-unloaded'"
729 ));
730 }
731 };
732 FrailtySpec::HazardMultiplier {
733 sigma_fixed: sigma,
734 loading,
735 }
736 }
737 "gaussian-shift" => {
738 if hazard_loading.is_some() {
739 return Err(
740 "hazard_loading is valid only with frailty_kind='hazard-multiplier'"
741 .to_string(),
742 );
743 }
744 FrailtySpec::GaussianShift { sigma_fixed: sigma }
745 }
746 other => {
747 return Err(format!(
748 "unknown frailty_kind '{other}'; supported: 'none', 'hazard-multiplier', 'gaussian-shift'"
749 ));
750 }
751 };
752 Ok(Some(frailty))
753 } else if frailty_sd.is_some() || hazard_loading.is_some() {
754 Err("frailty_kind is required when frailty_sd or hazard_loading is provided".to_string())
755 } else {
756 Ok(None)
757 }
758}
759
760fn cli_hazard_loading(loading: CliHazardLoading) -> HazardLoading {
761 match loading {
762 CliHazardLoading::Full => HazardLoading::Full,
763 CliHazardLoading::LoadedVsUnloaded => HazardLoading::LoadedVsUnloaded,
764 }
765}
766
767fn parse_group_metadata(
768 direct: Option<GroupMetadata>,
769 groups: Option<JsonValue>,
770) -> Result<Option<GroupMetadata>, String> {
771 match (direct, groups) {
772 (Some(metadata), None) => Ok(nonempty_group_metadata(metadata)),
773 (None, Some(groups)) => group_metadata_from_groups(groups),
774 (None, None) => Ok(None),
775 (Some(_), Some(_)) => {
776 Err("fit config accepts either group_metadata or groups metadata, not both".to_string())
777 }
778 }
779}
780
781fn parse_gamma_pair_value(label: &str, value: JsonValue) -> Result<(String, f64, f64), String> {
782 match value {
783 JsonValue::Array(values) => {
784 if values.len() != 2 {
785 return Err(format!(
786 "precision_hyperpriors['{label}'] must be [shape, rate]"
787 ));
788 }
789 let shape = values[0]
790 .as_f64()
791 .ok_or_else(|| format!("precision_hyperpriors['{label}'][0] must be numeric"))?;
792 let rate = values[1]
793 .as_f64()
794 .ok_or_else(|| format!("precision_hyperpriors['{label}'][1] must be numeric"))?;
795 Ok((label.to_string(), shape, rate))
796 }
797 JsonValue::Object(mut map) => {
798 let shape = map
799 .remove("shape")
800 .or_else(|| map.remove("a"))
801 .or_else(|| map.remove("a_p"))
802 .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing shape/a"))?
803 .as_f64()
804 .ok_or_else(|| {
805 format!("precision_hyperpriors['{label}'] shape/a must be numeric")
806 })?;
807 let rate = map
808 .remove("rate")
809 .or_else(|| map.remove("b"))
810 .or_else(|| map.remove("b_p"))
811 .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing rate/b"))?
812 .as_f64()
813 .ok_or_else(|| {
814 format!("precision_hyperpriors['{label}'] rate/b must be numeric")
815 })?;
816 Ok((label.to_string(), shape, rate))
817 }
818 _ => Err(format!(
819 "precision_hyperpriors['{label}'] must be [shape, rate] or an object"
820 )),
821 }
822}
823
824fn parse_precision_hyperpriors(
825 precision_hyperpriors: Option<JsonValue>,
826 penalty_block_gamma_priors: Option<JsonValue>,
827) -> Result<Vec<(String, f64, f64)>, String> {
828 let raw = match (precision_hyperpriors, penalty_block_gamma_priors) {
829 (Some(_), Some(_)) => {
830 return Err(
831 "fit config accepts either precision_hyperpriors or penalty_block_gamma_priors, not both"
832 .to_string(),
833 );
834 }
835 (Some(raw), None) | (None, Some(raw)) => raw,
836 (None, None) => {
837 return Ok(Vec::new());
838 }
839 };
840 let raw_name = "precision_hyperpriors";
841 let Some(raw) = (match raw {
842 JsonValue::Null => None,
843 other => Some(other),
844 }) else {
845 return Ok(Vec::new());
846 };
847 match raw {
848 JsonValue::Object(map) => map
849 .into_iter()
850 .map(|(label, value)| parse_gamma_pair_value(&label, value))
851 .collect(),
852 JsonValue::Array(items) => items
853 .into_iter()
854 .enumerate()
855 .map(|(idx, item)| match item {
856 JsonValue::Object(mut obj) => {
857 let label = obj
858 .remove("label")
859 .or_else(|| obj.remove("name"))
860 .or_else(|| obj.remove("group"))
861 .ok_or_else(|| format!("{raw_name}[{idx}] needs label/name/group"))?;
862 let JsonValue::String(label) = label else {
863 return Err(format!("{raw_name}[{idx}] label must be a string"));
864 };
865 parse_gamma_pair_value(&label, JsonValue::Object(obj))
866 }
867 JsonValue::Array(mut values) => {
868 if values.len() != 2 && values.len() != 3 {
869 return Err(format!(
870 "{raw_name}[{idx}] must be [label, shape, rate] or [label, [shape, rate]]"
871 ));
872 }
873 let label = values.remove(0);
874 let JsonValue::String(label) = label else {
875 return Err(format!("{raw_name}[{idx}][0] must be a string label"));
876 };
877 let pair = if values.len() == 1 {
878 values.remove(0)
879 } else {
880 JsonValue::Array(values)
881 };
882 parse_gamma_pair_value(&label, pair)
883 }
884 _ => Err(format!("{raw_name}[{idx}] must be an object or array")),
885 })
886 .collect(),
887 _ => Err(format!("{raw_name} must be a map or array")),
888 }
889}
890
891fn nonempty_group_metadata(metadata: GroupMetadata) -> Option<GroupMetadata> {
892 if metadata.is_empty() {
893 None
894 } else {
895 Some(metadata)
896 }
897}
898
899fn group_metadata_from_groups(groups: JsonValue) -> Result<Option<GroupMetadata>, String> {
900 match groups {
901 JsonValue::Null => Ok(None),
902 JsonValue::Object(map) => {
903 let out = map.into_iter().collect::<BTreeMap<_, _>>();
904 Ok(nonempty_group_metadata(out))
905 }
906 JsonValue::Array(items) => {
907 let mut out = BTreeMap::new();
908 for (idx, item) in items.into_iter().enumerate() {
909 let JsonValue::Object(mut group) = item else {
910 return Err(format!("groups[{idx}] must be an object"));
911 };
912 let Some(metadata) = group.remove("metadata") else {
913 continue;
914 };
915 let name = group
916 .remove("name")
917 .or_else(|| group.remove("id"))
918 .or_else(|| group.remove("key"))
919 .ok_or_else(|| {
920 format!(
921 "groups[{idx}] with metadata must include a string name, id, or key"
922 )
923 })?;
924 let JsonValue::String(name) = name else {
925 return Err(format!("groups[{idx}] name/id/key must be a string"));
926 };
927 if name.is_empty() {
928 return Err(format!("groups[{idx}] name/id/key must be non-empty"));
929 }
930 if out.insert(name.clone(), metadata).is_some() {
931 return Err(format!("duplicate group metadata key '{name}'"));
932 }
933 }
934 Ok(nonempty_group_metadata(out))
935 }
936 _ => Err("groups must be an object map or an array of group objects".to_string()),
937 }
938}
939
940fn parse_gpu_policy(raw_gpu: &str) -> Result<gam_gpu::GpuPolicy, String> {
941 gam_gpu::GpuPolicy::parse(raw_gpu).ok_or_else(|| {
942 format!(
943 "invalid gpu policy '{}'; supported values are auto, off, force",
944 raw_gpu
945 )
946 })
947}
948
949fn validate_resolved_fit_config(config: &FitConfig) -> Result<(), String> {
950 if !config.ridge_lambda.is_finite() || config.ridge_lambda < 0.0 {
951 return Err("--ridge-lambda must be finite and >= 0".to_string());
952 }
953 let likelihood_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
954 validate_survival_baseline_args(
955 likelihood_mode,
956 &config.baseline_target,
957 config.baseline_scale,
958 config.baseline_shape,
959 config.baseline_rate,
960 config.baseline_makeham,
961 )
962}
963
964fn survival_link_usage() -> &'static str {
965 "use identity|logit|probit|cloglog|loglog|cauchit|sas|beta-logistic|blended(...)/mixture(...) or flexible(...)"
966}
967
968#[cfg(test)]
969mod tests {
970 use super::*;
971 use gam_models::survival::lognormal_kernel::FrailtySpec;
972 use serde_json::{Value, json};
973
974 struct ParityCase {
975 name: &'static str,
976 cli: CliFitConfigInput,
977 json: Value,
978 }
979
980 fn base_cli() -> CliFitConfigInput {
981 CliFitConfigInput {
982 family: None,
983 expectile_tau: None,
984 negative_binomial_theta: None,
985 link: None,
986 flexible_link: false,
987 offset_column: None,
988 weight_column: None,
989 noise_offset_column: None,
990 baseline_target: "linear".to_string(),
991 baseline_scale: None,
992 baseline_shape: None,
993 baseline_rate: None,
994 baseline_makeham: None,
995 time_basis: "ispline".to_string(),
996 time_degree: 3,
997 time_num_internal_knots: 8,
998 time_smooth_lambda: 1e-2,
999 survival_likelihood: "transformation".to_string(),
1000 survival_distribution: "gaussian".to_string(),
1001 threshold_time_k: None,
1002 threshold_time_degree: 3,
1003 sigma_time_k: None,
1004 sigma_time_degree: 3,
1005 noise_formula: None,
1006 logslope_formula: None,
1007 z_column: None,
1008 scale_dimensions: false,
1009 adaptive_regularization: None,
1010 ridge_lambda: 1e-6,
1011 transformation_normal: false,
1012 firth: false,
1013 outer_max_iter: None,
1014 gpu: None,
1015 frailty_kind: None,
1016 frailty_sd: None,
1017 hazard_loading: None,
1018 }
1019 }
1020
1021 fn resolved_cli(input: CliFitConfigInput) -> Result<FitConfig, String> {
1022 resolve_cli_fit_config(input)
1023 }
1024
1025 fn resolved_json(config: Value) -> Result<FitConfig, String> {
1026 parse_fit_config_json(Some(&config.to_string())).map(|resolved| {
1027 assert_eq!(resolved.training_table_kind, None);
1028 resolved.fit_config
1029 })
1030 }
1031
1032 fn canonical_fit_config(mut config: FitConfig) -> String {
1033 if config.frailty.is_none() {
1034 config.frailty = Some(FrailtySpec::None);
1035 }
1036 format!("{config:#?}")
1037 }
1038
1039 #[test]
1040 fn cli_shaped_and_json_wire_config_resolution_match() {
1041 let cases = vec![
1042 ParityCase {
1043 name: "family and link selection",
1044 cli: {
1045 let mut input = base_cli();
1046 input.family = Some("binomial".to_string());
1047 input.link = Some("probit".to_string());
1048 input.flexible_link = true;
1049 input
1050 },
1051 json: json!({
1052 "family": "binomial",
1053 "link": "probit",
1054 "flexible_link": true
1055 }),
1056 },
1057 ParityCase {
1058 name: "offset weights ridge and noise offset columns",
1059 cli: {
1060 let mut input = base_cli();
1061 input.offset_column = Some("eta_offset".to_string());
1062 input.weight_column = Some("case_weight".to_string());
1063 input.noise_offset_column = Some("sigma_offset".to_string());
1064 input.ridge_lambda = 0.125;
1065 input
1066 },
1067 json: json!({
1068 "offset": "eta_offset",
1069 "weights": "case_weight",
1070 "noise_offset": "sigma_offset",
1071 "ridge_lambda": 0.125
1072 }),
1073 },
1074 ParityCase {
1075 name: "weibull survival likelihood and baseline scale shape",
1076 cli: {
1077 let mut input = base_cli();
1078 input.survival_likelihood = "weibull".to_string();
1079 input.baseline_target = "weibull".to_string();
1080 input.baseline_scale = Some(2.5);
1081 input.baseline_shape = Some(1.75);
1082 input
1083 },
1084 json: json!({
1085 "survival_likelihood": "weibull",
1086 "baseline_target": "weibull",
1087 "baseline_scale": 2.5,
1088 "baseline_shape": 1.75
1089 }),
1090 },
1091 ParityCase {
1092 name: "transformation survival gompertz makeham baseline",
1093 cli: {
1094 let mut input = base_cli();
1095 input.survival_likelihood = "transformation".to_string();
1096 input.baseline_target = "gompertz-makeham".to_string();
1097 input.baseline_shape = Some(1.2);
1098 input.baseline_rate = Some(0.04);
1099 input.baseline_makeham = Some(0.01);
1100 input
1101 },
1102 json: json!({
1103 "survival_likelihood": "transformation",
1104 "baseline_target": "gompertz-makeham",
1105 "baseline_shape": 1.2,
1106 "baseline_rate": 0.04,
1107 "baseline_makeham": 0.01
1108 }),
1109 },
1110 ParityCase {
1111 name: "survival likelihood values are canonicalized",
1112 cli: {
1113 let mut input = base_cli();
1114 input.survival_likelihood = "TRANSFORMATION".to_string();
1115 input
1116 },
1117 json: json!({
1118 "survival_likelihood": "Transformation"
1119 }),
1120 },
1121 ParityCase {
1122 name: "noise formula logslope z column and scale dimensions",
1123 cli: {
1124 let mut input = base_cli();
1125 input.noise_formula = Some("~ s(age) + treatment".to_string());
1126 input.logslope_formula = Some("~ s(dose)".to_string());
1127 input.z_column = Some("dose".to_string());
1128 input.scale_dimensions = true;
1129 input
1130 },
1131 json: json!({
1132 "noise_formula": "~ s(age) + treatment",
1133 "logslope_formula": "~ s(dose)",
1134 "z_column": "dose",
1135 "scale_dimensions": true
1136 }),
1137 },
1138 ParityCase {
1139 name: "firth transformation normal outer iterations and adaptive regularization",
1140 cli: {
1141 let mut input = base_cli();
1142 input.firth = true;
1143 input.transformation_normal = true;
1144 input.outer_max_iter = Some(7);
1145 input.adaptive_regularization = Some(true);
1146 input
1147 },
1148 json: json!({
1149 "firth": true,
1150 "transformation_normal": true,
1151 "outer_max_iter": 7,
1152 "adaptive_regularization": true
1153 }),
1154 },
1155 ParityCase {
1156 name: "gpu policy toggle",
1157 cli: {
1158 let mut input = base_cli();
1159 input.gpu = Some("off".to_string());
1160 input
1161 },
1162 json: json!({
1163 "gpu": "off"
1164 }),
1165 },
1166 ParityCase {
1167 name: "hazard multiplier frailty fields",
1168 cli: {
1169 let mut input = base_cli();
1170 input.frailty_kind = Some(CliFrailtyKind::HazardMultiplier);
1171 input.frailty_sd = Some(0.35);
1172 input.hazard_loading = Some(CliHazardLoading::LoadedVsUnloaded);
1173 input
1174 },
1175 json: json!({
1176 "frailty_kind": "hazard-multiplier",
1177 "frailty_sd": 0.35,
1178 "hazard_loading": "loaded-vs-unloaded"
1179 }),
1180 },
1181 ParityCase {
1182 name: "gaussian shift frailty fields",
1183 cli: {
1184 let mut input = base_cli();
1185 input.frailty_kind = Some(CliFrailtyKind::GaussianShift);
1186 input.frailty_sd = Some(0.2);
1187 input
1188 },
1189 json: json!({
1190 "frailty_kind": "gaussian-shift",
1191 "frailty_sd": 0.2
1192 }),
1193 },
1194 ];
1195
1196 for case in cases {
1197 let cli = resolved_cli(case.cli)
1198 .unwrap_or_else(|err| panic!("{}: CLI-shaped config failed: {err}", case.name));
1199 let json = resolved_json(case.json)
1200 .unwrap_or_else(|err| panic!("{}: JSON wire config failed: {err}", case.name));
1201 assert_eq!(
1202 canonical_fit_config(cli),
1203 canonical_fit_config(json),
1204 "{}",
1205 case.name
1206 );
1207 }
1208 }
1209
1210 #[test]
1211 fn cli_shaped_and_json_wire_config_resolution_rejections_match() {
1212 let cases = vec![
1213 ParityCase {
1214 name: "negative ridge lambda",
1215 cli: {
1216 let mut input = base_cli();
1217 input.ridge_lambda = -1.0;
1218 input
1219 },
1220 json: json!({
1221 "ridge_lambda": -1.0
1222 }),
1223 },
1224 ParityCase {
1225 name: "unknown gpu policy",
1226 cli: {
1227 let mut input = base_cli();
1228 input.gpu = Some("cuda".to_string());
1229 input
1230 },
1231 json: json!({
1232 "gpu": "cuda"
1233 }),
1234 },
1235 ParityCase {
1236 name: "linear baseline rejects shape",
1237 cli: {
1238 let mut input = base_cli();
1239 input.baseline_shape = Some(1.1);
1240 input
1241 },
1242 json: json!({
1243 "baseline_shape": 1.1
1244 }),
1245 },
1246 ParityCase {
1247 name: "weibull likelihood rejects gompertz target",
1248 cli: {
1249 let mut input = base_cli();
1250 input.survival_likelihood = "weibull".to_string();
1251 input.baseline_target = "gompertz".to_string();
1252 input
1253 },
1254 json: json!({
1255 "survival_likelihood": "weibull",
1256 "baseline_target": "gompertz"
1257 }),
1258 },
1259 ];
1260
1261 for case in cases {
1262 let cli = resolved_cli(case.cli).expect_err(case.name);
1263 let json = resolved_json(case.json).expect_err(case.name);
1264 assert_eq!(cli, json, "{}", case.name);
1265 }
1266 }
1267
1268 #[test]
1271 fn parse_comma_f64_empty_string_returns_empty_vec() {
1272 assert_eq!(parse_comma_f64("", "x").unwrap(), Vec::<f64>::new());
1273 assert_eq!(parse_comma_f64(" ", "x").unwrap(), Vec::<f64>::new());
1274 }
1275
1276 #[test]
1277 fn parse_comma_f64_single_value() {
1278 assert_eq!(parse_comma_f64("3.14", "x").unwrap(), vec![3.14]);
1279 }
1280
1281 #[test]
1282 fn parse_comma_f64_multiple_values_with_spaces() {
1283 let result = parse_comma_f64("1.0, 2.5, -3.0", "x").unwrap();
1284 assert_eq!(result, vec![1.0, 2.5, -3.0]);
1285 }
1286
1287 #[test]
1288 fn parse_comma_f64_non_numeric_returns_error() {
1289 let err = parse_comma_f64("1.0, bad, 3.0", "--vals").unwrap_err();
1290 assert!(err.contains("--vals"), "error should name the label: {err}");
1291 assert!(err.contains("bad"), "error should name the bad token: {err}");
1292 }
1293
1294 #[test]
1295 fn parse_comma_f64_infinity_returns_error() {
1296 let err = parse_comma_f64("inf", "--vals").unwrap_err();
1297 assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1298 }
1299
1300 #[test]
1301 fn parse_comma_f64_nan_returns_error() {
1302 let err = parse_comma_f64("nan", "--vals").unwrap_err();
1303 assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1304 }
1305
1306 #[test]
1309 fn normalize_optional_family_none_passthrough() {
1310 assert_eq!(normalize_optional_family(None), None);
1311 }
1312
1313 #[test]
1314 fn normalize_optional_family_auto_becomes_none() {
1315 assert_eq!(normalize_optional_family(Some("auto".to_string())), None);
1316 assert_eq!(normalize_optional_family(Some("Auto".to_string())), None);
1317 assert_eq!(normalize_optional_family(Some("AUTO".to_string())), None);
1318 }
1319
1320 #[test]
1321 fn normalize_optional_family_non_auto_passthrough() {
1322 assert_eq!(
1323 normalize_optional_family(Some("binomial".to_string())),
1324 Some("binomial".to_string())
1325 );
1326 assert_eq!(
1327 normalize_optional_family(Some("gaussian".to_string())),
1328 Some("gaussian".to_string())
1329 );
1330 }
1331
1332 #[test]
1335 fn parse_survival_likelihood_cli_valid_values() {
1336 assert_eq!(
1337 parse_survival_likelihood_cli("transformation").unwrap(),
1338 "transformation"
1339 );
1340 assert_eq!(
1341 parse_survival_likelihood_cli("weibull").unwrap(),
1342 "weibull"
1343 );
1344 assert_eq!(
1346 parse_survival_likelihood_cli("WEIBULL").unwrap(),
1347 "weibull"
1348 );
1349 assert_eq!(
1350 parse_survival_likelihood_cli("Transformation").unwrap(),
1351 "transformation"
1352 );
1353 }
1354
1355 #[test]
1356 fn parse_survival_likelihood_cli_invalid_returns_error() {
1357 assert!(parse_survival_likelihood_cli("lognormal").is_err());
1358 assert!(parse_survival_likelihood_cli("").is_err());
1359 }
1360
1361 #[test]
1364 fn parse_baseline_target_cli_valid_values() {
1365 for target in &["linear", "weibull", "gompertz", "gompertz-makeham"] {
1366 assert_eq!(
1367 parse_baseline_target_cli(target).unwrap(),
1368 *target,
1369 "should accept '{target}'"
1370 );
1371 }
1372 assert_eq!(
1374 parse_baseline_target_cli(" Weibull ").unwrap(),
1375 "weibull"
1376 );
1377 }
1378
1379 #[test]
1380 fn parse_baseline_target_cli_invalid_returns_error() {
1381 let err = parse_baseline_target_cli("cox").unwrap_err();
1382 assert!(err.contains("cox"), "error should name the bad value: {err}");
1383 }
1384
1385 #[test]
1388 fn validate_survival_baseline_args_linear_rejects_params() {
1389 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1390 assert!(validate_survival_baseline_args(
1391 mode,
1392 "linear",
1393 Some(1.0),
1394 None,
1395 None,
1396 None
1397 )
1398 .is_err());
1399 }
1400
1401 #[test]
1402 fn validate_survival_baseline_args_linear_accepts_no_params() {
1403 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1404 assert!(validate_survival_baseline_args(
1405 mode, "linear", None, None, None, None
1406 )
1407 .is_ok());
1408 }
1409
1410 #[test]
1411 fn validate_survival_baseline_args_weibull_likelihood_rejects_gompertz_target() {
1412 let mode = parse_survival_likelihood_mode("weibull").unwrap();
1413 assert!(validate_survival_baseline_args(
1414 mode,
1415 "gompertz",
1416 None,
1417 None,
1418 None,
1419 None
1420 )
1421 .is_err());
1422 }
1423
1424 #[test]
1425 fn validate_survival_baseline_args_gompertz_rejects_scale() {
1426 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1427 assert!(validate_survival_baseline_args(
1428 mode,
1429 "gompertz",
1430 Some(2.0),
1431 None,
1432 None,
1433 None
1434 )
1435 .is_err());
1436 }
1437
1438 #[test]
1439 fn validate_survival_baseline_args_gompertz_makeham_rejects_scale() {
1440 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1441 assert!(validate_survival_baseline_args(
1442 mode,
1443 "gompertz-makeham",
1444 Some(1.0),
1445 None,
1446 None,
1447 None
1448 )
1449 .is_err());
1450 }
1451}