1use diffsol::{
5 CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
6 NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeSolverProblem, Op, Vector,
7 VectorCommon, VectorHost, VectorRef,
8 error::DiffsolError,
9 matrix::{MatrixHost, MatrixRef},
10};
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use diffsol::{CodegenModuleCompile, CodegenModuleJit};
13use num_traits::{FromPrimitive, ToPrimitive}; use paste::paste;
15
16use crate::error::DiffsolJsError;
17use crate::host_array::{HostArray, ToHostArray};
18#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
19use crate::jit::JitBackendType;
20#[cfg(feature = "external")]
21use crate::scalar_type::ExternalScalar;
22use crate::scalar_type::{Scalar, ScalarType};
23use crate::{
24 generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
25 generate_trait_ic_option_accessors, generate_trait_ode_option_accessors,
26 option_value_from_store, option_value_to_store,
27};
28
29use crate::{
30 linear_solver_type::LinearSolverType,
31 matrix_type::{MatrixKind, MatrixType},
32 ode_solver_type::OdeSolverType,
33 valid_linear_solver::{KluValidator, LuValidator, validate_linear_solver},
34};
35
36use crate::solution::Solution;
39pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolJsError>;
40
41pub(crate) trait Solve {
42 fn matrix_type(&self) -> MatrixType;
43
44 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError>;
45
46 fn rhs_jac_mul(
47 &mut self,
48 params: &[f64],
49 t: f64,
50 y: &[f64],
51 v: &[f64],
52 ) -> Result<HostArray, DiffsolJsError>;
53
54 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError>;
55
56 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError>;
57 fn set_rtol(&mut self, rtol: f64);
58 fn rtol(&self) -> f64;
59 fn set_atol(&mut self, atol: f64);
60 fn atol(&self) -> f64;
61
62 fn solve(
64 &mut self,
65 method: OdeSolverType,
66 linear_solver: LinearSolverType,
67 params: &[f64],
68 final_time: f64,
69 ) -> SolveResult;
70
71 fn solve_hybrid(
72 &mut self,
73 method: OdeSolverType,
74 linear_solver: LinearSolverType,
75 params: &[f64],
76 final_time: f64,
77 ) -> SolveResult;
78
79 fn solve_dense(
80 &mut self,
81 method: OdeSolverType,
82 linear_solver: LinearSolverType,
83 params: &[f64],
84 t_eval: &[f64],
85 ) -> SolveResult;
86
87 fn solve_hybrid_dense(
88 &mut self,
89 method: OdeSolverType,
90 linear_solver: LinearSolverType,
91 params: &[f64],
92 t_eval: &[f64],
93 ) -> SolveResult;
94
95 fn solve_fwd_sens(
96 &mut self,
97 method: OdeSolverType,
98 linear_solver: LinearSolverType,
99 params: &[f64],
100 t_eval: &[f64],
101 ) -> SolveResult;
102
103 fn solve_hybrid_fwd_sens(
104 &mut self,
105 method: OdeSolverType,
106 linear_solver: LinearSolverType,
107 params: &[f64],
108 t_eval: &[f64],
109 ) -> SolveResult;
110
111 #[allow(clippy::type_complexity)]
112 #[allow(clippy::too_many_arguments)]
113 fn solve_sum_squares_adj(
114 &mut self,
115 method: OdeSolverType,
116 linear_solver: LinearSolverType,
117 backwards_method: OdeSolverType,
118 backwards_linear_solver: LinearSolverType,
119 params: &[f64],
120 data: HostArray,
121 t_eval: &[f64],
122 ) -> Result<(f64, HostArray), DiffsolJsError>;
123
124 generate_trait_ic_option_accessors! {
125 use_linesearch: bool,
126 max_linesearch_iterations: usize,
127 max_newton_iterations: usize,
128 max_linear_solver_setups: usize,
129 step_reduction_factor: f64,
130 armijo_constant: f64
131 }
132 generate_trait_ode_option_accessors! {
133 max_nonlinear_solver_iterations: usize,
134 max_error_test_failures: usize,
135 min_timestep: f64,
136 update_jacobian_after_steps: usize,
137 update_rhs_jacobian_after_steps: usize,
138 threshold_to_update_jacobian: f64,
139 threshold_to_update_rhs_jacobian: f64
140 }
141}
142#[cfg(feature = "external")]
144pub(crate) fn solve_factory_external(
145 rhs_state_deps: Vec<(usize, usize)>,
146 rhs_input_deps: Vec<(usize, usize)>,
147 mass_state_deps: Vec<(usize, usize)>,
148 matrix_type: MatrixType,
149 scalar_type: ScalarType,
150) -> Result<Box<dyn Solve>, DiffsolJsError> {
151 let solve: Box<dyn Solve> = match matrix_type {
152 MatrixType::NalgebraDense => match scalar_type {
153 #[cfg(feature = "diffsl-external-f32")]
154 ScalarType::F32 => Box::new(GenericSolve::<
155 diffsol::NalgebraMat<f32>,
156 diffsl::ExternalModule<f32>,
157 >::from_external(
158 rhs_state_deps, rhs_input_deps, mass_state_deps, false
159 )?),
160 #[cfg(feature = "diffsl-external-f64")]
161 ScalarType::F64 => Box::new(GenericSolve::<
162 diffsol::NalgebraMat<f64>,
163 diffsl::ExternalModule<f64>,
164 >::from_external(
165 rhs_state_deps, rhs_input_deps, mass_state_deps, false
166 )?),
167 _ => {
168 return Err(DiffsolJsError::from(DiffsolError::Other(
169 "Unsupported scalar type for NalgebraDense".to_string(),
170 )));
171 }
172 },
173 MatrixType::FaerDense => match scalar_type {
174 #[cfg(feature = "diffsl-external-f32")]
175 ScalarType::F32 => Box::new(GenericSolve::<
176 diffsol::FaerMat<f32>,
177 diffsl::ExternalModule<f32>,
178 >::from_external(
179 rhs_state_deps, rhs_input_deps, mass_state_deps, false
180 )?),
181 #[cfg(feature = "diffsl-external-f64")]
182 ScalarType::F64 => Box::new(GenericSolve::<
183 diffsol::FaerMat<f64>,
184 diffsl::ExternalModule<f64>,
185 >::from_external(
186 rhs_state_deps, rhs_input_deps, mass_state_deps, false
187 )?),
188 _ => {
189 return Err(DiffsolJsError::from(DiffsolError::Other(
190 "Unsupported scalar type for FaerDense".to_string(),
191 )));
192 }
193 },
194 MatrixType::FaerSparse => match scalar_type {
195 #[cfg(feature = "diffsl-external-f32")]
196 ScalarType::F32 => Box::new(GenericSolve::<
197 diffsol::FaerSparseMat<f32>,
198 diffsl::ExternalModule<f32>,
199 >::from_external(
200 rhs_state_deps, rhs_input_deps, mass_state_deps, false
201 )?),
202 #[cfg(feature = "diffsl-external-f64")]
203 ScalarType::F64 => Box::new(GenericSolve::<
204 diffsol::FaerSparseMat<f64>,
205 diffsl::ExternalModule<f64>,
206 >::from_external(
207 rhs_state_deps, rhs_input_deps, mass_state_deps, false
208 )?),
209 _ => {
210 return Err(DiffsolJsError::from(DiffsolError::Other(
211 "Unsupported scalar type for FaerSparse".to_string(),
212 )));
213 }
214 },
215 };
216 Ok(solve)
217}
218
219#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
220pub(crate) fn solve_factory_jit(
221 code: &str,
222 jit_backend: JitBackendType,
223 matrix_type: MatrixType,
224 scalar_type: ScalarType,
225) -> Result<Box<dyn Solve>, DiffsolJsError> {
226 match jit_backend {
227 #[cfg(feature = "diffsl-cranelift")]
228 JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
229 code,
230 matrix_type,
231 scalar_type,
232 ),
233 #[cfg(feature = "diffsl-llvm")]
234 JitBackendType::Llvm => {
235 solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
236 }
237 }
238}
239
240#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
241fn solve_factory_with_jit_backend<CG>(
242 code: &str,
243 matrix_type: MatrixType,
244 scalar_type: ScalarType,
245) -> Result<Box<dyn Solve>, DiffsolJsError>
246where
247 CG: CodegenModule + CodegenModuleJit + CodegenModuleCompile,
248{
249 let solve: Box<dyn Solve> = match matrix_type {
250 MatrixType::NalgebraDense => match scalar_type {
251 ScalarType::F32 => {
252 let problem =
253 OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
254 Box::new(GenericSolve { problem })
255 }
256 ScalarType::F64 => {
257 let problem =
258 OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
259 Box::new(GenericSolve { problem })
260 }
261 },
262 MatrixType::FaerDense => match scalar_type {
263 ScalarType::F32 => {
264 let problem =
265 OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
266 Box::new(GenericSolve { problem })
267 }
268 ScalarType::F64 => {
269 let problem =
270 OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
271 Box::new(GenericSolve { problem })
272 }
273 },
274 MatrixType::FaerSparse => match scalar_type {
275 ScalarType::F32 => {
276 let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
277 .build_from_diffsl::<CG>(code)?;
278 Box::new(GenericSolve { problem })
279 }
280 ScalarType::F64 => {
281 let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
282 .build_from_diffsl::<CG>(code)?;
283 Box::new(GenericSolve { problem })
284 }
285 },
286 };
287 Ok(solve)
288}
289
290pub(crate) struct GenericSolve<M, CG>
291where
292 M: MatrixHost<T: Scalar>,
293 M::V: Vector + VectorHost,
294 CG: CodegenModule,
295{
296 problem: OdeSolverProblem<DiffSl<M, CG>>,
297}
298
299impl<M, CG> GenericSolve<M, CG>
300where
301 M: MatrixHost<T: Scalar>,
302 M::V: Vector + VectorHost + DefaultDenseMatrix,
303 CG: CodegenModule,
304{
305 pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolJsError> {
306 let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
307 let params = M::V::from_slice(¶ms, M::C::default());
308
309 let nparams = self.problem.eqn.nparams();
311 if params.len() == nparams {
312 self.problem.eqn.set_params(¶ms);
313 Ok(())
314 } else {
315 Err(DiffsolError::Other(format!(
316 "Expecting {} params but got {}",
317 nparams,
318 params.len()
319 ))
320 .into())
321 }
322 }
323}
324
325#[cfg(feature = "external")]
326impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
327where
328 M: MatrixHost<T: ExternalScalar>,
329 M::V: Vector + VectorHost + DefaultDenseMatrix,
330{
331 pub fn from_external(
332 rhs_state_deps: Vec<(usize, usize)>,
333 rhs_input_deps: Vec<(usize, usize)>,
334 mass_state_deps: Vec<(usize, usize)>,
335 include_sensitivities: bool,
336 ) -> Result<Self, DiffsolJsError> {
337 let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
338 M::C::default(),
339 rhs_state_deps,
340 rhs_input_deps,
341 mass_state_deps,
342 include_sensitivities,
343 )?;
344 let default_p = vec![0.0; eqn.nparams()];
345 let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
346 Ok(GenericSolve { problem })
347 }
348}
349
350impl<M, CG> Solve for GenericSolve<M, CG>
351where
352 M: MatrixHost<T: Scalar + ToPrimitive>
353 + DefaultSolver
354 + LuValidator<M>
355 + KluValidator<M>
356 + MatrixKind,
357 CG: CodegenModule,
358 for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
359 for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
360 M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
361 <M::V as DefaultDenseMatrix>::M: Send + Sync,
362 for<'b> &'b M::V: VectorRef<M::V>,
363 for<'b> &'b M: MatrixRef<M>,
364{
365 fn matrix_type(&self) -> MatrixType {
366 MatrixType::from_diffsol::<M>()
367 }
368
369 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError> {
370 validate_linear_solver::<M>(linear_solver)
371 }
372
373 fn set_atol(&mut self, atol: f64) {
374 self.problem.atol.fill(M::T::from_f64(atol).unwrap());
375 }
376
377 fn atol(&self) -> f64 {
378 self.problem.atol[0].to_f64().unwrap()
379 }
380
381 fn set_rtol(&mut self, rtol: f64) {
382 self.problem.rtol = M::T::from_f64(rtol).unwrap();
383 }
384
385 fn rtol(&self) -> f64 {
386 self.problem.rtol.to_f64().unwrap()
387 }
388
389 generate_ic_option_accessors! {
390 use_linesearch: bool,
391 max_linesearch_iterations: usize,
392 max_newton_iterations: usize,
393 max_linear_solver_setups: usize,
394 step_reduction_factor: f64,
395 armijo_constant: f64
396 }
397
398 generate_ode_option_accessors! {
399 max_nonlinear_solver_iterations: usize,
400 max_error_test_failures: usize,
401 min_timestep: f64,
402 update_jacobian_after_steps: usize,
403 update_rhs_jacobian_after_steps: usize,
404 threshold_to_update_jacobian: f64,
405 threshold_to_update_rhs_jacobian: f64
406 }
407
408 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError> {
409 self.setup_problem(params)?;
410 let n = self.problem.eqn.nstates();
411 let mut y0 = M::V::zeros(n, M::C::default());
412 let t0 = self.problem.t0;
413 self.problem.eqn.init().call_inplace(t0, &mut y0);
414 Ok((*y0.inner()).clone().to_host_array())
415 }
416
417 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError> {
418 self.setup_problem(params)?;
419 let n = self.problem.eqn.nstates();
420 let y = y
421 .iter()
422 .map(|&x| M::T::from_f64(x).unwrap())
423 .collect::<Vec<_>>();
424 let y_vec = M::V::from_slice(&y, M::C::default());
425 let mut dydt = M::V::zeros(n, M::C::default());
426 self.problem
427 .eqn
428 .rhs()
429 .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
430 Ok((*dydt.inner()).clone().to_host_array())
431 }
432
433 fn rhs_jac_mul(
434 &mut self,
435
436 params: &[f64],
437 t: f64,
438 y: &[f64],
439 v: &[f64],
440 ) -> Result<HostArray, DiffsolJsError> {
441 self.setup_problem(params)?;
442 let n = self.problem.eqn.nstates();
443 let y = y
444 .iter()
445 .map(|&x| M::T::from_f64(x).unwrap())
446 .collect::<Vec<_>>();
447 let v = v
448 .iter()
449 .map(|&x| M::T::from_f64(x).unwrap())
450 .collect::<Vec<_>>();
451 let y_vec = M::V::from_slice(&y, M::C::default());
452 let v_vec = M::V::from_slice(&v, M::C::default());
453 let mut dydt = M::V::zeros(n, M::C::default());
454 self.problem.eqn.rhs().jac_mul_inplace(
455 &y_vec,
456 M::T::from_f64(t).unwrap(),
457 &v_vec,
458 &mut dydt,
459 );
460 Ok((*dydt.inner()).clone().to_host_array())
461 }
462
463 fn solve(
464 &mut self,
465 method: OdeSolverType,
466 linear_solver: LinearSolverType,
467 params: &[f64],
468 final_time: f64,
469 ) -> SolveResult {
470 self.check(linear_solver)?;
471 self.setup_problem(params)?;
472 let final_time = M::T::from_f64(final_time).unwrap();
473 let soln = match linear_solver {
474 LinearSolverType::Default => {
475 method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
476 }
477 LinearSolverType::Lu => {
478 method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
479 }
480 LinearSolverType::Klu => {
481 method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
482 }
483 };
484 Ok(Box::new(soln?))
485 }
486
487 fn solve_dense(
488 &mut self,
489 method: OdeSolverType,
490 linear_solver: LinearSolverType,
491 params: &[f64],
492 t_eval: &[f64],
493 ) -> SolveResult {
494 self.check(linear_solver)?;
495 self.setup_problem(params)?;
496
497 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
498 let soln =
499 match linear_solver {
500 LinearSolverType::Default => method
501 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
502 LinearSolverType::Lu => method
503 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
504 LinearSolverType::Klu => method
505 .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
506 };
507 Ok(Box::new(soln?))
508 }
509
510 fn solve_hybrid(
511 &mut self,
512 method: OdeSolverType,
513 linear_solver: LinearSolverType,
514 params: &[f64],
515 final_time: f64,
516 ) -> SolveResult {
517 self.check(linear_solver)?;
518 self.setup_problem(params)?;
519 let final_time = M::T::from_f64(final_time).unwrap();
520 let soln = match linear_solver {
521 LinearSolverType::Default => method
522 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
523 LinearSolverType::Lu => method
524 .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
525 LinearSolverType::Klu => method
526 .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
527 };
528 Ok(Box::new(soln?))
529 }
530
531 fn solve_fwd_sens(
532 &mut self,
533 method: OdeSolverType,
534 linear_solver: LinearSolverType,
535 params: &[f64],
536 t_eval: &[f64],
537 ) -> SolveResult {
538 self.check(linear_solver)?;
539 self.setup_problem(params)?;
540
541 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
542 let soln = match linear_solver {
543 LinearSolverType::Default => {
544 method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
545 }
546 LinearSolverType::Lu => method
547 .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
548 LinearSolverType::Klu => method
549 .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
550 };
551 Ok(Box::new(soln?))
552 }
553
554 fn solve_hybrid_dense(
555 &mut self,
556 method: OdeSolverType,
557 linear_solver: LinearSolverType,
558 params: &[f64],
559 t_eval: &[f64],
560 ) -> SolveResult {
561 self.check(linear_solver)?;
562 self.setup_problem(params)?;
563
564 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
565 let soln = match linear_solver {
566 LinearSolverType::Default => method
567 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
568 LinearSolverType::Lu => method
569 .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
570 LinearSolverType::Klu => method
571 .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
572 &mut self.problem,
573 &t_eval,
574 ),
575 };
576 Ok(Box::new(soln?))
577 }
578
579 fn solve_hybrid_fwd_sens(
580 &mut self,
581 method: OdeSolverType,
582 linear_solver: LinearSolverType,
583 params: &[f64],
584 t_eval: &[f64],
585 ) -> SolveResult {
586 self.check(linear_solver)?;
587 self.setup_problem(params)?;
588
589 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
590 let soln = match linear_solver {
591 LinearSolverType::Default => method
592 .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
593 &mut self.problem,
594 &t_eval,
595 ),
596 LinearSolverType::Lu => method
597 .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
598 &mut self.problem,
599 &t_eval,
600 ),
601 LinearSolverType::Klu => method
602 .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
603 &mut self.problem,
604 &t_eval,
605 ),
606 };
607 Ok(Box::new(soln?))
608 }
609
610 fn solve_sum_squares_adj(
611 &mut self,
612
613 method: OdeSolverType,
614 linear_solver: LinearSolverType,
615 backwards_method: OdeSolverType,
616 backwards_linear_solver: LinearSolverType,
617 params: &[f64],
618 data: HostArray,
619 t_eval: &[f64],
620 ) -> Result<(f64, HostArray), DiffsolJsError> {
621 self.check(linear_solver)?;
622 self.setup_problem(params)?;
623
624 let data = data.as_array()?;
625 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
626
627 let previous_integrate_out = self.problem.integrate_out;
628 self.problem.integrate_out = true;
629 let result = match linear_solver {
630 LinearSolverType::Default => method
631 .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
632 &mut self.problem,
633 data,
634 &t_eval,
635 backwards_method,
636 backwards_linear_solver,
637 ),
638 LinearSolverType::Lu => method
639 .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
640 &mut self.problem,
641 data,
642 &t_eval,
643 backwards_method,
644 backwards_linear_solver,
645 ),
646 LinearSolverType::Klu => method
647 .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
648 &mut self.problem,
649 data,
650 &t_eval,
651 backwards_method,
652 backwards_linear_solver,
653 ),
654 };
655 self.problem.integrate_out = previous_integrate_out;
656 let (y, y_sens) = result?;
657
658 Ok((
659 y.to_f64().unwrap(),
660 (*y_sens.inner()).clone().to_host_array(),
661 ))
662 }
663}
664
665#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
666mod tests {
667 use diffsol::{
668 CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
669 };
670
671 use crate::{
672 host_array::FromHostArray,
673 linear_solver_type::LinearSolverType,
674 matrix_type::MatrixType,
675 ode_solver_type::OdeSolverType,
676 scalar_type::ScalarType,
677 test_support::{
678 LOGISTIC_X0, assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state,
679 hybrid_logistic_state_dr, logistic_diffsl_code, logistic_integral, logistic_state,
680 logistic_state_dr, matrix_host,
681 },
682 };
683
684 use super::{GenericSolve, Solve, solve_factory_with_jit_backend};
685
686 fn make_generic_solve<CG>() -> GenericSolve<diffsol::NalgebraMat<f64>, CG>
687 where
688 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
689 {
690 let problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
691 .build_from_diffsl::<CG>(logistic_diffsl_code())
692 .unwrap();
693 GenericSolve { problem }
694 }
695
696 fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
697 where
698 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
699 {
700 for matrix_type in [
701 MatrixType::NalgebraDense,
702 MatrixType::FaerDense,
703 MatrixType::FaerSparse,
704 ] {
705 for scalar_type in [ScalarType::F32, ScalarType::F64] {
706 assert!(
707 solve_factory_with_jit_backend::<CG>(
708 logistic_diffsl_code(),
709 matrix_type,
710 scalar_type,
711 )
712 .is_ok()
713 );
714 }
715 }
716 }
717
718 fn assert_solve_metadata_and_helpers<CG>()
719 where
720 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
721 {
722 let mut solve = make_generic_solve::<CG>();
723 assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
724 assert!(solve.check(LinearSolverType::Default).is_ok());
725 assert!(solve.check(LinearSolverType::Lu).is_ok());
726 assert!(solve.check(LinearSolverType::Klu).is_err());
727
728 solve.set_atol(1e-5);
729 solve.set_rtol(1e-4);
730 assert_close(solve.atol(), 1e-5, 1e-12, "solve atol");
731 assert_close(solve.rtol(), 1e-4, 1e-12, "solve rtol");
732
733 let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
734 assert_eq!(y0, vec![LOGISTIC_X0]);
735
736 let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
737 assert_eq!(rhs.len(), 1);
738 assert_close(rhs[0], 0.375, 1e-12, "solve rhs");
739
740 let jac_mul =
741 Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
742 .unwrap();
743 assert_eq!(jac_mul.len(), 1);
744 assert_close(jac_mul[0], 3.0, 1e-12, "solve rhs jac mul");
745 }
746
747 fn assert_solve_runtime_paths<CG>()
748 where
749 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
750 {
751 let mut solve = make_generic_solve::<CG>();
752 let soln = solve
753 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
754 .unwrap();
755 let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
756 let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
757 assert_close(*ts.last().unwrap(), 1.0, 5e-4, "solve final time");
758 assert_close(
759 ys[0][ts.len() - 1],
760 logistic_state(LOGISTIC_X0, 2.0, 1.0),
761 5e-4,
762 "solve final value",
763 );
764
765 let mut solve = make_generic_solve::<CG>();
766 let dense = solve
767 .solve_dense(
768 OdeSolverType::Tsit45,
769 LinearSolverType::Lu,
770 &[2.0],
771 &[0.25, 0.5, 1.0],
772 )
773 .unwrap();
774 let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
775 let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
776 assert_eq!(ts, vec![0.25, 0.5, 1.0]);
777 for (i, &t) in ts.iter().enumerate() {
778 assert_close(
779 ys[0][i],
780 logistic_state(LOGISTIC_X0, 2.0, t),
781 5e-4,
782 &format!("solve_dense[{i}]"),
783 );
784 }
785
786 let mut solve = make_generic_solve::<CG>();
787 let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
788 Ok(_) => panic!("expected parameter count mismatch"),
789 Err(err) => err,
790 };
791 assert!(err.to_string().contains("Expecting 1 params but got 0"));
792
793 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
794 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
795 .unwrap();
796 let mut hybrid_solve = GenericSolve {
797 problem: hybrid_problem,
798 };
799 let hybrid = hybrid_solve
800 .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
801 .unwrap();
802 let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
803 let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
804 assert_close(
805 *hybrid_ts.last().unwrap(),
806 2.0,
807 5e-4,
808 "solve_hybrid final time",
809 );
810 assert_close(
811 hybrid_ys[0][hybrid_ts.len() - 1],
812 hybrid_logistic_state(2.0, 2.0),
813 5e-4,
814 "solve_hybrid final value",
815 );
816
817 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
818 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
819 .unwrap();
820 let mut hybrid_solve = GenericSolve {
821 problem: hybrid_problem,
822 };
823 let hybrid_dense = hybrid_solve
824 .solve_hybrid_dense(
825 OdeSolverType::Tsit45,
826 LinearSolverType::Lu,
827 &[2.0],
828 &[0.5, 1.0, 1.5, 2.0],
829 )
830 .unwrap();
831 let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
832 let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
833 assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
834 for (i, &t) in hybrid_dense_ts.iter().enumerate() {
835 assert_close(
836 hybrid_dense_ys[0][i],
837 hybrid_logistic_state(2.0, t),
838 5e-4,
839 &format!("solve_hybrid_dense[{i}]"),
840 );
841 }
842 }
843
844 #[cfg(feature = "diffsl-llvm")]
845 fn assert_solve_sensitivity_paths() {
846 let t_eval = [0.25, 0.5, 1.0];
847
848 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
849 let sens = solve
850 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
851 .unwrap();
852 let sens_values = sens.get_sens();
853 assert_eq!(sens_values.len(), 1);
854 let sens_matrix =
855 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
856 for (i, &t) in t_eval.iter().enumerate() {
857 assert_close(
858 sens_matrix[0][i],
859 logistic_state_dr(LOGISTIC_X0, 2.0, t),
860 5e-4,
861 &format!("solve_fwd_sens[{i}]"),
862 );
863 }
864
865 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
866 .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
867 .unwrap();
868 let mut solve = GenericSolve {
869 problem: hybrid_problem,
870 };
871 let hybrid_sens = solve
872 .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
873 .unwrap();
874 let sens_values = hybrid_sens.get_sens();
875 let sens_matrix =
876 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
877 for (i, &t) in t_eval.iter().enumerate() {
878 assert_close(
879 sens_matrix[0][i],
880 hybrid_logistic_state_dr(2.0, t),
881 5e-4,
882 &format!("solve_hybrid_fwd_sens[{i}]"),
883 );
884 }
885
886 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
887 let adjoint_data: Vec<f64> = adjoint_t_eval
888 .iter()
889 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
890 .collect();
891 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
892 let (objective, gradient) = solve
893 .solve_sum_squares_adj(
894 OdeSolverType::Bdf,
895 LinearSolverType::Lu,
896 OdeSolverType::TrBdf2,
897 LinearSolverType::Lu,
898 &[2.0],
899 matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
900 &adjoint_t_eval,
901 )
902 .unwrap();
903 assert!(objective.is_finite());
904 let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
905 assert_eq!(gradient.len(), 1);
906 assert!(gradient[0].is_finite());
907 }
908
909 #[cfg(feature = "diffsl-cranelift")]
910 #[test]
911 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
912 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
913 }
914
915 #[cfg(feature = "diffsl-cranelift")]
916 #[test]
917 fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
918 assert_solve_metadata_and_helpers::<diffsol::CraneliftJitModule>();
919 assert_solve_runtime_paths::<diffsol::CraneliftJitModule>();
920 }
921
922 #[cfg(feature = "diffsl-llvm")]
923 #[test]
924 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
925 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
926 }
927
928 #[cfg(feature = "diffsl-llvm")]
929 #[test]
930 fn solve_trait_helpers_and_runtime_paths_for_llvm() {
931 assert_solve_metadata_and_helpers::<diffsol::LlvmModule>();
932 assert_solve_runtime_paths::<diffsol::LlvmModule>();
933 }
934
935 #[cfg(feature = "diffsl-llvm")]
936 #[test]
937 fn solve_trait_sensitivity_paths_for_llvm() {
938 assert_solve_sensitivity_paths();
939 }
940
941 #[cfg(feature = "diffsl-cranelift")]
942 #[test]
943 fn setup_problem_validates_parameter_count_for_cranelift() {
944 let mut solve = make_generic_solve::<diffsol::CraneliftJitModule>();
945 let err = solve.setup_problem(&[]).unwrap_err();
946 assert!(err.to_string().contains("Expecting 1 params but got 0"));
947
948 solve.setup_problem(&[2.0]).unwrap();
949 let mut params = solve.problem.context().vector_zeros(1);
950 solve.problem.eqn.get_params(&mut params);
951 assert_eq!(params.get_index(0), 2.0);
952 }
953
954 #[cfg(feature = "diffsl-llvm")]
955 #[test]
956 fn setup_problem_validates_parameter_count_for_llvm() {
957 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
958 let err = solve.setup_problem(&[]).unwrap_err();
959 assert!(err.to_string().contains("Expecting 1 params but got 0"));
960
961 solve.setup_problem(&[2.0]).unwrap();
962 let mut params = solve.problem.context().vector_zeros(1);
963 solve.problem.eqn.get_params(&mut params);
964 assert_eq!(params.get_index(0), 2.0);
965 }
966}