1use std::sync::{Arc, Mutex};
2
3use serde::{Deserialize, Serialize, Serializer};
4
5use crate::{error::DiffsolRtError, ode::Ode, solve::Solve};
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
8pub struct OdeSolverOptionsSnapshot {
9 pub max_nonlinear_solver_iterations: usize,
10 pub max_error_test_failures: usize,
11 pub update_jacobian_after_steps: usize,
12 pub update_rhs_jacobian_after_steps: usize,
13 pub threshold_to_update_jacobian: f64,
14 pub threshold_to_update_rhs_jacobian: f64,
15 pub min_timestep: f64,
16}
17
18impl OdeSolverOptionsSnapshot {
19 pub(crate) fn from_solve(solve: &dyn Solve) -> Self {
20 Self {
21 max_nonlinear_solver_iterations: solve.ode_max_nonlinear_solver_iterations(),
22 max_error_test_failures: solve.ode_max_error_test_failures(),
23 update_jacobian_after_steps: solve.ode_update_jacobian_after_steps(),
24 update_rhs_jacobian_after_steps: solve.ode_update_rhs_jacobian_after_steps(),
25 threshold_to_update_jacobian: solve.ode_threshold_to_update_jacobian(),
26 threshold_to_update_rhs_jacobian: solve.ode_threshold_to_update_rhs_jacobian(),
27 min_timestep: solve.ode_min_timestep(),
28 }
29 }
30
31 pub(crate) fn apply_to_solve(&self, solve: &mut dyn Solve) {
32 solve.set_ode_max_nonlinear_solver_iterations(self.max_nonlinear_solver_iterations);
33 solve.set_ode_max_error_test_failures(self.max_error_test_failures);
34 solve.set_ode_update_jacobian_after_steps(self.update_jacobian_after_steps);
35 solve.set_ode_update_rhs_jacobian_after_steps(self.update_rhs_jacobian_after_steps);
36 solve.set_ode_threshold_to_update_jacobian(self.threshold_to_update_jacobian);
37 solve.set_ode_threshold_to_update_rhs_jacobian(self.threshold_to_update_rhs_jacobian);
38 solve.set_ode_min_timestep(self.min_timestep);
39 }
40}
41
42#[derive(Clone)]
43pub struct OdeSolverOptions {
44 ode: Arc<Mutex<Ode>>,
45}
46impl OdeSolverOptions {
47 pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
48 Self { ode }
49 }
50 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
51 self.ode.lock().map_err(|_| {
52 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
53 "Failed to acquire lock on Ode object".to_string(),
54 ))
55 })
56 }
57}
58
59impl OdeSolverOptions {
60 pub fn get_max_nonlinear_solver_iterations(&self) -> Result<usize, DiffsolRtError> {
61 Ok(self.guard()?.solve.ode_max_nonlinear_solver_iterations())
62 }
63 pub fn set_max_nonlinear_solver_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
64 self.guard()?
65 .solve
66 .set_ode_max_nonlinear_solver_iterations(value);
67 Ok(())
68 }
69 pub fn get_max_error_test_failures(&self) -> Result<usize, DiffsolRtError> {
70 Ok(self.guard()?.solve.ode_max_error_test_failures())
71 }
72 pub fn set_max_error_test_failures(&self, value: usize) -> Result<(), DiffsolRtError> {
73 self.guard()?.solve.set_ode_max_error_test_failures(value);
74 Ok(())
75 }
76 pub fn get_update_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
77 Ok(self.guard()?.solve.ode_update_jacobian_after_steps())
78 }
79 pub fn set_update_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
80 self.guard()?
81 .solve
82 .set_ode_update_jacobian_after_steps(value);
83 Ok(())
84 }
85 pub fn get_update_rhs_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
86 Ok(self.guard()?.solve.ode_update_rhs_jacobian_after_steps())
87 }
88 pub fn set_update_rhs_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
89 self.guard()?
90 .solve
91 .set_ode_update_rhs_jacobian_after_steps(value);
92 Ok(())
93 }
94 pub fn get_threshold_to_update_jacobian(&self) -> Result<f64, DiffsolRtError> {
95 Ok(self.guard()?.solve.ode_threshold_to_update_jacobian())
96 }
97 pub fn set_threshold_to_update_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
98 self.guard()?
99 .solve
100 .set_ode_threshold_to_update_jacobian(value);
101 Ok(())
102 }
103 pub fn get_threshold_to_update_rhs_jacobian(&self) -> Result<f64, DiffsolRtError> {
104 Ok(self.guard()?.solve.ode_threshold_to_update_rhs_jacobian())
105 }
106 pub fn set_threshold_to_update_rhs_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
107 self.guard()?
108 .solve
109 .set_ode_threshold_to_update_rhs_jacobian(value);
110 Ok(())
111 }
112 pub fn get_min_timestep(&self) -> Result<f64, DiffsolRtError> {
113 Ok(self.guard()?.solve.ode_min_timestep())
114 }
115 pub fn set_min_timestep(&self, value: f64) -> Result<(), DiffsolRtError> {
116 self.guard()?.solve.set_ode_min_timestep(value);
117 Ok(())
118 }
119}
120
121impl Serialize for OdeSolverOptions {
122 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123 where
124 S: Serializer,
125 {
126 let guard = self.guard().map_err(serde::ser::Error::custom)?;
127 OdeSolverOptionsSnapshot::from_solve(guard.solve.as_ref()).serialize(serializer)
128 }
129}
130
131#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
132mod tests {
133 use crate::{
134 jit::JitBackendType,
135 linear_solver_type::LinearSolverType,
136 matrix_type::MatrixType,
137 ode::OdeWrapper,
138 ode_solver_type::OdeSolverType,
139 scalar_type::ScalarType,
140 test_support::{available_jit_backends, logistic_diffsl_code},
141 };
142
143 use super::OdeSolverOptions;
144
145 fn make_options(jit_backend: JitBackendType) -> OdeSolverOptions {
146 OdeWrapper::new_jit(
147 logistic_diffsl_code(),
148 jit_backend,
149 ScalarType::F64,
150 MatrixType::NalgebraDense,
151 LinearSolverType::Default,
152 OdeSolverType::Bdf,
153 )
154 .unwrap()
155 .get_options()
156 }
157
158 #[test]
159 fn ode_solver_options_roundtrip_and_serialize() {
160 for jit_backend in available_jit_backends() {
161 let options = make_options(jit_backend);
162 options.set_max_nonlinear_solver_iterations(17).unwrap();
163 options.set_max_error_test_failures(19).unwrap();
164 options.set_update_jacobian_after_steps(23).unwrap();
165 options.set_update_rhs_jacobian_after_steps(29).unwrap();
166 options.set_threshold_to_update_jacobian(1e-3).unwrap();
167 options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
168 options.set_min_timestep(1e-4).unwrap();
169
170 assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 17);
171 assert_eq!(options.get_max_error_test_failures().unwrap(), 19);
172 assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 23);
173 assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 29);
174 assert_eq!(options.get_threshold_to_update_jacobian().unwrap(), 1e-3);
175 assert_eq!(
176 options.get_threshold_to_update_rhs_jacobian().unwrap(),
177 2e-3
178 );
179 assert_eq!(options.get_min_timestep().unwrap(), 1e-4);
180
181 let value = serde_json::to_value(&options).unwrap();
182 assert_eq!(value["max_nonlinear_solver_iterations"], 17);
183 assert_eq!(value["max_error_test_failures"], 19);
184 assert_eq!(value["update_jacobian_after_steps"], 23);
185 assert_eq!(value["update_rhs_jacobian_after_steps"], 29);
186 assert_eq!(value["threshold_to_update_jacobian"], 1e-3);
187 assert_eq!(value["threshold_to_update_rhs_jacobian"], 2e-3);
188 assert_eq!(value["min_timestep"], 1e-4);
189 }
190 }
191}