1#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::adjoint_checkpoint::AdjointCheckpointWrapper;
8use crate::error::DiffsolRtError;
9use crate::host_array::HostArray;
10use crate::host_array::ToHostArray;
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use crate::jit::JitBackendType;
13#[cfg(feature = "external")]
14use crate::scalar_type::ExternalScalar;
15use crate::scalar_type::Scalar;
16use crate::scalar_type::ScalarType;
17use crate::{
18 generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
19 option_value_from_store, option_value_to_store,
20};
21use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
22use diffsol::ObjectModule;
23use diffsol::OdeBuilder;
24use diffsol::{
25 error::DiffsolError,
26 matrix::{MatrixHost, MatrixRef},
27 CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DenseMatrix, DiffSl,
28 MatrixCommon, NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector,
29 VectorCommon, VectorHost, VectorRef,
30};
31#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
32use diffsol::{CodegenModuleCompile, CodegenModuleJit};
33use num_traits::{FromPrimitive, ToPrimitive}; use paste::paste;
35
36use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
37use crate::{
38 linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
39};
40use crate::{
41 matrix_type::MatrixKind,
42 valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
43};
44
45use crate::solution::Solution;
48pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
49
50pub(crate) trait Solve {
51 fn matrix_type(&self) -> MatrixType;
52 fn nstates(&self) -> usize;
53 fn nparams(&self) -> usize;
54 fn nout(&self) -> usize;
55 fn has_stop(&self) -> bool;
56
57 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
58
59 fn rhs_jac_mul(
60 &mut self,
61 params: &[f64],
62 t: f64,
63 y: &[f64],
64 v: &[f64],
65 ) -> Result<HostArray, DiffsolRtError>;
66
67 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
68
69 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
70 fn set_rtol(&mut self, rtol: f64);
71 fn rtol(&self) -> f64;
72 fn set_atol(&mut self, atol: f64);
73 fn atol(&self) -> f64;
74 fn set_t0(&mut self, t0: f64);
75 fn t0(&self) -> f64;
76 fn set_h0(&mut self, h0: f64);
77 fn h0(&self) -> f64;
78 fn set_integrate_out(&mut self, integrate_out: bool);
79 fn integrate_out(&self) -> bool;
80 fn set_sens_rtol(&mut self, sens_rtol: Option<f64>);
81 fn sens_rtol(&self) -> Option<f64>;
82 fn set_sens_atol(&mut self, sens_atol: Option<f64>);
83 fn sens_atol(&self) -> Option<f64>;
84 fn set_out_rtol(&mut self, out_rtol: Option<f64>);
85 fn out_rtol(&self) -> Option<f64>;
86 fn set_out_atol(&mut self, out_atol: Option<f64>);
87 fn out_atol(&self) -> Option<f64>;
88 fn set_param_rtol(&mut self, param_rtol: Option<f64>);
89 fn param_rtol(&self) -> Option<f64>;
90 fn set_param_atol(&mut self, param_atol: Option<f64>);
91 fn param_atol(&self) -> Option<f64>;
92 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
93 unsupported_serialization_error(
94 "ODE serialization is only supported for JIT-backed solvers",
95 )
96 }
97
98 fn solve(
100 &mut self,
101 method: OdeSolverType,
102 linear_solver: LinearSolverType,
103 params: &[f64],
104 final_time: f64,
105 ) -> SolveResult;
106
107 fn solve_dense(
108 &mut self,
109 method: OdeSolverType,
110 linear_solver: LinearSolverType,
111 params: &[f64],
112 t_eval: &[f64],
113 ) -> SolveResult;
114
115 fn solve_fwd_sens(
116 &mut self,
117 method: OdeSolverType,
118 linear_solver: LinearSolverType,
119 params: &[f64],
120 t_eval: &[f64],
121 ) -> SolveResult;
122
123 fn solve_continuous_adjoint(
124 &mut self,
125 method: OdeSolverType,
126 linear_solver: LinearSolverType,
127 params: &[f64],
128 final_time: f64,
129 ) -> Result<(HostArray, HostArray), DiffsolRtError>;
130
131 fn solve_adjoint_fwd(
132 &mut self,
133 method: OdeSolverType,
134 linear_solver: LinearSolverType,
135 params: &[f64],
136 t_eval: &[f64],
137 ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError>;
138
139 fn solve_adjoint_bkwd(
140 &mut self,
141 method: OdeSolverType,
142 linear_solver: LinearSolverType,
143 checkpoint: &AdjointCheckpointWrapper,
144 t_eval: &[f64],
145 dgdu_eval: HostArray,
146 ) -> Result<HostArray, DiffsolRtError>;
147
148 generate_trait_ic_option_accessors! {
149 use_linesearch: bool,
150 max_linesearch_iterations: usize,
151 max_newton_iterations: usize,
152 max_linear_solver_setups: usize,
153 step_reduction_factor: f64,
154 armijo_constant: f64
155 }
156 generate_trait_ode_option_accessors! {
157 max_nonlinear_solver_iterations: usize,
158 max_error_test_failures: usize,
159 min_timestep: f64,
160 update_jacobian_after_steps: usize,
161 update_rhs_jacobian_after_steps: usize,
162 threshold_to_update_jacobian: f64,
163 threshold_to_update_rhs_jacobian: f64
164 }
165}
166#[cfg(feature = "external")]
168pub(crate) fn solve_factory_external(
169 rhs_state_deps: Vec<(usize, usize)>,
170 rhs_input_deps: Vec<(usize, usize)>,
171 mass_state_deps: Vec<(usize, usize)>,
172 matrix_type: MatrixType,
173 scalar_type: ScalarType,
174) -> Result<Box<dyn Solve>, DiffsolRtError> {
175 let solve: Box<dyn Solve> = match matrix_type {
176 MatrixType::NalgebraDense => match scalar_type {
177 #[cfg(feature = "diffsl-external-f32")]
178 ScalarType::F32 => Box::new(GenericSolve::<
179 diffsol::NalgebraMat<f32>,
180 diffsl::ExternalModule<f32>,
181 >::from_external(
182 rhs_state_deps, rhs_input_deps, mass_state_deps, false
183 )?),
184 #[cfg(feature = "diffsl-external-f64")]
185 ScalarType::F64 => Box::new(GenericSolve::<
186 diffsol::NalgebraMat<f64>,
187 diffsl::ExternalModule<f64>,
188 >::from_external(
189 rhs_state_deps, rhs_input_deps, mass_state_deps, false
190 )?),
191 _ => {
192 return Err(DiffsolRtError::from(DiffsolError::Other(
193 "Unsupported scalar type for NalgebraDense".to_string(),
194 )));
195 }
196 },
197 MatrixType::FaerDense => match scalar_type {
198 #[cfg(feature = "diffsl-external-f32")]
199 ScalarType::F32 => Box::new(GenericSolve::<
200 diffsol::FaerMat<f32>,
201 diffsl::ExternalModule<f32>,
202 >::from_external(
203 rhs_state_deps, rhs_input_deps, mass_state_deps, false
204 )?),
205 #[cfg(feature = "diffsl-external-f64")]
206 ScalarType::F64 => Box::new(GenericSolve::<
207 diffsol::FaerMat<f64>,
208 diffsl::ExternalModule<f64>,
209 >::from_external(
210 rhs_state_deps, rhs_input_deps, mass_state_deps, false
211 )?),
212 _ => {
213 return Err(DiffsolRtError::from(DiffsolError::Other(
214 "Unsupported scalar type for FaerDense".to_string(),
215 )));
216 }
217 },
218 MatrixType::FaerSparse => match scalar_type {
219 #[cfg(feature = "diffsl-external-f32")]
220 ScalarType::F32 => Box::new(GenericSolve::<
221 diffsol::FaerSparseMat<f32>,
222 diffsl::ExternalModule<f32>,
223 >::from_external(
224 rhs_state_deps, rhs_input_deps, mass_state_deps, false
225 )?),
226 #[cfg(feature = "diffsl-external-f64")]
227 ScalarType::F64 => Box::new(GenericSolve::<
228 diffsol::FaerSparseMat<f64>,
229 diffsl::ExternalModule<f64>,
230 >::from_external(
231 rhs_state_deps, rhs_input_deps, mass_state_deps, false
232 )?),
233 _ => {
234 return Err(DiffsolRtError::from(DiffsolError::Other(
235 "Unsupported scalar type for FaerSparse".to_string(),
236 )));
237 }
238 },
239 };
240 Ok(solve)
241}
242
243#[cfg(feature = "diffsl-external-dynamic")]
244pub(crate) fn solve_factory_external_dynamic(
245 path: PathBuf,
246 rhs_state_deps: Vec<(usize, usize)>,
247 rhs_input_deps: Vec<(usize, usize)>,
248 mass_state_deps: Vec<(usize, usize)>,
249 matrix_type: MatrixType,
250 scalar_type: ScalarType,
251) -> Result<Box<dyn Solve>, DiffsolRtError> {
252 let solve: Box<dyn Solve> = match matrix_type {
253 MatrixType::NalgebraDense => match scalar_type {
254 ScalarType::F32 => Box::new(GenericSolve::<
255 diffsol::NalgebraMat<f32>,
256 diffsl::ExternalDynModule<f32>,
257 >::from_external_dynamic(
258 path,
259 rhs_state_deps,
260 rhs_input_deps,
261 mass_state_deps,
262 false,
263 )?),
264 ScalarType::F64 => Box::new(GenericSolve::<
265 diffsol::NalgebraMat<f64>,
266 diffsl::ExternalDynModule<f64>,
267 >::from_external_dynamic(
268 path,
269 rhs_state_deps,
270 rhs_input_deps,
271 mass_state_deps,
272 false,
273 )?),
274 },
275 MatrixType::FaerDense => match scalar_type {
276 ScalarType::F32 => Box::new(GenericSolve::<
277 diffsol::FaerMat<f32>,
278 diffsl::ExternalDynModule<f32>,
279 >::from_external_dynamic(
280 path,
281 rhs_state_deps,
282 rhs_input_deps,
283 mass_state_deps,
284 false,
285 )?),
286 ScalarType::F64 => Box::new(GenericSolve::<
287 diffsol::FaerMat<f64>,
288 diffsl::ExternalDynModule<f64>,
289 >::from_external_dynamic(
290 path,
291 rhs_state_deps,
292 rhs_input_deps,
293 mass_state_deps,
294 false,
295 )?),
296 },
297 MatrixType::FaerSparse => match scalar_type {
298 ScalarType::F32 => Box::new(GenericSolve::<
299 diffsol::FaerSparseMat<f32>,
300 diffsl::ExternalDynModule<f32>,
301 >::from_external_dynamic(
302 path,
303 rhs_state_deps,
304 rhs_input_deps,
305 mass_state_deps,
306 false,
307 )?),
308 ScalarType::F64 => Box::new(GenericSolve::<
309 diffsol::FaerSparseMat<f64>,
310 diffsl::ExternalDynModule<f64>,
311 >::from_external_dynamic(
312 path,
313 rhs_state_deps,
314 rhs_input_deps,
315 mass_state_deps,
316 false,
317 )?),
318 },
319 };
320 Ok(solve)
321}
322
323#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
324pub(crate) fn solve_factory_jit(
325 code: &str,
326 jit_backend: JitBackendType,
327 matrix_type: MatrixType,
328 scalar_type: ScalarType,
329) -> Result<Box<dyn Solve>, DiffsolRtError> {
330 match jit_backend {
331 #[cfg(feature = "diffsl-cranelift")]
332 JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
333 code,
334 matrix_type,
335 scalar_type,
336 ),
337 #[cfg(feature = "diffsl-llvm")]
338 JitBackendType::Llvm => {
339 solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
340 }
341 }
342}
343
344#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
345fn solve_factory_with_jit_backend<CG>(
346 code: &str,
347 matrix_type: MatrixType,
348 scalar_type: ScalarType,
349) -> Result<Box<dyn Solve>, DiffsolRtError>
350where
351 CG: CodegenModule
352 + CodegenModuleJit
353 + CodegenModuleCompile
354 + SolveSerialization<diffsol::NalgebraMat<f32>>
355 + SolveSerialization<diffsol::NalgebraMat<f64>>
356 + SolveSerialization<diffsol::FaerMat<f32>>
357 + SolveSerialization<diffsol::FaerMat<f64>>
358 + SolveSerialization<diffsol::FaerSparseMat<f32>>
359 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
360{
361 let solve: Box<dyn Solve> = match matrix_type {
362 MatrixType::NalgebraDense => match scalar_type {
363 ScalarType::F32 => {
364 let problem =
365 OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
366 Box::new(GenericSolve { problem })
367 }
368 ScalarType::F64 => {
369 let problem =
370 OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
371 Box::new(GenericSolve { problem })
372 }
373 },
374 MatrixType::FaerDense => match scalar_type {
375 ScalarType::F32 => {
376 let problem =
377 OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
378 Box::new(GenericSolve { problem })
379 }
380 ScalarType::F64 => {
381 let problem =
382 OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
383 Box::new(GenericSolve { problem })
384 }
385 },
386 MatrixType::FaerSparse => match scalar_type {
387 ScalarType::F32 => {
388 let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
389 .build_from_diffsl::<CG>(code)?;
390 Box::new(GenericSolve { problem })
391 }
392 ScalarType::F64 => {
393 let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
394 .build_from_diffsl::<CG>(code)?;
395 Box::new(GenericSolve { problem })
396 }
397 },
398 };
399 Ok(solve)
400}
401
402pub(crate) fn solve_factory_from_serialized_diffsl(
403 serialized_diffsl: &[u8],
404 matrix_type: MatrixType,
405 scalar_type: ScalarType,
406) -> Result<Box<dyn Solve>, DiffsolRtError> {
407 let solve: Box<dyn Solve> = match matrix_type {
408 MatrixType::NalgebraDense => match scalar_type {
409 ScalarType::F32 => Box::new(
410 GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
411 serialized_diffsl,
412 )?,
413 ),
414 ScalarType::F64 => Box::new(
415 GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
416 serialized_diffsl,
417 )?,
418 ),
419 },
420 MatrixType::FaerDense => match scalar_type {
421 ScalarType::F32 => Box::new(
422 GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
423 serialized_diffsl,
424 )?,
425 ),
426 ScalarType::F64 => Box::new(
427 GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
428 serialized_diffsl,
429 )?,
430 ),
431 },
432 MatrixType::FaerSparse => match scalar_type {
433 ScalarType::F32 => Box::new(
434 GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
435 serialized_diffsl,
436 )?,
437 ),
438 ScalarType::F64 => Box::new(
439 GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
440 serialized_diffsl,
441 )?,
442 ),
443 },
444 };
445 Ok(solve)
446}
447
448pub(crate) struct GenericSolve<M, CG>
449where
450 M: MatrixHost<T: Scalar>,
451 M::V: Vector + VectorHost,
452 CG: CodegenModule,
453{
454 problem: OdeSolverProblem<DiffSl<M, CG>>,
455}
456
457impl<M, CG> GenericSolve<M, CG>
458where
459 M: MatrixHost<T: Scalar>,
460 M::V: Vector + VectorHost + DefaultDenseMatrix,
461 CG: CodegenModule,
462{
463 fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
464 let default_p = vec![0.0; eqn.nparams()];
465 let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
466 Ok(GenericSolve { problem })
467 }
468
469 pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
470 let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
471 let params = M::V::from_slice(¶ms, M::C::default());
472
473 let nparams = self.problem.eqn.nparams();
475 if params.len() == nparams {
476 self.problem.eqn.set_params(¶ms);
477 Ok(())
478 } else {
479 Err(DiffsolError::Other(format!(
480 "Expecting {} params but got {}",
481 nparams,
482 params.len()
483 ))
484 .into())
485 }
486 }
487
488 pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
489 where
490 DiffSl<M, CG>: serde::Serialize,
491 {
492 serde_json::to_vec(&self.problem.eqn)
493 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
494 }
495}
496
497#[cfg(feature = "external")]
498impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
499where
500 M: MatrixHost<T: ExternalScalar>,
501 M::V: Vector + VectorHost + DefaultDenseMatrix,
502{
503 pub fn from_external(
504 rhs_state_deps: Vec<(usize, usize)>,
505 rhs_input_deps: Vec<(usize, usize)>,
506 mass_state_deps: Vec<(usize, usize)>,
507 include_sensitivities: bool,
508 ) -> Result<Self, DiffsolRtError> {
509 let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
510 M::C::default(),
511 rhs_state_deps,
512 rhs_input_deps,
513 mass_state_deps,
514 include_sensitivities,
515 )?;
516 Self::from_eqn(eqn)
517 }
518}
519
520#[cfg(feature = "diffsl-external-dynamic")]
521impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
522where
523 M: MatrixHost<T: Scalar>,
524 M::V: Vector + VectorHost + DefaultDenseMatrix,
525{
526 pub fn from_external_dynamic(
527 path: impl Into<PathBuf>,
528 rhs_state_deps: Vec<(usize, usize)>,
529 rhs_input_deps: Vec<(usize, usize)>,
530 mass_state_deps: Vec<(usize, usize)>,
531 include_sensitivities: bool,
532 ) -> Result<Self, DiffsolRtError> {
533 let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
534 path,
535 M::C::default(),
536 rhs_state_deps,
537 rhs_input_deps,
538 mass_state_deps,
539 include_sensitivities,
540 )?;
541 Self::from_eqn(eqn)
542 }
543}
544
545impl<M> GenericSolve<M, ObjectModule>
546where
547 M: MatrixHost<T: Scalar>,
548 M::V: Vector + VectorHost + DefaultDenseMatrix,
549{
550 pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
551 let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
552 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
553 Self::from_eqn(eqn)
554 }
555}
556
557impl<M, CG> Solve for GenericSolve<M, CG>
558where
559 M: MatrixHost<T: Scalar + ToPrimitive>
560 + DefaultSolver
561 + LuValidator<M>
562 + KluValidator<M>
563 + MatrixKind,
564 CG: CodegenModule,
565 for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
566 for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
567 M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
568 <M::V as DefaultDenseMatrix>::M: Send + Sync,
569 for<'b> &'b M::V: VectorRef<M::V>,
570 for<'b> &'b M: MatrixRef<M>,
571 CG: SolveSerialization<M>,
572{
573 fn matrix_type(&self) -> MatrixType {
574 MatrixType::from_diffsol::<M>()
575 }
576
577 fn nstates(&self) -> usize {
578 self.problem.eqn.nstates()
579 }
580
581 fn nparams(&self) -> usize {
582 self.problem.eqn.nparams()
583 }
584
585 fn nout(&self) -> usize {
586 self.problem.eqn.nout()
587 }
588
589 fn has_stop(&self) -> bool {
590 self.problem.eqn.root().is_some()
591 }
592
593 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
594 validate_linear_solver::<M>(linear_solver)
595 }
596
597 fn set_atol(&mut self, atol: f64) {
598 self.problem.atol.fill(M::T::from_f64(atol).unwrap());
599 }
600
601 fn atol(&self) -> f64 {
602 self.problem.atol[0].to_f64().unwrap()
603 }
604
605 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
606 CG::serialized_diffsl(self)
607 }
608
609 fn set_rtol(&mut self, rtol: f64) {
610 self.problem.rtol = M::T::from_f64(rtol).unwrap();
611 }
612
613 fn rtol(&self) -> f64 {
614 self.problem.rtol.to_f64().unwrap()
615 }
616
617 fn set_t0(&mut self, t0: f64) {
618 self.problem.t0 = M::T::from_f64(t0).unwrap();
619 }
620
621 fn t0(&self) -> f64 {
622 self.problem.t0.to_f64().unwrap()
623 }
624
625 fn set_h0(&mut self, h0: f64) {
626 self.problem.h0 = M::T::from_f64(h0).unwrap();
627 }
628
629 fn h0(&self) -> f64 {
630 self.problem.h0.to_f64().unwrap()
631 }
632
633 fn set_integrate_out(&mut self, integrate_out: bool) {
634 self.problem.integrate_out = integrate_out;
635 }
636
637 fn integrate_out(&self) -> bool {
638 self.problem.integrate_out
639 }
640
641 fn set_sens_rtol(&mut self, sens_rtol: Option<f64>) {
642 self.problem.sens_rtol = sens_rtol.map(|value| M::T::from_f64(value).unwrap());
643 }
644
645 fn sens_rtol(&self) -> Option<f64> {
646 self.problem.sens_rtol.map(|value| value.to_f64().unwrap())
647 }
648
649 fn set_sens_atol(&mut self, sens_atol: Option<f64>) {
650 self.problem.sens_atol = sens_atol.map(|value| {
651 M::V::from_element(
652 self.problem.eqn.nstates(),
653 M::T::from_f64(value).unwrap(),
654 M::C::default(),
655 )
656 });
657 }
658
659 fn sens_atol(&self) -> Option<f64> {
660 self.problem
661 .sens_atol
662 .as_ref()
663 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
664 }
665
666 fn set_out_rtol(&mut self, out_rtol: Option<f64>) {
667 self.problem.out_rtol = out_rtol.map(|value| M::T::from_f64(value).unwrap());
668 }
669
670 fn out_rtol(&self) -> Option<f64> {
671 self.problem.out_rtol.map(|value| value.to_f64().unwrap())
672 }
673
674 fn set_out_atol(&mut self, out_atol: Option<f64>) {
675 self.problem.out_atol = out_atol.map(|value| {
676 let len = if self.problem.eqn.out().is_some() {
677 self.problem.eqn.nout()
678 } else {
679 self.problem.eqn.nstates()
680 };
681 M::V::from_element(len, M::T::from_f64(value).unwrap(), M::C::default())
682 });
683 }
684
685 fn out_atol(&self) -> Option<f64> {
686 self.problem
687 .out_atol
688 .as_ref()
689 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
690 }
691
692 fn set_param_rtol(&mut self, param_rtol: Option<f64>) {
693 self.problem.param_rtol = param_rtol.map(|value| M::T::from_f64(value).unwrap());
694 }
695
696 fn param_rtol(&self) -> Option<f64> {
697 self.problem.param_rtol.map(|value| value.to_f64().unwrap())
698 }
699
700 fn set_param_atol(&mut self, param_atol: Option<f64>) {
701 self.problem.param_atol = param_atol.map(|value| {
702 M::V::from_element(
703 self.problem.eqn.nparams(),
704 M::T::from_f64(value).unwrap(),
705 M::C::default(),
706 )
707 });
708 }
709
710 fn param_atol(&self) -> Option<f64> {
711 self.problem
712 .param_atol
713 .as_ref()
714 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
715 }
716
717 generate_ic_option_accessors! {
718 use_linesearch: bool,
719 max_linesearch_iterations: usize,
720 max_newton_iterations: usize,
721 max_linear_solver_setups: usize,
722 step_reduction_factor: f64,
723 armijo_constant: f64
724 }
725
726 generate_ode_option_accessors! {
727 max_nonlinear_solver_iterations: usize,
728 max_error_test_failures: usize,
729 min_timestep: f64,
730 update_jacobian_after_steps: usize,
731 update_rhs_jacobian_after_steps: usize,
732 threshold_to_update_jacobian: f64,
733 threshold_to_update_rhs_jacobian: f64
734 }
735
736 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
737 self.setup_problem(params)?;
738 let n = self.problem.eqn.nstates();
739 let mut y0 = M::V::zeros(n, M::C::default());
740 let t0 = self.problem.t0;
741 self.problem.eqn.init().call_inplace(t0, &mut y0);
742 Ok((*y0.inner()).clone().to_host_array())
743 }
744
745 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
746 self.setup_problem(params)?;
747 let n = self.problem.eqn.nstates();
748 let y = y
749 .iter()
750 .map(|&x| M::T::from_f64(x).unwrap())
751 .collect::<Vec<_>>();
752 let y_vec = M::V::from_slice(&y, M::C::default());
753 let mut dydt = M::V::zeros(n, M::C::default());
754 self.problem
755 .eqn
756 .rhs()
757 .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
758 Ok((*dydt.inner()).clone().to_host_array())
759 }
760
761 fn rhs_jac_mul(
762 &mut self,
763
764 params: &[f64],
765 t: f64,
766 y: &[f64],
767 v: &[f64],
768 ) -> Result<HostArray, DiffsolRtError> {
769 self.setup_problem(params)?;
770 let n = self.problem.eqn.nstates();
771 let y = y
772 .iter()
773 .map(|&x| M::T::from_f64(x).unwrap())
774 .collect::<Vec<_>>();
775 let v = v
776 .iter()
777 .map(|&x| M::T::from_f64(x).unwrap())
778 .collect::<Vec<_>>();
779 let y_vec = M::V::from_slice(&y, M::C::default());
780 let v_vec = M::V::from_slice(&v, M::C::default());
781 let mut dydt = M::V::zeros(n, M::C::default());
782 self.problem.eqn.rhs().jac_mul_inplace(
783 &y_vec,
784 M::T::from_f64(t).unwrap(),
785 &v_vec,
786 &mut dydt,
787 );
788 Ok((*dydt.inner()).clone().to_host_array())
789 }
790
791 fn solve(
792 &mut self,
793 method: OdeSolverType,
794 linear_solver: LinearSolverType,
795 params: &[f64],
796 final_time: f64,
797 ) -> SolveResult {
798 self.check(linear_solver)?;
799 self.setup_problem(params)?;
800 let final_time = M::T::from_f64(final_time).unwrap();
801 let soln = match linear_solver {
802 LinearSolverType::Default => {
803 method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
804 }
805 LinearSolverType::Lu => {
806 method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
807 }
808 LinearSolverType::Klu => {
809 method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
810 }
811 };
812 Ok(Box::new(soln?))
813 }
814
815 fn solve_dense(
816 &mut self,
817 method: OdeSolverType,
818 linear_solver: LinearSolverType,
819 params: &[f64],
820 t_eval: &[f64],
821 ) -> SolveResult {
822 self.check(linear_solver)?;
823 self.setup_problem(params)?;
824
825 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
826 let soln =
827 match linear_solver {
828 LinearSolverType::Default => method
829 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
830 LinearSolverType::Lu => method
831 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
832 LinearSolverType::Klu => method
833 .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
834 };
835 Ok(Box::new(soln?))
836 }
837
838 fn solve_fwd_sens(
839 &mut self,
840 method: OdeSolverType,
841 linear_solver: LinearSolverType,
842 params: &[f64],
843 t_eval: &[f64],
844 ) -> SolveResult {
845 self.check(linear_solver)?;
846 self.setup_problem(params)?;
847
848 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
849 let soln = match linear_solver {
850 LinearSolverType::Default => {
851 method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
852 }
853 LinearSolverType::Lu => method
854 .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
855 LinearSolverType::Klu => method
856 .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
857 };
858 Ok(Box::new(soln?))
859 }
860
861 fn solve_continuous_adjoint(
862 &mut self,
863 method: OdeSolverType,
864 linear_solver: LinearSolverType,
865 params: &[f64],
866 final_time: f64,
867 ) -> Result<(HostArray, HostArray), DiffsolRtError> {
868 self.check(linear_solver)?;
869 self.setup_problem(params)?;
870
871 let final_time = M::T::from_f64(final_time).unwrap();
872 let integrate_out = self.problem.integrate_out;
873 self.problem.integrate_out = true;
874 let result = match linear_solver {
875 LinearSolverType::Default => method
876 .solve_continuous_adjoint::<M, CG, <M as DefaultSolver>::LS>(
877 &mut self.problem,
878 final_time,
879 ),
880 LinearSolverType::Lu => method
881 .solve_continuous_adjoint::<M, CG, <M as LuValidator<M>>::LS>(
882 &mut self.problem,
883 final_time,
884 ),
885 LinearSolverType::Klu => method
886 .solve_continuous_adjoint::<M, CG, <M as KluValidator<M>>::LS>(
887 &mut self.problem,
888 final_time,
889 ),
890 };
891 self.problem.integrate_out = integrate_out;
892 let (integral, gradient) = result?;
893 Ok((
894 (*integral.inner()).clone().to_host_array(),
895 gradient.to_host_array(),
896 ))
897 }
898
899 fn solve_adjoint_fwd(
900 &mut self,
901 method: OdeSolverType,
902 linear_solver: LinearSolverType,
903 params: &[f64],
904 t_eval: &[f64],
905 ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError> {
906 self.check(linear_solver)?;
907 self.setup_problem(params)?;
908
909 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
910 let (soln, checkpoint) = match linear_solver {
911 LinearSolverType::Default => method
912 .solve_adjoint_fwd::<M, CG, <M as DefaultSolver>::LS>(
913 &mut self.problem,
914 &t_eval,
915 params,
916 linear_solver,
917 ),
918 LinearSolverType::Lu => method.solve_adjoint_fwd::<M, CG, <M as LuValidator<M>>::LS>(
919 &mut self.problem,
920 &t_eval,
921 params,
922 linear_solver,
923 ),
924 LinearSolverType::Klu => method.solve_adjoint_fwd::<M, CG, <M as KluValidator<M>>::LS>(
925 &mut self.problem,
926 &t_eval,
927 params,
928 linear_solver,
929 ),
930 }?;
931 Ok((Box::new(soln), AdjointCheckpointWrapper::new(checkpoint)))
932 }
933
934 fn solve_adjoint_bkwd(
935 &mut self,
936 method: OdeSolverType,
937 linear_solver: LinearSolverType,
938 checkpoint: &AdjointCheckpointWrapper,
939 t_eval: &[f64],
940 dgdu_eval: HostArray,
941 ) -> Result<HostArray, DiffsolRtError> {
942 self.check(linear_solver)?;
943 let checkpoint = checkpoint.guard()?;
944 self.setup_problem(checkpoint.params())?;
945
946 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
947 let dgdu_eval = host_array_to_dense_matrix::<M>(dgdu_eval)?;
948 if dgdu_eval.nrows() != self.problem.eqn.nout() {
949 return Err(DiffsolError::Other(format!(
950 "Expected dgdu_eval to have {} rows, got {}",
951 self.problem.eqn.nout(),
952 dgdu_eval.nrows()
953 ))
954 .into());
955 }
956 if dgdu_eval.ncols() != t_eval.len() {
957 return Err(DiffsolError::Other(format!(
958 "Expected dgdu_eval to have {} columns, got {}",
959 t_eval.len(),
960 dgdu_eval.ncols()
961 ))
962 .into());
963 }
964
965 let gradient = match linear_solver {
966 LinearSolverType::Default => method
967 .solve_adjoint_bkwd::<M, CG, <M as DefaultSolver>::LS>(
968 &mut self.problem,
969 checkpoint.as_ref(),
970 &dgdu_eval,
971 &t_eval,
972 ),
973 LinearSolverType::Lu => method.solve_adjoint_bkwd::<M, CG, <M as LuValidator<M>>::LS>(
974 &mut self.problem,
975 checkpoint.as_ref(),
976 &dgdu_eval,
977 &t_eval,
978 ),
979 LinearSolverType::Klu => method
980 .solve_adjoint_bkwd::<M, CG, <M as KluValidator<M>>::LS>(
981 &mut self.problem,
982 checkpoint.as_ref(),
983 &dgdu_eval,
984 &t_eval,
985 ),
986 }?;
987 Ok(gradient.to_host_array())
988 }
989}
990
991fn host_array_to_dense_matrix<M>(
992 array: HostArray,
993) -> Result<<M::V as DefaultDenseMatrix>::M, DiffsolRtError>
994where
995 M: MatrixHost<T: Scalar>,
996 M::V: DefaultDenseMatrix,
997{
998 let view = array.as_array::<M::T>()?;
999 let mut values = Vec::with_capacity(view.nrows() * view.ncols());
1000 for col in 0..view.ncols() {
1001 for row in 0..view.nrows() {
1002 values.push(view[(row, col)]);
1003 }
1004 }
1005 Ok(<M::V as DefaultDenseMatrix>::M::from_vec(
1006 view.nrows(),
1007 view.ncols(),
1008 values,
1009 M::C::default(),
1010 ))
1011}
1012
1013#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1014mod tests {
1015 use diffsol::MatrixCommon;
1016 use diffsol::{
1017 CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
1018 };
1019
1020 #[cfg(feature = "diffsl-llvm")]
1021 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1022 use crate::{
1023 host_array::FromHostArray,
1024 linear_solver_type::LinearSolverType,
1025 matrix_type::MatrixType,
1026 ode_solver_type::OdeSolverType,
1027 scalar_type::ScalarType,
1028 test_support::{
1029 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
1030 logistic_state, LOGISTIC_X0,
1031 },
1032 };
1033
1034 use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
1035
1036 fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
1037 where
1038 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1039 CG: diffsol::CodegenModule
1040 + CodegenModuleJit
1041 + CodegenModuleCompile
1042 + SolveSerialization<diffsol::NalgebraMat<T>>,
1043 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1044 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1045 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1046 {
1047 let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1048 .build_from_diffsl::<CG>(logistic_diffsl_code())
1049 .unwrap();
1050 GenericSolve { problem }
1051 }
1052
1053 fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
1054 where
1055 CG: diffsol::CodegenModule
1056 + CodegenModuleJit
1057 + CodegenModuleCompile
1058 + SolveSerialization<diffsol::NalgebraMat<f32>>
1059 + SolveSerialization<diffsol::NalgebraMat<f64>>
1060 + SolveSerialization<diffsol::FaerMat<f32>>
1061 + SolveSerialization<diffsol::FaerMat<f64>>
1062 + SolveSerialization<diffsol::FaerSparseMat<f32>>
1063 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
1064 {
1065 for matrix_type in [
1066 MatrixType::NalgebraDense,
1067 MatrixType::FaerDense,
1068 MatrixType::FaerSparse,
1069 ] {
1070 for scalar_type in [ScalarType::F32, ScalarType::F64] {
1071 assert!(solve_factory_with_jit_backend::<CG>(
1072 logistic_diffsl_code(),
1073 matrix_type,
1074 scalar_type,
1075 )
1076 .is_ok());
1077 }
1078 }
1079 }
1080
1081 fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
1082 where
1083 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1084 CG: diffsol::CodegenModule
1085 + CodegenModuleJit
1086 + CodegenModuleCompile
1087 + SolveSerialization<diffsol::NalgebraMat<T>>,
1088 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1089 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1090 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1091 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1092 {
1093 let mut solve = make_generic_solve::<T, CG>();
1094 assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
1095 assert_eq!(solve.nstates(), 1);
1096 assert_eq!(solve.nparams(), 1);
1097 assert_eq!(solve.nout(), 1);
1098 assert!(!solve.has_stop());
1099 assert!(solve.check(LinearSolverType::Default).is_ok());
1100 assert!(solve.check(LinearSolverType::Lu).is_ok());
1101 assert!(solve.check(LinearSolverType::Klu).is_err());
1102
1103 solve.set_atol(1e-5);
1104 solve.set_rtol(1e-4);
1105 assert_close(solve.atol(), 1e-5, tol, "solve atol");
1106 assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
1107 solve.set_t0(0.125);
1108 solve.set_h0(0.25);
1109 solve.set_integrate_out(true);
1110 assert_close(solve.t0(), 0.125, tol, "solve t0");
1111 assert_close(solve.h0(), 0.25, tol, "solve h0");
1112 assert!(solve.integrate_out());
1113
1114 solve.set_sens_rtol(Some(1e-3));
1115 solve.set_sens_atol(Some(1e-4));
1116 solve.set_out_rtol(Some(2e-3));
1117 solve.set_out_atol(Some(2e-4));
1118 solve.set_param_rtol(Some(3e-3));
1119 solve.set_param_atol(Some(3e-4));
1120 assert_close(solve.sens_rtol().unwrap(), 1e-3, tol, "solve sens rtol");
1121 assert_close(solve.sens_atol().unwrap(), 1e-4, tol, "solve sens atol");
1122 assert_close(solve.out_rtol().unwrap(), 2e-3, tol, "solve out rtol");
1123 assert_close(solve.out_atol().unwrap(), 2e-4, tol, "solve out atol");
1124 assert_close(solve.param_rtol().unwrap(), 3e-3, tol, "solve param rtol");
1125 assert_close(solve.param_atol().unwrap(), 3e-4, tol, "solve param atol");
1126
1127 solve.set_sens_rtol(None);
1128 solve.set_sens_atol(None);
1129 solve.set_out_rtol(None);
1130 solve.set_out_atol(None);
1131 solve.set_param_rtol(None);
1132 solve.set_param_atol(None);
1133 assert_eq!(solve.sens_rtol(), None);
1134 assert_eq!(solve.sens_atol(), None);
1135 assert_eq!(solve.out_rtol(), None);
1136 assert_eq!(solve.out_atol(), None);
1137 assert_eq!(solve.param_rtol(), None);
1138 assert_eq!(solve.param_atol(), None);
1139
1140 let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
1141 assert_eq!(y0.len(), 1);
1142 assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
1143
1144 let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
1145 assert_eq!(rhs.len(), 1);
1146 assert_close(rhs[0], 0.375, tol, "solve rhs");
1147
1148 let jac_mul =
1149 Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
1150 .unwrap();
1151 assert_eq!(jac_mul.len(), 1);
1152 assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
1153 }
1154
1155 fn assert_solve_runtime_paths<T, CG>(tol: f64)
1156 where
1157 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1158 CG: diffsol::CodegenModule
1159 + CodegenModuleJit
1160 + CodegenModuleCompile
1161 + SolveSerialization<diffsol::NalgebraMat<T>>,
1162 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1163 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1164 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1165 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1166 {
1167 let mut solve = make_generic_solve::<T, CG>();
1168 let soln = solve
1169 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1170 .unwrap();
1171 let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1172 let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1173 assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1174 assert_close(
1175 ys[0][ts.len() - 1],
1176 logistic_state(LOGISTIC_X0, 2.0, 1.0),
1177 tol,
1178 "solve final value",
1179 );
1180
1181 let mut solve = make_generic_solve::<T, CG>();
1182 let dense = solve
1183 .solve_dense(
1184 OdeSolverType::Tsit45,
1185 LinearSolverType::Lu,
1186 &[2.0],
1187 &[0.25, 0.5, 1.0],
1188 )
1189 .unwrap();
1190 let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1191 let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1192 assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1193 for (i, &t) in ts.iter().enumerate() {
1194 assert_close(
1195 ys[0][i],
1196 logistic_state(LOGISTIC_X0, 2.0, t),
1197 tol,
1198 &format!("solve_dense[{i}]"),
1199 );
1200 }
1201
1202 let mut solve = make_generic_solve::<T, CG>();
1203 let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1204 Ok(_) => panic!("expected parameter count mismatch"),
1205 Err(err) => err,
1206 };
1207 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1208
1209 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1210 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1211 .unwrap();
1212 let mut hybrid_solve = GenericSolve {
1213 problem: hybrid_problem,
1214 };
1215 let hybrid = hybrid_solve
1216 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1217 .unwrap();
1218 let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1219 let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1220 assert_close(*hybrid_ts.last().unwrap(), 2.0, tol, "solve final time");
1221 assert_close(
1222 hybrid_ys[0][hybrid_ts.len() - 1],
1223 hybrid_logistic_state(2.0, 2.0),
1224 tol,
1225 "solve final value",
1226 );
1227
1228 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1229 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1230 .unwrap();
1231 let mut hybrid_solve = GenericSolve {
1232 problem: hybrid_problem,
1233 };
1234 let hybrid_dense = hybrid_solve
1235 .solve_dense(
1236 OdeSolverType::Tsit45,
1237 LinearSolverType::Lu,
1238 &[2.0],
1239 &[0.5, 1.0, 1.5, 2.0],
1240 )
1241 .unwrap();
1242 let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1243 let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1244 assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1245 for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1246 assert_close(
1247 hybrid_dense_ys[0][i],
1248 hybrid_logistic_state(2.0, t),
1249 tol,
1250 &format!("solve_dense[{i}]"),
1251 );
1252 }
1253 }
1254
1255 #[cfg(feature = "diffsl-llvm")]
1256 fn assert_solve_sensitivity_paths() {
1257 let t_eval = [0.25, 0.5, 1.0];
1258
1259 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1260 let sens = solve
1261 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1262 .unwrap();
1263 let sens_values = sens.get_sens();
1264 assert_eq!(sens_values.len(), 1);
1265 let sens_matrix =
1266 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1267 for (i, &t) in t_eval.iter().enumerate() {
1268 assert_close(
1269 sens_matrix[0][i],
1270 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1271 5e-4,
1272 &format!("solve_fwd_sens[{i}]"),
1273 );
1274 }
1275
1276 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1277 .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1278 .unwrap();
1279 let mut solve = GenericSolve {
1280 problem: hybrid_problem,
1281 };
1282 let hybrid_sens = solve
1283 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1284 .unwrap();
1285 let sens_values = hybrid_sens.get_sens();
1286 let sens_matrix =
1287 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1288 for (i, &t) in t_eval.iter().enumerate() {
1289 assert_close(
1290 sens_matrix[0][i],
1291 hybrid_logistic_state_dr(2.0, t),
1292 5e-4,
1293 &format!("solve_fwd_sens[{i}]"),
1294 );
1295 }
1296 }
1297
1298 #[cfg(feature = "diffsl-cranelift")]
1299 #[test]
1300 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1301 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1302 }
1303
1304 #[cfg(feature = "diffsl-cranelift")]
1305 #[test]
1306 fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1307 assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1308 assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1309 assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1310 assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1311 }
1312
1313 #[cfg(feature = "diffsl-llvm")]
1314 #[test]
1315 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1316 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1317 }
1318
1319 #[cfg(feature = "diffsl-llvm")]
1320 #[test]
1321 fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1322 assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1323 assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1324 assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1325 assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1326 }
1327
1328 #[cfg(feature = "diffsl-llvm")]
1329 #[test]
1330 fn solve_trait_sensitivity_paths_for_llvm() {
1331 assert_solve_sensitivity_paths();
1332 }
1333
1334 #[cfg(feature = "diffsl-cranelift")]
1335 #[test]
1336 fn setup_problem_validates_parameter_count_for_cranelift() {
1337 let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1338 let err = solve.setup_problem(&[]).unwrap_err();
1339 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1340
1341 solve.setup_problem(&[2.0]).unwrap();
1342 let mut params = solve.problem.context().vector_zeros(1);
1343 solve.problem.eqn.get_params(&mut params);
1344 assert_eq!(params.get_index(0), 2.0);
1345 }
1346
1347 #[cfg(feature = "diffsl-llvm")]
1348 #[test]
1349 fn setup_problem_validates_parameter_count_for_llvm() {
1350 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1351 let err = solve.setup_problem(&[]).unwrap_err();
1352 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1353
1354 solve.setup_problem(&[2.0]).unwrap();
1355 let mut params = solve.problem.context().vector_zeros(1);
1356 solve.problem.eqn.get_params(&mut params);
1357 assert_eq!(params.get_index(0), 2.0);
1358 }
1359}