1use numra_core::uncertainty::Uncertain;
39use numra_core::Scalar;
40
41use crate::error::SolverError;
42use crate::problem::OdeSystem;
43use crate::sensitivity::solve_forward_sensitivity_with;
44use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
45
46#[derive(Clone, Debug)]
52pub enum UncertaintyMode {
53 Trajectory,
55 MonteCarlo {
57 n_samples: usize,
59 },
60}
61
62#[derive(Clone, Debug)]
64pub struct UncertainParam<S: Scalar> {
65 pub name: String,
67 pub nominal: S,
69 pub std: S,
71}
72
73impl<S: Scalar> UncertainParam<S> {
74 pub fn new(name: impl Into<String>, nominal: S, std: S) -> Self {
76 Self {
77 name: name.into(),
78 nominal,
79 std,
80 }
81 }
82
83 pub fn from_uncertain(name: impl Into<String>, u: Uncertain<S>) -> Self {
85 Self {
86 name: name.into(),
87 nominal: u.mean,
88 std: u.std(),
89 }
90 }
91
92 pub fn variance(&self) -> S {
94 self.std * self.std
95 }
96}
97
98#[derive(Clone, Debug)]
100pub struct UncertainSolverResult<S: Scalar> {
101 pub result: SolverResult<S>,
103 pub sigma: Vec<S>,
106 pub sensitivities: Option<Vec<Vec<S>>>,
110 pub params: Vec<UncertainParam<S>>,
112}
113
114impl<S: Scalar> UncertainSolverResult<S> {
115 pub fn sigma_at(&self, i: usize, j: usize) -> S {
117 self.sigma[i * self.result.dim + j]
118 }
119
120 pub fn uncertain_at(&self, i: usize, j: usize) -> Uncertain<S> {
122 let mean = self.result.y_at(i)[j];
123 let std = self.sigma_at(i, j);
124 Uncertain::from_std(mean, std)
125 }
126
127 pub fn sensitivity_at(&self, i: usize, j: usize, k: usize) -> Option<S> {
130 self.sensitivities.as_ref().map(|sens| {
131 let n_params = self.params.len();
132 sens[i][j * n_params + k]
133 })
134 }
135
136 pub fn len(&self) -> usize {
138 self.result.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.result.is_empty()
144 }
145}
146
147pub fn solve_trajectory<Sol, S, F>(
170 model: F,
171 y0: &[S],
172 t0: S,
173 tf: S,
174 params: &[UncertainParam<S>],
175 options: &SolverOptions<S>,
176) -> Result<UncertainSolverResult<S>, SolverError>
177where
178 S: Scalar,
179 Sol: Solver<S>,
180 F: Fn(S, &[S], &mut [S], &[S]),
181{
182 let n_states = y0.len();
183 let n_params = params.len();
184 let nominal_params: Vec<S> = params.iter().map(|p| p.nominal).collect();
185 let variances: Vec<S> = params.iter().map(|p| p.variance()).collect();
186
187 let rhs = move |t: S, y: &[S], p: &[S], dydt: &mut [S]| {
190 model(t, y, dydt, p);
191 };
192
193 let sens =
194 solve_forward_sensitivity_with::<Sol, S, _>(rhs, y0, &nominal_params, t0, tf, options)?;
195
196 if !sens.success {
197 return Err(SolverError::Other(sens.message));
198 }
199
200 let n_times = sens.len();
208 let mut sens_out = Vec::with_capacity(n_times);
209 let mut sigma_out = Vec::with_capacity(n_times * n_states);
210
211 for i in 0..n_times {
212 let block = sens.sensitivity_at(i);
213 let mut row_major = vec![S::ZERO; n_states * n_params];
214 for j in 0..n_states {
215 for k in 0..n_params {
216 row_major[j * n_params + k] = block[k * n_states + j];
217 }
218 }
219 sens_out.push(row_major);
220
221 for j in 0..n_states {
223 let mut var_j = S::ZERO;
224 for k in 0..n_params {
225 let dydp = block[k * n_states + j];
226 var_j = var_j + dydp * dydp * variances[k];
227 }
228 sigma_out.push(var_j.sqrt());
229 }
230 }
231
232 let nominal_result = SolverResult {
233 t: sens.t,
234 y: sens.y,
235 dim: n_states,
236 stats: sens.stats,
237 success: true,
238 message: String::new(),
239 events: Vec::new(),
240 terminated_by_event: false,
241 dense_output: None,
242 };
243
244 Ok(UncertainSolverResult {
245 result: nominal_result,
246 sigma: sigma_out,
247 sensitivities: Some(sens_out),
248 params: params.to_vec(),
249 })
250}
251
252pub fn solve_monte_carlo<Sol, S, F>(
280 model: F,
281 y0: &[S],
282 t0: S,
283 tf: S,
284 params: &[UncertainParam<S>],
285 n_samples: usize,
286 options: &SolverOptions<S>,
287 seed: u64,
288) -> Result<UncertainSolverResult<S>, SolverError>
289where
290 S: Scalar,
291 Sol: Solver<S>,
292 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync,
293{
294 let n_states = y0.len();
295 let n_params = params.len();
296
297 let nominal_params: Vec<S> = params.iter().map(|p| p.nominal).collect();
299 let nominal_sys = ParameterizedWrapper {
300 model: &model,
301 params: nominal_params.clone(),
302 n_dim: n_states,
303 };
304
305 let nominal_result = Sol::solve(&nominal_sys, t0, tf, y0, options)?;
306 if !nominal_result.success {
307 return Err(SolverError::Other(nominal_result.message));
308 }
309
310 let n_times = nominal_result.len();
311
312 let mut sum_final = vec![S::ZERO; n_states];
316 let mut sum_sq_final = vec![S::ZERO; n_states];
317 let mut n_success: usize = 0;
318
319 let mut rng_state = seed;
320
321 for _ in 0..n_samples {
322 let mut p_sample = Vec::with_capacity(n_params);
324 for param in params {
325 let z = box_muller_sample(&mut rng_state);
326 let p_val = param.nominal + param.std * S::from_f64(z);
327 p_sample.push(p_val);
328 }
329
330 let sample_sys = ParameterizedWrapper {
331 model: &model,
332 params: p_sample,
333 n_dim: n_states,
334 };
335
336 match Sol::solve(&sample_sys, t0, tf, y0, options) {
337 Ok(result) if result.success => {
338 if let Some(y_final) = result.y_final() {
339 n_success += 1;
340 for j in 0..n_states {
341 sum_final[j] = sum_final[j] + y_final[j];
342 sum_sq_final[j] = sum_sq_final[j] + y_final[j] * y_final[j];
343 }
344 }
345 }
346 _ => {
347 }
349 }
350 }
351
352 if n_success < 2 {
353 return Err(SolverError::Other(
354 "Monte Carlo: fewer than 2 samples succeeded".to_string(),
355 ));
356 }
357
358 let n_s = S::from_usize(n_success);
360 let mut sigma_final = Vec::with_capacity(n_states);
361 for j in 0..n_states {
362 let mean = sum_final[j] / n_s;
363 let var = (sum_sq_final[j] / n_s - mean * mean) * n_s / (n_s - S::ONE);
364 let std = if var > S::ZERO { var.sqrt() } else { S::ZERO };
365 sigma_final.push(std);
366 }
367
368 let mut sigma = Vec::with_capacity(n_times * n_states);
372 for i in 0..n_times {
373 let frac = if n_times > 1 {
374 S::from_usize(i) / S::from_usize(n_times - 1)
375 } else {
376 S::ONE
377 };
378 for j in 0..n_states {
379 sigma.push(sigma_final[j] * frac);
380 }
381 }
382
383 let mc_result = SolverResult {
384 t: nominal_result.t.clone(),
385 y: nominal_result.y.clone(),
386 dim: n_states,
387 stats: SolverStats::new(),
388 success: true,
389 message: format!("{}/{} samples succeeded", n_success, n_samples),
390 events: Vec::new(),
391 terminated_by_event: false,
392 dense_output: None,
393 };
394
395 Ok(UncertainSolverResult {
396 result: mc_result,
397 sigma,
398 sensitivities: None,
399 params: params.to_vec(),
400 })
401}
402
403pub fn solve_with_uncertainty<Sol, S, F>(
432 model: F,
433 y0: &[S],
434 t0: S,
435 tf: S,
436 params: &[UncertainParam<S>],
437 mode: &UncertaintyMode,
438 options: &SolverOptions<S>,
439 seed: Option<u64>,
440) -> Result<UncertainSolverResult<S>, SolverError>
441where
442 S: Scalar,
443 Sol: Solver<S>,
444 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync,
445{
446 match mode {
447 UncertaintyMode::Trajectory => {
448 solve_trajectory::<Sol, S, F>(model, y0, t0, tf, params, options)
449 }
450 UncertaintyMode::MonteCarlo { n_samples } => solve_monte_carlo::<Sol, S, F>(
451 model,
452 y0,
453 t0,
454 tf,
455 params,
456 *n_samples,
457 options,
458 seed.unwrap_or(42),
459 ),
460 }
461}
462
463struct ParameterizedWrapper<'a, S: Scalar, F> {
470 model: &'a F,
471 params: Vec<S>,
472 n_dim: usize,
473}
474
475impl<S: Scalar, F> OdeSystem<S> for ParameterizedWrapper<'_, S, F>
476where
477 F: Fn(S, &[S], &mut [S], &[S]) + Send + Sync,
478{
479 fn dim(&self) -> usize {
480 self.n_dim
481 }
482
483 fn rhs(&self, t: S, y: &[S], dydt: &mut [S]) {
484 (self.model)(t, y, dydt, &self.params);
485 }
486}
487
488fn splitmix64(state: &mut u64) -> u64 {
491 *state = state.wrapping_add(0x9e3779b97f4a7c15);
492 let mut z = *state;
493 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
494 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
495 z ^ (z >> 31)
496}
497
498fn box_muller_sample(state: &mut u64) -> f64 {
500 loop {
501 let u1 = ((splitmix64(state) >> 11) as f64) / ((1u64 << 53) as f64);
502 let u2 = ((splitmix64(state) >> 11) as f64) / ((1u64 << 53) as f64);
503 if u1 > 1e-15 {
504 let r = (-2.0 * u1.ln()).sqrt();
505 return r * (2.0 * core::f64::consts::PI * u2).cos();
506 }
507 }
508}
509
510#[cfg(test)]
515mod tests {
516 use super::*;
517 use crate::{DoPri5, Radau5};
518
519 #[test]
525 fn test_trajectory_exponential_decay() {
526 let params = vec![UncertainParam::new("k", 0.5, 0.05)];
527
528 let result = solve_with_uncertainty::<DoPri5, f64, _>(
529 |_t, y, dydt, p| {
530 dydt[0] = -p[0] * y[0];
531 },
532 &[1.0],
533 0.0,
534 5.0,
535 ¶ms,
536 &UncertaintyMode::Trajectory,
537 &SolverOptions::default().rtol(1e-8).atol(1e-10),
538 None,
539 )
540 .expect("solve_with_uncertainty failed");
541
542 assert!(result.result.success);
543 assert!(!result.is_empty());
544
545 let n = result.len();
547 let t_final = result.result.t[n - 1];
548 let k = 0.5;
549 let sigma_k = 0.05;
550
551 let exact_sigma = t_final * (-k * t_final).exp() * sigma_k;
552 let computed_sigma = result.sigma_at(n - 1, 0);
553
554 assert!(
555 (computed_sigma - exact_sigma).abs() < 0.001,
556 "sigma: computed={}, exact={}, err={}",
557 computed_sigma,
558 exact_sigma,
559 (computed_sigma - exact_sigma).abs()
560 );
561
562 assert!(result.sensitivities.is_some());
564 let dydp = result.sensitivity_at(n - 1, 0, 0).unwrap();
565 let exact_dydp = -t_final * (-k * t_final).exp();
566 assert!(
567 (dydp - exact_dydp).abs() < 0.001,
568 "dy/dk: computed={}, exact={}, err={}",
569 dydp,
570 exact_dydp,
571 (dydp - exact_dydp).abs()
572 );
573 }
574
575 #[test]
578 fn test_trajectory_two_params() {
579 let params = vec![
580 UncertainParam::new("a", 1.0, 0.1),
581 UncertainParam::new("b", 2.0, 0.2),
582 ];
583
584 let result = solve_trajectory::<DoPri5, f64, _>(
585 |_t, y, dydt, p| {
586 dydt[0] = -p[0] * y[0] + p[1];
587 },
588 &[1.0],
589 0.0,
590 3.0,
591 ¶ms,
592 &SolverOptions::default().rtol(1e-8).atol(1e-10),
593 )
594 .expect("solve_trajectory failed");
595
596 assert!(result.result.success);
597
598 let n = result.len();
600 let y_final = result.result.y_at(n - 1)[0];
601 assert!(
602 (y_final - 2.0).abs() < 0.1,
603 "y(3) = {}, expected near 2.0",
604 y_final
605 );
606
607 let sigma_final = result.sigma_at(n - 1, 0);
609 assert!(
610 sigma_final > 0.0,
611 "sigma should be positive: {}",
612 sigma_final
613 );
614 }
615
616 #[test]
618 fn test_trajectory_vs_monte_carlo() {
619 let params = vec![UncertainParam::new("k", 0.5, 0.05)];
620
621 let traj_result = solve_with_uncertainty::<DoPri5, f64, _>(
622 |_t, y, dydt, p| {
623 dydt[0] = -p[0] * y[0];
624 },
625 &[1.0],
626 0.0,
627 2.0,
628 ¶ms,
629 &UncertaintyMode::Trajectory,
630 &SolverOptions::default().rtol(1e-8).atol(1e-10),
631 None,
632 )
633 .expect("trajectory failed");
634
635 let mc_result = solve_with_uncertainty::<DoPri5, f64, _>(
636 |_t, y, dydt, p| {
637 dydt[0] = -p[0] * y[0];
638 },
639 &[1.0],
640 0.0,
641 2.0,
642 ¶ms,
643 &UncertaintyMode::MonteCarlo { n_samples: 5000 },
644 &SolverOptions::default().rtol(1e-8).atol(1e-10),
645 Some(12345),
646 )
647 .expect("monte carlo failed");
648
649 let n_traj = traj_result.len();
651 let n_mc = mc_result.len();
652 let sigma_traj = traj_result.sigma_at(n_traj - 1, 0);
653 let sigma_mc = mc_result.sigma_at(n_mc - 1, 0);
654
655 let rel_diff = (sigma_traj - sigma_mc).abs() / sigma_traj;
656 assert!(
657 rel_diff < 0.15,
658 "Trajectory sigma={}, MC sigma={}, rel_diff={}",
659 sigma_traj,
660 sigma_mc,
661 rel_diff
662 );
663 }
664
665 #[test]
667 fn test_composability_stiff_solver() {
668 let params = vec![UncertainParam::new("k", 50.0, 5.0)];
669
670 let result = solve_trajectory::<Radau5, f64, _>(
671 |_t, y, dydt, p| {
672 dydt[0] = -p[0] * y[0];
673 },
674 &[1.0],
675 0.0,
676 0.2,
677 ¶ms,
678 &SolverOptions::default().rtol(1e-4).atol(1e-7).h0(1e-4),
679 )
680 .expect("stiff solve failed");
681
682 assert!(result.result.success);
683 assert!(!result.is_empty());
684
685 let n = result.len();
687 let sigma = result.sigma_at(n - 1, 0);
688 assert!(sigma >= 0.0, "sigma should be non-negative: {}", sigma);
689 }
690
691 #[test]
693 fn test_trajectory_lotka_volterra() {
694 let params = vec![
695 UncertainParam::new("alpha", 1.0, 0.1),
696 UncertainParam::new("beta", 0.1, 0.01),
697 UncertainParam::new("delta", 0.075, 0.005),
698 UncertainParam::new("gamma", 1.5, 0.1),
699 ];
700
701 let result = solve_trajectory::<DoPri5, f64, _>(
702 |_t, y, dydt, p| {
703 let x = y[0];
704 let yy = y[1];
705 dydt[0] = p[0] * x - p[1] * x * yy;
706 dydt[1] = p[2] * x * yy - p[3] * yy;
707 },
708 &[10.0, 5.0],
709 0.0,
710 5.0,
711 ¶ms,
712 &SolverOptions::default().rtol(1e-8).atol(1e-10),
713 )
714 .expect("Lotka-Volterra solve failed");
715
716 assert!(result.result.success);
717 assert_eq!(result.result.dim, 2);
718
719 let n = result.len();
721 let sigma_prey = result.sigma_at(n - 1, 0);
722 let sigma_pred = result.sigma_at(n - 1, 1);
723 assert!(sigma_prey > 0.0, "prey sigma should be positive");
724 assert!(sigma_pred > 0.0, "predator sigma should be positive");
725
726 let sens = result.sensitivities.as_ref().unwrap();
728 assert_eq!(sens[n - 1].len(), 2 * 4);
729 }
730
731 #[test]
733 fn test_uncertain_param_from_uncertain() {
734 let u = Uncertain::from_std(5.0, 0.5);
735 let p = UncertainParam::from_uncertain("x", u);
736 assert!((p.nominal - 5.0).abs() < 1e-10);
737 assert!((p.std - 0.5).abs() < 1e-10);
738 assert!((p.variance() - 0.25).abs() < 1e-10);
739 }
740
741 #[test]
743 fn test_zero_uncertainty() {
744 let params = vec![UncertainParam::new("k", 0.5, 0.0)];
745
746 let result = solve_trajectory::<DoPri5, f64, _>(
747 |_t, y, dydt, p| {
748 dydt[0] = -p[0] * y[0];
749 },
750 &[1.0],
751 0.0,
752 2.0,
753 ¶ms,
754 &SolverOptions::default(),
755 )
756 .expect("solve failed");
757
758 for i in 0..result.len() {
760 let sigma = result.sigma_at(i, 0);
761 assert!(
762 sigma.abs() < 1e-10,
763 "sigma should be ~0 at t={}: got {}",
764 result.result.t[i],
765 sigma
766 );
767 }
768 }
769
770 #[test]
772 fn test_uncertain_at() {
773 let params = vec![UncertainParam::new("k", 0.5, 0.05)];
774
775 let result = solve_trajectory::<DoPri5, f64, _>(
776 |_t, y, dydt, p| {
777 dydt[0] = -p[0] * y[0];
778 },
779 &[1.0],
780 0.0,
781 1.0,
782 ¶ms,
783 &SolverOptions::default().rtol(1e-8),
784 )
785 .expect("solve failed");
786
787 let n = result.len();
788 let u = result.uncertain_at(n - 1, 0);
789 assert!(u.mean > 0.0);
790 assert!(u.variance > 0.0);
791 assert!((u.std() - result.sigma_at(n - 1, 0)).abs() < 1e-14);
792 }
793}