1use crate::model_types::EstimationError;
15use crate::probability::signed_log_sum_exp;
16use crate::quadrature::{
17 IntegratedExpectationMode, QuadratureContext, lognormal_laplace_unit_log_term_shared,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21
22#[derive(Debug, Clone)]
30pub enum LognormalKernelError {
31 InvalidSpec { reason: String },
34}
35
36impl_reason_error_boilerplate! {
37 LognormalKernelError {
38 InvalidSpec,
39 }
40}
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "kebab-case")]
47pub enum HazardLoading {
48 Full,
50 LoadedVsUnloaded,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71#[serde(tag = "frailty_kind", rename_all = "kebab-case")]
72pub enum FrailtySpec {
73 None,
75 GaussianShift {
79 sigma_fixed: Option<f64>,
81 },
82 HazardMultiplier {
85 sigma_fixed: Option<f64>,
87 loading: HazardLoading,
89 },
90}
91
92impl FrailtySpec {
93 #[inline]
103 pub fn is_active(&self) -> bool {
104 !matches!(self, Self::None)
105 }
106
107 pub fn validate_for_marginal_slope(&self) -> Result<(), String> {
119 self.validate_for_marginal_slope_typed()
120 .map_err(|e| e.to_string())
121 }
122
123 pub fn validate_for_marginal_slope_typed(&self) -> Result<(), LognormalKernelError> {
127 match self {
128 Self::None | Self::GaussianShift { .. } => Ok(()),
129 Self::HazardMultiplier { .. } => Err(LognormalKernelError::InvalidSpec {
130 reason:
131 "HazardMultiplier frailty is not finite-state exact with score_warp/linkwiggle \
132 cubic marginal-slope families. Use GaussianShift frailty (exact probit scaling) \
133 or use the standalone latent-cloglog/latent-survival families instead."
134 .to_string(),
135 }),
136 }
137 }
138}
139
140#[inline]
143fn probit_frailty_scale_components(sigma: f64) -> (f64, f64) {
144 let abs_sigma = sigma.abs();
145 if abs_sigma > 1.0 {
146 let inv = 1.0 / abs_sigma;
147 let denom = 1.0 + inv * inv;
148 (inv / denom.sqrt(), 1.0 / denom)
149 } else {
150 let sigma2 = sigma * sigma;
151 let denom = 1.0 + sigma2;
152 (1.0 / denom.sqrt(), sigma2 / denom)
153 }
154}
155
156#[derive(Clone, Copy, Debug)]
164pub struct ProbitFrailtyScaleJet {
165 pub s: f64,
167 pub alpha: f64,
169 pub ds: f64,
171 pub d2s: f64,
173}
174
175impl ProbitFrailtyScaleJet {
176 pub fn new(sigma: f64) -> Self {
181 let (s, alpha) = probit_frailty_scale_components(sigma);
182 Self {
183 s,
184 alpha,
185 ds: -alpha * s,
186 d2s: alpha * (3.0 * alpha - 2.0) * s,
187 }
188 }
189
190 pub fn from_log_sigma(log_sigma: f64) -> Self {
192 Self::new(log_sigma.exp())
193 }
194}
195
196#[inline]
197fn worst_mode(
198 a: IntegratedExpectationMode,
199 b: IntegratedExpectationMode,
200) -> IntegratedExpectationMode {
201 if a.rank() >= b.rank() { a } else { b }
202}
203
204#[inline]
215fn validate_kernel_inputs(m: f64, mu: f64, sigma: f64) -> Result<(), EstimationError> {
216 if !m.is_finite() || m < 0.0 {
217 crate::bail_invalid_estim!("lognormal kernel requires finite m >= 0, got {m}");
218 }
219 if !mu.is_finite() || !sigma.is_finite() || sigma < 0.0 {
220 crate::bail_invalid_estim!(
221 "lognormal kernel requires finite mu and sigma >= 0, got mu={mu}, sigma={sigma}"
222 );
223 }
224 Ok::<(), _>(())
225}
226
227#[inline]
228pub fn log_kernel_term(
229 quadctx: &QuadratureContext,
230 k: usize,
231 m: f64,
232 mu: f64,
233 sigma: f64,
234) -> Result<(f64, IntegratedExpectationMode), EstimationError> {
235 validate_kernel_inputs(m, mu, sigma)?;
236 let kf = k as f64;
237 let sigma2 = sigma * sigma;
238 if !sigma2.is_finite() {
239 crate::bail_invalid_estim!(
240 "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
241 );
242 }
243 let prefix_bound = kf * mu.abs() + 0.5 * kf * kf * sigma2;
244 if !prefix_bound.is_finite() {
245 crate::bail_invalid_estim!(
246 "lognormal kernel prefix is outside the finite exact-derivative range: k={k}, mu={mu}, sigma={sigma}"
247 );
248 }
249 let prefix = kf * mu + 0.5 * kf * kf * sigma2;
250 if m == 0.0 {
251 return Ok((prefix, IntegratedExpectationMode::ExactClosedForm));
252 }
253 let log_m = m.ln();
254 let shifted_bound = mu.abs() + kf * sigma2 + log_m.abs();
255 if !shifted_bound.is_finite() {
256 crate::bail_invalid_estim!(
257 "lognormal kernel shifted location is outside the finite exact-derivative range: k={k}, m={m}, mu={mu}, sigma={sigma}"
258 );
259 }
260 let shifted_mu = mu + kf * sigma2 + log_m;
261 let (log_laplace, mode) = lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
266 Ok((prefix + log_laplace, mode))
267}
268
269#[derive(Clone, Debug)]
271pub struct LogLognormalKernelBundle {
272 pub log_values: Vec<f64>,
273 pub mode: IntegratedExpectationMode,
274}
275
276impl LogLognormalKernelBundle {
277 #[inline]
278 pub fn get(&self, k: usize) -> f64 {
279 self.log_values[k]
280 }
281
282 #[inline]
283 pub fn len(&self) -> usize {
284 self.log_values.len()
285 }
286}
287
288pub fn log_kernel_bundle(
291 quadctx: &QuadratureContext,
292 m: f64,
293 mu: f64,
294 sigma: f64,
295 max_k: usize,
296) -> Result<LogLognormalKernelBundle, EstimationError> {
297 validate_kernel_inputs(m, mu, sigma)?;
298 let mut log_values = Vec::with_capacity(max_k + 1);
299 let sigma2 = sigma * sigma;
300 if !sigma2.is_finite() {
301 crate::bail_invalid_estim!(
302 "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
303 );
304 }
305 let max_kf = max_k as f64;
306 let prefix_bound = max_kf * mu.abs() + 0.5 * max_kf * max_kf * sigma2;
307 if !prefix_bound.is_finite() {
308 crate::bail_invalid_estim!(
309 "lognormal kernel bundle prefix is outside the finite exact-derivative range: max_k={max_k}, mu={mu}, sigma={sigma}"
310 );
311 }
312 if m == 0.0 {
313 let mut prefix = 0.0;
314 for k in 0..=max_k {
315 log_values.push(prefix);
316 prefix += mu + (k as f64 + 0.5) * sigma2;
317 }
318 return Ok(LogLognormalKernelBundle {
319 log_values,
320 mode: IntegratedExpectationMode::ExactClosedForm,
321 });
322 }
323
324 let log_m = m.ln();
325 let shifted_bound = mu.abs() + max_kf * sigma2 + log_m.abs();
326 if !shifted_bound.is_finite() {
327 crate::bail_invalid_estim!(
328 "lognormal kernel bundle shifted location is outside the finite exact-derivative range: max_k={max_k}, m={m}, mu={mu}, sigma={sigma}"
329 );
330 }
331 let mut shifted_mu = mu + log_m;
332 let mut prefix = 0.0;
333 let mut mode = IntegratedExpectationMode::ExactClosedForm;
334 for k in 0..=max_k {
335 let (log_laplace, val_mode) =
336 lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
337 log_values.push(if log_laplace.is_finite() {
338 prefix + log_laplace
339 } else {
340 f64::NEG_INFINITY
341 });
342 mode = worst_mode(mode, val_mode);
343 prefix += mu + (k as f64 + 0.5) * sigma2;
344 shifted_mu += sigma2;
345 }
346 Ok(LogLognormalKernelBundle { log_values, mode })
347}
348
349pub fn kernel_ratio_jet(
359 log_bundle: &LogLognormalKernelBundle,
360 k: usize,
361 m: f64,
362 order: usize,
363) -> [f64; 5] {
364 let kf = k as f64;
365 let log_k0 = log_bundle.get(k);
366
367 let mut rk = [0.0f64; 5]; for r in 1..=order.min(4) {
372 let delta = log_bundle.get(k + r) - log_k0;
373 rk[r] = if delta.is_finite() {
374 delta.exp()
375 } else if delta > 0.0 {
376 f64::INFINITY
377 } else {
378 0.0
379 };
380 }
381
382 let mut jet = [0.0; 5];
383 jet[0] = 1.0;
384
385 if order >= 1 {
386 jet[1] = kf - m * rk[1];
387 }
388 if order >= 2 {
389 jet[2] = kf * kf - (2.0 * kf + 1.0) * m * rk[1] + m * m * rk[2];
390 }
391 if order >= 3 {
392 jet[3] = kf * kf * kf - (3.0 * kf * kf + 3.0 * kf + 1.0) * m * rk[1]
393 + 3.0 * (kf + 1.0) * m * m * rk[2]
394 - m * m * m * rk[3];
395 }
396 if order >= 4 {
397 let k2 = kf * kf;
398 let k3 = k2 * kf;
399 let k4 = k3 * kf;
400 let m2 = m * m;
401 let m3 = m2 * m;
402 let m4 = m3 * m;
403 jet[4] = k4 - (4.0 * k3 + 6.0 * k2 + 4.0 * kf + 1.0) * m * rk[1]
404 + (6.0 * k2 + 12.0 * kf + 7.0) * m2 * rk[2]
405 - (4.0 * kf + 6.0) * m3 * rk[3]
406 + m4 * rk[4];
407 }
408
409 jet
410}
411
412pub use crate::quadrature::{
418 LatentCLogLogJet5, latent_cloglog_inverse_link_jet, latent_cloglog_jet5,
419};
420
421#[derive(Clone, Copy, Debug)]
425pub struct KernelSumTerm {
426 pub coeff: f64,
428 pub k: usize,
430 pub m: f64,
432}
433
434#[derive(Clone, Copy, Debug)]
448pub struct LogKernelSumJet {
449 pub value: f64,
451 pub d1: f64,
453 pub d2: f64,
455 pub d3: f64,
457 pub d4: f64,
459 pub mode: IntegratedExpectationMode,
460}
461
462impl LogKernelSumJet {
463 #[inline]
464 fn non_positive(mode: IntegratedExpectationMode) -> Self {
465 Self {
466 value: f64::NEG_INFINITY,
467 d1: 0.0,
468 d2: 0.0,
469 d3: 0.0,
470 d4: 0.0,
471 mode,
472 }
473 }
474
475 #[inline]
476 fn from_log_value_and_ratios(
477 value: f64,
478 ratio: [f64; 5],
479 mode: IntegratedExpectationMode,
480 ) -> Self {
481 let r1 = ratio[1];
482 let r2 = ratio[2];
483 let r3 = ratio[3];
484 let r4 = ratio[4];
485 Self {
486 value,
487 d1: r1,
488 d2: r2 - r1 * r1,
489 d3: r3 - 3.0 * r1 * r2 + 2.0 * r1 * r1 * r1,
490 d4: r4 - 4.0 * r1 * r3 - 3.0 * r2 * r2 + 12.0 * r1 * r1 * r2 - 6.0 * r1.powi(4),
491 mode,
492 }
493 }
494
495 #[inline]
496 fn term_log_mag_and_ratio(
497 bundle: &LogLognormalKernelBundle,
498 term: KernelSumTerm,
499 ) -> (f64, [f64; 5]) {
500 (
501 term.coeff.abs().ln() + bundle.get(term.k),
502 kernel_ratio_jet(bundle, term.k, term.m, 4),
505 )
506 }
507
508 fn evaluate_two_terms(
509 quadctx: &QuadratureContext,
510 t0: KernelSumTerm,
511 t1: KernelSumTerm,
512 mu: f64,
513 sigma: f64,
514 ) -> Result<Self, EstimationError> {
515 let max_k_needed = t0.k.max(t1.k) + 4;
516 let bundle0 = log_kernel_bundle(quadctx, t0.m, mu, sigma, max_k_needed)?;
517 let mut overall_mode = bundle0.mode;
518 let bundle1_owned = if (t0.m - t1.m).abs() < 1e-300 {
519 None
520 } else {
521 let bundle1 = log_kernel_bundle(quadctx, t1.m, mu, sigma, max_k_needed)?;
522 overall_mode = worst_mode(overall_mode, bundle1.mode);
523 Some(bundle1)
524 };
525 let bundle1 = bundle1_owned.as_ref().unwrap_or(&bundle0);
526
527 let (log_mag0, ratio0) = Self::term_log_mag_and_ratio(&bundle0, t0);
528 let (log_mag1, ratio1) = Self::term_log_mag_and_ratio(bundle1, t1);
529 let log_mags = [log_mag0, log_mag1];
530 let signs = [t0.coeff.signum(), t1.coeff.signum()];
531 let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
532 if !log_s.is_finite() || sign_s <= 0.0 {
533 return Ok(Self::non_positive(overall_mode));
534 }
535
536 let w0 = sign_s * signs[0] * (log_mag0 - log_s).exp();
537 let w1 = sign_s * signs[1] * (log_mag1 - log_s).exp();
538 let wr1 = w0 * ratio0[1] + w1 * ratio1[1];
539 let wr2 = w0 * ratio0[2] + w1 * ratio1[2];
540 let wr3 = w0 * ratio0[3] + w1 * ratio1[3];
541 let wr4 = w0 * ratio0[4] + w1 * ratio1[4];
542
543 Ok(Self {
544 value: log_s,
545 d1: wr1,
546 d2: wr2 - wr1 * wr1,
547 d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
548 d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
549 - 6.0 * wr1.powi(4),
550 mode: overall_mode,
551 })
552 }
553
554 pub fn single_term(
559 quadctx: &QuadratureContext,
560 k: usize,
561 m: f64,
562 mu: f64,
563 sigma: f64,
564 ) -> Result<Self, EstimationError> {
565 let max_k_needed = k + 4;
566 let lb = log_kernel_bundle(quadctx, m, mu, sigma, max_k_needed)?;
567 Ok(Self::from_log_value_and_ratios(
568 lb.get(k),
569 kernel_ratio_jet(&lb, k, m, 4),
570 lb.mode,
571 ))
572 }
573
574 pub fn evaluate(
588 quadctx: &QuadratureContext,
589 terms: &[KernelSumTerm],
590 mu: f64,
591 sigma: f64,
592 ) -> Result<Self, EstimationError> {
593 if terms.is_empty() {
594 crate::bail_invalid_estim!("KernelSumJet requires at least one term");
597 }
598
599 if terms.len() == 1 {
601 let t = &terms[0];
602 if t.coeff <= 0.0 {
603 return Ok(Self::non_positive(
607 IntegratedExpectationMode::ExactClosedForm,
608 ));
609 }
610 let jet = Self::single_term(quadctx, t.k, t.m, mu, sigma)?;
611 return Ok(Self {
612 value: t.coeff.ln() + jet.value,
613 d1: jet.d1,
614 d2: jet.d2,
615 d3: jet.d3,
616 d4: jet.d4,
617 mode: jet.mode,
618 });
619 }
620 if terms.len() == 2 {
621 return Self::evaluate_two_terms(quadctx, terms[0], terms[1], mu, sigma);
622 }
623
624 let max_k_needed = terms.iter().map(|t| t.k).max().unwrap_or(0) + 4;
625
626 let mut log_bundles: Vec<(f64, LogLognormalKernelBundle)> = Vec::with_capacity(2);
628 let mut overall_mode = IntegratedExpectationMode::ExactClosedForm;
629 for term in terms {
630 if !log_bundles
631 .iter()
632 .any(|(m, _)| (*m - term.m).abs() < 1e-300)
633 {
634 let b = log_kernel_bundle(quadctx, term.m, mu, sigma, max_k_needed)?;
635 overall_mode = worst_mode(overall_mode, b.mode);
636 log_bundles.push((term.m, b));
637 }
638 }
639
640 let get_lb = |m: f64| -> &LogLognormalKernelBundle {
641 &log_bundles
642 .iter()
643 .find(|(bm, _)| (*bm - m).abs() < 1e-300)
644 .unwrap()
645 .1
646 };
647
648 let mut log_mags: Vec<f64> = Vec::with_capacity(terms.len());
650 let mut signs: Vec<f64> = Vec::with_capacity(terms.len());
651 let mut ratios: Vec<[f64; 5]> = Vec::with_capacity(terms.len());
652 for term in terms {
653 let lb = get_lb(term.m);
654 log_mags.push(term.coeff.abs().ln() + lb.get(term.k));
655 signs.push(term.coeff.signum());
656 ratios.push(kernel_ratio_jet(lb, term.k, term.m, 4));
657 }
658
659 let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
661
662 if !log_s.is_finite() || sign_s <= 0.0 {
663 return Ok(Self::non_positive(overall_mode));
665 }
666
667 let mut wr1 = 0.0;
670 let mut wr2 = 0.0;
671 let mut wr3 = 0.0;
672 let mut wr4 = 0.0;
673 for i in 0..terms.len() {
674 let w = sign_s * signs[i] * (log_mags[i] - log_s).exp();
675 wr1 += w * ratios[i][1];
676 wr2 += w * ratios[i][2];
677 wr3 += w * ratios[i][3];
678 wr4 += w * ratios[i][4];
679 }
680
681 Ok(Self {
682 value: log_s,
683 d1: wr1,
684 d2: wr2 - wr1 * wr1,
685 d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
686 d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
687 - 6.0 * wr1.powi(4),
688 mode: overall_mode,
689 })
690 }
691}
692
693#[derive(Clone, Copy, Debug, PartialEq, Eq)]
697pub enum LatentSurvivalEventType {
698 RightCensored,
700 ExactEvent,
702 IntervalCensored,
704}
705
706impl fmt::Display for LatentSurvivalEventType {
707 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708 match self {
709 Self::RightCensored => write!(f, "right_censored"),
710 Self::ExactEvent => write!(f, "exact_event"),
711 Self::IntervalCensored => write!(f, "interval_censored"),
712 }
713 }
714}
715
716#[derive(Clone, Copy, Debug)]
731pub struct LatentSurvivalRow {
732 pub event_type: LatentSurvivalEventType,
733 pub mass_entry: f64,
736 pub mass_exit: f64,
738 pub mass_left: f64,
740 pub mass_right: f64,
742 pub mass_unloaded_left: f64,
744 pub mass_unloaded_right: f64,
746 pub mass_unloaded_entry: f64,
748 pub mass_unloaded_exit: f64,
750 pub hazard_loaded: f64,
752 pub hazard_unloaded: f64,
754}
755
756impl LatentSurvivalRow {
757 pub fn right_censored(
763 mass_entry: f64,
764 mass_exit: f64,
765 mass_unloaded_entry: f64,
766 mass_unloaded_exit: f64,
767 ) -> Self {
768 Self {
769 event_type: LatentSurvivalEventType::RightCensored,
770 mass_entry,
771 mass_exit,
772 mass_left: 0.0,
773 mass_right: 0.0,
774 mass_unloaded_left: 0.0,
775 mass_unloaded_right: 0.0,
776 mass_unloaded_entry,
777 mass_unloaded_exit,
778 hazard_loaded: 0.0,
779 hazard_unloaded: 0.0,
780 }
781 }
782
783 pub fn exact_event(
785 mass_entry: f64,
786 mass_exit: f64,
787 mass_unloaded_entry: f64,
788 mass_unloaded_exit: f64,
789 hazard_loaded: f64,
790 hazard_unloaded: f64,
791 ) -> Self {
792 Self {
793 event_type: LatentSurvivalEventType::ExactEvent,
794 mass_entry,
795 mass_exit,
796 mass_left: 0.0,
797 mass_right: 0.0,
798 mass_unloaded_left: 0.0,
799 mass_unloaded_right: 0.0,
800 mass_unloaded_entry,
801 mass_unloaded_exit,
802 hazard_loaded,
803 hazard_unloaded,
804 }
805 }
806
807 pub fn interval_censored(
809 mass_entry: f64,
810 mass_left: f64,
811 mass_right: f64,
812 mass_unloaded_entry: f64,
813 mass_unloaded_left: f64,
814 mass_unloaded_right: f64,
815 ) -> Self {
816 Self {
817 event_type: LatentSurvivalEventType::IntervalCensored,
818 mass_entry,
819 mass_exit: 0.0,
820 mass_left,
821 mass_right,
822 mass_unloaded_left,
823 mass_unloaded_right,
824 mass_unloaded_entry,
825 mass_unloaded_exit: 0.0,
826 hazard_loaded: 0.0,
827 hazard_unloaded: 0.0,
828 }
829 }
830
831 pub fn validate(&self) -> Result<(), EstimationError> {
832 let fields = [
833 ("mass_entry", self.mass_entry),
834 ("mass_exit", self.mass_exit),
835 ("mass_left", self.mass_left),
836 ("mass_right", self.mass_right),
837 ("mass_unloaded_left", self.mass_unloaded_left),
838 ("mass_unloaded_right", self.mass_unloaded_right),
839 ("mass_unloaded_entry", self.mass_unloaded_entry),
840 ("mass_unloaded_exit", self.mass_unloaded_exit),
841 ("hazard_loaded", self.hazard_loaded),
842 ("hazard_unloaded", self.hazard_unloaded),
843 ];
844 for (name, value) in fields {
845 if !value.is_finite() || value < 0.0 {
846 crate::bail_invalid_estim!(
847 "latent survival row has invalid {name}={value}; expected a finite non-negative value"
848 );
849 }
850 }
851
852 match self.event_type {
853 LatentSurvivalEventType::RightCensored => {
854 if self.mass_exit < self.mass_entry {
855 crate::bail_invalid_estim!(
856 "latent survival right-censored row requires mass_exit >= mass_entry, got {} < {}",
857 self.mass_exit,
858 self.mass_entry
859 );
860 }
861 if self.mass_unloaded_exit < self.mass_unloaded_entry {
862 crate::bail_invalid_estim!(
863 "latent survival right-censored row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
864 self.mass_unloaded_exit,
865 self.mass_unloaded_entry
866 );
867 }
868 if self.mass_left > 0.0
869 || self.mass_right > 0.0
870 || self.mass_unloaded_left > 0.0
871 || self.mass_unloaded_right > 0.0
872 || self.hazard_loaded > 0.0
873 || self.hazard_unloaded > 0.0
874 {
875 crate::bail_invalid_estim!("latent survival right-censored row cannot carry interval masses or event hazards"
876 .to_string(),);
877 }
878 }
879 LatentSurvivalEventType::ExactEvent => {
880 if self.mass_exit < self.mass_entry {
881 crate::bail_invalid_estim!(
882 "latent survival exact-event row requires mass_exit >= mass_entry, got {} < {}",
883 self.mass_exit,
884 self.mass_entry
885 );
886 }
887 if self.mass_unloaded_exit < self.mass_unloaded_entry {
888 crate::bail_invalid_estim!(
889 "latent survival exact-event row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
890 self.mass_unloaded_exit,
891 self.mass_unloaded_entry
892 );
893 }
894 if self.mass_left > 0.0
895 || self.mass_right > 0.0
896 || self.mass_unloaded_left > 0.0
897 || self.mass_unloaded_right > 0.0
898 {
899 crate::bail_invalid_estim!(
900 "latent survival exact-event row cannot carry interval masses"
901 );
902 }
903 if self.hazard_loaded == 0.0 && self.hazard_unloaded == 0.0 {
904 crate::bail_invalid_estim!("latent survival exact-event row requires a positive loaded or unloaded hazard"
905 .to_string(),);
906 }
907 }
908 LatentSurvivalEventType::IntervalCensored => {
909 if self.mass_left < self.mass_entry || self.mass_right < self.mass_left {
910 crate::bail_invalid_estim!(
911 "latent survival interval row requires mass_entry <= mass_left <= mass_right, got entry={}, left={}, right={}",
912 self.mass_entry,
913 self.mass_left,
914 self.mass_right
915 );
916 }
917 if self.mass_unloaded_left < self.mass_unloaded_entry
918 || self.mass_unloaded_right < self.mass_unloaded_left
919 {
920 crate::bail_invalid_estim!(
921 "latent survival interval row requires unloaded_entry <= unloaded_left <= unloaded_right, got entry={}, left={}, right={}",
922 self.mass_unloaded_entry,
923 self.mass_unloaded_left,
924 self.mass_unloaded_right
925 );
926 }
927 if self.mass_exit > 0.0
928 || self.mass_unloaded_exit > 0.0
929 || self.hazard_loaded > 0.0
930 || self.hazard_unloaded > 0.0
931 {
932 crate::bail_invalid_estim!(
933 "latent survival interval row cannot carry exit masses or event hazards"
934 .to_string(),
935 );
936 }
937 }
938 }
939
940 Ok(())
941 }
942}
943
944fn exact_event_kernel_jet(
945 quadctx: &QuadratureContext,
946 row: &LatentSurvivalRow,
947 mu: f64,
948 sigma: f64,
949) -> Result<LogKernelSumJet, EstimationError> {
950 if row.hazard_loaded < 0.0 || row.hazard_unloaded < 0.0 {
951 crate::bail_invalid_estim!(
952 "latent survival exact-event hazards must be non-negative, got loaded={} unloaded={}",
953 row.hazard_loaded,
954 row.hazard_unloaded
955 );
956 }
957 match (row.hazard_unloaded > 0.0, row.hazard_loaded > 0.0) {
958 (true, true) => {
959 let terms = [
960 KernelSumTerm {
961 coeff: row.hazard_unloaded,
962 k: 0,
963 m: row.mass_exit,
964 },
965 KernelSumTerm {
966 coeff: row.hazard_loaded,
967 k: 1,
968 m: row.mass_exit,
969 },
970 ];
971 LogKernelSumJet::evaluate(quadctx, &terms, mu, sigma)
972 }
973 (true, false) => {
974 let jet = LogKernelSumJet::single_term(quadctx, 0, row.mass_exit, mu, sigma)?;
975 Ok(LogKernelSumJet {
976 value: row.hazard_unloaded.ln() + jet.value,
977 d1: jet.d1,
978 d2: jet.d2,
979 d3: jet.d3,
980 d4: jet.d4,
981 mode: jet.mode,
982 })
983 }
984 (false, true) => {
985 let jet = LogKernelSumJet::single_term(quadctx, 1, row.mass_exit, mu, sigma)?;
986 Ok(LogKernelSumJet {
987 value: row.hazard_loaded.ln() + jet.value,
988 d1: jet.d1,
989 d2: jet.d2,
990 d3: jet.d3,
991 d4: jet.d4,
992 mode: jet.mode,
993 })
994 }
995 (false, false) => Err(EstimationError::InvalidInput(
996 "latent survival exact-event row requires a positive loaded or unloaded hazard"
997 .to_string(),
998 )),
999 }
1000}
1001
1002#[derive(Clone, Copy, Debug)]
1009pub struct LatentSurvivalRowJet {
1010 pub log_lik: f64,
1011 pub score: f64,
1012 pub neg_hessian: f64,
1013 pub d3: f64,
1014 pub score_log_sigma: f64,
1015 pub neg_hessian_log_sigma: f64,
1016}
1017
1018#[inline]
1019fn log_sigma_score_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1020 let sigma2 = sigma * sigma;
1021 sigma2 * (jet.d2 + jet.d1 * jet.d1)
1022}
1023
1024#[inline]
1025fn log_sigma_neg_hessian_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1026 let sigma2 = sigma * sigma;
1027 let sigma4 = sigma2 * sigma2;
1028 let d1 = jet.d1;
1029 let d2 = jet.d2;
1030 let d3 = jet.d3;
1031 let d4 = jet.d4;
1032 let s2_over_s = d2 + d1 * d1;
1033 let s4_over_s_minus_s2_sq = d4 + 4.0 * d1 * d3 + 2.0 * d2 * d2 + 4.0 * d1 * d1 * d2;
1038 -(2.0 * sigma2 * s2_over_s + sigma4 * s4_over_s_minus_s2_sq)
1039}
1040
1041impl LatentSurvivalRowJet {
1042 pub fn evaluate(
1043 quadctx: &QuadratureContext,
1044 row: &LatentSurvivalRow,
1045 mu: f64,
1046 sigma: f64,
1047 ) -> Result<Self, EstimationError> {
1048 row.validate()?;
1049 match row.event_type {
1050 LatentSurvivalEventType::RightCensored => Self::right_censored(quadctx, mu, sigma, row),
1051 LatentSurvivalEventType::ExactEvent => Self::exact_event(quadctx, mu, sigma, row),
1052 LatentSurvivalEventType::IntervalCensored => {
1053 Self::interval_censored(quadctx, mu, sigma, row)
1054 }
1055 }
1056 }
1057
1058 fn right_censored(
1066 quadctx: &QuadratureContext,
1067 mu: f64,
1068 sigma: f64,
1069 row: &LatentSurvivalRow,
1070 ) -> Result<Self, EstimationError> {
1071 let has_unloaded =
1072 row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300;
1073
1074 let mass_exit_loaded = row.mass_exit;
1078 let mass_entry_loaded = row.mass_entry;
1079
1080 let unloaded_offset = if has_unloaded {
1082 -row.mass_unloaded_exit + row.mass_unloaded_entry
1083 } else {
1084 0.0
1085 };
1086
1087 let num = LogKernelSumJet::single_term(quadctx, 0, mass_exit_loaded, mu, sigma)?;
1088 if mass_entry_loaded > 1e-300 {
1089 let den = LogKernelSumJet::single_term(quadctx, 0, mass_entry_loaded, mu, sigma)?;
1090 Ok(Self {
1091 log_lik: unloaded_offset + num.value - den.value,
1092 score: num.d1 - den.d1,
1093 neg_hessian: -(num.d2 - den.d2),
1094 d3: num.d3 - den.d3,
1095 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1096 - log_sigma_score_from_log_sum(&den, sigma),
1097 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1098 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1099 })
1100 } else {
1101 Ok(Self {
1102 log_lik: unloaded_offset + num.value,
1103 score: num.d1,
1104 neg_hessian: -num.d2,
1105 d3: num.d3,
1106 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1107 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1108 })
1109 }
1110 }
1111
1112 fn exact_event(
1116 quadctx: &QuadratureContext,
1117 mu: f64,
1118 sigma: f64,
1119 row: &LatentSurvivalRow,
1120 ) -> Result<Self, EstimationError> {
1121 let unloaded_offset =
1122 if row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300 {
1123 -row.mass_unloaded_exit + row.mass_unloaded_entry
1124 } else {
1125 0.0
1126 };
1127 let num = exact_event_kernel_jet(quadctx, row, mu, sigma)?;
1128
1129 if row.mass_entry > 1e-300 {
1130 let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1131 Ok(Self {
1132 log_lik: unloaded_offset + num.value - den.value,
1133 score: num.d1 - den.d1,
1134 neg_hessian: -(num.d2 - den.d2),
1135 d3: num.d3 - den.d3,
1136 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1137 - log_sigma_score_from_log_sum(&den, sigma),
1138 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1139 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1140 })
1141 } else {
1142 Ok(Self {
1143 log_lik: unloaded_offset + num.value,
1144 score: num.d1,
1145 neg_hessian: -num.d2,
1146 d3: num.d3,
1147 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1148 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1149 })
1150 }
1151 }
1152
1153 fn interval_censored(
1155 quadctx: &QuadratureContext,
1156 mu: f64,
1157 sigma: f64,
1158 row: &LatentSurvivalRow,
1159 ) -> Result<Self, EstimationError> {
1160 let num_terms = [
1161 KernelSumTerm {
1162 coeff: (-row.mass_unloaded_left).exp(),
1163 k: 0,
1164 m: row.mass_left,
1165 },
1166 KernelSumTerm {
1167 coeff: -(-row.mass_unloaded_right).exp(),
1168 k: 0,
1169 m: row.mass_right,
1170 },
1171 ];
1172 let num = LogKernelSumJet::evaluate(quadctx, &num_terms, mu, sigma)?;
1173
1174 if row.mass_entry > 1e-300 {
1175 let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1176 Ok(Self {
1177 log_lik: num.value + row.mass_unloaded_entry - den.value,
1178 score: num.d1 - den.d1,
1179 neg_hessian: -(num.d2 - den.d2),
1180 d3: num.d3 - den.d3,
1181 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1182 - log_sigma_score_from_log_sum(&den, sigma),
1183 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1184 - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1185 })
1186 } else {
1187 Ok(Self {
1188 log_lik: num.value + row.mass_unloaded_entry,
1189 score: num.d1,
1190 neg_hessian: -num.d2,
1191 d3: num.d3,
1192 score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1193 neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1194 })
1195 }
1196 }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use super::*;
1202
1203 fn latent_binomial_row_log_lik(
1204 ctx: &QuadratureContext,
1205 eta: f64,
1206 sigma: f64,
1207 y: f64,
1208 weight: f64,
1209 ) -> f64 {
1210 let mu = latent_cloglog_jet5(ctx, eta, sigma)
1211 .expect("latent jet")
1212 .mean;
1213 let mu = mu.clamp(1e-12, 1.0 - 1e-12);
1214 weight * (y * mu.ln() + (1.0 - y) * (1.0 - mu).ln())
1215 }
1216
1217 #[test]
1218 fn kernel_ratio_jet_d1_fd_check() {
1219 let ctx = QuadratureContext::new();
1220 let mu = 0.3;
1221 let sigma = 0.5;
1222 let m = 1.0;
1223 let k = 0usize;
1224 let h = 1e-5;
1225
1226 let bundle = log_kernel_bundle(&ctx, m, mu, sigma, k + 4).unwrap();
1227 let log_k = bundle.get(k);
1228 let ratios = kernel_ratio_jet(&bundle, k, m, 2);
1229 let kc = log_k.exp();
1230 let d1 = kc * ratios[1];
1231 let d2 = kc * ratios[2];
1232
1233 let kp = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0.exp();
1234 let km = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0.exp();
1235 let fd_d1 = (kp - km) / (2.0 * h);
1236 assert!(
1237 (d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-4,
1238 "d1: jet={d1}, fd={fd_d1}",
1239 );
1240
1241 let fd_d2 = (kp - 2.0 * kc + km) / (h * h);
1242 assert!(
1243 (d2 - fd_d2).abs() / fd_d2.abs().max(1e-15) < 1e-3,
1244 "d2: jet={d2}, fd={fd_d2}",
1245 );
1246 }
1247
1248 #[test]
1249 fn survival_right_censored_score_fd() {
1250 let ctx = QuadratureContext::new();
1251 let mu = -0.5;
1252 let sigma = 0.3;
1253 let h = 1e-6;
1254 let row = LatentSurvivalRow::right_censored(0.0, 2.0, 0.0, 0.0);
1255 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1256 .unwrap()
1257 .log_lik;
1258 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1259 .unwrap()
1260 .log_lik;
1261 let fd_score = (ll_p - ll_m) / (2.0 * h);
1262 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1263 assert!(
1264 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1265 "score={}, fd={fd_score}",
1266 jet.score
1267 );
1268 }
1269
1270 #[test]
1271 fn survival_exact_event_score_fd() {
1272 let ctx = QuadratureContext::new();
1273 let mu = 0.2;
1274 let sigma = 0.5;
1275 let h = 1e-6;
1276 let row = LatentSurvivalRow::exact_event(0.0, 1.5, 0.0, 0.0, (-0.3f64).exp(), 0.0);
1277 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1278 .unwrap()
1279 .log_lik;
1280 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1281 .unwrap()
1282 .log_lik;
1283 let fd_score = (ll_p - ll_m) / (2.0 * h);
1284 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1285 assert!(
1286 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1287 "score={}, fd={fd_score}",
1288 jet.score
1289 );
1290 }
1291
1292 #[test]
1293 fn survival_exact_event_loaded_vs_unloaded_score_fd() {
1294 let ctx = QuadratureContext::new();
1295 let mu = -0.1;
1296 let sigma = 0.4;
1297 let h = 1e-6;
1298 let row = LatentSurvivalRow::exact_event(0.3, 1.2, 0.2, 0.6, 0.9, 0.15);
1299 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1300 .unwrap()
1301 .log_lik;
1302 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1303 .unwrap()
1304 .log_lik;
1305 let fd_score = (ll_p - ll_m) / (2.0 * h);
1306 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1307 assert!(
1308 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1309 "score={}, fd={fd_score}",
1310 jet.score
1311 );
1312 }
1313
1314 #[test]
1315 fn survival_right_censored_loaded_vs_unloaded_score_fd() {
1316 let ctx = QuadratureContext::new();
1317 let mu = 0.15;
1318 let sigma: f64 = 0.35;
1319 let h = 1e-6;
1320 let row = LatentSurvivalRow::right_censored(0.4, 1.7, 0.1, 0.5);
1321 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1322 .unwrap()
1323 .log_lik;
1324 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1325 .unwrap()
1326 .log_lik;
1327 let fd_score = (ll_p - ll_m) / (2.0 * h);
1328 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1329 assert!(
1330 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1331 "score={}, fd={fd_score}",
1332 jet.score
1333 );
1334 }
1335
1336 #[test]
1337 fn survival_interval_censored_score_fd() {
1338 let ctx = QuadratureContext::new();
1339 let mu = 0.0;
1340 let sigma = 0.6;
1341 let h = 1e-6;
1342 let row = LatentSurvivalRow::interval_censored(0.0, 1.0, 2.0, 0.0, 0.0, 0.0);
1343 let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1344 .unwrap()
1345 .log_lik;
1346 let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1347 .unwrap()
1348 .log_lik;
1349 let fd_score = (ll_p - ll_m) / (2.0 * h);
1350 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1351 assert!(
1352 (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1353 "score={}, fd={fd_score}",
1354 jet.score
1355 );
1356 }
1357
1358 #[test]
1359 fn survival_interval_censored_neg_hessian_fd() {
1360 let ctx = QuadratureContext::new();
1364 let mu = -0.2;
1365 let sigma = 0.55;
1366 let h = 2e-4;
1367 let row = LatentSurvivalRow::interval_censored(0.0, 0.7, 1.9, 0.0, 0.0, 0.0);
1368 let ll = |m: f64| {
1369 LatentSurvivalRowJet::evaluate(&ctx, &row, m, sigma)
1370 .unwrap()
1371 .log_lik
1372 };
1373 let fd_d2 = (ll(mu + h) - 2.0 * ll(mu) + ll(mu - h)) / (h * h);
1374 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1375 assert!(
1376 (jet.neg_hessian - (-fd_d2)).abs() / fd_d2.abs().max(1e-12) < 1e-2,
1377 "interval neg_hessian={}, fd(-d2)={}",
1378 jet.neg_hessian,
1379 -fd_d2
1380 );
1381 }
1382
1383 #[test]
1384 fn survival_interval_censored_log_sigma_score_fd() {
1385 let ctx = QuadratureContext::new();
1390 let mu = 0.1;
1391 let sigma: f64 = 0.6;
1392 let h = 1e-5;
1393 let row = LatentSurvivalRow::interval_censored(0.0, 0.8, 2.1, 0.0, 0.0, 0.0);
1394 let ll_at = |s: f64| {
1395 LatentSurvivalRowJet::evaluate(&ctx, &row, mu, s)
1396 .unwrap()
1397 .log_lik
1398 };
1399 let fd_dlogsigma =
1401 (ll_at((sigma.ln() + h).exp()) - ll_at((sigma.ln() - h).exp())) / (2.0 * h);
1402 let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1403 assert!(
1404 (jet.score_log_sigma - fd_dlogsigma).abs() / fd_dlogsigma.abs().max(1e-12) < 1e-3,
1405 "interval score_log_sigma={}, fd={fd_dlogsigma}",
1406 jet.score_log_sigma
1407 );
1408 }
1409
1410 #[test]
1411 fn log_kernel_single_term_log_sigma_derivatives_match_ghq_reference() {
1412 let ctx = QuadratureContext::new();
1413 let mu = 0.2;
1414 let sigma = 1.0;
1415 let jet = LogKernelSumJet::single_term(&ctx, 0, 1.0, mu, sigma).unwrap();
1416 let ghq = crate::inference::quadrature::cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
1417 let survival = (1.0 - ghq.l).max(1e-300);
1418 let survival_sigma_over_survival = -ghq.l_sigma / survival;
1419 let ref_score = sigma * survival_sigma_over_survival;
1420 let ref_neg_hessian = -(ref_score
1421 + sigma
1422 * sigma
1423 * (-ghq.l_sigmasigma / survival - survival_sigma_over_survival.powi(2)));
1424
1425 assert!(
1426 (log_sigma_score_from_log_sum(&jet, sigma) - ref_score).abs()
1427 / ref_score.abs().max(1e-12)
1428 < 1e-4,
1429 "log-sigma score={}, ref={ref_score}",
1430 log_sigma_score_from_log_sum(&jet, sigma)
1431 );
1432 assert!(
1433 (log_sigma_neg_hessian_from_log_sum(&jet, sigma) - ref_neg_hessian).abs()
1434 / ref_neg_hessian.abs().max(1e-12)
1435 < 1e-3,
1436 "log-sigma neg_hessian={}, ref={ref_neg_hessian}",
1437 log_sigma_neg_hessian_from_log_sum(&jet, sigma)
1438 );
1439 }
1440
1441 #[test]
1442 fn log_kernel_sum_jet_single_term_d1_fd() {
1443 let ctx = QuadratureContext::new();
1444 let mu = 0.5;
1445 let sigma = 0.4;
1446 let m = 1.0;
1447 let k = 0usize;
1448 let h = 1e-6;
1449
1450 let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1451 let val_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1452 let val_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1453 let fd_d1 = (val_p - val_m) / (2.0 * h);
1454 assert!(
1455 (jet.d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-3,
1456 "d1={}, fd={fd_d1}",
1457 jet.d1
1458 );
1459 }
1460
1461 #[test]
1462 fn log_kernel_sum_jet_single_term_d4_fd() {
1463 let ctx = QuadratureContext::new();
1464 let mu = 0.35;
1465 let sigma = 0.45;
1466 let m = 1.2;
1467 let k = 1usize;
1468 let h = 2e-3;
1469
1470 let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1471 let v_pp = log_kernel_term(&ctx, k, m, mu + 2.0 * h, sigma).unwrap().0;
1472 let v_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1473 let v_0 = log_kernel_term(&ctx, k, m, mu, sigma).unwrap().0;
1474 let v_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1475 let v_mm = log_kernel_term(&ctx, k, m, mu - 2.0 * h, sigma).unwrap().0;
1476 let fd_d4 = (v_mm - 4.0 * v_m + 6.0 * v_0 - 4.0 * v_p + v_pp) / h.powi(4);
1477 assert!(
1478 (jet.d4 - fd_d4).abs() / jet.d4.abs().max(fd_d4.abs()).max(1e-8) < 2e-2,
1479 "d4={}, fd={fd_d4}",
1480 jet.d4
1481 );
1482 }
1483
1484 #[test]
1485 fn latent_cloglog_jet_matches_point_limit_at_zero_sigma() {
1486 let ctx = QuadratureContext::new();
1487 let eta = -0.4;
1488 let jet = latent_cloglog_jet5(&ctx, eta, 0.0).expect("latent jet");
1489 let t = eta.exp();
1490 let d1 = (eta - t).exp();
1491 let d2 = (1.0 - t) * d1;
1492 let d3 = (t * t - 3.0 * t + 1.0) * d1;
1493 let d4 = (-t * t * t + 6.0 * t * t - 7.0 * t + 1.0) * d1;
1494 let d5 = (t.powi(4) - 10.0 * t.powi(3) + 25.0 * t * t - 15.0 * t + 1.0) * d1;
1495 assert!((jet.mean - (1.0 - (-t).exp())).abs() < 1e-12);
1496 assert!((jet.d1 - d1).abs() < 1e-12);
1497 assert!((jet.d2 - d2).abs() < 1e-12);
1498 assert!((jet.d3 - d3).abs() < 1e-12);
1499 assert!((jet.d4 - d4).abs() < 1e-12);
1500 assert!((jet.d5 - d5).abs() < 1e-12);
1501 }
1502
1503 #[test]
1504 fn latent_cloglog_jet_matches_exact_kernel_recurrence() {
1505 let ctx = QuadratureContext::new();
1506 let cases = [(-4.0, 0.15), (-1.2, 0.35), (0.4, 0.6), (1.3, 0.9)];
1507
1508 for (eta, sigma) in cases {
1509 let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1510 let bundle = log_kernel_bundle(&ctx, 1.0, eta, sigma, 5).expect("kernel bundle");
1511 let k0 = bundle.get(0);
1512 let k1 = bundle.get(1).exp();
1513 let k2 = bundle.get(2).exp();
1514 let k3 = bundle.get(3).exp();
1515 let k4 = bundle.get(4).exp();
1516 let k5 = bundle.get(5).exp();
1517
1518 let mean = if k0.is_finite() { -k0.exp_m1() } else { 1.0 };
1519 let d1 = k1;
1520 let d2 = k1 - k2;
1521 let d3 = k1 - 3.0 * k2 + k3;
1522 let d4 = k1 - 7.0 * k2 + 6.0 * k3 - k4;
1523 let d5 = k1 - 15.0 * k2 + 25.0 * k3 - 10.0 * k4 + k5;
1524
1525 assert!((jet.mean - mean).abs() < 1e-12);
1526 assert!((jet.d1 - d1).abs() < 1e-12);
1527 assert!((jet.d2 - d2).abs() < 1e-12);
1528 assert!((jet.d3 - d3).abs() < 1e-12);
1529 assert!((jet.d4 - d4).abs() < 1e-12);
1530 assert!((jet.d5 - d5).abs() < 1e-12);
1531 }
1532 }
1533
1534 #[test]
1535 fn latent_cloglog_binomial_row_neg_hessian_matches_fd() {
1536 let ctx = QuadratureContext::new();
1537 let eta = 0.4;
1538 let sigma = 0.6;
1539 let y = 0.35;
1540 let weight = 2.0;
1541 let h = 1e-4;
1542
1543 let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1544 let mu = jet.mean.clamp(1e-12, 1.0 - 1e-12);
1545 let ellmu = y / mu - (1.0 - y) / (1.0 - mu);
1546 let ellmumu = -y / (mu * mu) - (1.0 - y) / ((1.0 - mu) * (1.0 - mu));
1547 let neg_hessian = -weight * (ellmumu * jet.d1 * jet.d1 + ellmu * jet.d2);
1548
1549 let ll_minus = latent_binomial_row_log_lik(&ctx, eta - h, sigma, y, weight);
1550 let ll0 = latent_binomial_row_log_lik(&ctx, eta, sigma, y, weight);
1551 let ll_plus = latent_binomial_row_log_lik(&ctx, eta + h, sigma, y, weight);
1552 let neg_hessian_fd = -(ll_plus - 2.0 * ll0 + ll_minus) / (h * h);
1553
1554 let err = (neg_hessian - neg_hessian_fd).abs();
1555 let tol = 2e-5_f64.max(3e-3 * neg_hessian_fd.abs());
1556 assert!(
1557 err <= tol,
1558 "latent cloglog Bernoulli row curvature mismatch: analytic={} fd={}",
1559 neg_hessian,
1560 neg_hessian_fd
1561 );
1562 }
1563}