1use std::sync::{Arc, Mutex};
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer};
4
5use crate::jit::JitBackendType;
6use crate::{
7 error::DiffsolRtError,
8 host_array::HostArray,
9 initial_condition_options::{
10 InitialConditionSolverOptions, InitialConditionSolverOptionsSnapshot,
11 },
12 linear_solver_type::LinearSolverType,
13 matrix_type::MatrixType,
14 ode_options::{OdeSolverOptions, OdeSolverOptionsSnapshot},
15 ode_solver_type::OdeSolverType,
16 scalar_type::ScalarType,
17 solution_wrapper::SolutionWrapper,
18 solve::Solve,
19};
20
21pub struct Ode {
22 pub(crate) solve: Box<dyn Solve>,
23 code: String,
24 scalar_type: ScalarType,
25 jit_backend: Option<JitBackendType>,
26 linear_solver: LinearSolverType,
27 ode_solver: OdeSolverType,
28}
29
30unsafe impl Send for Ode {}
31unsafe impl Sync for Ode {}
32
33#[derive(Clone)]
34pub struct OdeWrapper(Arc<Mutex<Ode>>);
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
37struct OdeWrapperSnapshot {
38 code: String,
39 equation: Vec<u8>,
40 jit_backend: JitBackendType,
41 scalar_type: ScalarType,
42 matrix_type: MatrixType,
43 linear_solver: LinearSolverType,
44 ode_solver: OdeSolverType,
45 rtol: f64,
46 atol: f64,
47 ic_options: InitialConditionSolverOptionsSnapshot,
48 ode_options: OdeSolverOptionsSnapshot,
49}
50
51impl OdeWrapper {
52 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
53 self.0.lock().map_err(|_| {
54 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
55 "Failed to acquire lock on ODE solver".to_string(),
56 ))
57 })
58 }
59}
60
61impl OdeWrapper {
62 fn snapshot(&self) -> Result<OdeWrapperSnapshot, DiffsolRtError> {
63 let ode = self.guard()?;
64 let jit_backend = ode.jit_backend.ok_or_else(|| {
65 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
66 "OdeWrapper serialization is only supported for JIT-backed solvers".to_string(),
67 ))
68 })?;
69 Ok(OdeWrapperSnapshot {
70 code: ode.code.clone(),
71 equation: ode.solve.serialized_diffsl()?,
72 jit_backend,
73 scalar_type: ode.scalar_type,
74 matrix_type: ode.solve.matrix_type(),
75 linear_solver: ode.linear_solver,
76 ode_solver: ode.ode_solver,
77 rtol: ode.solve.rtol(),
78 atol: ode.solve.atol(),
79 ic_options: InitialConditionSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
80 ode_options: OdeSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
81 })
82 }
83
84 fn build(
85 code: String,
86 scalar_type: ScalarType,
87 solve: Box<dyn Solve>,
88 jit_backend: Option<JitBackendType>,
89 linear_solver: LinearSolverType,
90 ode_solver: OdeSolverType,
91 ) -> Result<Self, DiffsolRtError> {
92 solve.check(linear_solver)?;
93 Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
94 code,
95 scalar_type,
96 solve,
97 jit_backend,
98 linear_solver,
99 ode_solver,
100 }))))
101 }
102
103 fn from_snapshot(snapshot: OdeWrapperSnapshot) -> Result<Self, DiffsolRtError> {
104 let solve = crate::solve::solve_factory_from_serialized_diffsl(
105 snapshot.equation.as_slice(),
106 snapshot.matrix_type,
107 snapshot.scalar_type,
108 )?;
109 let wrapper = Self::build(
110 snapshot.code,
111 snapshot.scalar_type,
112 solve,
113 Some(snapshot.jit_backend),
114 snapshot.linear_solver,
115 snapshot.ode_solver,
116 )?;
117 {
118 let mut ode = wrapper.guard()?;
119 ode.solve.set_rtol(snapshot.rtol);
120 ode.solve.set_atol(snapshot.atol);
121 snapshot.ic_options.apply_to_solve(ode.solve.as_mut());
122 snapshot.ode_options.apply_to_solve(ode.solve.as_mut());
123 }
124 Ok(wrapper)
125 }
126
127 #[cfg(feature = "external")]
129 pub fn new_external(
130 rhs_state_deps: Vec<(usize, usize)>,
131 rhs_input_deps: Vec<(usize, usize)>,
132 mass_state_deps: Vec<(usize, usize)>,
133 scalar_type: ScalarType,
134 matrix_type: MatrixType,
135 linear_solver: LinearSolverType,
136 ode_solver: OdeSolverType,
137 ) -> Result<Self, DiffsolRtError> {
138 let solve = crate::solve::solve_factory_external(
139 rhs_state_deps,
140 rhs_input_deps,
141 mass_state_deps,
142 matrix_type,
143 scalar_type,
144 )?;
145 Self::build(
146 String::new(),
147 scalar_type,
148 solve,
149 None,
150 linear_solver,
151 ode_solver,
152 )
153 }
154
155 #[cfg(feature = "diffsl-external-dynamic")]
157 #[allow(clippy::too_many_arguments)]
158 pub fn new_external_dynamic(
159 path: impl Into<std::path::PathBuf>,
160 rhs_state_deps: Vec<(usize, usize)>,
161 rhs_input_deps: Vec<(usize, usize)>,
162 mass_state_deps: Vec<(usize, usize)>,
163 scalar_type: ScalarType,
164 matrix_type: MatrixType,
165 linear_solver: LinearSolverType,
166 ode_solver: OdeSolverType,
167 ) -> Result<Self, DiffsolRtError> {
168 let solve = crate::solve::solve_factory_external_dynamic(
169 path.into(),
170 rhs_state_deps,
171 rhs_input_deps,
172 mass_state_deps,
173 matrix_type,
174 scalar_type,
175 )?;
176 Self::build(
177 String::new(),
178 scalar_type,
179 solve,
180 None,
181 linear_solver,
182 ode_solver,
183 )
184 }
185
186 #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
188 pub fn new_jit(
189 code: &str,
190 jit_backend: JitBackendType,
191 scalar_type: ScalarType,
192 matrix_type: MatrixType,
193 linear_solver: LinearSolverType,
194 ode_solver: OdeSolverType,
195 ) -> Result<Self, DiffsolRtError> {
196 let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
197 Self::build(
198 code.to_owned(),
199 scalar_type,
200 solve,
201 Some(jit_backend),
202 linear_solver,
203 ode_solver,
204 )
205 }
206
207 pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
209 Ok(self.guard()?.solve.matrix_type())
210 }
211
212 pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
213 Ok(self.guard()?.solve.nstates())
214 }
215
216 pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
217 Ok(self.guard()?.solve.nparams())
218 }
219
220 pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
221 Ok(self.guard()?.solve.nout())
222 }
223
224 pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
225 Ok(self.guard()?.solve.has_stop())
226 }
227
228 pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
230 Ok(self.guard()?.ode_solver)
231 }
232
233 pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
234 self.guard()?.ode_solver = value;
235 Ok(())
236 }
237
238 pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
241 Ok(self.guard()?.linear_solver)
242 }
243
244 pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
245 self.guard()?.solve.check(value)?;
246 self.guard()?.linear_solver = value;
247 Ok(())
248 }
249
250 pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
252 Ok(self.guard()?.solve.rtol())
253 }
254
255 pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
256 self.guard()?.solve.set_rtol(value);
257 Ok(())
258 }
259
260 pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
262 Ok(self.guard()?.solve.atol())
263 }
264
265 pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
266 self.guard()?.solve.set_atol(value);
267 Ok(())
268 }
269
270 pub fn get_code(&self) -> Result<String, DiffsolRtError> {
271 Ok(self.guard()?.code.clone())
272 }
273
274 pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
275 Ok(self.guard()?.scalar_type)
276 }
277
278 pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
279 Ok(self.guard()?.jit_backend)
280 }
281
282 pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
283 InitialConditionSolverOptions::new(self.0.clone())
284 }
285
286 pub fn get_options(&self) -> OdeSolverOptions {
287 OdeSolverOptions::new(self.0.clone())
288 }
289
290 pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
292 let mut self_guard = self.guard()?;
293 self_guard.solve.y0(params.as_slice()?)
294 }
295
296 pub fn rhs(
298 &self,
299 params: HostArray,
300 t: f64,
301 y: HostArray,
302 ) -> Result<HostArray, DiffsolRtError> {
303 let mut self_guard = self.guard()?;
304 self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
305 }
306
307 pub fn rhs_jac_mul(
309 &self,
310 params: HostArray,
311 t: f64,
312 y: HostArray,
313 v: HostArray,
314 ) -> Result<HostArray, DiffsolRtError> {
315 let mut self_guard = self.guard()?;
316 self_guard
317 .solve
318 .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
319 }
320
321 #[allow(clippy::type_complexity)]
339 pub fn solve(
340 &self,
341 params: HostArray,
342 final_time: f64,
343 ) -> Result<SolutionWrapper, DiffsolRtError> {
344 let mut self_guard = self.guard()?;
345 let params = params.as_slice()?;
346 let linear_solver = self_guard.linear_solver;
347 let method = self_guard.ode_solver;
348 let solution = self_guard
349 .solve
350 .solve(method, linear_solver, params, final_time)?;
351 Ok(SolutionWrapper::new(solution))
352 }
353
354 pub fn solve_hybrid(
357 &self,
358 params: HostArray,
359 final_time: f64,
360 ) -> Result<SolutionWrapper, DiffsolRtError> {
361 let mut self_guard = self.guard()?;
362 let params = params.as_slice()?;
363 let linear_solver = self_guard.linear_solver;
364 let method = self_guard.ode_solver;
365 let solution = self_guard
366 .solve
367 .solve_hybrid(method, linear_solver, params, final_time)?;
368 Ok(SolutionWrapper::new(solution))
369 }
370
371 pub fn solve_dense(
385 &self,
386 params: HostArray,
387 t_eval: HostArray,
388 ) -> Result<SolutionWrapper, DiffsolRtError> {
389 let mut self_guard = self.guard()?;
390 let params = params.as_slice()?;
391 let t_eval = t_eval.as_slice()?;
392 let linear_solver = self_guard.linear_solver;
393 let method = self_guard.ode_solver;
394 let solution = self_guard
395 .solve
396 .solve_dense(method, linear_solver, params, t_eval)?;
397 Ok(SolutionWrapper::new(solution))
398 }
399
400 pub fn solve_hybrid_dense(
404 &self,
405 params: HostArray,
406 t_eval: HostArray,
407 ) -> Result<SolutionWrapper, DiffsolRtError> {
408 let mut self_guard = self.guard()?;
409 let params = params.as_slice()?;
410 let t_eval = t_eval.as_slice()?;
411 let linear_solver = self_guard.linear_solver;
412 let method = self_guard.ode_solver;
413 let solution =
414 self_guard
415 .solve
416 .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
417 Ok(SolutionWrapper::new(solution))
418 }
419
420 #[allow(clippy::type_complexity)]
433 pub fn solve_fwd_sens(
434 &self,
435 params: HostArray,
436 t_eval: HostArray,
437 ) -> Result<SolutionWrapper, DiffsolRtError> {
438 let mut self_guard = self.guard()?;
439 let params = params.as_slice()?;
440 let t_eval = t_eval.as_slice()?;
441 let linear_solver = self_guard.linear_solver;
442 let method = self_guard.ode_solver;
443 let solution = self_guard
444 .solve
445 .solve_fwd_sens(method, linear_solver, params, t_eval)?;
446 Ok(SolutionWrapper::new(solution))
447 }
448
449 #[allow(clippy::type_complexity)]
453 pub fn solve_hybrid_fwd_sens(
454 &self,
455 params: HostArray,
456 t_eval: HostArray,
457 ) -> Result<SolutionWrapper, DiffsolRtError> {
458 let mut self_guard = self.guard()?;
459 let params = params.as_slice()?;
460 let t_eval = t_eval.as_slice()?;
461 let linear_solver = self_guard.linear_solver;
462 let method = self_guard.ode_solver;
463 let solution =
464 self_guard
465 .solve
466 .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
467 Ok(SolutionWrapper::new(solution))
468 }
469
470 #[allow(clippy::type_complexity)]
475 pub fn solve_sum_squares_adj(
476 &self,
477 params: HostArray,
478 data: HostArray,
479 t_eval: HostArray,
480 ) -> Result<(f64, HostArray), DiffsolRtError> {
481 let mut self_guard = self.guard()?;
482 let linear_solver = self_guard.linear_solver;
483 let ode_solver = self_guard.ode_solver;
484
485 self_guard.solve.solve_sum_squares_adj(
486 ode_solver,
487 linear_solver,
488 ode_solver,
489 linear_solver,
490 params.as_slice()?,
491 data,
492 t_eval.as_slice()?,
493 )
494 }
495}
496
497impl Serialize for OdeWrapper {
498 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
499 where
500 S: Serializer,
501 {
502 self.snapshot()
503 .map_err(serde::ser::Error::custom)?
504 .serialize(serializer)
505 }
506}
507
508impl<'de> Deserialize<'de> for OdeWrapper {
509 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
510 where
511 D: Deserializer<'de>,
512 {
513 let snapshot = OdeWrapperSnapshot::deserialize(deserializer)?;
514 Self::from_snapshot(snapshot).map_err(DeError::custom)
515 }
516}
517
518#[cfg(all(test, feature = "diffsl-external-f64"))]
519mod tests {
520 use super::*;
521 use crate::host_array::FromHostArray;
522 use crate::linear_solver_type::LinearSolverType;
523 use crate::scalar_type::ScalarType;
524 use crate::test_support::{
525 assert_close, assert_solution_tail, logistic_state, logistic_state_dr, mass_state_deps,
526 rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
527 };
528
529 fn all_ode_solvers() -> [OdeSolverType; 4] {
530 [
531 OdeSolverType::Bdf,
532 OdeSolverType::Esdirk34,
533 OdeSolverType::TrBdf2,
534 OdeSolverType::Tsit45,
535 ]
536 }
537
538 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
539 OdeWrapper::new_external(
540 rhs_state_deps(),
541 rhs_input_deps(),
542 mass_state_deps(),
543 ScalarType::F64,
544 matrix_type,
545 LinearSolverType::Default,
546 ode_solver,
547 )
548 .unwrap()
549 }
550
551 fn assert_runtime_dispatch(matrix_type: MatrixType) {
552 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
553 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
554 assert_eq!(ode.get_nstates().unwrap(), 1);
555 assert_eq!(ode.get_nparams().unwrap(), 1);
556 assert_eq!(ode.get_nout().unwrap(), 1);
557 assert!(ode.has_stop().unwrap());
558
559 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
560 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
561
562 let rhs = ode
563 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
564 .unwrap();
565 assert_close(
566 Vec::<f64>::from_host_array(rhs).unwrap()[0],
567 0.375,
568 ASSERT_TOL,
569 "rhs(0.25)",
570 );
571
572 let rhs_jac_mul = ode
573 .rhs_jac_mul(
574 vector_host(&[2.0]),
575 0.0,
576 vector_host(&[0.25]),
577 vector_host(&[3.0]),
578 )
579 .unwrap();
580 assert_close(
581 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
582 3.0,
583 ASSERT_TOL,
584 "rhs_jac_mul(0.25, 3.0)",
585 );
586 }
587
588 fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
589 let ode = make_ode(matrix_type, ode_solver);
590 ode.set_rtol(1e-8).unwrap();
591 ode.set_atol(1e-8).unwrap();
592
593 let t_eval = [0.25, 0.5, 1.0];
594 let solution = ode
595 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
596 .unwrap();
597
598 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
599 }
600
601 fn hybrid_root_time() -> f64 {
602 0.5 * 9.0_f64.ln()
603 }
604
605 fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
606 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
607 ode.set_rtol(1e-8).unwrap();
608 ode.set_atol(1e-8).unwrap();
609
610 let final_time = 2.0;
611 let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
612 let ys = solution.get_ys().unwrap();
613 let ys = ys.as_array::<f64>().unwrap();
614 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
615 let root_time = hybrid_root_time();
616
617 assert_eq!(ys.nrows(), 1);
618 assert_eq!(ys.ncols(), ts.len());
619 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
620 assert_close(
621 *ts.last().unwrap(),
622 final_time,
623 ASSERT_TOL,
624 "hybrid final time",
625 );
626 assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
627 assert!(
628 ts.iter().any(|&t| t < root_time),
629 "expected pre-root samples"
630 );
631 assert!(
632 ts.iter().any(|&t| t > root_time),
633 "expected post-root samples after reset"
634 );
635 }
636
637 fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
638 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
639 ode.set_rtol(1e-8).unwrap();
640 ode.set_atol(1e-8).unwrap();
641
642 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
643 let solution = ode
644 .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
645 .unwrap();
646 let ys = solution.get_ys().unwrap();
647 let ys = ys.as_array::<f64>().unwrap();
648
649 assert_eq!(ys.nrows(), 1);
650 assert_eq!(ys.ncols(), t_eval.len());
651 assert_close(
652 ys[(0, 0)],
653 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
654 5e-4,
655 "hybrid dense pre-root value",
656 );
657 assert_close(
658 ys[(0, 1)],
659 logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
660 5e-4,
661 "hybrid dense near-root value",
662 );
663 for col in 2..t_eval.len() {
664 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
665 }
666 }
667
668 fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
669 let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
670 ode.set_rtol(1e-8).unwrap();
671 ode.set_atol(1e-8).unwrap();
672
673 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
674 let solution = ode
675 .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
676 .unwrap();
677 let ys = solution.get_ys().unwrap();
678 let ys = ys.as_array::<f64>().unwrap();
679 let sens = solution.get_sens().unwrap();
680
681 assert_eq!(ys.nrows(), 1);
682 assert_eq!(ys.ncols(), t_eval.len());
683 assert_eq!(sens.len(), 1);
684 let sens_values = sens[0].as_array::<f64>().unwrap();
685 assert_eq!(sens_values.nrows(), 1);
686 assert_eq!(sens_values.ncols(), t_eval.len());
687 assert_close(
688 ys[(0, 0)],
689 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
690 5e-4,
691 "hybrid sens pre-root value",
692 );
693 for col in 2..t_eval.len() {
694 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
695 assert!(
696 sens_values[(0, col)].is_finite(),
697 "expected finite post-root sensitivity at column {col}"
698 );
699 }
700 }
701
702 #[test]
703 fn runtime_dispatch_matches_requested_matrix_type() {
704 for matrix_type in [
705 MatrixType::NalgebraDense,
706 MatrixType::FaerDense,
707 MatrixType::FaerSparse,
708 ] {
709 assert_runtime_dispatch(matrix_type);
710 }
711 }
712
713 #[test]
714 fn bdf_dense_solution_matches_logistic_solution() {
715 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
716 ode.set_rtol(1e-8).unwrap();
717 ode.set_atol(1e-8).unwrap();
718
719 let t_eval = [0.25, 0.5, 1.0];
720 let solution = ode
721 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
722 .unwrap();
723
724 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
725 }
726
727 #[test]
728 fn esdirk34_dense_solution_matches_logistic_solution() {
729 assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
730 }
731
732 #[test]
733 fn tr_bdf2_sparse_solution_matches_logistic_solution() {
734 assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
735 }
736
737 #[test]
738 fn tsit45_dense_solution_matches_logistic_solution() {
739 assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
740 }
741
742 #[test]
743 fn bdf_forward_sensitivities_match_logistic_derivative() {
744 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
745 ode.set_rtol(1e-8).unwrap();
746 ode.set_atol(1e-8).unwrap();
747
748 let t_eval = [0.25, 0.5, 1.0];
749 let solution = ode
750 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
751 .unwrap();
752
753 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
754 let sens = solution.get_sens().unwrap();
755 assert_eq!(sens.len(), 1);
756 let sens_values = sens[0].as_array::<f64>().unwrap();
757 assert_eq!(sens_values.nrows(), 1);
758 assert_eq!(sens_values.ncols(), t_eval.len());
759 for (i, &t) in t_eval.iter().enumerate() {
760 assert_close(
761 sens_values[(0, i)],
762 logistic_state_dr(LOGISTIC_X0, 2.0, t),
763 ASSERT_TOL,
764 &format!("sensitivity[{i}]"),
765 );
766 }
767 }
768
769 #[test]
770 fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
771 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
772 ode.set_rtol(1e-8).unwrap();
773 ode.set_atol(1e-8).unwrap();
774
775 let t_eval = [0.0, 0.25, 0.5, 1.0];
776 let data_values: Vec<f64> = t_eval
777 .iter()
778 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
779 .collect();
780 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
781 let (value, sens) = ode
782 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
783 .unwrap();
784 let grad = Vec::<f64>::from_host_array(sens).unwrap();
785
786 assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
787 assert_eq!(grad.len(), 1);
788 assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
789 }
790
791 #[test]
792 fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
793 for ode_solver in all_ode_solvers() {
794 assert_hybrid_solution_applies_reset_after_root(ode_solver);
795 }
796 }
797
798 #[test]
799 fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
800 for ode_solver in all_ode_solvers() {
801 assert_hybrid_dense_solution_continues_after_reset(ode_solver);
802 }
803 }
804
805 #[test]
806 fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
807 for ode_solver in all_ode_solvers() {
808 assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
809 }
810 }
811}
812
813#[cfg(all(test, feature = "diffsl-external-dynamic"))]
814mod dynamic_tests {
815 use crate::host_array::FromHostArray;
816 use crate::linear_solver_type::LinearSolverType;
817 use crate::scalar_type::ScalarType;
818 use crate::test_support::{
819 assert_close, assert_solution_tail, external_dynamic_fixture_path, mass_state_deps,
820 rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
821 };
822
823 use super::*;
824
825 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
826 OdeWrapper::new_external_dynamic(
827 external_dynamic_fixture_path(),
828 rhs_state_deps(),
829 rhs_input_deps(),
830 mass_state_deps(),
831 ScalarType::F64,
832 matrix_type,
833 LinearSolverType::Default,
834 ode_solver,
835 )
836 .unwrap()
837 }
838
839 #[test]
840 fn runtime_dispatch_matches_requested_matrix_type() {
841 for matrix_type in [
842 MatrixType::NalgebraDense,
843 MatrixType::FaerDense,
844 MatrixType::FaerSparse,
845 ] {
846 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
847 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
848 assert_eq!(ode.get_code().unwrap(), "");
849 assert_eq!(ode.get_jit_backend().unwrap(), None);
850 assert_eq!(ode.get_nstates().unwrap(), 1);
851 assert_eq!(ode.get_nparams().unwrap(), 1);
852 assert_eq!(ode.get_nout().unwrap(), 1);
853 assert!(ode.has_stop().unwrap());
854
855 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
856 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
857
858 let rhs = ode
859 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
860 .unwrap();
861 assert_close(
862 Vec::<f64>::from_host_array(rhs).unwrap()[0],
863 0.375,
864 ASSERT_TOL,
865 "rhs(0.25)",
866 );
867
868 let rhs_jac_mul = ode
869 .rhs_jac_mul(
870 vector_host(&[2.0]),
871 0.0,
872 vector_host(&[0.25]),
873 vector_host(&[3.0]),
874 )
875 .unwrap();
876 assert_close(
877 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
878 3.0,
879 ASSERT_TOL,
880 "rhs_jac_mul(0.25, 3.0)",
881 );
882 }
883 }
884
885 #[test]
886 fn dense_solution_matches_logistic_solution() {
887 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
888 ode.set_rtol(1e-8).unwrap();
889 ode.set_atol(1e-8).unwrap();
890
891 let t_eval = [0.25, 0.5, 1.0];
892 let solution = ode
893 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
894 .unwrap();
895
896 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
897 }
898
899 #[test]
900 fn non_jit_serialization_is_rejected() {
901 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
902 let err = serde_json::to_string(&ode).unwrap_err().to_string();
903 assert!(err.contains("JIT-backed"));
904 }
905}
906
907#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
908mod jit_tests {
909 use crate::host_array::FromHostArray;
910 use crate::jit::JitBackendType;
911 use crate::linear_solver_type::LinearSolverType;
912 use crate::scalar_type::ScalarType;
913 use crate::test_support::{
914 assert_close, assert_solution_tail, available_jit_backends, hybrid_logistic_diffsl_code,
915 hybrid_logistic_period, hybrid_logistic_state, logistic_diffsl_code, logistic_state,
916 vector_host, ASSERT_TOL, LOGISTIC_X0,
917 };
918 #[cfg(feature = "diffsl-llvm")]
919 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
920 #[cfg(any(
921 all(feature = "diffsl-llvm", not(feature = "diffsl-cranelift")),
922 all(feature = "diffsl-cranelift", not(feature = "diffsl-llvm"))
923 ))]
924 use serde_json::Value;
925 use serde_json::{self};
926
927 use super::*;
928
929 fn all_ode_solvers() -> [OdeSolverType; 4] {
930 [
931 OdeSolverType::Bdf,
932 OdeSolverType::Esdirk34,
933 OdeSolverType::TrBdf2,
934 OdeSolverType::Tsit45,
935 ]
936 }
937
938 fn make_ode(
939 jit_backend: JitBackendType,
940 scalar_type: ScalarType,
941 matrix_type: MatrixType,
942 ode_solver: OdeSolverType,
943 ) -> OdeWrapper {
944 OdeWrapper::new_jit(
945 logistic_diffsl_code(),
946 jit_backend,
947 scalar_type,
948 matrix_type,
949 LinearSolverType::Default,
950 ode_solver,
951 )
952 .unwrap()
953 }
954
955 fn make_hybrid_ode(
956 jit_backend: JitBackendType,
957 matrix_type: MatrixType,
958 ode_solver: OdeSolverType,
959 ) -> OdeWrapper {
960 OdeWrapper::new_jit(
961 hybrid_logistic_diffsl_code(),
962 jit_backend,
963 ScalarType::F64,
964 matrix_type,
965 LinearSolverType::Default,
966 ode_solver,
967 )
968 .unwrap()
969 }
970
971 fn serialized_linear_solver(matrix_type: MatrixType) -> LinearSolverType {
972 match matrix_type {
973 MatrixType::NalgebraDense | MatrixType::FaerDense => LinearSolverType::Lu,
974 MatrixType::FaerSparse => LinearSolverType::Default,
975 }
976 }
977
978 fn configure_serialized_ode(ode: &OdeWrapper, matrix_type: MatrixType) {
979 ode.set_linear_solver(serialized_linear_solver(matrix_type))
980 .unwrap();
981 ode.set_ode_solver(OdeSolverType::TrBdf2).unwrap();
982 ode.set_rtol(1e-7).unwrap();
983 ode.set_atol(1e-9).unwrap();
984
985 let ic_options = ode.get_ic_options();
986 ic_options.set_use_linesearch(true).unwrap();
987 ic_options.set_max_linesearch_iterations(13).unwrap();
988 ic_options.set_max_newton_iterations(17).unwrap();
989 ic_options.set_max_linear_solver_setups(19).unwrap();
990 ic_options.set_step_reduction_factor(0.5).unwrap();
991 ic_options.set_armijo_constant(1e-4).unwrap();
992
993 let options = ode.get_options();
994 options.set_max_nonlinear_solver_iterations(23).unwrap();
995 options.set_max_error_test_failures(29).unwrap();
996 options.set_update_jacobian_after_steps(31).unwrap();
997 options.set_update_rhs_jacobian_after_steps(37).unwrap();
998 options.set_threshold_to_update_jacobian(1e-3).unwrap();
999 options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
1000 options.set_min_timestep(1e-4).unwrap();
1001 }
1002
1003 fn scalar_value(value: f64, scalar_type: ScalarType) -> f64 {
1004 match scalar_type {
1005 ScalarType::F32 => (value as f32) as f64,
1006 ScalarType::F64 => value,
1007 }
1008 }
1009
1010 fn assert_serialization_roundtrip(
1011 jit_backend: JitBackendType,
1012 scalar_type: ScalarType,
1013 matrix_type: MatrixType,
1014 ) {
1015 let ode = make_ode(jit_backend, scalar_type, matrix_type, OdeSolverType::Bdf);
1016 configure_serialized_ode(&ode, matrix_type);
1017
1018 #[cfg(feature = "diffsl-cranelift")]
1019 if jit_backend == JitBackendType::Cranelift {
1020 let err = serde_json::to_string(&ode).unwrap_err().to_string();
1021 assert!(err.contains("not supported for Cranelift"));
1022 return;
1023 }
1024
1025 let y0_before = Vec::<f64>::from_host_array(ode.y0(vector_host(&[2.0])).unwrap()).unwrap();
1026 let rhs_before = Vec::<f64>::from_host_array(
1027 ode.rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1028 .unwrap(),
1029 )
1030 .unwrap();
1031
1032 let encoded = serde_json::to_string(&ode).unwrap();
1033 let decoded: OdeWrapper = serde_json::from_str(&encoded).unwrap();
1034
1035 assert_eq!(decoded.get_jit_backend().unwrap(), Some(jit_backend));
1036 assert_eq!(decoded.get_code().unwrap(), logistic_diffsl_code());
1037 assert_eq!(decoded.get_scalar_type().unwrap(), scalar_type);
1038 assert_eq!(decoded.get_matrix_type().unwrap(), matrix_type);
1039 assert_eq!(
1040 decoded.get_linear_solver().unwrap(),
1041 serialized_linear_solver(matrix_type)
1042 );
1043 assert_eq!(decoded.get_ode_solver().unwrap(), OdeSolverType::TrBdf2);
1044 assert_close(
1045 decoded.get_rtol().unwrap(),
1046 scalar_value(1e-7, scalar_type),
1047 1e-12,
1048 "serialized rtol",
1049 );
1050 assert_close(
1051 decoded.get_atol().unwrap(),
1052 scalar_value(1e-9, scalar_type),
1053 1e-12,
1054 "serialized atol",
1055 );
1056
1057 let ic_options = decoded.get_ic_options();
1058 assert!(ic_options.get_use_linesearch().unwrap());
1059 assert_eq!(ic_options.get_max_linesearch_iterations().unwrap(), 13);
1060 assert_eq!(ic_options.get_max_newton_iterations().unwrap(), 17);
1061 assert_eq!(ic_options.get_max_linear_solver_setups().unwrap(), 19);
1062 assert_close(
1063 ic_options.get_step_reduction_factor().unwrap(),
1064 scalar_value(0.5, scalar_type),
1065 1e-12,
1066 "serialized step_reduction_factor",
1067 );
1068 assert_close(
1069 ic_options.get_armijo_constant().unwrap(),
1070 scalar_value(1e-4, scalar_type),
1071 1e-12,
1072 "serialized armijo_constant",
1073 );
1074
1075 let options = decoded.get_options();
1076 assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 23);
1077 assert_eq!(options.get_max_error_test_failures().unwrap(), 29);
1078 assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 31);
1079 assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 37);
1080 assert_close(
1081 options.get_threshold_to_update_jacobian().unwrap(),
1082 scalar_value(1e-3, scalar_type),
1083 1e-12,
1084 "serialized threshold_to_update_jacobian",
1085 );
1086 assert_close(
1087 options.get_threshold_to_update_rhs_jacobian().unwrap(),
1088 scalar_value(2e-3, scalar_type),
1089 1e-12,
1090 "serialized threshold_to_update_rhs_jacobian",
1091 );
1092 assert_close(
1093 options.get_min_timestep().unwrap(),
1094 scalar_value(1e-4, scalar_type),
1095 1e-12,
1096 "serialized min_timestep",
1097 );
1098
1099 let y0_after =
1100 Vec::<f64>::from_host_array(decoded.y0(vector_host(&[2.0])).unwrap()).unwrap();
1101 let rhs_after = Vec::<f64>::from_host_array(
1102 decoded
1103 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1104 .unwrap(),
1105 )
1106 .unwrap();
1107 assert_eq!(y0_after, y0_before);
1108 assert_close(
1109 rhs_after[0],
1110 rhs_before[0],
1111 ASSERT_TOL,
1112 "serialized rhs matches",
1113 );
1114
1115 decoded
1116 .set_linear_solver(serialized_linear_solver(matrix_type))
1117 .unwrap();
1118 let t_eval = [0.25, 0.5, 1.0];
1119 let solution = decoded
1120 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1121 .unwrap();
1122 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1123 }
1124
1125 fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
1126 let ode = make_ode(
1127 jit_backend,
1128 ScalarType::F64,
1129 matrix_type,
1130 OdeSolverType::Bdf,
1131 );
1132 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
1133 assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
1134 assert_eq!(ode.get_nstates().unwrap(), 1);
1135 assert_eq!(ode.get_nparams().unwrap(), 1);
1136 assert_eq!(ode.get_nout().unwrap(), 1);
1137 assert!(!ode.has_stop().unwrap());
1138
1139 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
1140 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
1141
1142 let rhs = ode
1143 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1144 .unwrap();
1145 assert_close(
1146 Vec::<f64>::from_host_array(rhs).unwrap()[0],
1147 0.375,
1148 ASSERT_TOL,
1149 "jit rhs(0.25)",
1150 );
1151
1152 let rhs_jac_mul = ode
1153 .rhs_jac_mul(
1154 vector_host(&[2.0]),
1155 0.0,
1156 vector_host(&[0.25]),
1157 vector_host(&[3.0]),
1158 )
1159 .unwrap();
1160 assert_close(
1161 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
1162 3.0,
1163 ASSERT_TOL,
1164 "jit rhs_jac_mul(0.25, 3.0)",
1165 );
1166 }
1167
1168 fn assert_solver_dense_solution(
1169 jit_backend: JitBackendType,
1170 scalar_type: ScalarType,
1171 matrix_type: MatrixType,
1172 ode_solver: OdeSolverType,
1173 ) {
1174 let ode = make_ode(jit_backend, scalar_type, matrix_type, ode_solver);
1175 ode.set_rtol(1e-8).unwrap();
1176 ode.set_atol(1e-8).unwrap();
1177
1178 let t_eval = [0.25, 0.5, 1.0];
1179 let solution = ode
1180 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1181 .unwrap();
1182
1183 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1184 }
1185
1186 fn hybrid_t_eval() -> [f64; 7] {
1187 [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
1188 }
1189
1190 fn assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1191 jit_backend: JitBackendType,
1192 ode_solver: OdeSolverType,
1193 ) {
1194 let r = 2.0;
1195 let final_time = 5.0;
1196 let tau = hybrid_logistic_period(r);
1197 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
1198 ode.set_rtol(1e-8).unwrap();
1199 ode.set_atol(1e-8).unwrap();
1200 assert_eq!(ode.get_nstates().unwrap(), 1);
1201 assert_eq!(ode.get_nparams().unwrap(), 1);
1202 assert_eq!(ode.get_nout().unwrap(), 1);
1203 assert!(ode.has_stop().unwrap());
1204
1205 let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
1206 let ys = solution.get_ys().unwrap();
1207 let ys = ys.as_array::<f64>().unwrap();
1208 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1209
1210 assert_eq!(ys.nrows(), 1);
1211 assert_eq!(ys.ncols(), ts.len());
1212 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
1213 assert_close(
1214 *ts.last().unwrap(),
1215 final_time,
1216 ASSERT_TOL,
1217 "jit hybrid final time",
1218 );
1219 assert_close(
1220 ys[(0, ys.ncols() - 1)],
1221 hybrid_logistic_state(r, final_time),
1222 5e-4,
1223 "jit hybrid final value",
1224 );
1225 assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
1226 assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
1227 for (col, &t) in ts.iter().enumerate() {
1228 if ((t / tau).round() * tau - t).abs() < 1e-3 {
1229 continue;
1230 }
1231 assert_close(
1232 ys[(0, col)],
1233 hybrid_logistic_state(r, t),
1234 5e-4,
1235 &format!("jit hybrid value[{col}]"),
1236 );
1237 }
1238 }
1239
1240 fn assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1241 jit_backend: JitBackendType,
1242 ode_solver: OdeSolverType,
1243 ) {
1244 let r = 2.0;
1245 let t_eval = hybrid_t_eval();
1246 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
1247 ode.set_rtol(1e-8).unwrap();
1248 ode.set_atol(1e-8).unwrap();
1249
1250 let solution = ode
1251 .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
1252 .unwrap();
1253 let ys = solution.get_ys().unwrap();
1254 let ys = ys.as_array::<f64>().unwrap();
1255 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1256
1257 assert_eq!(ys.nrows(), 1);
1258 assert_eq!(ys.ncols(), t_eval.len());
1259 assert_eq!(ts, t_eval);
1260 for (col, &t) in t_eval.iter().enumerate() {
1261 assert_close(
1262 ys[(0, col)],
1263 hybrid_logistic_state(r, t),
1264 5e-4,
1265 &format!("jit hybrid dense value[{col}]"),
1266 );
1267 }
1268 }
1269
1270 #[cfg(feature = "diffsl-llvm")]
1271 fn assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(
1272 ode_solver: OdeSolverType,
1273 ) {
1274 let r = 2.0;
1275 let t_eval = hybrid_t_eval();
1276 let ode = make_hybrid_ode(JitBackendType::Llvm, MatrixType::NalgebraDense, ode_solver);
1277 ode.set_rtol(1e-8).unwrap();
1278 ode.set_atol(1e-8).unwrap();
1279
1280 let solution = ode
1281 .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
1282 .unwrap();
1283 let ys = solution.get_ys().unwrap();
1284 let ys = ys.as_array::<f64>().unwrap();
1285 let sens = solution.get_sens().unwrap();
1286
1287 assert_eq!(ys.nrows(), 1);
1288 assert_eq!(ys.ncols(), t_eval.len());
1289 assert_eq!(sens.len(), 1);
1290 let sens_values = sens[0].as_array::<f64>().unwrap();
1291 assert_eq!(sens_values.nrows(), 1);
1292 assert_eq!(sens_values.ncols(), t_eval.len());
1293 for (col, &t) in t_eval.iter().enumerate() {
1294 assert_close(
1295 ys[(0, col)],
1296 hybrid_logistic_state(r, t),
1297 5e-4,
1298 &format!("jit hybrid sens value[{col}]"),
1299 );
1300 assert_close(
1301 sens_values[(0, col)],
1302 hybrid_logistic_state_dr(r, t),
1303 5e-4,
1304 &format!("jit hybrid sensitivity[{col}]"),
1305 );
1306 }
1307 }
1308
1309 #[test]
1310 fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
1311 for jit_backend in available_jit_backends() {
1312 for matrix_type in [
1313 MatrixType::NalgebraDense,
1314 MatrixType::FaerDense,
1315 MatrixType::FaerSparse,
1316 ] {
1317 assert_runtime_dispatch(jit_backend, matrix_type);
1318 }
1319 }
1320 }
1321
1322 #[test]
1323 fn dense_solution_matches_logistic_solution_from_diffsl() {
1324 for jit_backend in available_jit_backends() {
1325 for scalar_type in [ScalarType::F64, ScalarType::F32] {
1326 for (matrix_type, solver) in [
1327 (MatrixType::FaerDense, OdeSolverType::Esdirk34),
1328 (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
1329 (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
1330 ] {
1331 assert_solver_dense_solution(jit_backend, scalar_type, matrix_type, solver);
1332 }
1333 }
1334 }
1335 }
1336
1337 #[test]
1338 fn bdf_dense_solution_matches_logistic_diffsl_model() {
1339 for jit_backend in available_jit_backends() {
1340 let ode = make_ode(
1341 jit_backend,
1342 ScalarType::F64,
1343 MatrixType::NalgebraDense,
1344 OdeSolverType::Bdf,
1345 );
1346 ode.set_rtol(1e-8).unwrap();
1347 ode.set_atol(1e-8).unwrap();
1348
1349 let t_eval = [0.25, 0.5, 1.0];
1350 let solution = ode
1351 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1352 .unwrap();
1353
1354 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1355 }
1356 }
1357
1358 #[test]
1359 fn bdf_solution_matches_logistic_diffsl_model() {
1360 for jit_backend in available_jit_backends() {
1361 let x0 = LOGISTIC_X0;
1362 let r = 2.0;
1363 let ode = make_ode(
1364 jit_backend,
1365 ScalarType::F64,
1366 MatrixType::NalgebraDense,
1367 OdeSolverType::Bdf,
1368 );
1369 ode.set_rtol(1e-8).unwrap();
1370 ode.set_atol(1e-8).unwrap();
1371
1372 let final_time = 1.0;
1373 let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
1374
1375 let ys = solution.get_ys().unwrap();
1376 let ys = ys.as_array::<f64>().unwrap();
1377 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1378
1379 assert_eq!(ys.nrows(), 1);
1380 assert_eq!(ys.ncols(), ts.len());
1381 assert!(
1382 !ts.is_empty(),
1383 "expected solve() to record at least one time point"
1384 );
1385 assert_close(
1386 *ts.last().unwrap(),
1387 final_time,
1388 ASSERT_TOL,
1389 "solve final time",
1390 );
1391 for (i, &t) in ts.iter().enumerate() {
1392 assert_close(
1393 ys[(0, i)],
1394 logistic_state(x0, r, t),
1395 5e-4,
1396 &format!("solve value[{i}]"),
1397 );
1398 }
1399 }
1400 }
1401
1402 #[cfg_attr(
1403 all(target_os = "macos", target_arch = "x86_64"),
1404 ignore = "from_external_object is unsupported on Intel macOS due to unsupported relocations"
1405 )]
1406 #[test]
1407 fn serialization_roundtrip_restores_full_solver_state() {
1408 for jit_backend in available_jit_backends() {
1409 for scalar_type in [ScalarType::F64, ScalarType::F32] {
1410 for matrix_type in [MatrixType::NalgebraDense, MatrixType::FaerSparse] {
1411 assert_serialization_roundtrip(jit_backend, scalar_type, matrix_type);
1412 }
1413 }
1414 }
1415 }
1416
1417 #[cfg(all(feature = "diffsl-llvm", not(feature = "diffsl-cranelift")))]
1418 #[test]
1419 fn deserialization_rejects_unavailable_jit_backend() {
1420 let ode = make_ode(
1421 JitBackendType::Llvm,
1422 ScalarType::F64,
1423 MatrixType::NalgebraDense,
1424 OdeSolverType::Bdf,
1425 );
1426 let mut value = serde_json::to_value(&ode).unwrap();
1427 value["jit_backend"] = Value::String("cranelift".to_string());
1428 let err = serde_json::from_value::<OdeWrapper>(value)
1429 .err()
1430 .unwrap()
1431 .to_string();
1432 assert!(err.contains("unknown variant"));
1433 }
1434
1435 #[cfg(all(feature = "diffsl-cranelift", not(feature = "diffsl-llvm")))]
1436 #[test]
1437 fn deserialization_rejects_unavailable_jit_backend() {
1438 let ode = make_ode(
1439 JitBackendType::Cranelift,
1440 ScalarType::F64,
1441 MatrixType::NalgebraDense,
1442 OdeSolverType::Bdf,
1443 );
1444 let mut value = serde_json::to_value(&ode).unwrap();
1445 value["jit_backend"] = Value::String("llvm".to_string());
1446 let err = serde_json::from_value::<OdeWrapper>(value)
1447 .err()
1448 .unwrap()
1449 .to_string();
1450 assert!(err.contains("unknown variant"));
1451 }
1452
1453 #[test]
1454 fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
1455 for jit_backend in available_jit_backends() {
1456 for ode_solver in all_ode_solvers() {
1457 assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1458 jit_backend,
1459 ode_solver,
1460 );
1461 }
1462 }
1463 }
1464
1465 #[test]
1466 fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
1467 for jit_backend in available_jit_backends() {
1468 for ode_solver in all_ode_solvers() {
1469 assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1470 jit_backend,
1471 ode_solver,
1472 );
1473 }
1474 }
1475 }
1476
1477 #[cfg(feature = "diffsl-llvm")]
1478 #[test]
1479 fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
1480 let ode = make_ode(
1481 JitBackendType::Llvm,
1482 ScalarType::F64,
1483 MatrixType::NalgebraDense,
1484 OdeSolverType::Bdf,
1485 );
1486 ode.set_rtol(1e-8).unwrap();
1487 ode.set_atol(1e-8).unwrap();
1488
1489 let t_eval = [0.25, 0.5, 1.0];
1490 let solution = ode
1491 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
1492 .unwrap();
1493
1494 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1495 let sens = solution.get_sens().unwrap();
1496 assert_eq!(sens.len(), 1);
1497 let sens_values = sens[0].as_array::<f64>().unwrap();
1498 assert_eq!(sens_values.nrows(), 1);
1499 assert_eq!(sens_values.ncols(), t_eval.len());
1500 for (i, &t) in t_eval.iter().enumerate() {
1501 assert_close(
1502 sens_values[(0, i)],
1503 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1504 ASSERT_TOL,
1505 &format!("jit sensitivity[{i}]"),
1506 );
1507 }
1508 }
1509
1510 #[cfg(feature = "diffsl-llvm")]
1511 #[test]
1512 fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
1513 for ode_solver in all_ode_solvers() {
1514 assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(ode_solver);
1515 }
1516 }
1517
1518 #[cfg(feature = "diffsl-llvm")]
1519 #[test]
1520 fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
1521 let ode = make_ode(
1522 JitBackendType::Llvm,
1523 ScalarType::F64,
1524 MatrixType::NalgebraDense,
1525 OdeSolverType::Bdf,
1526 );
1527 ode.set_rtol(1e-8).unwrap();
1528 ode.set_atol(1e-8).unwrap();
1529
1530 let t_eval = [0.0, 0.25, 0.5, 1.0];
1531 let data_values: Vec<f64> = t_eval
1532 .iter()
1533 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1534 .collect();
1535 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1536 let (value, sens) = ode
1537 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
1538 .unwrap();
1539 let grad = Vec::<f64>::from_host_array(sens).unwrap();
1540
1541 assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1542 assert_eq!(grad.len(), 1);
1543 assert!(
1544 grad[0].is_finite(),
1545 "jit sum_squares gradient should be finite"
1546 );
1547 }
1548
1549 #[cfg(feature = "diffsl-llvm")]
1550 #[test]
1551 fn bdf_sum_squares_adjoint_matches_finite_difference_gradient_for_logistic_model() {
1552 let logistic_model = r#"
1553 in_i { r = 1, k = 1, y0 = 0.1 }
1554 u { y0 }
1555 F { r * u * (1.0 - u / k) }
1556 "#;
1557 let ode = OdeWrapper::new_jit(
1558 logistic_model,
1559 JitBackendType::Llvm,
1560 ScalarType::F64,
1561 MatrixType::NalgebraDense,
1562 LinearSolverType::Default,
1563 OdeSolverType::Bdf,
1564 )
1565 .unwrap();
1566 ode.set_rtol(1e-8).unwrap();
1567 ode.set_atol(1e-8).unwrap();
1568
1569 let t_eval = [0.0, 0.1, 0.3, 0.7, 1.0];
1570 let data_params = [1.2, 0.9, 0.2];
1571 let fit_params = [0.8, 1.3, 0.12];
1572 let fd_step = 1e-6;
1573
1574 let data_solution = ode
1575 .solve_dense(vector_host(&data_params), vector_host(&t_eval))
1576 .unwrap();
1577 let data_ys = data_solution.get_ys().unwrap();
1578 let data_ys = data_ys.as_array::<f64>().unwrap();
1579 let data_values: Vec<f64> = (0..t_eval.len()).map(|col| data_ys[(0, col)]).collect();
1580
1581 let objective_from_dense = |params: [f64; 3]| -> f64 {
1582 let solution = ode
1583 .solve_dense(vector_host(¶ms), vector_host(&t_eval))
1584 .unwrap();
1585 let ys = solution.get_ys().unwrap();
1586 let ys = ys.as_array::<f64>().unwrap();
1587 (0..t_eval.len())
1588 .map(|col| {
1589 let residual = ys[(0, col)] - data_values[col];
1590 residual * residual
1591 })
1592 .sum()
1593 };
1594
1595 let objective_fd = objective_from_dense(fit_params);
1596 let mut finite_difference_gradient = [0.0; 3];
1597 for i in 0..fit_params.len() {
1598 let mut plus = fit_params;
1599 let mut minus = fit_params;
1600 let step = fd_step * fit_params[i].abs().max(1.0);
1601 plus[i] += step;
1602 minus[i] -= step;
1603 finite_difference_gradient[i] =
1604 (objective_from_dense(plus) - objective_from_dense(minus)) / (2.0 * step);
1605 }
1606
1607 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1608 let ode_adj = OdeWrapper::new_jit(
1609 logistic_model,
1610 JitBackendType::Llvm,
1611 ScalarType::F64,
1612 MatrixType::NalgebraDense,
1613 LinearSolverType::Default,
1614 OdeSolverType::Bdf,
1615 )
1616 .unwrap();
1617 ode_adj.set_rtol(1e-8).unwrap();
1618 ode_adj.set_atol(1e-8).unwrap();
1619
1620 let (objective_adj, sens) = ode_adj
1621 .solve_sum_squares_adj(vector_host(&fit_params), data, vector_host(&t_eval))
1622 .unwrap();
1623 let adjoint_gradient = Vec::<f64>::from_host_array(sens).unwrap();
1624
1625 assert_eq!(adjoint_gradient.len(), 3);
1626 assert_close(
1627 objective_adj,
1628 objective_fd,
1629 1e-5,
1630 "sum_squares objective from dense finite differences",
1631 );
1632 for i in 0..adjoint_gradient.len() {
1633 assert_close(
1634 adjoint_gradient[i],
1635 finite_difference_gradient[i],
1636 5e-4,
1637 &format!("sum_squares gradient component {i}"),
1638 );
1639 }
1640 }
1641}