1use ndarray::{Array1, ArrayView1};
2
3#[derive(Clone, Copy, Debug, PartialEq)]
4pub struct SigmaJet1 {
5 pub sigma: f64,
6 pub d1: f64,
7}
8
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct SigmaJet3 {
11 pub sigma: f64,
12 pub d1: f64,
13 pub d2: f64,
14 pub d3: f64,
15}
16
17#[derive(Clone, Copy, Debug, PartialEq)]
18pub(crate) struct SigmaJet4 {
19 pub sigma: f64,
20 pub d1: f64,
21 pub d2: f64,
22 pub d3: f64,
23 pub d4: f64,
24}
25
26#[inline]
32pub fn safe_exp(eta: f64) -> f64 {
33 eta.exp()
34}
35
36#[inline]
37pub fn exp_sigma_jet1_scalar(eta: f64) -> SigmaJet1 {
38 let sigma = safe_exp(eta);
39 SigmaJet1 { sigma, d1: sigma }
40}
41
42#[inline]
43pub fn exp_sigma_from_eta_scalar(eta: f64) -> f64 {
44 safe_exp(eta)
45}
46
47pub const EXP_NEG_STABLE_MAX_ARG: f64 = 500.0;
52
53#[inline]
63pub fn exp_neg_stable(x: f64) -> f64 {
64 (-x).min(EXP_NEG_STABLE_MAX_ARG).exp()
65}
66
67#[inline]
73pub fn exp_sigma_inverse_from_eta_scalar(eta: f64) -> f64 {
74 exp_neg_stable(eta)
75}
76
77#[inline]
78pub fn exp_sigma_eta_for_sigma_scalar(sigma: f64) -> f64 {
79 assert!(
80 sigma.is_finite(),
81 "exp sigma inverse link requires finite sigma: sigma={sigma}"
82 );
83 assert!(
84 sigma > 0.0,
85 "exp sigma inverse link: sigma must be positive (got sigma={sigma})"
86 );
87 sigma.ln()
88}
89
90#[inline]
91pub fn exp_sigma_jet3_scalar(eta: f64) -> SigmaJet3 {
92 let jet = exp_sigma_jet4_scalar(eta);
93 SigmaJet3 {
94 sigma: jet.sigma,
95 d1: jet.d1,
96 d2: jet.d2,
97 d3: jet.d3,
98 }
99}
100
101#[inline]
102pub fn exp_sigma_derivs_up_to_third_scalar(eta: f64) -> (f64, f64, f64, f64) {
103 let jet = exp_sigma_jet3_scalar(eta);
104 (jet.sigma, jet.d1, jet.d2, jet.d3)
105}
106
107pub fn exp_sigma_derivs_up_to_third(
108 eta: ArrayView1<'_, f64>,
109) -> (Array1<f64>, Array1<f64>, Array1<f64>, Array1<f64>) {
110 let n = eta.len();
111 let mut sigma = Array1::<f64>::uninit(n);
112 let mut d1 = Array1::<f64>::uninit(n);
113 let mut d2 = Array1::<f64>::uninit(n);
114 let mut d3 = Array1::<f64>::uninit(n);
115 for i in 0..n {
116 let jet = exp_sigma_jet3_scalar(eta[i]);
117 sigma[i].write(jet.sigma);
118 d1[i].write(jet.d1);
119 d2[i].write(jet.d2);
120 d3[i].write(jet.d3);
121 }
122 unsafe {
125 (
126 sigma.assume_init(),
127 d1.assume_init(),
128 d2.assume_init(),
129 d3.assume_init(),
130 )
131 }
132}
133
134#[inline]
135pub(crate) fn exp_sigma_jet4_scalar(eta: f64) -> SigmaJet4 {
136 let sigma = safe_exp(eta);
137 SigmaJet4 {
138 sigma,
139 d1: sigma,
140 d2: sigma,
141 d3: sigma,
142 d4: sigma,
143 }
144}
145
146#[inline]
147pub fn exp_sigma_derivs_up_to_fourth_scalar(eta: f64) -> (f64, f64, f64, f64, f64) {
148 let jet = exp_sigma_jet4_scalar(eta);
149 (jet.sigma, jet.d1, jet.d2, jet.d3, jet.d4)
150}
151
152pub fn exp_sigma_derivs_up_to_fourth(
153 eta: ArrayView1<'_, f64>,
154) -> (
155 Array1<f64>,
156 Array1<f64>,
157 Array1<f64>,
158 Array1<f64>,
159 Array1<f64>,
160) {
161 let n = eta.len();
162 let mut sigma = Array1::<f64>::uninit(n);
163 let mut d1 = Array1::<f64>::uninit(n);
164 let mut d2 = Array1::<f64>::uninit(n);
165 let mut d3 = Array1::<f64>::uninit(n);
166 let mut d4 = Array1::<f64>::uninit(n);
167 for i in 0..n {
168 let jet = exp_sigma_jet4_scalar(eta[i]);
169 sigma[i].write(jet.sigma);
170 d1[i].write(jet.d1);
171 d2[i].write(jet.d2);
172 d3[i].write(jet.d3);
173 d4[i].write(jet.d4);
174 }
175 unsafe {
178 (
179 sigma.assume_init(),
180 d1.assume_init(),
181 d2.assume_init(),
182 d3.assume_init(),
183 d4.assume_init(),
184 )
185 }
186}
187
188pub const LOGB_SIGMA_FLOOR: f64 = 0.01;
236
237#[inline]
238pub fn logb_sigma_jet1_scalar(eta: f64) -> SigmaJet1 {
239 let s = safe_exp(eta);
240 SigmaJet1 {
241 sigma: LOGB_SIGMA_FLOOR + s,
242 d1: s,
243 }
244}
245
246#[inline]
247pub fn logb_sigma_from_eta_scalar(eta: f64) -> f64 {
248 LOGB_SIGMA_FLOOR + safe_exp(eta)
249}
250
251#[inline]
263pub fn logb_sigma_from_eta_with_floor_scalar(floor: f64, eta: f64) -> f64 {
264 floor + safe_exp(eta)
265}
266
267#[inline]
268pub fn logb_sigma_eta_for_sigma_scalar(sigma: f64) -> f64 {
269 assert!(
270 sigma.is_finite(),
271 "logb sigma inverse link requires finite sigma: sigma={sigma}"
272 );
273 assert!(
274 sigma > LOGB_SIGMA_FLOOR,
275 "logb sigma inverse link: sigma must exceed LOGB_SIGMA_FLOOR (got sigma={sigma}, floor={LOGB_SIGMA_FLOOR})"
276 );
277 (sigma - LOGB_SIGMA_FLOOR).ln()
278}
279
280#[inline]
281pub fn logb_sigma_jet3_scalar(eta: f64) -> SigmaJet3 {
282 let jet = logb_sigma_jet4_scalar(eta);
283 SigmaJet3 {
284 sigma: jet.sigma,
285 d1: jet.d1,
286 d2: jet.d2,
287 d3: jet.d3,
288 }
289}
290
291#[inline]
292pub fn logb_sigma_derivs_up_to_third_scalar(eta: f64) -> (f64, f64, f64, f64) {
293 let jet = logb_sigma_jet3_scalar(eta);
294 (jet.sigma, jet.d1, jet.d2, jet.d3)
295}
296
297#[inline]
298pub(crate) fn logb_sigma_jet4_scalar(eta: f64) -> SigmaJet4 {
299 let s = safe_exp(eta);
300 SigmaJet4 {
301 sigma: LOGB_SIGMA_FLOOR + s,
302 d1: s,
303 d2: s,
304 d3: s,
305 d4: s,
306 }
307}
308
309#[inline]
310pub fn logb_sigma_derivs_up_to_fourth_scalar(eta: f64) -> (f64, f64, f64, f64, f64) {
311 let jet = logb_sigma_jet4_scalar(eta);
312 (jet.sigma, jet.d1, jet.d2, jet.d3, jet.d4)
313}
314
315pub fn logb_sigma_derivs_up_to_fourth(
316 eta: ArrayView1<'_, f64>,
317) -> (
318 Array1<f64>,
319 Array1<f64>,
320 Array1<f64>,
321 Array1<f64>,
322 Array1<f64>,
323) {
324 let n = eta.len();
325 let mut sigma = Array1::<f64>::uninit(n);
326 let mut d1 = Array1::<f64>::uninit(n);
327 let mut d2 = Array1::<f64>::uninit(n);
328 let mut d3 = Array1::<f64>::uninit(n);
329 let mut d4 = Array1::<f64>::uninit(n);
330 for i in 0..n {
331 let jet = logb_sigma_jet4_scalar(eta[i]);
332 sigma[i].write(jet.sigma);
333 d1[i].write(jet.d1);
334 d2[i].write(jet.d2);
335 d3[i].write(jet.d3);
336 d4[i].write(jet.d4);
337 }
338 unsafe {
341 (
342 sigma.assume_init(),
343 d1.assume_init(),
344 d2.assume_init(),
345 d3.assume_init(),
346 d4.assume_init(),
347 )
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::fs;
355 use std::path::Path;
356
357 fn collect_rs_files(dir: &Path, out: &mut Vec<std::path::PathBuf>) {
358 let Ok(entries) = fs::read_dir(dir) else {
359 return;
360 };
361 for entry in entries.flatten() {
362 let path = entry.path();
363 if path.is_dir() {
364 collect_rs_files(&path, out);
365 continue;
366 }
367 if path.extension().and_then(|e| e.to_str()) == Some("rs") {
368 out.push(path);
369 }
370 }
371 }
372
373 fn stripwhitespace(s: &str) -> String {
374 s.chars().filter(|c| !c.is_whitespace()).collect()
375 }
376
377 #[test]
378 fn forbid_bounded_sigma_link_pattern_in_source() {
379 let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
380 let mut files = Vec::new();
381 collect_rs_files(&root, &mut files);
382
383 let bad_patterns = [
384 "bounded_sigma",
385 "model.sigma_min",
386 "model.sigma_max",
387 "payload.sigma_min",
388 "payload.sigma_max",
389 "survival_sigma_min",
390 "survival_sigma_max",
391 "fnsafe_sigma_from_eta(",
392 "fnsigma_and_deriv_from_eta(",
393 "fnsigma_from_eta_scalar(",
394 ];
395
396 for file in files {
397 if file.ends_with("families/sigma_link.rs") {
398 continue;
399 }
400 let Ok(content) = fs::read_to_string(&file) else {
401 continue;
402 };
403 let compact = stripwhitespace(&content);
404 for pat in bad_patterns {
405 assert!(
406 !compact.contains(pat),
407 "forbidden sigma link pattern '{pat}' found in {}",
408 file.display()
409 );
410 }
411 }
412 }
413
414 fn assert_sigma_derivs_match_fd(
421 sigma_at: impl Fn(f64) -> f64,
422 derivs_at: impl Fn(f64) -> (f64, f64, f64, f64),
423 ) {
424 let h = 1e-5;
425 let h3 = 2e-3;
426 let points = [-6.0, -3.5, -1.2, 0.0, 0.8, 2.1, 6.0];
427
428 for &eta in &points {
429 let (s, d1, d2, d3) = derivs_at(eta);
430 let s_plus = sigma_at(eta + h);
431 let s_minus = sigma_at(eta - h);
432
433 let d1fd = (s_plus - s_minus) / (2.0 * h);
434 let d2fd = (s_plus - 2.0 * s + s_minus) / (h * h);
435 let d2_at = |x: f64| {
436 let xp = sigma_at(x + h3);
437 let xc = sigma_at(x);
438 let xm = sigma_at(x - h3);
439 (xp - 2.0 * xc + xm) / (h3 * h3)
440 };
441 let d3fd = (d2_at(eta + h3) - d2_at(eta - h3)) / (2.0 * h3);
442
443 let d1_scale = d1.abs().max(d1fd.abs()).max(1.0);
444 let d2_scale = d2.abs().max(d2fd.abs()).max(1.0);
445 let d3_scale = d3.abs().max(d3fd.abs()).max(1.0);
446
447 assert!((d1 - d1fd).abs() < 1e-8 * d1_scale);
448 assert!((d2 - d2fd).abs() < 1e-5 * d2_scale);
449 assert!((d3 - d3fd).abs() < 5e-4 * d3_scale);
450 }
451 }
452
453 #[test]
454 fn exp_sigma_derivatives_match_finite_difference() {
455 assert_sigma_derivs_match_fd(
456 exp_sigma_from_eta_scalar,
457 exp_sigma_derivs_up_to_third_scalar,
458 );
459 }
460
461 #[test]
462 fn exp_sigma_fourth_derivative_matches_finite_difference() {
463 let h = 2e-3;
464 let points = [-6.0, -3.0, -1.1, 0.0, 0.6, 1.9, 5.5];
465
466 let d3_at = |x: f64| exp_sigma_derivs_up_to_third_scalar(x).3;
467 for &eta in &points {
468 let (_, d1_4, d2_4, d3_4, d4_4) = exp_sigma_derivs_up_to_fourth_scalar(eta);
469 let (_, d1_3, d2_3, d3_3) = exp_sigma_derivs_up_to_third_scalar(eta);
470 assert!((d1_4 - d1_3).abs() < 1e-12);
471 assert!((d2_4 - d2_3).abs() < 1e-12);
472 assert!((d3_4 - d3_3).abs() < 1e-12);
473
474 let d4fd = (d3_at(eta + h) - d3_at(eta - h)) / (2.0 * h);
475 let d4_scale = d4_4.abs().max(d4fd.abs()).max(1.0);
476 assert!((d4_4 - d4fd).abs() < 5e-4 * d4_scale);
477 }
478 }
479
480 #[test]
481 fn exp_sigmavectorized_up_to_fourth_matches_scalar() {
482 let eta = Array1::from_vec(vec![-701.0, -4.2, -1.4, -0.2, 0.4, 1.9, 3.1, 701.0]);
483 let (s, d1, d2, d3, d4) = exp_sigma_derivs_up_to_fourth(eta.view());
484 for i in 0..eta.len() {
485 let (ss, d1s, d2s, d3s, d4s) = exp_sigma_derivs_up_to_fourth_scalar(eta[i]);
486 assert!((s[i] - ss).abs() < 1e-12);
487 assert!((d1[i] - d1s).abs() < 1e-12);
488 assert!((d2[i] - d2s).abs() < 1e-12);
489 assert!((d3[i] - d3s).abs() < 1e-12);
490 assert!((d4[i] - d4s).abs() < 1e-12);
491 }
492 }
493
494 #[test]
495 fn exp_sigma_inverse_accepts_positive_sigma() {
496 let eta = exp_sigma_eta_for_sigma_scalar(2.5);
497 assert!(eta.is_finite());
498 assert!((eta - 2.5_f64.ln()).abs() < 1e-12);
499 }
500
501 #[test]
502 #[should_panic(expected = "sigma must be positive")]
503 fn exp_sigma_inverse_rejects_non_positive_sigma() {
504 exp_sigma_eta_for_sigma_scalar(0.0);
505 }
506
507 #[test]
508 fn safe_exp_matches_native_exp_semantics() {
509 assert!(safe_exp(0.0).is_finite());
510 assert!(safe_exp(700.0).is_finite());
511 assert!(safe_exp(-700.0).is_finite());
512 assert!(safe_exp(1000.0).is_infinite());
513 assert_eq!(safe_exp(-1000.0), 0.0);
514 assert!(safe_exp(f64::MAX).is_infinite());
515 assert_eq!(safe_exp(f64::MIN), 0.0);
516 assert!((safe_exp(1.0) - 1.0_f64.exp()).abs() < 1e-15);
517 assert!((safe_exp(-5.0) - (-5.0_f64).exp()).abs() < 1e-15);
518 }
519
520 #[test]
521 fn exp_sigma_derivatives_match_exact_exp_in_far_tails() {
522 for &eta in &[709.0, -745.0] {
523 let (sigma, d1, d2, d3, d4) = exp_sigma_derivs_up_to_fourth_scalar(eta);
524 assert_eq!(sigma, eta.exp());
525 assert_eq!(d1, sigma);
526 assert_eq!(d2, sigma);
527 assert_eq!(d3, sigma);
528 assert_eq!(d4, sigma);
529 }
530 }
531
532 #[test]
533 fn logb_sigma_floor_bounds_below_for_arbitrarily_negative_eta() {
534 for &eta in &[-1000.0, -100.0, -50.0, -10.0] {
535 let sigma = logb_sigma_from_eta_scalar(eta);
536 assert!(sigma >= LOGB_SIGMA_FLOOR);
537 assert!(sigma.is_finite());
538 let inv_s2 = (sigma * sigma).recip();
539 assert!(inv_s2 <= LOGB_SIGMA_FLOOR.powi(-2) + 1e-12);
540 }
541 }
542
543 #[test]
544 fn logb_sigma_recovers_exp_link_in_upper_regime() {
545 for &eta in &[3.0, 5.0, 10.0] {
546 let logb = logb_sigma_from_eta_scalar(eta);
547 let pure_exp = exp_sigma_from_eta_scalar(eta);
548 let rel_err = (logb - pure_exp).abs() / pure_exp;
549 assert!(rel_err < 1e-2);
550 }
551 }
552
553 #[test]
554 fn logb_sigma_jet_d1_through_d4_match_pure_exp_eta() {
555 for &eta in &[-3.0_f64, 0.0, 2.0] {
556 let s = eta.exp();
557 let jet1 = logb_sigma_jet1_scalar(eta);
558 let jet3 = logb_sigma_jet3_scalar(eta);
559 let jet4 = logb_sigma_jet4_scalar(eta);
560 assert!((jet1.sigma - (LOGB_SIGMA_FLOOR + s)).abs() < 1e-12);
561 assert!((jet1.d1 - s).abs() < 1e-12);
562 assert!((jet3.sigma - (LOGB_SIGMA_FLOOR + s)).abs() < 1e-12);
563 assert!((jet3.d1 - s).abs() < 1e-12);
564 assert!((jet3.d2 - s).abs() < 1e-12);
565 assert!((jet3.d3 - s).abs() < 1e-12);
566 assert!((jet4.d4 - s).abs() < 1e-12);
567 }
568 }
569
570 #[test]
571 fn logb_sigma_derivatives_match_finite_difference() {
572 assert_sigma_derivs_match_fd(
573 logb_sigma_from_eta_scalar,
574 logb_sigma_derivs_up_to_third_scalar,
575 );
576 }
577
578 #[test]
579 fn logb_sigma_inverse_round_trip() {
580 for &sigma in &[
581 LOGB_SIGMA_FLOOR + 1e-3,
582 LOGB_SIGMA_FLOOR + 0.5,
583 1.0,
584 10.0,
585 1e6,
586 ] {
587 let eta = logb_sigma_eta_for_sigma_scalar(sigma);
588 let recovered = logb_sigma_from_eta_scalar(eta);
589 let scale = sigma.abs().max(1.0);
590 assert!((recovered - sigma).abs() < 1e-10 * scale);
591 }
592 }
593
594 #[test]
595 #[should_panic(expected = "sigma must exceed LOGB_SIGMA_FLOOR")]
596 fn logb_sigma_inverse_rejects_sigma_at_floor() {
597 logb_sigma_eta_for_sigma_scalar(LOGB_SIGMA_FLOOR);
598 }
599
600 #[test]
601 fn logb_sigma_vectorized_matches_scalar() {
602 let eta = Array1::from_vec(vec![-701.0, -4.2, -1.4, -0.2, 0.4, 1.9, 3.1, 701.0]);
603 let (s, d1, d2, d3, d4) = logb_sigma_derivs_up_to_fourth(eta.view());
604 for i in 0..eta.len() {
605 let (ss, d1s, d2s, d3s, d4s) = logb_sigma_derivs_up_to_fourth_scalar(eta[i]);
606 assert!((s[i] - ss).abs() < 1e-12);
607 assert!((d1[i] - d1s).abs() < 1e-12);
608 assert!((d2[i] - d2s).abs() < 1e-12);
609 assert!((d3[i] - d3s).abs() < 1e-12);
610 assert!((d4[i] - d4s).abs() < 1e-12);
611 }
612 }
613}