1use std::sync::{Arc, Mutex};
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer};
4
5use crate::jit::JitBackendType;
6use crate::{
7 adjoint_checkpoint::AdjointCheckpointWrapper,
8 error::DiffsolRtError,
9 host_array::{FromHostArray, HostArray},
10 initial_condition_options::{
11 InitialConditionSolverOptions, InitialConditionSolverOptionsSnapshot,
12 },
13 linear_solver_type::LinearSolverType,
14 matrix_type::MatrixType,
15 ode_options::{OdeSolverOptions, OdeSolverOptionsSnapshot},
16 ode_solver_type::OdeSolverType,
17 scalar_type::ScalarType,
18 solution_wrapper::SolutionWrapper,
19 solve::Solve,
20};
21
22pub struct Ode {
23 pub(crate) solve: Box<dyn Solve>,
24 code: String,
25 scalar_type: ScalarType,
26 jit_backend: Option<JitBackendType>,
27 linear_solver: LinearSolverType,
28 ode_solver: OdeSolverType,
29}
30
31unsafe impl Send for Ode {}
32unsafe impl Sync for Ode {}
33
34#[derive(Clone)]
35pub struct OdeWrapper(Arc<Mutex<Ode>>);
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38struct OdeWrapperSnapshot {
39 code: String,
40 equation: Vec<u8>,
41 jit_backend: JitBackendType,
42 scalar_type: ScalarType,
43 matrix_type: MatrixType,
44 linear_solver: LinearSolverType,
45 ode_solver: OdeSolverType,
46 rtol: f64,
47 atol: f64,
48 t0: f64,
49 h0: f64,
50 integrate_out: bool,
51 sens_rtol: Option<f64>,
52 sens_atol: Option<f64>,
53 out_rtol: Option<f64>,
54 out_atol: Option<f64>,
55 param_rtol: Option<f64>,
56 param_atol: Option<f64>,
57 ic_options: InitialConditionSolverOptionsSnapshot,
58 ode_options: OdeSolverOptionsSnapshot,
59}
60
61impl OdeWrapper {
62 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
63 self.0.lock().map_err(|_| {
64 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
65 "Failed to acquire lock on ODE solver".to_string(),
66 ))
67 })
68 }
69}
70
71impl OdeWrapper {
72 fn snapshot(&self) -> Result<OdeWrapperSnapshot, DiffsolRtError> {
73 let ode = self.guard()?;
74 let jit_backend = ode.jit_backend.ok_or_else(|| {
75 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
76 "OdeWrapper serialization is only supported for JIT-backed solvers".to_string(),
77 ))
78 })?;
79 Ok(OdeWrapperSnapshot {
80 code: ode.code.clone(),
81 equation: ode.solve.serialized_diffsl()?,
82 jit_backend,
83 scalar_type: ode.scalar_type,
84 matrix_type: ode.solve.matrix_type(),
85 linear_solver: ode.linear_solver,
86 ode_solver: ode.ode_solver,
87 rtol: ode.solve.rtol(),
88 atol: ode.solve.atol(),
89 t0: ode.solve.t0(),
90 h0: ode.solve.h0(),
91 integrate_out: ode.solve.integrate_out(),
92 sens_rtol: ode.solve.sens_rtol(),
93 sens_atol: ode.solve.sens_atol(),
94 out_rtol: ode.solve.out_rtol(),
95 out_atol: ode.solve.out_atol(),
96 param_rtol: ode.solve.param_rtol(),
97 param_atol: ode.solve.param_atol(),
98 ic_options: InitialConditionSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
99 ode_options: OdeSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
100 })
101 }
102
103 fn build(
104 code: String,
105 scalar_type: ScalarType,
106 solve: Box<dyn Solve>,
107 jit_backend: Option<JitBackendType>,
108 linear_solver: LinearSolverType,
109 ode_solver: OdeSolverType,
110 ) -> Result<Self, DiffsolRtError> {
111 solve.check(linear_solver)?;
112 Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
113 code,
114 scalar_type,
115 solve,
116 jit_backend,
117 linear_solver,
118 ode_solver,
119 }))))
120 }
121
122 fn from_snapshot(snapshot: OdeWrapperSnapshot) -> Result<Self, DiffsolRtError> {
123 let solve = crate::solve::solve_factory_from_serialized_diffsl(
124 snapshot.equation.as_slice(),
125 snapshot.matrix_type,
126 snapshot.scalar_type,
127 )?;
128 let wrapper = Self::build(
129 snapshot.code,
130 snapshot.scalar_type,
131 solve,
132 Some(snapshot.jit_backend),
133 snapshot.linear_solver,
134 snapshot.ode_solver,
135 )?;
136 {
137 let mut ode = wrapper.guard()?;
138 ode.solve.set_rtol(snapshot.rtol);
139 ode.solve.set_atol(snapshot.atol);
140 ode.solve.set_t0(snapshot.t0);
141 ode.solve.set_h0(snapshot.h0);
142 ode.solve.set_integrate_out(snapshot.integrate_out);
143 ode.solve.set_sens_rtol(snapshot.sens_rtol);
144 ode.solve.set_sens_atol(snapshot.sens_atol);
145 ode.solve.set_out_rtol(snapshot.out_rtol);
146 ode.solve.set_out_atol(snapshot.out_atol);
147 ode.solve.set_param_rtol(snapshot.param_rtol);
148 ode.solve.set_param_atol(snapshot.param_atol);
149 snapshot.ic_options.apply_to_solve(ode.solve.as_mut());
150 snapshot.ode_options.apply_to_solve(ode.solve.as_mut());
151 }
152 Ok(wrapper)
153 }
154
155 #[cfg(feature = "external")]
157 pub fn new_external(
158 rhs_state_deps: Vec<(usize, usize)>,
159 rhs_input_deps: Vec<(usize, usize)>,
160 mass_state_deps: Vec<(usize, usize)>,
161 scalar_type: ScalarType,
162 matrix_type: MatrixType,
163 linear_solver: LinearSolverType,
164 ode_solver: OdeSolverType,
165 ) -> Result<Self, DiffsolRtError> {
166 let solve = crate::solve::solve_factory_external(
167 rhs_state_deps,
168 rhs_input_deps,
169 mass_state_deps,
170 matrix_type,
171 scalar_type,
172 )?;
173 Self::build(
174 String::new(),
175 scalar_type,
176 solve,
177 None,
178 linear_solver,
179 ode_solver,
180 )
181 }
182
183 #[cfg(feature = "diffsl-external-dynamic")]
185 #[allow(clippy::too_many_arguments)]
186 pub fn new_external_dynamic(
187 path: impl Into<std::path::PathBuf>,
188 rhs_state_deps: Vec<(usize, usize)>,
189 rhs_input_deps: Vec<(usize, usize)>,
190 mass_state_deps: Vec<(usize, usize)>,
191 scalar_type: ScalarType,
192 matrix_type: MatrixType,
193 linear_solver: LinearSolverType,
194 ode_solver: OdeSolverType,
195 ) -> Result<Self, DiffsolRtError> {
196 let solve = crate::solve::solve_factory_external_dynamic(
197 path.into(),
198 rhs_state_deps,
199 rhs_input_deps,
200 mass_state_deps,
201 matrix_type,
202 scalar_type,
203 )?;
204 Self::build(
205 String::new(),
206 scalar_type,
207 solve,
208 None,
209 linear_solver,
210 ode_solver,
211 )
212 }
213
214 #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
216 pub fn new_jit(
217 code: &str,
218 jit_backend: JitBackendType,
219 scalar_type: ScalarType,
220 matrix_type: MatrixType,
221 linear_solver: LinearSolverType,
222 ode_solver: OdeSolverType,
223 ) -> Result<Self, DiffsolRtError> {
224 let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
225 Self::build(
226 code.to_owned(),
227 scalar_type,
228 solve,
229 Some(jit_backend),
230 linear_solver,
231 ode_solver,
232 )
233 }
234
235 pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
237 Ok(self.guard()?.solve.matrix_type())
238 }
239
240 pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
241 Ok(self.guard()?.solve.nstates())
242 }
243
244 pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
245 Ok(self.guard()?.solve.nparams())
246 }
247
248 pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
249 Ok(self.guard()?.solve.nout())
250 }
251
252 pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
253 Ok(self.guard()?.solve.has_stop())
254 }
255
256 pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
258 Ok(self.guard()?.ode_solver)
259 }
260
261 pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
262 self.guard()?.ode_solver = value;
263 Ok(())
264 }
265
266 pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
269 Ok(self.guard()?.linear_solver)
270 }
271
272 pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
273 self.guard()?.solve.check(value)?;
274 self.guard()?.linear_solver = value;
275 Ok(())
276 }
277
278 pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
280 Ok(self.guard()?.solve.rtol())
281 }
282
283 pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
284 self.guard()?.solve.set_rtol(value);
285 Ok(())
286 }
287
288 pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
290 Ok(self.guard()?.solve.atol())
291 }
292
293 pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
294 self.guard()?.solve.set_atol(value);
295 Ok(())
296 }
297
298 pub fn get_t0(&self) -> Result<f64, DiffsolRtError> {
300 Ok(self.guard()?.solve.t0())
301 }
302
303 pub fn set_t0(&self, value: f64) -> Result<(), DiffsolRtError> {
304 self.guard()?.solve.set_t0(value);
305 Ok(())
306 }
307
308 pub fn get_h0(&self) -> Result<f64, DiffsolRtError> {
310 Ok(self.guard()?.solve.h0())
311 }
312
313 pub fn set_h0(&self, value: f64) -> Result<(), DiffsolRtError> {
314 self.guard()?.solve.set_h0(value);
315 Ok(())
316 }
317
318 pub fn get_integrate_out(&self) -> Result<bool, DiffsolRtError> {
320 Ok(self.guard()?.solve.integrate_out())
321 }
322
323 pub fn set_integrate_out(&self, value: bool) -> Result<(), DiffsolRtError> {
324 self.guard()?.solve.set_integrate_out(value);
325 Ok(())
326 }
327
328 pub fn get_sens_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
330 Ok(self.guard()?.solve.sens_rtol())
331 }
332
333 pub fn set_sens_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
334 self.guard()?.solve.set_sens_rtol(value);
335 Ok(())
336 }
337
338 pub fn get_sens_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
340 Ok(self.guard()?.solve.sens_atol())
341 }
342
343 pub fn set_sens_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
344 self.guard()?.solve.set_sens_atol(value);
345 Ok(())
346 }
347
348 pub fn get_out_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
350 Ok(self.guard()?.solve.out_rtol())
351 }
352
353 pub fn set_out_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
354 self.guard()?.solve.set_out_rtol(value);
355 Ok(())
356 }
357
358 pub fn get_out_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
360 Ok(self.guard()?.solve.out_atol())
361 }
362
363 pub fn set_out_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
364 self.guard()?.solve.set_out_atol(value);
365 Ok(())
366 }
367
368 pub fn get_param_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
370 Ok(self.guard()?.solve.param_rtol())
371 }
372
373 pub fn set_param_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
374 self.guard()?.solve.set_param_rtol(value);
375 Ok(())
376 }
377
378 pub fn get_param_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
380 Ok(self.guard()?.solve.param_atol())
381 }
382
383 pub fn set_param_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
384 self.guard()?.solve.set_param_atol(value);
385 Ok(())
386 }
387
388 pub fn get_code(&self) -> Result<String, DiffsolRtError> {
389 Ok(self.guard()?.code.clone())
390 }
391
392 pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
393 Ok(self.guard()?.scalar_type)
394 }
395
396 pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
397 Ok(self.guard()?.jit_backend)
398 }
399
400 pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
401 InitialConditionSolverOptions::new(self.0.clone())
402 }
403
404 pub fn get_options(&self) -> OdeSolverOptions {
405 OdeSolverOptions::new(self.0.clone())
406 }
407
408 pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
410 let mut self_guard = self.guard()?;
411 self_guard.solve.y0(params.as_slice()?)
412 }
413
414 pub fn rhs(
416 &self,
417 params: HostArray,
418 t: f64,
419 y: HostArray,
420 ) -> Result<HostArray, DiffsolRtError> {
421 let mut self_guard = self.guard()?;
422 self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
423 }
424
425 pub fn rhs_jac_mul(
427 &self,
428 params: HostArray,
429 t: f64,
430 y: HostArray,
431 v: HostArray,
432 ) -> Result<HostArray, DiffsolRtError> {
433 let mut self_guard = self.guard()?;
434 self_guard
435 .solve
436 .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
437 }
438
439 #[allow(clippy::type_complexity)]
457 pub fn solve(
458 &self,
459 params: HostArray,
460 final_time: f64,
461 ) -> Result<SolutionWrapper, DiffsolRtError> {
462 let mut self_guard = self.guard()?;
463 let params = params.as_slice()?;
464 let linear_solver = self_guard.linear_solver;
465 let method = self_guard.ode_solver;
466 let solution = self_guard
467 .solve
468 .solve(method, linear_solver, params, final_time)?;
469 Ok(SolutionWrapper::new(solution))
470 }
471
472 pub fn solve_dense(
486 &self,
487 params: HostArray,
488 t_eval: HostArray,
489 ) -> Result<SolutionWrapper, DiffsolRtError> {
490 let mut self_guard = self.guard()?;
491 let params = params.as_slice()?;
492 let t_eval = t_eval.as_slice()?;
493 let linear_solver = self_guard.linear_solver;
494 let method = self_guard.ode_solver;
495 let solution = self_guard
496 .solve
497 .solve_dense(method, linear_solver, params, t_eval)?;
498 Ok(SolutionWrapper::new(solution))
499 }
500
501 #[allow(clippy::type_complexity)]
514 pub fn solve_fwd_sens(
515 &self,
516 params: HostArray,
517 t_eval: HostArray,
518 ) -> Result<SolutionWrapper, DiffsolRtError> {
519 let mut self_guard = self.guard()?;
520 let params = params.as_slice()?;
521 let t_eval = t_eval.as_slice()?;
522 let linear_solver = self_guard.linear_solver;
523 let method = self_guard.ode_solver;
524 let solution = self_guard
525 .solve
526 .solve_fwd_sens(method, linear_solver, params, t_eval)?;
527 Ok(SolutionWrapper::new(solution))
528 }
529
530 pub fn solve_continuous_adjoint(
536 &self,
537 params: HostArray,
538 final_time: f64,
539 ) -> Result<(HostArray, HostArray), DiffsolRtError> {
540 let mut self_guard = self.guard()?;
541 let linear_solver = self_guard.linear_solver;
542 let ode_solver = self_guard.ode_solver;
543 self_guard.solve.solve_continuous_adjoint(
544 ode_solver,
545 linear_solver,
546 params.as_slice()?,
547 final_time,
548 )
549 }
550
551 pub fn solve_adjoint_fwd(
554 &self,
555 params: HostArray,
556 t_eval: HostArray,
557 ) -> Result<(SolutionWrapper, AdjointCheckpointWrapper), DiffsolRtError> {
558 let mut self_guard = self.guard()?;
559 let params = params.as_slice()?;
560 let t_eval = t_eval.as_slice()?;
561 let linear_solver = self_guard.linear_solver;
562 let method = self_guard.ode_solver;
563 let (solution, checkpoint) =
564 self_guard
565 .solve
566 .solve_adjoint_fwd(method, linear_solver, params, t_eval)?;
567 Ok((SolutionWrapper::new(solution), checkpoint))
568 }
569
570 pub fn solve_adjoint_bkwd(
576 &self,
577 solution: &SolutionWrapper,
578 checkpoint: &AdjointCheckpointWrapper,
579 dgdu_eval: HostArray,
580 ) -> Result<HostArray, DiffsolRtError> {
581 let t_eval_host = solution.get_ts()?;
582 let t_eval = Vec::<f64>::from_host_array(t_eval_host)?;
583 let mut self_guard = self.guard()?;
584 let linear_solver = self_guard.linear_solver;
585 let method = self_guard.ode_solver;
586 self_guard
587 .solve
588 .solve_adjoint_bkwd(method, linear_solver, checkpoint, &t_eval, dgdu_eval)
589 }
590}
591
592impl Serialize for OdeWrapper {
593 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
594 where
595 S: Serializer,
596 {
597 self.snapshot()
598 .map_err(serde::ser::Error::custom)?
599 .serialize(serializer)
600 }
601}
602
603impl<'de> Deserialize<'de> for OdeWrapper {
604 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
605 where
606 D: Deserializer<'de>,
607 {
608 let snapshot = OdeWrapperSnapshot::deserialize(deserializer)?;
609 Self::from_snapshot(snapshot).map_err(DeError::custom)
610 }
611}
612
613#[cfg(all(test, feature = "diffsl-external-f64"))]
614mod tests {
615 use super::*;
616 use crate::host_array::FromHostArray;
617 use crate::linear_solver_type::LinearSolverType;
618 use crate::scalar_type::ScalarType;
619 use crate::test_support::{
620 assert_close, assert_solution_tail, logistic_state, logistic_state_dr, mass_state_deps,
621 rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
622 };
623
624 fn all_ode_solvers() -> [OdeSolverType; 4] {
625 [
626 OdeSolverType::Bdf,
627 OdeSolverType::Esdirk34,
628 OdeSolverType::TrBdf2,
629 OdeSolverType::Tsit45,
630 ]
631 }
632
633 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
634 OdeWrapper::new_external(
635 rhs_state_deps(),
636 rhs_input_deps(),
637 mass_state_deps(),
638 ScalarType::F64,
639 matrix_type,
640 LinearSolverType::Default,
641 ode_solver,
642 )
643 .unwrap()
644 }
645
646 fn assert_runtime_dispatch(matrix_type: MatrixType) {
647 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
648 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
649 assert_eq!(ode.get_nstates().unwrap(), 1);
650 assert_eq!(ode.get_nparams().unwrap(), 1);
651 assert_eq!(ode.get_nout().unwrap(), 1);
652 assert!(ode.has_stop().unwrap());
653
654 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
655 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
656
657 let rhs = ode
658 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
659 .unwrap();
660 assert_close(
661 Vec::<f64>::from_host_array(rhs).unwrap()[0],
662 0.375,
663 ASSERT_TOL,
664 "rhs(0.25)",
665 );
666
667 let rhs_jac_mul = ode
668 .rhs_jac_mul(
669 vector_host(&[2.0]),
670 0.0,
671 vector_host(&[0.25]),
672 vector_host(&[3.0]),
673 )
674 .unwrap();
675 assert_close(
676 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
677 3.0,
678 ASSERT_TOL,
679 "rhs_jac_mul(0.25, 3.0)",
680 );
681 }
682
683 fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
684 let ode = make_ode(matrix_type, ode_solver);
685 ode.set_rtol(1e-8).unwrap();
686 ode.set_atol(1e-8).unwrap();
687
688 let t_eval = [0.25, 0.5, 1.0];
689 let solution = ode
690 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
691 .unwrap();
692
693 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
694 }
695
696 fn hybrid_root_time() -> f64 {
697 0.5 * 9.0_f64.ln()
698 }
699
700 fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
701 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
702 ode.set_rtol(1e-8).unwrap();
703 ode.set_atol(1e-8).unwrap();
704
705 let final_time = 2.0;
706 let solution = ode.solve(vector_host(&[2.0]), final_time).unwrap();
707 let ys = solution.get_ys().unwrap();
708 let ys = ys.as_array::<f64>().unwrap();
709 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
710 let root_time = hybrid_root_time();
711
712 assert_eq!(ys.nrows(), 1);
713 assert_eq!(ys.ncols(), ts.len());
714 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
715 assert_close(
716 *ts.last().unwrap(),
717 final_time,
718 ASSERT_TOL,
719 "hybrid final time",
720 );
721 assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
722 assert!(
723 ts.iter().any(|&t| t < root_time),
724 "expected pre-root samples"
725 );
726 assert!(
727 ts.iter().any(|&t| t > root_time),
728 "expected post-root samples after reset"
729 );
730 }
731
732 fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
733 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
734 ode.set_rtol(1e-8).unwrap();
735 ode.set_atol(1e-8).unwrap();
736
737 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
738 let solution = ode
739 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
740 .unwrap();
741 let ys = solution.get_ys().unwrap();
742 let ys = ys.as_array::<f64>().unwrap();
743
744 assert_eq!(ys.nrows(), 1);
745 assert_eq!(ys.ncols(), t_eval.len());
746 assert_close(
747 ys[(0, 0)],
748 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
749 5e-4,
750 "hybrid dense pre-root value",
751 );
752 assert_close(
753 ys[(0, 1)],
754 logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
755 5e-4,
756 "hybrid dense near-root value",
757 );
758 for col in 2..t_eval.len() {
759 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
760 }
761 }
762
763 fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
764 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
765 ode.set_rtol(1e-8).unwrap();
766 ode.set_atol(1e-8).unwrap();
767
768 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
769 let solution = ode
770 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
771 .unwrap();
772 let ys = solution.get_ys().unwrap();
773 let ys = ys.as_array::<f64>().unwrap();
774 let sens = solution.get_sens().unwrap();
775
776 assert_eq!(ys.nrows(), 1);
777 assert_eq!(ys.ncols(), t_eval.len());
778 assert_eq!(sens.len(), 1);
779 let sens_values = sens[0].as_array::<f64>().unwrap();
780 assert_eq!(sens_values.nrows(), 1);
781 assert_eq!(sens_values.ncols(), t_eval.len());
782 assert_close(
783 ys[(0, 0)],
784 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
785 5e-4,
786 "hybrid sens pre-root value",
787 );
788 for col in 2..t_eval.len() {
789 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
790 assert!(
791 sens_values[(0, col)].is_finite(),
792 "expected finite post-root sensitivity at column {col}"
793 );
794 }
795 }
796
797 #[test]
798 fn runtime_dispatch_matches_requested_matrix_type() {
799 for matrix_type in [
800 MatrixType::NalgebraDense,
801 MatrixType::FaerDense,
802 MatrixType::FaerSparse,
803 ] {
804 assert_runtime_dispatch(matrix_type);
805 }
806 }
807
808 #[test]
809 fn bdf_dense_solution_matches_logistic_solution() {
810 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
811 ode.set_rtol(1e-8).unwrap();
812 ode.set_atol(1e-8).unwrap();
813
814 let t_eval = [0.25, 0.5, 1.0];
815 let solution = ode
816 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
817 .unwrap();
818
819 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
820 }
821
822 #[test]
823 fn esdirk34_dense_solution_matches_logistic_solution() {
824 assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
825 }
826
827 #[test]
828 fn tr_bdf2_sparse_solution_matches_logistic_solution() {
829 assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
830 }
831
832 #[test]
833 fn tsit45_dense_solution_matches_logistic_solution() {
834 assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
835 }
836
837 #[test]
838 fn bdf_forward_sensitivities_match_logistic_derivative() {
839 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
840 ode.set_rtol(1e-8).unwrap();
841 ode.set_atol(1e-8).unwrap();
842
843 let t_eval = [0.25, 0.5, 1.0];
844 let solution = ode
845 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
846 .unwrap();
847
848 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
849 let sens = solution.get_sens().unwrap();
850 assert_eq!(sens.len(), 1);
851 let sens_values = sens[0].as_array::<f64>().unwrap();
852 assert_eq!(sens_values.nrows(), 1);
853 assert_eq!(sens_values.ncols(), t_eval.len());
854 for (i, &t) in t_eval.iter().enumerate() {
855 assert_close(
856 sens_values[(0, i)],
857 logistic_state_dr(LOGISTIC_X0, 2.0, t),
858 ASSERT_TOL,
859 &format!("sensitivity[{i}]"),
860 );
861 }
862 }
863
864 #[test]
865 fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
866 for ode_solver in all_ode_solvers() {
867 assert_hybrid_solution_applies_reset_after_root(ode_solver);
868 }
869 }
870
871 #[test]
872 fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
873 for ode_solver in all_ode_solvers() {
874 assert_hybrid_dense_solution_continues_after_reset(ode_solver);
875 }
876 }
877
878 #[test]
879 fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
880 for ode_solver in all_ode_solvers() {
881 assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
882 }
883 }
884}
885
886#[cfg(all(test, feature = "diffsl-external-dynamic"))]
887mod dynamic_tests {
888 use crate::host_array::FromHostArray;
889 use crate::linear_solver_type::LinearSolverType;
890 use crate::scalar_type::ScalarType;
891 use crate::test_support::{
892 assert_close, assert_solution_tail, external_dynamic_fixture_path, mass_state_deps,
893 rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
894 };
895
896 use super::*;
897
898 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
899 OdeWrapper::new_external_dynamic(
900 external_dynamic_fixture_path(),
901 rhs_state_deps(),
902 rhs_input_deps(),
903 mass_state_deps(),
904 ScalarType::F64,
905 matrix_type,
906 LinearSolverType::Default,
907 ode_solver,
908 )
909 .unwrap()
910 }
911
912 #[test]
913 fn runtime_dispatch_matches_requested_matrix_type() {
914 for matrix_type in [
915 MatrixType::NalgebraDense,
916 MatrixType::FaerDense,
917 MatrixType::FaerSparse,
918 ] {
919 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
920 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
921 assert_eq!(ode.get_code().unwrap(), "");
922 assert_eq!(ode.get_jit_backend().unwrap(), None);
923 assert_eq!(ode.get_nstates().unwrap(), 1);
924 assert_eq!(ode.get_nparams().unwrap(), 1);
925 assert_eq!(ode.get_nout().unwrap(), 1);
926 assert!(ode.has_stop().unwrap());
927
928 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
929 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
930
931 let rhs = ode
932 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
933 .unwrap();
934 assert_close(
935 Vec::<f64>::from_host_array(rhs).unwrap()[0],
936 0.375,
937 ASSERT_TOL,
938 "rhs(0.25)",
939 );
940
941 let rhs_jac_mul = ode
942 .rhs_jac_mul(
943 vector_host(&[2.0]),
944 0.0,
945 vector_host(&[0.25]),
946 vector_host(&[3.0]),
947 )
948 .unwrap();
949 assert_close(
950 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
951 3.0,
952 ASSERT_TOL,
953 "rhs_jac_mul(0.25, 3.0)",
954 );
955 }
956 }
957
958 #[test]
959 fn dense_solution_matches_logistic_solution() {
960 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
961 ode.set_rtol(1e-8).unwrap();
962 ode.set_atol(1e-8).unwrap();
963
964 let t_eval = [0.25, 0.5, 1.0];
965 let solution = ode
966 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
967 .unwrap();
968
969 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
970 }
971
972 #[test]
973 fn non_jit_serialization_is_rejected() {
974 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
975 let err = serde_json::to_string(&ode).unwrap_err().to_string();
976 assert!(err.contains("JIT-backed"));
977 }
978}