1use std::sync::{Arc, Mutex};
2
3use crate::jit::JitBackendType;
4use crate::{
5 error::DiffsolJsError, host_array::HostArray,
6 initial_condition_options::InitialConditionSolverOptions, linear_solver_type::LinearSolverType,
7 matrix_type::MatrixType, ode_options::OdeSolverOptions, ode_solver_type::OdeSolverType,
8 scalar_type::ScalarType, solution_wrapper::SolutionWrapper, solve::Solve,
9};
10
11pub struct Ode {
12 pub(crate) solve: Box<dyn Solve>,
13 code: String,
14 scalar_type: ScalarType,
15 jit_backend: Option<JitBackendType>,
16 linear_solver: LinearSolverType,
17 ode_solver: OdeSolverType,
18}
19
20unsafe impl Send for Ode {}
21unsafe impl Sync for Ode {}
22
23#[derive(Clone)]
24pub struct OdeWrapper(Arc<Mutex<Ode>>);
25
26impl OdeWrapper {
27 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolJsError> {
28 self.0.lock().map_err(|_| {
29 DiffsolJsError::from(diffsol::error::DiffsolError::Other(
30 "Failed to acquire lock on ODE solver".to_string(),
31 ))
32 })
33 }
34}
35
36impl OdeWrapper {
37 fn build(
38 code: String,
39 scalar_type: ScalarType,
40 solve: Box<dyn Solve>,
41 jit_backend: Option<JitBackendType>,
42 linear_solver: LinearSolverType,
43 ode_solver: OdeSolverType,
44 ) -> Result<Self, DiffsolJsError> {
45 solve.check(linear_solver)?;
46 Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
47 code,
48 scalar_type,
49 solve,
50 jit_backend,
51 linear_solver,
52 ode_solver,
53 }))))
54 }
55
56 #[cfg(feature = "external")]
58 pub fn new_external(
59 rhs_state_deps: Vec<(usize, usize)>,
60 rhs_input_deps: Vec<(usize, usize)>,
61 mass_state_deps: Vec<(usize, usize)>,
62 scalar_type: ScalarType,
63 matrix_type: MatrixType,
64 linear_solver: LinearSolverType,
65 ode_solver: OdeSolverType,
66 ) -> Result<Self, DiffsolJsError> {
67 let solve = crate::solve::solve_factory_external(
68 rhs_state_deps,
69 rhs_input_deps,
70 mass_state_deps,
71 matrix_type,
72 scalar_type,
73 )?;
74 Self::build(
75 String::new(),
76 scalar_type,
77 solve,
78 None,
79 linear_solver,
80 ode_solver,
81 )
82 }
83
84 #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
86 pub fn new_jit(
87 code: &str,
88 jit_backend: JitBackendType,
89 scalar_type: ScalarType,
90 matrix_type: MatrixType,
91 linear_solver: LinearSolverType,
92 ode_solver: OdeSolverType,
93 ) -> Result<Self, DiffsolJsError> {
94 let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
95 Self::build(
96 code.to_owned(),
97 scalar_type,
98 solve,
99 Some(jit_backend),
100 linear_solver,
101 ode_solver,
102 )
103 }
104
105 pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolJsError> {
107 Ok(self.guard()?.solve.matrix_type())
108 }
109
110 pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolJsError> {
112 Ok(self.guard()?.ode_solver)
113 }
114
115 pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolJsError> {
116 self.guard()?.ode_solver = value;
117 Ok(())
118 }
119
120 pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolJsError> {
123 Ok(self.guard()?.linear_solver)
124 }
125
126 pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolJsError> {
127 self.guard()?.solve.check(value)?;
128 self.guard()?.linear_solver = value;
129 Ok(())
130 }
131
132 pub fn get_rtol(&self) -> Result<f64, DiffsolJsError> {
134 Ok(self.guard()?.solve.rtol())
135 }
136
137 pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolJsError> {
138 self.guard()?.solve.set_rtol(value);
139 Ok(())
140 }
141
142 pub fn get_atol(&self) -> Result<f64, DiffsolJsError> {
144 Ok(self.guard()?.solve.atol())
145 }
146
147 pub fn set_atol(&self, value: f64) -> Result<(), DiffsolJsError> {
148 self.guard()?.solve.set_atol(value);
149 Ok(())
150 }
151
152 pub fn get_code(&self) -> Result<String, DiffsolJsError> {
153 Ok(self.guard()?.code.clone())
154 }
155
156 pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolJsError> {
157 Ok(self.guard()?.scalar_type)
158 }
159
160 pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolJsError> {
161 Ok(self.guard()?.jit_backend)
162 }
163
164 pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
165 InitialConditionSolverOptions::new(self.0.clone())
166 }
167
168 pub fn get_options(&self) -> OdeSolverOptions {
169 OdeSolverOptions::new(self.0.clone())
170 }
171
172 pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolJsError> {
174 let mut self_guard = self.guard()?;
175 self_guard.solve.y0(params.as_slice()?)
176 }
177
178 pub fn rhs(
180 &self,
181 params: HostArray,
182 t: f64,
183 y: HostArray,
184 ) -> Result<HostArray, DiffsolJsError> {
185 let mut self_guard = self.guard()?;
186 self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
187 }
188
189 pub fn rhs_jac_mul(
191 &self,
192 params: HostArray,
193 t: f64,
194 y: HostArray,
195 v: HostArray,
196 ) -> Result<HostArray, DiffsolJsError> {
197 let mut self_guard = self.guard()?;
198 self_guard
199 .solve
200 .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
201 }
202
203 #[allow(clippy::type_complexity)]
221 pub fn solve(
222 &self,
223 params: HostArray,
224 final_time: f64,
225 ) -> Result<SolutionWrapper, DiffsolJsError> {
226 let mut self_guard = self.guard()?;
227 let params = params.as_slice()?;
228 let linear_solver = self_guard.linear_solver;
229 let method = self_guard.ode_solver;
230 let solution = self_guard
231 .solve
232 .solve(method, linear_solver, params, final_time)?;
233 Ok(SolutionWrapper::new(solution))
234 }
235
236 pub fn solve_hybrid(
239 &self,
240 params: HostArray,
241 final_time: f64,
242 ) -> Result<SolutionWrapper, DiffsolJsError> {
243 let mut self_guard = self.guard()?;
244 let params = params.as_slice()?;
245 let linear_solver = self_guard.linear_solver;
246 let method = self_guard.ode_solver;
247 let solution = self_guard
248 .solve
249 .solve_hybrid(method, linear_solver, params, final_time)?;
250 Ok(SolutionWrapper::new(solution))
251 }
252
253 pub fn solve_dense(
267 &self,
268 params: HostArray,
269 t_eval: HostArray,
270 ) -> Result<SolutionWrapper, DiffsolJsError> {
271 let mut self_guard = self.guard()?;
272 let params = params.as_slice()?;
273 let t_eval = t_eval.as_slice()?;
274 let linear_solver = self_guard.linear_solver;
275 let method = self_guard.ode_solver;
276 let solution = self_guard
277 .solve
278 .solve_dense(method, linear_solver, params, t_eval)?;
279 Ok(SolutionWrapper::new(solution))
280 }
281
282 pub fn solve_hybrid_dense(
286 &self,
287 params: HostArray,
288 t_eval: HostArray,
289 ) -> Result<SolutionWrapper, DiffsolJsError> {
290 let mut self_guard = self.guard()?;
291 let params = params.as_slice()?;
292 let t_eval = t_eval.as_slice()?;
293 let linear_solver = self_guard.linear_solver;
294 let method = self_guard.ode_solver;
295 let solution =
296 self_guard
297 .solve
298 .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
299 Ok(SolutionWrapper::new(solution))
300 }
301
302 #[allow(clippy::type_complexity)]
315 pub fn solve_fwd_sens(
316 &self,
317 params: HostArray,
318 t_eval: HostArray,
319 ) -> Result<SolutionWrapper, DiffsolJsError> {
320 let mut self_guard = self.guard()?;
321 let params = params.as_slice()?;
322 let t_eval = t_eval.as_slice()?;
323 let linear_solver = self_guard.linear_solver;
324 let method = self_guard.ode_solver;
325 let solution = self_guard
326 .solve
327 .solve_fwd_sens(method, linear_solver, params, t_eval)?;
328 Ok(SolutionWrapper::new(solution))
329 }
330
331 #[allow(clippy::type_complexity)]
335 pub fn solve_hybrid_fwd_sens(
336 &self,
337 params: HostArray,
338 t_eval: HostArray,
339 ) -> Result<SolutionWrapper, DiffsolJsError> {
340 let mut self_guard = self.guard()?;
341 let params = params.as_slice()?;
342 let t_eval = t_eval.as_slice()?;
343 let linear_solver = self_guard.linear_solver;
344 let method = self_guard.ode_solver;
345 let solution =
346 self_guard
347 .solve
348 .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
349 Ok(SolutionWrapper::new(solution))
350 }
351
352 #[allow(clippy::type_complexity)]
357 pub(crate) fn solve_sum_squares_adj(
358 &self,
359 params: HostArray,
360 data: HostArray,
361 t_eval: HostArray,
362 ) -> Result<(f64, HostArray), DiffsolJsError> {
363 let mut self_guard = self.guard()?;
364 let linear_solver = self_guard.linear_solver;
365 let ode_solver = self_guard.ode_solver;
366
367 self_guard.solve.solve_sum_squares_adj(
368 ode_solver,
369 linear_solver,
370 ode_solver,
371 linear_solver,
372 params.as_slice()?,
373 data,
374 t_eval.as_slice()?,
375 )
376 }
377}
378
379#[cfg(all(test, feature = "diffsl-external-f64"))]
380mod tests {
381 use crate::host_array::FromHostArray;
382 use crate::linear_solver_type::LinearSolverType;
383 use crate::scalar_type::ScalarType;
384 use crate::test_support::{
385 ASSERT_TOL, LOGISTIC_X0, assert_close, assert_solution_tail, logistic_integral,
386 logistic_state, logistic_state_dr, mass_state_deps, rhs_input_deps, rhs_state_deps,
387 vector_host,
388 };
389
390 use super::*;
391
392 fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
393 OdeWrapper::new_external(
394 rhs_state_deps(),
395 rhs_input_deps(),
396 mass_state_deps(),
397 ScalarType::F64,
398 matrix_type,
399 LinearSolverType::Default,
400 ode_solver,
401 )
402 .unwrap()
403 }
404
405 fn assert_runtime_dispatch(matrix_type: MatrixType) {
406 let ode = make_ode(matrix_type, OdeSolverType::Bdf);
407 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
408
409 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
410 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
411
412 let rhs = ode
413 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
414 .unwrap();
415 assert_close(
416 Vec::<f64>::from_host_array(rhs).unwrap()[0],
417 0.375,
418 ASSERT_TOL,
419 "rhs(0.25)",
420 );
421
422 let rhs_jac_mul = ode
423 .rhs_jac_mul(
424 vector_host(&[2.0]),
425 0.0,
426 vector_host(&[0.25]),
427 vector_host(&[3.0]),
428 )
429 .unwrap();
430 assert_close(
431 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
432 3.0,
433 ASSERT_TOL,
434 "rhs_jac_mul(0.25, 3.0)",
435 );
436 }
437
438 fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
439 let ode = make_ode(matrix_type, ode_solver);
440 ode.set_rtol(1e-8).unwrap();
441 ode.set_atol(1e-8).unwrap();
442
443 let t_eval = [0.25, 0.5, 1.0];
444 let solution = ode
445 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
446 .unwrap();
447
448 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
449 }
450
451 fn hybrid_root_time() -> f64 {
452 0.5 * 9.0_f64.ln()
453 }
454
455 #[test]
456 fn runtime_dispatch_matches_requested_matrix_type() {
457 for matrix_type in [
458 MatrixType::NalgebraDense,
459 MatrixType::FaerDense,
460 MatrixType::FaerSparse,
461 ] {
462 assert_runtime_dispatch(matrix_type);
463 }
464 }
465
466 #[test]
467 fn bdf_dense_solution_matches_logistic_solution() {
468 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
469 ode.set_rtol(1e-8).unwrap();
470 ode.set_atol(1e-8).unwrap();
471
472 let t_eval = [0.25, 0.5, 1.0];
473 let solution = ode
474 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
475 .unwrap();
476
477 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
478 }
479
480 #[test]
481 fn esdirk34_dense_solution_matches_logistic_solution() {
482 assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
483 }
484
485 #[test]
486 fn tr_bdf2_sparse_solution_matches_logistic_solution() {
487 assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
488 }
489
490 #[test]
491 fn tsit45_dense_solution_matches_logistic_solution() {
492 assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
493 }
494
495 #[test]
496 fn bdf_forward_sensitivities_match_logistic_derivative() {
497 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
498 ode.set_rtol(1e-8).unwrap();
499 ode.set_atol(1e-8).unwrap();
500
501 let t_eval = [0.25, 0.5, 1.0];
502 let solution = ode
503 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
504 .unwrap();
505
506 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
507 let sens = solution.get_sens().unwrap();
508 assert_eq!(sens.len(), 1);
509 let sens_values = sens[0].as_array::<f64>().unwrap();
510 assert_eq!(sens_values.nrows(), 1);
511 assert_eq!(sens_values.ncols(), t_eval.len());
512 for (i, &t) in t_eval.iter().enumerate() {
513 assert_close(
514 sens_values[(0, i)],
515 logistic_state_dr(LOGISTIC_X0, 2.0, t),
516 ASSERT_TOL,
517 &format!("sensitivity[{i}]"),
518 );
519 }
520 }
521
522 #[test]
523 fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
524 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
525 ode.set_rtol(1e-8).unwrap();
526 ode.set_atol(1e-8).unwrap();
527
528 let t_eval = [0.0, 0.25, 0.5, 1.0];
529 let data_values: Vec<f64> = t_eval
530 .iter()
531 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
532 .collect();
533 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
534 let (value, sens) = ode
535 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
536 .unwrap();
537 let grad = Vec::<f64>::from_host_array(sens).unwrap();
538
539 assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
540 assert_eq!(grad.len(), 1);
541 assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
542 }
543
544 #[test]
545 fn bdf_hybrid_solution_applies_reset_after_root() {
546 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
547 ode.set_rtol(1e-8).unwrap();
548 ode.set_atol(1e-8).unwrap();
549
550 let final_time = 2.0;
551 let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
552 let ys = solution.get_ys().unwrap();
553 let ys = ys.as_array::<f64>().unwrap();
554 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
555 let root_time = hybrid_root_time();
556
557 assert_eq!(ys.nrows(), 1);
558 assert_eq!(ys.ncols(), ts.len());
559 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
560 assert_close(
561 *ts.last().unwrap(),
562 final_time,
563 ASSERT_TOL,
564 "hybrid final time",
565 );
566 assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
567 assert!(
568 ts.iter().any(|&t| t < root_time),
569 "expected pre-root samples"
570 );
571 assert!(
572 ts.iter().any(|&t| t > root_time),
573 "expected post-root samples after reset"
574 );
575 }
576
577 #[test]
578 fn bdf_hybrid_dense_solution_continues_after_reset() {
579 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
580 ode.set_rtol(1e-8).unwrap();
581 ode.set_atol(1e-8).unwrap();
582
583 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
584 let solution = ode
585 .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
586 .unwrap();
587 let ys = solution.get_ys().unwrap();
588 let ys = ys.as_array::<f64>().unwrap();
589
590 assert_eq!(ys.nrows(), 1);
591 assert_eq!(ys.ncols(), t_eval.len());
592 assert_close(
593 ys[(0, 0)],
594 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
595 5e-4,
596 "hybrid dense pre-root value",
597 );
598 assert_close(
599 ys[(0, 1)],
600 logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
601 5e-4,
602 "hybrid dense near-root value",
603 );
604 for col in 2..t_eval.len() {
605 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
606 }
607 }
608
609 #[test]
610 fn bdf_hybrid_forward_sensitivities_complete_across_reset() {
611 let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
612 ode.set_rtol(1e-8).unwrap();
613 ode.set_atol(1e-8).unwrap();
614
615 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
616 let solution = ode
617 .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
618 .unwrap();
619 let ys = solution.get_ys().unwrap();
620 let ys = ys.as_array::<f64>().unwrap();
621 let sens = solution.get_sens().unwrap();
622
623 assert_eq!(ys.nrows(), 1);
624 assert_eq!(ys.ncols(), t_eval.len());
625 assert_eq!(sens.len(), 1);
626 let sens_values = sens[0].as_array::<f64>().unwrap();
627 assert_eq!(sens_values.nrows(), 1);
628 assert_eq!(sens_values.ncols(), t_eval.len());
629 assert_close(
630 ys[(0, 0)],
631 logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
632 5e-4,
633 "hybrid sens pre-root value",
634 );
635 for col in 2..t_eval.len() {
636 assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
637 assert!(
638 sens_values[(0, col)].is_finite(),
639 "expected finite post-root sensitivity at column {col}"
640 );
641 }
642 }
643}
644
645#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
646mod jit_tests {
647 use crate::host_array::FromHostArray;
648 use crate::jit::JitBackendType;
649 use crate::linear_solver_type::LinearSolverType;
650 use crate::scalar_type::ScalarType;
651 use crate::test_support::{
652 ASSERT_TOL, LOGISTIC_X0, assert_close, assert_solution_tail, available_jit_backends,
653 hybrid_logistic_diffsl_code, hybrid_logistic_period, hybrid_logistic_state,
654 hybrid_logistic_state_dr, logistic_diffsl_code, logistic_state, vector_host,
655 };
656 #[cfg(feature = "diffsl-llvm")]
657 use crate::test_support::{logistic_integral, logistic_state_dr};
658
659 use super::*;
660
661 fn make_ode(
662 jit_backend: JitBackendType,
663 matrix_type: MatrixType,
664 ode_solver: OdeSolverType,
665 ) -> OdeWrapper {
666 OdeWrapper::new_jit(
667 logistic_diffsl_code(),
668 jit_backend,
669 ScalarType::F64,
670 matrix_type,
671 LinearSolverType::Default,
672 ode_solver,
673 )
674 .unwrap()
675 }
676
677 fn make_hybrid_ode(
678 jit_backend: JitBackendType,
679 matrix_type: MatrixType,
680 ode_solver: OdeSolverType,
681 ) -> OdeWrapper {
682 OdeWrapper::new_jit(
683 hybrid_logistic_diffsl_code(),
684 jit_backend,
685 ScalarType::F64,
686 matrix_type,
687 LinearSolverType::Default,
688 ode_solver,
689 )
690 .unwrap()
691 }
692
693 fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
694 let ode = make_ode(jit_backend, matrix_type, OdeSolverType::Bdf);
695 assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
696 assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
697
698 let y0 = ode.y0(vector_host(&[2.0])).unwrap();
699 assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
700
701 let rhs = ode
702 .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
703 .unwrap();
704 assert_close(
705 Vec::<f64>::from_host_array(rhs).unwrap()[0],
706 0.375,
707 ASSERT_TOL,
708 "jit rhs(0.25)",
709 );
710
711 let rhs_jac_mul = ode
712 .rhs_jac_mul(
713 vector_host(&[2.0]),
714 0.0,
715 vector_host(&[0.25]),
716 vector_host(&[3.0]),
717 )
718 .unwrap();
719 assert_close(
720 Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
721 3.0,
722 ASSERT_TOL,
723 "jit rhs_jac_mul(0.25, 3.0)",
724 );
725 }
726
727 fn assert_solver_dense_solution(
728 jit_backend: JitBackendType,
729 matrix_type: MatrixType,
730 ode_solver: OdeSolverType,
731 ) {
732 let ode = make_ode(jit_backend, matrix_type, ode_solver);
733 ode.set_rtol(1e-8).unwrap();
734 ode.set_atol(1e-8).unwrap();
735
736 let t_eval = [0.25, 0.5, 1.0];
737 let solution = ode
738 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
739 .unwrap();
740
741 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
742 }
743
744 fn hybrid_t_eval() -> [f64; 7] {
745 [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
746 }
747
748 #[test]
749 fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
750 for jit_backend in available_jit_backends() {
751 for matrix_type in [
752 MatrixType::NalgebraDense,
753 MatrixType::FaerDense,
754 MatrixType::FaerSparse,
755 ] {
756 assert_runtime_dispatch(jit_backend, matrix_type);
757 }
758 }
759 }
760
761 #[test]
762 fn dense_solution_matches_logistic_solution_from_diffsl() {
763 for jit_backend in available_jit_backends() {
764 for (matrix_type, solver) in [
765 (MatrixType::FaerDense, OdeSolverType::Esdirk34),
766 (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
767 (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
768 ] {
769 assert_solver_dense_solution(jit_backend, matrix_type, solver);
770 }
771 }
772 }
773
774 #[test]
775 fn bdf_dense_solution_matches_logistic_diffsl_model() {
776 for jit_backend in available_jit_backends() {
777 let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
778 ode.set_rtol(1e-8).unwrap();
779 ode.set_atol(1e-8).unwrap();
780
781 let t_eval = [0.25, 0.5, 1.0];
782 let solution = ode
783 .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
784 .unwrap();
785
786 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
787 }
788 }
789
790 #[test]
791 fn bdf_solution_matches_logistic_diffsl_model() {
792 for jit_backend in available_jit_backends() {
793 let x0 = LOGISTIC_X0;
794 let r = 2.0;
795 let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
796 ode.set_rtol(1e-8).unwrap();
797 ode.set_atol(1e-8).unwrap();
798
799 let final_time = 1.0;
800 let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
801
802 let ys = solution.get_ys().unwrap();
803 let ys = ys.as_array::<f64>().unwrap();
804 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
805
806 assert_eq!(ys.nrows(), 1);
807 assert_eq!(ys.ncols(), ts.len());
808 assert!(
809 !ts.is_empty(),
810 "expected solve() to record at least one time point"
811 );
812 assert_close(
813 *ts.last().unwrap(),
814 final_time,
815 ASSERT_TOL,
816 "solve final time",
817 );
818 for (i, &t) in ts.iter().enumerate() {
819 assert_close(
820 ys[(0, i)],
821 logistic_state(x0, r, t),
822 5e-4,
823 &format!("solve value[{i}]"),
824 );
825 }
826 }
827 }
828
829 #[test]
830 fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
831 let r = 2.0;
832 let final_time = 5.0;
833 let tau = hybrid_logistic_period(r);
834 for jit_backend in available_jit_backends() {
835 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
836 ode.set_rtol(1e-8).unwrap();
837 ode.set_atol(1e-8).unwrap();
838
839 let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
840 let ys = solution.get_ys().unwrap();
841 let ys = ys.as_array::<f64>().unwrap();
842 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
843
844 assert_eq!(ys.nrows(), 1);
845 assert_eq!(ys.ncols(), ts.len());
846 assert!(!ts.is_empty(), "expected hybrid solve to produce output");
847 assert_close(
848 *ts.last().unwrap(),
849 final_time,
850 ASSERT_TOL,
851 "jit hybrid final time",
852 );
853 assert_close(
854 ys[(0, ys.ncols() - 1)],
855 hybrid_logistic_state(r, final_time),
856 5e-4,
857 "jit hybrid final value",
858 );
859 assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
860 assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
861 for (col, &t) in ts.iter().enumerate() {
862 if ((t / tau).round() * tau - t).abs() < 1e-3 {
863 continue;
864 }
865 assert_close(
866 ys[(0, col)],
867 hybrid_logistic_state(r, t),
868 5e-4,
869 &format!("jit hybrid value[{col}]"),
870 );
871 }
872 }
873 }
874
875 #[test]
876 fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
877 let r = 2.0;
878 let t_eval = hybrid_t_eval();
879 for jit_backend in available_jit_backends() {
880 let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
881 ode.set_rtol(1e-8).unwrap();
882 ode.set_atol(1e-8).unwrap();
883
884 let solution = ode
885 .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
886 .unwrap();
887 let ys = solution.get_ys().unwrap();
888 let ys = ys.as_array::<f64>().unwrap();
889 let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
890
891 assert_eq!(ys.nrows(), 1);
892 assert_eq!(ys.ncols(), t_eval.len());
893 assert_eq!(ts, t_eval);
894 for (col, &t) in t_eval.iter().enumerate() {
895 assert_close(
896 ys[(0, col)],
897 hybrid_logistic_state(r, t),
898 5e-4,
899 &format!("jit hybrid dense value[{col}]"),
900 );
901 }
902 }
903 }
904
905 #[cfg(feature = "diffsl-llvm")]
906 #[test]
907 fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
908 let ode = make_ode(
909 JitBackendType::Llvm,
910 MatrixType::NalgebraDense,
911 OdeSolverType::Bdf,
912 );
913 ode.set_rtol(1e-8).unwrap();
914 ode.set_atol(1e-8).unwrap();
915
916 let t_eval = [0.25, 0.5, 1.0];
917 let solution = ode
918 .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
919 .unwrap();
920
921 assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
922 let sens = solution.get_sens().unwrap();
923 assert_eq!(sens.len(), 1);
924 let sens_values = sens[0].as_array::<f64>().unwrap();
925 assert_eq!(sens_values.nrows(), 1);
926 assert_eq!(sens_values.ncols(), t_eval.len());
927 for (i, &t) in t_eval.iter().enumerate() {
928 assert_close(
929 sens_values[(0, i)],
930 logistic_state_dr(LOGISTIC_X0, 2.0, t),
931 ASSERT_TOL,
932 &format!("jit sensitivity[{i}]"),
933 );
934 }
935 }
936
937 #[cfg(feature = "diffsl-llvm")]
938 #[test]
939 fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
940 let r = 2.0;
941 let t_eval = hybrid_t_eval();
942 let ode = make_hybrid_ode(
943 JitBackendType::Llvm,
944 MatrixType::NalgebraDense,
945 OdeSolverType::Bdf,
946 );
947 ode.set_rtol(1e-8).unwrap();
948 ode.set_atol(1e-8).unwrap();
949
950 let solution = ode
951 .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
952 .unwrap();
953 let ys = solution.get_ys().unwrap();
954 let ys = ys.as_array::<f64>().unwrap();
955 let sens = solution.get_sens().unwrap();
956
957 assert_eq!(ys.nrows(), 1);
958 assert_eq!(ys.ncols(), t_eval.len());
959 assert_eq!(sens.len(), 1);
960 let sens_values = sens[0].as_array::<f64>().unwrap();
961 assert_eq!(sens_values.nrows(), 1);
962 assert_eq!(sens_values.ncols(), t_eval.len());
963 for (col, &t) in t_eval.iter().enumerate() {
964 assert_close(
965 ys[(0, col)],
966 hybrid_logistic_state(r, t),
967 5e-4,
968 &format!("jit hybrid sens value[{col}]"),
969 );
970 assert_close(
971 sens_values[(0, col)],
972 hybrid_logistic_state_dr(r, t),
973 5e-4,
974 &format!("jit hybrid sensitivity[{col}]"),
975 );
976 }
977 }
978
979 #[cfg(feature = "diffsl-llvm")]
980 #[test]
981 fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
982 let ode = make_ode(
983 JitBackendType::Llvm,
984 MatrixType::NalgebraDense,
985 OdeSolverType::Bdf,
986 );
987 ode.set_rtol(1e-8).unwrap();
988 ode.set_atol(1e-8).unwrap();
989
990 let t_eval = [0.0, 0.25, 0.5, 1.0];
991 let data_values: Vec<f64> = t_eval
992 .iter()
993 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
994 .collect();
995 let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
996 let (value, sens) = ode
997 .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
998 .unwrap();
999 let grad = Vec::<f64>::from_host_array(sens).unwrap();
1000
1001 assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1002 assert_eq!(grad.len(), 1);
1003 assert!(
1004 grad[0].is_finite(),
1005 "jit sum_squares gradient should be finite"
1006 );
1007 }
1008}