1use gam_models::survival::location_scale::residual_distribution_inverse_link;
2use gam_models::survival::lognormal_kernel::{FrailtySpec, HazardLoading};
3use gam_models::survival::parse_survival_distribution;
4use gam_models::survival::{SurvivalLikelihoodMode, parse_survival_likelihood_mode};
5use gam_inference::formula_dsl::parse_link_choice;
6use gam_inference::model::GroupMetadata;
7use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
8use gam_models::fit_orchestration::descriptors::build_analytic_penalty_registry_from_descriptors;
9use gam_models::fit_orchestration::{CtnStage1Recipe, FitConfig};
10use gam_models::transformation_normal::TransformationNormalConfig;
11use gam_problem::types::{InverseLink, LinkFunction, MixtureLinkSpec, SasLinkSpec, StandardLink};
12use ndarray::Array1;
13use serde::Deserialize;
14use serde_json::Value as JsonValue;
15use std::collections::BTreeMap;
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18pub enum CliFrailtyKind {
19 GaussianShift,
20 HazardMultiplier,
21}
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum CliHazardLoading {
25 Full,
26 LoadedVsUnloaded,
27}
28
29#[derive(Default, Deserialize)]
30#[serde(deny_unknown_fields)]
31struct JsonFitConfig {
32 family: Option<String>,
33 offset: Option<String>,
34 weights: Option<String>,
35 ridge_lambda: Option<f64>,
36 transformation_normal: Option<bool>,
37 survival_likelihood: Option<String>,
38 baseline_target: Option<String>,
39 baseline_scale: Option<f64>,
40 baseline_shape: Option<f64>,
41 baseline_rate: Option<f64>,
42 baseline_makeham: Option<f64>,
43 z_column: Option<String>,
44 logslope_formula: Option<String>,
45 ctn_stage1: Option<JsonCtnStage1>,
46 link: Option<String>,
47 flexible_link: Option<bool>,
48 scale_dimensions: Option<bool>,
49 adaptive_regularization: Option<bool>,
50 noise_formula: Option<String>,
51 noise_offset: Option<String>,
52 firth: Option<bool>,
53 outer_max_iter: Option<usize>,
54 gpu: Option<String>,
55 group_metadata: Option<GroupMetadata>,
56 groups: Option<JsonValue>,
57 precision_hyperpriors: Option<JsonValue>,
58 penalty_block_gamma_priors: Option<JsonValue>,
59 latents: Option<JsonValue>,
60 penalties: Option<JsonValue>,
61 smooths: Option<JsonValue>,
62 topology_auto_selector: Option<JsonValue>,
63 frailty_kind: Option<String>,
64 frailty_sd: Option<f64>,
65 hazard_loading: Option<String>,
66 training_table_kind: Option<String>,
67}
68
69#[derive(Deserialize)]
70#[serde(deny_unknown_fields)]
71struct JsonCtnStage1 {
72 response_column: String,
73 covariate_formula_rhs: String,
74 #[serde(default)]
75 config: Option<JsonCtnStage1Config>,
76 #[serde(default)]
77 weight_column: Option<String>,
78 #[serde(default)]
79 offset_column: Option<String>,
80}
81
82#[derive(Deserialize)]
83#[serde(deny_unknown_fields)]
84struct JsonCtnStage1Config {
85 #[serde(default)]
86 response_degree: Option<usize>,
87 #[serde(default)]
88 response_num_internal_knots: Option<usize>,
89 #[serde(default)]
90 response_penalty_order: Option<usize>,
91 #[serde(default)]
92 response_extra_penalty_orders: Option<Vec<usize>>,
93 #[serde(default)]
94 double_penalty: Option<bool>,
95}
96
97impl JsonCtnStage1 {
98 fn into_recipe(self) -> Result<CtnStage1Recipe, String> {
99 let mut config = TransformationNormalConfig::default();
100 if let Some(overrides) = self.config {
101 if let Some(value) = overrides.response_degree {
102 config.response_degree = value;
103 }
104 if let Some(value) = overrides.response_num_internal_knots {
105 config.response_num_internal_knots = value;
106 }
107 if let Some(value) = overrides.response_penalty_order {
108 config.response_penalty_order = value;
109 }
110 if let Some(value) = overrides.response_extra_penalty_orders {
111 config.response_extra_penalty_orders = value;
112 }
113 if let Some(value) = overrides.double_penalty {
114 config.double_penalty = value;
115 }
116 }
117 CtnStage1Recipe::new(
118 &self.response_column,
119 &self.covariate_formula_rhs,
120 config,
121 self.weight_column.as_deref(),
122 self.offset_column.as_deref(),
123 )
124 }
125}
126
127pub struct ResolvedFitConfig {
128 pub fit_config: FitConfig,
129 pub training_table_kind: Option<String>,
130}
131
132pub struct CliFitConfigInput {
133 pub family: Option<String>,
134 pub expectile_tau: Option<f64>,
140 pub negative_binomial_theta: Option<f64>,
141 pub link: Option<String>,
142 pub flexible_link: bool,
143 pub offset_column: Option<String>,
144 pub weight_column: Option<String>,
145 pub noise_offset_column: Option<String>,
146 pub baseline_target: String,
147 pub baseline_scale: Option<f64>,
148 pub baseline_shape: Option<f64>,
149 pub baseline_rate: Option<f64>,
150 pub baseline_makeham: Option<f64>,
151 pub time_basis: String,
152 pub time_degree: usize,
153 pub time_num_internal_knots: usize,
154 pub time_smooth_lambda: f64,
155 pub survival_likelihood: String,
156 pub survival_distribution: String,
157 pub threshold_time_k: Option<usize>,
158 pub threshold_time_degree: usize,
159 pub sigma_time_k: Option<usize>,
160 pub sigma_time_degree: usize,
161 pub noise_formula: Option<String>,
162 pub logslope_formula: Option<String>,
163 pub z_column: Option<String>,
164 pub scale_dimensions: bool,
165 pub adaptive_regularization: Option<bool>,
166 pub ridge_lambda: f64,
167 pub transformation_normal: bool,
168 pub firth: bool,
169 pub outer_max_iter: Option<usize>,
170 pub gpu: Option<String>,
171 pub frailty_kind: Option<CliFrailtyKind>,
172 pub frailty_sd: Option<f64>,
173 pub hazard_loading: Option<CliHazardLoading>,
174}
175
176pub struct SurvivalInverseLinkInput<'a> {
177 pub link: Option<&'a str>,
178 pub mixture_rho: Option<&'a str>,
179 pub sas_init: Option<&'a str>,
180 pub beta_logistic_init: Option<&'a str>,
181 pub survival_distribution: &'a str,
182}
183
184pub fn parse_fit_config_json(config_json: Option<&str>) -> Result<ResolvedFitConfig, String> {
185 let json_config = match config_json {
186 Some(raw) if !raw.trim().is_empty() => serde_json::from_str::<JsonFitConfig>(raw)
187 .map_err(|err| format!("invalid fit config json: {err}"))?,
188 _ => JsonFitConfig::default(),
189 };
190 resolve_json_fit_config(json_config)
191}
192
193fn resolve_json_fit_config(json_config: JsonFitConfig) -> Result<ResolvedFitConfig, String> {
194 let training_table_kind = json_config.training_table_kind;
195 let mut fit_config = FitConfig::default();
196 fit_config.group_metadata =
197 parse_group_metadata(json_config.group_metadata, json_config.groups)?;
198 fit_config.penalty_block_gamma_priors = parse_precision_hyperpriors(
199 json_config.precision_hyperpriors,
200 json_config.penalty_block_gamma_priors,
201 )?;
202 let analytic_penalties = json_config.penalties;
203 build_analytic_penalty_registry_from_descriptors(
204 json_config.latents.as_ref(),
205 analytic_penalties.as_ref(),
206 )?;
207 fit_config.latents = json_config.latents;
208 fit_config.analytic_penalties = analytic_penalties;
209 fit_config.smooth_overrides = json_config.smooths;
210 fit_config.topology_auto_selector = json_config
211 .topology_auto_selector
212 .as_ref()
213 .map(gam_solve::topology_selector::TopologyAutoSelector::from_json)
214 .transpose()?;
215 fit_config.family = normalize_optional_family(json_config.family);
216 fit_config.offset_column = json_config.offset;
217 fit_config.weight_column = json_config.weights;
218 if let Some(ridge_lambda) = json_config.ridge_lambda {
219 fit_config.ridge_lambda = ridge_lambda;
220 }
221 if let Some(flag) = json_config.transformation_normal {
222 fit_config.transformation_normal = flag;
223 }
224 if let Some(mode) = json_config.survival_likelihood {
225 fit_config.survival_likelihood = parse_survival_likelihood_cli(&mode)?;
226 }
227 if let Some(target) = json_config.baseline_target {
228 fit_config.baseline_target =
229 resolve_nonempty_string(target, "baseline_target must be a non-empty string")?;
230 }
231 if let Some(value) = json_config.baseline_scale {
232 fit_config.baseline_scale = Some(value);
233 }
234 if let Some(value) = json_config.baseline_shape {
235 fit_config.baseline_shape = Some(value);
236 }
237 if let Some(value) = json_config.baseline_rate {
238 fit_config.baseline_rate = Some(value);
239 }
240 if let Some(value) = json_config.baseline_makeham {
241 fit_config.baseline_makeham = Some(value);
242 }
243 if let Some(z) = json_config.z_column {
244 fit_config.z_column = Some(resolve_nonempty_string(
245 z,
246 "z_column must be a non-empty string",
247 )?);
248 }
249 if let Some(formula) = json_config.logslope_formula {
250 fit_config.logslope_formula = Some(formula);
251 }
252 if let Some(stage1) = json_config.ctn_stage1 {
253 fit_config.ctn_stage1 = Some(stage1.into_recipe()?);
254 }
255 if let Some(link) = json_config.link {
256 let trimmed = link.trim();
257 fit_config.link = if trimmed.is_empty() {
258 None
259 } else {
260 Some(trimmed.to_string())
261 };
262 }
263 if let Some(flag) = json_config.flexible_link {
264 fit_config.flexible_link = flag;
265 }
266 if let Some(flag) = json_config.scale_dimensions {
267 fit_config.scale_dimensions = flag;
268 }
269 if let Some(flag) = json_config.adaptive_regularization {
270 fit_config.adaptive_regularization = Some(flag);
271 }
272 if let Some(formula) = json_config.noise_formula {
273 fit_config.noise_formula = Some(formula);
274 }
275 if let Some(column) = json_config.noise_offset {
276 fit_config.noise_offset_column = Some(column);
277 }
278 if let Some(flag) = json_config.firth {
279 fit_config.firth = flag;
280 }
281 if let Some(value) = json_config.outer_max_iter {
282 if value == 0 {
283 return Err("outer_max_iter must be >= 1".to_string());
284 }
285 fit_config.outer_max_iter = Some(value);
286 }
287 if let Some(raw_gpu) = json_config.gpu {
288 fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
289 }
290 fit_config.frailty = parse_json_frailty_spec(
291 json_config.frailty_kind,
292 json_config.frailty_sd,
293 json_config.hazard_loading,
294 )?;
295 validate_resolved_fit_config(&fit_config)?;
296 Ok(ResolvedFitConfig {
297 fit_config,
298 training_table_kind,
299 })
300}
301
302pub fn resolve_cli_fit_config(input: CliFitConfigInput) -> Result<FitConfig, String> {
303 let mut fit_config = FitConfig::default();
304 fit_config.family = input.family;
305 fit_config.expectile_tau = input.expectile_tau;
306 fit_config.negative_binomial_theta = input.negative_binomial_theta;
307 fit_config.link = input.link;
308 fit_config.flexible_link = input.flexible_link;
309 fit_config.offset_column = input.offset_column;
310 fit_config.weight_column = input.weight_column;
311 fit_config.noise_offset_column = input.noise_offset_column;
312 fit_config.baseline_target = input.baseline_target;
313 fit_config.baseline_scale = input.baseline_scale;
314 fit_config.baseline_shape = input.baseline_shape;
315 fit_config.baseline_rate = input.baseline_rate;
316 fit_config.baseline_makeham = input.baseline_makeham;
317 fit_config.time_basis = input.time_basis;
318 fit_config.time_degree = input.time_degree;
319 fit_config.time_num_internal_knots = input.time_num_internal_knots;
320 fit_config.time_smooth_lambda = input.time_smooth_lambda;
321 fit_config.survival_likelihood = parse_survival_likelihood_cli(&input.survival_likelihood)?;
322 fit_config.survival_distribution = input.survival_distribution;
323 fit_config.threshold_time_k = input.threshold_time_k;
324 fit_config.threshold_time_degree = input.threshold_time_degree;
325 fit_config.sigma_time_k = input.sigma_time_k;
326 fit_config.sigma_time_degree = input.sigma_time_degree;
327 fit_config.noise_formula = input.noise_formula;
328 fit_config.logslope_formula = input.logslope_formula;
329 fit_config.z_column = input.z_column;
330 fit_config.scale_dimensions = input.scale_dimensions;
331 fit_config.adaptive_regularization = input.adaptive_regularization;
332 fit_config.ridge_lambda = input.ridge_lambda;
333 fit_config.transformation_normal = input.transformation_normal;
334 fit_config.firth = input.firth;
335 fit_config.outer_max_iter = input.outer_max_iter;
336 if let Some(raw_gpu) = input.gpu {
337 fit_config.gpu_policy = parse_gpu_policy(&raw_gpu)?;
338 }
339 fit_config.frailty = Some(resolve_cli_frailty_spec(
340 input.frailty_kind,
341 input.frailty_sd,
342 input.hazard_loading,
343 "fit",
344 )?);
345 validate_resolved_fit_config(&fit_config)?;
346 Ok(fit_config)
347}
348
349pub fn resolve_cli_frailty_spec(
350 frailty_kind: Option<CliFrailtyKind>,
351 frailty_sd: Option<f64>,
352 hazard_loading: Option<CliHazardLoading>,
353 context: &str,
354) -> Result<FrailtySpec, String> {
355 let validate_sigma = || -> Result<Option<f64>, String> {
356 match frailty_sd {
357 None => Ok(None),
358 Some(sigma) => {
359 if !sigma.is_finite() || sigma < 0.0 {
360 return Err(format!(
361 "{context} requires a finite --frailty-sd >= 0, got {sigma}"
362 ));
363 }
364 Ok(Some(sigma))
365 }
366 }
367 };
368
369 match frailty_kind {
370 None => {
371 if frailty_sd.is_some() || hazard_loading.is_some() {
372 return Err(format!(
373 "{context} requires --frailty-kind when --frailty-sd or --hazard-loading is provided"
374 ));
375 }
376 Ok(FrailtySpec::None)
377 }
378 Some(CliFrailtyKind::GaussianShift) => {
379 if hazard_loading.is_some() {
380 return Err(format!(
381 "{context} does not accept --hazard-loading with --frailty-kind gaussian-shift"
382 ));
383 }
384 Ok(FrailtySpec::GaussianShift {
385 sigma_fixed: validate_sigma()?,
386 })
387 }
388 Some(CliFrailtyKind::HazardMultiplier) => Ok(FrailtySpec::HazardMultiplier {
389 sigma_fixed: validate_sigma()?,
390 loading: hazard_loading.map(cli_hazard_loading).ok_or_else(|| {
391 format!("{context} requires --hazard-loading with --frailty-kind hazard-multiplier")
392 })?,
393 }),
394 }
395}
396
397pub fn parse_survival_likelihood_cli(raw: &str) -> Result<String, String> {
398 let normalized = raw.trim().to_ascii_lowercase();
399 parse_survival_likelihood_mode(&normalized)?;
400 Ok(normalized)
401}
402
403pub fn parse_baseline_target_cli(raw: &str) -> Result<String, String> {
404 let normalized = raw.trim().to_ascii_lowercase();
405 match normalized.as_str() {
406 "linear" | "weibull" | "gompertz" | "gompertz-makeham" => Ok(normalized),
407 other => Err(format!(
408 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
409 )),
410 }
411}
412
413pub fn validate_survival_baseline_args(
414 likelihood_mode: SurvivalLikelihoodMode,
415 baseline_target: &str,
416 baseline_scale: Option<f64>,
417 baseline_shape: Option<f64>,
418 baseline_rate: Option<f64>,
419 baseline_makeham: Option<f64>,
420) -> Result<(), String> {
421 if likelihood_mode == SurvivalLikelihoodMode::Weibull {
422 if baseline_rate.is_some() || baseline_makeham.is_some() {
423 return Err(
424 "--survival-likelihood weibull does not use --baseline-rate or --baseline-makeham"
425 .to_string(),
426 );
427 }
428 if !matches!(baseline_target, "linear" | "weibull") {
429 return Err(
430 "--survival-likelihood weibull supports only --baseline-target linear|weibull"
431 .to_string(),
432 );
433 }
434 return Ok(());
435 }
436
437 match baseline_target {
438 "linear" => {
439 if baseline_scale.is_some()
440 || baseline_shape.is_some()
441 || baseline_rate.is_some()
442 || baseline_makeham.is_some()
443 {
444 return Err(
445 "--baseline-target linear does not use baseline parameter flags".to_string(),
446 );
447 }
448 }
449 "weibull" => {
450 if baseline_rate.is_some() || baseline_makeham.is_some() {
451 return Err(
452 "--baseline-target weibull does not use --baseline-rate or --baseline-makeham"
453 .to_string(),
454 );
455 }
456 }
457 "gompertz" => {
458 if baseline_scale.is_some() || baseline_makeham.is_some() {
459 return Err(
460 "--baseline-target gompertz does not use --baseline-scale or --baseline-makeham"
461 .to_string(),
462 );
463 }
464 }
465 "gompertz-makeham" => {
466 if baseline_scale.is_some() {
467 return Err(
468 "--baseline-target gompertz-makeham does not use --baseline-scale".to_string(),
469 );
470 }
471 }
472 other => {
473 return Err(format!(
474 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
475 ));
476 }
477 }
478 Ok(())
479}
480
481pub fn parse_comma_f64(v: &str, label: &str) -> Result<Vec<f64>, String> {
482 let mut out = Vec::new();
483 for part in v.split(',') {
484 let t = part.trim();
485 if t.is_empty() {
486 continue;
487 }
488 let parsed = t
489 .parse::<f64>()
490 .map_err(|err| format!("{label} contains non-numeric value '{t}': {err}"))?;
491 if !parsed.is_finite() {
492 return Err(format!("{label} contains non-finite value '{t}'"));
493 }
494 out.push(parsed);
495 }
496 Ok(out)
497}
498
499pub fn effective_link_to_standard(
500 link: LinkFunction,
501 context: &str,
502) -> Result<StandardLink, String> {
503 StandardLink::try_from(link).map_err(|_| {
504 format!(
505 "{context}: state-bearing link `{}` must be routed through `InverseLink::Sas` / `InverseLink::BetaLogistic`, not `Standard(_)`",
506 link.name()
507 )
508 })
509}
510
511pub fn parse_survival_inverse_link(
512 input: SurvivalInverseLinkInput<'_>,
513) -> Result<InverseLink, String> {
514 if let Some(raw) = input.link {
515 let name = raw.trim().to_ascii_lowercase();
516 if name == "loglog" || name == "cauchit" {
517 return Err(format!(
518 "survival --link {name} is not supported: cauchit and loglog have no \
519 LinkFunction representative and cannot be wrapped in a MixtureLinkSpec; \
520 {}",
521 survival_link_usage()
522 ));
523 }
524 }
525 let choice = parse_link_choice(input.link, false).map_err(|err| {
526 let err = err.to_string();
527 if let Some(raw) = input.link {
528 let name = raw.trim().to_ascii_lowercase();
529 if err.starts_with("unsupported --link ") || err.starts_with("unsupported link type ") {
530 return format!(
531 "unsupported survival --link '{name}'; {}",
532 survival_link_usage()
533 );
534 }
535 }
536 err
537 })?;
538 if let Some(choice) = choice {
539 if let Some(components) = choice.mixture_components {
540 if input.sas_init.is_some() || input.beta_logistic_init.is_some() {
541 return Err(
542 "survival blended(...) link does not accept --sas-init/--beta-logistic-init"
543 .to_string(),
544 );
545 }
546 let expected = components.len().saturating_sub(1);
547 let initial_rho = if let Some(raw) = input.mixture_rho {
548 let vals = parse_comma_f64(raw, "--mixture-rho")?;
549 if vals.len() != expected {
550 return Err(format!(
551 "--mixture-rho expects {expected} values for blended({})",
552 components
553 .iter()
554 .map(|component| component.name())
555 .collect::<Vec<_>>()
556 .join(",")
557 ));
558 }
559 Array1::from_vec(vals)
560 } else {
561 Array1::zeros(expected)
562 };
563 return state_fromspec(&MixtureLinkSpec {
564 components,
565 initial_rho,
566 })
567 .map(InverseLink::Mixture)
568 .map_err(|e| format!("invalid survival blended link state: {e}"));
569 }
570
571 if input.mixture_rho.is_some() {
572 return Err(
573 "--mixture-rho requires survival --link blended(...)/mixture(...)".to_string(),
574 );
575 }
576 match choice.link {
577 LinkFunction::Sas => {
578 if input.beta_logistic_init.is_some() {
579 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
580 }
581 let (epsilon, log_delta) = if let Some(raw) = input.sas_init {
582 let vals = parse_comma_f64(raw, "--sas-init")?;
583 if vals.len() != 2 {
584 return Err(format!(
585 "--sas-init expects two values: epsilon,log_delta (got {})",
586 vals.len()
587 ));
588 }
589 (vals[0], vals[1])
590 } else {
591 (0.0, 0.0)
592 };
593 state_from_sasspec(SasLinkSpec {
594 initial_epsilon: epsilon,
595 initial_log_delta: log_delta,
596 })
597 .map(InverseLink::Sas)
598 .map_err(|e| format!("invalid survival SAS link state: {e}"))
599 }
600 LinkFunction::BetaLogistic => {
601 if input.sas_init.is_some() {
602 return Err("--sas-init requires --link sas".to_string());
603 }
604 let (epsilon, delta) = if let Some(raw) = input.beta_logistic_init {
605 let vals = parse_comma_f64(raw, "--beta-logistic-init")?;
606 if vals.len() != 2 {
607 return Err(format!(
608 "--beta-logistic-init expects two values: epsilon,delta (got {})",
609 vals.len()
610 ));
611 }
612 (vals[0], vals[1])
613 } else {
614 (0.0, 0.0)
615 };
616 state_from_beta_logisticspec(SasLinkSpec {
617 initial_epsilon: epsilon,
618 initial_log_delta: delta,
619 })
620 .map(InverseLink::BetaLogistic)
621 .map_err(|e| format!("invalid survival Beta-Logistic link state: {e}"))
622 }
623 LinkFunction::Log => Err(format!(
624 "unsupported survival --link 'log'; {}",
625 survival_link_usage()
626 )),
627 other => {
628 if input.sas_init.is_some() {
629 return Err("--sas-init requires --link sas".to_string());
630 }
631 if input.beta_logistic_init.is_some() {
632 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
633 }
634 Ok(InverseLink::Standard(effective_link_to_standard(
635 other,
636 "survival inverse link",
637 )?))
638 }
639 }
640 } else {
641 if input.mixture_rho.is_some() {
642 return Err("--mixture-rho requires --link blended(...)/mixture(...)".to_string());
643 }
644 if input.sas_init.is_some() {
645 return Err("--sas-init requires --link sas".to_string());
646 }
647 if input.beta_logistic_init.is_some() {
648 return Err("--beta-logistic-init requires --link beta-logistic".to_string());
649 }
650 let dist = parse_survival_distribution(input.survival_distribution)?;
651 Ok(residual_distribution_inverse_link(dist))
652 }
653}
654
655pub fn normalize_optional_family(family: Option<String>) -> Option<String> {
656 match family {
657 Some(value) if value.eq_ignore_ascii_case("auto") => None,
658 other => other,
659 }
660}
661
662fn resolve_nonempty_string(raw: String, message: &str) -> Result<String, String> {
663 let trimmed = raw.trim();
664 if trimmed.is_empty() {
665 return Err(message.to_string());
666 }
667 Ok(trimmed.to_string())
668}
669
670fn parse_json_frailty_spec(
671 frailty_kind: Option<String>,
672 frailty_sd: Option<f64>,
673 hazard_loading: Option<String>,
674) -> Result<Option<FrailtySpec>, String> {
675 if let Some(kind) = frailty_kind {
676 let trimmed = kind.trim().to_ascii_lowercase();
677 let sigma = frailty_sd;
678 if let Some(value) = sigma
679 && (!value.is_finite() || value < 0.0)
680 {
681 return Err(format!("frailty_sd must be finite and >= 0, got {value}"));
682 }
683 let hazard_loading = hazard_loading
684 .as_ref()
685 .map(|raw| raw.trim().to_ascii_lowercase());
686 let frailty = match trimmed.as_str() {
687 "none" | "" => {
688 if sigma.is_some() || hazard_loading.is_some() {
689 return Err(
690 "frailty_kind='none' does not accept frailty_sd or hazard_loading"
691 .to_string(),
692 );
693 }
694 FrailtySpec::None
695 }
696 "hazard-multiplier" => {
697 let loading = match hazard_loading.as_deref() {
698 Some("full") | None => HazardLoading::Full,
699 Some("loaded-vs-unloaded") => HazardLoading::LoadedVsUnloaded,
700 Some(other) => {
701 return Err(format!(
702 "unknown hazard_loading '{other}'; supported: 'full', 'loaded-vs-unloaded'"
703 ));
704 }
705 };
706 FrailtySpec::HazardMultiplier {
707 sigma_fixed: sigma,
708 loading,
709 }
710 }
711 "gaussian-shift" => {
712 if hazard_loading.is_some() {
713 return Err(
714 "hazard_loading is valid only with frailty_kind='hazard-multiplier'"
715 .to_string(),
716 );
717 }
718 FrailtySpec::GaussianShift { sigma_fixed: sigma }
719 }
720 other => {
721 return Err(format!(
722 "unknown frailty_kind '{other}'; supported: 'none', 'hazard-multiplier', 'gaussian-shift'"
723 ));
724 }
725 };
726 Ok(Some(frailty))
727 } else if frailty_sd.is_some() || hazard_loading.is_some() {
728 Err("frailty_kind is required when frailty_sd or hazard_loading is provided".to_string())
729 } else {
730 Ok(None)
731 }
732}
733
734fn cli_hazard_loading(loading: CliHazardLoading) -> HazardLoading {
735 match loading {
736 CliHazardLoading::Full => HazardLoading::Full,
737 CliHazardLoading::LoadedVsUnloaded => HazardLoading::LoadedVsUnloaded,
738 }
739}
740
741fn parse_group_metadata(
742 direct: Option<GroupMetadata>,
743 groups: Option<JsonValue>,
744) -> Result<Option<GroupMetadata>, String> {
745 match (direct, groups) {
746 (Some(metadata), None) => Ok(nonempty_group_metadata(metadata)),
747 (None, Some(groups)) => group_metadata_from_groups(groups),
748 (None, None) => Ok(None),
749 (Some(_), Some(_)) => {
750 Err("fit config accepts either group_metadata or groups metadata, not both".to_string())
751 }
752 }
753}
754
755fn parse_gamma_pair_value(label: &str, value: JsonValue) -> Result<(String, f64, f64), String> {
756 match value {
757 JsonValue::Array(values) => {
758 if values.len() != 2 {
759 return Err(format!(
760 "precision_hyperpriors['{label}'] must be [shape, rate]"
761 ));
762 }
763 let shape = values[0]
764 .as_f64()
765 .ok_or_else(|| format!("precision_hyperpriors['{label}'][0] must be numeric"))?;
766 let rate = values[1]
767 .as_f64()
768 .ok_or_else(|| format!("precision_hyperpriors['{label}'][1] must be numeric"))?;
769 Ok((label.to_string(), shape, rate))
770 }
771 JsonValue::Object(mut map) => {
772 let shape = map
773 .remove("shape")
774 .or_else(|| map.remove("a"))
775 .or_else(|| map.remove("a_p"))
776 .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing shape/a"))?
777 .as_f64()
778 .ok_or_else(|| {
779 format!("precision_hyperpriors['{label}'] shape/a must be numeric")
780 })?;
781 let rate = map
782 .remove("rate")
783 .or_else(|| map.remove("b"))
784 .or_else(|| map.remove("b_p"))
785 .ok_or_else(|| format!("precision_hyperpriors['{label}'] missing rate/b"))?
786 .as_f64()
787 .ok_or_else(|| {
788 format!("precision_hyperpriors['{label}'] rate/b must be numeric")
789 })?;
790 Ok((label.to_string(), shape, rate))
791 }
792 _ => Err(format!(
793 "precision_hyperpriors['{label}'] must be [shape, rate] or an object"
794 )),
795 }
796}
797
798fn parse_precision_hyperpriors(
799 precision_hyperpriors: Option<JsonValue>,
800 penalty_block_gamma_priors: Option<JsonValue>,
801) -> Result<Vec<(String, f64, f64)>, String> {
802 let raw = match (precision_hyperpriors, penalty_block_gamma_priors) {
803 (Some(_), Some(_)) => {
804 return Err(
805 "fit config accepts either precision_hyperpriors or penalty_block_gamma_priors, not both"
806 .to_string(),
807 );
808 }
809 (Some(raw), None) | (None, Some(raw)) => raw,
810 (None, None) => {
811 return Ok(Vec::new());
812 }
813 };
814 let raw_name = "precision_hyperpriors";
815 let Some(raw) = (match raw {
816 JsonValue::Null => None,
817 other => Some(other),
818 }) else {
819 return Ok(Vec::new());
820 };
821 match raw {
822 JsonValue::Object(map) => map
823 .into_iter()
824 .map(|(label, value)| parse_gamma_pair_value(&label, value))
825 .collect(),
826 JsonValue::Array(items) => items
827 .into_iter()
828 .enumerate()
829 .map(|(idx, item)| match item {
830 JsonValue::Object(mut obj) => {
831 let label = obj
832 .remove("label")
833 .or_else(|| obj.remove("name"))
834 .or_else(|| obj.remove("group"))
835 .ok_or_else(|| format!("{raw_name}[{idx}] needs label/name/group"))?;
836 let JsonValue::String(label) = label else {
837 return Err(format!("{raw_name}[{idx}] label must be a string"));
838 };
839 parse_gamma_pair_value(&label, JsonValue::Object(obj))
840 }
841 JsonValue::Array(mut values) => {
842 if values.len() != 2 && values.len() != 3 {
843 return Err(format!(
844 "{raw_name}[{idx}] must be [label, shape, rate] or [label, [shape, rate]]"
845 ));
846 }
847 let label = values.remove(0);
848 let JsonValue::String(label) = label else {
849 return Err(format!("{raw_name}[{idx}][0] must be a string label"));
850 };
851 let pair = if values.len() == 1 {
852 values.remove(0)
853 } else {
854 JsonValue::Array(values)
855 };
856 parse_gamma_pair_value(&label, pair)
857 }
858 _ => Err(format!("{raw_name}[{idx}] must be an object or array")),
859 })
860 .collect(),
861 _ => Err(format!("{raw_name} must be a map or array")),
862 }
863}
864
865fn nonempty_group_metadata(metadata: GroupMetadata) -> Option<GroupMetadata> {
866 if metadata.is_empty() {
867 None
868 } else {
869 Some(metadata)
870 }
871}
872
873fn group_metadata_from_groups(groups: JsonValue) -> Result<Option<GroupMetadata>, String> {
874 match groups {
875 JsonValue::Null => Ok(None),
876 JsonValue::Object(map) => {
877 let out = map.into_iter().collect::<BTreeMap<_, _>>();
878 Ok(nonempty_group_metadata(out))
879 }
880 JsonValue::Array(items) => {
881 let mut out = BTreeMap::new();
882 for (idx, item) in items.into_iter().enumerate() {
883 let JsonValue::Object(mut group) = item else {
884 return Err(format!("groups[{idx}] must be an object"));
885 };
886 let Some(metadata) = group.remove("metadata") else {
887 continue;
888 };
889 let name = group
890 .remove("name")
891 .or_else(|| group.remove("id"))
892 .or_else(|| group.remove("key"))
893 .ok_or_else(|| {
894 format!(
895 "groups[{idx}] with metadata must include a string name, id, or key"
896 )
897 })?;
898 let JsonValue::String(name) = name else {
899 return Err(format!("groups[{idx}] name/id/key must be a string"));
900 };
901 if name.is_empty() {
902 return Err(format!("groups[{idx}] name/id/key must be non-empty"));
903 }
904 if out.insert(name.clone(), metadata).is_some() {
905 return Err(format!("duplicate group metadata key '{name}'"));
906 }
907 }
908 Ok(nonempty_group_metadata(out))
909 }
910 _ => Err("groups must be an object map or an array of group objects".to_string()),
911 }
912}
913
914fn parse_gpu_policy(raw_gpu: &str) -> Result<gam_gpu::GpuPolicy, String> {
915 gam_gpu::GpuPolicy::parse(raw_gpu).ok_or_else(|| {
916 format!(
917 "invalid gpu policy '{}'; supported values are auto, off, force",
918 raw_gpu
919 )
920 })
921}
922
923fn validate_resolved_fit_config(config: &FitConfig) -> Result<(), String> {
924 if !config.ridge_lambda.is_finite() || config.ridge_lambda < 0.0 {
925 return Err("--ridge-lambda must be finite and >= 0".to_string());
926 }
927 let likelihood_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
928 validate_survival_baseline_args(
929 likelihood_mode,
930 &config.baseline_target,
931 config.baseline_scale,
932 config.baseline_shape,
933 config.baseline_rate,
934 config.baseline_makeham,
935 )
936}
937
938fn survival_link_usage() -> &'static str {
939 "use identity|logit|probit|cloglog|sas|beta-logistic|blended(...)/mixture(...) or flexible(...)"
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use gam_models::survival::lognormal_kernel::FrailtySpec;
946 use serde_json::{Value, json};
947
948 struct ParityCase {
949 name: &'static str,
950 cli: CliFitConfigInput,
951 json: Value,
952 }
953
954 fn base_cli() -> CliFitConfigInput {
955 CliFitConfigInput {
956 family: None,
957 expectile_tau: None,
958 negative_binomial_theta: None,
959 link: None,
960 flexible_link: false,
961 offset_column: None,
962 weight_column: None,
963 noise_offset_column: None,
964 baseline_target: "linear".to_string(),
965 baseline_scale: None,
966 baseline_shape: None,
967 baseline_rate: None,
968 baseline_makeham: None,
969 time_basis: "ispline".to_string(),
970 time_degree: 3,
971 time_num_internal_knots: 8,
972 time_smooth_lambda: 1e-2,
973 survival_likelihood: "transformation".to_string(),
974 survival_distribution: "gaussian".to_string(),
975 threshold_time_k: None,
976 threshold_time_degree: 3,
977 sigma_time_k: None,
978 sigma_time_degree: 3,
979 noise_formula: None,
980 logslope_formula: None,
981 z_column: None,
982 scale_dimensions: false,
983 adaptive_regularization: None,
984 ridge_lambda: 1e-6,
985 transformation_normal: false,
986 firth: false,
987 outer_max_iter: None,
988 gpu: None,
989 frailty_kind: None,
990 frailty_sd: None,
991 hazard_loading: None,
992 }
993 }
994
995 fn resolved_cli(input: CliFitConfigInput) -> Result<FitConfig, String> {
996 resolve_cli_fit_config(input)
997 }
998
999 fn resolved_json(config: Value) -> Result<FitConfig, String> {
1000 parse_fit_config_json(Some(&config.to_string())).map(|resolved| {
1001 assert_eq!(resolved.training_table_kind, None);
1002 resolved.fit_config
1003 })
1004 }
1005
1006 fn canonical_fit_config(mut config: FitConfig) -> String {
1007 if config.frailty.is_none() {
1008 config.frailty = Some(FrailtySpec::None);
1009 }
1010 format!("{config:#?}")
1011 }
1012
1013 #[test]
1014 fn cli_shaped_and_json_wire_config_resolution_match() {
1015 let cases = vec![
1016 ParityCase {
1017 name: "family and link selection",
1018 cli: {
1019 let mut input = base_cli();
1020 input.family = Some("binomial".to_string());
1021 input.link = Some("probit".to_string());
1022 input.flexible_link = true;
1023 input
1024 },
1025 json: json!({
1026 "family": "binomial",
1027 "link": "probit",
1028 "flexible_link": true
1029 }),
1030 },
1031 ParityCase {
1032 name: "offset weights ridge and noise offset columns",
1033 cli: {
1034 let mut input = base_cli();
1035 input.offset_column = Some("eta_offset".to_string());
1036 input.weight_column = Some("case_weight".to_string());
1037 input.noise_offset_column = Some("sigma_offset".to_string());
1038 input.ridge_lambda = 0.125;
1039 input
1040 },
1041 json: json!({
1042 "offset": "eta_offset",
1043 "weights": "case_weight",
1044 "noise_offset": "sigma_offset",
1045 "ridge_lambda": 0.125
1046 }),
1047 },
1048 ParityCase {
1049 name: "weibull survival likelihood and baseline scale shape",
1050 cli: {
1051 let mut input = base_cli();
1052 input.survival_likelihood = "weibull".to_string();
1053 input.baseline_target = "weibull".to_string();
1054 input.baseline_scale = Some(2.5);
1055 input.baseline_shape = Some(1.75);
1056 input
1057 },
1058 json: json!({
1059 "survival_likelihood": "weibull",
1060 "baseline_target": "weibull",
1061 "baseline_scale": 2.5,
1062 "baseline_shape": 1.75
1063 }),
1064 },
1065 ParityCase {
1066 name: "transformation survival gompertz makeham baseline",
1067 cli: {
1068 let mut input = base_cli();
1069 input.survival_likelihood = "transformation".to_string();
1070 input.baseline_target = "gompertz-makeham".to_string();
1071 input.baseline_shape = Some(1.2);
1072 input.baseline_rate = Some(0.04);
1073 input.baseline_makeham = Some(0.01);
1074 input
1075 },
1076 json: json!({
1077 "survival_likelihood": "transformation",
1078 "baseline_target": "gompertz-makeham",
1079 "baseline_shape": 1.2,
1080 "baseline_rate": 0.04,
1081 "baseline_makeham": 0.01
1082 }),
1083 },
1084 ParityCase {
1085 name: "survival likelihood values are canonicalized",
1086 cli: {
1087 let mut input = base_cli();
1088 input.survival_likelihood = "TRANSFORMATION".to_string();
1089 input
1090 },
1091 json: json!({
1092 "survival_likelihood": "Transformation"
1093 }),
1094 },
1095 ParityCase {
1096 name: "noise formula logslope z column and scale dimensions",
1097 cli: {
1098 let mut input = base_cli();
1099 input.noise_formula = Some("~ s(age) + treatment".to_string());
1100 input.logslope_formula = Some("~ s(dose)".to_string());
1101 input.z_column = Some("dose".to_string());
1102 input.scale_dimensions = true;
1103 input
1104 },
1105 json: json!({
1106 "noise_formula": "~ s(age) + treatment",
1107 "logslope_formula": "~ s(dose)",
1108 "z_column": "dose",
1109 "scale_dimensions": true
1110 }),
1111 },
1112 ParityCase {
1113 name: "firth transformation normal outer iterations and adaptive regularization",
1114 cli: {
1115 let mut input = base_cli();
1116 input.firth = true;
1117 input.transformation_normal = true;
1118 input.outer_max_iter = Some(7);
1119 input.adaptive_regularization = Some(true);
1120 input
1121 },
1122 json: json!({
1123 "firth": true,
1124 "transformation_normal": true,
1125 "outer_max_iter": 7,
1126 "adaptive_regularization": true
1127 }),
1128 },
1129 ParityCase {
1130 name: "gpu policy toggle",
1131 cli: {
1132 let mut input = base_cli();
1133 input.gpu = Some("off".to_string());
1134 input
1135 },
1136 json: json!({
1137 "gpu": "off"
1138 }),
1139 },
1140 ParityCase {
1141 name: "hazard multiplier frailty fields",
1142 cli: {
1143 let mut input = base_cli();
1144 input.frailty_kind = Some(CliFrailtyKind::HazardMultiplier);
1145 input.frailty_sd = Some(0.35);
1146 input.hazard_loading = Some(CliHazardLoading::LoadedVsUnloaded);
1147 input
1148 },
1149 json: json!({
1150 "frailty_kind": "hazard-multiplier",
1151 "frailty_sd": 0.35,
1152 "hazard_loading": "loaded-vs-unloaded"
1153 }),
1154 },
1155 ParityCase {
1156 name: "gaussian shift frailty fields",
1157 cli: {
1158 let mut input = base_cli();
1159 input.frailty_kind = Some(CliFrailtyKind::GaussianShift);
1160 input.frailty_sd = Some(0.2);
1161 input
1162 },
1163 json: json!({
1164 "frailty_kind": "gaussian-shift",
1165 "frailty_sd": 0.2
1166 }),
1167 },
1168 ];
1169
1170 for case in cases {
1171 let cli = resolved_cli(case.cli)
1172 .unwrap_or_else(|err| panic!("{}: CLI-shaped config failed: {err}", case.name));
1173 let json = resolved_json(case.json)
1174 .unwrap_or_else(|err| panic!("{}: JSON wire config failed: {err}", case.name));
1175 assert_eq!(
1176 canonical_fit_config(cli),
1177 canonical_fit_config(json),
1178 "{}",
1179 case.name
1180 );
1181 }
1182 }
1183
1184 #[test]
1185 fn cli_shaped_and_json_wire_config_resolution_rejections_match() {
1186 let cases = vec![
1187 ParityCase {
1188 name: "negative ridge lambda",
1189 cli: {
1190 let mut input = base_cli();
1191 input.ridge_lambda = -1.0;
1192 input
1193 },
1194 json: json!({
1195 "ridge_lambda": -1.0
1196 }),
1197 },
1198 ParityCase {
1199 name: "unknown gpu policy",
1200 cli: {
1201 let mut input = base_cli();
1202 input.gpu = Some("cuda".to_string());
1203 input
1204 },
1205 json: json!({
1206 "gpu": "cuda"
1207 }),
1208 },
1209 ParityCase {
1210 name: "linear baseline rejects shape",
1211 cli: {
1212 let mut input = base_cli();
1213 input.baseline_shape = Some(1.1);
1214 input
1215 },
1216 json: json!({
1217 "baseline_shape": 1.1
1218 }),
1219 },
1220 ParityCase {
1221 name: "weibull likelihood rejects gompertz target",
1222 cli: {
1223 let mut input = base_cli();
1224 input.survival_likelihood = "weibull".to_string();
1225 input.baseline_target = "gompertz".to_string();
1226 input
1227 },
1228 json: json!({
1229 "survival_likelihood": "weibull",
1230 "baseline_target": "gompertz"
1231 }),
1232 },
1233 ];
1234
1235 for case in cases {
1236 let cli = resolved_cli(case.cli).expect_err(case.name);
1237 let json = resolved_json(case.json).expect_err(case.name);
1238 assert_eq!(cli, json, "{}", case.name);
1239 }
1240 }
1241
1242 #[test]
1245 fn parse_comma_f64_empty_string_returns_empty_vec() {
1246 assert_eq!(parse_comma_f64("", "x").unwrap(), Vec::<f64>::new());
1247 assert_eq!(parse_comma_f64(" ", "x").unwrap(), Vec::<f64>::new());
1248 }
1249
1250 #[test]
1251 fn parse_comma_f64_single_value() {
1252 assert_eq!(parse_comma_f64("3.14", "x").unwrap(), vec![3.14]);
1253 }
1254
1255 #[test]
1256 fn parse_comma_f64_multiple_values_with_spaces() {
1257 let result = parse_comma_f64("1.0, 2.5, -3.0", "x").unwrap();
1258 assert_eq!(result, vec![1.0, 2.5, -3.0]);
1259 }
1260
1261 #[test]
1262 fn parse_comma_f64_non_numeric_returns_error() {
1263 let err = parse_comma_f64("1.0, bad, 3.0", "--vals").unwrap_err();
1264 assert!(err.contains("--vals"), "error should name the label: {err}");
1265 assert!(err.contains("bad"), "error should name the bad token: {err}");
1266 }
1267
1268 #[test]
1269 fn parse_comma_f64_infinity_returns_error() {
1270 let err = parse_comma_f64("inf", "--vals").unwrap_err();
1271 assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1272 }
1273
1274 #[test]
1275 fn parse_comma_f64_nan_returns_error() {
1276 let err = parse_comma_f64("nan", "--vals").unwrap_err();
1277 assert!(err.contains("non-finite"), "error should say non-finite: {err}");
1278 }
1279
1280 #[test]
1283 fn normalize_optional_family_none_passthrough() {
1284 assert_eq!(normalize_optional_family(None), None);
1285 }
1286
1287 #[test]
1288 fn normalize_optional_family_auto_becomes_none() {
1289 assert_eq!(normalize_optional_family(Some("auto".to_string())), None);
1290 assert_eq!(normalize_optional_family(Some("Auto".to_string())), None);
1291 assert_eq!(normalize_optional_family(Some("AUTO".to_string())), None);
1292 }
1293
1294 #[test]
1295 fn normalize_optional_family_non_auto_passthrough() {
1296 assert_eq!(
1297 normalize_optional_family(Some("binomial".to_string())),
1298 Some("binomial".to_string())
1299 );
1300 assert_eq!(
1301 normalize_optional_family(Some("gaussian".to_string())),
1302 Some("gaussian".to_string())
1303 );
1304 }
1305
1306 #[test]
1309 fn parse_survival_likelihood_cli_valid_values() {
1310 assert_eq!(
1311 parse_survival_likelihood_cli("transformation").unwrap(),
1312 "transformation"
1313 );
1314 assert_eq!(
1315 parse_survival_likelihood_cli("weibull").unwrap(),
1316 "weibull"
1317 );
1318 assert_eq!(
1320 parse_survival_likelihood_cli("WEIBULL").unwrap(),
1321 "weibull"
1322 );
1323 assert_eq!(
1324 parse_survival_likelihood_cli("Transformation").unwrap(),
1325 "transformation"
1326 );
1327 }
1328
1329 #[test]
1330 fn parse_survival_likelihood_cli_invalid_returns_error() {
1331 assert!(parse_survival_likelihood_cli("lognormal").is_err());
1332 assert!(parse_survival_likelihood_cli("").is_err());
1333 }
1334
1335 #[test]
1338 fn parse_baseline_target_cli_valid_values() {
1339 for target in &["linear", "weibull", "gompertz", "gompertz-makeham"] {
1340 assert_eq!(
1341 parse_baseline_target_cli(target).unwrap(),
1342 *target,
1343 "should accept '{target}'"
1344 );
1345 }
1346 assert_eq!(
1348 parse_baseline_target_cli(" Weibull ").unwrap(),
1349 "weibull"
1350 );
1351 }
1352
1353 #[test]
1354 fn parse_baseline_target_cli_invalid_returns_error() {
1355 let err = parse_baseline_target_cli("cox").unwrap_err();
1356 assert!(err.contains("cox"), "error should name the bad value: {err}");
1357 }
1358
1359 #[test]
1362 fn validate_survival_baseline_args_linear_rejects_params() {
1363 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1364 assert!(validate_survival_baseline_args(
1365 mode,
1366 "linear",
1367 Some(1.0),
1368 None,
1369 None,
1370 None
1371 )
1372 .is_err());
1373 }
1374
1375 #[test]
1376 fn validate_survival_baseline_args_linear_accepts_no_params() {
1377 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1378 assert!(validate_survival_baseline_args(
1379 mode, "linear", None, None, None, None
1380 )
1381 .is_ok());
1382 }
1383
1384 #[test]
1385 fn validate_survival_baseline_args_weibull_likelihood_rejects_gompertz_target() {
1386 let mode = parse_survival_likelihood_mode("weibull").unwrap();
1387 assert!(validate_survival_baseline_args(
1388 mode,
1389 "gompertz",
1390 None,
1391 None,
1392 None,
1393 None
1394 )
1395 .is_err());
1396 }
1397
1398 #[test]
1399 fn validate_survival_baseline_args_gompertz_rejects_scale() {
1400 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1401 assert!(validate_survival_baseline_args(
1402 mode,
1403 "gompertz",
1404 Some(2.0),
1405 None,
1406 None,
1407 None
1408 )
1409 .is_err());
1410 }
1411
1412 #[test]
1413 fn validate_survival_baseline_args_gompertz_makeham_rejects_scale() {
1414 let mode = parse_survival_likelihood_mode("transformation").unwrap();
1415 assert!(validate_survival_baseline_args(
1416 mode,
1417 "gompertz-makeham",
1418 Some(1.0),
1419 None,
1420 None,
1421 None
1422 )
1423 .is_err());
1424 }
1425}