1use std::sync::{Arc, Mutex};
2
3use crate::jit::JitBackendType;
4use crate::{
5 error::DiffsolRtError, host_array::HostArray,
6 initial_condition_options::InitialConditionSolverOptions, linear_solver_type::LinearSolverType,
7 matrix_type::MatrixType, ode_options::OdeSolverOptions, ode_solver_type::OdeSolverType,
8 scalar_type::ScalarType, solution_wrapper::SolutionWrapper, solve::Solve,
9};
10
11pub struct Ode {
12 pub(crate) solve: Box<dyn Solve>,
13 code: String,
14 scalar_type: ScalarType,
15 jit_backend: Option<JitBackendType>,
16 linear_solver: LinearSolverType,
17 ode_solver: OdeSolverType,
18}
19
20unsafe impl Send for Ode {}
21unsafe impl Sync for Ode {}
22
23#[derive(Clone)]
24pub struct OdeWrapper(Arc<Mutex<Ode>>);
25
26impl OdeWrapper {
27 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
28 self.0.lock().map_err(|_| {
29 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
30 "Failed to acquire lock on ODE solver".to_string(),
31 ))
32 })
33 }
34}
35
36impl OdeWrapper {
37 fn build(
38 code: String,
39 scalar_type: ScalarType,
40 solve: Box<dyn Solve>,
41 jit_backend: Option<JitBackendType>,
42 linear_solver: LinearSolverType,
43 ode_solver: OdeSolverType,
44 ) -> Result<Self, DiffsolRtError> {
45 solve.check(linear_solver)?;
46 Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
47 code,
48 scalar_type,
49 solve,
50 jit_backend,
51 linear_solver,
52 ode_solver,
53 }))))
54 }
55
56 #[cfg(feature = "external")]
58 pub fn new_external(
59 rhs_state_deps: Vec<(usize, usize)>,
60 rhs_input_deps: Vec<(usize, usize)>,
61 mass_state_deps: Vec<(usize, usize)>,
62 scalar_type: ScalarType,
63 matrix_type: MatrixType,
64 linear_solver: LinearSolverType,
65 ode_solver: OdeSolverType,
66 ) -> Result<Self, DiffsolRtError> {
67 let solve = crate::solve::solve_factory_external(
68 rhs_state_deps,
69 rhs_input_deps,
70 mass_state_deps,
71 matrix_type,
72 scalar_type,
73 )?;
74 Self::build(
75 String::new(),
76 scalar_type,
77 solve,
78 None,
79 linear_solver,
80 ode_solver,
81 )
82 }
83
84 #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
86 pub fn new_jit(
87 code: &str,
88 jit_backend: JitBackendType,
89 scalar_type: ScalarType,
90 matrix_type: MatrixType,
91 linear_solver: LinearSolverType,
92 ode_solver: OdeSolverType,
93 ) -> Result<Self, DiffsolRtError> {
94 let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
95 Self::build(
96 code.to_owned(),
97 scalar_type,
98 solve,
99 Some(jit_backend),
100 linear_solver,
101 ode_solver,
102 )
103 }
104
105 pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
107 Ok(self.guard()?.solve.matrix_type())
108 }
109
110 pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
111 Ok(self.guard()?.solve.nstates())
112 }
113
114 pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
115 Ok(self.guard()?.solve.nparams())
116 }
117
118 pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
119 Ok(self.guard()?.solve.nout())
120 }
121
122 pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
123 Ok(self.guard()?.solve.has_stop())
124 }
125
126 pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
128 Ok(self.guard()?.ode_solver)
129 }
130
131 pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
132 self.guard()?.ode_solver = value;
133 Ok(())
134 }
135
136 pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
139 Ok(self.guard()?.linear_solver)
140 }
141
142 pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
143 self.guard()?.solve.check(value)?;
144 self.guard()?.linear_solver = value;
145 Ok(())
146 }
147
148 pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
150 Ok(self.guard()?.solve.rtol())
151 }
152
153 pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
154 self.guard()?.solve.set_rtol(value);
155 Ok(())
156 }
157
158 pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
160 Ok(self.guard()?.solve.atol())
161 }
162
163 pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
164 self.guard()?.solve.set_atol(value);
165 Ok(())
166 }
167
168 pub fn get_code(&self) -> Result<String, DiffsolRtError> {
169 Ok(self.guard()?.code.clone())
170 }
171
172 pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
173 Ok(self.guard()?.scalar_type)
174 }
175
176 pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
177 Ok(self.guard()?.jit_backend)
178 }
179
180 pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
181 InitialConditionSolverOptions::new(self.0.clone())
182 }
183
184 pub fn get_options(&self) -> OdeSolverOptions {
185 OdeSolverOptions::new(self.0.clone())
186 }
187
188 pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
190 let mut self_guard = self.guard()?;
191 self_guard.solve.y0(params.as_slice()?)
192 }
193
194 pub fn rhs(
196 &self,
197 params: HostArray,
198 t: f64,
199 y: HostArray,
200 ) -> Result<HostArray, DiffsolRtError> {
201 let mut self_guard = self.guard()?;
202 self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
203 }
204
205 pub fn rhs_jac_mul(
207 &self,
208 params: HostArray,
209 t: f64,
210 y: HostArray,
211 v: HostArray,
212 ) -> Result<HostArray, DiffsolRtError> {
213 let mut self_guard = self.guard()?;
214 self_guard
215 .solve
216 .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
217 }
218
219 #[allow(clippy::type_complexity)]
237 pub fn solve(
238 &self,
239 params: HostArray,
240 final_time: f64,
241 ) -> Result<SolutionWrapper, DiffsolRtError> {
242 let mut self_guard = self.guard()?;
243 let params = params.as_slice()?;
244 let linear_solver = self_guard.linear_solver;
245 let method = self_guard.ode_solver;
246 let solution = self_guard
247 .solve
248 .solve(method, linear_solver, params, final_time)?;
249 Ok(SolutionWrapper::new(solution))
250 }
251
252 pub fn solve_hybrid(
255 &self,
256 params: HostArray,
257 final_time: f64,
258 ) -> Result<SolutionWrapper, DiffsolRtError> {
259 let mut self_guard = self.guard()?;
260 let params = params.as_slice()?;
261 let linear_solver = self_guard.linear_solver;
262 let method = self_guard.ode_solver;
263 let solution = self_guard
264 .solve
265 .solve_hybrid(method, linear_solver, params, final_time)?;
266 Ok(SolutionWrapper::new(solution))
267 }
268
269 pub fn solve_dense(
283 &self,
284 params: HostArray,
285 t_eval: HostArray,
286 ) -> Result<SolutionWrapper, DiffsolRtError> {
287 let mut self_guard = self.guard()?;
288 let params = params.as_slice()?;
289 let t_eval = t_eval.as_slice()?;
290 let linear_solver = self_guard.linear_solver;
291 let method = self_guard.ode_solver;
292 let solution = self_guard
293 .solve
294 .solve_dense(method, linear_solver, params, t_eval)?;
295 Ok(SolutionWrapper::new(solution))
296 }
297
298 pub fn solve_hybrid_dense(
302 &self,
303 params: HostArray,
304 t_eval: HostArray,
305 ) -> Result<SolutionWrapper, DiffsolRtError> {
306 let mut self_guard = self.guard()?;
307 let params = params.as_slice()?;
308 let t_eval = t_eval.as_slice()?;
309 let linear_solver = self_guard.linear_solver;
310 let method = self_guard.ode_solver;
311 let solution =
312 self_guard
313 .solve
314 .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
315 Ok(SolutionWrapper::new(solution))
316 }
317
318 #[allow(clippy::type_complexity)]
331 pub fn solve_fwd_sens(
332 &self,
333 params: HostArray,
334 t_eval: HostArray,
335 ) -> Result<SolutionWrapper, DiffsolRtError> {
336 let mut self_guard = self.guard()?;
337 let params = params.as_slice()?;
338 let t_eval = t_eval.as_slice()?;
339 let linear_solver = self_guard.linear_solver;
340 let method = self_guard.ode_solver;
341 let solution = self_guard
342 .solve
343 .solve_fwd_sens(method, linear_solver, params, t_eval)?;
344 Ok(SolutionWrapper::new(solution))
345 }
346
347 #[allow(clippy::type_complexity)]
351 pub fn solve_hybrid_fwd_sens(
352 &self,
353 params: HostArray,
354 t_eval: HostArray,
355 ) -> Result<SolutionWrapper, DiffsolRtError> {
356 let mut self_guard = self.guard()?;
357 let params = params.as_slice()?;
358 let t_eval = t_eval.as_slice()?;
359 let linear_solver = self_guard.linear_solver;
360 let method = self_guard.ode_solver;
361 let solution =
362 self_guard
363 .solve
364 .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
365 Ok(SolutionWrapper::new(solution))
366 }
367
368 #[allow(clippy::type_complexity)]
373 pub fn solve_sum_squares_adj(
374 &self,
375 params: HostArray,
376 data: HostArray,
377 t_eval: HostArray,
378 ) -> Result<(f64, HostArray), DiffsolRtError> {
379 let mut self_guard = self.guard()?;
380 let linear_solver = self_guard.linear_solver;
381 let ode_solver = self_guard.ode_solver;
382
383 self_guard.solve.solve_sum_squares_adj(
384 ode_solver,
385 linear_solver,
386 ode_solver,
387 linear_solver,
388 params.as_slice()?,
389 data,
390 t_eval.as_slice()?,
391 )
392 }
393}
394
395#[cfg(all(test, feature = "diffsl-external-f64"))]
396mod tests {
397 use crate::host_array::FromHostArray;
398 use crate::linear_solver_type::LinearSolverType;
399 use crate::scalar_type::ScalarType;
400 use crate::test_support::{
401 assert_close, assert_solution_tail, logistic_integral, logistic_state, logistic_state_dr,
402 mass_state_deps, rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
403 };
404
405 use super::*;
406
407 fn all_ode_solvers() -> [OdeSolverType; 4] {
408 [
409 OdeSolverType::Bdf,
410 OdeSolverType::Esdirk34,
411 OdeSolverType::TrBdf2,
412 OdeSolverType::Tsit45,
413 ]
414 }
415
416 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
417 OdeWrapper::new_external(
418 rhs_state_deps(),
419 rhs_input_deps(),
420 mass_state_deps(),
421 ScalarType::F64,
422 matrix_type,
423 LinearSolverType::Default,
424 ode_solver,
425 )
426 .unwrap()
427 }
428
429 fn assert_runtime_dispatch(matrix_type: MatrixType) {
430 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
431 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
432 assert_eq!(ode.get_nstates().unwrap(), 1);
433 assert_eq!(ode.get_nparams().unwrap(), 1);
434 assert_eq!(ode.get_nout().unwrap(), 1);
435 assert!(ode.has_stop().unwrap());
436
437 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
438 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
439
440 let rhs = ode
441 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
442 .unwrap();
443 assert_close(
444 Vec::<f64>::from_host_array(rhs).unwrap()[0],
445 0.375,
446 ASSERT_TOL,
447 "rhs(0.25)",
448 );
449
450 let rhs_jac_mul = ode
451 .rhs_jac_mul(
452 vector_host(&[2.0]),
453 0.0,
454 vector_host(&[0.25]),
455 vector_host(&[3.0]),
456 )
457 .unwrap();
458 assert_close(
459 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
460 3.0,
461 ASSERT_TOL,
462 "rhs_jac_mul(0.25, 3.0)",
463 );
464 }
465
466 fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
467 let ode = make_ode(matrix_type, ode_solver);
468 ode.set_rtol(1e-8).unwrap();
469 ode.set_atol(1e-8).unwrap();
470
471 let t_eval = [0.25, 0.5, 1.0];
472 let solution = ode
473 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
474 .unwrap();
475
476 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
477 }
478
479 fn hybrid_root_time() -> f64 {
480 0.5 * 9.0_f64.ln()
481 }
482
483 fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
484 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
485 ode.set_rtol(1e-8).unwrap();
486 ode.set_atol(1e-8).unwrap();
487
488 let final_time = 2.0;
489 let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
490 let ys = solution.get_ys().unwrap();
491 let ys = ys.as_array::<f64>().unwrap();
492 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
493 let root_time = hybrid_root_time();
494
495 assert_eq!(ys.nrows(), 1);
496 assert_eq!(ys.ncols(), ts.len());
497 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
498 assert_close(
499 *ts.last().unwrap(),
500 final_time,
501 ASSERT_TOL,
502 "hybrid final time",
503 );
504 assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
505 assert!(
506 ts.iter().any(|&t| t < root_time),
507 "expected pre-root samples"
508 );
509 assert!(
510 ts.iter().any(|&t| t > root_time),
511 "expected post-root samples after reset"
512 );
513 }
514
515 fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
516 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
517 ode.set_rtol(1e-8).unwrap();
518 ode.set_atol(1e-8).unwrap();
519
520 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
521 let solution = ode
522 .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
523 .unwrap();
524 let ys = solution.get_ys().unwrap();
525 let ys = ys.as_array::<f64>().unwrap();
526
527 assert_eq!(ys.nrows(), 1);
528 assert_eq!(ys.ncols(), t_eval.len());
529 assert_close(
530 ys[(0, 0)],
531 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
532 5e-4,
533 "hybrid dense pre-root value",
534 );
535 assert_close(
536 ys[(0, 1)],
537 logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
538 5e-4,
539 "hybrid dense near-root value",
540 );
541 for col in 2..t_eval.len() {
542 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
543 }
544 }
545
546 fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
547 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
548 ode.set_rtol(1e-8).unwrap();
549 ode.set_atol(1e-8).unwrap();
550
551 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
552 let solution = ode
553 .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
554 .unwrap();
555 let ys = solution.get_ys().unwrap();
556 let ys = ys.as_array::<f64>().unwrap();
557 let sens = solution.get_sens().unwrap();
558
559 assert_eq!(ys.nrows(), 1);
560 assert_eq!(ys.ncols(), t_eval.len());
561 assert_eq!(sens.len(), 1);
562 let sens_values = sens[0].as_array::<f64>().unwrap();
563 assert_eq!(sens_values.nrows(), 1);
564 assert_eq!(sens_values.ncols(), t_eval.len());
565 assert_close(
566 ys[(0, 0)],
567 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
568 5e-4,
569 "hybrid sens pre-root value",
570 );
571 for col in 2..t_eval.len() {
572 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
573 assert!(
574 sens_values[(0, col)].is_finite(),
575 "expected finite post-root sensitivity at column {col}"
576 );
577 }
578 }
579
580 #[test]
581 fn runtime_dispatch_matches_requested_matrix_type() {
582 for matrix_type in [
583 MatrixType::NalgebraDense,
584 MatrixType::FaerDense,
585 MatrixType::FaerSparse,
586 ] {
587 assert_runtime_dispatch(matrix_type);
588 }
589 }
590
591 #[test]
592 fn bdf_dense_solution_matches_logistic_solution() {
593 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
594 ode.set_rtol(1e-8).unwrap();
595 ode.set_atol(1e-8).unwrap();
596
597 let t_eval = [0.25, 0.5, 1.0];
598 let solution = ode
599 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
600 .unwrap();
601
602 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
603 }
604
605 #[test]
606 fn esdirk34_dense_solution_matches_logistic_solution() {
607 assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
608 }
609
610 #[test]
611 fn tr_bdf2_sparse_solution_matches_logistic_solution() {
612 assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
613 }
614
615 #[test]
616 fn tsit45_dense_solution_matches_logistic_solution() {
617 assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
618 }
619
620 #[test]
621 fn bdf_forward_sensitivities_match_logistic_derivative() {
622 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
623 ode.set_rtol(1e-8).unwrap();
624 ode.set_atol(1e-8).unwrap();
625
626 let t_eval = [0.25, 0.5, 1.0];
627 let solution = ode
628 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
629 .unwrap();
630
631 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
632 let sens = solution.get_sens().unwrap();
633 assert_eq!(sens.len(), 1);
634 let sens_values = sens[0].as_array::<f64>().unwrap();
635 assert_eq!(sens_values.nrows(), 1);
636 assert_eq!(sens_values.ncols(), t_eval.len());
637 for (i, &t) in t_eval.iter().enumerate() {
638 assert_close(
639 sens_values[(0, i)],
640 logistic_state_dr(LOGISTIC_X0, 2.0, t),
641 ASSERT_TOL,
642 &format!("sensitivity[{i}]"),
643 );
644 }
645 }
646
647 #[test]
648 fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
649 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
650 ode.set_rtol(1e-8).unwrap();
651 ode.set_atol(1e-8).unwrap();
652
653 let t_eval = [0.0, 0.25, 0.5, 1.0];
654 let data_values: Vec<f64> = t_eval
655 .iter()
656 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
657 .collect();
658 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
659 let (value, sens) = ode
660 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
661 .unwrap();
662 let grad = Vec::<f64>::from_host_array(sens).unwrap();
663
664 assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
665 assert_eq!(grad.len(), 1);
666 assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
667 }
668
669 #[test]
670 fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
671 for ode_solver in all_ode_solvers() {
672 assert_hybrid_solution_applies_reset_after_root(ode_solver);
673 }
674 }
675
676 #[test]
677 fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
678 for ode_solver in all_ode_solvers() {
679 assert_hybrid_dense_solution_continues_after_reset(ode_solver);
680 }
681 }
682
683 #[test]
684 fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
685 for ode_solver in all_ode_solvers() {
686 assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
687 }
688 }
689}
690
691#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
692mod jit_tests {
693 use crate::host_array::FromHostArray;
694 use crate::jit::JitBackendType;
695 use crate::linear_solver_type::LinearSolverType;
696 use crate::scalar_type::ScalarType;
697 use crate::test_support::{
698 assert_close, assert_solution_tail, available_jit_backends, hybrid_logistic_diffsl_code,
699 hybrid_logistic_period, hybrid_logistic_state, logistic_diffsl_code, logistic_state,
700 vector_host, ASSERT_TOL, LOGISTIC_X0,
701 };
702 #[cfg(feature = "diffsl-llvm")]
703 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
704
705 use super::*;
706
707 fn all_ode_solvers() -> [OdeSolverType; 4] {
708 [
709 OdeSolverType::Bdf,
710 OdeSolverType::Esdirk34,
711 OdeSolverType::TrBdf2,
712 OdeSolverType::Tsit45,
713 ]
714 }
715
716 fn make_ode(
717 jit_backend: JitBackendType,
718 matrix_type: MatrixType,
719 ode_solver: OdeSolverType,
720 ) -> OdeWrapper {
721 OdeWrapper::new_jit(
722 logistic_diffsl_code(),
723 jit_backend,
724 ScalarType::F64,
725 matrix_type,
726 LinearSolverType::Default,
727 ode_solver,
728 )
729 .unwrap()
730 }
731
732 fn make_hybrid_ode(
733 jit_backend: JitBackendType,
734 matrix_type: MatrixType,
735 ode_solver: OdeSolverType,
736 ) -> OdeWrapper {
737 OdeWrapper::new_jit(
738 hybrid_logistic_diffsl_code(),
739 jit_backend,
740 ScalarType::F64,
741 matrix_type,
742 LinearSolverType::Default,
743 ode_solver,
744 )
745 .unwrap()
746 }
747
748 fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
749 let ode = make_ode(jit_backend, matrix_type, OdeSolverType::Bdf);
750 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
751 assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
752 assert_eq!(ode.get_nstates().unwrap(), 1);
753 assert_eq!(ode.get_nparams().unwrap(), 1);
754 assert_eq!(ode.get_nout().unwrap(), 1);
755 assert!(!ode.has_stop().unwrap());
756
757 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
758 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
759
760 let rhs = ode
761 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
762 .unwrap();
763 assert_close(
764 Vec::<f64>::from_host_array(rhs).unwrap()[0],
765 0.375,
766 ASSERT_TOL,
767 "jit rhs(0.25)",
768 );
769
770 let rhs_jac_mul = ode
771 .rhs_jac_mul(
772 vector_host(&[2.0]),
773 0.0,
774 vector_host(&[0.25]),
775 vector_host(&[3.0]),
776 )
777 .unwrap();
778 assert_close(
779 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
780 3.0,
781 ASSERT_TOL,
782 "jit rhs_jac_mul(0.25, 3.0)",
783 );
784 }
785
786 fn assert_solver_dense_solution(
787 jit_backend: JitBackendType,
788 matrix_type: MatrixType,
789 ode_solver: OdeSolverType,
790 ) {
791 let ode = make_ode(jit_backend, matrix_type, ode_solver);
792 ode.set_rtol(1e-8).unwrap();
793 ode.set_atol(1e-8).unwrap();
794
795 let t_eval = [0.25, 0.5, 1.0];
796 let solution = ode
797 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
798 .unwrap();
799
800 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
801 }
802
803 fn hybrid_t_eval() -> [f64; 7] {
804 [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
805 }
806
807 fn assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
808 jit_backend: JitBackendType,
809 ode_solver: OdeSolverType,
810 ) {
811 let r = 2.0;
812 let final_time = 5.0;
813 let tau = hybrid_logistic_period(r);
814 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
815 ode.set_rtol(1e-8).unwrap();
816 ode.set_atol(1e-8).unwrap();
817 assert_eq!(ode.get_nstates().unwrap(), 1);
818 assert_eq!(ode.get_nparams().unwrap(), 1);
819 assert_eq!(ode.get_nout().unwrap(), 1);
820 assert!(ode.has_stop().unwrap());
821
822 let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
823 let ys = solution.get_ys().unwrap();
824 let ys = ys.as_array::<f64>().unwrap();
825 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
826
827 assert_eq!(ys.nrows(), 1);
828 assert_eq!(ys.ncols(), ts.len());
829 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
830 assert_close(
831 *ts.last().unwrap(),
832 final_time,
833 ASSERT_TOL,
834 "jit hybrid final time",
835 );
836 assert_close(
837 ys[(0, ys.ncols() - 1)],
838 hybrid_logistic_state(r, final_time),
839 5e-4,
840 "jit hybrid final value",
841 );
842 assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
843 assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
844 for (col, &t) in ts.iter().enumerate() {
845 if ((t / tau).round() * tau - t).abs() < 1e-3 {
846 continue;
847 }
848 assert_close(
849 ys[(0, col)],
850 hybrid_logistic_state(r, t),
851 5e-4,
852 &format!("jit hybrid value[{col}]"),
853 );
854 }
855 }
856
857 fn assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
858 jit_backend: JitBackendType,
859 ode_solver: OdeSolverType,
860 ) {
861 let r = 2.0;
862 let t_eval = hybrid_t_eval();
863 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
864 ode.set_rtol(1e-8).unwrap();
865 ode.set_atol(1e-8).unwrap();
866
867 let solution = ode
868 .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
869 .unwrap();
870 let ys = solution.get_ys().unwrap();
871 let ys = ys.as_array::<f64>().unwrap();
872 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
873
874 assert_eq!(ys.nrows(), 1);
875 assert_eq!(ys.ncols(), t_eval.len());
876 assert_eq!(ts, t_eval);
877 for (col, &t) in t_eval.iter().enumerate() {
878 assert_close(
879 ys[(0, col)],
880 hybrid_logistic_state(r, t),
881 5e-4,
882 &format!("jit hybrid dense value[{col}]"),
883 );
884 }
885 }
886
887 #[cfg(feature = "diffsl-llvm")]
888 fn assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(
889 ode_solver: OdeSolverType,
890 ) {
891 let r = 2.0;
892 let t_eval = hybrid_t_eval();
893 let ode = make_hybrid_ode(JitBackendType::Llvm, MatrixType::NalgebraDense, ode_solver);
894 ode.set_rtol(1e-8).unwrap();
895 ode.set_atol(1e-8).unwrap();
896
897 let solution = ode
898 .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
899 .unwrap();
900 let ys = solution.get_ys().unwrap();
901 let ys = ys.as_array::<f64>().unwrap();
902 let sens = solution.get_sens().unwrap();
903
904 assert_eq!(ys.nrows(), 1);
905 assert_eq!(ys.ncols(), t_eval.len());
906 assert_eq!(sens.len(), 1);
907 let sens_values = sens[0].as_array::<f64>().unwrap();
908 assert_eq!(sens_values.nrows(), 1);
909 assert_eq!(sens_values.ncols(), t_eval.len());
910 for (col, &t) in t_eval.iter().enumerate() {
911 assert_close(
912 ys[(0, col)],
913 hybrid_logistic_state(r, t),
914 5e-4,
915 &format!("jit hybrid sens value[{col}]"),
916 );
917 assert_close(
918 sens_values[(0, col)],
919 hybrid_logistic_state_dr(r, t),
920 5e-4,
921 &format!("jit hybrid sensitivity[{col}]"),
922 );
923 }
924 }
925
926 #[test]
927 fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
928 for jit_backend in available_jit_backends() {
929 for matrix_type in [
930 MatrixType::NalgebraDense,
931 MatrixType::FaerDense,
932 MatrixType::FaerSparse,
933 ] {
934 assert_runtime_dispatch(jit_backend, matrix_type);
935 }
936 }
937 }
938
939 #[test]
940 fn dense_solution_matches_logistic_solution_from_diffsl() {
941 for jit_backend in available_jit_backends() {
942 for (matrix_type, solver) in [
943 (MatrixType::FaerDense, OdeSolverType::Esdirk34),
944 (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
945 (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
946 ] {
947 assert_solver_dense_solution(jit_backend, matrix_type, solver);
948 }
949 }
950 }
951
952 #[test]
953 fn bdf_dense_solution_matches_logistic_diffsl_model() {
954 for jit_backend in available_jit_backends() {
955 let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
956 ode.set_rtol(1e-8).unwrap();
957 ode.set_atol(1e-8).unwrap();
958
959 let t_eval = [0.25, 0.5, 1.0];
960 let solution = ode
961 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
962 .unwrap();
963
964 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
965 }
966 }
967
968 #[test]
969 fn bdf_solution_matches_logistic_diffsl_model() {
970 for jit_backend in available_jit_backends() {
971 let x0 = LOGISTIC_X0;
972 let r = 2.0;
973 let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
974 ode.set_rtol(1e-8).unwrap();
975 ode.set_atol(1e-8).unwrap();
976
977 let final_time = 1.0;
978 let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
979
980 let ys = solution.get_ys().unwrap();
981 let ys = ys.as_array::<f64>().unwrap();
982 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
983
984 assert_eq!(ys.nrows(), 1);
985 assert_eq!(ys.ncols(), ts.len());
986 assert!(
987 !ts.is_empty(),
988 "expected solve() to record at least one time point"
989 );
990 assert_close(
991 *ts.last().unwrap(),
992 final_time,
993 ASSERT_TOL,
994 "solve final time",
995 );
996 for (i, &t) in ts.iter().enumerate() {
997 assert_close(
998 ys[(0, i)],
999 logistic_state(x0, r, t),
1000 5e-4,
1001 &format!("solve value[{i}]"),
1002 );
1003 }
1004 }
1005 }
1006
1007 #[test]
1008 fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
1009 for jit_backend in available_jit_backends() {
1010 for ode_solver in all_ode_solvers() {
1011 assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1012 jit_backend,
1013 ode_solver,
1014 );
1015 }
1016 }
1017 }
1018
1019 #[test]
1020 fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
1021 for jit_backend in available_jit_backends() {
1022 for ode_solver in all_ode_solvers() {
1023 assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1024 jit_backend,
1025 ode_solver,
1026 );
1027 }
1028 }
1029 }
1030
1031 #[cfg(feature = "diffsl-llvm")]
1032 #[test]
1033 fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
1034 let ode = make_ode(
1035 JitBackendType::Llvm,
1036 MatrixType::NalgebraDense,
1037 OdeSolverType::Bdf,
1038 );
1039 ode.set_rtol(1e-8).unwrap();
1040 ode.set_atol(1e-8).unwrap();
1041
1042 let t_eval = [0.25, 0.5, 1.0];
1043 let solution = ode
1044 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
1045 .unwrap();
1046
1047 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1048 let sens = solution.get_sens().unwrap();
1049 assert_eq!(sens.len(), 1);
1050 let sens_values = sens[0].as_array::<f64>().unwrap();
1051 assert_eq!(sens_values.nrows(), 1);
1052 assert_eq!(sens_values.ncols(), t_eval.len());
1053 for (i, &t) in t_eval.iter().enumerate() {
1054 assert_close(
1055 sens_values[(0, i)],
1056 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1057 ASSERT_TOL,
1058 &format!("jit sensitivity[{i}]"),
1059 );
1060 }
1061 }
1062
1063 #[cfg(feature = "diffsl-llvm")]
1064 #[test]
1065 fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
1066 for ode_solver in all_ode_solvers() {
1067 assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(ode_solver);
1068 }
1069 }
1070
1071 #[cfg(feature = "diffsl-llvm")]
1072 #[test]
1073 fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
1074 let ode = make_ode(
1075 JitBackendType::Llvm,
1076 MatrixType::NalgebraDense,
1077 OdeSolverType::Bdf,
1078 );
1079 ode.set_rtol(1e-8).unwrap();
1080 ode.set_atol(1e-8).unwrap();
1081
1082 let t_eval = [0.0, 0.25, 0.5, 1.0];
1083 let data_values: Vec<f64> = t_eval
1084 .iter()
1085 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1086 .collect();
1087 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1088 let (value, sens) = ode
1089 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
1090 .unwrap();
1091 let grad = Vec::<f64>::from_host_array(sens).unwrap();
1092
1093 assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1094 assert_eq!(grad.len(), 1);
1095 assert!(
1096 grad[0].is_finite(),
1097 "jit sum_squares gradient should be finite"
1098 );
1099 }
1100
1101 #[cfg(feature = "diffsl-llvm")]
1102 #[test]
1103 fn bdf_sum_squares_adjoint_matches_finite_difference_gradient_for_logistic_model() {
1104 let logistic_model = r#"
1105 in_i { r = 1, k = 1, y0 = 0.1 }
1106 u { y0 }
1107 F { r * u * (1.0 - u / k) }
1108 "#;
1109 let ode = OdeWrapper::new_jit(
1110 logistic_model,
1111 JitBackendType::Llvm,
1112 ScalarType::F64,
1113 MatrixType::NalgebraDense,
1114 LinearSolverType::Default,
1115 OdeSolverType::Bdf,
1116 )
1117 .unwrap();
1118 ode.set_rtol(1e-8).unwrap();
1119 ode.set_atol(1e-8).unwrap();
1120
1121 let t_eval = [0.0, 0.1, 0.3, 0.7, 1.0];
1122 let data_params = [1.2, 0.9, 0.2];
1123 let fit_params = [0.8, 1.3, 0.12];
1124 let fd_step = 1e-6;
1125
1126 let data_solution = ode
1127 .solve_dense(vector_host(&data_params), vector_host(&t_eval))
1128 .unwrap();
1129 let data_ys = data_solution.get_ys().unwrap();
1130 let data_ys = data_ys.as_array::<f64>().unwrap();
1131 let data_values: Vec<f64> = (0..t_eval.len()).map(|col| data_ys[(0, col)]).collect();
1132
1133 let objective_from_dense = |params: [f64; 3]| -> f64 {
1134 let solution = ode
1135 .solve_dense(vector_host(¶ms), vector_host(&t_eval))
1136 .unwrap();
1137 let ys = solution.get_ys().unwrap();
1138 let ys = ys.as_array::<f64>().unwrap();
1139 (0..t_eval.len())
1140 .map(|col| {
1141 let residual = ys[(0, col)] - data_values[col];
1142 residual * residual
1143 })
1144 .sum()
1145 };
1146
1147 let objective_fd = objective_from_dense(fit_params);
1148 let mut finite_difference_gradient = [0.0; 3];
1149 for i in 0..fit_params.len() {
1150 let mut plus = fit_params;
1151 let mut minus = fit_params;
1152 let step = fd_step * fit_params[i].abs().max(1.0);
1153 plus[i] += step;
1154 minus[i] -= step;
1155 finite_difference_gradient[i] =
1156 (objective_from_dense(plus) - objective_from_dense(minus)) / (2.0 * step);
1157 }
1158
1159 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1160 let ode_adj = OdeWrapper::new_jit(
1161 logistic_model,
1162 JitBackendType::Llvm,
1163 ScalarType::F64,
1164 MatrixType::NalgebraDense,
1165 LinearSolverType::Default,
1166 OdeSolverType::Bdf,
1167 )
1168 .unwrap();
1169 ode_adj.set_rtol(1e-8).unwrap();
1170 ode_adj.set_atol(1e-8).unwrap();
1171
1172 let (objective_adj, sens) = ode_adj
1173 .solve_sum_squares_adj(vector_host(&fit_params), data, vector_host(&t_eval))
1174 .unwrap();
1175 let adjoint_gradient = Vec::<f64>::from_host_array(sens).unwrap();
1176
1177 assert_eq!(adjoint_gradient.len(), 3);
1178 assert_close(
1179 objective_adj,
1180 objective_fd,
1181 1e-5,
1182 "sum_squares objective from dense finite differences",
1183 );
1184 for i in 0..adjoint_gradient.len() {
1185 assert_close(
1186 adjoint_gradient[i],
1187 finite_difference_gradient[i],
1188 5e-4,
1189 &format!("sum_squares gradient component {i}"),
1190 );
1191 }
1192 }
1193}