1#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::error::DiffsolRtError;
8use crate::host_array::HostArray;
9use crate::host_array::ToHostArray;
10#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
11use crate::jit::JitBackendType;
12#[cfg(feature = "external")]
13use crate::scalar_type::ExternalScalar;
14use crate::scalar_type::Scalar;
15use crate::scalar_type::ScalarType;
16use crate::{
17 generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
18 option_value_from_store, option_value_to_store,
19};
20use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
21use diffsol::ObjectModule;
22use diffsol::OdeBuilder;
23use diffsol::{
24 error::DiffsolError,
25 matrix::{MatrixHost, MatrixRef},
26 CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
27 NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector, VectorCommon,
28 VectorHost, VectorRef,
29};
30#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
31use diffsol::{CodegenModuleCompile, CodegenModuleJit};
32use num_traits::{FromPrimitive, ToPrimitive}; use paste::paste;
34
35use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
36use crate::{
37 linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
38};
39use crate::{
40 matrix_type::MatrixKind,
41 valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
42};
43
44use crate::solution::Solution;
47pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
48
49pub(crate) trait Solve {
50 fn matrix_type(&self) -> MatrixType;
51 fn nstates(&self) -> usize;
52 fn nparams(&self) -> usize;
53 fn nout(&self) -> usize;
54 fn has_stop(&self) -> bool;
55
56 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
57
58 fn rhs_jac_mul(
59 &mut self,
60 params: &[f64],
61 t: f64,
62 y: &[f64],
63 v: &[f64],
64 ) -> Result<HostArray, DiffsolRtError>;
65
66 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
67
68 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
69 fn set_rtol(&mut self, rtol: f64);
70 fn rtol(&self) -> f64;
71 fn set_atol(&mut self, atol: f64);
72 fn atol(&self) -> f64;
73 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
74 unsupported_serialization_error(
75 "ODE serialization is only supported for JIT-backed solvers",
76 )
77 }
78
79 fn solve(
81 &mut self,
82 method: OdeSolverType,
83 linear_solver: LinearSolverType,
84 params: &[f64],
85 final_time: f64,
86 ) -> SolveResult;
87
88 fn solve_hybrid(
89 &mut self,
90 method: OdeSolverType,
91 linear_solver: LinearSolverType,
92 params: &[f64],
93 final_time: f64,
94 ) -> SolveResult;
95
96 fn solve_dense(
97 &mut self,
98 method: OdeSolverType,
99 linear_solver: LinearSolverType,
100 params: &[f64],
101 t_eval: &[f64],
102 ) -> SolveResult;
103
104 fn solve_hybrid_dense(
105 &mut self,
106 method: OdeSolverType,
107 linear_solver: LinearSolverType,
108 params: &[f64],
109 t_eval: &[f64],
110 ) -> SolveResult;
111
112 fn solve_fwd_sens(
113 &mut self,
114 method: OdeSolverType,
115 linear_solver: LinearSolverType,
116 params: &[f64],
117 t_eval: &[f64],
118 ) -> SolveResult;
119
120 fn solve_hybrid_fwd_sens(
121 &mut self,
122 method: OdeSolverType,
123 linear_solver: LinearSolverType,
124 params: &[f64],
125 t_eval: &[f64],
126 ) -> SolveResult;
127
128 #[allow(clippy::type_complexity)]
129 #[allow(clippy::too_many_arguments)]
130 fn solve_sum_squares_adj(
131 &mut self,
132 method: OdeSolverType,
133 linear_solver: LinearSolverType,
134 backwards_method: OdeSolverType,
135 backwards_linear_solver: LinearSolverType,
136 params: &[f64],
137 data: HostArray,
138 t_eval: &[f64],
139 ) -> Result<(f64, HostArray), DiffsolRtError>;
140
141 generate_trait_ic_option_accessors! {
142 use_linesearch: bool,
143 max_linesearch_iterations: usize,
144 max_newton_iterations: usize,
145 max_linear_solver_setups: usize,
146 step_reduction_factor: f64,
147 armijo_constant: f64
148 }
149 generate_trait_ode_option_accessors! {
150 max_nonlinear_solver_iterations: usize,
151 max_error_test_failures: usize,
152 min_timestep: f64,
153 update_jacobian_after_steps: usize,
154 update_rhs_jacobian_after_steps: usize,
155 threshold_to_update_jacobian: f64,
156 threshold_to_update_rhs_jacobian: f64
157 }
158}
159#[cfg(feature = "external")]
161pub(crate) fn solve_factory_external(
162 rhs_state_deps: Vec<(usize, usize)>,
163 rhs_input_deps: Vec<(usize, usize)>,
164 mass_state_deps: Vec<(usize, usize)>,
165 matrix_type: MatrixType,
166 scalar_type: ScalarType,
167) -> Result<Box<dyn Solve>, DiffsolRtError> {
168 let solve: Box<dyn Solve> = match matrix_type {
169 MatrixType::NalgebraDense => match scalar_type {
170 #[cfg(feature = "diffsl-external-f32")]
171 ScalarType::F32 => Box::new(GenericSolve::<
172 diffsol::NalgebraMat<f32>,
173 diffsl::ExternalModule<f32>,
174 >::from_external(
175 rhs_state_deps, rhs_input_deps, mass_state_deps, false
176 )?),
177 #[cfg(feature = "diffsl-external-f64")]
178 ScalarType::F64 => Box::new(GenericSolve::<
179 diffsol::NalgebraMat<f64>,
180 diffsl::ExternalModule<f64>,
181 >::from_external(
182 rhs_state_deps, rhs_input_deps, mass_state_deps, false
183 )?),
184 _ => {
185 return Err(DiffsolRtError::from(DiffsolError::Other(
186 "Unsupported scalar type for NalgebraDense".to_string(),
187 )));
188 }
189 },
190 MatrixType::FaerDense => match scalar_type {
191 #[cfg(feature = "diffsl-external-f32")]
192 ScalarType::F32 => Box::new(GenericSolve::<
193 diffsol::FaerMat<f32>,
194 diffsl::ExternalModule<f32>,
195 >::from_external(
196 rhs_state_deps, rhs_input_deps, mass_state_deps, false
197 )?),
198 #[cfg(feature = "diffsl-external-f64")]
199 ScalarType::F64 => Box::new(GenericSolve::<
200 diffsol::FaerMat<f64>,
201 diffsl::ExternalModule<f64>,
202 >::from_external(
203 rhs_state_deps, rhs_input_deps, mass_state_deps, false
204 )?),
205 _ => {
206 return Err(DiffsolRtError::from(DiffsolError::Other(
207 "Unsupported scalar type for FaerDense".to_string(),
208 )));
209 }
210 },
211 MatrixType::FaerSparse => match scalar_type {
212 #[cfg(feature = "diffsl-external-f32")]
213 ScalarType::F32 => Box::new(GenericSolve::<
214 diffsol::FaerSparseMat<f32>,
215 diffsl::ExternalModule<f32>,
216 >::from_external(
217 rhs_state_deps, rhs_input_deps, mass_state_deps, false
218 )?),
219 #[cfg(feature = "diffsl-external-f64")]
220 ScalarType::F64 => Box::new(GenericSolve::<
221 diffsol::FaerSparseMat<f64>,
222 diffsl::ExternalModule<f64>,
223 >::from_external(
224 rhs_state_deps, rhs_input_deps, mass_state_deps, false
225 )?),
226 _ => {
227 return Err(DiffsolRtError::from(DiffsolError::Other(
228 "Unsupported scalar type for FaerSparse".to_string(),
229 )));
230 }
231 },
232 };
233 Ok(solve)
234}
235
236#[cfg(feature = "diffsl-external-dynamic")]
237pub(crate) fn solve_factory_external_dynamic(
238 path: PathBuf,
239 rhs_state_deps: Vec<(usize, usize)>,
240 rhs_input_deps: Vec<(usize, usize)>,
241 mass_state_deps: Vec<(usize, usize)>,
242 matrix_type: MatrixType,
243 scalar_type: ScalarType,
244) -> Result<Box<dyn Solve>, DiffsolRtError> {
245 let solve: Box<dyn Solve> = match matrix_type {
246 MatrixType::NalgebraDense => match scalar_type {
247 ScalarType::F32 => Box::new(GenericSolve::<
248 diffsol::NalgebraMat<f32>,
249 diffsl::ExternalDynModule<f32>,
250 >::from_external_dynamic(
251 path,
252 rhs_state_deps,
253 rhs_input_deps,
254 mass_state_deps,
255 false,
256 )?),
257 ScalarType::F64 => Box::new(GenericSolve::<
258 diffsol::NalgebraMat<f64>,
259 diffsl::ExternalDynModule<f64>,
260 >::from_external_dynamic(
261 path,
262 rhs_state_deps,
263 rhs_input_deps,
264 mass_state_deps,
265 false,
266 )?),
267 },
268 MatrixType::FaerDense => match scalar_type {
269 ScalarType::F32 => Box::new(GenericSolve::<
270 diffsol::FaerMat<f32>,
271 diffsl::ExternalDynModule<f32>,
272 >::from_external_dynamic(
273 path,
274 rhs_state_deps,
275 rhs_input_deps,
276 mass_state_deps,
277 false,
278 )?),
279 ScalarType::F64 => Box::new(GenericSolve::<
280 diffsol::FaerMat<f64>,
281 diffsl::ExternalDynModule<f64>,
282 >::from_external_dynamic(
283 path,
284 rhs_state_deps,
285 rhs_input_deps,
286 mass_state_deps,
287 false,
288 )?),
289 },
290 MatrixType::FaerSparse => match scalar_type {
291 ScalarType::F32 => Box::new(GenericSolve::<
292 diffsol::FaerSparseMat<f32>,
293 diffsl::ExternalDynModule<f32>,
294 >::from_external_dynamic(
295 path,
296 rhs_state_deps,
297 rhs_input_deps,
298 mass_state_deps,
299 false,
300 )?),
301 ScalarType::F64 => Box::new(GenericSolve::<
302 diffsol::FaerSparseMat<f64>,
303 diffsl::ExternalDynModule<f64>,
304 >::from_external_dynamic(
305 path,
306 rhs_state_deps,
307 rhs_input_deps,
308 mass_state_deps,
309 false,
310 )?),
311 },
312 };
313 Ok(solve)
314}
315
316#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
317pub(crate) fn solve_factory_jit(
318 code: &str,
319 jit_backend: JitBackendType,
320 matrix_type: MatrixType,
321 scalar_type: ScalarType,
322) -> Result<Box<dyn Solve>, DiffsolRtError> {
323 match jit_backend {
324 #[cfg(feature = "diffsl-cranelift")]
325 JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
326 code,
327 matrix_type,
328 scalar_type,
329 ),
330 #[cfg(feature = "diffsl-llvm")]
331 JitBackendType::Llvm => {
332 solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
333 }
334 }
335}
336
337#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
338fn solve_factory_with_jit_backend<CG>(
339 code: &str,
340 matrix_type: MatrixType,
341 scalar_type: ScalarType,
342) -> Result<Box<dyn Solve>, DiffsolRtError>
343where
344 CG: CodegenModule
345 + CodegenModuleJit
346 + CodegenModuleCompile
347 + SolveSerialization<diffsol::NalgebraMat<f32>>
348 + SolveSerialization<diffsol::NalgebraMat<f64>>
349 + SolveSerialization<diffsol::FaerMat<f32>>
350 + SolveSerialization<diffsol::FaerMat<f64>>
351 + SolveSerialization<diffsol::FaerSparseMat<f32>>
352 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
353{
354 let solve: Box<dyn Solve> = match matrix_type {
355 MatrixType::NalgebraDense => match scalar_type {
356 ScalarType::F32 => {
357 let problem =
358 OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
359 Box::new(GenericSolve { problem })
360 }
361 ScalarType::F64 => {
362 let problem =
363 OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
364 Box::new(GenericSolve { problem })
365 }
366 },
367 MatrixType::FaerDense => match scalar_type {
368 ScalarType::F32 => {
369 let problem =
370 OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
371 Box::new(GenericSolve { problem })
372 }
373 ScalarType::F64 => {
374 let problem =
375 OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
376 Box::new(GenericSolve { problem })
377 }
378 },
379 MatrixType::FaerSparse => match scalar_type {
380 ScalarType::F32 => {
381 let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
382 .build_from_diffsl::<CG>(code)?;
383 Box::new(GenericSolve { problem })
384 }
385 ScalarType::F64 => {
386 let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
387 .build_from_diffsl::<CG>(code)?;
388 Box::new(GenericSolve { problem })
389 }
390 },
391 };
392 Ok(solve)
393}
394
395pub(crate) fn solve_factory_from_serialized_diffsl(
396 serialized_diffsl: &[u8],
397 matrix_type: MatrixType,
398 scalar_type: ScalarType,
399) -> Result<Box<dyn Solve>, DiffsolRtError> {
400 let solve: Box<dyn Solve> = match matrix_type {
401 MatrixType::NalgebraDense => match scalar_type {
402 ScalarType::F32 => Box::new(
403 GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
404 serialized_diffsl,
405 )?,
406 ),
407 ScalarType::F64 => Box::new(
408 GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
409 serialized_diffsl,
410 )?,
411 ),
412 },
413 MatrixType::FaerDense => match scalar_type {
414 ScalarType::F32 => Box::new(
415 GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
416 serialized_diffsl,
417 )?,
418 ),
419 ScalarType::F64 => Box::new(
420 GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
421 serialized_diffsl,
422 )?,
423 ),
424 },
425 MatrixType::FaerSparse => match scalar_type {
426 ScalarType::F32 => Box::new(
427 GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
428 serialized_diffsl,
429 )?,
430 ),
431 ScalarType::F64 => Box::new(
432 GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
433 serialized_diffsl,
434 )?,
435 ),
436 },
437 };
438 Ok(solve)
439}
440
441pub(crate) struct GenericSolve<M, CG>
442where
443 M: MatrixHost<T: Scalar>,
444 M::V: Vector + VectorHost,
445 CG: CodegenModule,
446{
447 problem: OdeSolverProblem<DiffSl<M, CG>>,
448}
449
450impl<M, CG> GenericSolve<M, CG>
451where
452 M: MatrixHost<T: Scalar>,
453 M::V: Vector + VectorHost + DefaultDenseMatrix,
454 CG: CodegenModule,
455{
456 fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
457 let default_p = vec![0.0; eqn.nparams()];
458 let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
459 Ok(GenericSolve { problem })
460 }
461
462 pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
463 let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
464 let params = M::V::from_slice(¶ms, M::C::default());
465
466 let nparams = self.problem.eqn.nparams();
468 if params.len() == nparams {
469 self.problem.eqn.set_params(¶ms);
470 Ok(())
471 } else {
472 Err(DiffsolError::Other(format!(
473 "Expecting {} params but got {}",
474 nparams,
475 params.len()
476 ))
477 .into())
478 }
479 }
480
481 pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
482 where
483 DiffSl<M, CG>: serde::Serialize,
484 {
485 serde_json::to_vec(&self.problem.eqn)
486 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
487 }
488}
489
490#[cfg(feature = "external")]
491impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
492where
493 M: MatrixHost<T: ExternalScalar>,
494 M::V: Vector + VectorHost + DefaultDenseMatrix,
495{
496 pub fn from_external(
497 rhs_state_deps: Vec<(usize, usize)>,
498 rhs_input_deps: Vec<(usize, usize)>,
499 mass_state_deps: Vec<(usize, usize)>,
500 include_sensitivities: bool,
501 ) -> Result<Self, DiffsolRtError> {
502 let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
503 M::C::default(),
504 rhs_state_deps,
505 rhs_input_deps,
506 mass_state_deps,
507 include_sensitivities,
508 )?;
509 Self::from_eqn(eqn)
510 }
511}
512
513#[cfg(feature = "diffsl-external-dynamic")]
514impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
515where
516 M: MatrixHost<T: Scalar>,
517 M::V: Vector + VectorHost + DefaultDenseMatrix,
518{
519 pub fn from_external_dynamic(
520 path: impl Into<PathBuf>,
521 rhs_state_deps: Vec<(usize, usize)>,
522 rhs_input_deps: Vec<(usize, usize)>,
523 mass_state_deps: Vec<(usize, usize)>,
524 include_sensitivities: bool,
525 ) -> Result<Self, DiffsolRtError> {
526 let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
527 path,
528 M::C::default(),
529 rhs_state_deps,
530 rhs_input_deps,
531 mass_state_deps,
532 include_sensitivities,
533 )?;
534 Self::from_eqn(eqn)
535 }
536}
537
538impl<M> GenericSolve<M, ObjectModule>
539where
540 M: MatrixHost<T: Scalar>,
541 M::V: Vector + VectorHost + DefaultDenseMatrix,
542{
543 pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
544 let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
545 .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
546 Self::from_eqn(eqn)
547 }
548}
549
550impl<M, CG> Solve for GenericSolve<M, CG>
551where
552 M: MatrixHost<T: Scalar + ToPrimitive>
553 + DefaultSolver
554 + LuValidator<M>
555 + KluValidator<M>
556 + MatrixKind,
557 CG: CodegenModule,
558 for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
559 for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
560 M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
561 <M::V as DefaultDenseMatrix>::M: Send + Sync,
562 for<'b> &'b M::V: VectorRef<M::V>,
563 for<'b> &'b M: MatrixRef<M>,
564 CG: SolveSerialization<M>,
565{
566 fn matrix_type(&self) -> MatrixType {
567 MatrixType::from_diffsol::<M>()
568 }
569
570 fn nstates(&self) -> usize {
571 self.problem.eqn.nstates()
572 }
573
574 fn nparams(&self) -> usize {
575 self.problem.eqn.nparams()
576 }
577
578 fn nout(&self) -> usize {
579 self.problem.eqn.nout()
580 }
581
582 fn has_stop(&self) -> bool {
583 self.problem.eqn.root().is_some()
584 }
585
586 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
587 validate_linear_solver::<M>(linear_solver)
588 }
589
590 fn set_atol(&mut self, atol: f64) {
591 self.problem.atol.fill(M::T::from_f64(atol).unwrap());
592 }
593
594 fn atol(&self) -> f64 {
595 self.problem.atol[0].to_f64().unwrap()
596 }
597
598 fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
599 CG::serialized_diffsl(self)
600 }
601
602 fn set_rtol(&mut self, rtol: f64) {
603 self.problem.rtol = M::T::from_f64(rtol).unwrap();
604 }
605
606 fn rtol(&self) -> f64 {
607 self.problem.rtol.to_f64().unwrap()
608 }
609
610 generate_ic_option_accessors! {
611 use_linesearch: bool,
612 max_linesearch_iterations: usize,
613 max_newton_iterations: usize,
614 max_linear_solver_setups: usize,
615 step_reduction_factor: f64,
616 armijo_constant: f64
617 }
618
619 generate_ode_option_accessors! {
620 max_nonlinear_solver_iterations: usize,
621 max_error_test_failures: usize,
622 min_timestep: f64,
623 update_jacobian_after_steps: usize,
624 update_rhs_jacobian_after_steps: usize,
625 threshold_to_update_jacobian: f64,
626 threshold_to_update_rhs_jacobian: f64
627 }
628
629 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
630 self.setup_problem(params)?;
631 let n = self.problem.eqn.nstates();
632 let mut y0 = M::V::zeros(n, M::C::default());
633 let t0 = self.problem.t0;
634 self.problem.eqn.init().call_inplace(t0, &mut y0);
635 Ok((*y0.inner()).clone().to_host_array())
636 }
637
638 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
639 self.setup_problem(params)?;
640 let n = self.problem.eqn.nstates();
641 let y = y
642 .iter()
643 .map(|&x| M::T::from_f64(x).unwrap())
644 .collect::<Vec<_>>();
645 let y_vec = M::V::from_slice(&y, M::C::default());
646 let mut dydt = M::V::zeros(n, M::C::default());
647 self.problem
648 .eqn
649 .rhs()
650 .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
651 Ok((*dydt.inner()).clone().to_host_array())
652 }
653
654 fn rhs_jac_mul(
655 &mut self,
656
657 params: &[f64],
658 t: f64,
659 y: &[f64],
660 v: &[f64],
661 ) -> Result<HostArray, DiffsolRtError> {
662 self.setup_problem(params)?;
663 let n = self.problem.eqn.nstates();
664 let y = y
665 .iter()
666 .map(|&x| M::T::from_f64(x).unwrap())
667 .collect::<Vec<_>>();
668 let v = v
669 .iter()
670 .map(|&x| M::T::from_f64(x).unwrap())
671 .collect::<Vec<_>>();
672 let y_vec = M::V::from_slice(&y, M::C::default());
673 let v_vec = M::V::from_slice(&v, M::C::default());
674 let mut dydt = M::V::zeros(n, M::C::default());
675 self.problem.eqn.rhs().jac_mul_inplace(
676 &y_vec,
677 M::T::from_f64(t).unwrap(),
678 &v_vec,
679 &mut dydt,
680 );
681 Ok((*dydt.inner()).clone().to_host_array())
682 }
683
684 fn solve(
685 &mut self,
686 method: OdeSolverType,
687 linear_solver: LinearSolverType,
688 params: &[f64],
689 final_time: f64,
690 ) -> SolveResult {
691 self.check(linear_solver)?;
692 self.setup_problem(params)?;
693 let final_time = M::T::from_f64(final_time).unwrap();
694 let soln = match linear_solver {
695 LinearSolverType::Default => {
696 method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
697 }
698 LinearSolverType::Lu => {
699 method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
700 }
701 LinearSolverType::Klu => {
702 method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
703 }
704 };
705 Ok(Box::new(soln?))
706 }
707
708 fn solve_dense(
709 &mut self,
710 method: OdeSolverType,
711 linear_solver: LinearSolverType,
712 params: &[f64],
713 t_eval: &[f64],
714 ) -> SolveResult {
715 self.check(linear_solver)?;
716 self.setup_problem(params)?;
717
718 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
719 let soln =
720 match linear_solver {
721 LinearSolverType::Default => method
722 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
723 LinearSolverType::Lu => method
724 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
725 LinearSolverType::Klu => method
726 .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
727 };
728 Ok(Box::new(soln?))
729 }
730
731 fn solve_hybrid(
732 &mut self,
733 method: OdeSolverType,
734 linear_solver: LinearSolverType,
735 params: &[f64],
736 final_time: f64,
737 ) -> SolveResult {
738 self.check(linear_solver)?;
739 self.setup_problem(params)?;
740 let final_time = M::T::from_f64(final_time).unwrap();
741 let soln = match linear_solver {
742 LinearSolverType::Default => method
743 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
744 LinearSolverType::Lu => method
745 .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
746 LinearSolverType::Klu => method
747 .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
748 };
749 Ok(Box::new(soln?))
750 }
751
752 fn solve_fwd_sens(
753 &mut self,
754 method: OdeSolverType,
755 linear_solver: LinearSolverType,
756 params: &[f64],
757 t_eval: &[f64],
758 ) -> SolveResult {
759 self.check(linear_solver)?;
760 self.setup_problem(params)?;
761
762 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
763 let soln = match linear_solver {
764 LinearSolverType::Default => {
765 method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
766 }
767 LinearSolverType::Lu => method
768 .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
769 LinearSolverType::Klu => method
770 .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
771 };
772 Ok(Box::new(soln?))
773 }
774
775 fn solve_hybrid_dense(
776 &mut self,
777 method: OdeSolverType,
778 linear_solver: LinearSolverType,
779 params: &[f64],
780 t_eval: &[f64],
781 ) -> SolveResult {
782 self.check(linear_solver)?;
783 self.setup_problem(params)?;
784
785 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
786 let soln = match linear_solver {
787 LinearSolverType::Default => method
788 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
789 LinearSolverType::Lu => method
790 .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
791 LinearSolverType::Klu => method
792 .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
793 &mut self.problem,
794 &t_eval,
795 ),
796 };
797 Ok(Box::new(soln?))
798 }
799
800 fn solve_hybrid_fwd_sens(
801 &mut self,
802 method: OdeSolverType,
803 linear_solver: LinearSolverType,
804 params: &[f64],
805 t_eval: &[f64],
806 ) -> SolveResult {
807 self.check(linear_solver)?;
808 self.setup_problem(params)?;
809
810 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
811 let soln = match linear_solver {
812 LinearSolverType::Default => method
813 .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
814 &mut self.problem,
815 &t_eval,
816 ),
817 LinearSolverType::Lu => method
818 .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
819 &mut self.problem,
820 &t_eval,
821 ),
822 LinearSolverType::Klu => method
823 .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
824 &mut self.problem,
825 &t_eval,
826 ),
827 };
828 Ok(Box::new(soln?))
829 }
830
831 fn solve_sum_squares_adj(
832 &mut self,
833
834 method: OdeSolverType,
835 linear_solver: LinearSolverType,
836 backwards_method: OdeSolverType,
837 backwards_linear_solver: LinearSolverType,
838 params: &[f64],
839 data: HostArray,
840 t_eval: &[f64],
841 ) -> Result<(f64, HostArray), DiffsolRtError> {
842 self.check(linear_solver)?;
843 self.setup_problem(params)?;
844
845 let data = data.as_array()?;
846 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
847
848 let result = match linear_solver {
849 LinearSolverType::Default => method
850 .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
851 &mut self.problem,
852 data,
853 &t_eval,
854 backwards_method,
855 backwards_linear_solver,
856 ),
857 LinearSolverType::Lu => method
858 .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
859 &mut self.problem,
860 data,
861 &t_eval,
862 backwards_method,
863 backwards_linear_solver,
864 ),
865 LinearSolverType::Klu => method
866 .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
867 &mut self.problem,
868 data,
869 &t_eval,
870 backwards_method,
871 backwards_linear_solver,
872 ),
873 };
874 let (y, y_sens) = result?;
875
876 Ok((
877 y.to_f64().unwrap(),
878 (*y_sens.inner()).clone().to_host_array(),
879 ))
880 }
881}
882
883#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
884mod tests {
885 use diffsol::MatrixCommon;
886 use diffsol::{
887 CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
888 };
889
890 #[cfg(feature = "diffsl-llvm")]
891 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr, matrix_host};
892 use crate::{
893 host_array::FromHostArray,
894 linear_solver_type::LinearSolverType,
895 matrix_type::MatrixType,
896 ode_solver_type::OdeSolverType,
897 scalar_type::ScalarType,
898 test_support::{
899 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
900 logistic_state, LOGISTIC_X0,
901 },
902 };
903
904 use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
905
906 fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
907 where
908 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
909 CG: diffsol::CodegenModule
910 + CodegenModuleJit
911 + CodegenModuleCompile
912 + SolveSerialization<diffsol::NalgebraMat<T>>,
913 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
914 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
915 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
916 {
917 let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
918 .build_from_diffsl::<CG>(logistic_diffsl_code())
919 .unwrap();
920 GenericSolve { problem }
921 }
922
923 fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
924 where
925 CG: diffsol::CodegenModule
926 + CodegenModuleJit
927 + CodegenModuleCompile
928 + SolveSerialization<diffsol::NalgebraMat<f32>>
929 + SolveSerialization<diffsol::NalgebraMat<f64>>
930 + SolveSerialization<diffsol::FaerMat<f32>>
931 + SolveSerialization<diffsol::FaerMat<f64>>
932 + SolveSerialization<diffsol::FaerSparseMat<f32>>
933 + SolveSerialization<diffsol::FaerSparseMat<f64>>,
934 {
935 for matrix_type in [
936 MatrixType::NalgebraDense,
937 MatrixType::FaerDense,
938 MatrixType::FaerSparse,
939 ] {
940 for scalar_type in [ScalarType::F32, ScalarType::F64] {
941 assert!(solve_factory_with_jit_backend::<CG>(
942 logistic_diffsl_code(),
943 matrix_type,
944 scalar_type,
945 )
946 .is_ok());
947 }
948 }
949 }
950
951 fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
952 where
953 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
954 CG: diffsol::CodegenModule
955 + CodegenModuleJit
956 + CodegenModuleCompile
957 + SolveSerialization<diffsol::NalgebraMat<T>>,
958 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
959 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
960 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
961 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
962 {
963 let mut solve = make_generic_solve::<T, CG>();
964 assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
965 assert_eq!(solve.nstates(), 1);
966 assert_eq!(solve.nparams(), 1);
967 assert_eq!(solve.nout(), 1);
968 assert!(!solve.has_stop());
969 assert!(solve.check(LinearSolverType::Default).is_ok());
970 assert!(solve.check(LinearSolverType::Lu).is_ok());
971 assert!(solve.check(LinearSolverType::Klu).is_err());
972
973 solve.set_atol(1e-5);
974 solve.set_rtol(1e-4);
975 assert_close(solve.atol(), 1e-5, tol, "solve atol");
976 assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
977
978 let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
979 assert_eq!(y0.len(), 1);
980 assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
981
982 let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
983 assert_eq!(rhs.len(), 1);
984 assert_close(rhs[0], 0.375, tol, "solve rhs");
985
986 let jac_mul =
987 Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
988 .unwrap();
989 assert_eq!(jac_mul.len(), 1);
990 assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
991 }
992
993 fn assert_solve_runtime_paths<T, CG>(tol: f64)
994 where
995 T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
996 CG: diffsol::CodegenModule
997 + CodegenModuleJit
998 + CodegenModuleCompile
999 + SolveSerialization<diffsol::NalgebraMat<T>>,
1000 diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1001 <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1002 <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1003 GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1004 {
1005 let mut solve = make_generic_solve::<T, CG>();
1006 let soln = solve
1007 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1008 .unwrap();
1009 let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1010 let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1011 assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1012 assert_close(
1013 ys[0][ts.len() - 1],
1014 logistic_state(LOGISTIC_X0, 2.0, 1.0),
1015 tol,
1016 "solve final value",
1017 );
1018
1019 let mut solve = make_generic_solve::<T, CG>();
1020 let dense = solve
1021 .solve_dense(
1022 OdeSolverType::Tsit45,
1023 LinearSolverType::Lu,
1024 &[2.0],
1025 &[0.25, 0.5, 1.0],
1026 )
1027 .unwrap();
1028 let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1029 let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1030 assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1031 for (i, &t) in ts.iter().enumerate() {
1032 assert_close(
1033 ys[0][i],
1034 logistic_state(LOGISTIC_X0, 2.0, t),
1035 tol,
1036 &format!("solve_dense[{i}]"),
1037 );
1038 }
1039
1040 let mut solve = make_generic_solve::<T, CG>();
1041 let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1042 Ok(_) => panic!("expected parameter count mismatch"),
1043 Err(err) => err,
1044 };
1045 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1046
1047 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1048 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1049 .unwrap();
1050 let mut hybrid_solve = GenericSolve {
1051 problem: hybrid_problem,
1052 };
1053 let hybrid = hybrid_solve
1054 .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1055 .unwrap();
1056 let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1057 let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1058 assert_close(
1059 *hybrid_ts.last().unwrap(),
1060 2.0,
1061 tol,
1062 "solve_hybrid final time",
1063 );
1064 assert_close(
1065 hybrid_ys[0][hybrid_ts.len() - 1],
1066 hybrid_logistic_state(2.0, 2.0),
1067 tol,
1068 "solve_hybrid final value",
1069 );
1070
1071 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1072 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1073 .unwrap();
1074 let mut hybrid_solve = GenericSolve {
1075 problem: hybrid_problem,
1076 };
1077 let hybrid_dense = hybrid_solve
1078 .solve_hybrid_dense(
1079 OdeSolverType::Tsit45,
1080 LinearSolverType::Lu,
1081 &[2.0],
1082 &[0.5, 1.0, 1.5, 2.0],
1083 )
1084 .unwrap();
1085 let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1086 let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1087 assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1088 for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1089 assert_close(
1090 hybrid_dense_ys[0][i],
1091 hybrid_logistic_state(2.0, t),
1092 tol,
1093 &format!("solve_hybrid_dense[{i}]"),
1094 );
1095 }
1096 }
1097
1098 #[cfg(feature = "diffsl-llvm")]
1099 fn assert_solve_sensitivity_paths() {
1100 let t_eval = [0.25, 0.5, 1.0];
1101
1102 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1103 let sens = solve
1104 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1105 .unwrap();
1106 let sens_values = sens.get_sens();
1107 assert_eq!(sens_values.len(), 1);
1108 let sens_matrix =
1109 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1110 for (i, &t) in t_eval.iter().enumerate() {
1111 assert_close(
1112 sens_matrix[0][i],
1113 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1114 5e-4,
1115 &format!("solve_fwd_sens[{i}]"),
1116 );
1117 }
1118
1119 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1120 .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1121 .unwrap();
1122 let mut solve = GenericSolve {
1123 problem: hybrid_problem,
1124 };
1125 let hybrid_sens = solve
1126 .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1127 .unwrap();
1128 let sens_values = hybrid_sens.get_sens();
1129 let sens_matrix =
1130 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1131 for (i, &t) in t_eval.iter().enumerate() {
1132 assert_close(
1133 sens_matrix[0][i],
1134 hybrid_logistic_state_dr(2.0, t),
1135 5e-4,
1136 &format!("solve_hybrid_fwd_sens[{i}]"),
1137 );
1138 }
1139
1140 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1141 let adjoint_data: Vec<f64> = adjoint_t_eval
1142 .iter()
1143 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1144 .collect();
1145 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1146 let (objective, gradient) = solve
1147 .solve_sum_squares_adj(
1148 OdeSolverType::Bdf,
1149 LinearSolverType::Lu,
1150 OdeSolverType::TrBdf2,
1151 LinearSolverType::Lu,
1152 &[2.0],
1153 matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
1154 &adjoint_t_eval,
1155 )
1156 .unwrap();
1157 assert!(objective.is_finite());
1158 let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
1159 assert_eq!(gradient.len(), 1);
1160 assert!(gradient[0].is_finite());
1161 }
1162
1163 #[cfg(feature = "diffsl-cranelift")]
1164 #[test]
1165 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1166 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1167 }
1168
1169 #[cfg(feature = "diffsl-cranelift")]
1170 #[test]
1171 fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1172 assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1173 assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1174 assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1175 assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1176 }
1177
1178 #[cfg(feature = "diffsl-llvm")]
1179 #[test]
1180 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1181 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1182 }
1183
1184 #[cfg(feature = "diffsl-llvm")]
1185 #[test]
1186 fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1187 assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1188 assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1189 assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1190 assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1191 }
1192
1193 #[cfg(feature = "diffsl-llvm")]
1194 #[test]
1195 fn solve_trait_sensitivity_paths_for_llvm() {
1196 assert_solve_sensitivity_paths();
1197 }
1198
1199 #[cfg(feature = "diffsl-cranelift")]
1200 #[test]
1201 fn setup_problem_validates_parameter_count_for_cranelift() {
1202 let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1203 let err = solve.setup_problem(&[]).unwrap_err();
1204 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1205
1206 solve.setup_problem(&[2.0]).unwrap();
1207 let mut params = solve.problem.context().vector_zeros(1);
1208 solve.problem.eqn.get_params(&mut params);
1209 assert_eq!(params.get_index(0), 2.0);
1210 }
1211
1212 #[cfg(feature = "diffsl-llvm")]
1213 #[test]
1214 fn setup_problem_validates_parameter_count_for_llvm() {
1215 let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1216 let err = solve.setup_problem(&[]).unwrap_err();
1217 assert!(err.to_string().contains("Expecting 1 params but got 0"));
1218
1219 solve.setup_problem(&[2.0]).unwrap();
1220 let mut params = solve.problem.context().vector_zeros(1);
1221 solve.problem.eqn.get_params(&mut params);
1222 assert_eq!(params.get_index(0), 2.0);
1223 }
1224}