1use crate::model_types::EstimationError;
15use crate::probability::signed_log_sum_exp;
16use crate::quadrature::{
17 IntegratedExpectationMode, QuadratureContext, lognormal_laplace_unit_log_term_shared,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21
22#[derive(Debug, Clone)]
30pub enum LognormalKernelError {
31 InvalidSpec { reason: String },
34}
35
36impl_reason_error_boilerplate! {
37 LognormalKernelError {
38 InvalidSpec,
39 }
40}
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "kebab-case")]
47pub enum HazardLoading {
48 Full,
50 LoadedVsUnloaded,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71#[serde(tag = "frailty_kind", rename_all = "kebab-case")]
72pub enum FrailtySpec {
73 None,
75 GaussianShift {
79 sigma_fixed: Option<f64>,
81 },
82 HazardMultiplier {
85 sigma_fixed: Option<f64>,
87 loading: HazardLoading,
89 },
90}
91
92impl FrailtySpec {
93 pub fn validate_for_marginal_slope(&self) -> Result<(), String> {
105 self.validate_for_marginal_slope_typed()
106 .map_err(|e| e.to_string())
107 }
108
109 pub fn validate_for_marginal_slope_typed(&self) -> Result<(), LognormalKernelError> {
113 match self {
114 Self::None | Self::GaussianShift { .. } => Ok(()),
115 Self::HazardMultiplier { .. } => Err(LognormalKernelError::InvalidSpec {
116 reason:
117 "HazardMultiplier frailty is not finite-state exact with score_warp/linkwiggle \
118 cubic marginal-slope families. Use GaussianShift frailty (exact probit scaling) \
119 or use the standalone latent-cloglog/latent-survival families instead."
120 .to_string(),
121 }),
122 }
123 }
124}
125
126#[inline]
129fn probit_frailty_scale_components(sigma: f64) -> (f64, f64) {
130 let abs_sigma = sigma.abs();
131 if abs_sigma > 1.0 {
132 let inv = 1.0 / abs_sigma;
133 let denom = 1.0 + inv * inv;
134 (inv / denom.sqrt(), 1.0 / denom)
135 } else {
136 let sigma2 = sigma * sigma;
137 let denom = 1.0 + sigma2;
138 (1.0 / denom.sqrt(), sigma2 / denom)
139 }
140}
141
142#[derive(Clone, Copy, Debug)]
150pub struct ProbitFrailtyScaleJet {
151 pub s: f64,
153 pub alpha: f64,
155 pub ds: f64,
157 pub d2s: f64,
159}
160
161impl ProbitFrailtyScaleJet {
162 pub fn new(sigma: f64) -> Self {
167 let (s, alpha) = probit_frailty_scale_components(sigma);
168 Self {
169 s,
170 alpha,
171 ds: -alpha * s,
172 d2s: alpha * (3.0 * alpha - 2.0) * s,
173 }
174 }
175
176 pub fn from_log_sigma(log_sigma: f64) -> Self {
178 Self::new(log_sigma.exp())
179 }
180}
181
182#[inline]
183fn worst_mode(
184 a: IntegratedExpectationMode,
185 b: IntegratedExpectationMode,
186) -> IntegratedExpectationMode {
187 if a.rank() >= b.rank() { a } else { b }
188}
189
190#[inline]
201fn validate_kernel_inputs(m: f64, mu: f64, sigma: f64) -> Result<(), EstimationError> {
202 if !m.is_finite() || m < 0.0 {
203 crate::bail_invalid_estim!("lognormal kernel requires finite m >= 0, got {m}");
204 }
205 if !mu.is_finite() || !sigma.is_finite() || sigma < 0.0 {
206 crate::bail_invalid_estim!(
207 "lognormal kernel requires finite mu and sigma >= 0, got mu={mu}, sigma={sigma}"
208 );
209 }
210 Ok::<(), _>(())
211}
212
213#[inline]
214pub fn log_kernel_term(
215 quadctx: &QuadratureContext,
216 k: usize,
217 m: f64,
218 mu: f64,
219 sigma: f64,
220) -> Result<(f64, IntegratedExpectationMode), EstimationError> {
221 validate_kernel_inputs(m, mu, sigma)?;
222 let kf = k as f64;
223 let sigma2 = sigma * sigma;
224 if !sigma2.is_finite() {
225 crate::bail_invalid_estim!(
226 "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
227 );
228 }
229 let prefix_bound = kf * mu.abs() + 0.5 * kf * kf * sigma2;
230 if !prefix_bound.is_finite() {
231 crate::bail_invalid_estim!(
232 "lognormal kernel prefix is outside the finite exact-derivative range: k={k}, mu={mu}, sigma={sigma}"
233 );
234 }
235 let prefix = kf * mu + 0.5 * kf * kf * sigma2;
236 if m == 0.0 {
237 return Ok((prefix, IntegratedExpectationMode::ExactClosedForm));
238 }
239 let log_m = m.ln();
240 let shifted_bound = mu.abs() + kf * sigma2 + log_m.abs();
241 if !shifted_bound.is_finite() {
242 crate::bail_invalid_estim!(
243 "lognormal kernel shifted location is outside the finite exact-derivative range: k={k}, m={m}, mu={mu}, sigma={sigma}"
244 );
245 }
246 let shifted_mu = mu + kf * sigma2 + log_m;
247 let (log_laplace, mode) = lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
252 Ok((prefix + log_laplace, mode))
253}
254
255#[derive(Clone, Debug)]
257pub struct LogLognormalKernelBundle {
258 pub log_values: Vec<f64>,
259 pub mode: IntegratedExpectationMode,
260}
261
262impl LogLognormalKernelBundle {
263 #[inline]
264 pub fn get(&self, k: usize) -> f64 {
265 self.log_values[k]
266 }
267
268 #[inline]
269 pub fn len(&self) -> usize {
270 self.log_values.len()
271 }
272}
273
274pub fn log_kernel_bundle(
277 quadctx: &QuadratureContext,
278 m: f64,
279 mu: f64,
280 sigma: f64,
281 max_k: usize,
282) -> Result<LogLognormalKernelBundle, EstimationError> {
283 validate_kernel_inputs(m, mu, sigma)?;
284 let mut log_values = Vec::with_capacity(max_k + 1);
285 let sigma2 = sigma * sigma;
286 if !sigma2.is_finite() {
287 crate::bail_invalid_estim!(
288 "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
289 );
290 }
291 let max_kf = max_k as f64;
292 let prefix_bound = max_kf * mu.abs() + 0.5 * max_kf * max_kf * sigma2;
293 if !prefix_bound.is_finite() {
294 crate::bail_invalid_estim!(
295 "lognormal kernel bundle prefix is outside the finite exact-derivative range: max_k={max_k}, mu={mu}, sigma={sigma}"
296 );
297 }
298 if m == 0.0 {
299 let mut prefix = 0.0;
300 for k in 0..=max_k {
301 log_values.push(prefix);
302 prefix += mu + (k as f64 + 0.5) * sigma2;
303 }
304 return Ok(LogLognormalKernelBundle {
305 log_values,
306 mode: IntegratedExpectationMode::ExactClosedForm,
307 });
308 }
309
310 let log_m = m.ln();
311 let shifted_bound = mu.abs() + max_kf * sigma2 + log_m.abs();
312 if !shifted_bound.is_finite() {
313 crate::bail_invalid_estim!(
314 "lognormal kernel bundle shifted location is outside the finite exact-derivative range: max_k={max_k}, m={m}, mu={mu}, sigma={sigma}"
315 );
316 }
317 let mut shifted_mu = mu + log_m;
318 let mut prefix = 0.0;
319 let mut mode = IntegratedExpectationMode::ExactClosedForm;
320 for k in 0..=max_k {
321 let (log_laplace, val_mode) =
322 lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
323 log_values.push(if log_laplace.is_finite() {
324 prefix + log_laplace
325 } else {
326 f64::NEG_INFINITY
327 });
328 mode = worst_mode(mode, val_mode);
329 prefix += mu + (k as f64 + 0.5) * sigma2;
330 shifted_mu += sigma2;
331 }
332 Ok(LogLognormalKernelBundle { log_values, mode })
333}
334
335pub fn kernel_ratio_jet(
345 log_bundle: &LogLognormalKernelBundle,
346 k: usize,
347 m: f64,
348 order: usize,
349) -> [f64; 5] {
350 let kf = k as f64;
351 let log_k0 = log_bundle.get(k);
352
353 let mut rk = [0.0f64; 5]; for r in 1..=order.min(4) {
358 let delta = log_bundle.get(k + r) - log_k0;
359 rk[r] = if delta.is_finite() {
360 delta.exp()
361 } else if delta > 0.0 {
362 f64::INFINITY
363 } else {
364 0.0
365 };
366 }
367
368 let mut jet = [0.0; 5];
369 jet[0] = 1.0;
370
371 if order >= 1 {
372 jet[1] = kf - m * rk[1];
373 }
374 if order >= 2 {
375 jet[2] = kf * kf - (2.0 * kf + 1.0) * m * rk[1] + m * m * rk[2];
376 }
377 if order >= 3 {
378 jet[3] = kf * kf * kf - (3.0 * kf * kf + 3.0 * kf + 1.0) * m * rk[1]
379 + 3.0 * (kf + 1.0) * m * m * rk[2]
380 - m * m * m * rk[3];
381 }
382 if order >= 4 {
383 let k2 = kf * kf;
384 let k3 = k2 * kf;
385 let k4 = k3 * kf;
386 let m2 = m * m;
387 let m3 = m2 * m;
388 let m4 = m3 * m;
389 jet[4] = k4 - (4.0 * k3 + 6.0 * k2 + 4.0 * kf + 1.0) * m * rk[1]
390 + (6.0 * k2 + 12.0 * kf + 7.0) * m2 * rk[2]
391 - (4.0 * kf + 6.0) * m3 * rk[3]
392 + m4 * rk[4];
393 }
394
395 jet
396}
397
398pub use crate::quadrature::{
404 LatentCLogLogJet5, latent_cloglog_inverse_link_jet, latent_cloglog_jet5,
405};
406
407#[derive(Clone, Copy, Debug)]
411pub struct KernelSumTerm {
412 pub coeff: f64,
414 pub k: usize,
416 pub m: f64,
418}
419
420#[derive(Clone, Copy, Debug)]
434pub struct LogKernelSumJet {
435 pub value: f64,
437 pub d1: f64,
439 pub d2: f64,
441 pub d3: f64,
443 pub d4: f64,
445 pub mode: IntegratedExpectationMode,
446}
447
448impl LogKernelSumJet {
449 #[inline]
450 fn non_positive(mode: IntegratedExpectationMode) -> Self {
451 Self {
452 value: f64::NEG_INFINITY,
453 d1: 0.0,
454 d2: 0.0,
455 d3: 0.0,
456 d4: 0.0,
457 mode,
458 }
459 }
460
461 #[inline]
462 fn from_log_value_and_ratios(
463 value: f64,
464 ratio: [f64; 5],
465 mode: IntegratedExpectationMode,
466 ) -> Self {
467 let r1 = ratio[1];
468 let r2 = ratio[2];
469 let r3 = ratio[3];
470 let r4 = ratio[4];
471 Self {
472 value,
473 d1: r1,
474 d2: r2 - r1 * r1,
475 d3: r3 - 3.0 * r1 * r2 + 2.0 * r1 * r1 * r1,
476 d4: r4 - 4.0 * r1 * r3 - 3.0 * r2 * r2 + 12.0 * r1 * r1 * r2 - 6.0 * r1.powi(4),
477 mode,
478 }
479 }
480
481 #[inline]
482 fn term_log_mag_and_ratio(
483 bundle: &LogLognormalKernelBundle,
484 term: KernelSumTerm,
485 ) -> (f64, [f64; 5]) {
486 (
487 term.coeff.abs().ln() + bundle.get(term.k),
488 kernel_ratio_jet(bundle, term.k, term.m, 4),
491 )
492 }
493
494 fn evaluate_two_terms(
495 quadctx: &QuadratureContext,
496 t0: KernelSumTerm,
497 t1: KernelSumTerm,
498 mu: f64,
499 sigma: f64,
500 ) -> Result<Self, EstimationError> {
501 let max_k_needed = t0.k.max(t1.k) + 4;
502 let bundle0 = log_kernel_bundle(quadctx, t0.m, mu, sigma, max_k_needed)?;
503 let mut overall_mode = bundle0.mode;
504 let bundle1_owned = if (t0.m - t1.m).abs() < 1e-300 {
505 None
506 } else {
507 let bundle1 = log_kernel_bundle(quadctx, t1.m, mu, sigma, max_k_needed)?;
508 overall_mode = worst_mode(overall_mode, bundle1.mode);
509 Some(bundle1)
510 };
511 let bundle1 = bundle1_owned.as_ref().unwrap_or(&bundle0);
512
513 let (log_mag0, ratio0) = Self::term_log_mag_and_ratio(&bundle0, t0);
514 let (log_mag1, ratio1) = Self::term_log_mag_and_ratio(bundle1, t1);
515 let log_mags = [log_mag0, log_mag1];
516 let signs = [t0.coeff.signum(), t1.coeff.signum()];
517 let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
518 if !log_s.is_finite() || sign_s <= 0.0 {
519 return Ok(Self::non_positive(overall_mode));
520 }
521
522 let w0 = sign_s * signs[0] * (log_mag0 - log_s).exp();
523 let w1 = sign_s * signs[1] * (log_mag1 - log_s).exp();
524 let wr1 = w0 * ratio0[1] + w1 * ratio1[1];
525 let wr2 = w0 * ratio0[2] + w1 * ratio1[2];
526 let wr3 = w0 * ratio0[3] + w1 * ratio1[3];
527 let wr4 = w0 * ratio0[4] + w1 * ratio1[4];
528
529 Ok(Self {
530 value: log_s,
531 d1: wr1,
532 d2: wr2 - wr1 * wr1,
533 d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
534 d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
535 - 6.0 * wr1.powi(4),
536 mode: overall_mode,
537 })
538 }
539
540 pub fn single_term(
545 quadctx: &QuadratureContext,
546 k: usize,
547 m: f64,
548 mu: f64,
549 sigma: f64,
550 ) -> Result<Self, EstimationError> {
551 let max_k_needed = k + 4;
552 let lb = log_kernel_bundle(quadctx, m, mu, sigma, max_k_needed)?;
553 Ok(Self::from_log_value_and_ratios(
554 lb.get(k),
555 kernel_ratio_jet(&lb, k, m, 4),
556 lb.mode,
557 ))
558 }
559
560 pub fn evaluate(
574 quadctx: &QuadratureContext,
575 terms: &[KernelSumTerm],
576 mu: f64,
577 sigma: f64,
578 ) -> Result<Self, EstimationError> {
579 if terms.is_empty() {
580 crate::bail_invalid_estim!("KernelSumJet requires at least one term");
583 }
584
585 if terms.len() == 1 {
587 let t = &terms[0];
588 if t.coeff <= 0.0 {
589 return Ok(Self::non_positive(
593 IntegratedExpectationMode::ExactClosedForm,
594 ));
595 }
596 let jet = Self::single_term(quadctx, t.k, t.m, mu, sigma)?;
597 return Ok(Self {
598 value: t.coeff.ln() + jet.value,
599 d1: jet.d1,
600 d2: jet.d2,
601 d3: jet.d3,
602 d4: jet.d4,
603 mode: jet.mode,
604 });
605 }
606 if terms.len() == 2 {
607 return Self::evaluate_two_terms(quadctx, terms[0], terms[1], mu, sigma);
608 }
609
610 let max_k_needed = terms.iter().map(|t| t.k).max().unwrap_or(0) + 4;
611
612 let mut log_bundles: Vec<(f64, LogLognormalKernelBundle)> = Vec::with_capacity(2);
614 let mut overall_mode = IntegratedExpectationMode::ExactClosedForm;
615 for term in terms {
616 if !log_bundles
617 .iter()
618 .any(|(m, _)| (*m - term.m).abs() < 1e-300)
619 {
620 let b = log_kernel_bundle(quadctx, term.m, mu, sigma, max_k_needed)?;
621 overall_mode = worst_mode(overall_mode, b.mode);
622 log_bundles.push((term.m, b));
623 }
624 }
625
626 let get_lb = |m: f64| -> &LogLognormalKernelBundle {
627 &log_bundles
628 .iter()
629 .find(|(bm, _)| (*bm - m).abs() < 1e-300)
630 .unwrap()
631 .1
632 };
633
634 let mut log_mags: Vec<f64> = Vec::with_capacity(terms.len());
636 let mut signs: Vec<f64> = Vec::with_capacity(terms.len());
637 let mut ratios: Vec<[f64; 5]> = Vec::with_capacity(terms.len());
638 for term in terms {
639 let lb = get_lb(term.m);
640 log_mags.push(term.coeff.abs().ln() + lb.get(term.k));
641 signs.push(term.coeff.signum());
642 ratios.push(kernel_ratio_jet(lb, term.k, term.m, 4));
643 }
644
645 let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
647
648 if !log_s.is_finite() || sign_s <= 0.0 {
649 return Ok(Self::non_positive(overall_mode));
651 }
652
653 let mut wr1 = 0.0;
656 let mut wr2 = 0.0;
657 let mut wr3 = 0.0;
658 let mut wr4 = 0.0;
659 for i in 0..terms.len() {
660 let w = sign_s * signs[i] * (log_mags[i] - log_s).exp();
661 wr1 += w * ratios[i][1];
662 wr2 += w * ratios[i][2];
663 wr3 += w * ratios[i][3];
664 wr4 += w * ratios[i][4];
665 }
666
667 Ok(Self {
668 value: log_s,
669 d1: wr1,
670 d2: wr2 - wr1 * wr1,
671 d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
672 d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
673 - 6.0 * wr1.powi(4),
674 mode: overall_mode,
675 })
676 }
677}
678
679#[derive(Clone, Copy, Debug, PartialEq, Eq)]
683pub enum LatentSurvivalEventType {
684 RightCensored,
686 ExactEvent,
688 IntervalCensored,
690}
691
692impl fmt::Display for LatentSurvivalEventType {
693 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
694 match self {
695 Self::RightCensored => write!(f, "right_censored"),
696 Self::ExactEvent => write!(f, "exact_event"),
697 Self::IntervalCensored => write!(f, "interval_censored"),
698 }
699 }
700}
701
702#[derive(Clone, Copy, Debug)]
717pub struct LatentSurvivalRow {
718 pub event_type: LatentSurvivalEventType,
719 pub mass_entry: f64,
722 pub mass_exit: f64,
724 pub mass_left: f64,
726 pub mass_right: f64,
728 pub mass_unloaded_left: f64,
730 pub mass_unloaded_right: f64,
732 pub mass_unloaded_entry: f64,
734 pub mass_unloaded_exit: f64,
736 pub hazard_loaded: f64,
738 pub hazard_unloaded: f64,
740}
741
742impl LatentSurvivalRow {
743 pub fn right_censored(
749 mass_entry: f64,
750 mass_exit: f64,
751 mass_unloaded_entry: f64,
752 mass_unloaded_exit: f64,
753 ) -> Self {
754 Self {
755 event_type: LatentSurvivalEventType::RightCensored,
756 mass_entry,
757 mass_exit,
758 mass_left: 0.0,
759 mass_right: 0.0,
760 mass_unloaded_left: 0.0,
761 mass_unloaded_right: 0.0,
762 mass_unloaded_entry,
763 mass_unloaded_exit,
764 hazard_loaded: 0.0,
765 hazard_unloaded: 0.0,
766 }
767 }
768
769 pub fn exact_event(
771 mass_entry: f64,
772 mass_exit: f64,
773 mass_unloaded_entry: f64,
774 mass_unloaded_exit: f64,
775 hazard_loaded: f64,
776 hazard_unloaded: f64,
777 ) -> Self {
778 Self {
779 event_type: LatentSurvivalEventType::ExactEvent,
780 mass_entry,
781 mass_exit,
782 mass_left: 0.0,
783 mass_right: 0.0,
784 mass_unloaded_left: 0.0,
785 mass_unloaded_right: 0.0,
786 mass_unloaded_entry,
787 mass_unloaded_exit,
788 hazard_loaded,
789 hazard_unloaded,
790 }
791 }
792
793 pub fn interval_censored(
795 mass_entry: f64,
796 mass_left: f64,
797 mass_right: f64,
798 mass_unloaded_entry: f64,
799 mass_unloaded_left: f64,
800 mass_unloaded_right: f64,
801 ) -> Self {
802 Self {
803 event_type: LatentSurvivalEventType::IntervalCensored,
804 mass_entry,
805 mass_exit: 0.0,
806 mass_left,
807 mass_right,
808 mass_unloaded_left,
809 mass_unloaded_right,
810 mass_unloaded_entry,
811 mass_unloaded_exit: 0.0,
812 hazard_loaded: 0.0,
813 hazard_unloaded: 0.0,
814 }
815 }
816
817 pub fn validate(&self) -> Result<(), EstimationError> {
818 let fields = [
819 ("mass_entry", self.mass_entry),
820 ("mass_exit", self.mass_exit),
821 ("mass_left", self.mass_left),
822 ("mass_right", self.mass_right),
823 ("mass_unloaded_left", self.mass_unloaded_left),
824 ("mass_unloaded_right", self.mass_unloaded_right),
825 ("mass_unloaded_entry", self.mass_unloaded_entry),
826 ("mass_unloaded_exit", self.mass_unloaded_exit),
827 ("hazard_loaded", self.hazard_loaded),
828 ("hazard_unloaded", self.hazard_unloaded),
829 ];
830 for (name, value) in fields {
831 if !value.is_finite() || value < 0.0 {
832 crate::bail_invalid_estim!(
833 "latent survival row has invalid {name}={value}; expected a finite non-negative value"
834 );
835 }
836 }
837
838 match self.event_type {
839 LatentSurvivalEventType::RightCensored => {
840 if self.mass_exit < self.mass_entry {
841 crate::bail_invalid_estim!(
842 "latent survival right-censored row requires mass_exit >= mass_entry, got {} < {}",
843 self.mass_exit,
844 self.mass_entry
845 );
846 }
847 if self.mass_unloaded_exit < self.mass_unloaded_entry {
848 crate::bail_invalid_estim!(
849 "latent survival right-censored row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
850 self.mass_unloaded_exit,
851 self.mass_unloaded_entry
852 );
853 }
854 if self.mass_left > 0.0
855 || self.mass_right > 0.0
856 || self.mass_unloaded_left > 0.0
857 || self.mass_unloaded_right > 0.0
858 || self.hazard_loaded > 0.0
859 || self.hazard_unloaded > 0.0
860 {
861 crate::bail_invalid_estim!("latent survival right-censored row cannot carry interval masses or event hazards"
862 .to_string(),);
863 }
864 }
865 LatentSurvivalEventType::ExactEvent => {
866 if self.mass_exit < self.mass_entry {
867 crate::bail_invalid_estim!(
868 "latent survival exact-event row requires mass_exit >= mass_entry, got {} < {}",
869 self.mass_exit,
870 self.mass_entry
871 );
872 }
873 if self.mass_unloaded_exit < self.mass_unloaded_entry {
874 crate::bail_invalid_estim!(
875 "latent survival exact-event row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
876 self.mass_unloaded_exit,
877 self.mass_unloaded_entry
878 );
879 }
880 if self.mass_left > 0.0
881 || self.mass_right > 0.0
882 || self.mass_unloaded_left > 0.0
883 || self.mass_unloaded_right > 0.0
884 {
885 crate::bail_invalid_estim!(
886 "latent survival exact-event row cannot carry interval masses"
887 );
888 }
889 if self.hazard_loaded == 0.0 && self.hazard_unloaded == 0.0 {
890 crate::bail_invalid_estim!("latent survival exact-event row requires a positive loaded or unloaded hazard"
891 .to_string(),);
892 }
893 }
894 LatentSurvivalEventType::IntervalCensored => {
895 if self.mass_left < self.mass_entry || self.mass_right < self.mass_left {
896 crate::bail_invalid_estim!(
897 "latent survival interval row requires mass_entry <= mass_left <= mass_right, got entry={}, left={}, right={}",
898 self.mass_entry,
899 self.mass_left,
900 self.mass_right
901 );
902 }
903 if self.mass_unloaded_left < self.mass_unloaded_entry
904 || self.mass_unloaded_right < self.mass_unloaded_left
905 {
906 crate::bail_invalid_estim!(
907 "latent survival interval row requires unloaded_entry <= unloaded_left <= unloaded_right, got entry={}, left={}, right={}",
908 self.mass_unloaded_entry,
909 self.mass_unloaded_left,
910 self.mass_unloaded_right
911 );
912 }
913 if self.mass_exit > 0.0
914 || self.mass_unloaded_exit > 0.0
915 || self.hazard_loaded > 0.0
916 || self.hazard_unloaded > 0.0
917 {
918 crate::bail_invalid_estim!(
919 "latent survival interval row cannot carry exit masses or event hazards"
920 .to_string(),
921 );
922 }
923 }
924 }
925
926 Ok(())
927 }
928}
929
930fn exact_event_kernel_jet(
931 quadctx: &QuadratureContext,
932 row: &LatentSurvivalRow,
933 mu: f64,
934 sigma: f64,
935) -> Result<LogKernelSumJet, EstimationError> {
936 if row.hazard_loaded < 0.0 || row.hazard_unloaded < 0.0 {
937 crate::bail_invalid_estim!(
938 "latent survival exact-event hazards must be non-negative, got loaded={} unloaded={}",
939 row.hazard_loaded,
940 row.hazard_unloaded
941 );
942 }
943 match (row.hazard_unloaded > 0.0, row.hazard_loaded > 0.0) {
944 (true, true) => {
945 let terms = [
946 KernelSumTerm {
947 coeff: row.hazard_unloaded,
948 k: 0,
949 m: row.mass_exit,
950 },
951 KernelSumTerm {
952 coeff: row.hazard_loaded,
953 k: 1,
954 m: row.mass_exit,
955 },
956 ];
957 LogKernelSumJet::evaluate(quadctx, &terms, mu, sigma)
958 }
959 (true, false) => {
960 let jet = LogKernelSumJet::single_term(quadctx, 0, row.mass_exit, mu, sigma)?;
961 Ok(LogKernelSumJet {
962 value: row.hazard_unloaded.ln() + jet.value,
963 d1: jet.d1,
964 d2: jet.d2,
965 d3: jet.d3,
966 d4: jet.d4,
967 mode: jet.mode,
968 })
969 }
970 (false, true) => {
971 let jet = LogKernelSumJet::single_term(quadctx, 1, row.mass_exit, mu, sigma)?;
972 Ok(LogKernelSumJet {
973 value: row.hazard_loaded.ln() + jet.value,
974 d1: jet.d1,
975 d2: jet.d2,
976 d3: jet.d3,
977 d4: jet.d4,
978 mode: jet.mode,
979 })
980 }
981 (false, false) => Err(EstimationError::InvalidInput(
982 "latent survival exact-event row requires a positive loaded or unloaded hazard"
983 .to_string(),
984 )),
985 }
986}
987
988#[derive(Clone, Copy, Debug)]
995pub struct LatentSurvivalRowJet {
996 pub log_lik: f64,
997 pub score: f64,
998 pub neg_hessian: f64,
999 pub d3: f64,
1000 pub score_log_sigma: f64,
1001 pub neg_hessian_log_sigma: f64,
1002}
1003
1004#[inline]
1005fn log_sigma_score_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1006 let sigma2 = sigma * sigma;
1007 sigma2 * (jet.d2 + jet.d1 * jet.d1)
1008}
1009
1010#[inline]
1011fn log_sigma_neg_hessian_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1012 let sigma2 = sigma * sigma;
1013 let sigma4 = sigma2 * sigma2;
1014 let d1 = jet.d1;
1015 let d2 = jet.d2;
1016 let d3 = jet.d3;
1017 let d4 = jet.d4;
1018 let s2_over_s = d2 + d1 * d1;
1019 let s4_over_s_minus_s2_sq = d4 + 4.0 * d1 * d3 + 2.0 * d2 * d2 + 4.0 * d1 * d1 * d2;
1024 -(2.0 * sigma2 * s2_over_s + sigma4 * s4_over_s_minus_s2_sq)
1025}
1026
1027impl LatentSurvivalRowJet {
1028 pub fn evaluate(
1029 quadctx: &QuadratureContext,
1030 row: &LatentSurvivalRow,
1031 mu: f64,
1032 sigma: f64,
1033 ) -> Result<Self, EstimationError> {
1034 row.validate()?;
1035 match row.event_type {
1036 LatentSurvivalEventType::RightCensored => Self::right_censored(quadctx, mu, sigma, row),
1037 LatentSurvivalEventType::ExactEvent => Self::exact_event(quadctx, mu, sigma, row),
1038 LatentSurvivalEventType::IntervalCensored => {
1039 Self::interval_censored(quadctx, mu, sigma, row)
1040 }
1041 }
1042 }
1043
1044 fn right_censored(
1052 quadctx: &QuadratureContext,
1053 mu: f64,
1054 sigma: f64,
1055 row: &LatentSurvivalRow,
1056 ) -> Result<Self, EstimationError> {
1057 let has_unloaded =
1058 row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300;
1059
1060 let mass_exit_loaded = row.mass_exit;
1064 let mass_entry_loaded = row.mass_entry;
1065
1066 let unloaded_offset = if has_unloaded {
1068 -row.mass_unloaded_exit + row.mass_unloaded_entry
1069 } else {
1070 0.0
1071 };
1072
1073 let num = LogKernelSumJet::single_term(quadctx, 0, mass_exit_loaded, mu, sigma)?;
1074 if mass_entry_loaded > 1e-300 {
1075 let den = LogKernelSumJet::single_term(quadctx, 0, mass_entry_loaded, mu, sigma)?;
1076 Ok(Self {
1077 log_lik: unloaded_offset + num.value - den.value,
1078 score: num.d1 - den.d1,
1079 neg_hessian: -(num.d2 - den.d2),
1080 d3: num.d3 - den.d3,
1081 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1082 - log_sigma_score_from_log_sum(&den, sigma),
1083 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1084 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1085 })
1086 } else {
1087 Ok(Self {
1088 log_lik: unloaded_offset + num.value,
1089 score: num.d1,
1090 neg_hessian: -num.d2,
1091 d3: num.d3,
1092 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1093 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1094 })
1095 }
1096 }
1097
1098 fn exact_event(
1102 quadctx: &QuadratureContext,
1103 mu: f64,
1104 sigma: f64,
1105 row: &LatentSurvivalRow,
1106 ) -> Result<Self, EstimationError> {
1107 let unloaded_offset =
1108 if row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300 {
1109 -row.mass_unloaded_exit + row.mass_unloaded_entry
1110 } else {
1111 0.0
1112 };
1113 let num = exact_event_kernel_jet(quadctx, row, mu, sigma)?;
1114
1115 if row.mass_entry > 1e-300 {
1116 let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1117 Ok(Self {
1118 log_lik: unloaded_offset + num.value - den.value,
1119 score: num.d1 - den.d1,
1120 neg_hessian: -(num.d2 - den.d2),
1121 d3: num.d3 - den.d3,
1122 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1123 - log_sigma_score_from_log_sum(&den, sigma),
1124 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1125 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1126 })
1127 } else {
1128 Ok(Self {
1129 log_lik: unloaded_offset + num.value,
1130 score: num.d1,
1131 neg_hessian: -num.d2,
1132 d3: num.d3,
1133 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1134 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1135 })
1136 }
1137 }
1138
1139 fn interval_censored(
1141 quadctx: &QuadratureContext,
1142 mu: f64,
1143 sigma: f64,
1144 row: &LatentSurvivalRow,
1145 ) -> Result<Self, EstimationError> {
1146 let num_terms = [
1147 KernelSumTerm {
1148 coeff: (-row.mass_unloaded_left).exp(),
1149 k: 0,
1150 m: row.mass_left,
1151 },
1152 KernelSumTerm {
1153 coeff: -(-row.mass_unloaded_right).exp(),
1154 k: 0,
1155 m: row.mass_right,
1156 },
1157 ];
1158 let num = LogKernelSumJet::evaluate(quadctx, &num_terms, mu, sigma)?;
1159
1160 if row.mass_entry > 1e-300 {
1161 let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1162 Ok(Self {
1163 log_lik: num.value + row.mass_unloaded_entry - den.value,
1164 score: num.d1 - den.d1,
1165 neg_hessian: -(num.d2 - den.d2),
1166 d3: num.d3 - den.d3,
1167 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1168 - log_sigma_score_from_log_sum(&den, sigma),
1169 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1170 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1171 })
1172 } else {
1173 Ok(Self {
1174 log_lik: num.value + row.mass_unloaded_entry,
1175 score: num.d1,
1176 neg_hessian: -num.d2,
1177 d3: num.d3,
1178 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1179 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1180 })
1181 }
1182 }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187 use super::*;
1188
1189 fn latent_binomial_row_log_lik(
1190 ctx: &QuadratureContext,
1191 eta: f64,
1192 sigma: f64,
1193 y: f64,
1194 weight: f64,
1195 ) -> f64 {
1196 let mu = latent_cloglog_jet5(ctx, eta, sigma)
1197 .expect("latent jet")
1198 .mean;
1199 let mu = mu.clamp(1e-12, 1.0 - 1e-12);
1200 weight * (y * mu.ln() + (1.0 - y) * (1.0 - mu).ln())
1201 }
1202
1203 #[test]
1204 fn kernel_ratio_jet_d1_fd_check() {
1205 let ctx = QuadratureContext::new();
1206 let mu = 0.3;
1207 let sigma = 0.5;
1208 let m = 1.0;
1209 let k = 0usize;
1210 let h = 1e-5;
1211
1212 let bundle = log_kernel_bundle(&ctx, m, mu, sigma, k + 4).unwrap();
1213 let log_k = bundle.get(k);
1214 let ratios = kernel_ratio_jet(&bundle, k, m, 2);
1215 let kc = log_k.exp();
1216 let d1 = kc * ratios[1];
1217 let d2 = kc * ratios[2];
1218
1219 let kp = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0.exp();
1220 let km = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0.exp();
1221 let fd_d1 = (kp - km) / (2.0 * h);
1222 assert!(
1223 (d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-4,
1224 "d1: jet={d1}, fd={fd_d1}",
1225 );
1226
1227 let fd_d2 = (kp - 2.0 * kc + km) / (h * h);
1228 assert!(
1229 (d2 - fd_d2).abs() / fd_d2.abs().max(1e-15) < 1e-3,
1230 "d2: jet={d2}, fd={fd_d2}",
1231 );
1232 }
1233
1234 #[test]
1235 fn survival_right_censored_score_fd() {
1236 let ctx = QuadratureContext::new();
1237 let mu = -0.5;
1238 let sigma = 0.3;
1239 let h = 1e-6;
1240 let row = LatentSurvivalRow::right_censored(0.0, 2.0, 0.0, 0.0);
1241 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1242 .unwrap()
1243 .log_lik;
1244 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1245 .unwrap()
1246 .log_lik;
1247 let fd_score = (ll_p - ll_m) / (2.0 * h);
1248 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1249 assert!(
1250 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1251 "score={}, fd={fd_score}",
1252 jet.score
1253 );
1254 }
1255
1256 #[test]
1257 fn survival_exact_event_score_fd() {
1258 let ctx = QuadratureContext::new();
1259 let mu = 0.2;
1260 let sigma = 0.5;
1261 let h = 1e-6;
1262 let row = LatentSurvivalRow::exact_event(0.0, 1.5, 0.0, 0.0, (-0.3f64).exp(), 0.0);
1263 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1264 .unwrap()
1265 .log_lik;
1266 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1267 .unwrap()
1268 .log_lik;
1269 let fd_score = (ll_p - ll_m) / (2.0 * h);
1270 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1271 assert!(
1272 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1273 "score={}, fd={fd_score}",
1274 jet.score
1275 );
1276 }
1277
1278 #[test]
1279 fn survival_exact_event_loaded_vs_unloaded_score_fd() {
1280 let ctx = QuadratureContext::new();
1281 let mu = -0.1;
1282 let sigma = 0.4;
1283 let h = 1e-6;
1284 let row = LatentSurvivalRow::exact_event(0.3, 1.2, 0.2, 0.6, 0.9, 0.15);
1285 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1286 .unwrap()
1287 .log_lik;
1288 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1289 .unwrap()
1290 .log_lik;
1291 let fd_score = (ll_p - ll_m) / (2.0 * h);
1292 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1293 assert!(
1294 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1295 "score={}, fd={fd_score}",
1296 jet.score
1297 );
1298 }
1299
1300 #[test]
1301 fn survival_right_censored_loaded_vs_unloaded_score_fd() {
1302 let ctx = QuadratureContext::new();
1303 let mu = 0.15;
1304 let sigma: f64 = 0.35;
1305 let h = 1e-6;
1306 let row = LatentSurvivalRow::right_censored(0.4, 1.7, 0.1, 0.5);
1307 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1308 .unwrap()
1309 .log_lik;
1310 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1311 .unwrap()
1312 .log_lik;
1313 let fd_score = (ll_p - ll_m) / (2.0 * h);
1314 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1315 assert!(
1316 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1317 "score={}, fd={fd_score}",
1318 jet.score
1319 );
1320 }
1321
1322 #[test]
1323 fn survival_interval_censored_score_fd() {
1324 let ctx = QuadratureContext::new();
1325 let mu = 0.0;
1326 let sigma = 0.6;
1327 let h = 1e-6;
1328 let row = LatentSurvivalRow::interval_censored(0.0, 1.0, 2.0, 0.0, 0.0, 0.0);
1329 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1330 .unwrap()
1331 .log_lik;
1332 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1333 .unwrap()
1334 .log_lik;
1335 let fd_score = (ll_p - ll_m) / (2.0 * h);
1336 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1337 assert!(
1338 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1339 "score={}, fd={fd_score}",
1340 jet.score
1341 );
1342 }
1343
1344 #[test]
1345 fn survival_interval_censored_neg_hessian_fd() {
1346 let ctx = QuadratureContext::new();
1350 let mu = -0.2;
1351 let sigma = 0.55;
1352 let h = 2e-4;
1353 let row = LatentSurvivalRow::interval_censored(0.0, 0.7, 1.9, 0.0, 0.0, 0.0);
1354 let ll = |m: f64| {
1355 LatentSurvivalRowJet::evaluate(&ctx, &row, m, sigma)
1356 .unwrap()
1357 .log_lik
1358 };
1359 let fd_d2 = (ll(mu + h) - 2.0 * ll(mu) + ll(mu - h)) / (h * h);
1360 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1361 assert!(
1362 (jet.neg_hessian - (-fd_d2)).abs() / fd_d2.abs().max(1e-12) < 1e-2,
1363 "interval neg_hessian={}, fd(-d2)={}",
1364 jet.neg_hessian,
1365 -fd_d2
1366 );
1367 }
1368
1369 #[test]
1370 fn survival_interval_censored_log_sigma_score_fd() {
1371 let ctx = QuadratureContext::new();
1376 let mu = 0.1;
1377 let sigma: f64 = 0.6;
1378 let h = 1e-5;
1379 let row = LatentSurvivalRow::interval_censored(0.0, 0.8, 2.1, 0.0, 0.0, 0.0);
1380 let ll_at = |s: f64| {
1381 LatentSurvivalRowJet::evaluate(&ctx, &row, mu, s)
1382 .unwrap()
1383 .log_lik
1384 };
1385 let fd_dlogsigma =
1387 (ll_at((sigma.ln() + h).exp()) - ll_at((sigma.ln() - h).exp())) / (2.0 * h);
1388 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1389 assert!(
1390 (jet.score_log_sigma - fd_dlogsigma).abs() / fd_dlogsigma.abs().max(1e-12) < 1e-3,
1391 "interval score_log_sigma={}, fd={fd_dlogsigma}",
1392 jet.score_log_sigma
1393 );
1394 }
1395
1396 #[test]
1397 fn log_kernel_single_term_log_sigma_derivatives_match_ghq_reference() {
1398 let ctx = QuadratureContext::new();
1399 let mu = 0.2;
1400 let sigma = 1.0;
1401 let jet = LogKernelSumJet::single_term(&ctx, 0, 1.0, mu, sigma).unwrap();
1402 let ghq = crate::inference::quadrature::cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
1403 let survival = (1.0 - ghq.l).max(1e-300);
1404 let survival_sigma_over_survival = -ghq.l_sigma / survival;
1405 let ref_score = sigma * survival_sigma_over_survival;
1406 let ref_neg_hessian = -(ref_score
1407 + sigma
1408 * sigma
1409 * (-ghq.l_sigmasigma / survival - survival_sigma_over_survival.powi(2)));
1410
1411 assert!(
1412 (log_sigma_score_from_log_sum(&jet, sigma) - ref_score).abs()
1413 / ref_score.abs().max(1e-12)
1414 < 1e-4,
1415 "log-sigma score={}, ref={ref_score}",
1416 log_sigma_score_from_log_sum(&jet, sigma)
1417 );
1418 assert!(
1419 (log_sigma_neg_hessian_from_log_sum(&jet, sigma) - ref_neg_hessian).abs()
1420 / ref_neg_hessian.abs().max(1e-12)
1421 < 1e-3,
1422 "log-sigma neg_hessian={}, ref={ref_neg_hessian}",
1423 log_sigma_neg_hessian_from_log_sum(&jet, sigma)
1424 );
1425 }
1426
1427 #[test]
1428 fn log_kernel_sum_jet_single_term_d1_fd() {
1429 let ctx = QuadratureContext::new();
1430 let mu = 0.5;
1431 let sigma = 0.4;
1432 let m = 1.0;
1433 let k = 0usize;
1434 let h = 1e-6;
1435
1436 let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1437 let val_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1438 let val_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1439 let fd_d1 = (val_p - val_m) / (2.0 * h);
1440 assert!(
1441 (jet.d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-3,
1442 "d1={}, fd={fd_d1}",
1443 jet.d1
1444 );
1445 }
1446
1447 #[test]
1448 fn log_kernel_sum_jet_single_term_d4_fd() {
1449 let ctx = QuadratureContext::new();
1450 let mu = 0.35;
1451 let sigma = 0.45;
1452 let m = 1.2;
1453 let k = 1usize;
1454 let h = 2e-3;
1455
1456 let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1457 let v_pp = log_kernel_term(&ctx, k, m, mu + 2.0 * h, sigma).unwrap().0;
1458 let v_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1459 let v_0 = log_kernel_term(&ctx, k, m, mu, sigma).unwrap().0;
1460 let v_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1461 let v_mm = log_kernel_term(&ctx, k, m, mu - 2.0 * h, sigma).unwrap().0;
1462 let fd_d4 = (v_mm - 4.0 * v_m + 6.0 * v_0 - 4.0 * v_p + v_pp) / h.powi(4);
1463 assert!(
1464 (jet.d4 - fd_d4).abs() / jet.d4.abs().max(fd_d4.abs()).max(1e-8) < 2e-2,
1465 "d4={}, fd={fd_d4}",
1466 jet.d4
1467 );
1468 }
1469
1470 #[test]
1471 fn latent_cloglog_jet_matches_point_limit_at_zero_sigma() {
1472 let ctx = QuadratureContext::new();
1473 let eta = -0.4;
1474 let jet = latent_cloglog_jet5(&ctx, eta, 0.0).expect("latent jet");
1475 let t = eta.exp();
1476 let d1 = (eta - t).exp();
1477 let d2 = (1.0 - t) * d1;
1478 let d3 = (t * t - 3.0 * t + 1.0) * d1;
1479 let d4 = (-t * t * t + 6.0 * t * t - 7.0 * t + 1.0) * d1;
1480 let d5 = (t.powi(4) - 10.0 * t.powi(3) + 25.0 * t * t - 15.0 * t + 1.0) * d1;
1481 assert!((jet.mean - (1.0 - (-t).exp())).abs() < 1e-12);
1482 assert!((jet.d1 - d1).abs() < 1e-12);
1483 assert!((jet.d2 - d2).abs() < 1e-12);
1484 assert!((jet.d3 - d3).abs() < 1e-12);
1485 assert!((jet.d4 - d4).abs() < 1e-12);
1486 assert!((jet.d5 - d5).abs() < 1e-12);
1487 }
1488
1489 #[test]
1490 fn latent_cloglog_jet_matches_exact_kernel_recurrence() {
1491 let ctx = QuadratureContext::new();
1492 let cases = [(-4.0, 0.15), (-1.2, 0.35), (0.4, 0.6), (1.3, 0.9)];
1493
1494 for (eta, sigma) in cases {
1495 let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1496 let bundle = log_kernel_bundle(&ctx, 1.0, eta, sigma, 5).expect("kernel bundle");
1497 let k0 = bundle.get(0);
1498 let k1 = bundle.get(1).exp();
1499 let k2 = bundle.get(2).exp();
1500 let k3 = bundle.get(3).exp();
1501 let k4 = bundle.get(4).exp();
1502 let k5 = bundle.get(5).exp();
1503
1504 let mean = if k0.is_finite() { -k0.exp_m1() } else { 1.0 };
1505 let d1 = k1;
1506 let d2 = k1 - k2;
1507 let d3 = k1 - 3.0 * k2 + k3;
1508 let d4 = k1 - 7.0 * k2 + 6.0 * k3 - k4;
1509 let d5 = k1 - 15.0 * k2 + 25.0 * k3 - 10.0 * k4 + k5;
1510
1511 assert!((jet.mean - mean).abs() < 1e-12);
1512 assert!((jet.d1 - d1).abs() < 1e-12);
1513 assert!((jet.d2 - d2).abs() < 1e-12);
1514 assert!((jet.d3 - d3).abs() < 1e-12);
1515 assert!((jet.d4 - d4).abs() < 1e-12);
1516 assert!((jet.d5 - d5).abs() < 1e-12);
1517 }
1518 }
1519
1520 #[test]
1521 fn latent_cloglog_binomial_row_neg_hessian_matches_fd() {
1522 let ctx = QuadratureContext::new();
1523 let eta = 0.4;
1524 let sigma = 0.6;
1525 let y = 0.35;
1526 let weight = 2.0;
1527 let h = 1e-4;
1528
1529 let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1530 let mu = jet.mean.clamp(1e-12, 1.0 - 1e-12);
1531 let ellmu = y / mu - (1.0 - y) / (1.0 - mu);
1532 let ellmumu = -y / (mu * mu) - (1.0 - y) / ((1.0 - mu) * (1.0 - mu));
1533 let neg_hessian = -weight * (ellmumu * jet.d1 * jet.d1 + ellmu * jet.d2);
1534
1535 let ll_minus = latent_binomial_row_log_lik(&ctx, eta - h, sigma, y, weight);
1536 let ll0 = latent_binomial_row_log_lik(&ctx, eta, sigma, y, weight);
1537 let ll_plus = latent_binomial_row_log_lik(&ctx, eta + h, sigma, y, weight);
1538 let neg_hessian_fd = -(ll_plus - 2.0 * ll0 + ll_minus) / (h * h);
1539
1540 let err = (neg_hessian - neg_hessian_fd).abs();
1541 let tol = 2e-5_f64.max(3e-3 * neg_hessian_fd.abs());
1542 assert!(
1543 err <= tol,
1544 "latent cloglog Bernoulli row curvature mismatch: analytic={} fd={}",
1545 neg_hessian,
1546 neg_hessian_fd
1547 );
1548 }
1549}