1#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::adjoint_checkpoint::AdjointCheckpointWrapper;
8use crate::error::DiffsolRtError;
9use crate::host_array::HostArray;
10use crate::host_array::ToHostArray;
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use crate::jit::JitBackendType;
13#[cfg(feature = "external")]
14use crate::scalar_type::ExternalScalar;
15use crate::scalar_type::Scalar;
16use crate::scalar_type::ScalarType;
17use crate::{
18 generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
19 option_value_from_store, option_value_to_store,
20};
21use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
22use diffsol::ObjectModule;
23use diffsol::OdeBuilder;
24use diffsol::{
25 error::DiffsolError,
26 matrix::{MatrixHost, MatrixRef},
27 CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DenseMatrix, DiffSl,
28 MatrixCommon, NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector,
29 VectorCommon, VectorHost, VectorRef,
30};
31#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
32use diffsol::{CodegenModuleCompile, CodegenModuleJit};
33use num_traits::{FromPrimitive, ToPrimitive}; use paste::paste;
35
36use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
37use crate::{
38 linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
39};
40use crate::{
41 matrix_type::MatrixKind,
42 valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
43};
44
45use crate::solution::Solution;
48pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
49
50pub(crate) trait Solve {
51 fn matrix_type(&self) -> MatrixType;
52 fn nstates(&self) -> usize;
53 fn nparams(&self) -> usize;
54 fn nout(&self) -> usize;
55 fn has_stop(&self) -> bool;
56
57 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
58
59 fn rhs_jac_mul(
60 &mut self,
61 params: &[f64],
62 t: f64,
63 y: &[f64],
64 v: &[f64],
65 ) -> Result<HostArray, DiffsolRtError>;
66
67 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
68
69 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
70 fn set_rtol(&mut self, rtol: f64);
71 fn rtol(&self) -> f64;
72 fn set_atol(&mut self, atol: f64);
73 fn atol(&self) -> f64;
74 fn set_t0(&mut self, t0: f64);
75 fn t0(&self) -> f64;
76 fn set_h0(&mut self, h0: f64);
77 fn h0(&self) -> f64;
78 fn set_integrate_out(&mut self, integrate_out: bool);
79 fn integrate_out(&self) -> bool;
80 fn set_sens_rtol(&mut self, sens_rtol: Option<f64>);
81 fn sens_rtol(&self) -> Option<f64>;
82 fn set_sens_atol(&mut self, sens_atol: Option<f64>);
83 fn sens_atol(&self) -> Option<f64>;
84 fn set_out_rtol(&mut self, out_rtol: Option<f64>);
85 fn out_rtol(&self) -> Option<f64>;
86 fn set_out_atol(&mut self, out_atol: Option<f64>);
87 fn out_atol(&self) -> Option<f64>;
88 fn set_param_rtol(&mut self, param_rtol: Option<f64>);
89 fn param_rtol(&self) -> Option<f64>;
90 fn set_param_atol(&mut self, param_atol: Option<f64>);
91 fn param_atol(&self) -> Option<f64>;
92 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
93 unsupported_serialization_error(
94 "ODE serialization is only supported for JIT-backed solvers",
95 )
96 }
97
98 fn solve(
100 &mut self,
101 method: OdeSolverType,
102 linear_solver: LinearSolverType,
103 params: &[f64],
104 final_time: f64,
105 ) -> SolveResult;
106
107 fn solve_dense(
108 &mut self,
109 method: OdeSolverType,
110 linear_solver: LinearSolverType,
111 params: &[f64],
112 t_eval: &[f64],
113 ) -> SolveResult;
114
115 fn solve_fwd_sens(
116 &mut self,
117 method: OdeSolverType,
118 linear_solver: LinearSolverType,
119 params: &[f64],
120 t_eval: &[f64],
121 ) -> SolveResult;
122
123 fn solve_continuous_adjoint(
124 &mut self,
125 method: OdeSolverType,
126 linear_solver: LinearSolverType,
127 params: &[f64],
128 final_time: f64,
129 ) -> Result<(HostArray, HostArray), DiffsolRtError>;
130
131 fn solve_adjoint_fwd(
132 &mut self,
133 method: OdeSolverType,
134 linear_solver: LinearSolverType,
135 params: &[f64],
136 t_eval: &[f64],
137 ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError>;
138
139 fn solve_adjoint_bkwd(
140 &mut self,
141 method: OdeSolverType,
142 linear_solver: LinearSolverType,
143 checkpoint: &AdjointCheckpointWrapper,
144 t_eval: &[f64],
145 dgdu_eval: HostArray,
146 ) -> Result<HostArray, DiffsolRtError>;
147
148 generate_trait_ic_option_accessors! {
149 use_linesearch: bool,
150 max_linesearch_iterations: usize,
151 max_newton_iterations: usize,
152 max_linear_solver_setups: usize,
153 step_reduction_factor: f64,
154 armijo_constant: f64
155 }
156 generate_trait_ode_option_accessors! {
157 max_nonlinear_solver_iterations: usize,
158 max_error_test_failures: usize,
159 min_timestep: f64,
160 update_jacobian_after_steps: usize,
161 update_rhs_jacobian_after_steps: usize,
162 threshold_to_update_jacobian: f64,
163 threshold_to_update_rhs_jacobian: f64
164 }
165}
166#[cfg(feature = "external")]
168pub(crate) fn solve_factory_external(
169 rhs_state_deps: Vec<(usize, usize)>,
170 rhs_input_deps: Vec<(usize, usize)>,
171 mass_state_deps: Vec<(usize, usize)>,
172 matrix_type: MatrixType,
173 scalar_type: ScalarType,
174) -> Result<Box<dyn Solve>, DiffsolRtError> {
175 let solve: Box<dyn Solve> = match matrix_type {
176 MatrixType::NalgebraDense => match scalar_type {
177 #[cfg(feature = "diffsl-external-f32")]
178 ScalarType::F32 => Box::new(GenericSolve::<
179 diffsol::NalgebraMat<f32>,
180 diffsl::ExternalModule<f32>,
181 >::from_external(
182 rhs_state_deps, rhs_input_deps, mass_state_deps, false
183 )?),
184 #[cfg(feature = "diffsl-external-f64")]
185 ScalarType::F64 => Box::new(GenericSolve::<
186 diffsol::NalgebraMat<f64>,
187 diffsl::ExternalModule<f64>,
188 >::from_external(
189 rhs_state_deps, rhs_input_deps, mass_state_deps, false
190 )?),
191 _ => {
192 return Err(DiffsolRtError::from(DiffsolError::Other(
193 "Unsupported scalar type for NalgebraDense".to_string(),
194 )));
195 }
196 },
197 MatrixType::FaerDense => match scalar_type {
198 #[cfg(feature = "diffsl-external-f32")]
199 ScalarType::F32 => Box::new(GenericSolve::<
200 diffsol::FaerMat<f32>,
201 diffsl::ExternalModule<f32>,
202 >::from_external(
203 rhs_state_deps, rhs_input_deps, mass_state_deps, false
204 )?),
205 #[cfg(feature = "diffsl-external-f64")]
206 ScalarType::F64 => Box::new(GenericSolve::<
207 diffsol::FaerMat<f64>,
208 diffsl::ExternalModule<f64>,
209 >::from_external(
210 rhs_state_deps, rhs_input_deps, mass_state_deps, false
211 )?),
212 _ => {
213 return Err(DiffsolRtError::from(DiffsolError::Other(
214 "Unsupported scalar type for FaerDense".to_string(),
215 )));
216 }
217 },
218 MatrixType::FaerSparse => match scalar_type {
219 #[cfg(feature = "diffsl-external-f32")]
220 ScalarType::F32 => Box::new(GenericSolve::<
221 diffsol::FaerSparseMat<f32>,
222 diffsl::ExternalModule<f32>,
223 >::from_external(
224 rhs_state_deps, rhs_input_deps, mass_state_deps, false
225 )?),
226 #[cfg(feature = "diffsl-external-f64")]
227 ScalarType::F64 => Box::new(GenericSolve::<
228 diffsol::FaerSparseMat<f64>,
229 diffsl::ExternalModule<f64>,
230 >::from_external(
231 rhs_state_deps, rhs_input_deps, mass_state_deps, false
232 )?),
233 _ => {
234 return Err(DiffsolRtError::from(DiffsolError::Other(
235 "Unsupported scalar type for FaerSparse".to_string(),
236 )));
237 }
238 },
239 };
240 Ok(solve)
241}
242
243#[cfg(feature = "diffsl-external-dynamic")]
244pub(crate) fn solve_factory_external_dynamic(
245 path: PathBuf,
246 rhs_state_deps: Vec<(usize, usize)>,
247 rhs_input_deps: Vec<(usize, usize)>,
248 mass_state_deps: Vec<(usize, usize)>,
249 matrix_type: MatrixType,
250 scalar_type: ScalarType,
251) -> Result<Box<dyn Solve>, DiffsolRtError> {
252 let solve: Box<dyn Solve> = match matrix_type {
253 MatrixType::NalgebraDense => match scalar_type {
254 ScalarType::F32 => Box::new(GenericSolve::<
255 diffsol::NalgebraMat<f32>,
256 diffsl::ExternalDynModule<f32>,
257 >::from_external_dynamic(
258 path,
259 rhs_state_deps,
260 rhs_input_deps,
261 mass_state_deps,
262 false,
263 )?),
264 ScalarType::F64 => Box::new(GenericSolve::<
265 diffsol::NalgebraMat<f64>,
266 diffsl::ExternalDynModule<f64>,
267 >::from_external_dynamic(
268 path,
269 rhs_state_deps,
270 rhs_input_deps,
271 mass_state_deps,
272 false,
273 )?),
274 },
275 MatrixType::FaerDense => match scalar_type {
276 ScalarType::F32 => Box::new(GenericSolve::<
277 diffsol::FaerMat<f32>,
278 diffsl::ExternalDynModule<f32>,
279 >::from_external_dynamic(
280 path,
281 rhs_state_deps,
282 rhs_input_deps,
283 mass_state_deps,
284 false,
285 )?),
286 ScalarType::F64 => Box::new(GenericSolve::<
287 diffsol::FaerMat<f64>,
288 diffsl::ExternalDynModule<f64>,
289 >::from_external_dynamic(
290 path,
291 rhs_state_deps,
292 rhs_input_deps,
293 mass_state_deps,
294 false,
295 )?),
296 },
297 MatrixType::FaerSparse => match scalar_type {
298 ScalarType::F32 => Box::new(GenericSolve::<
299 diffsol::FaerSparseMat<f32>,
300 diffsl::ExternalDynModule<f32>,
301 >::from_external_dynamic(
302 path,
303 rhs_state_deps,
304 rhs_input_deps,
305 mass_state_deps,
306 false,
307 )?),
308 ScalarType::F64 => Box::new(GenericSolve::<
309 diffsol::FaerSparseMat<f64>,
310 diffsl::ExternalDynModule<f64>,
311 >::from_external_dynamic(
312 path,
313 rhs_state_deps,
314 rhs_input_deps,
315 mass_state_deps,
316 false,
317 )?),
318 },
319 };
320 Ok(solve)
321}
322
323#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
324pub(crate) fn solve_factory_jit(
325 code: &str,
326 jit_backend: JitBackendType,
327 matrix_type: MatrixType,
328 scalar_type: ScalarType,
329) -> Result<Box<dyn Solve>, DiffsolRtError> {
330 match jit_backend {
331 #[cfg(feature = "diffsl-cranelift")]
332 JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
333 code,
334 matrix_type,
335 scalar_type,
336 ),
337 #[cfg(feature = "diffsl-llvm")]
338 JitBackendType::Llvm => {
339 solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
340 }
341 }
342}
343
344#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
345fn solve_factory_with_jit_backend<CG>(
346 code: &str,
347 matrix_type: MatrixType,
348 scalar_type: ScalarType,
349) -> Result<Box<dyn Solve>, DiffsolRtError>
350where
351 CG: CodegenModule
352 + CodegenModuleJit
353 + CodegenModuleCompile
354 + SolveSerialization<diffsol::NalgebraMat<f32>>
355 + SolveSerialization<diffsol::NalgebraMat<f64>>
356 + SolveSerialization<diffsol::FaerMat<f32>>
357 + SolveSerialization<diffsol::FaerMat<f64>>
358 + SolveSerialization<diffsol::FaerSparseMat<f32>>
359 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
360{
361 let solve: Box<dyn Solve> = match matrix_type {
362 MatrixType::NalgebraDense => match scalar_type {
363 ScalarType::F32 => {
364 let problem =
365 OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
366 Box::new(GenericSolve { problem })
367 }
368 ScalarType::F64 => {
369 let problem =
370 OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
371 Box::new(GenericSolve { problem })
372 }
373 },
374 MatrixType::FaerDense => match scalar_type {
375 ScalarType::F32 => {
376 let problem =
377 OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
378 Box::new(GenericSolve { problem })
379 }
380 ScalarType::F64 => {
381 let problem =
382 OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
383 Box::new(GenericSolve { problem })
384 }
385 },
386 MatrixType::FaerSparse => match scalar_type {
387 ScalarType::F32 => {
388 let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
389 .build_from_diffsl::<CG>(code)?;
390 Box::new(GenericSolve { problem })
391 }
392 ScalarType::F64 => {
393 let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
394 .build_from_diffsl::<CG>(code)?;
395 Box::new(GenericSolve { problem })
396 }
397 },
398 };
399 Ok(solve)
400}
401
402pub(crate) fn solve_factory_from_serialized_diffsl(
403 serialized_diffsl: &[u8],
404 matrix_type: MatrixType,
405 scalar_type: ScalarType,
406) -> Result<Box<dyn Solve>, DiffsolRtError> {
407 let solve: Box<dyn Solve> = match matrix_type {
408 MatrixType::NalgebraDense => match scalar_type {
409 ScalarType::F32 => Box::new(
410 GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
411 serialized_diffsl,
412 )?,
413 ),
414 ScalarType::F64 => Box::new(
415 GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
416 serialized_diffsl,
417 )?,
418 ),
419 },
420 MatrixType::FaerDense => match scalar_type {
421 ScalarType::F32 => Box::new(
422 GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
423 serialized_diffsl,
424 )?,
425 ),
426 ScalarType::F64 => Box::new(
427 GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
428 serialized_diffsl,
429 )?,
430 ),
431 },
432 MatrixType::FaerSparse => match scalar_type {
433 ScalarType::F32 => Box::new(
434 GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
435 serialized_diffsl,
436 )?,
437 ),
438 ScalarType::F64 => Box::new(
439 GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
440 serialized_diffsl,
441 )?,
442 ),
443 },
444 };
445 Ok(solve)
446}
447
448pub(crate) struct GenericSolve<M, CG>
449where
450 M: MatrixHost<T: Scalar>,
451 M::V: Vector + VectorHost,
452 CG: CodegenModule,
453{
454 problem: OdeSolverProblem<DiffSl<M, CG>>,
455}
456
457impl<M, CG> GenericSolve<M, CG>
458where
459 M: MatrixHost<T: Scalar>,
460 M::V: Vector + VectorHost + DefaultDenseMatrix,
461 CG: CodegenModule,
462{
463 fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
464 let default_p = vec![0.0; eqn.nparams()];
465 let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
466 Ok(GenericSolve { problem })
467 }
468
469 pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
470 let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
471 let params = M::V::from_slice(¶ms, M::C::default());
472
473 let nparams = self.problem.eqn.nparams();
475 if params.len() == nparams {
476 self.problem.eqn.set_params(¶ms);
477 self.problem.eqn.set_model_index(0);
479 Ok(())
480 } else {
481 Err(DiffsolError::Other(format!(
482 "Expecting {} params but got {}",
483 nparams,
484 params.len()
485 ))
486 .into())
487 }
488 }
489
490 pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
491 where
492 DiffSl<M, CG>: serde::Serialize,
493 {
494 serde_json::to_vec(&self.problem.eqn)
495 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
496 }
497}
498
499#[cfg(feature = "external")]
500impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
501where
502 M: MatrixHost<T: ExternalScalar>,
503 M::V: Vector + VectorHost + DefaultDenseMatrix,
504{
505 pub fn from_external(
506 rhs_state_deps: Vec<(usize, usize)>,
507 rhs_input_deps: Vec<(usize, usize)>,
508 mass_state_deps: Vec<(usize, usize)>,
509 include_sensitivities: bool,
510 ) -> Result<Self, DiffsolRtError> {
511 let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
512 M::C::default(),
513 rhs_state_deps,
514 rhs_input_deps,
515 mass_state_deps,
516 include_sensitivities,
517 )?;
518 Self::from_eqn(eqn)
519 }
520}
521
522#[cfg(feature = "diffsl-external-dynamic")]
523impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
524where
525 M: MatrixHost<T: Scalar>,
526 M::V: Vector + VectorHost + DefaultDenseMatrix,
527{
528 pub fn from_external_dynamic(
529 path: impl Into<PathBuf>,
530 rhs_state_deps: Vec<(usize, usize)>,
531 rhs_input_deps: Vec<(usize, usize)>,
532 mass_state_deps: Vec<(usize, usize)>,
533 include_sensitivities: bool,
534 ) -> Result<Self, DiffsolRtError> {
535 let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
536 path,
537 M::C::default(),
538 rhs_state_deps,
539 rhs_input_deps,
540 mass_state_deps,
541 include_sensitivities,
542 )?;
543 Self::from_eqn(eqn)
544 }
545}
546
547impl<M> GenericSolve<M, ObjectModule>
548where
549 M: MatrixHost<T: Scalar>,
550 M::V: Vector + VectorHost + DefaultDenseMatrix,
551{
552 pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
553 let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
554 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
555 Self::from_eqn(eqn)
556 }
557}
558
559impl<M, CG> Solve for GenericSolve<M, CG>
560where
561 M: MatrixHost<T: Scalar + ToPrimitive>
562 + DefaultSolver
563 + LuValidator<M>
564 + KluValidator<M>
565 + MatrixKind,
566 CG: CodegenModule,
567 for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
568 for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
569 M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
570 <M::V as DefaultDenseMatrix>::M: Send + Sync,
571 for<'b> &'b M::V: VectorRef<M::V>,
572 for<'b> &'b M: MatrixRef<M>,
573 CG: SolveSerialization<M>,
574{
575 fn matrix_type(&self) -> MatrixType {
576 MatrixType::from_diffsol::<M>()
577 }
578
579 fn nstates(&self) -> usize {
580 self.problem.eqn.nstates()
581 }
582
583 fn nparams(&self) -> usize {
584 self.problem.eqn.nparams()
585 }
586
587 fn nout(&self) -> usize {
588 self.problem.eqn.nout()
589 }
590
591 fn has_stop(&self) -> bool {
592 self.problem.eqn.root().is_some()
593 }
594
595 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
596 validate_linear_solver::<M>(linear_solver)
597 }
598
599 fn set_atol(&mut self, atol: f64) {
600 self.problem.atol.fill(M::T::from_f64(atol).unwrap());
601 }
602
603 fn atol(&self) -> f64 {
604 self.problem.atol[0].to_f64().unwrap()
605 }
606
607 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
608 CG::serialized_diffsl(self)
609 }
610
611 fn set_rtol(&mut self, rtol: f64) {
612 self.problem.rtol = M::T::from_f64(rtol).unwrap();
613 }
614
615 fn rtol(&self) -> f64 {
616 self.problem.rtol.to_f64().unwrap()
617 }
618
619 fn set_t0(&mut self, t0: f64) {
620 self.problem.t0 = M::T::from_f64(t0).unwrap();
621 }
622
623 fn t0(&self) -> f64 {
624 self.problem.t0.to_f64().unwrap()
625 }
626
627 fn set_h0(&mut self, h0: f64) {
628 self.problem.h0 = M::T::from_f64(h0).unwrap();
629 }
630
631 fn h0(&self) -> f64 {
632 self.problem.h0.to_f64().unwrap()
633 }
634
635 fn set_integrate_out(&mut self, integrate_out: bool) {
636 self.problem.integrate_out = integrate_out;
637 }
638
639 fn integrate_out(&self) -> bool {
640 self.problem.integrate_out
641 }
642
643 fn set_sens_rtol(&mut self, sens_rtol: Option<f64>) {
644 self.problem.sens_rtol = sens_rtol.map(|value| M::T::from_f64(value).unwrap());
645 }
646
647 fn sens_rtol(&self) -> Option<f64> {
648 self.problem.sens_rtol.map(|value| value.to_f64().unwrap())
649 }
650
651 fn set_sens_atol(&mut self, sens_atol: Option<f64>) {
652 self.problem.sens_atol = sens_atol.map(|value| {
653 M::V::from_element(
654 self.problem.eqn.nstates(),
655 M::T::from_f64(value).unwrap(),
656 M::C::default(),
657 )
658 });
659 }
660
661 fn sens_atol(&self) -> Option<f64> {
662 self.problem
663 .sens_atol
664 .as_ref()
665 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
666 }
667
668 fn set_out_rtol(&mut self, out_rtol: Option<f64>) {
669 self.problem.out_rtol = out_rtol.map(|value| M::T::from_f64(value).unwrap());
670 }
671
672 fn out_rtol(&self) -> Option<f64> {
673 self.problem.out_rtol.map(|value| value.to_f64().unwrap())
674 }
675
676 fn set_out_atol(&mut self, out_atol: Option<f64>) {
677 self.problem.out_atol = out_atol.map(|value| {
678 let len = if self.problem.eqn.out().is_some() {
679 self.problem.eqn.nout()
680 } else {
681 self.problem.eqn.nstates()
682 };
683 M::V::from_element(len, M::T::from_f64(value).unwrap(), M::C::default())
684 });
685 }
686
687 fn out_atol(&self) -> Option<f64> {
688 self.problem
689 .out_atol
690 .as_ref()
691 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
692 }
693
694 fn set_param_rtol(&mut self, param_rtol: Option<f64>) {
695 self.problem.param_rtol = param_rtol.map(|value| M::T::from_f64(value).unwrap());
696 }
697
698 fn param_rtol(&self) -> Option<f64> {
699 self.problem.param_rtol.map(|value| value.to_f64().unwrap())
700 }
701
702 fn set_param_atol(&mut self, param_atol: Option<f64>) {
703 self.problem.param_atol = param_atol.map(|value| {
704 M::V::from_element(
705 self.problem.eqn.nparams(),
706 M::T::from_f64(value).unwrap(),
707 M::C::default(),
708 )
709 });
710 }
711
712 fn param_atol(&self) -> Option<f64> {
713 self.problem
714 .param_atol
715 .as_ref()
716 .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
717 }
718
719 generate_ic_option_accessors! {
720 use_linesearch: bool,
721 max_linesearch_iterations: usize,
722 max_newton_iterations: usize,
723 max_linear_solver_setups: usize,
724 step_reduction_factor: f64,
725 armijo_constant: f64
726 }
727
728 generate_ode_option_accessors! {
729 max_nonlinear_solver_iterations: usize,
730 max_error_test_failures: usize,
731 min_timestep: f64,
732 update_jacobian_after_steps: usize,
733 update_rhs_jacobian_after_steps: usize,
734 threshold_to_update_jacobian: f64,
735 threshold_to_update_rhs_jacobian: f64
736 }
737
738 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
739 self.setup_problem(params)?;
740 let n = self.problem.eqn.nstates();
741 let mut y0 = M::V::zeros(n, M::C::default());
742 let t0 = self.problem.t0;
743 self.problem.eqn.init().call_inplace(t0, &mut y0);
744 Ok((*y0.inner()).clone().to_host_array())
745 }
746
747 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
748 self.setup_problem(params)?;
749 let n = self.problem.eqn.nstates();
750 let y = y
751 .iter()
752 .map(|&x| M::T::from_f64(x).unwrap())
753 .collect::<Vec<_>>();
754 let y_vec = M::V::from_slice(&y, M::C::default());
755 let mut dydt = M::V::zeros(n, M::C::default());
756 self.problem
757 .eqn
758 .rhs()
759 .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
760 Ok((*dydt.inner()).clone().to_host_array())
761 }
762
763 fn rhs_jac_mul(
764 &mut self,
765
766 params: &[f64],
767 t: f64,
768 y: &[f64],
769 v: &[f64],
770 ) -> Result<HostArray, DiffsolRtError> {
771 self.setup_problem(params)?;
772 let n = self.problem.eqn.nstates();
773 let y = y
774 .iter()
775 .map(|&x| M::T::from_f64(x).unwrap())
776 .collect::<Vec<_>>();
777 let v = v
778 .iter()
779 .map(|&x| M::T::from_f64(x).unwrap())
780 .collect::<Vec<_>>();
781 let y_vec = M::V::from_slice(&y, M::C::default());
782 let v_vec = M::V::from_slice(&v, M::C::default());
783 let mut dydt = M::V::zeros(n, M::C::default());
784 self.problem.eqn.rhs().jac_mul_inplace(
785 &y_vec,
786 M::T::from_f64(t).unwrap(),
787 &v_vec,
788 &mut dydt,
789 );
790 Ok((*dydt.inner()).clone().to_host_array())
791 }
792
793 fn solve(
794 &mut self,
795 method: OdeSolverType,
796 linear_solver: LinearSolverType,
797 params: &[f64],
798 final_time: f64,
799 ) -> SolveResult {
800 self.check(linear_solver)?;
801 self.setup_problem(params)?;
802 let final_time = M::T::from_f64(final_time).unwrap();
803 let soln = match linear_solver {
804 LinearSolverType::Default => {
805 method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
806 }
807 LinearSolverType::Lu => {
808 method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
809 }
810 LinearSolverType::Klu => {
811 method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
812 }
813 };
814 Ok(Box::new(soln?))
815 }
816
817 fn solve_dense(
818 &mut self,
819 method: OdeSolverType,
820 linear_solver: LinearSolverType,
821 params: &[f64],
822 t_eval: &[f64],
823 ) -> SolveResult {
824 self.check(linear_solver)?;
825 self.setup_problem(params)?;
826
827 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
828 let soln =
829 match linear_solver {
830 LinearSolverType::Default => method
831 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
832 LinearSolverType::Lu => method
833 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
834 LinearSolverType::Klu => method
835 .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
836 };
837 Ok(Box::new(soln?))
838 }
839
840 fn solve_fwd_sens(
841 &mut self,
842 method: OdeSolverType,
843 linear_solver: LinearSolverType,
844 params: &[f64],
845 t_eval: &[f64],
846 ) -> SolveResult {
847 self.check(linear_solver)?;
848 self.setup_problem(params)?;
849
850 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
851 let soln = match linear_solver {
852 LinearSolverType::Default => {
853 method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
854 }
855 LinearSolverType::Lu => method
856 .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
857 LinearSolverType::Klu => method
858 .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
859 };
860 Ok(Box::new(soln?))
861 }
862
863 fn solve_continuous_adjoint(
864 &mut self,
865 method: OdeSolverType,
866 linear_solver: LinearSolverType,
867 params: &[f64],
868 final_time: f64,
869 ) -> Result<(HostArray, HostArray), DiffsolRtError> {
870 self.check(linear_solver)?;
871 self.setup_problem(params)?;
872
873 let final_time = M::T::from_f64(final_time).unwrap();
874 let integrate_out = self.problem.integrate_out;
875 self.problem.integrate_out = true;
876 let result = match linear_solver {
877 LinearSolverType::Default => method
878 .solve_continuous_adjoint::<M, CG, <M as DefaultSolver>::LS>(
879 &mut self.problem,
880 final_time,
881 ),
882 LinearSolverType::Lu => method
883 .solve_continuous_adjoint::<M, CG, <M as LuValidator<M>>::LS>(
884 &mut self.problem,
885 final_time,
886 ),
887 LinearSolverType::Klu => method
888 .solve_continuous_adjoint::<M, CG, <M as KluValidator<M>>::LS>(
889 &mut self.problem,
890 final_time,
891 ),
892 };
893 self.problem.integrate_out = integrate_out;
894 let (integral, gradient) = result?;
895 Ok((
896 (*integral.inner()).clone().to_host_array(),
897 gradient.to_host_array(),
898 ))
899 }
900
901 fn solve_adjoint_fwd(
902 &mut self,
903 method: OdeSolverType,
904 linear_solver: LinearSolverType,
905 params: &[f64],
906 t_eval: &[f64],
907 ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError> {
908 self.check(linear_solver)?;
909 self.setup_problem(params)?;
910
911 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
912 let (soln, checkpoint) = match linear_solver {
913 LinearSolverType::Default => method
914 .solve_adjoint_fwd::<M, CG, <M as DefaultSolver>::LS>(
915 &mut self.problem,
916 &t_eval,
917 params,
918 linear_solver,
919 ),
920 LinearSolverType::Lu => method.solve_adjoint_fwd::<M, CG, <M as LuValidator<M>>::LS>(
921 &mut self.problem,
922 &t_eval,
923 params,
924 linear_solver,
925 ),
926 LinearSolverType::Klu => method.solve_adjoint_fwd::<M, CG, <M as KluValidator<M>>::LS>(
927 &mut self.problem,
928 &t_eval,
929 params,
930 linear_solver,
931 ),
932 }?;
933 Ok((Box::new(soln), AdjointCheckpointWrapper::new(checkpoint)))
934 }
935
936 fn solve_adjoint_bkwd(
937 &mut self,
938 method: OdeSolverType,
939 linear_solver: LinearSolverType,
940 checkpoint: &AdjointCheckpointWrapper,
941 t_eval: &[f64],
942 dgdu_eval: HostArray,
943 ) -> Result<HostArray, DiffsolRtError> {
944 self.check(linear_solver)?;
945 let checkpoint = checkpoint.guard()?;
946 self.setup_problem(checkpoint.params())?;
947
948 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
949 let dgdu_eval = host_array_to_dense_matrix::<M>(dgdu_eval)?;
950 if dgdu_eval.nrows() != self.problem.eqn.nout() {
951 return Err(DiffsolError::Other(format!(
952 "Expected dgdu_eval to have {} rows, got {}",
953 self.problem.eqn.nout(),
954 dgdu_eval.nrows()
955 ))
956 .into());
957 }
958 if dgdu_eval.ncols() != t_eval.len() {
959 return Err(DiffsolError::Other(format!(
960 "Expected dgdu_eval to have {} columns, got {}",
961 t_eval.len(),
962 dgdu_eval.ncols()
963 ))
964 .into());
965 }
966
967 let gradient = match linear_solver {
968 LinearSolverType::Default => method
969 .solve_adjoint_bkwd::<M, CG, <M as DefaultSolver>::LS>(
970 &mut self.problem,
971 checkpoint.as_ref(),
972 &dgdu_eval,
973 &t_eval,
974 ),
975 LinearSolverType::Lu => method.solve_adjoint_bkwd::<M, CG, <M as LuValidator<M>>::LS>(
976 &mut self.problem,
977 checkpoint.as_ref(),
978 &dgdu_eval,
979 &t_eval,
980 ),
981 LinearSolverType::Klu => method
982 .solve_adjoint_bkwd::<M, CG, <M as KluValidator<M>>::LS>(
983 &mut self.problem,
984 checkpoint.as_ref(),
985 &dgdu_eval,
986 &t_eval,
987 ),
988 }?;
989 Ok(gradient.to_host_array())
990 }
991}
992
993fn host_array_to_dense_matrix<M>(
994 array: HostArray,
995) -> Result<<M::V as DefaultDenseMatrix>::M, DiffsolRtError>
996where
997 M: MatrixHost<T: Scalar>,
998 M::V: DefaultDenseMatrix,
999{
1000 let view = array.as_array::<M::T>()?;
1001 let mut values = Vec::with_capacity(view.nrows() * view.ncols());
1002 for col in 0..view.ncols() {
1003 for row in 0..view.nrows() {
1004 values.push(view[(row, col)]);
1005 }
1006 }
1007 Ok(<M::V as DefaultDenseMatrix>::M::from_vec(
1008 view.nrows(),
1009 view.ncols(),
1010 values,
1011 M::C::default(),
1012 ))
1013}
1014
1015#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1016mod tests {
1017 use diffsol::MatrixCommon;
1018 use diffsol::{
1019 CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
1020 };
1021
1022 #[cfg(feature = "diffsl-llvm")]
1023 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1024 use crate::{
1025 host_array::FromHostArray,
1026 linear_solver_type::LinearSolverType,
1027 matrix_type::MatrixType,
1028 ode_solver_type::OdeSolverType,
1029 scalar_type::ScalarType,
1030 test_support::{
1031 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
1032 logistic_state, LOGISTIC_X0,
1033 },
1034 };
1035
1036 use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
1037
1038 fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
1039 where
1040 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1041 CG: diffsol::CodegenModule
1042 + CodegenModuleJit
1043 + CodegenModuleCompile
1044 + SolveSerialization<diffsol::NalgebraMat<T>>,
1045 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1046 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1047 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1048 {
1049 let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1050 .build_from_diffsl::<CG>(logistic_diffsl_code())
1051 .unwrap();
1052 GenericSolve { problem }
1053 }
1054
1055 fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
1056 where
1057 CG: diffsol::CodegenModule
1058 + CodegenModuleJit
1059 + CodegenModuleCompile
1060 + SolveSerialization<diffsol::NalgebraMat<f32>>
1061 + SolveSerialization<diffsol::NalgebraMat<f64>>
1062 + SolveSerialization<diffsol::FaerMat<f32>>
1063 + SolveSerialization<diffsol::FaerMat<f64>>
1064 + SolveSerialization<diffsol::FaerSparseMat<f32>>
1065 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
1066 {
1067 for matrix_type in [
1068 MatrixType::NalgebraDense,
1069 MatrixType::FaerDense,
1070 MatrixType::FaerSparse,
1071 ] {
1072 for scalar_type in [ScalarType::F32, ScalarType::F64] {
1073 assert!(solve_factory_with_jit_backend::<CG>(
1074 logistic_diffsl_code(),
1075 matrix_type,
1076 scalar_type,
1077 )
1078 .is_ok());
1079 }
1080 }
1081 }
1082
1083 fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
1084 where
1085 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1086 CG: diffsol::CodegenModule
1087 + CodegenModuleJit
1088 + CodegenModuleCompile
1089 + SolveSerialization<diffsol::NalgebraMat<T>>,
1090 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1091 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1092 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1093 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1094 {
1095 let mut solve = make_generic_solve::<T, CG>();
1096 assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
1097 assert_eq!(solve.nstates(), 1);
1098 assert_eq!(solve.nparams(), 1);
1099 assert_eq!(solve.nout(), 1);
1100 assert!(!solve.has_stop());
1101 assert!(solve.check(LinearSolverType::Default).is_ok());
1102 assert!(solve.check(LinearSolverType::Lu).is_ok());
1103 assert!(solve.check(LinearSolverType::Klu).is_err());
1104
1105 solve.set_atol(1e-5);
1106 solve.set_rtol(1e-4);
1107 assert_close(solve.atol(), 1e-5, tol, "solve atol");
1108 assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
1109 solve.set_t0(0.125);
1110 solve.set_h0(0.25);
1111 solve.set_integrate_out(true);
1112 assert_close(solve.t0(), 0.125, tol, "solve t0");
1113 assert_close(solve.h0(), 0.25, tol, "solve h0");
1114 assert!(solve.integrate_out());
1115
1116 solve.set_sens_rtol(Some(1e-3));
1117 solve.set_sens_atol(Some(1e-4));
1118 solve.set_out_rtol(Some(2e-3));
1119 solve.set_out_atol(Some(2e-4));
1120 solve.set_param_rtol(Some(3e-3));
1121 solve.set_param_atol(Some(3e-4));
1122 assert_close(solve.sens_rtol().unwrap(), 1e-3, tol, "solve sens rtol");
1123 assert_close(solve.sens_atol().unwrap(), 1e-4, tol, "solve sens atol");
1124 assert_close(solve.out_rtol().unwrap(), 2e-3, tol, "solve out rtol");
1125 assert_close(solve.out_atol().unwrap(), 2e-4, tol, "solve out atol");
1126 assert_close(solve.param_rtol().unwrap(), 3e-3, tol, "solve param rtol");
1127 assert_close(solve.param_atol().unwrap(), 3e-4, tol, "solve param atol");
1128
1129 solve.set_sens_rtol(None);
1130 solve.set_sens_atol(None);
1131 solve.set_out_rtol(None);
1132 solve.set_out_atol(None);
1133 solve.set_param_rtol(None);
1134 solve.set_param_atol(None);
1135 assert_eq!(solve.sens_rtol(), None);
1136 assert_eq!(solve.sens_atol(), None);
1137 assert_eq!(solve.out_rtol(), None);
1138 assert_eq!(solve.out_atol(), None);
1139 assert_eq!(solve.param_rtol(), None);
1140 assert_eq!(solve.param_atol(), None);
1141
1142 let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
1143 assert_eq!(y0.len(), 1);
1144 assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
1145
1146 let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
1147 assert_eq!(rhs.len(), 1);
1148 assert_close(rhs[0], 0.375, tol, "solve rhs");
1149
1150 let jac_mul =
1151 Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
1152 .unwrap();
1153 assert_eq!(jac_mul.len(), 1);
1154 assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
1155 }
1156
1157 fn assert_solve_runtime_paths<T, CG>(tol: f64)
1158 where
1159 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1160 CG: diffsol::CodegenModule
1161 + CodegenModuleJit
1162 + CodegenModuleCompile
1163 + SolveSerialization<diffsol::NalgebraMat<T>>,
1164 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1165 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1166 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1167 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1168 {
1169 let mut solve = make_generic_solve::<T, CG>();
1170 let soln = solve
1171 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1172 .unwrap();
1173 let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1174 let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1175 assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1176 assert_close(
1177 ys[0][ts.len() - 1],
1178 logistic_state(LOGISTIC_X0, 2.0, 1.0),
1179 tol,
1180 "solve final value",
1181 );
1182
1183 let mut solve = make_generic_solve::<T, CG>();
1184 let dense = solve
1185 .solve_dense(
1186 OdeSolverType::Tsit45,
1187 LinearSolverType::Lu,
1188 &[2.0],
1189 &[0.25, 0.5, 1.0],
1190 )
1191 .unwrap();
1192 let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1193 let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1194 assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1195 for (i, &t) in ts.iter().enumerate() {
1196 assert_close(
1197 ys[0][i],
1198 logistic_state(LOGISTIC_X0, 2.0, t),
1199 tol,
1200 &format!("solve_dense[{i}]"),
1201 );
1202 }
1203
1204 let mut solve = make_generic_solve::<T, CG>();
1205 let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1206 Ok(_) => panic!("expected parameter count mismatch"),
1207 Err(err) => err,
1208 };
1209 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1210
1211 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1212 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1213 .unwrap();
1214 let mut hybrid_solve = GenericSolve {
1215 problem: hybrid_problem,
1216 };
1217 let hybrid = hybrid_solve
1218 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1219 .unwrap();
1220 let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1221 let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1222 assert_close(*hybrid_ts.last().unwrap(), 2.0, tol, "solve final time");
1223 assert_close(
1224 hybrid_ys[0][hybrid_ts.len() - 1],
1225 hybrid_logistic_state(2.0, 2.0),
1226 tol,
1227 "solve final value",
1228 );
1229
1230 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1231 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1232 .unwrap();
1233 let mut hybrid_solve = GenericSolve {
1234 problem: hybrid_problem,
1235 };
1236 let hybrid_dense = hybrid_solve
1237 .solve_dense(
1238 OdeSolverType::Tsit45,
1239 LinearSolverType::Lu,
1240 &[2.0],
1241 &[0.5, 1.0, 1.5, 2.0],
1242 )
1243 .unwrap();
1244 let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1245 let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1246 assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1247 for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1248 assert_close(
1249 hybrid_dense_ys[0][i],
1250 hybrid_logistic_state(2.0, t),
1251 tol,
1252 &format!("solve_dense[{i}]"),
1253 );
1254 }
1255 }
1256
1257 #[cfg(feature = "diffsl-llvm")]
1258 fn assert_solve_sensitivity_paths() {
1259 let t_eval = [0.25, 0.5, 1.0];
1260
1261 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1262 let sens = solve
1263 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1264 .unwrap();
1265 let sens_values = sens.get_sens();
1266 assert_eq!(sens_values.len(), 1);
1267 let sens_matrix =
1268 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1269 for (i, &t) in t_eval.iter().enumerate() {
1270 assert_close(
1271 sens_matrix[0][i],
1272 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1273 5e-4,
1274 &format!("solve_fwd_sens[{i}]"),
1275 );
1276 }
1277
1278 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1279 .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1280 .unwrap();
1281 let mut solve = GenericSolve {
1282 problem: hybrid_problem,
1283 };
1284 let hybrid_sens = solve
1285 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1286 .unwrap();
1287 let sens_values = hybrid_sens.get_sens();
1288 let sens_matrix =
1289 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1290 for (i, &t) in t_eval.iter().enumerate() {
1291 assert_close(
1292 sens_matrix[0][i],
1293 hybrid_logistic_state_dr(2.0, t),
1294 5e-4,
1295 &format!("solve_fwd_sens[{i}]"),
1296 );
1297 }
1298 }
1299
1300 #[cfg(feature = "diffsl-cranelift")]
1301 #[test]
1302 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1303 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1304 }
1305
1306 #[cfg(feature = "diffsl-cranelift")]
1307 #[test]
1308 fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1309 assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1310 assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1311 assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1312 assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1313 }
1314
1315 #[cfg(feature = "diffsl-llvm")]
1316 #[test]
1317 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1318 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1319 }
1320
1321 #[cfg(feature = "diffsl-llvm")]
1322 #[test]
1323 fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1324 assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1325 assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1326 assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1327 assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1328 }
1329
1330 #[cfg(feature = "diffsl-llvm")]
1331 #[test]
1332 fn solve_trait_sensitivity_paths_for_llvm() {
1333 assert_solve_sensitivity_paths();
1334 }
1335
1336 #[cfg(feature = "diffsl-cranelift")]
1337 #[test]
1338 fn setup_problem_validates_parameter_count_for_cranelift() {
1339 let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1340 let err = solve.setup_problem(&[]).unwrap_err();
1341 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1342
1343 solve.setup_problem(&[2.0]).unwrap();
1344 let mut params = solve.problem.context().vector_zeros(1);
1345 solve.problem.eqn.get_params(&mut params);
1346 assert_eq!(params.get_index(0), 2.0);
1347 }
1348
1349 #[cfg(feature = "diffsl-llvm")]
1350 #[test]
1351 fn setup_problem_validates_parameter_count_for_llvm() {
1352 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1353 let err = solve.setup_problem(&[]).unwrap_err();
1354 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1355
1356 solve.setup_problem(&[2.0]).unwrap();
1357 let mut params = solve.problem.context().vector_zeros(1);
1358 solve.problem.eqn.get_params(&mut params);
1359 assert_eq!(params.get_index(0), 2.0);
1360 }
1361}