1use crate::custom_family::{
21 BlockWorkingSet, BlockwiseFitOptions, CustomFamily, ExactNewtonJointGradientEvaluation,
22 ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
23 PenaltyMatrix, fit_custom_family, fit_custom_family_fixed_log_lambdas,
24};
25use crate::gamlss::{FamilyMetadata, ParameterLink};
26use crate::sigma_link::{exp_sigma_eta_for_sigma_scalar, exp_sigma_from_eta_scalar};
27use crate::survival::latent::interval::{
28 LatentFrailtyResolution, LatentIntervalModel, LatentIntervalRowView,
29 validate_latent_interval_inputs,
30};
31use crate::survival::location_scale::{
32 TimeBlockInput, project_onto_linear_constraints, structural_time_coefficient_constraints,
33};
34use crate::survival::lognormal_kernel::{
35 FrailtySpec, HazardLoading, LatentSurvivalEventType, LatentSurvivalRow, LatentSurvivalRowJet,
36 log_kernel_bundle,
37};
38use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
39use crate::model_types::UnifiedFitResult;
40use gam_solve::pirls::LinearInequalityConstraints;
41use crate::probability::signed_log_sum_exp;
42use crate::quadrature::{IntegratedExpectationMode, QuadratureContext};
43use gam_terms::smooth::{
44 TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
45};
46use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
47use gam_problem::MIN_WEIGHT;
48use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
49use std::collections::BTreeMap;
50use std::sync::Arc;
51
52#[derive(Debug, Clone)]
58pub enum LatentSurvivalError {
59 InvalidFrailty { reason: String },
63 InvalidDataset { reason: String },
67 BlockMismatch { reason: String },
70 NumericalFailure { reason: String },
73 UnsupportedConfiguration { reason: String },
77}
78
79impl_reason_error_boilerplate! {
80 LatentSurvivalError {
81 InvalidFrailty,
82 InvalidDataset,
83 BlockMismatch,
84 NumericalFailure,
85 UnsupportedConfiguration,
86 }
87}
88
89impl From<crate::block_layout::block_count::BlockCountMismatch> for LatentSurvivalError {
90 fn from(
91 err: crate::block_layout::block_count::BlockCountMismatch,
92 ) -> LatentSurvivalError {
93 LatentSurvivalError::BlockMismatch {
94 reason: err.message(),
95 }
96 }
97}
98
99impl From<String> for LatentSurvivalError {
100 fn from(reason: String) -> LatentSurvivalError {
106 LatentSurvivalError::InvalidDataset { reason }
107 }
108}
109
110pub const LATENT_SURVIVAL_EVENT_INTERVAL: u8 = u8::MAX;
116
117#[inline]
118fn latent_survival_event_type_for(code: u8) -> LatentSurvivalEventType {
119 match code {
120 0 => LatentSurvivalEventType::RightCensored,
121 LATENT_SURVIVAL_EVENT_INTERVAL => LatentSurvivalEventType::IntervalCensored,
122 _ => LatentSurvivalEventType::ExactEvent,
123 }
124}
125
126#[derive(Clone)]
127pub struct LatentSurvivalTermSpec {
128 pub age_entry: Array1<f64>,
129 pub age_exit: Array1<f64>,
130 pub event_target: Array1<u8>,
131 pub weights: Array1<f64>,
132 pub derivative_guard: f64,
133 pub time_block: TimeBlockInput,
134 pub time_design_right: Option<DesignMatrix>,
141 pub time_offset_right: Option<Array1<f64>>,
142 pub unloaded_mass_entry: Array1<f64>,
143 pub unloaded_mass_exit: Array1<f64>,
144 pub unloaded_mass_right: Array1<f64>,
148 pub unloaded_hazard_exit: Array1<f64>,
149 pub meanspec: TermCollectionSpec,
150 pub mean_offset: Array1<f64>,
151}
152
153pub struct LatentSurvivalTermFitResult {
154 pub fit: UnifiedFitResult,
155 pub design: TermCollectionDesign,
156 pub resolvedspec: TermCollectionSpec,
157 pub latent_sd: f64,
158 pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
164}
165
166#[derive(Clone)]
167pub struct LatentBinaryTermSpec {
168 pub age_entry: Array1<f64>,
169 pub age_exit: Array1<f64>,
170 pub event_target: Array1<u8>,
171 pub weights: Array1<f64>,
172 pub derivative_guard: f64,
173 pub time_block: TimeBlockInput,
174 pub unloaded_mass_entry: Array1<f64>,
175 pub unloaded_mass_exit: Array1<f64>,
176 pub meanspec: TermCollectionSpec,
177 pub mean_offset: Array1<f64>,
178}
179
180pub struct LatentBinaryTermFitResult {
181 pub fit: UnifiedFitResult,
182 pub design: TermCollectionDesign,
183 pub resolvedspec: TermCollectionSpec,
184 pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
188}
189
190#[derive(Clone)]
191struct PreparedLatentTimeBlock {
192 design_entry: Array2<f64>,
193 design_exit: Array2<f64>,
194 design_derivative_exit: Array2<f64>,
195 design_right: Array2<f64>,
200 linear_constraints: Option<LinearInequalityConstraints>,
201 penalties: Vec<Array2<f64>>,
202 initial_beta: Option<Array1<f64>>,
203}
204
205#[derive(Clone)]
206pub struct LatentSurvivalFamily {
207 pub event_target: Array1<u8>,
208 pub weights: Array1<f64>,
209 pub latent_sd_fixed: Option<f64>,
210 pub hazard_loading: HazardLoading,
211 pub unloaded_mass_entry: Array1<f64>,
212 pub unloaded_mass_exit: Array1<f64>,
213 pub unloaded_hazard_exit: Array1<f64>,
214 pub x_time_entry: Array2<f64>,
215 pub x_time_exit: Array2<f64>,
216 pub x_time_derivative_exit: Array2<f64>,
217 pub x_time_right: Array2<f64>,
223 pub time_offset_right: Array1<f64>,
225 pub unloaded_mass_right: Array1<f64>,
228 pub x_mean: DesignMatrix,
229 pub time_linear_constraints: Option<LinearInequalityConstraints>,
230 pub quadctx: Arc<QuadratureContext>,
231}
232
233#[derive(Clone)]
234pub struct LatentBinaryFamily {
235 pub event_target: Array1<u8>,
236 pub weights: Array1<f64>,
237 pub latent_sd: f64,
238 pub hazard_loading: HazardLoading,
239 pub unloaded_mass_entry: Array1<f64>,
240 pub unloaded_mass_exit: Array1<f64>,
241 pub x_time_entry: Array2<f64>,
242 pub x_time_exit: Array2<f64>,
243 pub x_mean: DesignMatrix,
244 pub time_linear_constraints: Option<LinearInequalityConstraints>,
245 pub quadctx: Arc<QuadratureContext>,
246}
247
248impl LatentSurvivalFamily {
249 pub const BLOCK_TIME: usize = 0;
250 pub const BLOCK_MEAN: usize = 1;
251 pub const BLOCK_LOG_SIGMA: usize = 2;
252
253 pub fn parameter_names() -> &'static [&'static str] {
254 &["time_transform", "mean"]
255 }
256
257 pub fn parameter_links() -> &'static [ParameterLink] {
258 &[ParameterLink::Identity, ParameterLink::Identity]
259 }
260
261 pub fn metadata() -> FamilyMetadata {
262 FamilyMetadata {
263 name: "latent_survival",
264 parameternames: Self::parameter_names(),
265 parameter_links: Self::parameter_links(),
266 }
267 }
268
269 fn split_time_eta<'a>(
270 &self,
271 block_states: &'a [ParameterBlockState],
272 ) -> Result<
273 (
274 ArrayView1<'a, f64>,
275 ArrayView1<'a, f64>,
276 ArrayView1<'a, f64>,
277 &'a Array1<f64>,
278 ),
279 LatentSurvivalError,
280 > {
281 let expected_blocks = if self.latent_sd_fixed.is_some() { 2 } else { 3 };
282 crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
283 "LatentSurvivalFamily",
284 expected_blocks,
285 block_states.len(),
286 )?;
287 let n = self.event_target.len();
288 let eta_time = &block_states[Self::BLOCK_TIME].eta;
289 let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
290 if eta_time.len() != 3 * n {
291 return Err(LatentSurvivalError::BlockMismatch {
292 reason: format!(
293 "latent survival time eta length mismatch: got {}, expected {}",
294 eta_time.len(),
295 3 * n
296 ),
297 });
298 }
299 if eta_mean.len() != n || self.weights.len() != n {
300 return Err(LatentSurvivalError::BlockMismatch {
301 reason: "latent survival mean eta dimension mismatch".to_string(),
302 });
303 }
304 Ok((
305 eta_time.slice(s![0..n]),
306 eta_time.slice(s![n..2 * n]),
307 eta_time.slice(s![2 * n..3 * n]),
308 eta_mean,
309 ))
310 }
311
312 fn time_q_right(
319 &self,
320 block_states: &[ParameterBlockState],
321 ) -> Result<Array1<f64>, LatentSurvivalError> {
322 let n = self.event_target.len();
323 let beta_time = &block_states[Self::BLOCK_TIME].beta;
324 if self.x_time_right.ncols() != beta_time.len() {
325 return Err(LatentSurvivalError::BlockMismatch {
326 reason: format!(
327 "latent survival interval right design has {} columns but time beta has {}",
328 self.x_time_right.ncols(),
329 beta_time.len()
330 ),
331 });
332 }
333 if self.x_time_right.nrows() != n || self.time_offset_right.len() != n {
334 return Err(LatentSurvivalError::BlockMismatch {
335 reason: "latent survival interval right design/offset row count mismatch"
336 .to_string(),
337 });
338 }
339 let mut q_right = self.x_time_right.dot(beta_time);
340 q_right += &self.time_offset_right;
341 Ok(q_right)
342 }
343
344 fn latent_sd(&self, block_states: &[ParameterBlockState]) -> Result<f64, LatentSurvivalError> {
345 if let Some(sigma) = self.latent_sd_fixed {
346 return Ok(sigma);
347 }
348 let eta = *block_states
349 .get(Self::BLOCK_LOG_SIGMA)
350 .and_then(|state| state.eta.get(0))
351 .ok_or_else(|| LatentSurvivalError::BlockMismatch {
352 reason: "latent survival learnable log_sigma block is missing".to_string(),
353 })?;
354 let sigma = exp_sigma_from_eta_scalar(eta);
355 if !(sigma.is_finite() && sigma > 0.0) {
356 return Err(LatentSurvivalError::NumericalFailure {
357 reason: format!(
358 "latent survival learnable sigma became invalid: log_sigma={eta}, sigma={sigma}"
359 ),
360 });
361 }
362 Ok(sigma)
363 }
364}
365
366impl LatentBinaryFamily {
367 pub const BLOCK_TIME: usize = 0;
368 pub const BLOCK_MEAN: usize = 1;
369
370 fn split_time_eta<'a>(
371 &self,
372 block_states: &'a [ParameterBlockState],
373 ) -> Result<(ArrayView1<'a, f64>, ArrayView1<'a, f64>, &'a Array1<f64>), LatentSurvivalError>
374 {
375 crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
376 "LatentBinaryFamily",
377 2,
378 block_states.len(),
379 )?;
380 let n = self.event_target.len();
381 let eta_time = &block_states[Self::BLOCK_TIME].eta;
382 let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
383 if eta_time.len() != 3 * n {
384 return Err(LatentSurvivalError::BlockMismatch {
385 reason: format!(
386 "latent binary time eta length mismatch: got {}, expected {}",
387 eta_time.len(),
388 3 * n
389 ),
390 });
391 }
392 if eta_mean.len() != n || self.weights.len() != n {
393 return Err(LatentSurvivalError::BlockMismatch {
394 reason: "latent binary mean eta dimension mismatch".to_string(),
395 });
396 }
397 Ok((
398 eta_time.slice(s![0..n]),
399 eta_time.slice(s![n..2 * n]),
400 eta_mean,
401 ))
402 }
403}
404
405pub fn fixed_latent_hazard_frailty(
406 frailty: &FrailtySpec,
407 context: &str,
408) -> Result<(f64, HazardLoading), String> {
409 fixed_latent_hazard_frailty_typed(frailty, context).map_err(Into::into)
410}
411
412fn fixed_latent_hazard_frailty_typed(
413 frailty: &FrailtySpec,
414 context: &str,
415) -> Result<(f64, HazardLoading), LatentSurvivalError> {
416 match frailty {
417 FrailtySpec::HazardMultiplier {
418 sigma_fixed: Some(sigma),
419 loading,
420 } if sigma.is_finite() && *sigma >= 0.0 => Ok((*sigma, *loading)),
421 FrailtySpec::HazardMultiplier {
422 sigma_fixed: Some(sigma),
423 ..
424 } => Err(LatentSurvivalError::InvalidFrailty {
425 reason: format!(
426 "{context} requires a finite fixed hazard-multiplier sigma >= 0, got {sigma}"
427 ),
428 }),
429 FrailtySpec::HazardMultiplier {
430 sigma_fixed: None, ..
431 } => Err(LatentSurvivalError::InvalidFrailty {
432 reason: format!("{context} currently requires a fixed hazard-multiplier sigma"),
433 }),
434 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
435 reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
436 }),
437 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
438 reason: format!("{context} requires a fixed HazardMultiplier frailty specification"),
439 }),
440 }
441}
442
443pub fn latent_hazard_loading(
444 frailty: &FrailtySpec,
445 context: &str,
446) -> Result<HazardLoading, String> {
447 latent_hazard_loading_typed(frailty, context).map_err(Into::into)
448}
449
450fn latent_hazard_loading_typed(
451 frailty: &FrailtySpec,
452 context: &str,
453) -> Result<HazardLoading, LatentSurvivalError> {
454 match frailty {
455 FrailtySpec::HazardMultiplier { loading, .. } => Ok(*loading),
456 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
457 reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
458 }),
459 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
460 reason: format!("{context} requires a HazardMultiplier frailty specification"),
461 }),
462 }
463}
464
465#[derive(Clone, Copy)]
466struct LatentSurvivalTimeJet {
467 grad_entry: f64,
468 grad_exit: f64,
469 neg_hess_entry: f64,
470 neg_hess_exit: f64,
471}
472
473pub fn fit_latent_survival_terms(
474 data: ArrayView2<'_, f64>,
475 spec: LatentSurvivalTermSpec,
476 frailty: FrailtySpec,
477 options: &BlockwiseFitOptions,
478) -> Result<LatentSurvivalTermFitResult, String> {
479 let latent_sd = validate_latent_survival_inputs(data, &spec, &frailty)?;
480 let hazard_loading = latent_hazard_loading(&frailty, "latent-survival")?;
481 let mean_design =
482 build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
483 let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
484 .map_err(|e| e.to_string())?;
485 let time_prepared = prepare_latent_time_block(
486 &spec.time_block,
487 spec.time_design_right.as_ref(),
488 spec.derivative_guard,
489 )?;
490
491 let n = spec.event_target.len();
492 let time_offset_right = match spec.time_offset_right.as_ref() {
493 Some(offset) => {
494 if offset.len() != n {
495 return Err(format!(
496 "latent survival interval right time offset must have length {n}, got {}",
497 offset.len()
498 ));
499 }
500 offset.clone()
501 }
502 None => Array1::zeros(n),
503 };
504 let unloaded_mass_right = if spec.unloaded_mass_right.is_empty() {
505 Array1::zeros(n)
506 } else {
507 if spec.unloaded_mass_right.len() != n {
508 return Err(format!(
509 "latent survival interval right unloaded mass must have length {n}, got {}",
510 spec.unloaded_mass_right.len()
511 ));
512 }
513 spec.unloaded_mass_right.clone()
514 };
515
516 let family = LatentSurvivalFamily {
517 event_target: spec.event_target.clone(),
518 weights: spec.weights.clone(),
519 latent_sd_fixed: latent_sd,
520 hazard_loading,
521 unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
522 unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
523 unloaded_hazard_exit: spec.unloaded_hazard_exit.clone(),
524 x_time_entry: time_prepared.design_entry.clone(),
525 x_time_exit: time_prepared.design_exit.clone(),
526 x_time_derivative_exit: time_prepared.design_derivative_exit.clone(),
527 x_time_right: time_prepared.design_right.clone(),
528 time_offset_right,
529 unloaded_mass_right,
530 x_mean: mean_design.design.clone(),
531 time_linear_constraints: time_prepared.linear_constraints.clone(),
532 quadctx: Arc::new(QuadratureContext::new()),
533 };
534
535 let mut blocks = vec![
536 build_time_blockspec(&time_prepared, &spec.time_block),
537 build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
538 ];
539 if latent_sd.is_none() {
540 blocks.push(build_log_sigma_blockspec(
541 LEARNABLE_LATENT_SD_SEED,
542 mean_design.design.nrows(),
543 ));
544 }
545 let has_interval_rows = spec
570 .event_target
571 .iter()
572 .any(|&code| code == LATENT_SURVIVAL_EVENT_INTERVAL);
573 if has_interval_rows {
574 let censored_warm_event_target = spec.event_target.mapv(|code| {
575 if code == LATENT_SURVIVAL_EVENT_INTERVAL {
576 0u8
577 } else {
578 code
579 }
580 });
581 let mut warm_family = family.clone();
582 warm_family.event_target = censored_warm_event_target;
583 let warm_fit_result = fit_custom_family_fixed_log_lambdas(
587 &warm_family,
588 &blocks,
589 options,
590 None,
591 0,
592 None,
593 false,
594 );
595 let warm_fit = match warm_fit_result {
596 Ok(fit) => fit,
597 Err(censored_error) => {
598 let has_finite_event_in_censored_surrogate =
599 warm_family.event_target.iter().any(|&code| code != 0);
600 if has_finite_event_in_censored_surrogate {
601 return Err(format!(
602 "latent interval warm start: right-censored-at-L surrogate fit failed \
603 (so the interval fit cannot be safely warm-started; this surrogate is \
604 log-concave and should converge — investigate the surrogate, not the \
605 interval kernel): {censored_error}"
606 ));
607 }
608
609 let lower_event_warm_target = spec.event_target.mapv(|code| {
618 if code == LATENT_SURVIVAL_EVENT_INTERVAL {
619 1u8
620 } else {
621 code
622 }
623 });
624 let mut event_warm_family = family.clone();
625 event_warm_family.event_target = lower_event_warm_target;
626 fit_custom_family_fixed_log_lambdas(
627 &event_warm_family,
628 &blocks,
629 options,
630 None,
631 0,
632 None,
633 false,
634 )
635 .map_err(|event_error| {
636 format!(
637 "latent interval warm start failed: the right-censored-at-L surrogate \
638 has no finite failures and refused its boundary optimum ({censored_error}); \
639 the finite lower-endpoint event surrogate also failed ({event_error})"
640 )
641 })?
642 }
643 };
644 let warm_beta_usable = warm_fit
645 .block_states
646 .iter()
647 .any(|s| s.beta.iter().all(|v| v.is_finite()) && s.beta.iter().any(|&v| v != 0.0));
648 if !warm_beta_usable {
649 return Err(
650 "latent interval warm start: right-censored-at-L surrogate returned a \
651 degenerate (non-finite or all-zero) β across every block; the warm start \
652 cannot seed the interval fit. This indicates the surrogate's time-block \
653 design is rank-deficient or the inner solve stalled at the seed — \
654 investigate the surrogate before retrying the interval fit."
655 .to_string(),
656 );
657 }
658 for (block, state) in blocks.iter_mut().zip(warm_fit.block_states.iter()) {
659 if state.beta.iter().all(|v| v.is_finite()) {
660 block.initial_beta = Some(state.beta.clone());
661 }
662 }
663 }
664 let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
665 let latent_sd = family.latent_sd(&fit.block_states)?;
666 let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
667 Ok(LatentSurvivalTermFitResult {
668 fit,
669 design: mean_design,
670 resolvedspec,
671 latent_sd,
672 baseline_offset_residuals,
673 })
674}
675
676pub fn fit_latent_binary_terms(
677 data: ArrayView2<'_, f64>,
678 spec: LatentBinaryTermSpec,
679 frailty: FrailtySpec,
680 options: &BlockwiseFitOptions,
681) -> Result<LatentBinaryTermFitResult, String> {
682 let latent_sd = validate_latent_binary_inputs(data, &spec, &frailty)?;
683 let (_, hazard_loading) = fixed_latent_hazard_frailty(&frailty, "latent-binary")?;
684 let mean_design =
685 build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
686 let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
687 .map_err(|e| e.to_string())?;
688 let time_prepared = prepare_latent_time_block(&spec.time_block, None, spec.derivative_guard)?;
689
690 let family = LatentBinaryFamily {
691 event_target: spec.event_target.clone(),
692 weights: spec.weights.clone(),
693 latent_sd,
694 hazard_loading,
695 unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
696 unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
697 x_time_entry: time_prepared.design_entry.clone(),
698 x_time_exit: time_prepared.design_exit.clone(),
699 x_mean: mean_design.design.clone(),
700 time_linear_constraints: time_prepared.linear_constraints.clone(),
701 quadctx: Arc::new(QuadratureContext::new()),
702 };
703
704 let blocks = vec![
705 build_time_blockspec(&time_prepared, &spec.time_block),
706 build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
707 ];
708 let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
709 let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
710 Ok(LatentBinaryTermFitResult {
711 fit,
712 design: mean_design,
713 resolvedspec,
714 baseline_offset_residuals,
715 })
716}
717
718struct LatentSurvivalModel;
724
725impl LatentIntervalModel for LatentSurvivalModel {
726 fn context() -> &'static str {
727 "latent-survival"
728 }
729
730 fn allows_interval() -> bool {
731 true
732 }
733
734 fn frailty_policy(
735 frailty: &FrailtySpec,
736 ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
737 match frailty {
738 FrailtySpec::HazardMultiplier {
739 sigma_fixed,
740 loading,
741 } => {
742 if let Some(sigma) = sigma_fixed
743 && (!sigma.is_finite() || *sigma < 0.0)
744 {
745 return Err(LatentSurvivalError::InvalidFrailty {
746 reason: format!(
747 "latent-survival requires a finite hazard-multiplier sigma >= 0, got {sigma}"
748 ),
749 });
750 }
751 Ok(LatentFrailtyResolution {
752 sigma: *sigma_fixed,
753 loading: *loading,
754 })
755 }
756 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
757 reason: "latent-survival requires HazardMultiplier frailty, not GaussianShift"
758 .to_string(),
759 }),
760 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
761 reason: "latent-survival requires a HazardMultiplier frailty specification"
762 .to_string(),
763 }),
764 }
765 }
766}
767
768fn validate_latent_survival_inputs(
769 data: ArrayView2<'_, f64>,
770 spec: &LatentSurvivalTermSpec,
771 frailty: &FrailtySpec,
772) -> Result<Option<f64>, LatentSurvivalError> {
773 let row = LatentIntervalRowView {
774 frailty,
775 age_entry: &spec.age_entry,
776 age_exit: &spec.age_exit,
777 event_target: &spec.event_target,
778 weights: &spec.weights,
779 unloaded_mass_entry: &spec.unloaded_mass_entry,
780 unloaded_mass_exit: &spec.unloaded_mass_exit,
781 unloaded_hazard_exit: Some(&spec.unloaded_hazard_exit),
782 mean_offset: &spec.mean_offset,
783 derivative_guard: spec.derivative_guard,
784 time_block: &spec.time_block,
785 };
786 validate_latent_interval_inputs::<LatentSurvivalModel>(data, &row)
787}
788
789pub(crate) fn validate_unloaded_components_for_loading(
790 context: &str,
791 row_index: usize,
792 loading: HazardLoading,
793 unloaded_entry: f64,
794 unloaded_exit: f64,
795 unloaded_hazard: Option<f64>,
796) -> Result<(), LatentSurvivalError> {
797 match loading {
798 HazardLoading::Full => {
799 if unloaded_entry != 0.0
800 || unloaded_exit != 0.0
801 || unloaded_hazard.is_some_and(|hazard| hazard != 0.0)
802 {
803 return Err(LatentSurvivalError::InvalidDataset {
804 reason: format!(
805 "{context} row {} uses full hazard loading, so unloaded components must be exactly zero; got entry_mass={}, exit_mass={}, exit_hazard={}",
806 row_index + 1,
807 unloaded_entry,
808 unloaded_exit,
809 unloaded_hazard.unwrap_or(0.0)
810 ),
811 });
812 }
813 }
814 HazardLoading::LoadedVsUnloaded => {}
815 }
816 Ok(())
817}
818
819struct LatentBinaryModel;
826
827impl LatentIntervalModel for LatentBinaryModel {
828 fn context() -> &'static str {
829 "latent-binary"
830 }
831
832 fn frailty_policy(
833 frailty: &FrailtySpec,
834 ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
835 let (sigma, loading) = fixed_latent_hazard_frailty_typed(frailty, "latent-binary")?;
836 Ok(LatentFrailtyResolution {
837 sigma: Some(sigma),
838 loading,
839 })
840 }
841}
842
843fn validate_latent_binary_inputs(
844 data: ArrayView2<'_, f64>,
845 spec: &LatentBinaryTermSpec,
846 frailty: &FrailtySpec,
847) -> Result<f64, LatentSurvivalError> {
848 let row = LatentIntervalRowView {
849 frailty,
850 age_entry: &spec.age_entry,
851 age_exit: &spec.age_exit,
852 event_target: &spec.event_target,
853 weights: &spec.weights,
854 unloaded_mass_entry: &spec.unloaded_mass_entry,
855 unloaded_mass_exit: &spec.unloaded_mass_exit,
856 unloaded_hazard_exit: None,
857 mean_offset: &spec.mean_offset,
858 derivative_guard: spec.derivative_guard,
859 time_block: &spec.time_block,
860 };
861 validate_latent_interval_inputs::<LatentBinaryModel>(data, &row)?.ok_or_else(|| {
866 LatentSurvivalError::InvalidFrailty {
867 reason: "latent-binary requires a fixed latent sigma".to_string(),
868 }
869 })
870}
871
872fn prepare_latent_time_block(
873 input: &TimeBlockInput,
874 design_right: Option<&DesignMatrix>,
875 derivative_guard: f64,
876) -> Result<PreparedLatentTimeBlock, LatentSurvivalError> {
877 if !input.time_monotonicity.is_coordinate_cone() {
878 return Err(LatentSurvivalError::UnsupportedConfiguration {
879 reason: format!(
880 "latent survival requires a coordinate-cone monotonicity strategy; got {:?}",
881 input.time_monotonicity
882 ),
883 });
884 }
885 let design_entry = input
886 .design_entry
887 .try_to_dense_by_chunks("latent survival entry time design")?;
888 let design_exit = input
889 .design_exit
890 .try_to_dense_by_chunks("latent survival exit time design")?;
891 let design_derivative_exit = input
892 .design_derivative_exit
893 .try_to_dense_by_chunks("latent survival derivative time design")?;
894 let design_right = match design_right {
900 Some(matrix) => {
901 let dense =
902 matrix.try_to_dense_by_chunks("latent survival interval right time design")?;
903 if dense.nrows() != design_exit.nrows() || dense.ncols() != design_exit.ncols() {
904 return Err(LatentSurvivalError::InvalidDataset {
905 reason: format!(
906 "latent survival interval right time design must match exit design shape \
907 {:?}, got {:?}",
908 design_exit.dim(),
909 dense.dim()
910 ),
911 });
912 }
913 dense
914 }
915 None => design_exit.clone(),
916 };
917 let linear_constraints = structural_time_coefficient_constraints(
918 &input.design_derivative_exit,
919 &input.derivative_offset_exit,
920 derivative_guard,
921 )?;
922 let initial_beta = match linear_constraints.as_ref() {
923 Some(constraints) => Some(project_onto_linear_constraints(
928 design_exit.ncols(),
929 constraints,
930 input.initial_beta.as_ref(),
931 )?),
932 None => None,
933 };
934 Ok(PreparedLatentTimeBlock {
935 design_entry,
936 design_exit,
937 design_derivative_exit,
938 design_right,
939 linear_constraints,
940 penalties: input.penalties.clone(),
941 initial_beta,
942 })
943}
944
945fn stack_rows(blocks: &[&Array2<f64>]) -> Array2<f64> {
946 let ncols = blocks.first().map_or(0, |m| m.ncols());
947 let nrows = blocks.iter().map(|m| m.nrows()).sum();
948 let mut out = Array2::<f64>::zeros((nrows, ncols));
949 let mut row = 0usize;
950 for block in blocks {
951 let end = row + block.nrows();
952 out.slice_mut(s![row..end, ..]).assign(block);
953 row = end;
954 }
955 out
956}
957
958fn build_time_blockspec(
959 prepared: &PreparedLatentTimeBlock,
960 input: &TimeBlockInput,
961) -> ParameterBlockSpec {
962 let stacked_design = stack_rows(&[
974 &prepared.design_entry,
975 &prepared.design_exit,
976 &prepared.design_derivative_exit,
977 ]);
978 let stacked_offset = gam_linalg::utils::stack_offsets(&[
979 &input.offset_entry,
980 &input.offset_exit,
981 &input.derivative_offset_exit,
982 ]);
983 ParameterBlockSpec {
984 name: "time_transform".to_string(),
985 design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
986 prepared.design_exit.clone(),
987 ))),
988 offset: input.offset_exit.clone(),
989 penalties: prepared
990 .penalties
991 .iter()
992 .cloned()
993 .map(PenaltyMatrix::Dense)
994 .collect(),
995 nullspace_dims: input.nullspace_dims.clone(),
996 initial_log_lambdas: input
997 .initial_log_lambdas
998 .clone()
999 .unwrap_or_else(|| Array1::zeros(prepared.penalties.len())),
1000 initial_beta: prepared.initial_beta.clone(),
1001 gauge_priority: 200,
1007 jacobian_callback: None,
1008 stacked_design: Some(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
1009 stacked_design,
1010 )))),
1011 stacked_offset: Some(stacked_offset),
1012 }
1013}
1014
1015fn build_mean_blockspec(design: &TermCollectionDesign, offset: Array1<f64>) -> ParameterBlockSpec {
1016 ParameterBlockSpec {
1017 name: "mean".to_string(),
1018 design: design.design.clone(),
1019 offset,
1020 penalties: design.penalties_as_penalty_matrix(),
1021 nullspace_dims: design.nullspace_dims.clone(),
1022 initial_log_lambdas: Array1::zeros(design.penalties.len()),
1023 initial_beta: None,
1024 gauge_priority: 150,
1030 jacobian_callback: None,
1031 stacked_design: None,
1032 stacked_offset: None,
1033 }
1034}
1035
1036const LEARNABLE_LATENT_SD_SEED: f64 = 0.5;
1043
1044fn build_log_sigma_blockspec(initial_sigma: f64, n_obs: usize) -> ParameterBlockSpec {
1045 ParameterBlockSpec {
1046 name: "log_sigma".to_string(),
1047 design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(Array2::from_elem(
1057 (n_obs, 1),
1058 1.0,
1059 )))),
1060 offset: Array1::zeros(n_obs),
1061 penalties: vec![],
1062 nullspace_dims: vec![],
1063 initial_log_lambdas: Array1::zeros(0),
1064 initial_beta: Some(Array1::from_elem(
1065 1,
1066 exp_sigma_eta_for_sigma_scalar(initial_sigma),
1067 )),
1068 gauge_priority: 120,
1071 jacobian_callback: None,
1072 stacked_design: None,
1073 stacked_offset: None,
1074 }
1075}
1076
1077const LATENT_SURVIVAL_PRIMARY_Q_ENTRY: usize = 0;
1078const LATENT_SURVIVAL_PRIMARY_Q_EXIT: usize = 1;
1079const LATENT_SURVIVAL_PRIMARY_QDOT_EXIT: usize = 2;
1080const LATENT_SURVIVAL_PRIMARY_Q_RIGHT: usize = 3;
1087const LATENT_SURVIVAL_PRIMARY_MU: usize = 4;
1088const LATENT_SURVIVAL_PRIMARY_LOG_SIGMA: usize = 5;
1089const LATENT_SURVIVAL_PRIMARY_DIM: usize = 6;
1090
1091use gam_math::jet_partitions::MultiDirJet as LatentMultiDirJet;
1092
1093#[inline]
1109fn latent_unary_derivatives_log(x: f64) -> [f64; 5] {
1110 let x2 = x * x;
1111 let x3 = x2 * x;
1112 let x4 = x3 * x;
1113 [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
1114}
1115
1116#[derive(Clone, Copy, Debug)]
1117struct LatentKernelPrimaryTerm {
1118 coeff: f64,
1119 q_exp: usize,
1120 qdot_power: usize,
1121 tau_exp: usize,
1122 k: usize,
1123}
1124
1125#[derive(Clone, Copy, Debug)]
1126struct LatentKernelPrimaryDirection {
1127 dq: f64,
1128 dqd: f64,
1129 dmu: f64,
1130 dtau: f64,
1131}
1132
1133#[derive(Clone, Copy, Debug)]
1134struct LatentSurvivalPrimaryDirection {
1135 dq_entry: f64,
1136 dq_exit: f64,
1137 dqdot_exit: f64,
1138 dq_right: f64,
1139 dmu: f64,
1140 dlog_sigma: f64,
1141}
1142
1143#[derive(Clone, Copy, Debug)]
1144struct LatentKernelPrimaryState {
1145 q: f64,
1146 qdot: f64,
1147 mu: f64,
1148 sigma: f64,
1149 log_sigma_factor: f64,
1150}
1151
1152fn latent_kernel_accumulate_term(
1153 terms: &mut BTreeMap<(usize, usize, usize, usize), f64>,
1154 term: LatentKernelPrimaryTerm,
1155 scale: f64,
1156) {
1157 if scale == 0.0 || term.coeff == 0.0 {
1158 return;
1159 }
1160 *terms
1161 .entry((term.q_exp, term.qdot_power, term.tau_exp, term.k))
1162 .or_insert(0.0) += scale * term.coeff;
1163}
1164
1165fn latent_kernel_differentiate_terms(
1166 terms: &[LatentKernelPrimaryTerm],
1167 dir: LatentKernelPrimaryDirection,
1168) -> Vec<LatentKernelPrimaryTerm> {
1169 let mut out = BTreeMap::<(usize, usize, usize, usize), f64>::new();
1170 for term in terms {
1171 if dir.dq != 0.0 {
1172 if term.q_exp > 0 {
1173 latent_kernel_accumulate_term(&mut out, *term, dir.dq * term.q_exp as f64);
1174 }
1175 latent_kernel_accumulate_term(
1176 &mut out,
1177 LatentKernelPrimaryTerm {
1178 q_exp: term.q_exp + 1,
1179 k: term.k + 1,
1180 ..*term
1181 },
1182 -dir.dq,
1183 );
1184 }
1185 if dir.dmu != 0.0 {
1186 if term.k > 0 {
1187 latent_kernel_accumulate_term(&mut out, *term, dir.dmu * term.k as f64);
1188 }
1189 latent_kernel_accumulate_term(
1190 &mut out,
1191 LatentKernelPrimaryTerm {
1192 q_exp: term.q_exp + 1,
1193 k: term.k + 1,
1194 ..*term
1195 },
1196 -dir.dmu,
1197 );
1198 }
1199 if dir.dtau != 0.0 {
1200 if term.tau_exp > 0 {
1201 latent_kernel_accumulate_term(&mut out, *term, dir.dtau * term.tau_exp as f64);
1202 }
1203 let kf = term.k as f64;
1204 latent_kernel_accumulate_term(
1205 &mut out,
1206 LatentKernelPrimaryTerm {
1207 tau_exp: term.tau_exp + 2,
1208 ..*term
1209 },
1210 dir.dtau * kf * kf,
1211 );
1212 latent_kernel_accumulate_term(
1213 &mut out,
1214 LatentKernelPrimaryTerm {
1215 q_exp: term.q_exp + 1,
1216 tau_exp: term.tau_exp + 2,
1217 k: term.k + 1,
1218 ..*term
1219 },
1220 -dir.dtau * (2.0 * kf + 1.0),
1221 );
1222 latent_kernel_accumulate_term(
1223 &mut out,
1224 LatentKernelPrimaryTerm {
1225 q_exp: term.q_exp + 2,
1226 tau_exp: term.tau_exp + 2,
1227 k: term.k + 2,
1228 ..*term
1229 },
1230 dir.dtau,
1231 );
1232 }
1233 if dir.dqd != 0.0 && term.qdot_power > 0 {
1234 latent_kernel_accumulate_term(
1235 &mut out,
1236 LatentKernelPrimaryTerm {
1237 qdot_power: term.qdot_power - 1,
1238 ..*term
1239 },
1240 dir.dqd * term.qdot_power as f64,
1241 );
1242 }
1243 }
1244 out.into_iter()
1245 .filter_map(|((q_exp, qdot_power, tau_exp, k), coeff)| {
1246 (coeff != 0.0).then_some(LatentKernelPrimaryTerm {
1247 coeff,
1248 q_exp,
1249 qdot_power,
1250 tau_exp,
1251 k,
1252 })
1253 })
1254 .collect()
1255}
1256
1257fn latent_kernel_term_lists_for_directions(
1258 base_terms: &[LatentKernelPrimaryTerm],
1259 directions: &[LatentKernelPrimaryDirection],
1260) -> Vec<Vec<LatentKernelPrimaryTerm>> {
1261 fn build_mask(
1262 mask: usize,
1263 base_terms: &[LatentKernelPrimaryTerm],
1264 directions: &[LatentKernelPrimaryDirection],
1265 cache: &mut [Option<Vec<LatentKernelPrimaryTerm>>],
1266 ) -> Vec<LatentKernelPrimaryTerm> {
1267 if let Some(existing) = &cache[mask] {
1268 return existing.clone();
1269 }
1270 let built = if mask == 0 {
1271 base_terms.to_vec()
1272 } else {
1273 let bit = 1usize << mask.trailing_zeros();
1274 let prev = build_mask(mask ^ bit, base_terms, directions, cache);
1275 latent_kernel_differentiate_terms(&prev, directions[bit.trailing_zeros() as usize])
1276 };
1277 cache[mask] = Some(built.clone());
1278 built
1279 }
1280
1281 let mut cache = vec![None; 1usize << directions.len()];
1282 (0..cache.len())
1283 .map(|mask| build_mask(mask, base_terms, directions, &mut cache))
1284 .collect()
1285}
1286
1287fn latent_kernel_sum_log_jet(
1288 quadctx: &QuadratureContext,
1289 base_terms: &[LatentKernelPrimaryTerm],
1290 state: LatentKernelPrimaryState,
1291 directions: &[LatentKernelPrimaryDirection],
1292 context: &str,
1293) -> Result<LatentMultiDirJet, LatentSurvivalError> {
1294 let term_lists = latent_kernel_term_lists_for_directions(base_terms, directions);
1295 let max_k = term_lists
1296 .iter()
1297 .flat_map(|terms| terms.iter().map(|term| term.k))
1298 .max()
1299 .unwrap_or(0);
1300 let bundle =
1301 log_kernel_bundle(quadctx, state.q.exp(), state.mu, state.sigma, max_k).map_err(|e| {
1302 LatentSurvivalError::NumericalFailure {
1303 reason: format!("{context} kernel evaluation failed: {e}"),
1304 }
1305 })?;
1306
1307 let evaluate_terms =
1308 |terms: &[LatentKernelPrimaryTerm]| -> Result<(f64, f64), LatentSurvivalError> {
1309 let mut log_mags = Vec::new();
1310 let mut signs = Vec::new();
1311 for term in terms {
1312 if term.coeff == 0.0 {
1313 continue;
1314 }
1315 if term.qdot_power > 0 && !(state.qdot.is_finite() && state.qdot > 0.0) {
1316 return Err(LatentSurvivalError::NumericalFailure {
1317 reason: format!(
1318 "{context} requires positive finite qdot for exact-event directional terms, got {}",
1319 state.qdot
1320 ),
1321 });
1322 }
1323 let log_qdot = if term.qdot_power > 0 {
1324 state.qdot.ln()
1325 } else {
1326 0.0
1327 };
1328 let log_mag = term.coeff.abs().ln()
1329 + term.q_exp as f64 * state.q
1330 + term.tau_exp as f64 * state.log_sigma_factor
1331 + term.qdot_power as f64 * log_qdot
1332 + bundle.get(term.k);
1333 log_mags.push(log_mag);
1334 signs.push(term.coeff.signum());
1335 }
1336 if log_mags.is_empty() {
1337 return Ok((f64::NEG_INFINITY, 0.0));
1338 }
1339 Ok(signed_log_sum_exp(&log_mags, &signs))
1340 };
1341
1342 let (base_log_sum, base_sign) = evaluate_terms(&term_lists[0])?;
1343 if !(base_log_sum.is_finite() && base_sign > 0.0) {
1344 return Err(LatentSurvivalError::NumericalFailure {
1345 reason: format!("{context} produced a non-positive signed kernel sum"),
1346 });
1347 }
1348
1349 let mut normalized = LatentMultiDirJet::constant(directions.len(), 1.0);
1350 for mask in 1..term_lists.len() {
1351 let (log_abs, sign) = evaluate_terms(&term_lists[mask])?;
1352 normalized.coeffs[mask] = if !log_abs.is_finite() || sign == 0.0 {
1353 0.0
1354 } else {
1355 sign * (log_abs - base_log_sum).exp()
1356 };
1357 }
1358
1359 let mut out = normalized.compose_unary(latent_unary_derivatives_log(1.0));
1360 out.coeffs[0] += base_log_sum;
1361 Ok(out)
1362}
1363
1364fn latent_survival_basis_direction(primary_idx: usize) -> LatentSurvivalPrimaryDirection {
1365 match primary_idx {
1366 LATENT_SURVIVAL_PRIMARY_Q_ENTRY => LatentSurvivalPrimaryDirection {
1367 dq_entry: 1.0,
1368 dq_exit: 0.0,
1369 dqdot_exit: 0.0,
1370 dq_right: 0.0,
1371 dmu: 0.0,
1372 dlog_sigma: 0.0,
1373 },
1374 LATENT_SURVIVAL_PRIMARY_Q_EXIT => LatentSurvivalPrimaryDirection {
1375 dq_entry: 0.0,
1376 dq_exit: 1.0,
1377 dqdot_exit: 0.0,
1378 dq_right: 0.0,
1379 dmu: 0.0,
1380 dlog_sigma: 0.0,
1381 },
1382 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT => LatentSurvivalPrimaryDirection {
1383 dq_entry: 0.0,
1384 dq_exit: 0.0,
1385 dqdot_exit: 1.0,
1386 dq_right: 0.0,
1387 dmu: 0.0,
1388 dlog_sigma: 0.0,
1389 },
1390 LATENT_SURVIVAL_PRIMARY_Q_RIGHT => LatentSurvivalPrimaryDirection {
1391 dq_entry: 0.0,
1392 dq_exit: 0.0,
1393 dqdot_exit: 0.0,
1394 dq_right: 1.0,
1395 dmu: 0.0,
1396 dlog_sigma: 0.0,
1397 },
1398 LATENT_SURVIVAL_PRIMARY_MU => LatentSurvivalPrimaryDirection {
1399 dq_entry: 0.0,
1400 dq_exit: 0.0,
1401 dqdot_exit: 0.0,
1402 dq_right: 0.0,
1403 dmu: 1.0,
1404 dlog_sigma: 0.0,
1405 },
1406 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA => LatentSurvivalPrimaryDirection {
1407 dq_entry: 0.0,
1408 dq_exit: 0.0,
1409 dqdot_exit: 0.0,
1410 dq_right: 0.0,
1411 dmu: 0.0,
1412 dlog_sigma: 1.0,
1413 },
1414 _ => std::panic::panic_any(format!(
1422 "latent survival primary index out of bounds: primary_idx={primary_idx}, primary_dim={LATENT_SURVIVAL_PRIMARY_DIM}"
1423 )),
1424 }
1425}
1426
1427fn latent_survival_map_entry_direction(
1428 direction: LatentSurvivalPrimaryDirection,
1429) -> LatentKernelPrimaryDirection {
1430 LatentKernelPrimaryDirection {
1431 dq: direction.dq_entry,
1432 dqd: 0.0,
1433 dmu: direction.dmu,
1434 dtau: direction.dlog_sigma,
1435 }
1436}
1437
1438fn latent_survival_map_exit_direction(
1439 direction: LatentSurvivalPrimaryDirection,
1440 event_type: LatentSurvivalEventType,
1441) -> LatentKernelPrimaryDirection {
1442 LatentKernelPrimaryDirection {
1443 dq: direction.dq_exit,
1444 dqd: if matches!(event_type, LatentSurvivalEventType::ExactEvent) {
1445 direction.dqdot_exit
1446 } else {
1447 0.0
1448 },
1449 dmu: direction.dmu,
1450 dtau: direction.dlog_sigma,
1451 }
1452}
1453
1454fn latent_survival_map_left_direction(
1458 direction: LatentSurvivalPrimaryDirection,
1459) -> LatentKernelPrimaryDirection {
1460 LatentKernelPrimaryDirection {
1461 dq: direction.dq_exit,
1462 dqd: 0.0,
1463 dmu: direction.dmu,
1464 dtau: direction.dlog_sigma,
1465 }
1466}
1467
1468fn latent_survival_map_right_direction(
1473 direction: LatentSurvivalPrimaryDirection,
1474) -> LatentKernelPrimaryDirection {
1475 LatentKernelPrimaryDirection {
1476 dq: direction.dq_right,
1477 dqd: 0.0,
1478 dmu: direction.dmu,
1479 dtau: direction.dlog_sigma,
1480 }
1481}
1482
1483fn latent_survival_row_primary_log_jet(
1484 quadctx: &QuadratureContext,
1485 row: &LatentSurvivalRow,
1486 q_entry: f64,
1487 q_exit: f64,
1488 qdot_exit: f64,
1489 q_right: f64,
1490 mu: f64,
1491 sigma: f64,
1492 log_sigma_factor: f64,
1493 directions: &[LatentSurvivalPrimaryDirection],
1494) -> Result<LatentMultiDirJet, String> {
1495 let entry_state = LatentKernelPrimaryState {
1496 q: q_entry,
1497 qdot: 1.0,
1498 mu,
1499 sigma,
1500 log_sigma_factor,
1501 };
1502 let entry_directions = directions
1503 .iter()
1504 .copied()
1505 .map(latent_survival_map_entry_direction)
1506 .collect::<Vec<_>>();
1507
1508 let denominator = latent_kernel_sum_log_jet(
1509 quadctx,
1510 &[LatentKernelPrimaryTerm {
1511 coeff: 1.0,
1512 q_exp: 0,
1513 qdot_power: 0,
1514 tau_exp: 0,
1515 k: 0,
1516 }],
1517 entry_state,
1518 &entry_directions,
1519 "latent survival denominator",
1520 )?;
1521
1522 let numerator = match row.event_type {
1527 LatentSurvivalEventType::RightCensored | LatentSurvivalEventType::ExactEvent => {
1528 let exit_state = LatentKernelPrimaryState {
1529 q: q_exit,
1530 qdot: qdot_exit,
1531 mu,
1532 sigma,
1533 log_sigma_factor,
1534 };
1535 let exit_directions = directions
1536 .iter()
1537 .copied()
1538 .map(|dir| latent_survival_map_exit_direction(dir, row.event_type))
1539 .collect::<Vec<_>>();
1540 let numerator_terms = match row.event_type {
1541 LatentSurvivalEventType::RightCensored => vec![LatentKernelPrimaryTerm {
1542 coeff: 1.0,
1543 q_exp: 0,
1544 qdot_power: 0,
1545 tau_exp: 0,
1546 k: 0,
1547 }],
1548 LatentSurvivalEventType::ExactEvent => {
1549 let mut terms = Vec::new();
1550 if row.hazard_unloaded > 0.0 {
1551 terms.push(LatentKernelPrimaryTerm {
1552 coeff: row.hazard_unloaded,
1553 q_exp: 0,
1554 qdot_power: 0,
1555 tau_exp: 0,
1556 k: 0,
1557 });
1558 }
1559 terms.push(LatentKernelPrimaryTerm {
1560 coeff: 1.0,
1561 q_exp: 1,
1562 qdot_power: 1,
1563 tau_exp: 0,
1564 k: 1,
1565 });
1566 terms
1567 }
1568 LatentSurvivalEventType::IntervalCensored => {
1569 return Err(
1574 "interval-censored row reached the single-state numerator branch; \
1575 it must take the dedicated two-state branch"
1576 .to_string(),
1577 );
1578 }
1579 };
1580 latent_kernel_sum_log_jet(
1581 quadctx,
1582 &numerator_terms,
1583 exit_state,
1584 &exit_directions,
1585 "latent survival numerator",
1586 )?
1587 }
1588 LatentSurvivalEventType::IntervalCensored => latent_survival_interval_numerator_log_jet(
1589 quadctx,
1590 row,
1591 q_exit,
1592 q_right,
1593 mu,
1594 sigma,
1595 log_sigma_factor,
1596 directions,
1597 )?,
1598 };
1599
1600 let mut total = numerator.add(&denominator.scale(-1.0));
1601 match row.event_type {
1607 LatentSurvivalEventType::IntervalCensored => {
1608 total.coeffs[0] += row.mass_unloaded_entry;
1609 }
1610 _ => {
1611 total.coeffs[0] += -row.mass_unloaded_exit + row.mass_unloaded_entry;
1612 }
1613 }
1614 Ok(total)
1615}
1616
1617fn latent_survival_interval_numerator_log_jet(
1641 quadctx: &QuadratureContext,
1642 row: &LatentSurvivalRow,
1643 q_exit: f64,
1644 q_right: f64,
1645 mu: f64,
1646 sigma: f64,
1647 log_sigma_factor: f64,
1648 directions: &[LatentSurvivalPrimaryDirection],
1649) -> Result<LatentMultiDirJet, String> {
1650 let single_k0 = [LatentKernelPrimaryTerm {
1651 coeff: 1.0,
1652 q_exp: 0,
1653 qdot_power: 0,
1654 tau_exp: 0,
1655 k: 0,
1656 }];
1657
1658 let left_state = LatentKernelPrimaryState {
1659 q: q_exit,
1660 qdot: 1.0,
1661 mu,
1662 sigma,
1663 log_sigma_factor,
1664 };
1665 let right_state = LatentKernelPrimaryState {
1666 q: q_right,
1667 qdot: 1.0,
1668 mu,
1669 sigma,
1670 log_sigma_factor,
1671 };
1672 let left_directions = directions
1673 .iter()
1674 .copied()
1675 .map(latent_survival_map_left_direction)
1676 .collect::<Vec<_>>();
1677 let right_directions = directions
1678 .iter()
1679 .copied()
1680 .map(latent_survival_map_right_direction)
1681 .collect::<Vec<_>>();
1682
1683 let log_left = latent_kernel_sum_log_jet(
1684 quadctx,
1685 &single_k0,
1686 left_state,
1687 &left_directions,
1688 "latent survival interval left boundary",
1689 )?;
1690 let log_right = latent_kernel_sum_log_jet(
1691 quadctx,
1692 &single_k0,
1693 right_state,
1694 &right_directions,
1695 "latent survival interval right boundary",
1696 )?;
1697
1698 let c_left = (-row.mass_unloaded_left).exp();
1702 let c_right = (-row.mass_unloaded_right).exp();
1703 let exp_left_value = log_left.coeff(0).exp();
1704 let exp_right_value = log_right.coeff(0).exp();
1705 let linear_left = log_left.compose_unary([exp_left_value; 5]).scale(c_left);
1706 let linear_right = log_right.compose_unary([exp_right_value; 5]).scale(c_right);
1707
1708 let linear_numerator = linear_left.add(&linear_right.scale(-1.0));
1709 let base = linear_numerator.coeff(0);
1710 if !(base.is_finite() && base > 0.0) {
1711 return Err(LatentSurvivalError::NumericalFailure {
1712 reason: format!(
1713 "latent survival interval numerator must be a positive survival-mass difference, \
1714 got c_L*K0(M_L) - c_R*K0(M_R) = {base}; require M_L < M_R (i.e. L < R)"
1715 ),
1716 }
1717 .into());
1718 }
1719 Ok(linear_numerator.compose_unary(latent_unary_derivatives_log(base)))
1725}
1726
1727fn latent_survival_row_primary_gradient_hessian(
1728 quadctx: &QuadratureContext,
1729 row: &LatentSurvivalRow,
1730 q_entry: f64,
1731 q_exit: f64,
1732 qdot_exit: f64,
1733 q_right: f64,
1734 mu: f64,
1735 sigma: f64,
1736 include_log_sigma: bool,
1737) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
1738 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1739 let mut gradient = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1740 let mut neg_hessian =
1741 Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1742 let active_primary = if include_log_sigma {
1743 LATENT_SURVIVAL_PRIMARY_DIM
1744 } else {
1745 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1746 };
1747 let log_lik = latent_survival_row_primary_log_jet(
1748 quadctx,
1749 row,
1750 q_entry,
1751 q_exit,
1752 qdot_exit,
1753 q_right,
1754 mu,
1755 sigma,
1756 log_sigma_factor,
1757 &[],
1758 )?
1759 .coeff(0);
1760 for a in 0..active_primary {
1761 let dir_a = latent_survival_basis_direction(a);
1762 gradient[a] = latent_survival_row_primary_log_jet(
1763 quadctx,
1764 row,
1765 q_entry,
1766 q_exit,
1767 qdot_exit,
1768 q_right,
1769 mu,
1770 sigma,
1771 log_sigma_factor,
1772 &[dir_a],
1773 )?
1774 .coeff(1);
1775 for b in a..active_primary {
1776 let coeff = latent_survival_row_primary_log_jet(
1777 quadctx,
1778 row,
1779 q_entry,
1780 q_exit,
1781 qdot_exit,
1782 q_right,
1783 mu,
1784 sigma,
1785 log_sigma_factor,
1786 &[dir_a, latent_survival_basis_direction(b)],
1787 )?
1788 .coeff(3);
1789 neg_hessian[[a, b]] = -coeff;
1790 neg_hessian[[b, a]] = -coeff;
1791 }
1792 }
1793 Ok((log_lik, gradient, neg_hessian))
1794}
1795
1796fn latent_survival_row_primary_third_contracted(
1797 quadctx: &QuadratureContext,
1798 row: &LatentSurvivalRow,
1799 q_entry: f64,
1800 q_exit: f64,
1801 qdot_exit: f64,
1802 q_right: f64,
1803 mu: f64,
1804 sigma: f64,
1805 direction: &Array1<f64>,
1806 include_log_sigma: bool,
1807) -> Result<Array2<f64>, String> {
1808 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1809 let active_primary = if include_log_sigma {
1810 LATENT_SURVIVAL_PRIMARY_DIM
1811 } else {
1812 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1813 };
1814 let dir = LatentSurvivalPrimaryDirection {
1815 dq_entry: direction[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1816 dq_exit: direction[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1817 dqdot_exit: direction[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1818 dq_right: direction[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1819 dmu: direction[LATENT_SURVIVAL_PRIMARY_MU],
1820 dlog_sigma: direction[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1821 };
1822 let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1823 for a in 0..active_primary {
1824 let dir_a = latent_survival_basis_direction(a);
1825 for b in a..active_primary {
1826 let coeff = latent_survival_row_primary_log_jet(
1827 quadctx,
1828 row,
1829 q_entry,
1830 q_exit,
1831 qdot_exit,
1832 q_right,
1833 mu,
1834 sigma,
1835 log_sigma_factor,
1836 &[dir_a, latent_survival_basis_direction(b), dir],
1837 )?
1838 .coeff(7);
1839 out[[a, b]] = -coeff;
1840 out[[b, a]] = -coeff;
1841 }
1842 }
1843 Ok(out)
1844}
1845
1846fn latent_survival_row_primary_fourth_contracted(
1847 quadctx: &QuadratureContext,
1848 row: &LatentSurvivalRow,
1849 q_entry: f64,
1850 q_exit: f64,
1851 qdot_exit: f64,
1852 q_right: f64,
1853 mu: f64,
1854 sigma: f64,
1855 direction_u: &Array1<f64>,
1856 direction_v: &Array1<f64>,
1857 include_log_sigma: bool,
1858) -> Result<Array2<f64>, String> {
1859 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1860 let active_primary = if include_log_sigma {
1861 LATENT_SURVIVAL_PRIMARY_DIM
1862 } else {
1863 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1864 };
1865 let dir_u = LatentSurvivalPrimaryDirection {
1866 dq_entry: direction_u[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1867 dq_exit: direction_u[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1868 dqdot_exit: direction_u[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1869 dq_right: direction_u[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1870 dmu: direction_u[LATENT_SURVIVAL_PRIMARY_MU],
1871 dlog_sigma: direction_u[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1872 };
1873 let dir_v = LatentSurvivalPrimaryDirection {
1874 dq_entry: direction_v[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1875 dq_exit: direction_v[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1876 dqdot_exit: direction_v[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1877 dq_right: direction_v[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1878 dmu: direction_v[LATENT_SURVIVAL_PRIMARY_MU],
1879 dlog_sigma: direction_v[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1880 };
1881 let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1882 for a in 0..active_primary {
1883 let dir_a = latent_survival_basis_direction(a);
1884 for b in a..active_primary {
1885 let coeff = latent_survival_row_primary_log_jet(
1886 quadctx,
1887 row,
1888 q_entry,
1889 q_exit,
1890 qdot_exit,
1891 q_right,
1892 mu,
1893 sigma,
1894 log_sigma_factor,
1895 &[dir_a, latent_survival_basis_direction(b), dir_u, dir_v],
1896 )?
1897 .coeff(15);
1898 out[[a, b]] = -coeff;
1899 out[[b, a]] = -coeff;
1900 }
1901 }
1902 Ok(out)
1903}
1904
1905#[derive(Clone)]
1906struct LatentSurvivalJointSlices {
1907 time: std::ops::Range<usize>,
1908 mean: std::ops::Range<usize>,
1909 log_sigma: Option<std::ops::Range<usize>>,
1910 total: usize,
1911}
1912
1913#[derive(Clone)]
1914struct LatentSurvivalJointGradientAccum {
1915 ll: f64,
1916 gradient: Array1<f64>,
1917}
1918
1919#[derive(Clone)]
1920struct LatentSurvivalJointDenseAccum {
1921 ll: f64,
1922 gradient: Array1<f64>,
1923 hessian: Array2<f64>,
1924}
1925
1926#[derive(Clone)]
1927struct LatentSurvivalDenseHessianAccum {
1928 hessian: Array2<f64>,
1929}
1930
1931fn deterministic_latent_survival_row_reduction<Acc, Init, Process, Combine>(
1935 n_rows: usize,
1936 init: Init,
1937 process_row: Process,
1938 mut combine: Combine,
1939) -> Result<Acc, String>
1940where
1941 Acc: Send,
1942 Init: Fn() -> Acc + Sync,
1943 Process: Fn(usize, &mut Acc) -> Result<(), String> + Sync,
1944 Combine: FnMut(&mut Acc, Acc),
1945{
1946 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1947
1948 const TARGET_CHUNK_COUNT: usize = 32;
1949 if n_rows == 0 {
1950 return Ok(init());
1951 }
1952 let chunk_size = n_rows.div_ceil(TARGET_CHUNK_COUNT).max(1);
1953 let n_chunks = n_rows.div_ceil(chunk_size);
1954 let chunk_accumulators: Vec<Acc> = (0..n_chunks)
1955 .into_par_iter()
1956 .map(|chunk_idx| -> Result<Acc, String> {
1957 let start = chunk_idx * chunk_size;
1958 let end = (start + chunk_size).min(n_rows);
1959 let mut acc = init();
1960 for row_idx in start..end {
1961 process_row(row_idx, &mut acc)?;
1962 }
1963 Ok(acc)
1964 })
1965 .collect::<Result<Vec<_>, String>>()?;
1966
1967 let mut total = init();
1968 for acc in chunk_accumulators {
1969 combine(&mut total, acc);
1970 }
1971 Ok(total)
1972}
1973
1974impl LatentSurvivalFamily {
1975 fn build_row_at(
1983 &self,
1984 row_idx: usize,
1985 q_entry: f64,
1986 q_exit: f64,
1987 qdot_exit: f64,
1988 q_right: f64,
1989 ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
1990 let event_type = latent_survival_event_type_for(self.event_target[row_idx]);
1991 build_latent_survival_row(
1992 row_idx,
1993 self.hazard_loading,
1994 event_type,
1995 q_entry,
1996 q_exit,
1997 qdot_exit,
1998 q_right,
1999 self.unloaded_mass_entry[row_idx],
2000 self.unloaded_mass_exit[row_idx],
2001 self.unloaded_mass_right[row_idx],
2002 self.unloaded_hazard_exit[row_idx],
2003 )
2004 }
2005
2006 fn joint_slices(&self) -> LatentSurvivalJointSlices {
2007 let p_time = self.x_time_exit.ncols();
2008 let p_mean = self.x_mean.ncols();
2009 let time = 0..p_time;
2010 let mean = p_time..p_time + p_mean;
2011 let log_sigma = self
2012 .latent_sd_fixed
2013 .is_none()
2014 .then_some((p_time + p_mean)..(p_time + p_mean + 1));
2015 LatentSurvivalJointSlices {
2016 total: log_sigma
2017 .as_ref()
2018 .map_or(p_time + p_mean, |range| range.end),
2019 time,
2020 mean,
2021 log_sigma,
2022 }
2023 }
2024
2025 fn row_primary_direction_from_flat(
2026 &self,
2027 row: usize,
2028 slices: &LatentSurvivalJointSlices,
2029 d_beta_flat: &Array1<f64>,
2030 ) -> Array1<f64> {
2031 let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
2032 let d_time = d_beta_flat.slice(s![slices.time.clone()]);
2033 out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
2034 out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
2035 out[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT] = self.x_time_derivative_exit.row(row).dot(&d_time);
2036 out[LATENT_SURVIVAL_PRIMARY_Q_RIGHT] = self.x_time_right.row(row).dot(&d_time);
2037 out[LATENT_SURVIVAL_PRIMARY_MU] = self
2038 .x_mean
2039 .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
2040 if let Some(range) = &slices.log_sigma {
2041 out[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA] = d_beta_flat[range.start];
2042 }
2043 out
2044 }
2045
2046 fn joint_block_ranges(&self) -> Vec<std::ops::Range<usize>> {
2047 let slices = self.joint_slices();
2048 let mut ranges = vec![slices.time.clone(), slices.mean.clone()];
2049 if let Some(log_sigma) = slices.log_sigma {
2050 ranges.push(log_sigma);
2051 }
2052 ranges
2053 }
2054
2055 fn add_pullback_primary_gradient(
2056 &self,
2057 target: &mut Array1<f64>,
2058 row: usize,
2059 slices: &LatentSurvivalJointSlices,
2060 primary_gradient: &Array1<f64>,
2061 weight: f64,
2062 ) -> Result<(), String> {
2063 for (primary_idx, time_vec) in [
2064 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2065 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2066 (
2067 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2068 self.x_time_derivative_exit.row(row),
2069 ),
2070 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2071 ] {
2072 let scale = weight * primary_gradient[primary_idx];
2073 if scale == 0.0 {
2074 continue;
2075 }
2076 for i in 0..time_vec.len() {
2077 let xi = time_vec[i];
2078 if xi != 0.0 {
2079 target[slices.time.start + i] += scale * xi;
2080 }
2081 }
2082 }
2083
2084 let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
2085 if mean_scale != 0.0 {
2086 self.x_mean
2087 .axpy_row_into(
2088 row,
2089 mean_scale,
2090 &mut target.slice_mut(s![slices.mean.clone()]),
2091 )
2092 .map_err(|error| {
2093 format!(
2094 "latent survival mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
2095 slices.mean,
2096 target.len(),
2097 self.x_mean.ncols()
2098 )
2099 })?;
2100 }
2101
2102 if let Some(log_sigma) = &slices.log_sigma {
2103 target[log_sigma.start] += weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA];
2104 }
2105 Ok(())
2106 }
2107
2108 fn add_pullback_primary_hessian(
2109 &self,
2110 target: &mut Array2<f64>,
2111 row: usize,
2112 slices: &LatentSurvivalJointSlices,
2113 primary_hessian: &Array2<f64>,
2114 ) -> Result<(), String> {
2115 let time_weights = [
2116 primary_hessian[[
2117 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2118 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2119 ]],
2120 primary_hessian[[
2121 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2122 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2123 ]],
2124 primary_hessian[[
2125 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2126 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2127 ]],
2128 primary_hessian[[
2129 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2130 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2131 ]],
2132 ];
2133 let time_cross_weights = [
2134 (
2135 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2136 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2137 &self.x_time_entry,
2138 &self.x_time_exit,
2139 ),
2140 (
2141 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2142 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2143 &self.x_time_entry,
2144 &self.x_time_derivative_exit,
2145 ),
2146 (
2147 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2148 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2149 &self.x_time_exit,
2150 &self.x_time_derivative_exit,
2151 ),
2152 (
2153 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2154 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2155 &self.x_time_entry,
2156 &self.x_time_right,
2157 ),
2158 (
2159 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2160 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2161 &self.x_time_exit,
2162 &self.x_time_right,
2163 ),
2164 (
2165 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2166 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2167 &self.x_time_derivative_exit,
2168 &self.x_time_right,
2169 ),
2170 ];
2171 {
2172 let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
2173 dense_outer_accumulate(time_target, time_weights[0], self.x_time_entry.row(row));
2174 dense_outer_accumulate(time_target, time_weights[1], self.x_time_exit.row(row));
2175 dense_outer_accumulate(
2176 time_target,
2177 time_weights[2],
2178 self.x_time_derivative_exit.row(row),
2179 );
2180 dense_outer_accumulate(time_target, time_weights[3], self.x_time_right.row(row));
2181 for (a, b, lhs, rhs) in time_cross_weights {
2182 let weight = primary_hessian[[a, b]];
2183 if weight == 0.0 {
2184 continue;
2185 }
2186 dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2187 }
2188 }
2189
2190 let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2191 self.x_mean
2192 .syr_row_into_view(
2193 row,
2194 mean_weight,
2195 target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
2196 )
2197 .map_err(|error| {
2198 format!(
2199 "latent survival mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
2200 slices.mean,
2201 target.dim(),
2202 self.x_mean.ncols()
2203 )
2204 })?;
2205
2206 let mean_row = self
2207 .x_mean
2208 .try_row_chunk(row..row + 1)
2209 .map_err(|error| {
2210 format!(
2211 "latent survival mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
2212 self.x_mean.nrows(),
2213 self.x_mean.ncols()
2214 )
2215 })?;
2216 let mean_vec = mean_row.row(0);
2217 let time_mean_weights = [
2218 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2219 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2220 (
2221 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2222 self.x_time_derivative_exit.row(row),
2223 ),
2224 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2225 ];
2226 for (primary_idx, time_vec) in time_mean_weights {
2227 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
2228 if weight == 0.0 {
2229 continue;
2230 }
2231 for i in 0..time_vec.len() {
2232 let xi = time_vec[i];
2233 if xi == 0.0 {
2234 continue;
2235 }
2236 for j in 0..mean_vec.len() {
2237 let xj = mean_vec[j];
2238 if xj == 0.0 {
2239 continue;
2240 }
2241 target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
2242 target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
2243 }
2244 }
2245 }
2246
2247 if let Some(log_sigma) = &slices.log_sigma {
2248 let sigma_idx = log_sigma.start;
2249 target[[sigma_idx, sigma_idx]] += primary_hessian[[
2250 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2251 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2252 ]];
2253
2254 for (primary_idx, time_vec) in [
2255 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2256 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2257 (
2258 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2259 self.x_time_derivative_exit.row(row),
2260 ),
2261 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2262 ] {
2263 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_LOG_SIGMA]];
2264 if weight == 0.0 {
2265 continue;
2266 }
2267 for i in 0..time_vec.len() {
2268 let xi = time_vec[i];
2269 if xi == 0.0 {
2270 continue;
2271 }
2272 target[[slices.time.start + i, sigma_idx]] += weight * xi;
2273 target[[sigma_idx, slices.time.start + i]] += weight * xi;
2274 }
2275 }
2276
2277 let mean_sigma_weight = primary_hessian[[
2278 LATENT_SURVIVAL_PRIMARY_MU,
2279 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2280 ]];
2281 if mean_sigma_weight != 0.0 {
2282 for j in 0..mean_vec.len() {
2283 let xj = mean_vec[j];
2284 if xj == 0.0 {
2285 continue;
2286 }
2287 target[[slices.mean.start + j, sigma_idx]] += mean_sigma_weight * xj;
2288 target[[sigma_idx, slices.mean.start + j]] += mean_sigma_weight * xj;
2289 }
2290 }
2291 }
2292 Ok(())
2293 }
2294
2295 fn evaluate_exact_newton_joint_gradient_dense(
2296 &self,
2297 block_states: &[ParameterBlockState],
2298 ) -> Result<(f64, Array1<f64>), String> {
2299 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2300 let q_right = self.time_q_right(block_states)?;
2301 let sigma = self.latent_sd(block_states)?;
2302 let slices = self.joint_slices();
2303 let include_log_sigma = slices.log_sigma.is_some();
2304 let total = slices.total;
2305 let acc = deterministic_latent_survival_row_reduction(
2306 self.event_target.len(),
2307 || LatentSurvivalJointGradientAccum {
2308 ll: 0.0,
2309 gradient: Array1::<f64>::zeros(total),
2310 },
2311 |row_idx, acc| {
2312 let wi = self.weights[row_idx];
2313 if wi <= MIN_WEIGHT {
2314 return Ok(());
2315 }
2316 let row = self.build_row_at(
2317 row_idx,
2318 q_entry[row_idx],
2319 q_exit[row_idx],
2320 qdot_exit[row_idx],
2321 q_right[row_idx],
2322 )?;
2323 let (row_ll, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2324 &self.quadctx,
2325 &row,
2326 q_entry[row_idx],
2327 q_exit[row_idx],
2328 qdot_exit[row_idx],
2329 q_right[row_idx],
2330 mu[row_idx],
2331 sigma,
2332 include_log_sigma,
2333 )?;
2334 acc.ll += wi * row_ll;
2335 self.add_pullback_primary_gradient(
2336 &mut acc.gradient,
2337 row_idx,
2338 &slices,
2339 &primary_gradient,
2340 wi,
2341 )?;
2342 Ok(())
2343 },
2344 |total_acc, chunk_acc| {
2345 total_acc.ll += chunk_acc.ll;
2346 total_acc.gradient += &chunk_acc.gradient;
2347 },
2348 )?;
2349 Ok((acc.ll, acc.gradient))
2350 }
2351
2352 pub fn offset_channel_residuals(
2381 &self,
2382 block_states: &[ParameterBlockState],
2383 ) -> Result<crate::survival::OffsetChannelResiduals, String> {
2384 let n = self.event_target.len();
2385 if block_states.is_empty() {
2386 log::warn!(
2391 "LatentSurvivalFamily::offset_channel_residuals: block_states is empty \
2392 (degraded fit); returning zero offset residuals (n={n})"
2393 );
2394 return Ok(crate::survival::OffsetChannelResiduals {
2395 exit: Array1::<f64>::zeros(n),
2396 entry: Array1::<f64>::zeros(n),
2397 derivative: Array1::<f64>::zeros(n),
2398 right: Array1::<f64>::zeros(n),
2399 });
2400 }
2401 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2402 let q_right = self.time_q_right(block_states)?;
2403 let sigma = self.latent_sd(block_states)?;
2404 let include_log_sigma = self.joint_slices().log_sigma.is_some();
2405 let mut entry = Array1::<f64>::zeros(n);
2406 let mut exit = Array1::<f64>::zeros(n);
2407 let mut derivative = Array1::<f64>::zeros(n);
2408 let mut right = Array1::<f64>::zeros(n);
2409 for row_idx in 0..n {
2410 let wi = self.weights[row_idx];
2411 if wi <= MIN_WEIGHT {
2412 continue;
2413 }
2414 let row = self.build_row_at(
2415 row_idx,
2416 q_entry[row_idx],
2417 q_exit[row_idx],
2418 qdot_exit[row_idx],
2419 q_right[row_idx],
2420 )?;
2421 let (_, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2422 &self.quadctx,
2423 &row,
2424 q_entry[row_idx],
2425 q_exit[row_idx],
2426 qdot_exit[row_idx],
2427 q_right[row_idx],
2428 mu[row_idx],
2429 sigma,
2430 include_log_sigma,
2431 )?;
2432 entry[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
2434 exit[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
2435 derivative[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
2436 right[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
2443 }
2444 Ok(crate::survival::OffsetChannelResiduals {
2445 exit,
2446 entry,
2447 derivative,
2448 right,
2449 })
2450 }
2451
2452 fn add_pullback_primary_block_diagonals(
2457 &self,
2458 row: usize,
2459 primary_hessian: &Array2<f64>,
2460 time_target: &mut Array2<f64>,
2461 mean_target: &mut Array2<f64>,
2462 log_sigma_target: Option<&mut Array2<f64>>,
2463 ) -> Result<(), String> {
2464 let h = primary_hessian;
2465 dense_outer_accumulate(
2469 time_target,
2470 h[[
2471 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2472 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2473 ]],
2474 self.x_time_entry.row(row),
2475 );
2476 dense_outer_accumulate(
2477 time_target,
2478 h[[
2479 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2480 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2481 ]],
2482 self.x_time_exit.row(row),
2483 );
2484 dense_outer_accumulate(
2485 time_target,
2486 h[[
2487 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2488 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2489 ]],
2490 self.x_time_derivative_exit.row(row),
2491 );
2492 dense_outer_accumulate(
2493 time_target,
2494 h[[
2495 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2496 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2497 ]],
2498 self.x_time_right.row(row),
2499 );
2500 for (a, b, lhs, rhs) in [
2501 (
2502 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2503 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2504 &self.x_time_entry,
2505 &self.x_time_exit,
2506 ),
2507 (
2508 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2509 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2510 &self.x_time_entry,
2511 &self.x_time_derivative_exit,
2512 ),
2513 (
2514 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2515 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2516 &self.x_time_exit,
2517 &self.x_time_derivative_exit,
2518 ),
2519 (
2520 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2521 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2522 &self.x_time_entry,
2523 &self.x_time_right,
2524 ),
2525 (
2526 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2527 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2528 &self.x_time_exit,
2529 &self.x_time_right,
2530 ),
2531 (
2532 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2533 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2534 &self.x_time_derivative_exit,
2535 &self.x_time_right,
2536 ),
2537 ] {
2538 let weight = h[[a, b]];
2539 if weight == 0.0 {
2540 continue;
2541 }
2542 dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2543 }
2544 let mean_weight = h[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2546 self.x_mean
2547 .syr_row_into_view(row, mean_weight, mean_target.view_mut())
2548 .map_err(|error| {
2549 format!(
2550 "latent survival mean block-diagonal pullback dimension mismatch: row={row}, mean_target_dim={:?}, x_mean_cols={}, error={error}",
2551 mean_target.dim(),
2552 self.x_mean.ncols()
2553 )
2554 })?;
2555 if let Some(target) = log_sigma_target {
2557 target[[0, 0]] += h[[
2558 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2559 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2560 ]];
2561 }
2562 Ok(())
2563 }
2564
2565 fn evaluate_exact_newton_block_diagonals(
2570 &self,
2571 block_states: &[ParameterBlockState],
2572 ) -> Result<
2573 (
2574 f64,
2575 Array1<f64>,
2576 Array2<f64>,
2577 Array2<f64>,
2578 Option<Array2<f64>>,
2579 ),
2580 String,
2581 > {
2582 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2583 let q_right = self.time_q_right(block_states)?;
2584 let sigma = self.latent_sd(block_states)?;
2585 let slices = self.joint_slices();
2586 let include_log_sigma = slices.log_sigma.is_some();
2587 let mut ll = 0.0;
2588 let mut gradient = Array1::<f64>::zeros(slices.total);
2589 let p_time = slices.time.len();
2590 let p_mean = slices.mean.len();
2591 let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
2592 let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
2593 let mut hess_log_sigma = if include_log_sigma {
2594 Some(Array2::<f64>::zeros((1, 1)))
2595 } else {
2596 None
2597 };
2598 for row_idx in 0..self.event_target.len() {
2599 let wi = self.weights[row_idx];
2600 if wi <= MIN_WEIGHT {
2601 continue;
2602 }
2603 let row = self.build_row_at(
2604 row_idx,
2605 q_entry[row_idx],
2606 q_exit[row_idx],
2607 qdot_exit[row_idx],
2608 q_right[row_idx],
2609 )?;
2610 let (row_ll, primary_gradient, primary_hessian) =
2611 latent_survival_row_primary_gradient_hessian(
2612 &self.quadctx,
2613 &row,
2614 q_entry[row_idx],
2615 q_exit[row_idx],
2616 qdot_exit[row_idx],
2617 q_right[row_idx],
2618 mu[row_idx],
2619 sigma,
2620 include_log_sigma,
2621 )?;
2622 ll += wi * row_ll;
2623 self.add_pullback_primary_gradient(
2624 &mut gradient,
2625 row_idx,
2626 &slices,
2627 &primary_gradient,
2628 wi,
2629 )?;
2630 self.add_pullback_primary_block_diagonals(
2631 row_idx,
2632 &(wi * primary_hessian),
2633 &mut hess_time,
2634 &mut hess_mean,
2635 hess_log_sigma.as_mut(),
2636 )?;
2637 }
2638 Ok((ll, gradient, hess_time, hess_mean, hess_log_sigma))
2639 }
2640
2641 fn evaluate_exact_newton_joint_dense(
2642 &self,
2643 block_states: &[ParameterBlockState],
2644 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
2645 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2646 let q_right = self.time_q_right(block_states)?;
2647 let sigma = self.latent_sd(block_states)?;
2648 let slices = self.joint_slices();
2649 let include_log_sigma = slices.log_sigma.is_some();
2650 let total = slices.total;
2651 let acc = deterministic_latent_survival_row_reduction(
2652 self.event_target.len(),
2653 || LatentSurvivalJointDenseAccum {
2654 ll: 0.0,
2655 gradient: Array1::<f64>::zeros(total),
2656 hessian: Array2::<f64>::zeros((total, total)),
2657 },
2658 |row_idx, acc| {
2659 let wi = self.weights[row_idx];
2660 if wi <= MIN_WEIGHT {
2661 return Ok(());
2662 }
2663 let row = self.build_row_at(
2664 row_idx,
2665 q_entry[row_idx],
2666 q_exit[row_idx],
2667 qdot_exit[row_idx],
2668 q_right[row_idx],
2669 )?;
2670 let (row_ll, primary_gradient, primary_hessian) =
2671 latent_survival_row_primary_gradient_hessian(
2672 &self.quadctx,
2673 &row,
2674 q_entry[row_idx],
2675 q_exit[row_idx],
2676 qdot_exit[row_idx],
2677 q_right[row_idx],
2678 mu[row_idx],
2679 sigma,
2680 include_log_sigma,
2681 )?;
2682 acc.ll += wi * row_ll;
2683 self.add_pullback_primary_gradient(
2684 &mut acc.gradient,
2685 row_idx,
2686 &slices,
2687 &primary_gradient,
2688 wi,
2689 )?;
2690 self.add_pullback_primary_hessian(
2691 &mut acc.hessian,
2692 row_idx,
2693 &slices,
2694 &(wi * primary_hessian),
2695 )?;
2696 Ok(())
2697 },
2698 |total_acc, chunk_acc| {
2699 total_acc.ll += chunk_acc.ll;
2700 total_acc.gradient += &chunk_acc.gradient;
2701 total_acc.hessian += &chunk_acc.hessian;
2702 },
2703 )?;
2704 Ok((acc.ll, acc.gradient, acc.hessian))
2705 }
2706
2707 fn exact_newton_joint_hessian_directional_derivative_dense(
2708 &self,
2709 block_states: &[ParameterBlockState],
2710 d_beta_flat: &Array1<f64>,
2711 ) -> Result<Array2<f64>, String> {
2712 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2713 let q_right = self.time_q_right(block_states)?;
2714 let sigma = self.latent_sd(block_states)?;
2715 let slices = self.joint_slices();
2716 if d_beta_flat.len() != slices.total {
2717 return Err(format!(
2718 "latent survival joint dH direction length mismatch: got {}, expected {}",
2719 d_beta_flat.len(),
2720 slices.total
2721 ));
2722 }
2723 let include_log_sigma = slices.log_sigma.is_some();
2724 let total = slices.total;
2725 let acc = deterministic_latent_survival_row_reduction(
2726 self.event_target.len(),
2727 || LatentSurvivalDenseHessianAccum {
2728 hessian: Array2::<f64>::zeros((total, total)),
2729 },
2730 |row_idx, acc| {
2731 let wi = self.weights[row_idx];
2732 if wi <= MIN_WEIGHT {
2733 return Ok(());
2734 }
2735 let row = self.build_row_at(
2736 row_idx,
2737 q_entry[row_idx],
2738 q_exit[row_idx],
2739 qdot_exit[row_idx],
2740 q_right[row_idx],
2741 )?;
2742 let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
2743 let third = latent_survival_row_primary_third_contracted(
2744 &self.quadctx,
2745 &row,
2746 q_entry[row_idx],
2747 q_exit[row_idx],
2748 qdot_exit[row_idx],
2749 q_right[row_idx],
2750 mu[row_idx],
2751 sigma,
2752 &direction,
2753 include_log_sigma,
2754 )?;
2755 self.add_pullback_primary_hessian(
2756 &mut acc.hessian,
2757 row_idx,
2758 &slices,
2759 &(wi * third),
2760 )?;
2761 Ok(())
2762 },
2763 |total_acc, chunk_acc| {
2764 total_acc.hessian += &chunk_acc.hessian;
2765 },
2766 )?;
2767 Ok(acc.hessian)
2768 }
2769
2770 fn exact_newton_joint_hessian_second_directional_derivative_dense(
2771 &self,
2772 block_states: &[ParameterBlockState],
2773 d_beta_u_flat: &Array1<f64>,
2774 d_beta_v_flat: &Array1<f64>,
2775 ) -> Result<Array2<f64>, String> {
2776 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2777 let q_right = self.time_q_right(block_states)?;
2778 let sigma = self.latent_sd(block_states)?;
2779 let slices = self.joint_slices();
2780 if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
2781 return Err(format!(
2782 "latent survival joint d2H direction length mismatch: got {} and {}, expected {}",
2783 d_beta_u_flat.len(),
2784 d_beta_v_flat.len(),
2785 slices.total
2786 ));
2787 }
2788 let include_log_sigma = slices.log_sigma.is_some();
2789 let total = slices.total;
2790 let acc = deterministic_latent_survival_row_reduction(
2791 self.event_target.len(),
2792 || LatentSurvivalDenseHessianAccum {
2793 hessian: Array2::<f64>::zeros((total, total)),
2794 },
2795 |row_idx, acc| {
2796 let wi = self.weights[row_idx];
2797 if wi <= MIN_WEIGHT {
2798 return Ok(());
2799 }
2800 let row = self.build_row_at(
2801 row_idx,
2802 q_entry[row_idx],
2803 q_exit[row_idx],
2804 qdot_exit[row_idx],
2805 q_right[row_idx],
2806 )?;
2807 let direction_u =
2808 self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
2809 let direction_v =
2810 self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
2811 let fourth = latent_survival_row_primary_fourth_contracted(
2812 &self.quadctx,
2813 &row,
2814 q_entry[row_idx],
2815 q_exit[row_idx],
2816 qdot_exit[row_idx],
2817 q_right[row_idx],
2818 mu[row_idx],
2819 sigma,
2820 &direction_u,
2821 &direction_v,
2822 include_log_sigma,
2823 )?;
2824 self.add_pullback_primary_hessian(
2825 &mut acc.hessian,
2826 row_idx,
2827 &slices,
2828 &(wi * fourth),
2829 )?;
2830 Ok(())
2831 },
2832 |total_acc, chunk_acc| {
2833 total_acc.hessian += &chunk_acc.hessian;
2834 },
2835 )?;
2836 Ok(acc.hessian)
2837 }
2838}
2839
2840fn log_kernel_ratio(
2841 bundle: &crate::survival::lognormal_kernel::LogLognormalKernelBundle,
2842 num: usize,
2843 den: usize,
2844) -> f64 {
2845 let delta = bundle.get(num) - bundle.get(den);
2846 if delta.is_finite() {
2847 delta.exp()
2848 } else if delta > 0.0 {
2849 f64::INFINITY
2850 } else {
2851 0.0
2852 }
2853}
2854
2855fn logk_q_derivatives(
2856 quadctx: &QuadratureContext,
2857 k: usize,
2858 mass: f64,
2859 mu: f64,
2860 sigma: f64,
2861) -> Result<(f64, f64, IntegratedExpectationMode), LatentSurvivalError> {
2862 if mass <= 0.0 {
2863 return Ok((0.0, 0.0, IntegratedExpectationMode::ExactClosedForm));
2864 }
2865 let bundle = log_kernel_bundle(quadctx, mass, mu, sigma, k + 2).map_err(|e| {
2866 LatentSurvivalError::NumericalFailure {
2867 reason: format!("latent survival kernel evaluation failed: {e}"),
2868 }
2869 })?;
2870 let r1 = log_kernel_ratio(&bundle, k + 1, k);
2871 let r2 = log_kernel_ratio(&bundle, k + 2, k);
2872 let d1 = -mass * r1;
2873 let d2 = d1 + mass * mass * (r2 - r1 * r1);
2874 Ok((d1, d2, bundle.mode))
2875}
2876
2877fn latent_survival_time_jet(
2878 quadctx: &QuadratureContext,
2879 row: &LatentSurvivalRow,
2880 qdot_exit: f64,
2881 mu: f64,
2882 sigma: f64,
2883) -> Result<LatentSurvivalTimeJet, LatentSurvivalError> {
2884 let (entry_d1, entry_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_entry, mu, sigma)?;
2885 match row.event_type {
2886 LatentSurvivalEventType::RightCensored => {
2887 let (exit_d1, exit_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2888 Ok(LatentSurvivalTimeJet {
2889 grad_entry: -entry_d1,
2890 grad_exit: exit_d1,
2891 neg_hess_entry: entry_d2,
2892 neg_hess_exit: -exit_d2,
2893 })
2894 }
2895 LatentSurvivalEventType::ExactEvent => {
2896 if !(qdot_exit.is_finite() && qdot_exit > 0.0) {
2897 return Err(LatentSurvivalError::NumericalFailure {
2898 reason: format!(
2899 "latent survival requires positive finite baseline hazard derivative, got {qdot_exit}"
2900 ),
2901 });
2902 }
2903 if row.hazard_unloaded > 0.0 {
2904 let bundle =
2905 log_kernel_bundle(quadctx, row.mass_exit, mu, sigma, 3).map_err(|e| {
2906 LatentSurvivalError::NumericalFailure {
2907 reason: format!("latent survival kernel evaluation failed: {e}"),
2908 }
2909 })?;
2910 let (unloaded_d1, unloaded_d2, _) =
2911 logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2912 let (loaded_log_d1, loaded_d2, _) =
2913 logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2914 let loaded_d1 = 1.0 + loaded_log_d1;
2915 let log_loaded = row.hazard_loaded.ln() + bundle.get(1);
2916 let log_unloaded = row.hazard_unloaded.ln() + bundle.get(0);
2917 let shift = log_loaded.max(log_unloaded);
2918 let loaded_weight = (log_loaded - shift).exp();
2919 let unloaded_weight = (log_unloaded - shift).exp();
2920 let normalizer = loaded_weight + unloaded_weight;
2921 if !(normalizer.is_finite() && normalizer > 0.0) {
2922 return Err(LatentSurvivalError::NumericalFailure {
2923 reason: "latent survival exact-event numerator became non-finite under loaded/unloaded hazard decomposition"
2924 .to_string(),
2925 });
2926 }
2927 let w_loaded = loaded_weight / normalizer;
2928 let w_unloaded = unloaded_weight / normalizer;
2929 let grad_exit = w_loaded * loaded_d1 + w_unloaded * unloaded_d1;
2930 let d2_exit = w_loaded * (loaded_d2 + loaded_d1 * loaded_d1)
2931 + w_unloaded * (unloaded_d2 + unloaded_d1 * unloaded_d1)
2932 - grad_exit * grad_exit;
2933 Ok(LatentSurvivalTimeJet {
2934 grad_entry: -entry_d1,
2935 grad_exit,
2936 neg_hess_entry: entry_d2,
2937 neg_hess_exit: -d2_exit,
2938 })
2939 } else {
2940 let (exit_d1, exit_d2, _) =
2941 logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2942 Ok(LatentSurvivalTimeJet {
2943 grad_entry: -entry_d1,
2944 grad_exit: 1.0 + exit_d1,
2945 neg_hess_entry: entry_d2,
2946 neg_hess_exit: -exit_d2,
2947 })
2948 }
2949 }
2950 LatentSurvivalEventType::IntervalCensored => {
2951 Err(LatentSurvivalError::UnsupportedConfiguration {
2952 reason:
2953 "latent survival dynamic time derivatives do not implement interval censoring"
2954 .to_string(),
2955 })
2956 }
2957 }
2958}
2959
2960fn dense_outer_accumulate<S>(
2961 target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2962 weight: f64,
2963 x: ArrayView1<'_, f64>,
2964) where
2965 S: ndarray::DataMut<Elem = f64>,
2966{
2967 for a in 0..x.len() {
2968 let xa = x[a];
2969 if xa == 0.0 {
2970 continue;
2971 }
2972 for b in 0..x.len() {
2973 let xb = x[b];
2974 if xb == 0.0 {
2975 continue;
2976 }
2977 target[[a, b]] += weight * xa * xb;
2978 }
2979 }
2980}
2981
2982fn dense_symmetric_cross_accumulate<S>(
2983 target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2984 weight: f64,
2985 x: ArrayView1<'_, f64>,
2986 y: ArrayView1<'_, f64>,
2987) where
2988 S: ndarray::DataMut<Elem = f64>,
2989{
2990 for a in 0..x.len() {
2991 let xa = x[a];
2992 let ya = y[a];
2993 if xa == 0.0 && ya == 0.0 {
2994 continue;
2995 }
2996 for b in 0..x.len() {
2997 let xb = x[b];
2998 let yb = y[b];
2999 let contribution = xa * yb + ya * xb;
3000 if contribution == 0.0 {
3001 continue;
3002 }
3003 target[[a, b]] += weight * contribution;
3004 }
3005 }
3006}
3007
3008fn build_latent_survival_row(
3009 row_index: usize,
3010 hazard_loading: HazardLoading,
3011 event_type: LatentSurvivalEventType,
3012 q_entry: f64,
3013 q_exit: f64,
3014 qdot_exit: f64,
3015 q_right: f64,
3016 unloaded_mass_entry: f64,
3017 unloaded_mass_exit: f64,
3018 unloaded_mass_right: f64,
3019 unloaded_hazard_exit: f64,
3020) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3021 if !(q_entry.is_finite() && q_exit.is_finite()) {
3022 return Err(LatentSurvivalError::NumericalFailure {
3023 reason: format!(
3024 "latent survival requires finite q_entry and q_exit, got q_entry={q_entry}, q_exit={q_exit}"
3025 ),
3026 });
3027 }
3028 if q_exit < q_entry {
3029 return Err(LatentSurvivalError::NumericalFailure {
3030 reason: format!(
3031 "latent survival requires q_exit >= q_entry so cumulative mass is monotone, got q_entry={q_entry}, q_exit={q_exit}"
3032 ),
3033 });
3034 }
3035 if !(unloaded_mass_entry.is_finite()
3036 && unloaded_mass_exit.is_finite()
3037 && unloaded_hazard_exit.is_finite())
3038 {
3039 return Err(LatentSurvivalError::InvalidDataset {
3040 reason: format!(
3041 "latent survival requires finite unloaded components, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3042 ),
3043 });
3044 }
3045 if unloaded_mass_entry < 0.0
3046 || unloaded_mass_exit < unloaded_mass_entry
3047 || unloaded_hazard_exit < 0.0
3048 {
3049 return Err(LatentSurvivalError::InvalidDataset {
3050 reason: format!(
3051 "latent survival requires unloaded masses/hazard to be non-negative and monotone, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3052 ),
3053 });
3054 }
3055 let mass_entry = q_entry.exp();
3056 let mass_exit = q_exit.exp();
3057 let row = match event_type {
3058 LatentSurvivalEventType::RightCensored => {
3059 validate_unloaded_components_for_loading(
3060 "latent-survival",
3061 row_index,
3062 hazard_loading,
3063 unloaded_mass_entry,
3064 unloaded_mass_exit,
3065 Some(unloaded_hazard_exit),
3066 )?;
3067 LatentSurvivalRow::right_censored(
3068 mass_entry,
3069 mass_exit,
3070 unloaded_mass_entry,
3071 unloaded_mass_exit,
3072 )
3073 }
3074 LatentSurvivalEventType::ExactEvent => {
3075 validate_unloaded_components_for_loading(
3076 "latent-survival",
3077 row_index,
3078 hazard_loading,
3079 unloaded_mass_entry,
3080 unloaded_mass_exit,
3081 Some(unloaded_hazard_exit),
3082 )?;
3083 LatentSurvivalRow::exact_event(
3084 mass_entry,
3085 mass_exit,
3086 unloaded_mass_entry,
3087 unloaded_mass_exit,
3088 mass_exit
3089 * if qdot_exit.is_finite() && qdot_exit > 0.0 {
3090 qdot_exit
3091 } else {
3092 return Err(LatentSurvivalError::NumericalFailure {
3093 reason: format!(
3094 "latent survival exact event requires positive finite baseline hazard derivative, got {qdot_exit}"
3095 ),
3096 });
3097 },
3098 unloaded_hazard_exit,
3099 )
3100 }
3101 LatentSurvivalEventType::IntervalCensored => {
3102 if !q_right.is_finite() {
3110 return Err(LatentSurvivalError::NumericalFailure {
3111 reason: format!(
3112 "latent survival interval row {} requires a finite q_right, got {q_right}",
3113 row_index + 1
3114 ),
3115 });
3116 }
3117 if q_right < q_exit {
3118 return Err(LatentSurvivalError::NumericalFailure {
3119 reason: format!(
3120 "latent survival interval row {} requires q_right >= q_exit (R >= L) so the \
3121 survival-mass difference is non-negative, got q_left={q_exit}, q_right={q_right}",
3122 row_index + 1
3123 ),
3124 });
3125 }
3126 if !(unloaded_mass_right.is_finite()) || unloaded_mass_right < unloaded_mass_exit {
3127 return Err(LatentSurvivalError::InvalidDataset {
3128 reason: format!(
3129 "latent survival interval row {} requires a finite unloaded right mass >= unloaded left mass, got left={unloaded_mass_exit}, right={unloaded_mass_right}",
3130 row_index + 1
3131 ),
3132 });
3133 }
3134 let mass_right = q_right.exp();
3138 LatentSurvivalRow::interval_censored(
3139 mass_entry,
3140 mass_exit,
3141 mass_right,
3142 unloaded_mass_entry,
3143 unloaded_mass_exit,
3144 unloaded_mass_right,
3145 )
3146 }
3147 };
3148 row.validate()
3149 .map_err(|e| LatentSurvivalError::InvalidDataset {
3150 reason: e.to_string(),
3151 })?;
3152 Ok(row)
3153}
3154
3155#[derive(Clone, Copy)]
3156struct BinaryFromLogSurvival {
3157 log_lik: f64,
3158 grad_scale: f64,
3161 neg_hess_scale: f64,
3170 outer_scale: f64,
3172 grad_scale_prime: f64,
3174 grad_scale_second: f64,
3176 outer_scale_prime: f64,
3178 outer_scale_second: f64,
3180}
3181
3182#[inline]
3189fn binary_log_survival_scales(survival: f64, event_prob: f64) -> (f64, f64, f64, f64, f64) {
3190 let log_lik = event_prob.ln();
3198 let p = event_prob;
3199 let p2 = p * p;
3200 let p3 = p2 * p;
3201 let p4 = p3 * p;
3202 let s = survival;
3203 let s2 = s * s;
3204 let s3 = s2 * s;
3205 let ell_prime = -s / p;
3206 let ell_pp = -s / p2;
3207 let ell_ppp = -s * (1.0 + s) / p3;
3208 let ell_pppp = -(s + 4.0 * s2 + s3) / p4;
3218 (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp)
3219}
3220
3221fn binary_from_log_survival(
3222 log_survival: f64,
3223 event: u8,
3224) -> Result<BinaryFromLogSurvival, LatentSurvivalError> {
3225 if event == 0 {
3226 return Ok(BinaryFromLogSurvival {
3228 log_lik: log_survival,
3229 grad_scale: 1.0,
3230 neg_hess_scale: 1.0,
3231 outer_scale: 0.0,
3232 grad_scale_prime: 0.0,
3233 grad_scale_second: 0.0,
3234 outer_scale_prime: 0.0,
3235 outer_scale_second: 0.0,
3236 });
3237 }
3238 if event != 1 {
3239 return Err(LatentSurvivalError::InvalidDataset {
3240 reason: format!("latent-binary requires event targets in {{0,1}}, got {event}"),
3241 });
3242 }
3243 const MAX_LOG_SURVIVAL: f64 = -1e-15;
3250 let log_survival = log_survival.min(MAX_LOG_SURVIVAL);
3251 let survival = log_survival.exp();
3252 let event_prob = 1.0 - survival;
3253 if !(event_prob.is_finite() && event_prob > 0.0) {
3254 return Err(LatentSurvivalError::NumericalFailure {
3255 reason: format!(
3256 "latent-binary encountered non-positive event probability from log survival {log_survival}"
3257 ),
3258 });
3259 }
3260 let (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp) =
3261 binary_log_survival_scales(survival, event_prob);
3262 let grad_scale = ell_prime;
3263 let neg_hess_scale = ell_prime; let outer_scale = -ell_pp;
3265 let grad_scale_prime = ell_pp;
3266 let grad_scale_second = ell_ppp;
3267 let outer_scale_prime = -ell_ppp;
3268 let outer_scale_second = -ell_pppp;
3269 assert!(
3274 (grad_scale - neg_hess_scale).abs() <= 1e-15 * grad_scale.abs().max(1.0),
3275 "binary_from_log_survival invariant: neg_hess_scale ({neg_hess_scale}) must equal grad_scale ({grad_scale}) so that grad_scale and the coefficient on neg_hessian share sign"
3276 );
3277 assert!(
3278 outer_scale >= 0.0 || !outer_scale.is_finite(),
3279 "binary_from_log_survival invariant: outer_scale (= -ℓ'') must be non-negative for event=1; got {outer_scale}"
3280 );
3281 Ok(BinaryFromLogSurvival {
3282 log_lik,
3283 grad_scale,
3284 neg_hess_scale,
3285 outer_scale,
3286 grad_scale_prime,
3287 grad_scale_second,
3288 outer_scale_prime,
3289 outer_scale_second,
3290 })
3291}
3292
3293impl LatentBinaryFamily {
3294 fn build_right_censored_row_at(
3300 &self,
3301 row_idx: usize,
3302 q_entry: f64,
3303 q_exit: f64,
3304 ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3305 build_latent_survival_row(
3306 row_idx,
3307 self.hazard_loading,
3308 LatentSurvivalEventType::RightCensored,
3309 q_entry,
3310 q_exit,
3311 1.0,
3312 q_exit,
3313 self.unloaded_mass_entry[row_idx],
3314 self.unloaded_mass_exit[row_idx],
3315 0.0,
3316 0.0,
3317 )
3318 }
3319
3320 fn joint_slices(&self) -> LatentSurvivalJointSlices {
3321 let p_time = self.x_time_exit.ncols();
3322 let p_mean = self.x_mean.ncols();
3323 LatentSurvivalJointSlices {
3324 time: 0..p_time,
3325 mean: p_time..p_time + p_mean,
3326 log_sigma: None,
3327 total: p_time + p_mean,
3328 }
3329 }
3330
3331 fn row_primary_direction_from_flat(
3332 &self,
3333 row: usize,
3334 slices: &LatentSurvivalJointSlices,
3335 d_beta_flat: &Array1<f64>,
3336 ) -> Array1<f64> {
3337 let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
3338 let d_time = d_beta_flat.slice(s![slices.time.clone()]);
3339 out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
3340 out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
3341 out[LATENT_SURVIVAL_PRIMARY_MU] = self
3342 .x_mean
3343 .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
3344 out
3345 }
3346
3347 fn add_pullback_primary_gradient(
3348 &self,
3349 target: &mut Array1<f64>,
3350 row: usize,
3351 slices: &LatentSurvivalJointSlices,
3352 primary_gradient: &Array1<f64>,
3353 weight: f64,
3354 ) {
3355 for (primary_idx, time_vec) in [
3356 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3357 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3358 ] {
3359 let scale = weight * primary_gradient[primary_idx];
3360 if scale == 0.0 {
3361 continue;
3362 }
3363 for i in 0..time_vec.len() {
3364 let xi = time_vec[i];
3365 if xi != 0.0 {
3366 target[slices.time.start + i] += scale * xi;
3367 }
3368 }
3369 }
3370
3371 let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
3372 if mean_scale != 0.0 {
3373 self.x_mean
3374 .axpy_row_into(
3375 row,
3376 mean_scale,
3377 &mut target.slice_mut(s![slices.mean.clone()]),
3378 )
3379 .unwrap_or_else(|error| {
3384 panic!(
3385 "latent binary mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
3386 slices.mean,
3387 target.len(),
3388 self.x_mean.ncols()
3389 )
3390 });
3391 }
3392 }
3393
3394 fn add_pullback_primary_hessian(
3395 &self,
3396 target: &mut Array2<f64>,
3397 row: usize,
3398 slices: &LatentSurvivalJointSlices,
3399 primary_hessian: &Array2<f64>,
3400 ) {
3401 {
3402 let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
3403 dense_outer_accumulate(
3404 time_target,
3405 primary_hessian[[
3406 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3407 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3408 ]],
3409 self.x_time_entry.row(row),
3410 );
3411 dense_outer_accumulate(
3412 time_target,
3413 primary_hessian[[
3414 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3415 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3416 ]],
3417 self.x_time_exit.row(row),
3418 );
3419 dense_symmetric_cross_accumulate(
3420 time_target,
3421 primary_hessian[[
3422 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3423 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3424 ]],
3425 self.x_time_entry.row(row),
3426 self.x_time_exit.row(row),
3427 );
3428 }
3429
3430 let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
3431 self.x_mean
3432 .syr_row_into_view(
3433 row,
3434 mean_weight,
3435 target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
3436 )
3437 .unwrap_or_else(|error| {
3438 panic!(
3444 "latent binary mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
3445 slices.mean,
3446 target.dim(),
3447 self.x_mean.ncols()
3448 )
3449 });
3450
3451 let mean_row = self
3452 .x_mean
3453 .try_row_chunk(row..row + 1)
3454 .unwrap_or_else(|error| {
3455 panic!(
3459 "latent binary mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
3460 self.x_mean.nrows(),
3461 self.x_mean.ncols()
3462 )
3463 });
3464 let mean_vec = mean_row.row(0);
3465 for (primary_idx, time_vec) in [
3466 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3467 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3468 ] {
3469 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
3470 if weight == 0.0 {
3471 continue;
3472 }
3473 for i in 0..time_vec.len() {
3474 let xi = time_vec[i];
3475 if xi == 0.0 {
3476 continue;
3477 }
3478 for j in 0..mean_vec.len() {
3479 let xj = mean_vec[j];
3480 if xj == 0.0 {
3481 continue;
3482 }
3483 target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
3484 target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
3485 }
3486 }
3487 }
3488 }
3489
3490 fn evaluate_exact_newton_joint_dense(
3491 &self,
3492 block_states: &[ParameterBlockState],
3493 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3494 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3495 let slices = self.joint_slices();
3496 let mut ll = 0.0;
3497 let mut gradient = Array1::<f64>::zeros(slices.total);
3498 let mut hessian = Array2::<f64>::zeros((slices.total, slices.total));
3499 for row_idx in 0..self.event_target.len() {
3500 let wi = self.weights[row_idx];
3501 if wi <= MIN_WEIGHT {
3502 continue;
3503 }
3504 let row =
3505 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3506 let (row_log_survival, survival_gradient, survival_hessian) =
3507 latent_survival_row_primary_gradient_hessian(
3508 &self.quadctx,
3509 &row,
3510 q_entry[row_idx],
3511 q_exit[row_idx],
3512 1.0,
3513 q_exit[row_idx],
3514 mu[row_idx],
3515 self.latent_sd,
3516 false,
3517 )?;
3518 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3519 ll += wi * binary.log_lik;
3520 let primary_gradient = binary.grad_scale * &survival_gradient;
3521 let mut primary_hessian = binary.grad_scale * survival_hessian;
3522 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3523 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3524 primary_hessian[[a, b]] +=
3525 binary.outer_scale * survival_gradient[a] * survival_gradient[b];
3526 }
3527 }
3528 self.add_pullback_primary_gradient(
3529 &mut gradient,
3530 row_idx,
3531 &slices,
3532 &primary_gradient,
3533 wi,
3534 );
3535 self.add_pullback_primary_hessian(
3536 &mut hessian,
3537 row_idx,
3538 &slices,
3539 &(wi * primary_hessian),
3540 );
3541 }
3542 Ok((ll, gradient, hessian))
3543 }
3544
3545 pub fn offset_channel_residuals(
3557 &self,
3558 block_states: &[ParameterBlockState],
3559 ) -> Result<crate::survival::OffsetChannelResiduals, String> {
3560 let n = self.event_target.len();
3561 if block_states.is_empty() {
3562 log::warn!(
3563 "LatentBinaryFamily::offset_channel_residuals: block_states is empty \
3564 (degraded fit); returning zero offset residuals (n={n})"
3565 );
3566 return Ok(crate::survival::OffsetChannelResiduals {
3567 exit: Array1::<f64>::zeros(n),
3568 entry: Array1::<f64>::zeros(n),
3569 derivative: Array1::<f64>::zeros(n),
3570 right: Array1::<f64>::zeros(n),
3571 });
3572 }
3573 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3574 let mut entry = Array1::<f64>::zeros(n);
3575 let mut exit = Array1::<f64>::zeros(n);
3576 for row_idx in 0..n {
3577 let wi = self.weights[row_idx];
3578 if wi <= MIN_WEIGHT {
3579 continue;
3580 }
3581 let row =
3582 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3583 let (row_log_survival, survival_gradient, _) =
3584 latent_survival_row_primary_gradient_hessian(
3585 &self.quadctx,
3586 &row,
3587 q_entry[row_idx],
3588 q_exit[row_idx],
3589 1.0,
3590 q_exit[row_idx],
3591 mu[row_idx],
3592 self.latent_sd,
3593 false,
3594 )?;
3595 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3596 entry[row_idx] =
3598 -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
3599 exit[row_idx] =
3600 -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
3601 }
3602 Ok(crate::survival::OffsetChannelResiduals {
3603 exit,
3604 entry,
3605 derivative: Array1::<f64>::zeros(n),
3606 right: Array1::<f64>::zeros(n),
3609 })
3610 }
3611
3612 fn exact_newton_joint_hessian_directional_derivative_dense(
3613 &self,
3614 block_states: &[ParameterBlockState],
3615 d_beta_flat: &Array1<f64>,
3616 ) -> Result<Array2<f64>, String> {
3617 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3618 let slices = self.joint_slices();
3619 if d_beta_flat.len() != slices.total {
3620 return Err(format!(
3621 "latent binary joint dH direction length mismatch: got {}, expected {}",
3622 d_beta_flat.len(),
3623 slices.total
3624 ));
3625 }
3626 let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3627 for row_idx in 0..self.event_target.len() {
3628 let wi = self.weights[row_idx];
3629 if wi <= MIN_WEIGHT {
3630 continue;
3631 }
3632 let row =
3633 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3634 let (row_log_survival, survival_gradient, survival_hessian) =
3635 latent_survival_row_primary_gradient_hessian(
3636 &self.quadctx,
3637 &row,
3638 q_entry[row_idx],
3639 q_exit[row_idx],
3640 1.0,
3641 q_exit[row_idx],
3642 mu[row_idx],
3643 self.latent_sd,
3644 false,
3645 )?;
3646 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3647 let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
3648 let third = latent_survival_row_primary_third_contracted(
3649 &self.quadctx,
3650 &row,
3651 q_entry[row_idx],
3652 q_exit[row_idx],
3653 1.0,
3654 q_exit[row_idx],
3655 mu[row_idx],
3656 self.latent_sd,
3657 &direction,
3658 false,
3659 )?;
3660 let g_u = -survival_hessian.dot(&direction);
3661 let t_u = survival_gradient.dot(&direction);
3662 let mut primary = binary.grad_scale * third;
3663 primary.scaled_add(binary.grad_scale_prime * t_u, &survival_hessian);
3664 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3665 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3666 primary[[a, b]] += binary.outer_scale_prime
3667 * t_u
3668 * survival_gradient[a]
3669 * survival_gradient[b]
3670 + binary.outer_scale
3671 * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b]);
3672 }
3673 }
3674 self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3675 }
3676 Ok(out)
3677 }
3678
3679 fn exact_newton_joint_hessian_second_directional_derivative_dense(
3680 &self,
3681 block_states: &[ParameterBlockState],
3682 d_beta_u_flat: &Array1<f64>,
3683 d_beta_v_flat: &Array1<f64>,
3684 ) -> Result<Array2<f64>, String> {
3685 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3686 let slices = self.joint_slices();
3687 if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
3688 return Err(format!(
3689 "latent binary joint d2H direction length mismatch: got {} and {}, expected {}",
3690 d_beta_u_flat.len(),
3691 d_beta_v_flat.len(),
3692 slices.total
3693 ));
3694 }
3695 let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3696 for row_idx in 0..self.event_target.len() {
3697 let wi = self.weights[row_idx];
3698 if wi <= MIN_WEIGHT {
3699 continue;
3700 }
3701 let row =
3702 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3703 let (row_log_survival, survival_gradient, survival_hessian) =
3704 latent_survival_row_primary_gradient_hessian(
3705 &self.quadctx,
3706 &row,
3707 q_entry[row_idx],
3708 q_exit[row_idx],
3709 1.0,
3710 q_exit[row_idx],
3711 mu[row_idx],
3712 self.latent_sd,
3713 false,
3714 )?;
3715 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3716 let direction_u = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
3717 let direction_v = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
3718 let third_u = latent_survival_row_primary_third_contracted(
3719 &self.quadctx,
3720 &row,
3721 q_entry[row_idx],
3722 q_exit[row_idx],
3723 1.0,
3724 q_exit[row_idx],
3725 mu[row_idx],
3726 self.latent_sd,
3727 &direction_u,
3728 false,
3729 )?;
3730 let third_v = latent_survival_row_primary_third_contracted(
3731 &self.quadctx,
3732 &row,
3733 q_entry[row_idx],
3734 q_exit[row_idx],
3735 1.0,
3736 q_exit[row_idx],
3737 mu[row_idx],
3738 self.latent_sd,
3739 &direction_v,
3740 false,
3741 )?;
3742 let fourth = latent_survival_row_primary_fourth_contracted(
3743 &self.quadctx,
3744 &row,
3745 q_entry[row_idx],
3746 q_exit[row_idx],
3747 1.0,
3748 q_exit[row_idx],
3749 mu[row_idx],
3750 self.latent_sd,
3751 &direction_u,
3752 &direction_v,
3753 false,
3754 )?;
3755 let g_u = -survival_hessian.dot(&direction_u);
3756 let g_v = -survival_hessian.dot(&direction_v);
3757 let g_uv = -third_v.dot(&direction_u);
3758 let t_u = survival_gradient.dot(&direction_u);
3759 let t_v = survival_gradient.dot(&direction_v);
3760 let l_uv = -direction_u.dot(&survival_hessian.dot(&direction_v));
3761 let c_u = binary.grad_scale_prime * t_u;
3762 let c_v = binary.grad_scale_prime * t_v;
3763 let c_uv = binary.grad_scale_second * t_u * t_v + binary.grad_scale_prime * l_uv;
3764 let o_u = binary.outer_scale_prime * t_u;
3765 let o_v = binary.outer_scale_prime * t_v;
3766 let o_uv = binary.outer_scale_second * t_u * t_v + binary.outer_scale_prime * l_uv;
3767 let mut primary = binary.grad_scale * fourth;
3768 primary.scaled_add(c_u, &third_v);
3769 primary.scaled_add(c_v, &third_u);
3770 primary.scaled_add(c_uv, &survival_hessian);
3771 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3772 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3773 primary[[a, b]] += o_uv * survival_gradient[a] * survival_gradient[b]
3774 + o_v * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b])
3775 + o_u * (g_v[a] * survival_gradient[b] + survival_gradient[a] * g_v[b])
3776 + binary.outer_scale
3777 * (g_uv[a] * survival_gradient[b]
3778 + g_u[a] * g_v[b]
3779 + g_v[a] * g_u[b]
3780 + survival_gradient[a] * g_uv[b]);
3781 }
3782 }
3783 self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3784 }
3785 Ok(out)
3786 }
3787}
3788
3789trait LatentJointHessianFamily {
3803 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices;
3804
3805 fn ws_evaluate_dense(
3806 &self,
3807 block_states: &[ParameterBlockState],
3808 ) -> Result<(f64, Array1<f64>, Array2<f64>), String>;
3809
3810 fn ws_dh_directional(
3811 &self,
3812 block_states: &[ParameterBlockState],
3813 d_beta_flat: &Array1<f64>,
3814 ) -> Result<Array2<f64>, String>;
3815
3816 fn ws_dh_second_directional(
3817 &self,
3818 block_states: &[ParameterBlockState],
3819 d_beta_u: &Array1<f64>,
3820 d_beta_v: &Array1<f64>,
3821 ) -> Result<Array2<f64>, String>;
3822
3823 fn ws_matvec_into(
3827 &self,
3828 slices: &LatentSurvivalJointSlices,
3829 block_states: &[ParameterBlockState],
3830 v: &Array1<f64>,
3831 out: &mut Array1<f64>,
3832 ) -> Result<bool, String>;
3833
3834 fn ws_label() -> &'static str;
3838}
3839
3840impl LatentJointHessianFamily for LatentSurvivalFamily {
3841 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3842 self.joint_slices()
3843 }
3844
3845 fn ws_evaluate_dense(
3846 &self,
3847 block_states: &[ParameterBlockState],
3848 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3849 self.evaluate_exact_newton_joint_dense(block_states)
3850 }
3851
3852 fn ws_dh_directional(
3853 &self,
3854 block_states: &[ParameterBlockState],
3855 d_beta_flat: &Array1<f64>,
3856 ) -> Result<Array2<f64>, String> {
3857 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3858 }
3859
3860 fn ws_dh_second_directional(
3861 &self,
3862 block_states: &[ParameterBlockState],
3863 d_beta_u: &Array1<f64>,
3864 d_beta_v: &Array1<f64>,
3865 ) -> Result<Array2<f64>, String> {
3866 self.exact_newton_joint_hessian_second_directional_derivative_dense(
3867 block_states,
3868 d_beta_u,
3869 d_beta_v,
3870 )
3871 }
3872
3873 fn ws_matvec_into(
3874 &self,
3875 slices: &LatentSurvivalJointSlices,
3876 block_states: &[ParameterBlockState],
3877 v: &Array1<f64>,
3878 out: &mut Array1<f64>,
3879 ) -> Result<bool, String> {
3880 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
3881 let q_right = self.time_q_right(block_states)?;
3882 let sigma = self.latent_sd(block_states)?;
3883 let include_log_sigma = slices.log_sigma.is_some();
3884 for row_idx in 0..self.event_target.len() {
3885 let wi = self.weights[row_idx];
3886 if wi <= MIN_WEIGHT {
3887 continue;
3888 }
3889 let row = self.build_row_at(
3890 row_idx,
3891 q_entry[row_idx],
3892 q_exit[row_idx],
3893 qdot_exit[row_idx],
3894 q_right[row_idx],
3895 )?;
3896 let (_, _, primary_hessian) = latent_survival_row_primary_gradient_hessian(
3897 &self.quadctx,
3898 &row,
3899 q_entry[row_idx],
3900 q_exit[row_idx],
3901 qdot_exit[row_idx],
3902 q_right[row_idx],
3903 mu[row_idx],
3904 sigma,
3905 include_log_sigma,
3906 )?;
3907 let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3908 let primary_hv = primary_hessian.dot(&primary_dir);
3909 self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi)?;
3910 }
3911 Ok(true)
3912 }
3913
3914 fn ws_label() -> &'static str {
3915 "survival"
3916 }
3917}
3918
3919impl LatentJointHessianFamily for LatentBinaryFamily {
3920 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3921 self.joint_slices()
3922 }
3923
3924 fn ws_evaluate_dense(
3925 &self,
3926 block_states: &[ParameterBlockState],
3927 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3928 self.evaluate_exact_newton_joint_dense(block_states)
3929 }
3930
3931 fn ws_dh_directional(
3932 &self,
3933 block_states: &[ParameterBlockState],
3934 d_beta_flat: &Array1<f64>,
3935 ) -> Result<Array2<f64>, String> {
3936 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3937 }
3938
3939 fn ws_dh_second_directional(
3940 &self,
3941 block_states: &[ParameterBlockState],
3942 d_beta_u: &Array1<f64>,
3943 d_beta_v: &Array1<f64>,
3944 ) -> Result<Array2<f64>, String> {
3945 self.exact_newton_joint_hessian_second_directional_derivative_dense(
3946 block_states,
3947 d_beta_u,
3948 d_beta_v,
3949 )
3950 }
3951
3952 fn ws_matvec_into(
3953 &self,
3954 slices: &LatentSurvivalJointSlices,
3955 block_states: &[ParameterBlockState],
3956 v: &Array1<f64>,
3957 out: &mut Array1<f64>,
3958 ) -> Result<bool, String> {
3959 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3960 for row_idx in 0..self.event_target.len() {
3961 let wi = self.weights[row_idx];
3962 if wi <= MIN_WEIGHT {
3963 continue;
3964 }
3965 let row =
3966 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3967 let (row_log_survival, survival_gradient, survival_hessian) =
3968 latent_survival_row_primary_gradient_hessian(
3969 &self.quadctx,
3970 &row,
3971 q_entry[row_idx],
3972 q_exit[row_idx],
3973 1.0,
3974 q_exit[row_idx],
3975 mu[row_idx],
3976 self.latent_sd,
3977 false,
3978 )?;
3979 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3980 let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3981 let mut primary_hv = binary.grad_scale * survival_hessian.dot(&primary_dir);
3982 let outer_dot = survival_gradient.dot(&primary_dir);
3983 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3984 primary_hv[a] += binary.outer_scale * survival_gradient[a] * outer_dot;
3985 }
3986 self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi);
3987 }
3988 Ok(true)
3989 }
3990
3991 fn ws_label() -> &'static str {
3992 "binary"
3993 }
3994}
3995
3996struct LatentHessianWorkspace<F: LatentJointHessianFamily> {
4003 family: F,
4004 block_states: Vec<ParameterBlockState>,
4005 slices: LatentSurvivalJointSlices,
4006}
4007
4008impl<F: LatentJointHessianFamily> LatentHessianWorkspace<F> {
4009 fn new(family: F, block_states: Vec<ParameterBlockState>) -> Self {
4010 let slices = family.ws_joint_slices();
4011 Self {
4012 family,
4013 block_states,
4014 slices,
4015 }
4016 }
4017}
4018
4019impl<F> ExactNewtonJointHessianWorkspace for LatentHessianWorkspace<F>
4020where
4021 F: LatentJointHessianFamily + Send + Sync + 'static,
4022{
4023 fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
4024 self.family
4025 .ws_evaluate_dense(&self.block_states)
4026 .map(|(_, _, hessian)| Some(hessian))
4027 }
4028
4029 fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
4030 let mut out = Array1::<f64>::zeros(self.slices.total);
4031 self.hessian_matvec_into(v, &mut out)?;
4032 Ok(Some(out))
4033 }
4034
4035 fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
4036 if v.len() != self.slices.total || out.len() != self.slices.total {
4037 return Err(format!(
4038 "latent {} Hessian matvec dimension mismatch: v={} out={} expected={}",
4039 F::ws_label(),
4040 v.len(),
4041 out.len(),
4042 self.slices.total
4043 ));
4044 }
4045 out.fill(0.0);
4046 self.family
4047 .ws_matvec_into(&self.slices, &self.block_states, v, out)
4048 }
4049
4050 fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
4051 let dense = self.family.ws_evaluate_dense(&self.block_states)?.2;
4052 Ok(Some(dense.diag().to_owned()))
4053 }
4054
4055 fn directional_derivative(
4056 &self,
4057 d_beta_flat: &Array1<f64>,
4058 ) -> Result<Option<Array2<f64>>, String> {
4059 self.family
4060 .ws_dh_directional(&self.block_states, d_beta_flat)
4061 .map(Some)
4062 }
4063
4064 fn second_directional_derivative(
4065 &self,
4066 d_beta_u: &Array1<f64>,
4067 d_beta_v: &Array1<f64>,
4068 ) -> Result<Option<Array2<f64>>, String> {
4069 self.family
4070 .ws_dh_second_directional(&self.block_states, d_beta_u, d_beta_v)
4071 .map(Some)
4072 }
4073}
4074
4075type LatentSurvivalHessianWorkspace = LatentHessianWorkspace<LatentSurvivalFamily>;
4076type LatentBinaryHessianWorkspace = LatentHessianWorkspace<LatentBinaryFamily>;
4077
4078impl CustomFamily for LatentSurvivalFamily {
4079 fn joint_jeffreys_term_required(&self) -> bool {
4083 true
4084 }
4085
4086 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4087 true
4088 }
4089
4090 fn has_explicit_joint_hessian(&self) -> bool {
4091 true
4092 }
4093
4094 fn levenberg_on_ill_conditioning(&self) -> bool {
4115 true
4116 }
4117
4118 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4119 crate::custom_family::joint_coupled_coefficient_hessian_cost(
4123 self.event_target.len() as u64,
4124 specs,
4125 )
4126 }
4127
4128 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4129 let (ll, joint_gradient, hess_time, hess_mean, hess_log_sigma) =
4130 self.evaluate_exact_newton_block_diagonals(block_states)?;
4131 let block_ranges = self.joint_block_ranges();
4132 let mut blockworking_sets = vec![
4133 BlockWorkingSet::ExactNewton {
4134 gradient: joint_gradient.slice(s![block_ranges[0].clone()]).to_owned(),
4135 hessian: SymmetricMatrix::Dense(hess_time),
4136 },
4137 BlockWorkingSet::ExactNewton {
4138 gradient: joint_gradient.slice(s![block_ranges[1].clone()]).to_owned(),
4139 hessian: SymmetricMatrix::Dense(hess_mean),
4140 },
4141 ];
4142 if let (Some(range), Some(hessian)) = (block_ranges.get(2).cloned(), hess_log_sigma) {
4143 blockworking_sets.push(BlockWorkingSet::ExactNewton {
4144 gradient: joint_gradient.slice(s![range]).to_owned(),
4145 hessian: SymmetricMatrix::Dense(hessian),
4146 });
4147 }
4148 Ok(FamilyEvaluation {
4149 log_likelihood: ll,
4150 blockworking_sets,
4151 })
4152 }
4153
4154 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4155 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4156 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
4157 let q_right = self.time_q_right(block_states)?;
4158 let latent_sd = self.latent_sd(block_states)?;
4159 let n = self.event_target.len();
4160 let contributions: Result<Vec<f64>, String> = (0..n)
4164 .into_par_iter()
4165 .map(|i| -> Result<f64, String> {
4166 let wi = self.weights[i];
4167 if wi <= MIN_WEIGHT {
4168 return Ok(0.0);
4169 }
4170 let row = self.build_row_at(i, q_entry[i], q_exit[i], qdot_exit[i], q_right[i])?;
4171 let jet = LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], latent_sd)
4172 .map_err(|e| format!("LatentSurvivalFamily row {i}: {e}"))?;
4173 Ok(wi * jet.log_lik)
4174 })
4175 .collect();
4176 Ok(contributions?.into_iter().sum())
4177 }
4178
4179 fn block_linear_constraints(
4180 &self,
4181 _: &[ParameterBlockState],
4182 block_idx: usize,
4183 block_spec: &ParameterBlockSpec,
4184 ) -> Result<Option<LinearInequalityConstraints>, String> {
4185 assert!(!block_spec.name.is_empty());
4186 if block_idx == Self::BLOCK_TIME {
4187 Ok(self.time_linear_constraints.clone())
4188 } else {
4189 Ok(None)
4190 }
4191 }
4192
4193 fn exact_newton_joint_hessian(
4194 &self,
4195 block_states: &[ParameterBlockState],
4196 ) -> Result<Option<Array2<f64>>, String> {
4197 self.evaluate_exact_newton_joint_dense(block_states)
4198 .map(|(_, _, hessian)| Some(hessian))
4199 }
4200
4201 fn exact_newton_joint_hessian_workspace(
4202 &self,
4203 block_states: &[ParameterBlockState],
4204 _: &[ParameterBlockSpec],
4205 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4206 Ok(Some(Arc::new(LatentSurvivalHessianWorkspace::new(
4207 self.clone(),
4208 block_states.to_vec(),
4209 ))))
4210 }
4211
4212 fn exact_newton_joint_gradient_evaluation(
4213 &self,
4214 block_states: &[ParameterBlockState],
4215 _: &[ParameterBlockSpec],
4216 ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4217 self.evaluate_exact_newton_joint_gradient_dense(block_states)
4218 .map(|(log_likelihood, gradient)| {
4219 Some(ExactNewtonJointGradientEvaluation {
4220 log_likelihood,
4221 gradient,
4222 })
4223 })
4224 }
4225
4226 fn exact_newton_joint_hessian_directional_derivative(
4227 &self,
4228 block_states: &[ParameterBlockState],
4229 d_beta_flat: &Array1<f64>,
4230 ) -> Result<Option<Array2<f64>>, String> {
4231 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4232 .map(Some)
4233 }
4234
4235 fn exact_newton_joint_hessiansecond_directional_derivative(
4236 &self,
4237 block_states: &[ParameterBlockState],
4238 d_beta_u_flat: &Array1<f64>,
4239 d_beta_v_flat: &Array1<f64>,
4240 ) -> Result<Option<Array2<f64>>, String> {
4241 self.exact_newton_joint_hessian_second_directional_derivative_dense(
4242 block_states,
4243 d_beta_u_flat,
4244 d_beta_v_flat,
4245 )
4246 .map(Some)
4247 }
4248
4249 fn requires_joint_outer_hyper_path(&self) -> bool {
4250 true
4251 }
4252}
4253
4254impl CustomFamily for LatentBinaryFamily {
4255 fn joint_jeffreys_term_required(&self) -> bool {
4259 true
4260 }
4261
4262 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4263 true
4264 }
4265
4266 fn has_explicit_joint_hessian(&self) -> bool {
4267 true
4268 }
4269
4270 fn levenberg_on_ill_conditioning(&self) -> bool {
4278 true
4279 }
4280
4281 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4282 crate::custom_family::joint_coupled_coefficient_hessian_cost(
4283 self.event_target.len() as u64,
4284 specs,
4285 )
4286 }
4287
4288 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4289 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4290 let n = self.event_target.len();
4291 let p_time = self.x_time_exit.ncols();
4292 let p_mean = self.x_mean.ncols();
4293
4294 let mut ll = 0.0;
4295 let mut grad_time = Array1::<f64>::zeros(p_time);
4296 let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
4297 let mut grad_mean = Array1::<f64>::zeros(p_mean);
4298 let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
4299 let mut mean_row_buf = Array2::<f64>::zeros((1, p_mean));
4302
4303 for i in 0..n {
4304 let wi = self.weights[i];
4305 if wi <= MIN_WEIGHT {
4306 continue;
4307 }
4308 if !(q_entry[i].is_finite() && q_exit[i].is_finite() && mu[i].is_finite()) {
4309 return Err(format!(
4310 "latent-binary row {i} contains non-finite predictors: q_entry={}, q_exit={}, mu={}",
4311 q_entry[i], q_exit[i], mu[i]
4312 ));
4313 }
4314 let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4315 let survival_jet =
4316 LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4317 .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4318 let binary = binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?;
4319 ll += wi * binary.log_lik;
4320
4321 self.x_mean
4322 .row_chunk_into(i..i + 1, mean_row_buf.view_mut())
4323 .map_err(|e| format!("LatentBinaryFamily row {i} mean row_chunk: {e}"))?;
4324 let mean_vec = mean_row_buf.row(0);
4325 let mean_grad_scale = wi * binary.grad_scale * survival_jet.score;
4326 for j in 0..p_mean {
4327 grad_mean[j] += mean_grad_scale * mean_vec[j];
4328 }
4329 let mean_neg_hess = wi
4330 * (binary.neg_hess_scale * survival_jet.neg_hessian
4331 + binary.outer_scale * survival_jet.score * survival_jet.score);
4332 dense_outer_accumulate(&mut hess_mean, mean_neg_hess, mean_vec);
4333
4334 let time_jet =
4335 latent_survival_time_jet(&self.quadctx, &row, 0.0, mu[i], self.latent_sd)?;
4336 let t_entry = self.x_time_entry.row(i);
4337 let t_exit = self.x_time_exit.row(i);
4338 for j in 0..p_time {
4339 grad_time[j] += wi
4340 * binary.grad_scale
4341 * (time_jet.grad_entry * t_entry[j] + time_jet.grad_exit * t_exit[j]);
4342 }
4343 dense_outer_accumulate(
4344 &mut hess_time,
4345 wi * binary.neg_hess_scale * time_jet.neg_hess_entry,
4346 t_entry,
4347 );
4348 dense_outer_accumulate(
4349 &mut hess_time,
4350 wi * binary.neg_hess_scale * time_jet.neg_hess_exit,
4351 t_exit,
4352 );
4353 if binary.outer_scale != 0.0 {
4354 dense_outer_accumulate(
4355 &mut hess_time,
4356 wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_entry,
4357 t_entry,
4358 );
4359 dense_outer_accumulate(
4360 &mut hess_time,
4361 wi * binary.outer_scale * time_jet.grad_exit * time_jet.grad_exit,
4362 t_exit,
4363 );
4364 dense_symmetric_cross_accumulate(
4365 &mut hess_time,
4366 wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_exit,
4367 t_entry,
4368 t_exit,
4369 );
4370 }
4371 }
4372
4373 Ok(FamilyEvaluation {
4374 log_likelihood: ll,
4375 blockworking_sets: vec![
4376 BlockWorkingSet::ExactNewton {
4377 gradient: grad_time,
4378 hessian: SymmetricMatrix::Dense(hess_time),
4379 },
4380 BlockWorkingSet::ExactNewton {
4381 gradient: grad_mean,
4382 hessian: SymmetricMatrix::Dense(hess_mean),
4383 },
4384 ],
4385 })
4386 }
4387
4388 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4389 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4390 let mut ll = 0.0;
4391 for i in 0..self.event_target.len() {
4392 let wi = self.weights[i];
4393 if wi <= MIN_WEIGHT {
4394 continue;
4395 }
4396 let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4397 let survival_jet =
4398 LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4399 .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4400 ll +=
4401 wi * binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?.log_lik;
4402 }
4403 Ok(ll)
4404 }
4405
4406 fn block_linear_constraints(
4407 &self,
4408 _: &[ParameterBlockState],
4409 block_idx: usize,
4410 block_spec: &ParameterBlockSpec,
4411 ) -> Result<Option<LinearInequalityConstraints>, String> {
4412 assert!(!block_spec.name.is_empty());
4413 if block_idx == Self::BLOCK_TIME {
4414 Ok(self.time_linear_constraints.clone())
4415 } else {
4416 Ok(None)
4417 }
4418 }
4419
4420 fn exact_newton_joint_hessian(
4421 &self,
4422 block_states: &[ParameterBlockState],
4423 ) -> Result<Option<Array2<f64>>, String> {
4424 self.evaluate_exact_newton_joint_dense(block_states)
4425 .map(|(_, _, hessian)| Some(hessian))
4426 }
4427
4428 fn exact_newton_joint_hessian_workspace(
4429 &self,
4430 block_states: &[ParameterBlockState],
4431 _: &[ParameterBlockSpec],
4432 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4433 Ok(Some(Arc::new(LatentBinaryHessianWorkspace::new(
4434 self.clone(),
4435 block_states.to_vec(),
4436 ))))
4437 }
4438
4439 fn exact_newton_joint_gradient_evaluation(
4440 &self,
4441 block_states: &[ParameterBlockState],
4442 _: &[ParameterBlockSpec],
4443 ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4444 self.evaluate_exact_newton_joint_dense(block_states)
4445 .map(|(log_likelihood, gradient, _)| {
4446 Some(ExactNewtonJointGradientEvaluation {
4447 log_likelihood,
4448 gradient,
4449 })
4450 })
4451 }
4452
4453 fn exact_newton_joint_hessian_directional_derivative(
4454 &self,
4455 block_states: &[ParameterBlockState],
4456 d_beta_flat: &Array1<f64>,
4457 ) -> Result<Option<Array2<f64>>, String> {
4458 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4459 .map(Some)
4460 }
4461
4462 fn exact_newton_joint_hessiansecond_directional_derivative(
4463 &self,
4464 block_states: &[ParameterBlockState],
4465 d_beta_u_flat: &Array1<f64>,
4466 d_beta_v_flat: &Array1<f64>,
4467 ) -> Result<Option<Array2<f64>>, String> {
4468 self.exact_newton_joint_hessian_second_directional_derivative_dense(
4469 block_states,
4470 d_beta_u_flat,
4471 d_beta_v_flat,
4472 )
4473 .map(Some)
4474 }
4475
4476 fn requires_joint_outer_hyper_path(&self) -> bool {
4477 true
4478 }
4479}
4480
4481#[cfg(test)]
4482mod tests {
4483 use super::*;
4484 use crate::custom_family::BlockWorkingSet;
4485 use gam_linalg::matrix::DenseDesignMatrix;
4486 use ndarray::array;
4487
4488 fn learnable_sigma_test_family() -> LatentSurvivalFamily {
4489 LatentSurvivalFamily {
4490 event_target: array![1u8, 0u8],
4491 weights: array![1.0, 0.7],
4492 latent_sd_fixed: None,
4493 hazard_loading: HazardLoading::LoadedVsUnloaded,
4494 unloaded_mass_entry: array![0.02, 0.03],
4495 unloaded_mass_exit: array![0.05, 0.08],
4496 unloaded_hazard_exit: array![0.04, 0.0],
4497 x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4498 x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4499 x_time_derivative_exit: array![[0.8, 0.4], [0.6, 0.5]],
4500 x_time_right: array![[1.3, 0.1], [0.9, 1.0]],
4501 time_offset_right: Array1::zeros(2),
4502 unloaded_mass_right: Array1::zeros(2),
4503 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4504 time_linear_constraints: None,
4505 quadctx: Arc::new(QuadratureContext::new()),
4506 }
4507 }
4508
4509 fn learnable_sigma_test_joint_beta() -> Array1<f64> {
4510 array![0.15, 0.25, 0.1, -0.15, 0.35_f64.ln()]
4511 }
4512
4513 fn survival_stress_test_family(n: usize) -> LatentSurvivalFamily {
4514 LatentSurvivalFamily {
4515 event_target: Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1u8 } else { 0u8 })),
4516 weights: Array1::from_iter((0..n).map(|i| 0.55 + 0.03 * ((i % 7) as f64))),
4517 latent_sd_fixed: None,
4518 hazard_loading: HazardLoading::LoadedVsUnloaded,
4519 unloaded_mass_entry: Array1::from_iter(
4520 (0..n).map(|i| 0.015 + 0.0015 * ((i % 11) as f64)),
4521 ),
4522 unloaded_mass_exit: Array1::from_iter((0..n).map(|i| 0.06 + 0.002 * ((i % 13) as f64))),
4523 unloaded_hazard_exit: Array1::from_iter((0..n).map(|i| {
4524 if i % 4 == 0 {
4525 0.018 + 0.001 * ((i % 5) as f64)
4526 } else {
4527 0.0
4528 }
4529 })),
4530 x_time_entry: Array2::from_shape_fn((n, 4), |(i, j)| {
4531 0.2 + 0.03 * ((i + 2 * j) % 9) as f64 - if j == 1 { 0.12 } else { 0.0 }
4532 }),
4533 x_time_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4534 0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4535 }),
4536 x_time_derivative_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4537 0.45 + 0.015 * ((i + 3 * j) % 8) as f64
4538 }),
4539 x_time_right: Array2::from_shape_fn((n, 4), |(i, j)| {
4540 0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4541 }),
4542 time_offset_right: Array1::zeros(n),
4543 unloaded_mass_right: Array1::zeros(n),
4544 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_shape_fn(
4545 (n, 3),
4546 |(i, j)| 0.1 + 0.04 * ((3 * i + j) % 7) as f64 - if j == 0 { 0.18 } else { 0.0 },
4547 ))),
4548 time_linear_constraints: None,
4549 quadctx: Arc::new(QuadratureContext::new()),
4550 }
4551 }
4552
4553 fn survival_stress_test_joint_beta() -> Array1<f64> {
4554 array![0.18, 0.11, 0.07, 0.13, -0.09, 0.05, 0.12, 0.42_f64.ln()]
4555 }
4556
4557 fn latent_survival_states_from_joint_beta(
4558 family: &LatentSurvivalFamily,
4559 joint_beta: &Array1<f64>,
4560 ) -> Vec<ParameterBlockState> {
4561 let slices = family.joint_slices();
4562 let n = family.event_target.len();
4563 let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4564 let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4565
4566 let mut eta_time = Array1::<f64>::zeros(3 * n);
4567 eta_time
4568 .slice_mut(s![0..n])
4569 .assign(&gam_linalg::faer_ndarray::fast_av(
4570 &family.x_time_entry,
4571 &beta_time,
4572 ));
4573 eta_time
4574 .slice_mut(s![n..2 * n])
4575 .assign(&gam_linalg::faer_ndarray::fast_av(
4576 &family.x_time_exit,
4577 &beta_time,
4578 ));
4579 eta_time
4580 .slice_mut(s![2 * n..3 * n])
4581 .assign(&gam_linalg::faer_ndarray::fast_av(
4582 &family.x_time_derivative_exit,
4583 &beta_time,
4584 ));
4585
4586 let mut states = vec![
4587 ParameterBlockState {
4588 beta: beta_time,
4589 eta: eta_time,
4590 },
4591 ParameterBlockState {
4592 beta: beta_mean.clone(),
4593 eta: family.x_mean.dot(&beta_mean),
4594 },
4595 ];
4596 if let Some(log_sigma) = slices.log_sigma {
4597 let beta_log_sigma = array![joint_beta[log_sigma.start]];
4598 states.push(ParameterBlockState {
4599 beta: beta_log_sigma.clone(),
4600 eta: beta_log_sigma,
4601 });
4602 }
4603 states
4604 }
4605
4606 fn max_relative_array1(left: &Array1<f64>, right: &Array1<f64>) -> f64 {
4607 left.iter()
4608 .zip(right.iter())
4609 .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4610 .fold(0.0_f64, f64::max)
4611 }
4612
4613 fn max_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4614 left.iter()
4615 .zip(right.iter())
4616 .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4617 .fold(0.0_f64, f64::max)
4618 }
4619
4620 fn frobenius_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4621 let mut diff2 = 0.0_f64;
4622 let mut scale2 = 0.0_f64;
4623 for (l, r) in left.iter().zip(right.iter()) {
4624 let d = l - r;
4625 diff2 += d * d;
4626 scale2 += l * l + r * r;
4627 }
4628 diff2.sqrt() / scale2.sqrt().max(1e-12)
4629 }
4630
4631 fn latent_survival_row_loglik_from_primary(
4632 quadctx: &QuadratureContext,
4633 row: &LatentSurvivalRow,
4634 primary: &Array1<f64>,
4635 ) -> f64 {
4636 let q_entry = primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
4637 let q_exit = primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
4638 let qdot_exit = primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
4639 let q_right = primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
4640 let mu = primary[LATENT_SURVIVAL_PRIMARY_MU];
4641 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
4642 latent_survival_row_primary_gradient_hessian(
4643 quadctx, row, q_entry, q_exit, qdot_exit, q_right, mu, sigma, true,
4644 )
4645 .expect("row primary evaluation")
4646 .0
4647 }
4648
4649 fn latent_test_specs(n: usize, block_dims: &[(&str, usize)]) -> Vec<ParameterBlockSpec> {
4650 block_dims
4651 .iter()
4652 .map(|(name, p)| ParameterBlockSpec {
4653 name: (*name).to_string(),
4654 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, *p)))),
4655 offset: Array1::zeros(n),
4656 penalties: Vec::new(),
4657 nullspace_dims: Vec::new(),
4658 initial_log_lambdas: Array1::zeros(0),
4659 initial_beta: None,
4660 gauge_priority: 100,
4661 jacobian_callback: None,
4662 stacked_design: None,
4663 stacked_offset: None,
4664 })
4665 .collect()
4666 }
4667
4668 fn fixed_sigma_binary_test_family() -> LatentBinaryFamily {
4669 LatentBinaryFamily {
4670 event_target: array![1u8, 0u8],
4671 weights: array![1.0, 0.7],
4672 latent_sd: 0.35,
4673 hazard_loading: HazardLoading::LoadedVsUnloaded,
4674 unloaded_mass_entry: array![0.02, 0.03],
4675 unloaded_mass_exit: array![0.05, 0.08],
4676 x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4677 x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4678 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4679 time_linear_constraints: None,
4680 quadctx: Arc::new(QuadratureContext::new()),
4681 }
4682 }
4683
4684 fn latent_binary_states_from_joint_beta(
4685 family: &LatentBinaryFamily,
4686 joint_beta: &Array1<f64>,
4687 ) -> Vec<ParameterBlockState> {
4688 let slices = family.joint_slices();
4689 let n = family.event_target.len();
4690 let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4691 let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4692
4693 let mut eta_time = Array1::<f64>::zeros(3 * n);
4694 eta_time
4695 .slice_mut(s![0..n])
4696 .assign(&gam_linalg::faer_ndarray::fast_av(
4697 &family.x_time_entry,
4698 &beta_time,
4699 ));
4700 eta_time
4701 .slice_mut(s![n..2 * n])
4702 .assign(&gam_linalg::faer_ndarray::fast_av(
4703 &family.x_time_exit,
4704 &beta_time,
4705 ));
4706
4707 vec![
4708 ParameterBlockState {
4709 beta: beta_time,
4710 eta: eta_time,
4711 },
4712 ParameterBlockState {
4713 beta: beta_mean.clone(),
4714 eta: family.x_mean.dot(&beta_mean),
4715 },
4716 ]
4717 }
4718
4719 use crate::survival::location_scale::{TimeBlockInput, TimeBlockMonotonicity};
4722
4723 fn validation_time_block(n: usize, p_time: usize) -> TimeBlockInput {
4727 let design = |fill: f64| {
4728 DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_elem(
4729 (n, p_time),
4730 fill,
4731 )))
4732 };
4733 TimeBlockInput {
4734 design_entry: design(0.1),
4735 design_exit: design(0.2),
4736 design_derivative_exit: design(0.3),
4737 offset_entry: Array1::zeros(n),
4738 offset_exit: Array1::zeros(n),
4739 derivative_offset_exit: Array1::zeros(n),
4740 time_monotonicity: TimeBlockMonotonicity::EnforcedByCoordinateCone,
4741 penalties: Vec::new(),
4742 nullspace_dims: Vec::new(),
4743 initial_log_lambdas: None,
4744 initial_beta: None,
4745 }
4746 }
4747
4748 fn empty_meanspec() -> TermCollectionSpec {
4749 TermCollectionSpec {
4750 linear_terms: Vec::new(),
4751 random_effect_terms: Vec::new(),
4752 smooth_terms: Vec::new(),
4753 }
4754 }
4755
4756 fn valid_survival_spec(n: usize, p_time: usize) -> LatentSurvivalTermSpec {
4759 LatentSurvivalTermSpec {
4760 age_entry: Array1::zeros(n),
4761 age_exit: Array1::from_elem(n, 1.0),
4762 event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4763 weights: Array1::from_elem(n, 1.0),
4764 derivative_guard: 0.0,
4765 time_block: validation_time_block(n, p_time),
4766 time_design_right: None,
4767 time_offset_right: None,
4768 unloaded_mass_entry: Array1::from_elem(n, 0.01),
4769 unloaded_mass_exit: Array1::from_elem(n, 0.05),
4770 unloaded_mass_right: Array1::zeros(0),
4771 unloaded_hazard_exit: Array1::from_elem(n, 0.02),
4772 meanspec: empty_meanspec(),
4773 mean_offset: Array1::zeros(n),
4774 }
4775 }
4776
4777 fn valid_binary_spec(n: usize, p_time: usize) -> LatentBinaryTermSpec {
4780 LatentBinaryTermSpec {
4781 age_entry: Array1::zeros(n),
4782 age_exit: Array1::from_elem(n, 1.0),
4783 event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4784 weights: Array1::from_elem(n, 1.0),
4785 derivative_guard: 0.0,
4786 time_block: validation_time_block(n, p_time),
4787 unloaded_mass_entry: Array1::from_elem(n, 0.01),
4788 unloaded_mass_exit: Array1::from_elem(n, 0.05),
4789 meanspec: empty_meanspec(),
4790 mean_offset: Array1::zeros(n),
4791 }
4792 }
4793
4794 fn loaded_frailty() -> FrailtySpec {
4795 FrailtySpec::HazardMultiplier {
4796 sigma_fixed: Some(0.3),
4797 loading: HazardLoading::LoadedVsUnloaded,
4798 }
4799 }
4800
4801 #[test]
4808 fn latent_interval_validation_parity_across_models() {
4809 let n = 2;
4810 let p_time = 2;
4811 let data = Array2::<f64>::zeros((n, 3));
4812
4813 let surv_sigma = validate_latent_survival_inputs(
4817 data.view(),
4818 &valid_survival_spec(n, p_time),
4819 &loaded_frailty(),
4820 )
4821 .expect("valid survival spec must validate");
4822 assert_eq!(surv_sigma, Some(0.3));
4823 let bin_sigma = validate_latent_binary_inputs(
4824 data.view(),
4825 &valid_binary_spec(n, p_time),
4826 &loaded_frailty(),
4827 )
4828 .expect("valid binary spec must validate");
4829 assert_eq!(bin_sigma, 0.3);
4830
4831 let empty = Array2::<f64>::zeros((0, 3));
4833 let surv_empty = validate_latent_survival_inputs(
4834 empty.view(),
4835 &valid_survival_spec(n, p_time),
4836 &loaded_frailty(),
4837 )
4838 .expect_err("empty data must be rejected");
4839 assert_eq!(
4840 surv_empty.to_string(),
4841 "latent-survival requires a non-empty dataset"
4842 );
4843 let bin_empty = validate_latent_binary_inputs(
4844 empty.view(),
4845 &valid_binary_spec(n, p_time),
4846 &loaded_frailty(),
4847 )
4848 .expect_err("empty data must be rejected");
4849 assert_eq!(
4850 bin_empty.to_string(),
4851 "latent-binary requires a non-empty dataset"
4852 );
4853
4854 let mut surv_bad = valid_survival_spec(n, p_time);
4858 surv_bad.weights = Array1::from_elem(n + 1, 1.0);
4859 let surv_size = validate_latent_survival_inputs(data.view(), &surv_bad, &loaded_frailty())
4860 .expect_err("size mismatch must be rejected");
4861 let surv_msg = surv_size.to_string();
4862 assert!(
4863 surv_msg.starts_with("latent-survival size mismatch")
4864 && surv_msg.contains("unloaded_hazard="),
4865 "survival size-mismatch message must include unloaded_hazard: {surv_msg}"
4866 );
4867 let mut bin_bad = valid_binary_spec(n, p_time);
4868 bin_bad.weights = Array1::from_elem(n + 1, 1.0);
4869 let bin_size = validate_latent_binary_inputs(data.view(), &bin_bad, &loaded_frailty())
4870 .expect_err("size mismatch must be rejected");
4871 let bin_msg = bin_size.to_string();
4872 assert!(
4873 bin_msg.starts_with("latent-binary size mismatch")
4874 && !bin_msg.contains("unloaded_hazard"),
4875 "binary size-mismatch message must omit unloaded_hazard: {bin_msg}"
4876 );
4877
4878 let mut surv_neg_hazard = valid_survival_spec(n, p_time);
4881 surv_neg_hazard.unloaded_hazard_exit[0] = -1.0;
4882 let surv_decomp =
4883 validate_latent_survival_inputs(data.view(), &surv_neg_hazard, &loaded_frailty())
4884 .expect_err("negative unloaded hazard must be rejected");
4885 assert_eq!(
4886 surv_decomp.to_string(),
4887 "latent-survival row 1 has invalid unloaded hazard decomposition: entry_mass=0.01, exit_mass=0.05, exit_hazard=-1"
4888 );
4889 let mut bin_bad_mass = valid_binary_spec(n, p_time);
4890 bin_bad_mass.unloaded_mass_exit[0] = 0.0; let bin_decomp =
4892 validate_latent_binary_inputs(data.view(), &bin_bad_mass, &loaded_frailty())
4893 .expect_err("non-monotone unloaded mass must be rejected");
4894 assert_eq!(
4895 bin_decomp.to_string(),
4896 "latent-binary row 1 has invalid unloaded mass decomposition: entry_mass=0.01, exit_mass=0"
4897 );
4898
4899 let mut surv_event = valid_survival_spec(n, p_time);
4902 surv_event.event_target[1] = 7;
4903 let surv_event_err =
4904 validate_latent_survival_inputs(data.view(), &surv_event, &loaded_frailty())
4905 .expect_err("invalid event target must be rejected");
4906 assert_eq!(
4907 surv_event_err.to_string(),
4908 "latent-survival row 2 has invalid event target 7; expected 0 or 1"
4909 );
4910 let mut bin_event = valid_binary_spec(n, p_time);
4911 bin_event.event_target[1] = 7;
4912 let bin_event_err =
4913 validate_latent_binary_inputs(data.view(), &bin_event, &loaded_frailty())
4914 .expect_err("invalid event target must be rejected");
4915 assert_eq!(
4916 bin_event_err.to_string(),
4917 "latent-binary row 2 has invalid event target 7; expected 0 or 1"
4918 );
4919
4920 let learnable = FrailtySpec::HazardMultiplier {
4923 sigma_fixed: None,
4924 loading: HazardLoading::LoadedVsUnloaded,
4925 };
4926 let surv_learnable = validate_latent_survival_inputs(
4927 data.view(),
4928 &valid_survival_spec(n, p_time),
4929 &learnable,
4930 )
4931 .expect("survival accepts a learnable latent scale");
4932 assert_eq!(surv_learnable, None);
4933 let bin_learnable =
4934 validate_latent_binary_inputs(data.view(), &valid_binary_spec(n, p_time), &learnable)
4935 .expect_err("binary requires a fixed latent scale");
4936 assert_eq!(
4937 bin_learnable.to_string(),
4938 "latent-binary currently requires a fixed hazard-multiplier sigma"
4939 );
4940
4941 let mut surv_time_bad = valid_survival_spec(n, p_time);
4944 surv_time_bad.time_block.design_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
4945 Array2::from_elem((n, p_time + 1), 0.1),
4946 ));
4947 let surv_time_err =
4948 validate_latent_survival_inputs(data.view(), &surv_time_bad, &loaded_frailty())
4949 .expect_err("time block column mismatch must be rejected");
4950 assert!(
4951 surv_time_err
4952 .to_string()
4953 .starts_with("latent-survival time block column mismatch"),
4954 "unexpected survival time-block message: {surv_time_err}"
4955 );
4956 }
4957
4958 #[test]
4959 fn latent_survival_coefficient_cost_uses_joint_coupled_formula() {
4960 let family = learnable_sigma_test_family();
4966 let n = family.event_target.len() as u64;
4967 let p_time = 2u64;
4968 let p_mean = 2u64;
4969 let p_log_sigma = 1u64;
4970 let specs = vec![
4971 ParameterBlockSpec {
4972 name: "time".to_string(),
4973 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4974 n as usize,
4975 p_time as usize,
4976 )))),
4977 offset: Array1::zeros(n as usize),
4978 penalties: Vec::new(),
4979 nullspace_dims: Vec::new(),
4980 initial_log_lambdas: Array1::zeros(0),
4981 initial_beta: None,
4982 gauge_priority: 100,
4983 jacobian_callback: None,
4984 stacked_design: None,
4985 stacked_offset: None,
4986 },
4987 ParameterBlockSpec {
4988 name: "mean".to_string(),
4989 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4990 n as usize,
4991 p_mean as usize,
4992 )))),
4993 offset: Array1::zeros(n as usize),
4994 penalties: Vec::new(),
4995 nullspace_dims: Vec::new(),
4996 initial_log_lambdas: Array1::zeros(0),
4997 initial_beta: None,
4998 gauge_priority: 100,
4999 jacobian_callback: None,
5000 stacked_design: None,
5001 stacked_offset: None,
5002 },
5003 ParameterBlockSpec {
5004 name: "log_sigma".to_string(),
5005 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
5006 n as usize,
5007 p_log_sigma as usize,
5008 )))),
5009 offset: Array1::zeros(n as usize),
5010 penalties: Vec::new(),
5011 nullspace_dims: Vec::new(),
5012 initial_log_lambdas: Array1::zeros(0),
5013 initial_beta: None,
5014 gauge_priority: 100,
5015 jacobian_callback: None,
5016 stacked_design: None,
5017 stacked_offset: None,
5018 },
5019 ];
5020 let p_total = p_time + p_mean + p_log_sigma;
5021 let expected_joint = n * p_total * p_total;
5022 let expected_block_diag =
5023 n * (p_time * p_time + p_mean * p_mean + p_log_sigma * p_log_sigma);
5024 assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
5025 assert!(expected_joint > expected_block_diag);
5028 }
5029
5030 #[test]
5031 fn latent_family_planner_keeps_outer_hessian_at_large_n() {
5032 use crate::custom_family::custom_family_outer_derivatives;
5033 use gam_problem::{DeclaredHessianForm, Derivative};
5034
5035 let options = BlockwiseFitOptions::default();
5036 let large_n = 50_001;
5037
5038 let survival = learnable_sigma_test_family();
5039 let survival_specs =
5040 latent_test_specs(large_n, &[("time", 2), ("mean", 2), ("log_sigma", 1)]);
5041 let (surv_grad, surv_hess) =
5042 custom_family_outer_derivatives(&survival, &survival_specs, &options);
5043 assert_eq!(surv_grad, Derivative::Analytic);
5044 assert_eq!(surv_hess, DeclaredHessianForm::Either);
5045
5046 let binary = fixed_sigma_binary_test_family();
5047 let binary_specs = latent_test_specs(large_n, &[("time", 2), ("mean", 2)]);
5048 let (bin_grad, bin_hess) =
5049 custom_family_outer_derivatives(&binary, &binary_specs, &options);
5050 assert_eq!(bin_grad, Derivative::Analytic);
5051 assert_eq!(bin_hess, DeclaredHessianForm::Either);
5052 }
5053
5054 #[test]
5055 fn latent_families_arm_self_vanishing_levenberg_on_ill_conditioning() {
5056 assert!(
5069 learnable_sigma_test_family().levenberg_on_ill_conditioning(),
5070 "LatentSurvivalFamily must arm the self-vanishing Levenberg floor so the \
5071 indefinite interval-censored joint Hessian converges (see #1108)"
5072 );
5073 assert!(
5074 fixed_sigma_binary_test_family().levenberg_on_ill_conditioning(),
5075 "LatentBinaryFamily must arm the self-vanishing Levenberg floor on its \
5076 constrained coupled time block (see #1108)"
5077 );
5078 }
5079
5080 #[test]
5081 fn latent_binary_exact_joint_hessian_and_workspace_matvec_match_fd() {
5082 let family = fixed_sigma_binary_test_family();
5083 let beta = array![0.15, 0.25, 0.1, -0.15];
5084 let states = latent_binary_states_from_joint_beta(&family, &beta);
5085 let h = 1e-6;
5086
5087 let analytic_hessian = family
5088 .exact_newton_joint_hessian(&states)
5089 .expect("analytic latent binary joint hessian evaluation")
5090 .expect("latent binary should expose exact joint hessian");
5091
5092 for j in 0..beta.len() {
5093 let mut beta_plus = beta.clone();
5094 beta_plus[j] += h;
5095 let gradient_plus = family
5096 .exact_newton_joint_gradient_evaluation(
5097 &latent_binary_states_from_joint_beta(&family, &beta_plus),
5098 &[],
5099 )
5100 .expect("joint gradient plus")
5101 .expect("joint gradient should exist")
5102 .gradient;
5103
5104 let mut beta_minus = beta.clone();
5105 beta_minus[j] -= h;
5106 let gradient_minus = family
5107 .exact_newton_joint_gradient_evaluation(
5108 &latent_binary_states_from_joint_beta(&family, &beta_minus),
5109 &[],
5110 )
5111 .expect("joint gradient minus")
5112 .expect("joint gradient should exist")
5113 .gradient;
5114
5115 let fd_column = -((&gradient_plus - &gradient_minus) / (2.0 * h));
5116 let analytic_column = analytic_hessian.column(j).to_owned();
5117 let rel = max_relative_array1(&analytic_column, &fd_column);
5118 assert!(
5119 rel < 5e-4,
5120 "latent binary joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={fd_column:?}"
5121 );
5122 }
5123
5124 let workspace = family
5125 .exact_newton_joint_hessian_workspace(&states, &[])
5126 .expect("latent binary hessian workspace")
5127 .expect("workspace should exist");
5128 let direction = array![0.4, -0.2, 0.3, 0.1];
5129 let hv = workspace
5130 .hessian_matvec(&direction)
5131 .expect("workspace matvec")
5132 .expect("workspace should support matvec");
5133 let dense_hv = analytic_hessian.dot(&direction);
5134 assert!(
5135 max_relative_array1(&hv, &dense_hv) < 1e-12,
5136 "latent binary workspace HVP mismatch: hv={hv:?}, dense={dense_hv:?}"
5137 );
5138
5139 let dh = workspace
5140 .directional_derivative(&direction)
5141 .expect("workspace dH")
5142 .expect("workspace should support dH");
5143 let fd_step = 1e-5;
5144 let h_plus = family
5145 .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5146 &family,
5147 &(beta.clone() + &(fd_step * &direction)),
5148 ))
5149 .expect("hessian plus")
5150 .expect("hessian plus should exist");
5151 let h_minus = family
5152 .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5153 &family,
5154 &(beta - &(fd_step * &direction)),
5155 ))
5156 .expect("hessian minus")
5157 .expect("hessian minus should exist");
5158 let fd_dh = (&h_plus - &h_minus) / (2.0 * fd_step);
5159 assert!(
5160 max_relative_array2(&dh, &fd_dh) < 2e-4,
5161 "latent binary workspace dH mismatch: dh={dh:?}, fd={fd_dh:?}"
5162 );
5163 }
5164
5165 #[test]
5166 fn latent_survival_learnable_sigma_block_matches_family_fd() {
5167 let family = learnable_sigma_test_family();
5168 let beta = learnable_sigma_test_joint_beta();
5169 let states = latent_survival_states_from_joint_beta(&family, &beta);
5170 let slices = family.joint_slices();
5171 let sigma_idx = slices
5172 .log_sigma
5173 .as_ref()
5174 .expect("learnable sigma test family should expose log_sigma")
5175 .start;
5176 let h = 2e-4;
5177
5178 let eval = family
5179 .evaluate(&states)
5180 .expect("learnable latent survival evaluation");
5181 let joint_gradient = family
5182 .exact_newton_joint_gradient_evaluation(&states, &[])
5183 .expect("joint gradient evaluation")
5184 .expect("joint gradient should exist")
5185 .gradient;
5186 let joint_hessian = family
5187 .exact_newton_joint_hessian(&states)
5188 .expect("joint hessian evaluation")
5189 .expect("joint hessian should exist");
5190 assert_eq!(eval.blockworking_sets.len(), 3);
5191
5192 let (block_grad, block_neg_hess) =
5193 match &eval.blockworking_sets[LatentSurvivalFamily::BLOCK_LOG_SIGMA] {
5194 BlockWorkingSet::ExactNewton { gradient, hessian } => {
5195 let neg_hess = match hessian {
5196 SymmetricMatrix::Dense(mat) => mat[[0, 0]],
5197 _ => panic!("log_sigma block should use a dense exact-Newton Hessian"),
5198 };
5199 (gradient[0], neg_hess)
5200 }
5201 _ => panic!("log_sigma block should use ExactNewton"),
5202 };
5203
5204 assert!((block_grad - joint_gradient[sigma_idx]).abs() < 1e-12);
5205 assert!((block_neg_hess - joint_hessian[[sigma_idx, sigma_idx]]).abs() < 1e-12);
5206
5207 let mut beta_plus = beta.clone();
5208 beta_plus[sigma_idx] += h;
5209 let ll_plus = family
5210 .log_likelihood_only(&latent_survival_states_from_joint_beta(&family, &beta_plus))
5211 .expect("ll plus");
5212 let ll_0 = family.log_likelihood_only(&states).expect("ll base");
5213 let mut beta_minus = beta.clone();
5214 beta_minus[sigma_idx] -= h;
5215 let ll_minus = family
5216 .log_likelihood_only(&latent_survival_states_from_joint_beta(
5217 &family,
5218 &beta_minus,
5219 ))
5220 .expect("ll minus");
5221
5222 let fd_grad = (ll_plus - ll_minus) / (2.0 * h);
5223 let fd_neg_hess = -(ll_plus - 2.0 * ll_0 + ll_minus) / (h * h);
5224 assert!(
5225 (joint_gradient[sigma_idx] - fd_grad).abs()
5226 / joint_gradient[sigma_idx]
5227 .abs()
5228 .max(fd_grad.abs())
5229 .max(1e-12)
5230 < 2e-3,
5231 "family log_sigma grad={}, fd={fd_grad}",
5232 joint_gradient[sigma_idx]
5233 );
5234 assert!(
5235 (joint_hessian[[sigma_idx, sigma_idx]] - fd_neg_hess).abs()
5236 / joint_hessian[[sigma_idx, sigma_idx]]
5237 .abs()
5238 .max(fd_neg_hess.abs())
5239 .max(1e-10)
5240 < 2e-2,
5241 "family log_sigma neg_hess={}, fd={fd_neg_hess}",
5242 joint_hessian[[sigma_idx, sigma_idx]]
5243 );
5244 }
5245
5246 #[test]
5247 fn latent_survival_exact_joint_hessian_matches_gradient_fd() {
5248 let family = learnable_sigma_test_family();
5249 let beta = learnable_sigma_test_joint_beta();
5250 let states = latent_survival_states_from_joint_beta(&family, &beta);
5251 let h = 1e-6;
5252
5253 let analytic_hessian = family
5254 .exact_newton_joint_hessian(&states)
5255 .expect("analytic joint hessian evaluation")
5256 .expect("latent survival should expose exact joint hessian");
5257
5258 for j in 0..beta.len() {
5259 let mut beta_plus = beta.clone();
5260 beta_plus[j] += h;
5261 let gradient_plus = family
5262 .exact_newton_joint_gradient_evaluation(
5263 &latent_survival_states_from_joint_beta(&family, &beta_plus),
5264 &[],
5265 )
5266 .expect("joint gradient plus")
5267 .expect("joint gradient should exist")
5268 .gradient;
5269
5270 let mut beta_minus = beta.clone();
5271 beta_minus[j] -= h;
5272 let gradient_minus = family
5273 .exact_newton_joint_gradient_evaluation(
5274 &latent_survival_states_from_joint_beta(&family, &beta_minus),
5275 &[],
5276 )
5277 .expect("joint gradient minus")
5278 .expect("joint gradient should exist")
5279 .gradient;
5280
5281 let fd_column = (&gradient_plus - &gradient_minus) / (2.0 * h);
5282 let analytic_column = analytic_hessian.column(j).to_owned();
5283 let rel = max_relative_array1(&analytic_column, &(-fd_column));
5284 assert!(
5285 rel < 5e-4,
5286 "joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={:?}",
5287 -((&gradient_plus - &gradient_minus) / (2.0 * h))
5288 );
5289 }
5290 }
5291
5292 #[test]
5299 fn latent_survival_offset_channel_residuals_match_finite_difference() {
5300 let family = survival_stress_test_family(24);
5301 let beta = survival_stress_test_joint_beta();
5302 let states = latent_survival_states_from_joint_beta(&family, &beta);
5303 let n = family.event_target.len();
5304
5305 let residuals = family
5306 .offset_channel_residuals(&states)
5307 .expect("offset channel residuals");
5308 let sum_entry: f64 = residuals.entry.sum();
5309 let sum_exit: f64 = residuals.exit.sum();
5310 let sum_deriv: f64 = residuals.derivative.sum();
5311
5312 let neg_ll_with_offset = |channel: usize, delta: f64| -> f64 {
5314 let mut shifted = states.clone();
5315 let slice = match channel {
5316 0 => s![0..n],
5317 1 => s![n..2 * n],
5318 2 => s![2 * n..3 * n],
5319 _ => unreachable!(),
5320 };
5321 shifted[LatentSurvivalFamily::BLOCK_TIME]
5322 .eta
5323 .slice_mut(slice)
5324 .mapv_inplace(|v| v + delta);
5325 let (ll, _) = family
5326 .evaluate_exact_newton_joint_gradient_dense(&shifted)
5327 .expect("shifted joint gradient evaluation");
5328 -ll
5329 };
5330
5331 let h = 1e-6;
5332 let fd_entry = (neg_ll_with_offset(0, h) - neg_ll_with_offset(0, -h)) / (2.0 * h);
5333 let fd_exit = (neg_ll_with_offset(1, h) - neg_ll_with_offset(1, -h)) / (2.0 * h);
5334 let fd_deriv = (neg_ll_with_offset(2, h) - neg_ll_with_offset(2, -h)) / (2.0 * h);
5335
5336 assert!(
5337 (sum_entry - fd_entry).abs() <= 1e-5 * fd_entry.abs().max(1.0),
5338 "entry-channel residual sum mismatch: analytic={sum_entry}, fd={fd_entry}"
5339 );
5340 assert!(
5341 (sum_exit - fd_exit).abs() <= 1e-5 * fd_exit.abs().max(1.0),
5342 "exit-channel residual sum mismatch: analytic={sum_exit}, fd={fd_exit}"
5343 );
5344 assert!(
5345 (sum_deriv - fd_deriv).abs() <= 1e-5 * fd_deriv.abs().max(1.0),
5346 "derivative-channel residual sum mismatch: analytic={sum_deriv}, fd={fd_deriv}"
5347 );
5348 }
5349
5350 #[test]
5351 fn latent_survival_exact_joint_parallel_stress_is_repeatable() {
5352 let family = survival_stress_test_family(96);
5353 let beta = survival_stress_test_joint_beta();
5354 let states = latent_survival_states_from_joint_beta(&family, &beta);
5355 let direction_u = array![0.03, -0.02, 0.01, 0.04, -0.015, 0.025, -0.005, 0.02];
5356 let direction_v = array![-0.01, 0.035, -0.025, 0.015, 0.02, -0.01, 0.03, -0.015];
5357
5358 let (ll_a, grad_a) = family
5359 .evaluate_exact_newton_joint_gradient_dense(&states)
5360 .expect("stress joint gradient evaluation");
5361 let (ll_b, grad_b) = family
5362 .evaluate_exact_newton_joint_gradient_dense(&states)
5363 .expect("repeat stress joint gradient evaluation");
5364 assert_eq!(ll_a.to_bits(), ll_b.to_bits());
5365 assert_eq!(grad_a, grad_b);
5366
5367 let (joint_ll_a, joint_grad_a, hess_a) = family
5368 .evaluate_exact_newton_joint_dense(&states)
5369 .expect("stress joint dense evaluation");
5370 let (joint_ll_b, joint_grad_b, hess_b) = family
5371 .evaluate_exact_newton_joint_dense(&states)
5372 .expect("repeat stress joint dense evaluation");
5373 assert_eq!(joint_ll_a.to_bits(), joint_ll_b.to_bits());
5374 assert_eq!(joint_grad_a, joint_grad_b);
5375 assert_eq!(hess_a, hess_b);
5376 assert!(hess_a.iter().all(|value| value.is_finite()));
5377 assert!(max_relative_array2(&hess_a, &hess_a.t().to_owned()) < 1e-12);
5378
5379 let dh_a = family
5380 .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5381 .expect("stress joint dH evaluation");
5382 let dh_b = family
5383 .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5384 .expect("repeat stress joint dH evaluation");
5385 assert_eq!(dh_a, dh_b);
5386 assert!(dh_a.iter().all(|value| value.is_finite()));
5387 assert!(max_relative_array2(&dh_a, &dh_a.t().to_owned()) < 1e-12);
5388
5389 let d2h_a = family
5390 .exact_newton_joint_hessian_second_directional_derivative_dense(
5391 &states,
5392 &direction_u,
5393 &direction_v,
5394 )
5395 .expect("stress joint d2H evaluation");
5396 let d2h_b = family
5397 .exact_newton_joint_hessian_second_directional_derivative_dense(
5398 &states,
5399 &direction_u,
5400 &direction_v,
5401 )
5402 .expect("repeat stress joint d2H evaluation");
5403 assert_eq!(d2h_a, d2h_b);
5404 assert!(d2h_a.iter().all(|value| value.is_finite()));
5405 assert!(max_relative_array2(&d2h_a, &d2h_a.t().to_owned()) < 1e-12);
5406 }
5407
5408 #[test]
5409 fn latent_survival_exact_joint_dh_matches_hessian_fd() {
5410 let family = learnable_sigma_test_family();
5411 let beta = learnable_sigma_test_joint_beta();
5412 let states = latent_survival_states_from_joint_beta(&family, &beta);
5413 let h = 2e-4;
5414 let direction = array![0.07, -0.03, 0.05, 0.02, -0.04];
5415
5416 let analytic = family
5417 .exact_newton_joint_hessian_directional_derivative(&states, &direction)
5418 .expect("analytic joint dH evaluation")
5419 .expect("latent survival should expose exact joint dH");
5420
5421 let hessian_plus = family
5422 .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5423 &family,
5424 &(beta.clone() + h * &direction),
5425 ))
5426 .expect("joint hessian plus")
5427 .expect("joint hessian should exist");
5428 let hessian_minus = family
5429 .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5430 &family,
5431 &(beta.clone() - h * &direction),
5432 ))
5433 .expect("joint hessian minus")
5434 .expect("joint hessian should exist");
5435
5436 let fd = (&hessian_plus - &hessian_minus) / (2.0 * h);
5437 let rel = frobenius_relative_array2(&analytic, &fd);
5438 assert!(rel < 2e-3, "joint dH mismatch: rel={rel}");
5439 }
5440
5441 #[test]
5442 fn latent_survival_exact_joint_d2h_matches_directional_fd() {
5443 let family = learnable_sigma_test_family();
5444 let beta = learnable_sigma_test_joint_beta();
5445 let states = latent_survival_states_from_joint_beta(&family, &beta);
5446 let h = 5e-4;
5447 let direction_u = array![0.07, -0.03, 0.05, 0.02, -0.04];
5448 let direction_v = array![-0.02, 0.06, -0.01, 0.03, 0.05];
5449
5450 let analytic = family
5451 .exact_newton_joint_hessiansecond_directional_derivative(
5452 &states,
5453 &direction_u,
5454 &direction_v,
5455 )
5456 .expect("analytic joint d2H evaluation")
5457 .expect("latent survival should expose exact joint d2H");
5458 let swapped = family
5459 .exact_newton_joint_hessiansecond_directional_derivative(
5460 &states,
5461 &direction_v,
5462 &direction_u,
5463 )
5464 .expect("swapped analytic joint d2H evaluation")
5465 .expect("latent survival should expose exact joint d2H");
5466 let symmetry_rel = max_relative_array2(&analytic, &swapped);
5467 assert!(
5468 symmetry_rel < 1e-10,
5469 "joint d2H should be symmetric in directions, got rel={symmetry_rel}"
5470 );
5471
5472 let dh_plus = family
5473 .exact_newton_joint_hessian_directional_derivative(
5474 &latent_survival_states_from_joint_beta(
5475 &family,
5476 &(beta.clone() + h * &direction_v),
5477 ),
5478 &direction_u,
5479 )
5480 .expect("joint dH plus")
5481 .expect("joint dH should exist");
5482 let dh_minus = family
5483 .exact_newton_joint_hessian_directional_derivative(
5484 &latent_survival_states_from_joint_beta(
5485 &family,
5486 &(beta.clone() - h * &direction_v),
5487 ),
5488 &direction_u,
5489 )
5490 .expect("joint dH minus")
5491 .expect("joint dH should exist");
5492
5493 let fd = (&dh_plus - &dh_minus) / (2.0 * h);
5494 let rel = frobenius_relative_array2(&analytic, &fd);
5495 assert!(rel < 2.5e-2, "joint d2H mismatch: rel={rel}");
5496 }
5497
5498 #[test]
5499 fn latent_survival_row_primary_derivatives_match_fd() {
5500 let quadctx = QuadratureContext::new();
5501 let row = LatentSurvivalRow::exact_event(0.35, 1.4, 0.1, 0.45, 0.8, 0.12);
5502 let primary = array![
5507 0.35f64.ln(),
5508 1.4f64.ln(),
5509 0.8,
5510 1.6f64.ln(),
5511 -0.2,
5512 0.4f64.ln()
5513 ];
5514 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5515 let h_grad = 1e-6;
5516 let h_hess = 2e-4;
5517
5518 let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5519 &quadctx,
5520 &row,
5521 primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5522 primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5523 primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5524 primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5525 primary[LATENT_SURVIVAL_PRIMARY_MU],
5526 sigma,
5527 true,
5528 )
5529 .expect("analytic row primary gradient/hessian");
5530
5531 for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5532 let mut plus = primary.clone();
5533 plus[j] += h_grad;
5534 let mut minus = primary.clone();
5535 minus[j] -= h_grad;
5536 let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5537 - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5538 / (2.0 * h_grad);
5539 let rel_grad =
5540 (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5541 assert!(
5542 rel_grad < 2e-4,
5543 "row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5544 gradient[j]
5545 );
5546
5547 for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5548 let mut pp = primary.clone();
5549 pp[j] += h_hess;
5550 pp[k] += h_hess;
5551 let mut pm = primary.clone();
5552 pm[j] += h_hess;
5553 pm[k] -= h_hess;
5554 let mut mp = primary.clone();
5555 mp[j] -= h_hess;
5556 mp[k] += h_hess;
5557 let mut mm = primary.clone();
5558 mm[j] -= h_hess;
5559 mm[k] -= h_hess;
5560 let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5561 - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5562 - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5563 + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5564 / (4.0 * h_hess * h_hess);
5565 let analytic = neg_hessian[[j, k]];
5566 let abs_err = (analytic - fd_neg_hess).abs();
5567 let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5568 assert!(
5569 abs_err < 2e-5 || rel < 2e-3,
5570 "row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5571 );
5572 }
5573 }
5574 }
5575
5576 #[test]
5577 fn latent_survival_interval_row_primary_derivatives_match_fd() {
5578 let quadctx = QuadratureContext::new();
5590 let q_entry = -1.2_f64; let q_exit = -0.4_f64; let q_right = 0.5_f64; let mu = -0.15_f64;
5595 let log_sigma = 0.3_f64; let row = LatentSurvivalRow::interval_censored(
5599 q_entry.exp(), q_exit.exp(), q_right.exp(), 0.01, 0.02, 0.05, );
5606 assert!(matches!(
5607 row.event_type,
5608 LatentSurvivalEventType::IntervalCensored
5609 ));
5610
5611 let primary = array![q_entry, q_exit, 0.7, q_right, mu, log_sigma];
5615 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5616 let h_grad = 1e-6;
5617 let h_hess = 2e-4;
5618
5619 let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5620 &quadctx,
5621 &row,
5622 primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5623 primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5624 primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5625 primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5626 primary[LATENT_SURVIVAL_PRIMARY_MU],
5627 sigma,
5628 true,
5629 )
5630 .expect("analytic interval row primary gradient/hessian");
5631
5632 let value = latent_survival_row_loglik_from_primary(&quadctx, &row, &primary);
5635 assert!(
5636 value.is_finite(),
5637 "interval row log-likelihood must be finite on a well-posed bracket, got {value}"
5638 );
5639
5640 for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5641 let mut plus = primary.clone();
5642 plus[j] += h_grad;
5643 let mut minus = primary.clone();
5644 minus[j] -= h_grad;
5645 let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5646 - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5647 / (2.0 * h_grad);
5648 let rel_grad =
5649 (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5650 assert!(
5651 rel_grad < 2e-4,
5652 "interval row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5653 gradient[j]
5654 );
5655
5656 for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5657 let mut pp = primary.clone();
5658 pp[j] += h_hess;
5659 pp[k] += h_hess;
5660 let mut pm = primary.clone();
5661 pm[j] += h_hess;
5662 pm[k] -= h_hess;
5663 let mut mp = primary.clone();
5664 mp[j] -= h_hess;
5665 mp[k] += h_hess;
5666 let mut mm = primary.clone();
5667 mm[j] -= h_hess;
5668 mm[k] -= h_hess;
5669 let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5670 - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5671 - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5672 + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5673 / (4.0 * h_hess * h_hess);
5674 let analytic = neg_hessian[[j, k]];
5675 let abs_err = (analytic - fd_neg_hess).abs();
5676 let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5677 assert!(
5678 abs_err < 5e-5 || rel < 3e-3,
5679 "interval row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5680 );
5681 }
5682 }
5683 }
5684}