1use crate::custom_family::{
21 BlockWorkingSet, BlockwiseFitOptions, CustomFamily, ExactNewtonJointGradientEvaluation,
22 ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
23 PenaltyMatrix, fit_custom_family, fit_custom_family_fixed_log_lambdas,
24};
25use crate::gamlss::{FamilyMetadata, ParameterLink};
26use crate::sigma_link::{exp_sigma_eta_for_sigma_scalar, exp_sigma_from_eta_scalar};
27use crate::survival::latent::interval::{
28 LatentFrailtyResolution, LatentIntervalModel, LatentIntervalRowView,
29 validate_latent_interval_inputs,
30};
31use crate::survival::location_scale::{
32 TimeBlockInput, project_onto_linear_constraints, structural_time_coefficient_constraints,
33};
34use crate::survival::lognormal_kernel::{
35 FrailtySpec, HazardLoading, LatentSurvivalEventType, LatentSurvivalRow, LatentSurvivalRowJet,
36 log_kernel_bundle,
37};
38use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
39use crate::model_types::UnifiedFitResult;
40use gam_solve::pirls::LinearInequalityConstraints;
41use crate::probability::signed_log_sum_exp;
42use crate::quadrature::{IntegratedExpectationMode, QuadratureContext};
43use gam_terms::smooth::{
44 TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
45};
46use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
47use gam_problem::MIN_WEIGHT;
48use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
49use std::collections::BTreeMap;
50use std::sync::Arc;
51
52#[derive(Debug, Clone)]
58pub enum LatentSurvivalError {
59 InvalidFrailty { reason: String },
63 InvalidDataset { reason: String },
67 BlockMismatch { reason: String },
70 NumericalFailure { reason: String },
73 UnsupportedConfiguration { reason: String },
77}
78
79impl_reason_error_boilerplate! {
80 LatentSurvivalError {
81 InvalidFrailty,
82 InvalidDataset,
83 BlockMismatch,
84 NumericalFailure,
85 UnsupportedConfiguration,
86 }
87}
88
89impl From<crate::block_layout::block_count::BlockCountMismatch> for LatentSurvivalError {
90 fn from(
91 err: crate::block_layout::block_count::BlockCountMismatch,
92 ) -> LatentSurvivalError {
93 LatentSurvivalError::BlockMismatch {
94 reason: err.message(),
95 }
96 }
97}
98
99impl From<String> for LatentSurvivalError {
100 fn from(reason: String) -> LatentSurvivalError {
106 LatentSurvivalError::InvalidDataset { reason }
107 }
108}
109
110pub const LATENT_SURVIVAL_EVENT_INTERVAL: u8 = u8::MAX;
116
117#[inline]
118fn latent_survival_event_type_for(code: u8) -> LatentSurvivalEventType {
119 match code {
120 0 => LatentSurvivalEventType::RightCensored,
121 LATENT_SURVIVAL_EVENT_INTERVAL => LatentSurvivalEventType::IntervalCensored,
122 _ => LatentSurvivalEventType::ExactEvent,
123 }
124}
125
126#[derive(Clone)]
127pub struct LatentSurvivalTermSpec {
128 pub age_entry: Array1<f64>,
129 pub age_exit: Array1<f64>,
130 pub event_target: Array1<u8>,
131 pub weights: Array1<f64>,
132 pub derivative_guard: f64,
133 pub time_block: TimeBlockInput,
134 pub time_design_right: Option<DesignMatrix>,
141 pub time_offset_right: Option<Array1<f64>>,
142 pub unloaded_mass_entry: Array1<f64>,
143 pub unloaded_mass_exit: Array1<f64>,
144 pub unloaded_mass_right: Array1<f64>,
148 pub unloaded_hazard_exit: Array1<f64>,
149 pub meanspec: TermCollectionSpec,
150 pub mean_offset: Array1<f64>,
151}
152
153pub struct LatentSurvivalTermFitResult {
154 pub fit: UnifiedFitResult,
155 pub design: TermCollectionDesign,
156 pub resolvedspec: TermCollectionSpec,
157 pub latent_sd: f64,
158 pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
164}
165
166#[derive(Clone)]
167pub struct LatentBinaryTermSpec {
168 pub age_entry: Array1<f64>,
169 pub age_exit: Array1<f64>,
170 pub event_target: Array1<u8>,
171 pub weights: Array1<f64>,
172 pub derivative_guard: f64,
173 pub time_block: TimeBlockInput,
174 pub unloaded_mass_entry: Array1<f64>,
175 pub unloaded_mass_exit: Array1<f64>,
176 pub meanspec: TermCollectionSpec,
177 pub mean_offset: Array1<f64>,
178}
179
180pub struct LatentBinaryTermFitResult {
181 pub fit: UnifiedFitResult,
182 pub design: TermCollectionDesign,
183 pub resolvedspec: TermCollectionSpec,
184 pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
188}
189
190#[derive(Clone)]
191struct PreparedLatentTimeBlock {
192 design_entry: Array2<f64>,
193 design_exit: Array2<f64>,
194 design_derivative_exit: Array2<f64>,
195 design_right: Array2<f64>,
200 linear_constraints: Option<LinearInequalityConstraints>,
201 penalties: Vec<Array2<f64>>,
202 initial_beta: Option<Array1<f64>>,
203}
204
205#[derive(Clone)]
206pub struct LatentSurvivalFamily {
207 pub event_target: Array1<u8>,
208 pub weights: Array1<f64>,
209 pub latent_sd_fixed: Option<f64>,
210 pub hazard_loading: HazardLoading,
211 pub unloaded_mass_entry: Array1<f64>,
212 pub unloaded_mass_exit: Array1<f64>,
213 pub unloaded_hazard_exit: Array1<f64>,
214 pub x_time_entry: Array2<f64>,
215 pub x_time_exit: Array2<f64>,
216 pub x_time_derivative_exit: Array2<f64>,
217 pub x_time_right: Array2<f64>,
223 pub time_offset_right: Array1<f64>,
225 pub unloaded_mass_right: Array1<f64>,
228 pub x_mean: DesignMatrix,
229 pub time_linear_constraints: Option<LinearInequalityConstraints>,
230 pub quadctx: Arc<QuadratureContext>,
231}
232
233#[derive(Clone)]
234pub struct LatentBinaryFamily {
235 pub event_target: Array1<u8>,
236 pub weights: Array1<f64>,
237 pub latent_sd: f64,
238 pub hazard_loading: HazardLoading,
239 pub unloaded_mass_entry: Array1<f64>,
240 pub unloaded_mass_exit: Array1<f64>,
241 pub x_time_entry: Array2<f64>,
242 pub x_time_exit: Array2<f64>,
243 pub x_mean: DesignMatrix,
244 pub time_linear_constraints: Option<LinearInequalityConstraints>,
245 pub quadctx: Arc<QuadratureContext>,
246}
247
248impl LatentSurvivalFamily {
249 pub const BLOCK_TIME: usize = 0;
250 pub const BLOCK_MEAN: usize = 1;
251 pub const BLOCK_LOG_SIGMA: usize = 2;
252
253 pub fn parameter_names() -> &'static [&'static str] {
254 &["time_transform", "mean"]
255 }
256
257 pub fn parameter_links() -> &'static [ParameterLink] {
258 &[ParameterLink::Identity, ParameterLink::Identity]
259 }
260
261 pub fn metadata() -> FamilyMetadata {
262 FamilyMetadata {
263 name: "latent_survival",
264 parameternames: Self::parameter_names(),
265 parameter_links: Self::parameter_links(),
266 }
267 }
268
269 fn split_time_eta<'a>(
270 &self,
271 block_states: &'a [ParameterBlockState],
272 ) -> Result<
273 (
274 ArrayView1<'a, f64>,
275 ArrayView1<'a, f64>,
276 ArrayView1<'a, f64>,
277 &'a Array1<f64>,
278 ),
279 LatentSurvivalError,
280 > {
281 let expected_blocks = if self.latent_sd_fixed.is_some() { 2 } else { 3 };
282 crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
283 "LatentSurvivalFamily",
284 expected_blocks,
285 block_states.len(),
286 )?;
287 let n = self.event_target.len();
288 let eta_time = &block_states[Self::BLOCK_TIME].eta;
289 let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
290 if eta_time.len() != 3 * n {
291 return Err(LatentSurvivalError::BlockMismatch {
292 reason: format!(
293 "latent survival time eta length mismatch: got {}, expected {}",
294 eta_time.len(),
295 3 * n
296 ),
297 });
298 }
299 if eta_mean.len() != n || self.weights.len() != n {
300 return Err(LatentSurvivalError::BlockMismatch {
301 reason: "latent survival mean eta dimension mismatch".to_string(),
302 });
303 }
304 Ok((
305 eta_time.slice(s![0..n]),
306 eta_time.slice(s![n..2 * n]),
307 eta_time.slice(s![2 * n..3 * n]),
308 eta_mean,
309 ))
310 }
311
312 fn time_q_right(
319 &self,
320 block_states: &[ParameterBlockState],
321 ) -> Result<Array1<f64>, LatentSurvivalError> {
322 let n = self.event_target.len();
323 let beta_time = &block_states[Self::BLOCK_TIME].beta;
324 if self.x_time_right.ncols() != beta_time.len() {
325 return Err(LatentSurvivalError::BlockMismatch {
326 reason: format!(
327 "latent survival interval right design has {} columns but time beta has {}",
328 self.x_time_right.ncols(),
329 beta_time.len()
330 ),
331 });
332 }
333 if self.x_time_right.nrows() != n || self.time_offset_right.len() != n {
334 return Err(LatentSurvivalError::BlockMismatch {
335 reason: "latent survival interval right design/offset row count mismatch"
336 .to_string(),
337 });
338 }
339 let mut q_right = self.x_time_right.dot(beta_time);
340 q_right += &self.time_offset_right;
341 Ok(q_right)
342 }
343
344 fn latent_sd(&self, block_states: &[ParameterBlockState]) -> Result<f64, LatentSurvivalError> {
345 if let Some(sigma) = self.latent_sd_fixed {
346 return Ok(sigma);
347 }
348 let eta = *block_states
349 .get(Self::BLOCK_LOG_SIGMA)
350 .and_then(|state| state.eta.get(0))
351 .ok_or_else(|| LatentSurvivalError::BlockMismatch {
352 reason: "latent survival learnable log_sigma block is missing".to_string(),
353 })?;
354 let sigma = exp_sigma_from_eta_scalar(eta);
355 if !(sigma.is_finite() && sigma > 0.0) {
356 return Err(LatentSurvivalError::NumericalFailure {
357 reason: format!(
358 "latent survival learnable sigma became invalid: log_sigma={eta}, sigma={sigma}"
359 ),
360 });
361 }
362 Ok(sigma)
363 }
364}
365
366impl LatentBinaryFamily {
367 pub const BLOCK_TIME: usize = 0;
368 pub const BLOCK_MEAN: usize = 1;
369
370 fn split_time_eta<'a>(
371 &self,
372 block_states: &'a [ParameterBlockState],
373 ) -> Result<(ArrayView1<'a, f64>, ArrayView1<'a, f64>, &'a Array1<f64>), LatentSurvivalError>
374 {
375 crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
376 "LatentBinaryFamily",
377 2,
378 block_states.len(),
379 )?;
380 let n = self.event_target.len();
381 let eta_time = &block_states[Self::BLOCK_TIME].eta;
382 let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
383 if eta_time.len() != 3 * n {
384 return Err(LatentSurvivalError::BlockMismatch {
385 reason: format!(
386 "latent binary time eta length mismatch: got {}, expected {}",
387 eta_time.len(),
388 3 * n
389 ),
390 });
391 }
392 if eta_mean.len() != n || self.weights.len() != n {
393 return Err(LatentSurvivalError::BlockMismatch {
394 reason: "latent binary mean eta dimension mismatch".to_string(),
395 });
396 }
397 Ok((
398 eta_time.slice(s![0..n]),
399 eta_time.slice(s![n..2 * n]),
400 eta_mean,
401 ))
402 }
403}
404
405pub fn fixed_latent_hazard_frailty(
406 frailty: &FrailtySpec,
407 context: &str,
408) -> Result<(f64, HazardLoading), String> {
409 fixed_latent_hazard_frailty_typed(frailty, context).map_err(Into::into)
410}
411
412fn fixed_latent_hazard_frailty_typed(
413 frailty: &FrailtySpec,
414 context: &str,
415) -> Result<(f64, HazardLoading), LatentSurvivalError> {
416 match frailty {
417 FrailtySpec::HazardMultiplier {
418 sigma_fixed: Some(sigma),
419 loading,
420 } if sigma.is_finite() && *sigma >= 0.0 => Ok((*sigma, *loading)),
421 FrailtySpec::HazardMultiplier {
422 sigma_fixed: Some(sigma),
423 ..
424 } => Err(LatentSurvivalError::InvalidFrailty {
425 reason: format!(
426 "{context} requires a finite fixed hazard-multiplier sigma >= 0, got {sigma}"
427 ),
428 }),
429 FrailtySpec::HazardMultiplier {
430 sigma_fixed: None, ..
431 } => Err(LatentSurvivalError::InvalidFrailty {
432 reason: format!("{context} currently requires a fixed hazard-multiplier sigma"),
433 }),
434 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
435 reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
436 }),
437 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
438 reason: format!("{context} requires a fixed HazardMultiplier frailty specification"),
439 }),
440 }
441}
442
443pub fn latent_hazard_loading(
444 frailty: &FrailtySpec,
445 context: &str,
446) -> Result<HazardLoading, String> {
447 latent_hazard_loading_typed(frailty, context).map_err(Into::into)
448}
449
450fn latent_hazard_loading_typed(
451 frailty: &FrailtySpec,
452 context: &str,
453) -> Result<HazardLoading, LatentSurvivalError> {
454 match frailty {
455 FrailtySpec::HazardMultiplier { loading, .. } => Ok(*loading),
456 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
457 reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
458 }),
459 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
460 reason: format!("{context} requires a HazardMultiplier frailty specification"),
461 }),
462 }
463}
464
465#[derive(Clone, Copy)]
466struct LatentSurvivalTimeJet {
467 grad_entry: f64,
468 grad_exit: f64,
469 neg_hess_entry: f64,
470 neg_hess_exit: f64,
471}
472
473pub fn fit_latent_survival_terms(
474 data: ArrayView2<'_, f64>,
475 spec: LatentSurvivalTermSpec,
476 frailty: FrailtySpec,
477 options: &BlockwiseFitOptions,
478) -> Result<LatentSurvivalTermFitResult, String> {
479 let latent_sd = validate_latent_survival_inputs(data, &spec, &frailty)?;
480 let hazard_loading = latent_hazard_loading(&frailty, "latent-survival")?;
481 let mean_design =
482 build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
483 let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
484 .map_err(|e| e.to_string())?;
485 let time_prepared = prepare_latent_time_block(
486 &spec.time_block,
487 spec.time_design_right.as_ref(),
488 spec.derivative_guard,
489 )?;
490
491 let n = spec.event_target.len();
492 let time_offset_right = match spec.time_offset_right.as_ref() {
493 Some(offset) => {
494 if offset.len() != n {
495 return Err(format!(
496 "latent survival interval right time offset must have length {n}, got {}",
497 offset.len()
498 ));
499 }
500 offset.clone()
501 }
502 None => Array1::zeros(n),
503 };
504 let unloaded_mass_right = if spec.unloaded_mass_right.is_empty() {
505 Array1::zeros(n)
506 } else {
507 if spec.unloaded_mass_right.len() != n {
508 return Err(format!(
509 "latent survival interval right unloaded mass must have length {n}, got {}",
510 spec.unloaded_mass_right.len()
511 ));
512 }
513 spec.unloaded_mass_right.clone()
514 };
515
516 let family = LatentSurvivalFamily {
517 event_target: spec.event_target.clone(),
518 weights: spec.weights.clone(),
519 latent_sd_fixed: latent_sd,
520 hazard_loading,
521 unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
522 unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
523 unloaded_hazard_exit: spec.unloaded_hazard_exit.clone(),
524 x_time_entry: time_prepared.design_entry.clone(),
525 x_time_exit: time_prepared.design_exit.clone(),
526 x_time_derivative_exit: time_prepared.design_derivative_exit.clone(),
527 x_time_right: time_prepared.design_right.clone(),
528 time_offset_right,
529 unloaded_mass_right,
530 x_mean: mean_design.design.clone(),
531 time_linear_constraints: time_prepared.linear_constraints.clone(),
532 quadctx: Arc::new(QuadratureContext::new()),
533 };
534
535 let mut blocks = vec![
536 build_time_blockspec(&time_prepared, &spec.time_block),
537 build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
538 ];
539 if latent_sd.is_none() {
540 blocks.push(build_log_sigma_blockspec(
541 LEARNABLE_LATENT_SD_SEED,
542 mean_design.design.nrows(),
543 ));
544 }
545 let has_interval_rows = spec
570 .event_target
571 .iter()
572 .any(|&code| code == LATENT_SURVIVAL_EVENT_INTERVAL);
573 if has_interval_rows {
574 let warm_event_target = spec.event_target.mapv(|code| {
575 if code == LATENT_SURVIVAL_EVENT_INTERVAL {
576 0u8
577 } else {
578 code
579 }
580 });
581 let mut warm_family = family.clone();
582 warm_family.event_target = warm_event_target;
583 let warm_fit = fit_custom_family_fixed_log_lambdas(
587 &warm_family,
588 &blocks,
589 options,
590 None,
591 0,
592 None,
593 false,
594 )
595 .map_err(|e| {
596 format!(
597 "latent interval warm start: right-censored-at-L surrogate fit failed \
598 (so the interval fit cannot be safely warm-started; this surrogate is \
599 log-concave and should converge — investigate the surrogate, not the \
600 interval kernel): {e}"
601 )
602 })?;
603 let warm_beta_usable = warm_fit
604 .block_states
605 .iter()
606 .any(|s| s.beta.iter().all(|v| v.is_finite()) && s.beta.iter().any(|&v| v != 0.0));
607 if !warm_beta_usable {
608 return Err(
609 "latent interval warm start: right-censored-at-L surrogate returned a \
610 degenerate (non-finite or all-zero) β across every block; the warm start \
611 cannot seed the interval fit. This indicates the surrogate's time-block \
612 design is rank-deficient or the inner solve stalled at the seed — \
613 investigate the surrogate before retrying the interval fit."
614 .to_string(),
615 );
616 }
617 for (block, state) in blocks.iter_mut().zip(warm_fit.block_states.iter()) {
618 if state.beta.iter().all(|v| v.is_finite()) {
619 block.initial_beta = Some(state.beta.clone());
620 }
621 }
622 }
623 let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
624 let latent_sd = family.latent_sd(&fit.block_states)?;
625 let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
626 Ok(LatentSurvivalTermFitResult {
627 fit,
628 design: mean_design,
629 resolvedspec,
630 latent_sd,
631 baseline_offset_residuals,
632 })
633}
634
635pub fn fit_latent_binary_terms(
636 data: ArrayView2<'_, f64>,
637 spec: LatentBinaryTermSpec,
638 frailty: FrailtySpec,
639 options: &BlockwiseFitOptions,
640) -> Result<LatentBinaryTermFitResult, String> {
641 let latent_sd = validate_latent_binary_inputs(data, &spec, &frailty)?;
642 let (_, hazard_loading) = fixed_latent_hazard_frailty(&frailty, "latent-binary")?;
643 let mean_design =
644 build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
645 let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
646 .map_err(|e| e.to_string())?;
647 let time_prepared = prepare_latent_time_block(&spec.time_block, None, spec.derivative_guard)?;
648
649 let family = LatentBinaryFamily {
650 event_target: spec.event_target.clone(),
651 weights: spec.weights.clone(),
652 latent_sd,
653 hazard_loading,
654 unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
655 unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
656 x_time_entry: time_prepared.design_entry.clone(),
657 x_time_exit: time_prepared.design_exit.clone(),
658 x_mean: mean_design.design.clone(),
659 time_linear_constraints: time_prepared.linear_constraints.clone(),
660 quadctx: Arc::new(QuadratureContext::new()),
661 };
662
663 let blocks = vec![
664 build_time_blockspec(&time_prepared, &spec.time_block),
665 build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
666 ];
667 let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
668 let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
669 Ok(LatentBinaryTermFitResult {
670 fit,
671 design: mean_design,
672 resolvedspec,
673 baseline_offset_residuals,
674 })
675}
676
677struct LatentSurvivalModel;
683
684impl LatentIntervalModel for LatentSurvivalModel {
685 fn context() -> &'static str {
686 "latent-survival"
687 }
688
689 fn allows_interval() -> bool {
690 true
691 }
692
693 fn frailty_policy(
694 frailty: &FrailtySpec,
695 ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
696 match frailty {
697 FrailtySpec::HazardMultiplier {
698 sigma_fixed,
699 loading,
700 } => {
701 if let Some(sigma) = sigma_fixed
702 && (!sigma.is_finite() || *sigma < 0.0)
703 {
704 return Err(LatentSurvivalError::InvalidFrailty {
705 reason: format!(
706 "latent-survival requires a finite hazard-multiplier sigma >= 0, got {sigma}"
707 ),
708 });
709 }
710 Ok(LatentFrailtyResolution {
711 sigma: *sigma_fixed,
712 loading: *loading,
713 })
714 }
715 FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
716 reason: "latent-survival requires HazardMultiplier frailty, not GaussianShift"
717 .to_string(),
718 }),
719 FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
720 reason: "latent-survival requires a HazardMultiplier frailty specification"
721 .to_string(),
722 }),
723 }
724 }
725}
726
727fn validate_latent_survival_inputs(
728 data: ArrayView2<'_, f64>,
729 spec: &LatentSurvivalTermSpec,
730 frailty: &FrailtySpec,
731) -> Result<Option<f64>, LatentSurvivalError> {
732 let row = LatentIntervalRowView {
733 frailty,
734 age_entry: &spec.age_entry,
735 age_exit: &spec.age_exit,
736 event_target: &spec.event_target,
737 weights: &spec.weights,
738 unloaded_mass_entry: &spec.unloaded_mass_entry,
739 unloaded_mass_exit: &spec.unloaded_mass_exit,
740 unloaded_hazard_exit: Some(&spec.unloaded_hazard_exit),
741 mean_offset: &spec.mean_offset,
742 derivative_guard: spec.derivative_guard,
743 time_block: &spec.time_block,
744 };
745 validate_latent_interval_inputs::<LatentSurvivalModel>(data, &row)
746}
747
748pub(crate) fn validate_unloaded_components_for_loading(
749 context: &str,
750 row_index: usize,
751 loading: HazardLoading,
752 unloaded_entry: f64,
753 unloaded_exit: f64,
754 unloaded_hazard: Option<f64>,
755) -> Result<(), LatentSurvivalError> {
756 match loading {
757 HazardLoading::Full => {
758 if unloaded_entry != 0.0
759 || unloaded_exit != 0.0
760 || unloaded_hazard.is_some_and(|hazard| hazard != 0.0)
761 {
762 return Err(LatentSurvivalError::InvalidDataset {
763 reason: format!(
764 "{context} row {} uses full hazard loading, so unloaded components must be exactly zero; got entry_mass={}, exit_mass={}, exit_hazard={}",
765 row_index + 1,
766 unloaded_entry,
767 unloaded_exit,
768 unloaded_hazard.unwrap_or(0.0)
769 ),
770 });
771 }
772 }
773 HazardLoading::LoadedVsUnloaded => {}
774 }
775 Ok(())
776}
777
778struct LatentBinaryModel;
785
786impl LatentIntervalModel for LatentBinaryModel {
787 fn context() -> &'static str {
788 "latent-binary"
789 }
790
791 fn frailty_policy(
792 frailty: &FrailtySpec,
793 ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
794 let (sigma, loading) = fixed_latent_hazard_frailty_typed(frailty, "latent-binary")?;
795 Ok(LatentFrailtyResolution {
796 sigma: Some(sigma),
797 loading,
798 })
799 }
800}
801
802fn validate_latent_binary_inputs(
803 data: ArrayView2<'_, f64>,
804 spec: &LatentBinaryTermSpec,
805 frailty: &FrailtySpec,
806) -> Result<f64, LatentSurvivalError> {
807 let row = LatentIntervalRowView {
808 frailty,
809 age_entry: &spec.age_entry,
810 age_exit: &spec.age_exit,
811 event_target: &spec.event_target,
812 weights: &spec.weights,
813 unloaded_mass_entry: &spec.unloaded_mass_entry,
814 unloaded_mass_exit: &spec.unloaded_mass_exit,
815 unloaded_hazard_exit: None,
816 mean_offset: &spec.mean_offset,
817 derivative_guard: spec.derivative_guard,
818 time_block: &spec.time_block,
819 };
820 validate_latent_interval_inputs::<LatentBinaryModel>(data, &row)?.ok_or_else(|| {
825 LatentSurvivalError::InvalidFrailty {
826 reason: "latent-binary requires a fixed latent sigma".to_string(),
827 }
828 })
829}
830
831fn prepare_latent_time_block(
832 input: &TimeBlockInput,
833 design_right: Option<&DesignMatrix>,
834 derivative_guard: f64,
835) -> Result<PreparedLatentTimeBlock, LatentSurvivalError> {
836 if !input.time_monotonicity.is_coordinate_cone() {
837 return Err(LatentSurvivalError::UnsupportedConfiguration {
838 reason: format!(
839 "latent survival requires a coordinate-cone monotonicity strategy; got {:?}",
840 input.time_monotonicity
841 ),
842 });
843 }
844 let design_entry = input
845 .design_entry
846 .try_to_dense_by_chunks("latent survival entry time design")?;
847 let design_exit = input
848 .design_exit
849 .try_to_dense_by_chunks("latent survival exit time design")?;
850 let design_derivative_exit = input
851 .design_derivative_exit
852 .try_to_dense_by_chunks("latent survival derivative time design")?;
853 let design_right = match design_right {
859 Some(matrix) => {
860 let dense =
861 matrix.try_to_dense_by_chunks("latent survival interval right time design")?;
862 if dense.nrows() != design_exit.nrows() || dense.ncols() != design_exit.ncols() {
863 return Err(LatentSurvivalError::InvalidDataset {
864 reason: format!(
865 "latent survival interval right time design must match exit design shape \
866 {:?}, got {:?}",
867 design_exit.dim(),
868 dense.dim()
869 ),
870 });
871 }
872 dense
873 }
874 None => design_exit.clone(),
875 };
876 let linear_constraints = structural_time_coefficient_constraints(
877 &input.design_derivative_exit,
878 &input.derivative_offset_exit,
879 derivative_guard,
880 )?;
881 let initial_beta = match linear_constraints.as_ref() {
882 Some(constraints) => Some(project_onto_linear_constraints(
887 design_exit.ncols(),
888 constraints,
889 input.initial_beta.as_ref(),
890 )?),
891 None => None,
892 };
893 Ok(PreparedLatentTimeBlock {
894 design_entry,
895 design_exit,
896 design_derivative_exit,
897 design_right,
898 linear_constraints,
899 penalties: input.penalties.clone(),
900 initial_beta,
901 })
902}
903
904fn stack_rows(blocks: &[&Array2<f64>]) -> Array2<f64> {
905 let ncols = blocks.first().map_or(0, |m| m.ncols());
906 let nrows = blocks.iter().map(|m| m.nrows()).sum();
907 let mut out = Array2::<f64>::zeros((nrows, ncols));
908 let mut row = 0usize;
909 for block in blocks {
910 let end = row + block.nrows();
911 out.slice_mut(s![row..end, ..]).assign(block);
912 row = end;
913 }
914 out
915}
916
917fn build_time_blockspec(
918 prepared: &PreparedLatentTimeBlock,
919 input: &TimeBlockInput,
920) -> ParameterBlockSpec {
921 let stacked_design = stack_rows(&[
933 &prepared.design_entry,
934 &prepared.design_exit,
935 &prepared.design_derivative_exit,
936 ]);
937 let stacked_offset = gam_linalg::utils::stack_offsets(&[
938 &input.offset_entry,
939 &input.offset_exit,
940 &input.derivative_offset_exit,
941 ]);
942 ParameterBlockSpec {
943 name: "time_transform".to_string(),
944 design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
945 prepared.design_exit.clone(),
946 ))),
947 offset: input.offset_exit.clone(),
948 penalties: prepared
949 .penalties
950 .iter()
951 .cloned()
952 .map(PenaltyMatrix::Dense)
953 .collect(),
954 nullspace_dims: input.nullspace_dims.clone(),
955 initial_log_lambdas: input
956 .initial_log_lambdas
957 .clone()
958 .unwrap_or_else(|| Array1::zeros(prepared.penalties.len())),
959 initial_beta: prepared.initial_beta.clone(),
960 gauge_priority: 200,
966 jacobian_callback: None,
967 stacked_design: Some(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
968 stacked_design,
969 )))),
970 stacked_offset: Some(stacked_offset),
971 }
972}
973
974fn build_mean_blockspec(design: &TermCollectionDesign, offset: Array1<f64>) -> ParameterBlockSpec {
975 ParameterBlockSpec {
976 name: "mean".to_string(),
977 design: design.design.clone(),
978 offset,
979 penalties: design.penalties_as_penalty_matrix(),
980 nullspace_dims: design.nullspace_dims.clone(),
981 initial_log_lambdas: Array1::zeros(design.penalties.len()),
982 initial_beta: None,
983 gauge_priority: 150,
989 jacobian_callback: None,
990 stacked_design: None,
991 stacked_offset: None,
992 }
993}
994
995const LEARNABLE_LATENT_SD_SEED: f64 = 0.5;
1002
1003fn build_log_sigma_blockspec(initial_sigma: f64, n_obs: usize) -> ParameterBlockSpec {
1004 ParameterBlockSpec {
1005 name: "log_sigma".to_string(),
1006 design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(Array2::from_elem(
1016 (n_obs, 1),
1017 1.0,
1018 )))),
1019 offset: Array1::zeros(n_obs),
1020 penalties: vec![],
1021 nullspace_dims: vec![],
1022 initial_log_lambdas: Array1::zeros(0),
1023 initial_beta: Some(Array1::from_elem(
1024 1,
1025 exp_sigma_eta_for_sigma_scalar(initial_sigma),
1026 )),
1027 gauge_priority: 120,
1030 jacobian_callback: None,
1031 stacked_design: None,
1032 stacked_offset: None,
1033 }
1034}
1035
1036const LATENT_SURVIVAL_PRIMARY_Q_ENTRY: usize = 0;
1037const LATENT_SURVIVAL_PRIMARY_Q_EXIT: usize = 1;
1038const LATENT_SURVIVAL_PRIMARY_QDOT_EXIT: usize = 2;
1039const LATENT_SURVIVAL_PRIMARY_Q_RIGHT: usize = 3;
1046const LATENT_SURVIVAL_PRIMARY_MU: usize = 4;
1047const LATENT_SURVIVAL_PRIMARY_LOG_SIGMA: usize = 5;
1048const LATENT_SURVIVAL_PRIMARY_DIM: usize = 6;
1049
1050use gam_math::jet_partitions::MultiDirJet as LatentMultiDirJet;
1051
1052#[inline]
1068fn latent_unary_derivatives_log(x: f64) -> [f64; 5] {
1069 let x2 = x * x;
1070 let x3 = x2 * x;
1071 let x4 = x3 * x;
1072 [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
1073}
1074
1075#[derive(Clone, Copy, Debug)]
1076struct LatentKernelPrimaryTerm {
1077 coeff: f64,
1078 q_exp: usize,
1079 qdot_power: usize,
1080 tau_exp: usize,
1081 k: usize,
1082}
1083
1084#[derive(Clone, Copy, Debug)]
1085struct LatentKernelPrimaryDirection {
1086 dq: f64,
1087 dqd: f64,
1088 dmu: f64,
1089 dtau: f64,
1090}
1091
1092#[derive(Clone, Copy, Debug)]
1093struct LatentSurvivalPrimaryDirection {
1094 dq_entry: f64,
1095 dq_exit: f64,
1096 dqdot_exit: f64,
1097 dq_right: f64,
1098 dmu: f64,
1099 dlog_sigma: f64,
1100}
1101
1102#[derive(Clone, Copy, Debug)]
1103struct LatentKernelPrimaryState {
1104 q: f64,
1105 qdot: f64,
1106 mu: f64,
1107 sigma: f64,
1108 log_sigma_factor: f64,
1109}
1110
1111fn latent_kernel_accumulate_term(
1112 terms: &mut BTreeMap<(usize, usize, usize, usize), f64>,
1113 term: LatentKernelPrimaryTerm,
1114 scale: f64,
1115) {
1116 if scale == 0.0 || term.coeff == 0.0 {
1117 return;
1118 }
1119 *terms
1120 .entry((term.q_exp, term.qdot_power, term.tau_exp, term.k))
1121 .or_insert(0.0) += scale * term.coeff;
1122}
1123
1124fn latent_kernel_differentiate_terms(
1125 terms: &[LatentKernelPrimaryTerm],
1126 dir: LatentKernelPrimaryDirection,
1127) -> Vec<LatentKernelPrimaryTerm> {
1128 let mut out = BTreeMap::<(usize, usize, usize, usize), f64>::new();
1129 for term in terms {
1130 if dir.dq != 0.0 {
1131 if term.q_exp > 0 {
1132 latent_kernel_accumulate_term(&mut out, *term, dir.dq * term.q_exp as f64);
1133 }
1134 latent_kernel_accumulate_term(
1135 &mut out,
1136 LatentKernelPrimaryTerm {
1137 q_exp: term.q_exp + 1,
1138 k: term.k + 1,
1139 ..*term
1140 },
1141 -dir.dq,
1142 );
1143 }
1144 if dir.dmu != 0.0 {
1145 if term.k > 0 {
1146 latent_kernel_accumulate_term(&mut out, *term, dir.dmu * term.k as f64);
1147 }
1148 latent_kernel_accumulate_term(
1149 &mut out,
1150 LatentKernelPrimaryTerm {
1151 q_exp: term.q_exp + 1,
1152 k: term.k + 1,
1153 ..*term
1154 },
1155 -dir.dmu,
1156 );
1157 }
1158 if dir.dtau != 0.0 {
1159 if term.tau_exp > 0 {
1160 latent_kernel_accumulate_term(&mut out, *term, dir.dtau * term.tau_exp as f64);
1161 }
1162 let kf = term.k as f64;
1163 latent_kernel_accumulate_term(
1164 &mut out,
1165 LatentKernelPrimaryTerm {
1166 tau_exp: term.tau_exp + 2,
1167 ..*term
1168 },
1169 dir.dtau * kf * kf,
1170 );
1171 latent_kernel_accumulate_term(
1172 &mut out,
1173 LatentKernelPrimaryTerm {
1174 q_exp: term.q_exp + 1,
1175 tau_exp: term.tau_exp + 2,
1176 k: term.k + 1,
1177 ..*term
1178 },
1179 -dir.dtau * (2.0 * kf + 1.0),
1180 );
1181 latent_kernel_accumulate_term(
1182 &mut out,
1183 LatentKernelPrimaryTerm {
1184 q_exp: term.q_exp + 2,
1185 tau_exp: term.tau_exp + 2,
1186 k: term.k + 2,
1187 ..*term
1188 },
1189 dir.dtau,
1190 );
1191 }
1192 if dir.dqd != 0.0 && term.qdot_power > 0 {
1193 latent_kernel_accumulate_term(
1194 &mut out,
1195 LatentKernelPrimaryTerm {
1196 qdot_power: term.qdot_power - 1,
1197 ..*term
1198 },
1199 dir.dqd * term.qdot_power as f64,
1200 );
1201 }
1202 }
1203 out.into_iter()
1204 .filter_map(|((q_exp, qdot_power, tau_exp, k), coeff)| {
1205 (coeff != 0.0).then_some(LatentKernelPrimaryTerm {
1206 coeff,
1207 q_exp,
1208 qdot_power,
1209 tau_exp,
1210 k,
1211 })
1212 })
1213 .collect()
1214}
1215
1216fn latent_kernel_term_lists_for_directions(
1217 base_terms: &[LatentKernelPrimaryTerm],
1218 directions: &[LatentKernelPrimaryDirection],
1219) -> Vec<Vec<LatentKernelPrimaryTerm>> {
1220 fn build_mask(
1221 mask: usize,
1222 base_terms: &[LatentKernelPrimaryTerm],
1223 directions: &[LatentKernelPrimaryDirection],
1224 cache: &mut [Option<Vec<LatentKernelPrimaryTerm>>],
1225 ) -> Vec<LatentKernelPrimaryTerm> {
1226 if let Some(existing) = &cache[mask] {
1227 return existing.clone();
1228 }
1229 let built = if mask == 0 {
1230 base_terms.to_vec()
1231 } else {
1232 let bit = 1usize << mask.trailing_zeros();
1233 let prev = build_mask(mask ^ bit, base_terms, directions, cache);
1234 latent_kernel_differentiate_terms(&prev, directions[bit.trailing_zeros() as usize])
1235 };
1236 cache[mask] = Some(built.clone());
1237 built
1238 }
1239
1240 let mut cache = vec![None; 1usize << directions.len()];
1241 (0..cache.len())
1242 .map(|mask| build_mask(mask, base_terms, directions, &mut cache))
1243 .collect()
1244}
1245
1246fn latent_kernel_sum_log_jet(
1247 quadctx: &QuadratureContext,
1248 base_terms: &[LatentKernelPrimaryTerm],
1249 state: LatentKernelPrimaryState,
1250 directions: &[LatentKernelPrimaryDirection],
1251 context: &str,
1252) -> Result<LatentMultiDirJet, LatentSurvivalError> {
1253 let term_lists = latent_kernel_term_lists_for_directions(base_terms, directions);
1254 let max_k = term_lists
1255 .iter()
1256 .flat_map(|terms| terms.iter().map(|term| term.k))
1257 .max()
1258 .unwrap_or(0);
1259 let bundle =
1260 log_kernel_bundle(quadctx, state.q.exp(), state.mu, state.sigma, max_k).map_err(|e| {
1261 LatentSurvivalError::NumericalFailure {
1262 reason: format!("{context} kernel evaluation failed: {e}"),
1263 }
1264 })?;
1265
1266 let evaluate_terms =
1267 |terms: &[LatentKernelPrimaryTerm]| -> Result<(f64, f64), LatentSurvivalError> {
1268 let mut log_mags = Vec::new();
1269 let mut signs = Vec::new();
1270 for term in terms {
1271 if term.coeff == 0.0 {
1272 continue;
1273 }
1274 if term.qdot_power > 0 && !(state.qdot.is_finite() && state.qdot > 0.0) {
1275 return Err(LatentSurvivalError::NumericalFailure {
1276 reason: format!(
1277 "{context} requires positive finite qdot for exact-event directional terms, got {}",
1278 state.qdot
1279 ),
1280 });
1281 }
1282 let log_qdot = if term.qdot_power > 0 {
1283 state.qdot.ln()
1284 } else {
1285 0.0
1286 };
1287 let log_mag = term.coeff.abs().ln()
1288 + term.q_exp as f64 * state.q
1289 + term.tau_exp as f64 * state.log_sigma_factor
1290 + term.qdot_power as f64 * log_qdot
1291 + bundle.get(term.k);
1292 log_mags.push(log_mag);
1293 signs.push(term.coeff.signum());
1294 }
1295 if log_mags.is_empty() {
1296 return Ok((f64::NEG_INFINITY, 0.0));
1297 }
1298 Ok(signed_log_sum_exp(&log_mags, &signs))
1299 };
1300
1301 let (base_log_sum, base_sign) = evaluate_terms(&term_lists[0])?;
1302 if !(base_log_sum.is_finite() && base_sign > 0.0) {
1303 return Err(LatentSurvivalError::NumericalFailure {
1304 reason: format!("{context} produced a non-positive signed kernel sum"),
1305 });
1306 }
1307
1308 let mut normalized = LatentMultiDirJet::constant(directions.len(), 1.0);
1309 for mask in 1..term_lists.len() {
1310 let (log_abs, sign) = evaluate_terms(&term_lists[mask])?;
1311 normalized.coeffs[mask] = if !log_abs.is_finite() || sign == 0.0 {
1312 0.0
1313 } else {
1314 sign * (log_abs - base_log_sum).exp()
1315 };
1316 }
1317
1318 let mut out = normalized.compose_unary(latent_unary_derivatives_log(1.0));
1319 out.coeffs[0] += base_log_sum;
1320 Ok(out)
1321}
1322
1323fn latent_survival_basis_direction(primary_idx: usize) -> LatentSurvivalPrimaryDirection {
1324 match primary_idx {
1325 LATENT_SURVIVAL_PRIMARY_Q_ENTRY => LatentSurvivalPrimaryDirection {
1326 dq_entry: 1.0,
1327 dq_exit: 0.0,
1328 dqdot_exit: 0.0,
1329 dq_right: 0.0,
1330 dmu: 0.0,
1331 dlog_sigma: 0.0,
1332 },
1333 LATENT_SURVIVAL_PRIMARY_Q_EXIT => LatentSurvivalPrimaryDirection {
1334 dq_entry: 0.0,
1335 dq_exit: 1.0,
1336 dqdot_exit: 0.0,
1337 dq_right: 0.0,
1338 dmu: 0.0,
1339 dlog_sigma: 0.0,
1340 },
1341 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT => LatentSurvivalPrimaryDirection {
1342 dq_entry: 0.0,
1343 dq_exit: 0.0,
1344 dqdot_exit: 1.0,
1345 dq_right: 0.0,
1346 dmu: 0.0,
1347 dlog_sigma: 0.0,
1348 },
1349 LATENT_SURVIVAL_PRIMARY_Q_RIGHT => LatentSurvivalPrimaryDirection {
1350 dq_entry: 0.0,
1351 dq_exit: 0.0,
1352 dqdot_exit: 0.0,
1353 dq_right: 1.0,
1354 dmu: 0.0,
1355 dlog_sigma: 0.0,
1356 },
1357 LATENT_SURVIVAL_PRIMARY_MU => LatentSurvivalPrimaryDirection {
1358 dq_entry: 0.0,
1359 dq_exit: 0.0,
1360 dqdot_exit: 0.0,
1361 dq_right: 0.0,
1362 dmu: 1.0,
1363 dlog_sigma: 0.0,
1364 },
1365 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA => LatentSurvivalPrimaryDirection {
1366 dq_entry: 0.0,
1367 dq_exit: 0.0,
1368 dqdot_exit: 0.0,
1369 dq_right: 0.0,
1370 dmu: 0.0,
1371 dlog_sigma: 1.0,
1372 },
1373 _ => std::panic::panic_any(format!(
1381 "latent survival primary index out of bounds: primary_idx={primary_idx}, primary_dim={LATENT_SURVIVAL_PRIMARY_DIM}"
1382 )),
1383 }
1384}
1385
1386fn latent_survival_map_entry_direction(
1387 direction: LatentSurvivalPrimaryDirection,
1388) -> LatentKernelPrimaryDirection {
1389 LatentKernelPrimaryDirection {
1390 dq: direction.dq_entry,
1391 dqd: 0.0,
1392 dmu: direction.dmu,
1393 dtau: direction.dlog_sigma,
1394 }
1395}
1396
1397fn latent_survival_map_exit_direction(
1398 direction: LatentSurvivalPrimaryDirection,
1399 event_type: LatentSurvivalEventType,
1400) -> LatentKernelPrimaryDirection {
1401 LatentKernelPrimaryDirection {
1402 dq: direction.dq_exit,
1403 dqd: if matches!(event_type, LatentSurvivalEventType::ExactEvent) {
1404 direction.dqdot_exit
1405 } else {
1406 0.0
1407 },
1408 dmu: direction.dmu,
1409 dtau: direction.dlog_sigma,
1410 }
1411}
1412
1413fn latent_survival_map_left_direction(
1417 direction: LatentSurvivalPrimaryDirection,
1418) -> LatentKernelPrimaryDirection {
1419 LatentKernelPrimaryDirection {
1420 dq: direction.dq_exit,
1421 dqd: 0.0,
1422 dmu: direction.dmu,
1423 dtau: direction.dlog_sigma,
1424 }
1425}
1426
1427fn latent_survival_map_right_direction(
1432 direction: LatentSurvivalPrimaryDirection,
1433) -> LatentKernelPrimaryDirection {
1434 LatentKernelPrimaryDirection {
1435 dq: direction.dq_right,
1436 dqd: 0.0,
1437 dmu: direction.dmu,
1438 dtau: direction.dlog_sigma,
1439 }
1440}
1441
1442fn latent_survival_row_primary_log_jet(
1443 quadctx: &QuadratureContext,
1444 row: &LatentSurvivalRow,
1445 q_entry: f64,
1446 q_exit: f64,
1447 qdot_exit: f64,
1448 q_right: f64,
1449 mu: f64,
1450 sigma: f64,
1451 log_sigma_factor: f64,
1452 directions: &[LatentSurvivalPrimaryDirection],
1453) -> Result<LatentMultiDirJet, String> {
1454 let entry_state = LatentKernelPrimaryState {
1455 q: q_entry,
1456 qdot: 1.0,
1457 mu,
1458 sigma,
1459 log_sigma_factor,
1460 };
1461 let entry_directions = directions
1462 .iter()
1463 .copied()
1464 .map(latent_survival_map_entry_direction)
1465 .collect::<Vec<_>>();
1466
1467 let denominator = latent_kernel_sum_log_jet(
1468 quadctx,
1469 &[LatentKernelPrimaryTerm {
1470 coeff: 1.0,
1471 q_exp: 0,
1472 qdot_power: 0,
1473 tau_exp: 0,
1474 k: 0,
1475 }],
1476 entry_state,
1477 &entry_directions,
1478 "latent survival denominator",
1479 )?;
1480
1481 let numerator = match row.event_type {
1486 LatentSurvivalEventType::RightCensored | LatentSurvivalEventType::ExactEvent => {
1487 let exit_state = LatentKernelPrimaryState {
1488 q: q_exit,
1489 qdot: qdot_exit,
1490 mu,
1491 sigma,
1492 log_sigma_factor,
1493 };
1494 let exit_directions = directions
1495 .iter()
1496 .copied()
1497 .map(|dir| latent_survival_map_exit_direction(dir, row.event_type))
1498 .collect::<Vec<_>>();
1499 let numerator_terms = match row.event_type {
1500 LatentSurvivalEventType::RightCensored => vec![LatentKernelPrimaryTerm {
1501 coeff: 1.0,
1502 q_exp: 0,
1503 qdot_power: 0,
1504 tau_exp: 0,
1505 k: 0,
1506 }],
1507 LatentSurvivalEventType::ExactEvent => {
1508 let mut terms = Vec::new();
1509 if row.hazard_unloaded > 0.0 {
1510 terms.push(LatentKernelPrimaryTerm {
1511 coeff: row.hazard_unloaded,
1512 q_exp: 0,
1513 qdot_power: 0,
1514 tau_exp: 0,
1515 k: 0,
1516 });
1517 }
1518 terms.push(LatentKernelPrimaryTerm {
1519 coeff: 1.0,
1520 q_exp: 1,
1521 qdot_power: 1,
1522 tau_exp: 0,
1523 k: 1,
1524 });
1525 terms
1526 }
1527 LatentSurvivalEventType::IntervalCensored => {
1528 return Err(
1533 "interval-censored row reached the single-state numerator branch; \
1534 it must take the dedicated two-state branch"
1535 .to_string(),
1536 );
1537 }
1538 };
1539 latent_kernel_sum_log_jet(
1540 quadctx,
1541 &numerator_terms,
1542 exit_state,
1543 &exit_directions,
1544 "latent survival numerator",
1545 )?
1546 }
1547 LatentSurvivalEventType::IntervalCensored => latent_survival_interval_numerator_log_jet(
1548 quadctx,
1549 row,
1550 q_exit,
1551 q_right,
1552 mu,
1553 sigma,
1554 log_sigma_factor,
1555 directions,
1556 )?,
1557 };
1558
1559 let mut total = numerator.add(&denominator.scale(-1.0));
1560 match row.event_type {
1566 LatentSurvivalEventType::IntervalCensored => {
1567 total.coeffs[0] += row.mass_unloaded_entry;
1568 }
1569 _ => {
1570 total.coeffs[0] += -row.mass_unloaded_exit + row.mass_unloaded_entry;
1571 }
1572 }
1573 Ok(total)
1574}
1575
1576fn latent_survival_interval_numerator_log_jet(
1600 quadctx: &QuadratureContext,
1601 row: &LatentSurvivalRow,
1602 q_exit: f64,
1603 q_right: f64,
1604 mu: f64,
1605 sigma: f64,
1606 log_sigma_factor: f64,
1607 directions: &[LatentSurvivalPrimaryDirection],
1608) -> Result<LatentMultiDirJet, String> {
1609 let single_k0 = [LatentKernelPrimaryTerm {
1610 coeff: 1.0,
1611 q_exp: 0,
1612 qdot_power: 0,
1613 tau_exp: 0,
1614 k: 0,
1615 }];
1616
1617 let left_state = LatentKernelPrimaryState {
1618 q: q_exit,
1619 qdot: 1.0,
1620 mu,
1621 sigma,
1622 log_sigma_factor,
1623 };
1624 let right_state = LatentKernelPrimaryState {
1625 q: q_right,
1626 qdot: 1.0,
1627 mu,
1628 sigma,
1629 log_sigma_factor,
1630 };
1631 let left_directions = directions
1632 .iter()
1633 .copied()
1634 .map(latent_survival_map_left_direction)
1635 .collect::<Vec<_>>();
1636 let right_directions = directions
1637 .iter()
1638 .copied()
1639 .map(latent_survival_map_right_direction)
1640 .collect::<Vec<_>>();
1641
1642 let log_left = latent_kernel_sum_log_jet(
1643 quadctx,
1644 &single_k0,
1645 left_state,
1646 &left_directions,
1647 "latent survival interval left boundary",
1648 )?;
1649 let log_right = latent_kernel_sum_log_jet(
1650 quadctx,
1651 &single_k0,
1652 right_state,
1653 &right_directions,
1654 "latent survival interval right boundary",
1655 )?;
1656
1657 let c_left = (-row.mass_unloaded_left).exp();
1661 let c_right = (-row.mass_unloaded_right).exp();
1662 let exp_left_value = log_left.coeff(0).exp();
1663 let exp_right_value = log_right.coeff(0).exp();
1664 let linear_left = log_left.compose_unary([exp_left_value; 5]).scale(c_left);
1665 let linear_right = log_right.compose_unary([exp_right_value; 5]).scale(c_right);
1666
1667 let linear_numerator = linear_left.add(&linear_right.scale(-1.0));
1668 let base = linear_numerator.coeff(0);
1669 if !(base.is_finite() && base > 0.0) {
1670 return Err(LatentSurvivalError::NumericalFailure {
1671 reason: format!(
1672 "latent survival interval numerator must be a positive survival-mass difference, \
1673 got c_L*K0(M_L) - c_R*K0(M_R) = {base}; require M_L < M_R (i.e. L < R)"
1674 ),
1675 }
1676 .into());
1677 }
1678 Ok(linear_numerator.compose_unary(latent_unary_derivatives_log(base)))
1684}
1685
1686fn latent_survival_row_primary_gradient_hessian(
1687 quadctx: &QuadratureContext,
1688 row: &LatentSurvivalRow,
1689 q_entry: f64,
1690 q_exit: f64,
1691 qdot_exit: f64,
1692 q_right: f64,
1693 mu: f64,
1694 sigma: f64,
1695 include_log_sigma: bool,
1696) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
1697 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1698 let mut gradient = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1699 let mut neg_hessian =
1700 Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1701 let active_primary = if include_log_sigma {
1702 LATENT_SURVIVAL_PRIMARY_DIM
1703 } else {
1704 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1705 };
1706 let log_lik = latent_survival_row_primary_log_jet(
1707 quadctx,
1708 row,
1709 q_entry,
1710 q_exit,
1711 qdot_exit,
1712 q_right,
1713 mu,
1714 sigma,
1715 log_sigma_factor,
1716 &[],
1717 )?
1718 .coeff(0);
1719 for a in 0..active_primary {
1720 let dir_a = latent_survival_basis_direction(a);
1721 gradient[a] = latent_survival_row_primary_log_jet(
1722 quadctx,
1723 row,
1724 q_entry,
1725 q_exit,
1726 qdot_exit,
1727 q_right,
1728 mu,
1729 sigma,
1730 log_sigma_factor,
1731 &[dir_a],
1732 )?
1733 .coeff(1);
1734 for b in a..active_primary {
1735 let coeff = latent_survival_row_primary_log_jet(
1736 quadctx,
1737 row,
1738 q_entry,
1739 q_exit,
1740 qdot_exit,
1741 q_right,
1742 mu,
1743 sigma,
1744 log_sigma_factor,
1745 &[dir_a, latent_survival_basis_direction(b)],
1746 )?
1747 .coeff(3);
1748 neg_hessian[[a, b]] = -coeff;
1749 neg_hessian[[b, a]] = -coeff;
1750 }
1751 }
1752 Ok((log_lik, gradient, neg_hessian))
1753}
1754
1755fn latent_survival_row_primary_third_contracted(
1756 quadctx: &QuadratureContext,
1757 row: &LatentSurvivalRow,
1758 q_entry: f64,
1759 q_exit: f64,
1760 qdot_exit: f64,
1761 q_right: f64,
1762 mu: f64,
1763 sigma: f64,
1764 direction: &Array1<f64>,
1765 include_log_sigma: bool,
1766) -> Result<Array2<f64>, String> {
1767 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1768 let active_primary = if include_log_sigma {
1769 LATENT_SURVIVAL_PRIMARY_DIM
1770 } else {
1771 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1772 };
1773 let dir = LatentSurvivalPrimaryDirection {
1774 dq_entry: direction[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1775 dq_exit: direction[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1776 dqdot_exit: direction[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1777 dq_right: direction[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1778 dmu: direction[LATENT_SURVIVAL_PRIMARY_MU],
1779 dlog_sigma: direction[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1780 };
1781 let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1782 for a in 0..active_primary {
1783 let dir_a = latent_survival_basis_direction(a);
1784 for b in a..active_primary {
1785 let coeff = latent_survival_row_primary_log_jet(
1786 quadctx,
1787 row,
1788 q_entry,
1789 q_exit,
1790 qdot_exit,
1791 q_right,
1792 mu,
1793 sigma,
1794 log_sigma_factor,
1795 &[dir_a, latent_survival_basis_direction(b), dir],
1796 )?
1797 .coeff(7);
1798 out[[a, b]] = -coeff;
1799 out[[b, a]] = -coeff;
1800 }
1801 }
1802 Ok(out)
1803}
1804
1805fn latent_survival_row_primary_fourth_contracted(
1806 quadctx: &QuadratureContext,
1807 row: &LatentSurvivalRow,
1808 q_entry: f64,
1809 q_exit: f64,
1810 qdot_exit: f64,
1811 q_right: f64,
1812 mu: f64,
1813 sigma: f64,
1814 direction_u: &Array1<f64>,
1815 direction_v: &Array1<f64>,
1816 include_log_sigma: bool,
1817) -> Result<Array2<f64>, String> {
1818 let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1819 let active_primary = if include_log_sigma {
1820 LATENT_SURVIVAL_PRIMARY_DIM
1821 } else {
1822 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1823 };
1824 let dir_u = LatentSurvivalPrimaryDirection {
1825 dq_entry: direction_u[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1826 dq_exit: direction_u[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1827 dqdot_exit: direction_u[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1828 dq_right: direction_u[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1829 dmu: direction_u[LATENT_SURVIVAL_PRIMARY_MU],
1830 dlog_sigma: direction_u[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1831 };
1832 let dir_v = LatentSurvivalPrimaryDirection {
1833 dq_entry: direction_v[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1834 dq_exit: direction_v[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1835 dqdot_exit: direction_v[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1836 dq_right: direction_v[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1837 dmu: direction_v[LATENT_SURVIVAL_PRIMARY_MU],
1838 dlog_sigma: direction_v[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1839 };
1840 let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1841 for a in 0..active_primary {
1842 let dir_a = latent_survival_basis_direction(a);
1843 for b in a..active_primary {
1844 let coeff = latent_survival_row_primary_log_jet(
1845 quadctx,
1846 row,
1847 q_entry,
1848 q_exit,
1849 qdot_exit,
1850 q_right,
1851 mu,
1852 sigma,
1853 log_sigma_factor,
1854 &[dir_a, latent_survival_basis_direction(b), dir_u, dir_v],
1855 )?
1856 .coeff(15);
1857 out[[a, b]] = -coeff;
1858 out[[b, a]] = -coeff;
1859 }
1860 }
1861 Ok(out)
1862}
1863
1864#[derive(Clone)]
1865struct LatentSurvivalJointSlices {
1866 time: std::ops::Range<usize>,
1867 mean: std::ops::Range<usize>,
1868 log_sigma: Option<std::ops::Range<usize>>,
1869 total: usize,
1870}
1871
1872#[derive(Clone)]
1873struct LatentSurvivalJointGradientAccum {
1874 ll: f64,
1875 gradient: Array1<f64>,
1876}
1877
1878#[derive(Clone)]
1879struct LatentSurvivalJointDenseAccum {
1880 ll: f64,
1881 gradient: Array1<f64>,
1882 hessian: Array2<f64>,
1883}
1884
1885#[derive(Clone)]
1886struct LatentSurvivalDenseHessianAccum {
1887 hessian: Array2<f64>,
1888}
1889
1890fn deterministic_latent_survival_row_reduction<Acc, Init, Process, Combine>(
1894 n_rows: usize,
1895 init: Init,
1896 process_row: Process,
1897 mut combine: Combine,
1898) -> Result<Acc, String>
1899where
1900 Acc: Send,
1901 Init: Fn() -> Acc + Sync,
1902 Process: Fn(usize, &mut Acc) -> Result<(), String> + Sync,
1903 Combine: FnMut(&mut Acc, Acc),
1904{
1905 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1906
1907 const TARGET_CHUNK_COUNT: usize = 32;
1908 if n_rows == 0 {
1909 return Ok(init());
1910 }
1911 let chunk_size = n_rows.div_ceil(TARGET_CHUNK_COUNT).max(1);
1912 let n_chunks = n_rows.div_ceil(chunk_size);
1913 let chunk_accumulators: Vec<Acc> = (0..n_chunks)
1914 .into_par_iter()
1915 .map(|chunk_idx| -> Result<Acc, String> {
1916 let start = chunk_idx * chunk_size;
1917 let end = (start + chunk_size).min(n_rows);
1918 let mut acc = init();
1919 for row_idx in start..end {
1920 process_row(row_idx, &mut acc)?;
1921 }
1922 Ok(acc)
1923 })
1924 .collect::<Result<Vec<_>, String>>()?;
1925
1926 let mut total = init();
1927 for acc in chunk_accumulators {
1928 combine(&mut total, acc);
1929 }
1930 Ok(total)
1931}
1932
1933impl LatentSurvivalFamily {
1934 fn build_row_at(
1942 &self,
1943 row_idx: usize,
1944 q_entry: f64,
1945 q_exit: f64,
1946 qdot_exit: f64,
1947 q_right: f64,
1948 ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
1949 let event_type = latent_survival_event_type_for(self.event_target[row_idx]);
1950 build_latent_survival_row(
1951 row_idx,
1952 self.hazard_loading,
1953 event_type,
1954 q_entry,
1955 q_exit,
1956 qdot_exit,
1957 q_right,
1958 self.unloaded_mass_entry[row_idx],
1959 self.unloaded_mass_exit[row_idx],
1960 self.unloaded_mass_right[row_idx],
1961 self.unloaded_hazard_exit[row_idx],
1962 )
1963 }
1964
1965 fn joint_slices(&self) -> LatentSurvivalJointSlices {
1966 let p_time = self.x_time_exit.ncols();
1967 let p_mean = self.x_mean.ncols();
1968 let time = 0..p_time;
1969 let mean = p_time..p_time + p_mean;
1970 let log_sigma = self
1971 .latent_sd_fixed
1972 .is_none()
1973 .then_some((p_time + p_mean)..(p_time + p_mean + 1));
1974 LatentSurvivalJointSlices {
1975 total: log_sigma
1976 .as_ref()
1977 .map_or(p_time + p_mean, |range| range.end),
1978 time,
1979 mean,
1980 log_sigma,
1981 }
1982 }
1983
1984 fn row_primary_direction_from_flat(
1985 &self,
1986 row: usize,
1987 slices: &LatentSurvivalJointSlices,
1988 d_beta_flat: &Array1<f64>,
1989 ) -> Array1<f64> {
1990 let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1991 let d_time = d_beta_flat.slice(s![slices.time.clone()]);
1992 out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
1993 out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
1994 out[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT] = self.x_time_derivative_exit.row(row).dot(&d_time);
1995 out[LATENT_SURVIVAL_PRIMARY_Q_RIGHT] = self.x_time_right.row(row).dot(&d_time);
1996 out[LATENT_SURVIVAL_PRIMARY_MU] = self
1997 .x_mean
1998 .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
1999 if let Some(range) = &slices.log_sigma {
2000 out[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA] = d_beta_flat[range.start];
2001 }
2002 out
2003 }
2004
2005 fn joint_block_ranges(&self) -> Vec<std::ops::Range<usize>> {
2006 let slices = self.joint_slices();
2007 let mut ranges = vec![slices.time.clone(), slices.mean.clone()];
2008 if let Some(log_sigma) = slices.log_sigma {
2009 ranges.push(log_sigma);
2010 }
2011 ranges
2012 }
2013
2014 fn add_pullback_primary_gradient(
2015 &self,
2016 target: &mut Array1<f64>,
2017 row: usize,
2018 slices: &LatentSurvivalJointSlices,
2019 primary_gradient: &Array1<f64>,
2020 weight: f64,
2021 ) -> Result<(), String> {
2022 for (primary_idx, time_vec) in [
2023 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2024 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2025 (
2026 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2027 self.x_time_derivative_exit.row(row),
2028 ),
2029 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2030 ] {
2031 let scale = weight * primary_gradient[primary_idx];
2032 if scale == 0.0 {
2033 continue;
2034 }
2035 for i in 0..time_vec.len() {
2036 let xi = time_vec[i];
2037 if xi != 0.0 {
2038 target[slices.time.start + i] += scale * xi;
2039 }
2040 }
2041 }
2042
2043 let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
2044 if mean_scale != 0.0 {
2045 self.x_mean
2046 .axpy_row_into(
2047 row,
2048 mean_scale,
2049 &mut target.slice_mut(s![slices.mean.clone()]),
2050 )
2051 .map_err(|error| {
2052 format!(
2053 "latent survival mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
2054 slices.mean,
2055 target.len(),
2056 self.x_mean.ncols()
2057 )
2058 })?;
2059 }
2060
2061 if let Some(log_sigma) = &slices.log_sigma {
2062 target[log_sigma.start] += weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA];
2063 }
2064 Ok(())
2065 }
2066
2067 fn add_pullback_primary_hessian(
2068 &self,
2069 target: &mut Array2<f64>,
2070 row: usize,
2071 slices: &LatentSurvivalJointSlices,
2072 primary_hessian: &Array2<f64>,
2073 ) -> Result<(), String> {
2074 let time_weights = [
2075 primary_hessian[[
2076 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2077 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2078 ]],
2079 primary_hessian[[
2080 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2081 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2082 ]],
2083 primary_hessian[[
2084 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2085 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2086 ]],
2087 primary_hessian[[
2088 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2089 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2090 ]],
2091 ];
2092 let time_cross_weights = [
2093 (
2094 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2095 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2096 &self.x_time_entry,
2097 &self.x_time_exit,
2098 ),
2099 (
2100 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2101 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2102 &self.x_time_entry,
2103 &self.x_time_derivative_exit,
2104 ),
2105 (
2106 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2107 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2108 &self.x_time_exit,
2109 &self.x_time_derivative_exit,
2110 ),
2111 (
2112 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2113 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2114 &self.x_time_entry,
2115 &self.x_time_right,
2116 ),
2117 (
2118 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2119 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2120 &self.x_time_exit,
2121 &self.x_time_right,
2122 ),
2123 (
2124 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2125 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2126 &self.x_time_derivative_exit,
2127 &self.x_time_right,
2128 ),
2129 ];
2130 {
2131 let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
2132 dense_outer_accumulate(time_target, time_weights[0], self.x_time_entry.row(row));
2133 dense_outer_accumulate(time_target, time_weights[1], self.x_time_exit.row(row));
2134 dense_outer_accumulate(
2135 time_target,
2136 time_weights[2],
2137 self.x_time_derivative_exit.row(row),
2138 );
2139 dense_outer_accumulate(time_target, time_weights[3], self.x_time_right.row(row));
2140 for (a, b, lhs, rhs) in time_cross_weights {
2141 let weight = primary_hessian[[a, b]];
2142 if weight == 0.0 {
2143 continue;
2144 }
2145 dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2146 }
2147 }
2148
2149 let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2150 self.x_mean
2151 .syr_row_into_view(
2152 row,
2153 mean_weight,
2154 target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
2155 )
2156 .map_err(|error| {
2157 format!(
2158 "latent survival mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
2159 slices.mean,
2160 target.dim(),
2161 self.x_mean.ncols()
2162 )
2163 })?;
2164
2165 let mean_row = self
2166 .x_mean
2167 .try_row_chunk(row..row + 1)
2168 .map_err(|error| {
2169 format!(
2170 "latent survival mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
2171 self.x_mean.nrows(),
2172 self.x_mean.ncols()
2173 )
2174 })?;
2175 let mean_vec = mean_row.row(0);
2176 let time_mean_weights = [
2177 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2178 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2179 (
2180 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2181 self.x_time_derivative_exit.row(row),
2182 ),
2183 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2184 ];
2185 for (primary_idx, time_vec) in time_mean_weights {
2186 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
2187 if weight == 0.0 {
2188 continue;
2189 }
2190 for i in 0..time_vec.len() {
2191 let xi = time_vec[i];
2192 if xi == 0.0 {
2193 continue;
2194 }
2195 for j in 0..mean_vec.len() {
2196 let xj = mean_vec[j];
2197 if xj == 0.0 {
2198 continue;
2199 }
2200 target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
2201 target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
2202 }
2203 }
2204 }
2205
2206 if let Some(log_sigma) = &slices.log_sigma {
2207 let sigma_idx = log_sigma.start;
2208 target[[sigma_idx, sigma_idx]] += primary_hessian[[
2209 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2210 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2211 ]];
2212
2213 for (primary_idx, time_vec) in [
2214 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2215 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2216 (
2217 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2218 self.x_time_derivative_exit.row(row),
2219 ),
2220 (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2221 ] {
2222 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_LOG_SIGMA]];
2223 if weight == 0.0 {
2224 continue;
2225 }
2226 for i in 0..time_vec.len() {
2227 let xi = time_vec[i];
2228 if xi == 0.0 {
2229 continue;
2230 }
2231 target[[slices.time.start + i, sigma_idx]] += weight * xi;
2232 target[[sigma_idx, slices.time.start + i]] += weight * xi;
2233 }
2234 }
2235
2236 let mean_sigma_weight = primary_hessian[[
2237 LATENT_SURVIVAL_PRIMARY_MU,
2238 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2239 ]];
2240 if mean_sigma_weight != 0.0 {
2241 for j in 0..mean_vec.len() {
2242 let xj = mean_vec[j];
2243 if xj == 0.0 {
2244 continue;
2245 }
2246 target[[slices.mean.start + j, sigma_idx]] += mean_sigma_weight * xj;
2247 target[[sigma_idx, slices.mean.start + j]] += mean_sigma_weight * xj;
2248 }
2249 }
2250 }
2251 Ok(())
2252 }
2253
2254 fn evaluate_exact_newton_joint_gradient_dense(
2255 &self,
2256 block_states: &[ParameterBlockState],
2257 ) -> Result<(f64, Array1<f64>), String> {
2258 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2259 let q_right = self.time_q_right(block_states)?;
2260 let sigma = self.latent_sd(block_states)?;
2261 let slices = self.joint_slices();
2262 let include_log_sigma = slices.log_sigma.is_some();
2263 let total = slices.total;
2264 let acc = deterministic_latent_survival_row_reduction(
2265 self.event_target.len(),
2266 || LatentSurvivalJointGradientAccum {
2267 ll: 0.0,
2268 gradient: Array1::<f64>::zeros(total),
2269 },
2270 |row_idx, acc| {
2271 let wi = self.weights[row_idx];
2272 if wi <= MIN_WEIGHT {
2273 return Ok(());
2274 }
2275 let row = self.build_row_at(
2276 row_idx,
2277 q_entry[row_idx],
2278 q_exit[row_idx],
2279 qdot_exit[row_idx],
2280 q_right[row_idx],
2281 )?;
2282 let (row_ll, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2283 &self.quadctx,
2284 &row,
2285 q_entry[row_idx],
2286 q_exit[row_idx],
2287 qdot_exit[row_idx],
2288 q_right[row_idx],
2289 mu[row_idx],
2290 sigma,
2291 include_log_sigma,
2292 )?;
2293 acc.ll += wi * row_ll;
2294 self.add_pullback_primary_gradient(
2295 &mut acc.gradient,
2296 row_idx,
2297 &slices,
2298 &primary_gradient,
2299 wi,
2300 )?;
2301 Ok(())
2302 },
2303 |total_acc, chunk_acc| {
2304 total_acc.ll += chunk_acc.ll;
2305 total_acc.gradient += &chunk_acc.gradient;
2306 },
2307 )?;
2308 Ok((acc.ll, acc.gradient))
2309 }
2310
2311 pub fn offset_channel_residuals(
2340 &self,
2341 block_states: &[ParameterBlockState],
2342 ) -> Result<crate::survival::OffsetChannelResiduals, String> {
2343 let n = self.event_target.len();
2344 if block_states.is_empty() {
2345 log::warn!(
2350 "LatentSurvivalFamily::offset_channel_residuals: block_states is empty \
2351 (degraded fit); returning zero offset residuals (n={n})"
2352 );
2353 return Ok(crate::survival::OffsetChannelResiduals {
2354 exit: Array1::<f64>::zeros(n),
2355 entry: Array1::<f64>::zeros(n),
2356 derivative: Array1::<f64>::zeros(n),
2357 right: Array1::<f64>::zeros(n),
2358 });
2359 }
2360 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2361 let q_right = self.time_q_right(block_states)?;
2362 let sigma = self.latent_sd(block_states)?;
2363 let include_log_sigma = self.joint_slices().log_sigma.is_some();
2364 let mut entry = Array1::<f64>::zeros(n);
2365 let mut exit = Array1::<f64>::zeros(n);
2366 let mut derivative = Array1::<f64>::zeros(n);
2367 let mut right = Array1::<f64>::zeros(n);
2368 for row_idx in 0..n {
2369 let wi = self.weights[row_idx];
2370 if wi <= MIN_WEIGHT {
2371 continue;
2372 }
2373 let row = self.build_row_at(
2374 row_idx,
2375 q_entry[row_idx],
2376 q_exit[row_idx],
2377 qdot_exit[row_idx],
2378 q_right[row_idx],
2379 )?;
2380 let (_, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2381 &self.quadctx,
2382 &row,
2383 q_entry[row_idx],
2384 q_exit[row_idx],
2385 qdot_exit[row_idx],
2386 q_right[row_idx],
2387 mu[row_idx],
2388 sigma,
2389 include_log_sigma,
2390 )?;
2391 entry[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
2393 exit[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
2394 derivative[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
2395 right[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
2402 }
2403 Ok(crate::survival::OffsetChannelResiduals {
2404 exit,
2405 entry,
2406 derivative,
2407 right,
2408 })
2409 }
2410
2411 fn add_pullback_primary_block_diagonals(
2416 &self,
2417 row: usize,
2418 primary_hessian: &Array2<f64>,
2419 time_target: &mut Array2<f64>,
2420 mean_target: &mut Array2<f64>,
2421 log_sigma_target: Option<&mut Array2<f64>>,
2422 ) -> Result<(), String> {
2423 let h = primary_hessian;
2424 dense_outer_accumulate(
2428 time_target,
2429 h[[
2430 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2431 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2432 ]],
2433 self.x_time_entry.row(row),
2434 );
2435 dense_outer_accumulate(
2436 time_target,
2437 h[[
2438 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2439 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2440 ]],
2441 self.x_time_exit.row(row),
2442 );
2443 dense_outer_accumulate(
2444 time_target,
2445 h[[
2446 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2447 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2448 ]],
2449 self.x_time_derivative_exit.row(row),
2450 );
2451 dense_outer_accumulate(
2452 time_target,
2453 h[[
2454 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2455 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2456 ]],
2457 self.x_time_right.row(row),
2458 );
2459 for (a, b, lhs, rhs) in [
2460 (
2461 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2462 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2463 &self.x_time_entry,
2464 &self.x_time_exit,
2465 ),
2466 (
2467 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2468 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2469 &self.x_time_entry,
2470 &self.x_time_derivative_exit,
2471 ),
2472 (
2473 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2474 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2475 &self.x_time_exit,
2476 &self.x_time_derivative_exit,
2477 ),
2478 (
2479 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2480 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2481 &self.x_time_entry,
2482 &self.x_time_right,
2483 ),
2484 (
2485 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2486 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2487 &self.x_time_exit,
2488 &self.x_time_right,
2489 ),
2490 (
2491 LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2492 LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2493 &self.x_time_derivative_exit,
2494 &self.x_time_right,
2495 ),
2496 ] {
2497 let weight = h[[a, b]];
2498 if weight == 0.0 {
2499 continue;
2500 }
2501 dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2502 }
2503 let mean_weight = h[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2505 self.x_mean
2506 .syr_row_into_view(row, mean_weight, mean_target.view_mut())
2507 .map_err(|error| {
2508 format!(
2509 "latent survival mean block-diagonal pullback dimension mismatch: row={row}, mean_target_dim={:?}, x_mean_cols={}, error={error}",
2510 mean_target.dim(),
2511 self.x_mean.ncols()
2512 )
2513 })?;
2514 if let Some(target) = log_sigma_target {
2516 target[[0, 0]] += h[[
2517 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2518 LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2519 ]];
2520 }
2521 Ok(())
2522 }
2523
2524 fn evaluate_exact_newton_block_diagonals(
2529 &self,
2530 block_states: &[ParameterBlockState],
2531 ) -> Result<
2532 (
2533 f64,
2534 Array1<f64>,
2535 Array2<f64>,
2536 Array2<f64>,
2537 Option<Array2<f64>>,
2538 ),
2539 String,
2540 > {
2541 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2542 let q_right = self.time_q_right(block_states)?;
2543 let sigma = self.latent_sd(block_states)?;
2544 let slices = self.joint_slices();
2545 let include_log_sigma = slices.log_sigma.is_some();
2546 let mut ll = 0.0;
2547 let mut gradient = Array1::<f64>::zeros(slices.total);
2548 let p_time = slices.time.len();
2549 let p_mean = slices.mean.len();
2550 let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
2551 let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
2552 let mut hess_log_sigma = if include_log_sigma {
2553 Some(Array2::<f64>::zeros((1, 1)))
2554 } else {
2555 None
2556 };
2557 for row_idx in 0..self.event_target.len() {
2558 let wi = self.weights[row_idx];
2559 if wi <= MIN_WEIGHT {
2560 continue;
2561 }
2562 let row = self.build_row_at(
2563 row_idx,
2564 q_entry[row_idx],
2565 q_exit[row_idx],
2566 qdot_exit[row_idx],
2567 q_right[row_idx],
2568 )?;
2569 let (row_ll, primary_gradient, primary_hessian) =
2570 latent_survival_row_primary_gradient_hessian(
2571 &self.quadctx,
2572 &row,
2573 q_entry[row_idx],
2574 q_exit[row_idx],
2575 qdot_exit[row_idx],
2576 q_right[row_idx],
2577 mu[row_idx],
2578 sigma,
2579 include_log_sigma,
2580 )?;
2581 ll += wi * row_ll;
2582 self.add_pullback_primary_gradient(
2583 &mut gradient,
2584 row_idx,
2585 &slices,
2586 &primary_gradient,
2587 wi,
2588 )?;
2589 self.add_pullback_primary_block_diagonals(
2590 row_idx,
2591 &(wi * primary_hessian),
2592 &mut hess_time,
2593 &mut hess_mean,
2594 hess_log_sigma.as_mut(),
2595 )?;
2596 }
2597 Ok((ll, gradient, hess_time, hess_mean, hess_log_sigma))
2598 }
2599
2600 fn evaluate_exact_newton_joint_dense(
2601 &self,
2602 block_states: &[ParameterBlockState],
2603 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
2604 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2605 let q_right = self.time_q_right(block_states)?;
2606 let sigma = self.latent_sd(block_states)?;
2607 let slices = self.joint_slices();
2608 let include_log_sigma = slices.log_sigma.is_some();
2609 let total = slices.total;
2610 let acc = deterministic_latent_survival_row_reduction(
2611 self.event_target.len(),
2612 || LatentSurvivalJointDenseAccum {
2613 ll: 0.0,
2614 gradient: Array1::<f64>::zeros(total),
2615 hessian: Array2::<f64>::zeros((total, total)),
2616 },
2617 |row_idx, acc| {
2618 let wi = self.weights[row_idx];
2619 if wi <= MIN_WEIGHT {
2620 return Ok(());
2621 }
2622 let row = self.build_row_at(
2623 row_idx,
2624 q_entry[row_idx],
2625 q_exit[row_idx],
2626 qdot_exit[row_idx],
2627 q_right[row_idx],
2628 )?;
2629 let (row_ll, primary_gradient, primary_hessian) =
2630 latent_survival_row_primary_gradient_hessian(
2631 &self.quadctx,
2632 &row,
2633 q_entry[row_idx],
2634 q_exit[row_idx],
2635 qdot_exit[row_idx],
2636 q_right[row_idx],
2637 mu[row_idx],
2638 sigma,
2639 include_log_sigma,
2640 )?;
2641 acc.ll += wi * row_ll;
2642 self.add_pullback_primary_gradient(
2643 &mut acc.gradient,
2644 row_idx,
2645 &slices,
2646 &primary_gradient,
2647 wi,
2648 )?;
2649 self.add_pullback_primary_hessian(
2650 &mut acc.hessian,
2651 row_idx,
2652 &slices,
2653 &(wi * primary_hessian),
2654 )?;
2655 Ok(())
2656 },
2657 |total_acc, chunk_acc| {
2658 total_acc.ll += chunk_acc.ll;
2659 total_acc.gradient += &chunk_acc.gradient;
2660 total_acc.hessian += &chunk_acc.hessian;
2661 },
2662 )?;
2663 Ok((acc.ll, acc.gradient, acc.hessian))
2664 }
2665
2666 fn exact_newton_joint_hessian_directional_derivative_dense(
2667 &self,
2668 block_states: &[ParameterBlockState],
2669 d_beta_flat: &Array1<f64>,
2670 ) -> Result<Array2<f64>, String> {
2671 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2672 let q_right = self.time_q_right(block_states)?;
2673 let sigma = self.latent_sd(block_states)?;
2674 let slices = self.joint_slices();
2675 if d_beta_flat.len() != slices.total {
2676 return Err(format!(
2677 "latent survival joint dH direction length mismatch: got {}, expected {}",
2678 d_beta_flat.len(),
2679 slices.total
2680 ));
2681 }
2682 let include_log_sigma = slices.log_sigma.is_some();
2683 let total = slices.total;
2684 let acc = deterministic_latent_survival_row_reduction(
2685 self.event_target.len(),
2686 || LatentSurvivalDenseHessianAccum {
2687 hessian: Array2::<f64>::zeros((total, total)),
2688 },
2689 |row_idx, acc| {
2690 let wi = self.weights[row_idx];
2691 if wi <= MIN_WEIGHT {
2692 return Ok(());
2693 }
2694 let row = self.build_row_at(
2695 row_idx,
2696 q_entry[row_idx],
2697 q_exit[row_idx],
2698 qdot_exit[row_idx],
2699 q_right[row_idx],
2700 )?;
2701 let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
2702 let third = latent_survival_row_primary_third_contracted(
2703 &self.quadctx,
2704 &row,
2705 q_entry[row_idx],
2706 q_exit[row_idx],
2707 qdot_exit[row_idx],
2708 q_right[row_idx],
2709 mu[row_idx],
2710 sigma,
2711 &direction,
2712 include_log_sigma,
2713 )?;
2714 self.add_pullback_primary_hessian(
2715 &mut acc.hessian,
2716 row_idx,
2717 &slices,
2718 &(wi * third),
2719 )?;
2720 Ok(())
2721 },
2722 |total_acc, chunk_acc| {
2723 total_acc.hessian += &chunk_acc.hessian;
2724 },
2725 )?;
2726 Ok(acc.hessian)
2727 }
2728
2729 fn exact_newton_joint_hessian_second_directional_derivative_dense(
2730 &self,
2731 block_states: &[ParameterBlockState],
2732 d_beta_u_flat: &Array1<f64>,
2733 d_beta_v_flat: &Array1<f64>,
2734 ) -> Result<Array2<f64>, String> {
2735 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2736 let q_right = self.time_q_right(block_states)?;
2737 let sigma = self.latent_sd(block_states)?;
2738 let slices = self.joint_slices();
2739 if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
2740 return Err(format!(
2741 "latent survival joint d2H direction length mismatch: got {} and {}, expected {}",
2742 d_beta_u_flat.len(),
2743 d_beta_v_flat.len(),
2744 slices.total
2745 ));
2746 }
2747 let include_log_sigma = slices.log_sigma.is_some();
2748 let total = slices.total;
2749 let acc = deterministic_latent_survival_row_reduction(
2750 self.event_target.len(),
2751 || LatentSurvivalDenseHessianAccum {
2752 hessian: Array2::<f64>::zeros((total, total)),
2753 },
2754 |row_idx, acc| {
2755 let wi = self.weights[row_idx];
2756 if wi <= MIN_WEIGHT {
2757 return Ok(());
2758 }
2759 let row = self.build_row_at(
2760 row_idx,
2761 q_entry[row_idx],
2762 q_exit[row_idx],
2763 qdot_exit[row_idx],
2764 q_right[row_idx],
2765 )?;
2766 let direction_u =
2767 self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
2768 let direction_v =
2769 self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
2770 let fourth = latent_survival_row_primary_fourth_contracted(
2771 &self.quadctx,
2772 &row,
2773 q_entry[row_idx],
2774 q_exit[row_idx],
2775 qdot_exit[row_idx],
2776 q_right[row_idx],
2777 mu[row_idx],
2778 sigma,
2779 &direction_u,
2780 &direction_v,
2781 include_log_sigma,
2782 )?;
2783 self.add_pullback_primary_hessian(
2784 &mut acc.hessian,
2785 row_idx,
2786 &slices,
2787 &(wi * fourth),
2788 )?;
2789 Ok(())
2790 },
2791 |total_acc, chunk_acc| {
2792 total_acc.hessian += &chunk_acc.hessian;
2793 },
2794 )?;
2795 Ok(acc.hessian)
2796 }
2797}
2798
2799fn log_kernel_ratio(
2800 bundle: &crate::survival::lognormal_kernel::LogLognormalKernelBundle,
2801 num: usize,
2802 den: usize,
2803) -> f64 {
2804 let delta = bundle.get(num) - bundle.get(den);
2805 if delta.is_finite() {
2806 delta.exp()
2807 } else if delta > 0.0 {
2808 f64::INFINITY
2809 } else {
2810 0.0
2811 }
2812}
2813
2814fn logk_q_derivatives(
2815 quadctx: &QuadratureContext,
2816 k: usize,
2817 mass: f64,
2818 mu: f64,
2819 sigma: f64,
2820) -> Result<(f64, f64, IntegratedExpectationMode), LatentSurvivalError> {
2821 if mass <= 0.0 {
2822 return Ok((0.0, 0.0, IntegratedExpectationMode::ExactClosedForm));
2823 }
2824 let bundle = log_kernel_bundle(quadctx, mass, mu, sigma, k + 2).map_err(|e| {
2825 LatentSurvivalError::NumericalFailure {
2826 reason: format!("latent survival kernel evaluation failed: {e}"),
2827 }
2828 })?;
2829 let r1 = log_kernel_ratio(&bundle, k + 1, k);
2830 let r2 = log_kernel_ratio(&bundle, k + 2, k);
2831 let d1 = -mass * r1;
2832 let d2 = d1 + mass * mass * (r2 - r1 * r1);
2833 Ok((d1, d2, bundle.mode))
2834}
2835
2836fn latent_survival_time_jet(
2837 quadctx: &QuadratureContext,
2838 row: &LatentSurvivalRow,
2839 qdot_exit: f64,
2840 mu: f64,
2841 sigma: f64,
2842) -> Result<LatentSurvivalTimeJet, LatentSurvivalError> {
2843 let (entry_d1, entry_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_entry, mu, sigma)?;
2844 match row.event_type {
2845 LatentSurvivalEventType::RightCensored => {
2846 let (exit_d1, exit_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2847 Ok(LatentSurvivalTimeJet {
2848 grad_entry: -entry_d1,
2849 grad_exit: exit_d1,
2850 neg_hess_entry: entry_d2,
2851 neg_hess_exit: -exit_d2,
2852 })
2853 }
2854 LatentSurvivalEventType::ExactEvent => {
2855 if !(qdot_exit.is_finite() && qdot_exit > 0.0) {
2856 return Err(LatentSurvivalError::NumericalFailure {
2857 reason: format!(
2858 "latent survival requires positive finite baseline hazard derivative, got {qdot_exit}"
2859 ),
2860 });
2861 }
2862 if row.hazard_unloaded > 0.0 {
2863 let bundle =
2864 log_kernel_bundle(quadctx, row.mass_exit, mu, sigma, 3).map_err(|e| {
2865 LatentSurvivalError::NumericalFailure {
2866 reason: format!("latent survival kernel evaluation failed: {e}"),
2867 }
2868 })?;
2869 let (unloaded_d1, unloaded_d2, _) =
2870 logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2871 let (loaded_log_d1, loaded_d2, _) =
2872 logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2873 let loaded_d1 = 1.0 + loaded_log_d1;
2874 let log_loaded = row.hazard_loaded.ln() + bundle.get(1);
2875 let log_unloaded = row.hazard_unloaded.ln() + bundle.get(0);
2876 let shift = log_loaded.max(log_unloaded);
2877 let loaded_weight = (log_loaded - shift).exp();
2878 let unloaded_weight = (log_unloaded - shift).exp();
2879 let normalizer = loaded_weight + unloaded_weight;
2880 if !(normalizer.is_finite() && normalizer > 0.0) {
2881 return Err(LatentSurvivalError::NumericalFailure {
2882 reason: "latent survival exact-event numerator became non-finite under loaded/unloaded hazard decomposition"
2883 .to_string(),
2884 });
2885 }
2886 let w_loaded = loaded_weight / normalizer;
2887 let w_unloaded = unloaded_weight / normalizer;
2888 let grad_exit = w_loaded * loaded_d1 + w_unloaded * unloaded_d1;
2889 let d2_exit = w_loaded * (loaded_d2 + loaded_d1 * loaded_d1)
2890 + w_unloaded * (unloaded_d2 + unloaded_d1 * unloaded_d1)
2891 - grad_exit * grad_exit;
2892 Ok(LatentSurvivalTimeJet {
2893 grad_entry: -entry_d1,
2894 grad_exit,
2895 neg_hess_entry: entry_d2,
2896 neg_hess_exit: -d2_exit,
2897 })
2898 } else {
2899 let (exit_d1, exit_d2, _) =
2900 logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2901 Ok(LatentSurvivalTimeJet {
2902 grad_entry: -entry_d1,
2903 grad_exit: 1.0 + exit_d1,
2904 neg_hess_entry: entry_d2,
2905 neg_hess_exit: -exit_d2,
2906 })
2907 }
2908 }
2909 LatentSurvivalEventType::IntervalCensored => {
2910 Err(LatentSurvivalError::UnsupportedConfiguration {
2911 reason:
2912 "latent survival dynamic time derivatives do not implement interval censoring"
2913 .to_string(),
2914 })
2915 }
2916 }
2917}
2918
2919fn dense_outer_accumulate<S>(
2920 target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2921 weight: f64,
2922 x: ArrayView1<'_, f64>,
2923) where
2924 S: ndarray::DataMut<Elem = f64>,
2925{
2926 for a in 0..x.len() {
2927 let xa = x[a];
2928 if xa == 0.0 {
2929 continue;
2930 }
2931 for b in 0..x.len() {
2932 let xb = x[b];
2933 if xb == 0.0 {
2934 continue;
2935 }
2936 target[[a, b]] += weight * xa * xb;
2937 }
2938 }
2939}
2940
2941fn dense_symmetric_cross_accumulate<S>(
2942 target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2943 weight: f64,
2944 x: ArrayView1<'_, f64>,
2945 y: ArrayView1<'_, f64>,
2946) where
2947 S: ndarray::DataMut<Elem = f64>,
2948{
2949 for a in 0..x.len() {
2950 let xa = x[a];
2951 let ya = y[a];
2952 if xa == 0.0 && ya == 0.0 {
2953 continue;
2954 }
2955 for b in 0..x.len() {
2956 let xb = x[b];
2957 let yb = y[b];
2958 let contribution = xa * yb + ya * xb;
2959 if contribution == 0.0 {
2960 continue;
2961 }
2962 target[[a, b]] += weight * contribution;
2963 }
2964 }
2965}
2966
2967fn build_latent_survival_row(
2968 row_index: usize,
2969 hazard_loading: HazardLoading,
2970 event_type: LatentSurvivalEventType,
2971 q_entry: f64,
2972 q_exit: f64,
2973 qdot_exit: f64,
2974 q_right: f64,
2975 unloaded_mass_entry: f64,
2976 unloaded_mass_exit: f64,
2977 unloaded_mass_right: f64,
2978 unloaded_hazard_exit: f64,
2979) -> Result<LatentSurvivalRow, LatentSurvivalError> {
2980 if !(q_entry.is_finite() && q_exit.is_finite()) {
2981 return Err(LatentSurvivalError::NumericalFailure {
2982 reason: format!(
2983 "latent survival requires finite q_entry and q_exit, got q_entry={q_entry}, q_exit={q_exit}"
2984 ),
2985 });
2986 }
2987 if q_exit < q_entry {
2988 return Err(LatentSurvivalError::NumericalFailure {
2989 reason: format!(
2990 "latent survival requires q_exit >= q_entry so cumulative mass is monotone, got q_entry={q_entry}, q_exit={q_exit}"
2991 ),
2992 });
2993 }
2994 if !(unloaded_mass_entry.is_finite()
2995 && unloaded_mass_exit.is_finite()
2996 && unloaded_hazard_exit.is_finite())
2997 {
2998 return Err(LatentSurvivalError::InvalidDataset {
2999 reason: format!(
3000 "latent survival requires finite unloaded components, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3001 ),
3002 });
3003 }
3004 if unloaded_mass_entry < 0.0
3005 || unloaded_mass_exit < unloaded_mass_entry
3006 || unloaded_hazard_exit < 0.0
3007 {
3008 return Err(LatentSurvivalError::InvalidDataset {
3009 reason: format!(
3010 "latent survival requires unloaded masses/hazard to be non-negative and monotone, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3011 ),
3012 });
3013 }
3014 let mass_entry = q_entry.exp();
3015 let mass_exit = q_exit.exp();
3016 let row = match event_type {
3017 LatentSurvivalEventType::RightCensored => {
3018 validate_unloaded_components_for_loading(
3019 "latent-survival",
3020 row_index,
3021 hazard_loading,
3022 unloaded_mass_entry,
3023 unloaded_mass_exit,
3024 Some(unloaded_hazard_exit),
3025 )?;
3026 LatentSurvivalRow::right_censored(
3027 mass_entry,
3028 mass_exit,
3029 unloaded_mass_entry,
3030 unloaded_mass_exit,
3031 )
3032 }
3033 LatentSurvivalEventType::ExactEvent => {
3034 validate_unloaded_components_for_loading(
3035 "latent-survival",
3036 row_index,
3037 hazard_loading,
3038 unloaded_mass_entry,
3039 unloaded_mass_exit,
3040 Some(unloaded_hazard_exit),
3041 )?;
3042 LatentSurvivalRow::exact_event(
3043 mass_entry,
3044 mass_exit,
3045 unloaded_mass_entry,
3046 unloaded_mass_exit,
3047 mass_exit
3048 * if qdot_exit.is_finite() && qdot_exit > 0.0 {
3049 qdot_exit
3050 } else {
3051 return Err(LatentSurvivalError::NumericalFailure {
3052 reason: format!(
3053 "latent survival exact event requires positive finite baseline hazard derivative, got {qdot_exit}"
3054 ),
3055 });
3056 },
3057 unloaded_hazard_exit,
3058 )
3059 }
3060 LatentSurvivalEventType::IntervalCensored => {
3061 if !q_right.is_finite() {
3069 return Err(LatentSurvivalError::NumericalFailure {
3070 reason: format!(
3071 "latent survival interval row {} requires a finite q_right, got {q_right}",
3072 row_index + 1
3073 ),
3074 });
3075 }
3076 if q_right < q_exit {
3077 return Err(LatentSurvivalError::NumericalFailure {
3078 reason: format!(
3079 "latent survival interval row {} requires q_right >= q_exit (R >= L) so the \
3080 survival-mass difference is non-negative, got q_left={q_exit}, q_right={q_right}",
3081 row_index + 1
3082 ),
3083 });
3084 }
3085 if !(unloaded_mass_right.is_finite()) || unloaded_mass_right < unloaded_mass_exit {
3086 return Err(LatentSurvivalError::InvalidDataset {
3087 reason: format!(
3088 "latent survival interval row {} requires a finite unloaded right mass >= unloaded left mass, got left={unloaded_mass_exit}, right={unloaded_mass_right}",
3089 row_index + 1
3090 ),
3091 });
3092 }
3093 let mass_right = q_right.exp();
3097 LatentSurvivalRow::interval_censored(
3098 mass_entry,
3099 mass_exit,
3100 mass_right,
3101 unloaded_mass_entry,
3102 unloaded_mass_exit,
3103 unloaded_mass_right,
3104 )
3105 }
3106 };
3107 row.validate()
3108 .map_err(|e| LatentSurvivalError::InvalidDataset {
3109 reason: e.to_string(),
3110 })?;
3111 Ok(row)
3112}
3113
3114#[derive(Clone, Copy)]
3115struct BinaryFromLogSurvival {
3116 log_lik: f64,
3117 grad_scale: f64,
3120 neg_hess_scale: f64,
3129 outer_scale: f64,
3131 grad_scale_prime: f64,
3133 grad_scale_second: f64,
3135 outer_scale_prime: f64,
3137 outer_scale_second: f64,
3139}
3140
3141#[inline]
3148fn binary_log_survival_scales(survival: f64, event_prob: f64) -> (f64, f64, f64, f64, f64) {
3149 let log_lik = event_prob.ln();
3157 let p = event_prob;
3158 let p2 = p * p;
3159 let p3 = p2 * p;
3160 let p4 = p3 * p;
3161 let s = survival;
3162 let s2 = s * s;
3163 let s3 = s2 * s;
3164 let ell_prime = -s / p;
3165 let ell_pp = -s / p2;
3166 let ell_ppp = -s * (1.0 + s) / p3;
3167 let ell_pppp = -(s + 4.0 * s2 + s3) / p4;
3177 (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp)
3178}
3179
3180fn binary_from_log_survival(
3181 log_survival: f64,
3182 event: u8,
3183) -> Result<BinaryFromLogSurvival, LatentSurvivalError> {
3184 if event == 0 {
3185 return Ok(BinaryFromLogSurvival {
3187 log_lik: log_survival,
3188 grad_scale: 1.0,
3189 neg_hess_scale: 1.0,
3190 outer_scale: 0.0,
3191 grad_scale_prime: 0.0,
3192 grad_scale_second: 0.0,
3193 outer_scale_prime: 0.0,
3194 outer_scale_second: 0.0,
3195 });
3196 }
3197 if event != 1 {
3198 return Err(LatentSurvivalError::InvalidDataset {
3199 reason: format!("latent-binary requires event targets in {{0,1}}, got {event}"),
3200 });
3201 }
3202 const MAX_LOG_SURVIVAL: f64 = -1e-15;
3209 let log_survival = log_survival.min(MAX_LOG_SURVIVAL);
3210 let survival = log_survival.exp();
3211 let event_prob = 1.0 - survival;
3212 if !(event_prob.is_finite() && event_prob > 0.0) {
3213 return Err(LatentSurvivalError::NumericalFailure {
3214 reason: format!(
3215 "latent-binary encountered non-positive event probability from log survival {log_survival}"
3216 ),
3217 });
3218 }
3219 let (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp) =
3220 binary_log_survival_scales(survival, event_prob);
3221 let grad_scale = ell_prime;
3222 let neg_hess_scale = ell_prime; let outer_scale = -ell_pp;
3224 let grad_scale_prime = ell_pp;
3225 let grad_scale_second = ell_ppp;
3226 let outer_scale_prime = -ell_ppp;
3227 let outer_scale_second = -ell_pppp;
3228 assert!(
3233 (grad_scale - neg_hess_scale).abs() <= 1e-15 * grad_scale.abs().max(1.0),
3234 "binary_from_log_survival invariant: neg_hess_scale ({neg_hess_scale}) must equal grad_scale ({grad_scale}) so that grad_scale and the coefficient on neg_hessian share sign"
3235 );
3236 assert!(
3237 outer_scale >= 0.0 || !outer_scale.is_finite(),
3238 "binary_from_log_survival invariant: outer_scale (= -ℓ'') must be non-negative for event=1; got {outer_scale}"
3239 );
3240 Ok(BinaryFromLogSurvival {
3241 log_lik,
3242 grad_scale,
3243 neg_hess_scale,
3244 outer_scale,
3245 grad_scale_prime,
3246 grad_scale_second,
3247 outer_scale_prime,
3248 outer_scale_second,
3249 })
3250}
3251
3252impl LatentBinaryFamily {
3253 fn build_right_censored_row_at(
3259 &self,
3260 row_idx: usize,
3261 q_entry: f64,
3262 q_exit: f64,
3263 ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3264 build_latent_survival_row(
3265 row_idx,
3266 self.hazard_loading,
3267 LatentSurvivalEventType::RightCensored,
3268 q_entry,
3269 q_exit,
3270 1.0,
3271 q_exit,
3272 self.unloaded_mass_entry[row_idx],
3273 self.unloaded_mass_exit[row_idx],
3274 0.0,
3275 0.0,
3276 )
3277 }
3278
3279 fn joint_slices(&self) -> LatentSurvivalJointSlices {
3280 let p_time = self.x_time_exit.ncols();
3281 let p_mean = self.x_mean.ncols();
3282 LatentSurvivalJointSlices {
3283 time: 0..p_time,
3284 mean: p_time..p_time + p_mean,
3285 log_sigma: None,
3286 total: p_time + p_mean,
3287 }
3288 }
3289
3290 fn row_primary_direction_from_flat(
3291 &self,
3292 row: usize,
3293 slices: &LatentSurvivalJointSlices,
3294 d_beta_flat: &Array1<f64>,
3295 ) -> Array1<f64> {
3296 let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
3297 let d_time = d_beta_flat.slice(s![slices.time.clone()]);
3298 out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
3299 out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
3300 out[LATENT_SURVIVAL_PRIMARY_MU] = self
3301 .x_mean
3302 .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
3303 out
3304 }
3305
3306 fn add_pullback_primary_gradient(
3307 &self,
3308 target: &mut Array1<f64>,
3309 row: usize,
3310 slices: &LatentSurvivalJointSlices,
3311 primary_gradient: &Array1<f64>,
3312 weight: f64,
3313 ) {
3314 for (primary_idx, time_vec) in [
3315 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3316 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3317 ] {
3318 let scale = weight * primary_gradient[primary_idx];
3319 if scale == 0.0 {
3320 continue;
3321 }
3322 for i in 0..time_vec.len() {
3323 let xi = time_vec[i];
3324 if xi != 0.0 {
3325 target[slices.time.start + i] += scale * xi;
3326 }
3327 }
3328 }
3329
3330 let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
3331 if mean_scale != 0.0 {
3332 self.x_mean
3333 .axpy_row_into(
3334 row,
3335 mean_scale,
3336 &mut target.slice_mut(s![slices.mean.clone()]),
3337 )
3338 .unwrap_or_else(|error| {
3343 panic!(
3344 "latent binary mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
3345 slices.mean,
3346 target.len(),
3347 self.x_mean.ncols()
3348 )
3349 });
3350 }
3351 }
3352
3353 fn add_pullback_primary_hessian(
3354 &self,
3355 target: &mut Array2<f64>,
3356 row: usize,
3357 slices: &LatentSurvivalJointSlices,
3358 primary_hessian: &Array2<f64>,
3359 ) {
3360 {
3361 let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
3362 dense_outer_accumulate(
3363 time_target,
3364 primary_hessian[[
3365 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3366 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3367 ]],
3368 self.x_time_entry.row(row),
3369 );
3370 dense_outer_accumulate(
3371 time_target,
3372 primary_hessian[[
3373 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3374 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3375 ]],
3376 self.x_time_exit.row(row),
3377 );
3378 dense_symmetric_cross_accumulate(
3379 time_target,
3380 primary_hessian[[
3381 LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3382 LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3383 ]],
3384 self.x_time_entry.row(row),
3385 self.x_time_exit.row(row),
3386 );
3387 }
3388
3389 let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
3390 self.x_mean
3391 .syr_row_into_view(
3392 row,
3393 mean_weight,
3394 target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
3395 )
3396 .unwrap_or_else(|error| {
3397 panic!(
3403 "latent binary mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
3404 slices.mean,
3405 target.dim(),
3406 self.x_mean.ncols()
3407 )
3408 });
3409
3410 let mean_row = self
3411 .x_mean
3412 .try_row_chunk(row..row + 1)
3413 .unwrap_or_else(|error| {
3414 panic!(
3418 "latent binary mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
3419 self.x_mean.nrows(),
3420 self.x_mean.ncols()
3421 )
3422 });
3423 let mean_vec = mean_row.row(0);
3424 for (primary_idx, time_vec) in [
3425 (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3426 (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3427 ] {
3428 let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
3429 if weight == 0.0 {
3430 continue;
3431 }
3432 for i in 0..time_vec.len() {
3433 let xi = time_vec[i];
3434 if xi == 0.0 {
3435 continue;
3436 }
3437 for j in 0..mean_vec.len() {
3438 let xj = mean_vec[j];
3439 if xj == 0.0 {
3440 continue;
3441 }
3442 target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
3443 target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
3444 }
3445 }
3446 }
3447 }
3448
3449 fn evaluate_exact_newton_joint_dense(
3450 &self,
3451 block_states: &[ParameterBlockState],
3452 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3453 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3454 let slices = self.joint_slices();
3455 let mut ll = 0.0;
3456 let mut gradient = Array1::<f64>::zeros(slices.total);
3457 let mut hessian = Array2::<f64>::zeros((slices.total, slices.total));
3458 for row_idx in 0..self.event_target.len() {
3459 let wi = self.weights[row_idx];
3460 if wi <= MIN_WEIGHT {
3461 continue;
3462 }
3463 let row =
3464 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3465 let (row_log_survival, survival_gradient, survival_hessian) =
3466 latent_survival_row_primary_gradient_hessian(
3467 &self.quadctx,
3468 &row,
3469 q_entry[row_idx],
3470 q_exit[row_idx],
3471 1.0,
3472 q_exit[row_idx],
3473 mu[row_idx],
3474 self.latent_sd,
3475 false,
3476 )?;
3477 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3478 ll += wi * binary.log_lik;
3479 let primary_gradient = binary.grad_scale * &survival_gradient;
3480 let mut primary_hessian = binary.grad_scale * survival_hessian;
3481 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3482 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3483 primary_hessian[[a, b]] +=
3484 binary.outer_scale * survival_gradient[a] * survival_gradient[b];
3485 }
3486 }
3487 self.add_pullback_primary_gradient(
3488 &mut gradient,
3489 row_idx,
3490 &slices,
3491 &primary_gradient,
3492 wi,
3493 );
3494 self.add_pullback_primary_hessian(
3495 &mut hessian,
3496 row_idx,
3497 &slices,
3498 &(wi * primary_hessian),
3499 );
3500 }
3501 Ok((ll, gradient, hessian))
3502 }
3503
3504 pub fn offset_channel_residuals(
3516 &self,
3517 block_states: &[ParameterBlockState],
3518 ) -> Result<crate::survival::OffsetChannelResiduals, String> {
3519 let n = self.event_target.len();
3520 if block_states.is_empty() {
3521 log::warn!(
3522 "LatentBinaryFamily::offset_channel_residuals: block_states is empty \
3523 (degraded fit); returning zero offset residuals (n={n})"
3524 );
3525 return Ok(crate::survival::OffsetChannelResiduals {
3526 exit: Array1::<f64>::zeros(n),
3527 entry: Array1::<f64>::zeros(n),
3528 derivative: Array1::<f64>::zeros(n),
3529 right: Array1::<f64>::zeros(n),
3530 });
3531 }
3532 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3533 let mut entry = Array1::<f64>::zeros(n);
3534 let mut exit = Array1::<f64>::zeros(n);
3535 for row_idx in 0..n {
3536 let wi = self.weights[row_idx];
3537 if wi <= MIN_WEIGHT {
3538 continue;
3539 }
3540 let row =
3541 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3542 let (row_log_survival, survival_gradient, _) =
3543 latent_survival_row_primary_gradient_hessian(
3544 &self.quadctx,
3545 &row,
3546 q_entry[row_idx],
3547 q_exit[row_idx],
3548 1.0,
3549 q_exit[row_idx],
3550 mu[row_idx],
3551 self.latent_sd,
3552 false,
3553 )?;
3554 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3555 entry[row_idx] =
3557 -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
3558 exit[row_idx] =
3559 -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
3560 }
3561 Ok(crate::survival::OffsetChannelResiduals {
3562 exit,
3563 entry,
3564 derivative: Array1::<f64>::zeros(n),
3565 right: Array1::<f64>::zeros(n),
3568 })
3569 }
3570
3571 fn exact_newton_joint_hessian_directional_derivative_dense(
3572 &self,
3573 block_states: &[ParameterBlockState],
3574 d_beta_flat: &Array1<f64>,
3575 ) -> Result<Array2<f64>, String> {
3576 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3577 let slices = self.joint_slices();
3578 if d_beta_flat.len() != slices.total {
3579 return Err(format!(
3580 "latent binary joint dH direction length mismatch: got {}, expected {}",
3581 d_beta_flat.len(),
3582 slices.total
3583 ));
3584 }
3585 let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3586 for row_idx in 0..self.event_target.len() {
3587 let wi = self.weights[row_idx];
3588 if wi <= MIN_WEIGHT {
3589 continue;
3590 }
3591 let row =
3592 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3593 let (row_log_survival, survival_gradient, survival_hessian) =
3594 latent_survival_row_primary_gradient_hessian(
3595 &self.quadctx,
3596 &row,
3597 q_entry[row_idx],
3598 q_exit[row_idx],
3599 1.0,
3600 q_exit[row_idx],
3601 mu[row_idx],
3602 self.latent_sd,
3603 false,
3604 )?;
3605 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3606 let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
3607 let third = latent_survival_row_primary_third_contracted(
3608 &self.quadctx,
3609 &row,
3610 q_entry[row_idx],
3611 q_exit[row_idx],
3612 1.0,
3613 q_exit[row_idx],
3614 mu[row_idx],
3615 self.latent_sd,
3616 &direction,
3617 false,
3618 )?;
3619 let g_u = -survival_hessian.dot(&direction);
3620 let t_u = survival_gradient.dot(&direction);
3621 let mut primary = binary.grad_scale * third;
3622 primary.scaled_add(binary.grad_scale_prime * t_u, &survival_hessian);
3623 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3624 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3625 primary[[a, b]] += binary.outer_scale_prime
3626 * t_u
3627 * survival_gradient[a]
3628 * survival_gradient[b]
3629 + binary.outer_scale
3630 * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b]);
3631 }
3632 }
3633 self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3634 }
3635 Ok(out)
3636 }
3637
3638 fn exact_newton_joint_hessian_second_directional_derivative_dense(
3639 &self,
3640 block_states: &[ParameterBlockState],
3641 d_beta_u_flat: &Array1<f64>,
3642 d_beta_v_flat: &Array1<f64>,
3643 ) -> Result<Array2<f64>, String> {
3644 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3645 let slices = self.joint_slices();
3646 if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
3647 return Err(format!(
3648 "latent binary joint d2H direction length mismatch: got {} and {}, expected {}",
3649 d_beta_u_flat.len(),
3650 d_beta_v_flat.len(),
3651 slices.total
3652 ));
3653 }
3654 let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3655 for row_idx in 0..self.event_target.len() {
3656 let wi = self.weights[row_idx];
3657 if wi <= MIN_WEIGHT {
3658 continue;
3659 }
3660 let row =
3661 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3662 let (row_log_survival, survival_gradient, survival_hessian) =
3663 latent_survival_row_primary_gradient_hessian(
3664 &self.quadctx,
3665 &row,
3666 q_entry[row_idx],
3667 q_exit[row_idx],
3668 1.0,
3669 q_exit[row_idx],
3670 mu[row_idx],
3671 self.latent_sd,
3672 false,
3673 )?;
3674 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3675 let direction_u = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
3676 let direction_v = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
3677 let third_u = latent_survival_row_primary_third_contracted(
3678 &self.quadctx,
3679 &row,
3680 q_entry[row_idx],
3681 q_exit[row_idx],
3682 1.0,
3683 q_exit[row_idx],
3684 mu[row_idx],
3685 self.latent_sd,
3686 &direction_u,
3687 false,
3688 )?;
3689 let third_v = latent_survival_row_primary_third_contracted(
3690 &self.quadctx,
3691 &row,
3692 q_entry[row_idx],
3693 q_exit[row_idx],
3694 1.0,
3695 q_exit[row_idx],
3696 mu[row_idx],
3697 self.latent_sd,
3698 &direction_v,
3699 false,
3700 )?;
3701 let fourth = latent_survival_row_primary_fourth_contracted(
3702 &self.quadctx,
3703 &row,
3704 q_entry[row_idx],
3705 q_exit[row_idx],
3706 1.0,
3707 q_exit[row_idx],
3708 mu[row_idx],
3709 self.latent_sd,
3710 &direction_u,
3711 &direction_v,
3712 false,
3713 )?;
3714 let g_u = -survival_hessian.dot(&direction_u);
3715 let g_v = -survival_hessian.dot(&direction_v);
3716 let g_uv = -third_v.dot(&direction_u);
3717 let t_u = survival_gradient.dot(&direction_u);
3718 let t_v = survival_gradient.dot(&direction_v);
3719 let l_uv = -direction_u.dot(&survival_hessian.dot(&direction_v));
3720 let c_u = binary.grad_scale_prime * t_u;
3721 let c_v = binary.grad_scale_prime * t_v;
3722 let c_uv = binary.grad_scale_second * t_u * t_v + binary.grad_scale_prime * l_uv;
3723 let o_u = binary.outer_scale_prime * t_u;
3724 let o_v = binary.outer_scale_prime * t_v;
3725 let o_uv = binary.outer_scale_second * t_u * t_v + binary.outer_scale_prime * l_uv;
3726 let mut primary = binary.grad_scale * fourth;
3727 primary.scaled_add(c_u, &third_v);
3728 primary.scaled_add(c_v, &third_u);
3729 primary.scaled_add(c_uv, &survival_hessian);
3730 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3731 for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3732 primary[[a, b]] += o_uv * survival_gradient[a] * survival_gradient[b]
3733 + o_v * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b])
3734 + o_u * (g_v[a] * survival_gradient[b] + survival_gradient[a] * g_v[b])
3735 + binary.outer_scale
3736 * (g_uv[a] * survival_gradient[b]
3737 + g_u[a] * g_v[b]
3738 + g_v[a] * g_u[b]
3739 + survival_gradient[a] * g_uv[b]);
3740 }
3741 }
3742 self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3743 }
3744 Ok(out)
3745 }
3746}
3747
3748trait LatentJointHessianFamily {
3762 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices;
3763
3764 fn ws_evaluate_dense(
3765 &self,
3766 block_states: &[ParameterBlockState],
3767 ) -> Result<(f64, Array1<f64>, Array2<f64>), String>;
3768
3769 fn ws_dh_directional(
3770 &self,
3771 block_states: &[ParameterBlockState],
3772 d_beta_flat: &Array1<f64>,
3773 ) -> Result<Array2<f64>, String>;
3774
3775 fn ws_dh_second_directional(
3776 &self,
3777 block_states: &[ParameterBlockState],
3778 d_beta_u: &Array1<f64>,
3779 d_beta_v: &Array1<f64>,
3780 ) -> Result<Array2<f64>, String>;
3781
3782 fn ws_matvec_into(
3786 &self,
3787 slices: &LatentSurvivalJointSlices,
3788 block_states: &[ParameterBlockState],
3789 v: &Array1<f64>,
3790 out: &mut Array1<f64>,
3791 ) -> Result<bool, String>;
3792
3793 fn ws_label() -> &'static str;
3797}
3798
3799impl LatentJointHessianFamily for LatentSurvivalFamily {
3800 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3801 self.joint_slices()
3802 }
3803
3804 fn ws_evaluate_dense(
3805 &self,
3806 block_states: &[ParameterBlockState],
3807 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3808 self.evaluate_exact_newton_joint_dense(block_states)
3809 }
3810
3811 fn ws_dh_directional(
3812 &self,
3813 block_states: &[ParameterBlockState],
3814 d_beta_flat: &Array1<f64>,
3815 ) -> Result<Array2<f64>, String> {
3816 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3817 }
3818
3819 fn ws_dh_second_directional(
3820 &self,
3821 block_states: &[ParameterBlockState],
3822 d_beta_u: &Array1<f64>,
3823 d_beta_v: &Array1<f64>,
3824 ) -> Result<Array2<f64>, String> {
3825 self.exact_newton_joint_hessian_second_directional_derivative_dense(
3826 block_states,
3827 d_beta_u,
3828 d_beta_v,
3829 )
3830 }
3831
3832 fn ws_matvec_into(
3833 &self,
3834 slices: &LatentSurvivalJointSlices,
3835 block_states: &[ParameterBlockState],
3836 v: &Array1<f64>,
3837 out: &mut Array1<f64>,
3838 ) -> Result<bool, String> {
3839 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
3840 let q_right = self.time_q_right(block_states)?;
3841 let sigma = self.latent_sd(block_states)?;
3842 let include_log_sigma = slices.log_sigma.is_some();
3843 for row_idx in 0..self.event_target.len() {
3844 let wi = self.weights[row_idx];
3845 if wi <= MIN_WEIGHT {
3846 continue;
3847 }
3848 let row = self.build_row_at(
3849 row_idx,
3850 q_entry[row_idx],
3851 q_exit[row_idx],
3852 qdot_exit[row_idx],
3853 q_right[row_idx],
3854 )?;
3855 let (_, _, primary_hessian) = latent_survival_row_primary_gradient_hessian(
3856 &self.quadctx,
3857 &row,
3858 q_entry[row_idx],
3859 q_exit[row_idx],
3860 qdot_exit[row_idx],
3861 q_right[row_idx],
3862 mu[row_idx],
3863 sigma,
3864 include_log_sigma,
3865 )?;
3866 let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3867 let primary_hv = primary_hessian.dot(&primary_dir);
3868 self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi)?;
3869 }
3870 Ok(true)
3871 }
3872
3873 fn ws_label() -> &'static str {
3874 "survival"
3875 }
3876}
3877
3878impl LatentJointHessianFamily for LatentBinaryFamily {
3879 fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3880 self.joint_slices()
3881 }
3882
3883 fn ws_evaluate_dense(
3884 &self,
3885 block_states: &[ParameterBlockState],
3886 ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3887 self.evaluate_exact_newton_joint_dense(block_states)
3888 }
3889
3890 fn ws_dh_directional(
3891 &self,
3892 block_states: &[ParameterBlockState],
3893 d_beta_flat: &Array1<f64>,
3894 ) -> Result<Array2<f64>, String> {
3895 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3896 }
3897
3898 fn ws_dh_second_directional(
3899 &self,
3900 block_states: &[ParameterBlockState],
3901 d_beta_u: &Array1<f64>,
3902 d_beta_v: &Array1<f64>,
3903 ) -> Result<Array2<f64>, String> {
3904 self.exact_newton_joint_hessian_second_directional_derivative_dense(
3905 block_states,
3906 d_beta_u,
3907 d_beta_v,
3908 )
3909 }
3910
3911 fn ws_matvec_into(
3912 &self,
3913 slices: &LatentSurvivalJointSlices,
3914 block_states: &[ParameterBlockState],
3915 v: &Array1<f64>,
3916 out: &mut Array1<f64>,
3917 ) -> Result<bool, String> {
3918 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3919 for row_idx in 0..self.event_target.len() {
3920 let wi = self.weights[row_idx];
3921 if wi <= MIN_WEIGHT {
3922 continue;
3923 }
3924 let row =
3925 self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3926 let (row_log_survival, survival_gradient, survival_hessian) =
3927 latent_survival_row_primary_gradient_hessian(
3928 &self.quadctx,
3929 &row,
3930 q_entry[row_idx],
3931 q_exit[row_idx],
3932 1.0,
3933 q_exit[row_idx],
3934 mu[row_idx],
3935 self.latent_sd,
3936 false,
3937 )?;
3938 let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3939 let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3940 let mut primary_hv = binary.grad_scale * survival_hessian.dot(&primary_dir);
3941 let outer_dot = survival_gradient.dot(&primary_dir);
3942 for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3943 primary_hv[a] += binary.outer_scale * survival_gradient[a] * outer_dot;
3944 }
3945 self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi);
3946 }
3947 Ok(true)
3948 }
3949
3950 fn ws_label() -> &'static str {
3951 "binary"
3952 }
3953}
3954
3955struct LatentHessianWorkspace<F: LatentJointHessianFamily> {
3962 family: F,
3963 block_states: Vec<ParameterBlockState>,
3964 slices: LatentSurvivalJointSlices,
3965}
3966
3967impl<F: LatentJointHessianFamily> LatentHessianWorkspace<F> {
3968 fn new(family: F, block_states: Vec<ParameterBlockState>) -> Self {
3969 let slices = family.ws_joint_slices();
3970 Self {
3971 family,
3972 block_states,
3973 slices,
3974 }
3975 }
3976}
3977
3978impl<F> ExactNewtonJointHessianWorkspace for LatentHessianWorkspace<F>
3979where
3980 F: LatentJointHessianFamily + Send + Sync + 'static,
3981{
3982 fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
3983 self.family
3984 .ws_evaluate_dense(&self.block_states)
3985 .map(|(_, _, hessian)| Some(hessian))
3986 }
3987
3988 fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
3989 let mut out = Array1::<f64>::zeros(self.slices.total);
3990 self.hessian_matvec_into(v, &mut out)?;
3991 Ok(Some(out))
3992 }
3993
3994 fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
3995 if v.len() != self.slices.total || out.len() != self.slices.total {
3996 return Err(format!(
3997 "latent {} Hessian matvec dimension mismatch: v={} out={} expected={}",
3998 F::ws_label(),
3999 v.len(),
4000 out.len(),
4001 self.slices.total
4002 ));
4003 }
4004 out.fill(0.0);
4005 self.family
4006 .ws_matvec_into(&self.slices, &self.block_states, v, out)
4007 }
4008
4009 fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
4010 let dense = self.family.ws_evaluate_dense(&self.block_states)?.2;
4011 Ok(Some(dense.diag().to_owned()))
4012 }
4013
4014 fn directional_derivative(
4015 &self,
4016 d_beta_flat: &Array1<f64>,
4017 ) -> Result<Option<Array2<f64>>, String> {
4018 self.family
4019 .ws_dh_directional(&self.block_states, d_beta_flat)
4020 .map(Some)
4021 }
4022
4023 fn second_directional_derivative(
4024 &self,
4025 d_beta_u: &Array1<f64>,
4026 d_beta_v: &Array1<f64>,
4027 ) -> Result<Option<Array2<f64>>, String> {
4028 self.family
4029 .ws_dh_second_directional(&self.block_states, d_beta_u, d_beta_v)
4030 .map(Some)
4031 }
4032}
4033
4034type LatentSurvivalHessianWorkspace = LatentHessianWorkspace<LatentSurvivalFamily>;
4035type LatentBinaryHessianWorkspace = LatentHessianWorkspace<LatentBinaryFamily>;
4036
4037impl CustomFamily for LatentSurvivalFamily {
4038 fn joint_jeffreys_term_required(&self) -> bool {
4042 true
4043 }
4044
4045 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4046 true
4047 }
4048
4049 fn has_explicit_joint_hessian(&self) -> bool {
4050 true
4051 }
4052
4053 fn levenberg_on_ill_conditioning(&self) -> bool {
4074 true
4075 }
4076
4077 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4078 crate::custom_family::joint_coupled_coefficient_hessian_cost(
4082 self.event_target.len() as u64,
4083 specs,
4084 )
4085 }
4086
4087 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4088 let (ll, joint_gradient, hess_time, hess_mean, hess_log_sigma) =
4089 self.evaluate_exact_newton_block_diagonals(block_states)?;
4090 let block_ranges = self.joint_block_ranges();
4091 let mut blockworking_sets = vec![
4092 BlockWorkingSet::ExactNewton {
4093 gradient: joint_gradient.slice(s![block_ranges[0].clone()]).to_owned(),
4094 hessian: SymmetricMatrix::Dense(hess_time),
4095 },
4096 BlockWorkingSet::ExactNewton {
4097 gradient: joint_gradient.slice(s![block_ranges[1].clone()]).to_owned(),
4098 hessian: SymmetricMatrix::Dense(hess_mean),
4099 },
4100 ];
4101 if let (Some(range), Some(hessian)) = (block_ranges.get(2).cloned(), hess_log_sigma) {
4102 blockworking_sets.push(BlockWorkingSet::ExactNewton {
4103 gradient: joint_gradient.slice(s![range]).to_owned(),
4104 hessian: SymmetricMatrix::Dense(hessian),
4105 });
4106 }
4107 Ok(FamilyEvaluation {
4108 log_likelihood: ll,
4109 blockworking_sets,
4110 })
4111 }
4112
4113 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4114 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4115 let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
4116 let q_right = self.time_q_right(block_states)?;
4117 let latent_sd = self.latent_sd(block_states)?;
4118 let n = self.event_target.len();
4119 let contributions: Result<Vec<f64>, String> = (0..n)
4123 .into_par_iter()
4124 .map(|i| -> Result<f64, String> {
4125 let wi = self.weights[i];
4126 if wi <= MIN_WEIGHT {
4127 return Ok(0.0);
4128 }
4129 let row = self.build_row_at(i, q_entry[i], q_exit[i], qdot_exit[i], q_right[i])?;
4130 let jet = LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], latent_sd)
4131 .map_err(|e| format!("LatentSurvivalFamily row {i}: {e}"))?;
4132 Ok(wi * jet.log_lik)
4133 })
4134 .collect();
4135 Ok(contributions?.into_iter().sum())
4136 }
4137
4138 fn block_linear_constraints(
4139 &self,
4140 _: &[ParameterBlockState],
4141 block_idx: usize,
4142 block_spec: &ParameterBlockSpec,
4143 ) -> Result<Option<LinearInequalityConstraints>, String> {
4144 assert!(!block_spec.name.is_empty());
4145 if block_idx == Self::BLOCK_TIME {
4146 Ok(self.time_linear_constraints.clone())
4147 } else {
4148 Ok(None)
4149 }
4150 }
4151
4152 fn exact_newton_joint_hessian(
4153 &self,
4154 block_states: &[ParameterBlockState],
4155 ) -> Result<Option<Array2<f64>>, String> {
4156 self.evaluate_exact_newton_joint_dense(block_states)
4157 .map(|(_, _, hessian)| Some(hessian))
4158 }
4159
4160 fn exact_newton_joint_hessian_workspace(
4161 &self,
4162 block_states: &[ParameterBlockState],
4163 _: &[ParameterBlockSpec],
4164 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4165 Ok(Some(Arc::new(LatentSurvivalHessianWorkspace::new(
4166 self.clone(),
4167 block_states.to_vec(),
4168 ))))
4169 }
4170
4171 fn exact_newton_joint_gradient_evaluation(
4172 &self,
4173 block_states: &[ParameterBlockState],
4174 _: &[ParameterBlockSpec],
4175 ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4176 self.evaluate_exact_newton_joint_gradient_dense(block_states)
4177 .map(|(log_likelihood, gradient)| {
4178 Some(ExactNewtonJointGradientEvaluation {
4179 log_likelihood,
4180 gradient,
4181 })
4182 })
4183 }
4184
4185 fn exact_newton_joint_hessian_directional_derivative(
4186 &self,
4187 block_states: &[ParameterBlockState],
4188 d_beta_flat: &Array1<f64>,
4189 ) -> Result<Option<Array2<f64>>, String> {
4190 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4191 .map(Some)
4192 }
4193
4194 fn exact_newton_joint_hessiansecond_directional_derivative(
4195 &self,
4196 block_states: &[ParameterBlockState],
4197 d_beta_u_flat: &Array1<f64>,
4198 d_beta_v_flat: &Array1<f64>,
4199 ) -> Result<Option<Array2<f64>>, String> {
4200 self.exact_newton_joint_hessian_second_directional_derivative_dense(
4201 block_states,
4202 d_beta_u_flat,
4203 d_beta_v_flat,
4204 )
4205 .map(Some)
4206 }
4207
4208 fn requires_joint_outer_hyper_path(&self) -> bool {
4209 true
4210 }
4211}
4212
4213impl CustomFamily for LatentBinaryFamily {
4214 fn joint_jeffreys_term_required(&self) -> bool {
4218 true
4219 }
4220
4221 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4222 true
4223 }
4224
4225 fn has_explicit_joint_hessian(&self) -> bool {
4226 true
4227 }
4228
4229 fn levenberg_on_ill_conditioning(&self) -> bool {
4237 true
4238 }
4239
4240 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4241 crate::custom_family::joint_coupled_coefficient_hessian_cost(
4242 self.event_target.len() as u64,
4243 specs,
4244 )
4245 }
4246
4247 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4248 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4249 let n = self.event_target.len();
4250 let p_time = self.x_time_exit.ncols();
4251 let p_mean = self.x_mean.ncols();
4252
4253 let mut ll = 0.0;
4254 let mut grad_time = Array1::<f64>::zeros(p_time);
4255 let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
4256 let mut grad_mean = Array1::<f64>::zeros(p_mean);
4257 let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
4258 let mut mean_row_buf = Array2::<f64>::zeros((1, p_mean));
4261
4262 for i in 0..n {
4263 let wi = self.weights[i];
4264 if wi <= MIN_WEIGHT {
4265 continue;
4266 }
4267 if !(q_entry[i].is_finite() && q_exit[i].is_finite() && mu[i].is_finite()) {
4268 return Err(format!(
4269 "latent-binary row {i} contains non-finite predictors: q_entry={}, q_exit={}, mu={}",
4270 q_entry[i], q_exit[i], mu[i]
4271 ));
4272 }
4273 let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4274 let survival_jet =
4275 LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4276 .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4277 let binary = binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?;
4278 ll += wi * binary.log_lik;
4279
4280 self.x_mean
4281 .row_chunk_into(i..i + 1, mean_row_buf.view_mut())
4282 .map_err(|e| format!("LatentBinaryFamily row {i} mean row_chunk: {e}"))?;
4283 let mean_vec = mean_row_buf.row(0);
4284 let mean_grad_scale = wi * binary.grad_scale * survival_jet.score;
4285 for j in 0..p_mean {
4286 grad_mean[j] += mean_grad_scale * mean_vec[j];
4287 }
4288 let mean_neg_hess = wi
4289 * (binary.neg_hess_scale * survival_jet.neg_hessian
4290 + binary.outer_scale * survival_jet.score * survival_jet.score);
4291 dense_outer_accumulate(&mut hess_mean, mean_neg_hess, mean_vec);
4292
4293 let time_jet =
4294 latent_survival_time_jet(&self.quadctx, &row, 0.0, mu[i], self.latent_sd)?;
4295 let t_entry = self.x_time_entry.row(i);
4296 let t_exit = self.x_time_exit.row(i);
4297 for j in 0..p_time {
4298 grad_time[j] += wi
4299 * binary.grad_scale
4300 * (time_jet.grad_entry * t_entry[j] + time_jet.grad_exit * t_exit[j]);
4301 }
4302 dense_outer_accumulate(
4303 &mut hess_time,
4304 wi * binary.neg_hess_scale * time_jet.neg_hess_entry,
4305 t_entry,
4306 );
4307 dense_outer_accumulate(
4308 &mut hess_time,
4309 wi * binary.neg_hess_scale * time_jet.neg_hess_exit,
4310 t_exit,
4311 );
4312 if binary.outer_scale != 0.0 {
4313 dense_outer_accumulate(
4314 &mut hess_time,
4315 wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_entry,
4316 t_entry,
4317 );
4318 dense_outer_accumulate(
4319 &mut hess_time,
4320 wi * binary.outer_scale * time_jet.grad_exit * time_jet.grad_exit,
4321 t_exit,
4322 );
4323 dense_symmetric_cross_accumulate(
4324 &mut hess_time,
4325 wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_exit,
4326 t_entry,
4327 t_exit,
4328 );
4329 }
4330 }
4331
4332 Ok(FamilyEvaluation {
4333 log_likelihood: ll,
4334 blockworking_sets: vec![
4335 BlockWorkingSet::ExactNewton {
4336 gradient: grad_time,
4337 hessian: SymmetricMatrix::Dense(hess_time),
4338 },
4339 BlockWorkingSet::ExactNewton {
4340 gradient: grad_mean,
4341 hessian: SymmetricMatrix::Dense(hess_mean),
4342 },
4343 ],
4344 })
4345 }
4346
4347 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4348 let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4349 let mut ll = 0.0;
4350 for i in 0..self.event_target.len() {
4351 let wi = self.weights[i];
4352 if wi <= MIN_WEIGHT {
4353 continue;
4354 }
4355 let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4356 let survival_jet =
4357 LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4358 .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4359 ll +=
4360 wi * binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?.log_lik;
4361 }
4362 Ok(ll)
4363 }
4364
4365 fn block_linear_constraints(
4366 &self,
4367 _: &[ParameterBlockState],
4368 block_idx: usize,
4369 block_spec: &ParameterBlockSpec,
4370 ) -> Result<Option<LinearInequalityConstraints>, String> {
4371 assert!(!block_spec.name.is_empty());
4372 if block_idx == Self::BLOCK_TIME {
4373 Ok(self.time_linear_constraints.clone())
4374 } else {
4375 Ok(None)
4376 }
4377 }
4378
4379 fn exact_newton_joint_hessian(
4380 &self,
4381 block_states: &[ParameterBlockState],
4382 ) -> Result<Option<Array2<f64>>, String> {
4383 self.evaluate_exact_newton_joint_dense(block_states)
4384 .map(|(_, _, hessian)| Some(hessian))
4385 }
4386
4387 fn exact_newton_joint_hessian_workspace(
4388 &self,
4389 block_states: &[ParameterBlockState],
4390 _: &[ParameterBlockSpec],
4391 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4392 Ok(Some(Arc::new(LatentBinaryHessianWorkspace::new(
4393 self.clone(),
4394 block_states.to_vec(),
4395 ))))
4396 }
4397
4398 fn exact_newton_joint_gradient_evaluation(
4399 &self,
4400 block_states: &[ParameterBlockState],
4401 _: &[ParameterBlockSpec],
4402 ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4403 self.evaluate_exact_newton_joint_dense(block_states)
4404 .map(|(log_likelihood, gradient, _)| {
4405 Some(ExactNewtonJointGradientEvaluation {
4406 log_likelihood,
4407 gradient,
4408 })
4409 })
4410 }
4411
4412 fn exact_newton_joint_hessian_directional_derivative(
4413 &self,
4414 block_states: &[ParameterBlockState],
4415 d_beta_flat: &Array1<f64>,
4416 ) -> Result<Option<Array2<f64>>, String> {
4417 self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4418 .map(Some)
4419 }
4420
4421 fn exact_newton_joint_hessiansecond_directional_derivative(
4422 &self,
4423 block_states: &[ParameterBlockState],
4424 d_beta_u_flat: &Array1<f64>,
4425 d_beta_v_flat: &Array1<f64>,
4426 ) -> Result<Option<Array2<f64>>, String> {
4427 self.exact_newton_joint_hessian_second_directional_derivative_dense(
4428 block_states,
4429 d_beta_u_flat,
4430 d_beta_v_flat,
4431 )
4432 .map(Some)
4433 }
4434
4435 fn requires_joint_outer_hyper_path(&self) -> bool {
4436 true
4437 }
4438}
4439
4440#[cfg(test)]
4441mod tests {
4442 use super::*;
4443 use crate::custom_family::BlockWorkingSet;
4444 use gam_linalg::matrix::DenseDesignMatrix;
4445 use ndarray::array;
4446
4447 fn learnable_sigma_test_family() -> LatentSurvivalFamily {
4448 LatentSurvivalFamily {
4449 event_target: array![1u8, 0u8],
4450 weights: array![1.0, 0.7],
4451 latent_sd_fixed: None,
4452 hazard_loading: HazardLoading::LoadedVsUnloaded,
4453 unloaded_mass_entry: array![0.02, 0.03],
4454 unloaded_mass_exit: array![0.05, 0.08],
4455 unloaded_hazard_exit: array![0.04, 0.0],
4456 x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4457 x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4458 x_time_derivative_exit: array![[0.8, 0.4], [0.6, 0.5]],
4459 x_time_right: array![[1.3, 0.1], [0.9, 1.0]],
4460 time_offset_right: Array1::zeros(2),
4461 unloaded_mass_right: Array1::zeros(2),
4462 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4463 time_linear_constraints: None,
4464 quadctx: Arc::new(QuadratureContext::new()),
4465 }
4466 }
4467
4468 fn learnable_sigma_test_joint_beta() -> Array1<f64> {
4469 array![0.15, 0.25, 0.1, -0.15, 0.35_f64.ln()]
4470 }
4471
4472 fn survival_stress_test_family(n: usize) -> LatentSurvivalFamily {
4473 LatentSurvivalFamily {
4474 event_target: Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1u8 } else { 0u8 })),
4475 weights: Array1::from_iter((0..n).map(|i| 0.55 + 0.03 * ((i % 7) as f64))),
4476 latent_sd_fixed: None,
4477 hazard_loading: HazardLoading::LoadedVsUnloaded,
4478 unloaded_mass_entry: Array1::from_iter(
4479 (0..n).map(|i| 0.015 + 0.0015 * ((i % 11) as f64)),
4480 ),
4481 unloaded_mass_exit: Array1::from_iter((0..n).map(|i| 0.06 + 0.002 * ((i % 13) as f64))),
4482 unloaded_hazard_exit: Array1::from_iter((0..n).map(|i| {
4483 if i % 4 == 0 {
4484 0.018 + 0.001 * ((i % 5) as f64)
4485 } else {
4486 0.0
4487 }
4488 })),
4489 x_time_entry: Array2::from_shape_fn((n, 4), |(i, j)| {
4490 0.2 + 0.03 * ((i + 2 * j) % 9) as f64 - if j == 1 { 0.12 } else { 0.0 }
4491 }),
4492 x_time_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4493 0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4494 }),
4495 x_time_derivative_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4496 0.45 + 0.015 * ((i + 3 * j) % 8) as f64
4497 }),
4498 x_time_right: Array2::from_shape_fn((n, 4), |(i, j)| {
4499 0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4500 }),
4501 time_offset_right: Array1::zeros(n),
4502 unloaded_mass_right: Array1::zeros(n),
4503 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_shape_fn(
4504 (n, 3),
4505 |(i, j)| 0.1 + 0.04 * ((3 * i + j) % 7) as f64 - if j == 0 { 0.18 } else { 0.0 },
4506 ))),
4507 time_linear_constraints: None,
4508 quadctx: Arc::new(QuadratureContext::new()),
4509 }
4510 }
4511
4512 fn survival_stress_test_joint_beta() -> Array1<f64> {
4513 array![0.18, 0.11, 0.07, 0.13, -0.09, 0.05, 0.12, 0.42_f64.ln()]
4514 }
4515
4516 fn latent_survival_states_from_joint_beta(
4517 family: &LatentSurvivalFamily,
4518 joint_beta: &Array1<f64>,
4519 ) -> Vec<ParameterBlockState> {
4520 let slices = family.joint_slices();
4521 let n = family.event_target.len();
4522 let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4523 let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4524
4525 let mut eta_time = Array1::<f64>::zeros(3 * n);
4526 eta_time
4527 .slice_mut(s![0..n])
4528 .assign(&gam_linalg::faer_ndarray::fast_av(
4529 &family.x_time_entry,
4530 &beta_time,
4531 ));
4532 eta_time
4533 .slice_mut(s![n..2 * n])
4534 .assign(&gam_linalg::faer_ndarray::fast_av(
4535 &family.x_time_exit,
4536 &beta_time,
4537 ));
4538 eta_time
4539 .slice_mut(s![2 * n..3 * n])
4540 .assign(&gam_linalg::faer_ndarray::fast_av(
4541 &family.x_time_derivative_exit,
4542 &beta_time,
4543 ));
4544
4545 let mut states = vec![
4546 ParameterBlockState {
4547 beta: beta_time,
4548 eta: eta_time,
4549 },
4550 ParameterBlockState {
4551 beta: beta_mean.clone(),
4552 eta: family.x_mean.dot(&beta_mean),
4553 },
4554 ];
4555 if let Some(log_sigma) = slices.log_sigma {
4556 let beta_log_sigma = array![joint_beta[log_sigma.start]];
4557 states.push(ParameterBlockState {
4558 beta: beta_log_sigma.clone(),
4559 eta: beta_log_sigma,
4560 });
4561 }
4562 states
4563 }
4564
4565 fn max_relative_array1(left: &Array1<f64>, right: &Array1<f64>) -> f64 {
4566 left.iter()
4567 .zip(right.iter())
4568 .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4569 .fold(0.0_f64, f64::max)
4570 }
4571
4572 fn max_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4573 left.iter()
4574 .zip(right.iter())
4575 .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4576 .fold(0.0_f64, f64::max)
4577 }
4578
4579 fn frobenius_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4580 let mut diff2 = 0.0_f64;
4581 let mut scale2 = 0.0_f64;
4582 for (l, r) in left.iter().zip(right.iter()) {
4583 let d = l - r;
4584 diff2 += d * d;
4585 scale2 += l * l + r * r;
4586 }
4587 diff2.sqrt() / scale2.sqrt().max(1e-12)
4588 }
4589
4590 fn latent_survival_row_loglik_from_primary(
4591 quadctx: &QuadratureContext,
4592 row: &LatentSurvivalRow,
4593 primary: &Array1<f64>,
4594 ) -> f64 {
4595 let q_entry = primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
4596 let q_exit = primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
4597 let qdot_exit = primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
4598 let q_right = primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
4599 let mu = primary[LATENT_SURVIVAL_PRIMARY_MU];
4600 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
4601 latent_survival_row_primary_gradient_hessian(
4602 quadctx, row, q_entry, q_exit, qdot_exit, q_right, mu, sigma, true,
4603 )
4604 .expect("row primary evaluation")
4605 .0
4606 }
4607
4608 fn latent_test_specs(n: usize, block_dims: &[(&str, usize)]) -> Vec<ParameterBlockSpec> {
4609 block_dims
4610 .iter()
4611 .map(|(name, p)| ParameterBlockSpec {
4612 name: (*name).to_string(),
4613 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, *p)))),
4614 offset: Array1::zeros(n),
4615 penalties: Vec::new(),
4616 nullspace_dims: Vec::new(),
4617 initial_log_lambdas: Array1::zeros(0),
4618 initial_beta: None,
4619 gauge_priority: 100,
4620 jacobian_callback: None,
4621 stacked_design: None,
4622 stacked_offset: None,
4623 })
4624 .collect()
4625 }
4626
4627 fn fixed_sigma_binary_test_family() -> LatentBinaryFamily {
4628 LatentBinaryFamily {
4629 event_target: array![1u8, 0u8],
4630 weights: array![1.0, 0.7],
4631 latent_sd: 0.35,
4632 hazard_loading: HazardLoading::LoadedVsUnloaded,
4633 unloaded_mass_entry: array![0.02, 0.03],
4634 unloaded_mass_exit: array![0.05, 0.08],
4635 x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4636 x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4637 x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4638 time_linear_constraints: None,
4639 quadctx: Arc::new(QuadratureContext::new()),
4640 }
4641 }
4642
4643 fn latent_binary_states_from_joint_beta(
4644 family: &LatentBinaryFamily,
4645 joint_beta: &Array1<f64>,
4646 ) -> Vec<ParameterBlockState> {
4647 let slices = family.joint_slices();
4648 let n = family.event_target.len();
4649 let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4650 let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4651
4652 let mut eta_time = Array1::<f64>::zeros(3 * n);
4653 eta_time
4654 .slice_mut(s![0..n])
4655 .assign(&gam_linalg::faer_ndarray::fast_av(
4656 &family.x_time_entry,
4657 &beta_time,
4658 ));
4659 eta_time
4660 .slice_mut(s![n..2 * n])
4661 .assign(&gam_linalg::faer_ndarray::fast_av(
4662 &family.x_time_exit,
4663 &beta_time,
4664 ));
4665
4666 vec![
4667 ParameterBlockState {
4668 beta: beta_time,
4669 eta: eta_time,
4670 },
4671 ParameterBlockState {
4672 beta: beta_mean.clone(),
4673 eta: family.x_mean.dot(&beta_mean),
4674 },
4675 ]
4676 }
4677
4678 use crate::survival::location_scale::{TimeBlockInput, TimeBlockMonotonicity};
4681
4682 fn validation_time_block(n: usize, p_time: usize) -> TimeBlockInput {
4686 let design = |fill: f64| {
4687 DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_elem(
4688 (n, p_time),
4689 fill,
4690 )))
4691 };
4692 TimeBlockInput {
4693 design_entry: design(0.1),
4694 design_exit: design(0.2),
4695 design_derivative_exit: design(0.3),
4696 offset_entry: Array1::zeros(n),
4697 offset_exit: Array1::zeros(n),
4698 derivative_offset_exit: Array1::zeros(n),
4699 time_monotonicity: TimeBlockMonotonicity::EnforcedByCoordinateCone,
4700 penalties: Vec::new(),
4701 nullspace_dims: Vec::new(),
4702 initial_log_lambdas: None,
4703 initial_beta: None,
4704 }
4705 }
4706
4707 fn empty_meanspec() -> TermCollectionSpec {
4708 TermCollectionSpec {
4709 linear_terms: Vec::new(),
4710 random_effect_terms: Vec::new(),
4711 smooth_terms: Vec::new(),
4712 }
4713 }
4714
4715 fn valid_survival_spec(n: usize, p_time: usize) -> LatentSurvivalTermSpec {
4718 LatentSurvivalTermSpec {
4719 age_entry: Array1::zeros(n),
4720 age_exit: Array1::from_elem(n, 1.0),
4721 event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4722 weights: Array1::from_elem(n, 1.0),
4723 derivative_guard: 0.0,
4724 time_block: validation_time_block(n, p_time),
4725 time_design_right: None,
4726 time_offset_right: None,
4727 unloaded_mass_entry: Array1::from_elem(n, 0.01),
4728 unloaded_mass_exit: Array1::from_elem(n, 0.05),
4729 unloaded_mass_right: Array1::zeros(0),
4730 unloaded_hazard_exit: Array1::from_elem(n, 0.02),
4731 meanspec: empty_meanspec(),
4732 mean_offset: Array1::zeros(n),
4733 }
4734 }
4735
4736 fn valid_binary_spec(n: usize, p_time: usize) -> LatentBinaryTermSpec {
4739 LatentBinaryTermSpec {
4740 age_entry: Array1::zeros(n),
4741 age_exit: Array1::from_elem(n, 1.0),
4742 event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4743 weights: Array1::from_elem(n, 1.0),
4744 derivative_guard: 0.0,
4745 time_block: validation_time_block(n, p_time),
4746 unloaded_mass_entry: Array1::from_elem(n, 0.01),
4747 unloaded_mass_exit: Array1::from_elem(n, 0.05),
4748 meanspec: empty_meanspec(),
4749 mean_offset: Array1::zeros(n),
4750 }
4751 }
4752
4753 fn loaded_frailty() -> FrailtySpec {
4754 FrailtySpec::HazardMultiplier {
4755 sigma_fixed: Some(0.3),
4756 loading: HazardLoading::LoadedVsUnloaded,
4757 }
4758 }
4759
4760 #[test]
4767 fn latent_interval_validation_parity_across_models() {
4768 let n = 2;
4769 let p_time = 2;
4770 let data = Array2::<f64>::zeros((n, 3));
4771
4772 let surv_sigma = validate_latent_survival_inputs(
4776 data.view(),
4777 &valid_survival_spec(n, p_time),
4778 &loaded_frailty(),
4779 )
4780 .expect("valid survival spec must validate");
4781 assert_eq!(surv_sigma, Some(0.3));
4782 let bin_sigma = validate_latent_binary_inputs(
4783 data.view(),
4784 &valid_binary_spec(n, p_time),
4785 &loaded_frailty(),
4786 )
4787 .expect("valid binary spec must validate");
4788 assert_eq!(bin_sigma, 0.3);
4789
4790 let empty = Array2::<f64>::zeros((0, 3));
4792 let surv_empty = validate_latent_survival_inputs(
4793 empty.view(),
4794 &valid_survival_spec(n, p_time),
4795 &loaded_frailty(),
4796 )
4797 .expect_err("empty data must be rejected");
4798 assert_eq!(
4799 surv_empty.to_string(),
4800 "latent-survival requires a non-empty dataset"
4801 );
4802 let bin_empty = validate_latent_binary_inputs(
4803 empty.view(),
4804 &valid_binary_spec(n, p_time),
4805 &loaded_frailty(),
4806 )
4807 .expect_err("empty data must be rejected");
4808 assert_eq!(
4809 bin_empty.to_string(),
4810 "latent-binary requires a non-empty dataset"
4811 );
4812
4813 let mut surv_bad = valid_survival_spec(n, p_time);
4817 surv_bad.weights = Array1::from_elem(n + 1, 1.0);
4818 let surv_size = validate_latent_survival_inputs(data.view(), &surv_bad, &loaded_frailty())
4819 .expect_err("size mismatch must be rejected");
4820 let surv_msg = surv_size.to_string();
4821 assert!(
4822 surv_msg.starts_with("latent-survival size mismatch")
4823 && surv_msg.contains("unloaded_hazard="),
4824 "survival size-mismatch message must include unloaded_hazard: {surv_msg}"
4825 );
4826 let mut bin_bad = valid_binary_spec(n, p_time);
4827 bin_bad.weights = Array1::from_elem(n + 1, 1.0);
4828 let bin_size = validate_latent_binary_inputs(data.view(), &bin_bad, &loaded_frailty())
4829 .expect_err("size mismatch must be rejected");
4830 let bin_msg = bin_size.to_string();
4831 assert!(
4832 bin_msg.starts_with("latent-binary size mismatch")
4833 && !bin_msg.contains("unloaded_hazard"),
4834 "binary size-mismatch message must omit unloaded_hazard: {bin_msg}"
4835 );
4836
4837 let mut surv_neg_hazard = valid_survival_spec(n, p_time);
4840 surv_neg_hazard.unloaded_hazard_exit[0] = -1.0;
4841 let surv_decomp =
4842 validate_latent_survival_inputs(data.view(), &surv_neg_hazard, &loaded_frailty())
4843 .expect_err("negative unloaded hazard must be rejected");
4844 assert_eq!(
4845 surv_decomp.to_string(),
4846 "latent-survival row 1 has invalid unloaded hazard decomposition: entry_mass=0.01, exit_mass=0.05, exit_hazard=-1"
4847 );
4848 let mut bin_bad_mass = valid_binary_spec(n, p_time);
4849 bin_bad_mass.unloaded_mass_exit[0] = 0.0; let bin_decomp =
4851 validate_latent_binary_inputs(data.view(), &bin_bad_mass, &loaded_frailty())
4852 .expect_err("non-monotone unloaded mass must be rejected");
4853 assert_eq!(
4854 bin_decomp.to_string(),
4855 "latent-binary row 1 has invalid unloaded mass decomposition: entry_mass=0.01, exit_mass=0"
4856 );
4857
4858 let mut surv_event = valid_survival_spec(n, p_time);
4861 surv_event.event_target[1] = 7;
4862 let surv_event_err =
4863 validate_latent_survival_inputs(data.view(), &surv_event, &loaded_frailty())
4864 .expect_err("invalid event target must be rejected");
4865 assert_eq!(
4866 surv_event_err.to_string(),
4867 "latent-survival row 2 has invalid event target 7; expected 0 or 1"
4868 );
4869 let mut bin_event = valid_binary_spec(n, p_time);
4870 bin_event.event_target[1] = 7;
4871 let bin_event_err =
4872 validate_latent_binary_inputs(data.view(), &bin_event, &loaded_frailty())
4873 .expect_err("invalid event target must be rejected");
4874 assert_eq!(
4875 bin_event_err.to_string(),
4876 "latent-binary row 2 has invalid event target 7; expected 0 or 1"
4877 );
4878
4879 let learnable = FrailtySpec::HazardMultiplier {
4882 sigma_fixed: None,
4883 loading: HazardLoading::LoadedVsUnloaded,
4884 };
4885 let surv_learnable = validate_latent_survival_inputs(
4886 data.view(),
4887 &valid_survival_spec(n, p_time),
4888 &learnable,
4889 )
4890 .expect("survival accepts a learnable latent scale");
4891 assert_eq!(surv_learnable, None);
4892 let bin_learnable =
4893 validate_latent_binary_inputs(data.view(), &valid_binary_spec(n, p_time), &learnable)
4894 .expect_err("binary requires a fixed latent scale");
4895 assert_eq!(
4896 bin_learnable.to_string(),
4897 "latent-binary currently requires a fixed hazard-multiplier sigma"
4898 );
4899
4900 let mut surv_time_bad = valid_survival_spec(n, p_time);
4903 surv_time_bad.time_block.design_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
4904 Array2::from_elem((n, p_time + 1), 0.1),
4905 ));
4906 let surv_time_err =
4907 validate_latent_survival_inputs(data.view(), &surv_time_bad, &loaded_frailty())
4908 .expect_err("time block column mismatch must be rejected");
4909 assert!(
4910 surv_time_err
4911 .to_string()
4912 .starts_with("latent-survival time block column mismatch"),
4913 "unexpected survival time-block message: {surv_time_err}"
4914 );
4915 }
4916
4917 #[test]
4918 fn latent_survival_coefficient_cost_uses_joint_coupled_formula() {
4919 let family = learnable_sigma_test_family();
4925 let n = family.event_target.len() as u64;
4926 let p_time = 2u64;
4927 let p_mean = 2u64;
4928 let p_log_sigma = 1u64;
4929 let specs = vec![
4930 ParameterBlockSpec {
4931 name: "time".to_string(),
4932 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4933 n as usize,
4934 p_time as usize,
4935 )))),
4936 offset: Array1::zeros(n as usize),
4937 penalties: Vec::new(),
4938 nullspace_dims: Vec::new(),
4939 initial_log_lambdas: Array1::zeros(0),
4940 initial_beta: None,
4941 gauge_priority: 100,
4942 jacobian_callback: None,
4943 stacked_design: None,
4944 stacked_offset: None,
4945 },
4946 ParameterBlockSpec {
4947 name: "mean".to_string(),
4948 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4949 n as usize,
4950 p_mean as usize,
4951 )))),
4952 offset: Array1::zeros(n as usize),
4953 penalties: Vec::new(),
4954 nullspace_dims: Vec::new(),
4955 initial_log_lambdas: Array1::zeros(0),
4956 initial_beta: None,
4957 gauge_priority: 100,
4958 jacobian_callback: None,
4959 stacked_design: None,
4960 stacked_offset: None,
4961 },
4962 ParameterBlockSpec {
4963 name: "log_sigma".to_string(),
4964 design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4965 n as usize,
4966 p_log_sigma as usize,
4967 )))),
4968 offset: Array1::zeros(n as usize),
4969 penalties: Vec::new(),
4970 nullspace_dims: Vec::new(),
4971 initial_log_lambdas: Array1::zeros(0),
4972 initial_beta: None,
4973 gauge_priority: 100,
4974 jacobian_callback: None,
4975 stacked_design: None,
4976 stacked_offset: None,
4977 },
4978 ];
4979 let p_total = p_time + p_mean + p_log_sigma;
4980 let expected_joint = n * p_total * p_total;
4981 let expected_block_diag =
4982 n * (p_time * p_time + p_mean * p_mean + p_log_sigma * p_log_sigma);
4983 assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
4984 assert!(expected_joint > expected_block_diag);
4987 }
4988
4989 #[test]
4990 fn latent_family_planner_keeps_outer_hessian_at_large_n() {
4991 use crate::custom_family::custom_family_outer_derivatives;
4992 use gam_problem::{DeclaredHessianForm, Derivative};
4993
4994 let options = BlockwiseFitOptions::default();
4995 let large_n = 50_001;
4996
4997 let survival = learnable_sigma_test_family();
4998 let survival_specs =
4999 latent_test_specs(large_n, &[("time", 2), ("mean", 2), ("log_sigma", 1)]);
5000 let (surv_grad, surv_hess) =
5001 custom_family_outer_derivatives(&survival, &survival_specs, &options);
5002 assert_eq!(surv_grad, Derivative::Analytic);
5003 assert_eq!(surv_hess, DeclaredHessianForm::Either);
5004
5005 let binary = fixed_sigma_binary_test_family();
5006 let binary_specs = latent_test_specs(large_n, &[("time", 2), ("mean", 2)]);
5007 let (bin_grad, bin_hess) =
5008 custom_family_outer_derivatives(&binary, &binary_specs, &options);
5009 assert_eq!(bin_grad, Derivative::Analytic);
5010 assert_eq!(bin_hess, DeclaredHessianForm::Either);
5011 }
5012
5013 #[test]
5014 fn latent_families_arm_self_vanishing_levenberg_on_ill_conditioning() {
5015 assert!(
5028 learnable_sigma_test_family().levenberg_on_ill_conditioning(),
5029 "LatentSurvivalFamily must arm the self-vanishing Levenberg floor so the \
5030 indefinite interval-censored joint Hessian converges (see #1108)"
5031 );
5032 assert!(
5033 fixed_sigma_binary_test_family().levenberg_on_ill_conditioning(),
5034 "LatentBinaryFamily must arm the self-vanishing Levenberg floor on its \
5035 constrained coupled time block (see #1108)"
5036 );
5037 }
5038
5039 #[test]
5040 fn latent_binary_exact_joint_hessian_and_workspace_matvec_match_fd() {
5041 let family = fixed_sigma_binary_test_family();
5042 let beta = array![0.15, 0.25, 0.1, -0.15];
5043 let states = latent_binary_states_from_joint_beta(&family, &beta);
5044 let h = 1e-6;
5045
5046 let analytic_hessian = family
5047 .exact_newton_joint_hessian(&states)
5048 .expect("analytic latent binary joint hessian evaluation")
5049 .expect("latent binary should expose exact joint hessian");
5050
5051 for j in 0..beta.len() {
5052 let mut beta_plus = beta.clone();
5053 beta_plus[j] += h;
5054 let gradient_plus = family
5055 .exact_newton_joint_gradient_evaluation(
5056 &latent_binary_states_from_joint_beta(&family, &beta_plus),
5057 &[],
5058 )
5059 .expect("joint gradient plus")
5060 .expect("joint gradient should exist")
5061 .gradient;
5062
5063 let mut beta_minus = beta.clone();
5064 beta_minus[j] -= h;
5065 let gradient_minus = family
5066 .exact_newton_joint_gradient_evaluation(
5067 &latent_binary_states_from_joint_beta(&family, &beta_minus),
5068 &[],
5069 )
5070 .expect("joint gradient minus")
5071 .expect("joint gradient should exist")
5072 .gradient;
5073
5074 let fd_column = -((&gradient_plus - &gradient_minus) / (2.0 * h));
5075 let analytic_column = analytic_hessian.column(j).to_owned();
5076 let rel = max_relative_array1(&analytic_column, &fd_column);
5077 assert!(
5078 rel < 5e-4,
5079 "latent binary joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={fd_column:?}"
5080 );
5081 }
5082
5083 let workspace = family
5084 .exact_newton_joint_hessian_workspace(&states, &[])
5085 .expect("latent binary hessian workspace")
5086 .expect("workspace should exist");
5087 let direction = array![0.4, -0.2, 0.3, 0.1];
5088 let hv = workspace
5089 .hessian_matvec(&direction)
5090 .expect("workspace matvec")
5091 .expect("workspace should support matvec");
5092 let dense_hv = analytic_hessian.dot(&direction);
5093 assert!(
5094 max_relative_array1(&hv, &dense_hv) < 1e-12,
5095 "latent binary workspace HVP mismatch: hv={hv:?}, dense={dense_hv:?}"
5096 );
5097
5098 let dh = workspace
5099 .directional_derivative(&direction)
5100 .expect("workspace dH")
5101 .expect("workspace should support dH");
5102 let fd_step = 1e-5;
5103 let h_plus = family
5104 .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5105 &family,
5106 &(beta.clone() + &(fd_step * &direction)),
5107 ))
5108 .expect("hessian plus")
5109 .expect("hessian plus should exist");
5110 let h_minus = family
5111 .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5112 &family,
5113 &(beta - &(fd_step * &direction)),
5114 ))
5115 .expect("hessian minus")
5116 .expect("hessian minus should exist");
5117 let fd_dh = (&h_plus - &h_minus) / (2.0 * fd_step);
5118 assert!(
5119 max_relative_array2(&dh, &fd_dh) < 2e-4,
5120 "latent binary workspace dH mismatch: dh={dh:?}, fd={fd_dh:?}"
5121 );
5122 }
5123
5124 #[test]
5125 fn latent_survival_learnable_sigma_block_matches_family_fd() {
5126 let family = learnable_sigma_test_family();
5127 let beta = learnable_sigma_test_joint_beta();
5128 let states = latent_survival_states_from_joint_beta(&family, &beta);
5129 let slices = family.joint_slices();
5130 let sigma_idx = slices
5131 .log_sigma
5132 .as_ref()
5133 .expect("learnable sigma test family should expose log_sigma")
5134 .start;
5135 let h = 2e-4;
5136
5137 let eval = family
5138 .evaluate(&states)
5139 .expect("learnable latent survival evaluation");
5140 let joint_gradient = family
5141 .exact_newton_joint_gradient_evaluation(&states, &[])
5142 .expect("joint gradient evaluation")
5143 .expect("joint gradient should exist")
5144 .gradient;
5145 let joint_hessian = family
5146 .exact_newton_joint_hessian(&states)
5147 .expect("joint hessian evaluation")
5148 .expect("joint hessian should exist");
5149 assert_eq!(eval.blockworking_sets.len(), 3);
5150
5151 let (block_grad, block_neg_hess) =
5152 match &eval.blockworking_sets[LatentSurvivalFamily::BLOCK_LOG_SIGMA] {
5153 BlockWorkingSet::ExactNewton { gradient, hessian } => {
5154 let neg_hess = match hessian {
5155 SymmetricMatrix::Dense(mat) => mat[[0, 0]],
5156 _ => panic!("log_sigma block should use a dense exact-Newton Hessian"),
5157 };
5158 (gradient[0], neg_hess)
5159 }
5160 _ => panic!("log_sigma block should use ExactNewton"),
5161 };
5162
5163 assert!((block_grad - joint_gradient[sigma_idx]).abs() < 1e-12);
5164 assert!((block_neg_hess - joint_hessian[[sigma_idx, sigma_idx]]).abs() < 1e-12);
5165
5166 let mut beta_plus = beta.clone();
5167 beta_plus[sigma_idx] += h;
5168 let ll_plus = family
5169 .log_likelihood_only(&latent_survival_states_from_joint_beta(&family, &beta_plus))
5170 .expect("ll plus");
5171 let ll_0 = family.log_likelihood_only(&states).expect("ll base");
5172 let mut beta_minus = beta.clone();
5173 beta_minus[sigma_idx] -= h;
5174 let ll_minus = family
5175 .log_likelihood_only(&latent_survival_states_from_joint_beta(
5176 &family,
5177 &beta_minus,
5178 ))
5179 .expect("ll minus");
5180
5181 let fd_grad = (ll_plus - ll_minus) / (2.0 * h);
5182 let fd_neg_hess = -(ll_plus - 2.0 * ll_0 + ll_minus) / (h * h);
5183 assert!(
5184 (joint_gradient[sigma_idx] - fd_grad).abs()
5185 / joint_gradient[sigma_idx]
5186 .abs()
5187 .max(fd_grad.abs())
5188 .max(1e-12)
5189 < 2e-3,
5190 "family log_sigma grad={}, fd={fd_grad}",
5191 joint_gradient[sigma_idx]
5192 );
5193 assert!(
5194 (joint_hessian[[sigma_idx, sigma_idx]] - fd_neg_hess).abs()
5195 / joint_hessian[[sigma_idx, sigma_idx]]
5196 .abs()
5197 .max(fd_neg_hess.abs())
5198 .max(1e-10)
5199 < 2e-2,
5200 "family log_sigma neg_hess={}, fd={fd_neg_hess}",
5201 joint_hessian[[sigma_idx, sigma_idx]]
5202 );
5203 }
5204
5205 #[test]
5206 fn latent_survival_exact_joint_hessian_matches_gradient_fd() {
5207 let family = learnable_sigma_test_family();
5208 let beta = learnable_sigma_test_joint_beta();
5209 let states = latent_survival_states_from_joint_beta(&family, &beta);
5210 let h = 1e-6;
5211
5212 let analytic_hessian = family
5213 .exact_newton_joint_hessian(&states)
5214 .expect("analytic joint hessian evaluation")
5215 .expect("latent survival should expose exact joint hessian");
5216
5217 for j in 0..beta.len() {
5218 let mut beta_plus = beta.clone();
5219 beta_plus[j] += h;
5220 let gradient_plus = family
5221 .exact_newton_joint_gradient_evaluation(
5222 &latent_survival_states_from_joint_beta(&family, &beta_plus),
5223 &[],
5224 )
5225 .expect("joint gradient plus")
5226 .expect("joint gradient should exist")
5227 .gradient;
5228
5229 let mut beta_minus = beta.clone();
5230 beta_minus[j] -= h;
5231 let gradient_minus = family
5232 .exact_newton_joint_gradient_evaluation(
5233 &latent_survival_states_from_joint_beta(&family, &beta_minus),
5234 &[],
5235 )
5236 .expect("joint gradient minus")
5237 .expect("joint gradient should exist")
5238 .gradient;
5239
5240 let fd_column = (&gradient_plus - &gradient_minus) / (2.0 * h);
5241 let analytic_column = analytic_hessian.column(j).to_owned();
5242 let rel = max_relative_array1(&analytic_column, &(-fd_column));
5243 assert!(
5244 rel < 5e-4,
5245 "joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={:?}",
5246 -((&gradient_plus - &gradient_minus) / (2.0 * h))
5247 );
5248 }
5249 }
5250
5251 #[test]
5258 fn latent_survival_offset_channel_residuals_match_finite_difference() {
5259 let family = survival_stress_test_family(24);
5260 let beta = survival_stress_test_joint_beta();
5261 let states = latent_survival_states_from_joint_beta(&family, &beta);
5262 let n = family.event_target.len();
5263
5264 let residuals = family
5265 .offset_channel_residuals(&states)
5266 .expect("offset channel residuals");
5267 let sum_entry: f64 = residuals.entry.sum();
5268 let sum_exit: f64 = residuals.exit.sum();
5269 let sum_deriv: f64 = residuals.derivative.sum();
5270
5271 let neg_ll_with_offset = |channel: usize, delta: f64| -> f64 {
5273 let mut shifted = states.clone();
5274 let slice = match channel {
5275 0 => s![0..n],
5276 1 => s![n..2 * n],
5277 2 => s![2 * n..3 * n],
5278 _ => unreachable!(),
5279 };
5280 shifted[LatentSurvivalFamily::BLOCK_TIME]
5281 .eta
5282 .slice_mut(slice)
5283 .mapv_inplace(|v| v + delta);
5284 let (ll, _) = family
5285 .evaluate_exact_newton_joint_gradient_dense(&shifted)
5286 .expect("shifted joint gradient evaluation");
5287 -ll
5288 };
5289
5290 let h = 1e-6;
5291 let fd_entry = (neg_ll_with_offset(0, h) - neg_ll_with_offset(0, -h)) / (2.0 * h);
5292 let fd_exit = (neg_ll_with_offset(1, h) - neg_ll_with_offset(1, -h)) / (2.0 * h);
5293 let fd_deriv = (neg_ll_with_offset(2, h) - neg_ll_with_offset(2, -h)) / (2.0 * h);
5294
5295 assert!(
5296 (sum_entry - fd_entry).abs() <= 1e-5 * fd_entry.abs().max(1.0),
5297 "entry-channel residual sum mismatch: analytic={sum_entry}, fd={fd_entry}"
5298 );
5299 assert!(
5300 (sum_exit - fd_exit).abs() <= 1e-5 * fd_exit.abs().max(1.0),
5301 "exit-channel residual sum mismatch: analytic={sum_exit}, fd={fd_exit}"
5302 );
5303 assert!(
5304 (sum_deriv - fd_deriv).abs() <= 1e-5 * fd_deriv.abs().max(1.0),
5305 "derivative-channel residual sum mismatch: analytic={sum_deriv}, fd={fd_deriv}"
5306 );
5307 }
5308
5309 #[test]
5310 fn latent_survival_exact_joint_parallel_stress_is_repeatable() {
5311 let family = survival_stress_test_family(96);
5312 let beta = survival_stress_test_joint_beta();
5313 let states = latent_survival_states_from_joint_beta(&family, &beta);
5314 let direction_u = array![0.03, -0.02, 0.01, 0.04, -0.015, 0.025, -0.005, 0.02];
5315 let direction_v = array![-0.01, 0.035, -0.025, 0.015, 0.02, -0.01, 0.03, -0.015];
5316
5317 let (ll_a, grad_a) = family
5318 .evaluate_exact_newton_joint_gradient_dense(&states)
5319 .expect("stress joint gradient evaluation");
5320 let (ll_b, grad_b) = family
5321 .evaluate_exact_newton_joint_gradient_dense(&states)
5322 .expect("repeat stress joint gradient evaluation");
5323 assert_eq!(ll_a.to_bits(), ll_b.to_bits());
5324 assert_eq!(grad_a, grad_b);
5325
5326 let (joint_ll_a, joint_grad_a, hess_a) = family
5327 .evaluate_exact_newton_joint_dense(&states)
5328 .expect("stress joint dense evaluation");
5329 let (joint_ll_b, joint_grad_b, hess_b) = family
5330 .evaluate_exact_newton_joint_dense(&states)
5331 .expect("repeat stress joint dense evaluation");
5332 assert_eq!(joint_ll_a.to_bits(), joint_ll_b.to_bits());
5333 assert_eq!(joint_grad_a, joint_grad_b);
5334 assert_eq!(hess_a, hess_b);
5335 assert!(hess_a.iter().all(|value| value.is_finite()));
5336 assert!(max_relative_array2(&hess_a, &hess_a.t().to_owned()) < 1e-12);
5337
5338 let dh_a = family
5339 .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5340 .expect("stress joint dH evaluation");
5341 let dh_b = family
5342 .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5343 .expect("repeat stress joint dH evaluation");
5344 assert_eq!(dh_a, dh_b);
5345 assert!(dh_a.iter().all(|value| value.is_finite()));
5346 assert!(max_relative_array2(&dh_a, &dh_a.t().to_owned()) < 1e-12);
5347
5348 let d2h_a = family
5349 .exact_newton_joint_hessian_second_directional_derivative_dense(
5350 &states,
5351 &direction_u,
5352 &direction_v,
5353 )
5354 .expect("stress joint d2H evaluation");
5355 let d2h_b = family
5356 .exact_newton_joint_hessian_second_directional_derivative_dense(
5357 &states,
5358 &direction_u,
5359 &direction_v,
5360 )
5361 .expect("repeat stress joint d2H evaluation");
5362 assert_eq!(d2h_a, d2h_b);
5363 assert!(d2h_a.iter().all(|value| value.is_finite()));
5364 assert!(max_relative_array2(&d2h_a, &d2h_a.t().to_owned()) < 1e-12);
5365 }
5366
5367 #[test]
5368 fn latent_survival_exact_joint_dh_matches_hessian_fd() {
5369 let family = learnable_sigma_test_family();
5370 let beta = learnable_sigma_test_joint_beta();
5371 let states = latent_survival_states_from_joint_beta(&family, &beta);
5372 let h = 2e-4;
5373 let direction = array![0.07, -0.03, 0.05, 0.02, -0.04];
5374
5375 let analytic = family
5376 .exact_newton_joint_hessian_directional_derivative(&states, &direction)
5377 .expect("analytic joint dH evaluation")
5378 .expect("latent survival should expose exact joint dH");
5379
5380 let hessian_plus = family
5381 .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5382 &family,
5383 &(beta.clone() + h * &direction),
5384 ))
5385 .expect("joint hessian plus")
5386 .expect("joint hessian should exist");
5387 let hessian_minus = family
5388 .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5389 &family,
5390 &(beta.clone() - h * &direction),
5391 ))
5392 .expect("joint hessian minus")
5393 .expect("joint hessian should exist");
5394
5395 let fd = (&hessian_plus - &hessian_minus) / (2.0 * h);
5396 let rel = frobenius_relative_array2(&analytic, &fd);
5397 assert!(rel < 2e-3, "joint dH mismatch: rel={rel}");
5398 }
5399
5400 #[test]
5401 fn latent_survival_exact_joint_d2h_matches_directional_fd() {
5402 let family = learnable_sigma_test_family();
5403 let beta = learnable_sigma_test_joint_beta();
5404 let states = latent_survival_states_from_joint_beta(&family, &beta);
5405 let h = 5e-4;
5406 let direction_u = array![0.07, -0.03, 0.05, 0.02, -0.04];
5407 let direction_v = array![-0.02, 0.06, -0.01, 0.03, 0.05];
5408
5409 let analytic = family
5410 .exact_newton_joint_hessiansecond_directional_derivative(
5411 &states,
5412 &direction_u,
5413 &direction_v,
5414 )
5415 .expect("analytic joint d2H evaluation")
5416 .expect("latent survival should expose exact joint d2H");
5417 let swapped = family
5418 .exact_newton_joint_hessiansecond_directional_derivative(
5419 &states,
5420 &direction_v,
5421 &direction_u,
5422 )
5423 .expect("swapped analytic joint d2H evaluation")
5424 .expect("latent survival should expose exact joint d2H");
5425 let symmetry_rel = max_relative_array2(&analytic, &swapped);
5426 assert!(
5427 symmetry_rel < 1e-10,
5428 "joint d2H should be symmetric in directions, got rel={symmetry_rel}"
5429 );
5430
5431 let dh_plus = family
5432 .exact_newton_joint_hessian_directional_derivative(
5433 &latent_survival_states_from_joint_beta(
5434 &family,
5435 &(beta.clone() + h * &direction_v),
5436 ),
5437 &direction_u,
5438 )
5439 .expect("joint dH plus")
5440 .expect("joint dH should exist");
5441 let dh_minus = family
5442 .exact_newton_joint_hessian_directional_derivative(
5443 &latent_survival_states_from_joint_beta(
5444 &family,
5445 &(beta.clone() - h * &direction_v),
5446 ),
5447 &direction_u,
5448 )
5449 .expect("joint dH minus")
5450 .expect("joint dH should exist");
5451
5452 let fd = (&dh_plus - &dh_minus) / (2.0 * h);
5453 let rel = frobenius_relative_array2(&analytic, &fd);
5454 assert!(rel < 2.5e-2, "joint d2H mismatch: rel={rel}");
5455 }
5456
5457 #[test]
5458 fn latent_survival_row_primary_derivatives_match_fd() {
5459 let quadctx = QuadratureContext::new();
5460 let row = LatentSurvivalRow::exact_event(0.35, 1.4, 0.1, 0.45, 0.8, 0.12);
5461 let primary = array![
5466 0.35f64.ln(),
5467 1.4f64.ln(),
5468 0.8,
5469 1.6f64.ln(),
5470 -0.2,
5471 0.4f64.ln()
5472 ];
5473 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5474 let h_grad = 1e-6;
5475 let h_hess = 2e-4;
5476
5477 let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5478 &quadctx,
5479 &row,
5480 primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5481 primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5482 primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5483 primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5484 primary[LATENT_SURVIVAL_PRIMARY_MU],
5485 sigma,
5486 true,
5487 )
5488 .expect("analytic row primary gradient/hessian");
5489
5490 for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5491 let mut plus = primary.clone();
5492 plus[j] += h_grad;
5493 let mut minus = primary.clone();
5494 minus[j] -= h_grad;
5495 let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5496 - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5497 / (2.0 * h_grad);
5498 let rel_grad =
5499 (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5500 assert!(
5501 rel_grad < 2e-4,
5502 "row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5503 gradient[j]
5504 );
5505
5506 for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5507 let mut pp = primary.clone();
5508 pp[j] += h_hess;
5509 pp[k] += h_hess;
5510 let mut pm = primary.clone();
5511 pm[j] += h_hess;
5512 pm[k] -= h_hess;
5513 let mut mp = primary.clone();
5514 mp[j] -= h_hess;
5515 mp[k] += h_hess;
5516 let mut mm = primary.clone();
5517 mm[j] -= h_hess;
5518 mm[k] -= h_hess;
5519 let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5520 - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5521 - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5522 + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5523 / (4.0 * h_hess * h_hess);
5524 let analytic = neg_hessian[[j, k]];
5525 let abs_err = (analytic - fd_neg_hess).abs();
5526 let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5527 assert!(
5528 abs_err < 2e-5 || rel < 2e-3,
5529 "row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5530 );
5531 }
5532 }
5533 }
5534
5535 #[test]
5536 fn latent_survival_interval_row_primary_derivatives_match_fd() {
5537 let quadctx = QuadratureContext::new();
5549 let q_entry = -1.2_f64; let q_exit = -0.4_f64; let q_right = 0.5_f64; let mu = -0.15_f64;
5554 let log_sigma = 0.3_f64; let row = LatentSurvivalRow::interval_censored(
5558 q_entry.exp(), q_exit.exp(), q_right.exp(), 0.01, 0.02, 0.05, );
5565 assert!(matches!(
5566 row.event_type,
5567 LatentSurvivalEventType::IntervalCensored
5568 ));
5569
5570 let primary = array![q_entry, q_exit, 0.7, q_right, mu, log_sigma];
5574 let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5575 let h_grad = 1e-6;
5576 let h_hess = 2e-4;
5577
5578 let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5579 &quadctx,
5580 &row,
5581 primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5582 primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5583 primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5584 primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5585 primary[LATENT_SURVIVAL_PRIMARY_MU],
5586 sigma,
5587 true,
5588 )
5589 .expect("analytic interval row primary gradient/hessian");
5590
5591 let value = latent_survival_row_loglik_from_primary(&quadctx, &row, &primary);
5594 assert!(
5595 value.is_finite(),
5596 "interval row log-likelihood must be finite on a well-posed bracket, got {value}"
5597 );
5598
5599 for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5600 let mut plus = primary.clone();
5601 plus[j] += h_grad;
5602 let mut minus = primary.clone();
5603 minus[j] -= h_grad;
5604 let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5605 - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5606 / (2.0 * h_grad);
5607 let rel_grad =
5608 (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5609 assert!(
5610 rel_grad < 2e-4,
5611 "interval row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5612 gradient[j]
5613 );
5614
5615 for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5616 let mut pp = primary.clone();
5617 pp[j] += h_hess;
5618 pp[k] += h_hess;
5619 let mut pm = primary.clone();
5620 pm[j] += h_hess;
5621 pm[k] -= h_hess;
5622 let mut mp = primary.clone();
5623 mp[j] -= h_hess;
5624 mp[k] += h_hess;
5625 let mut mm = primary.clone();
5626 mm[j] -= h_hess;
5627 mm[k] -= h_hess;
5628 let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5629 - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5630 - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5631 + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5632 / (4.0 * h_hess * h_hess);
5633 let analytic = neg_hessian[[j, k]];
5634 let abs_err = (analytic - fd_neg_hess).abs();
5635 let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5636 assert!(
5637 abs_err < 5e-5 || rel < 3e-3,
5638 "interval row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5639 );
5640 }
5641 }
5642 }
5643}