1use std::sync::{Arc, Mutex};
2
3use serde::{ser::SerializeStruct, Serialize, Serializer};
4
5use crate::{error::DiffsolJsError, ode::Ode};
6
7#[derive(Clone)]
8pub struct OdeSolverOptions {
9 ode: Arc<Mutex<Ode>>,
10}
11impl OdeSolverOptions {
12 pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
13 Self { ode }
14 }
15 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolJsError> {
16 self.ode.lock().map_err(|_| {
17 DiffsolJsError::from(diffsol::error::DiffsolError::Other(
18 "Failed to acquire lock on Ode object".to_string(),
19 ))
20 })
21 }
22}
23
24impl OdeSolverOptions {
25 pub fn get_max_nonlinear_solver_iterations(&self) -> Result<usize, DiffsolJsError> {
26 Ok(self.guard()?.solve.ode_max_nonlinear_solver_iterations())
27 }
28 pub fn set_max_nonlinear_solver_iterations(&self, value: usize) -> Result<(), DiffsolJsError> {
29 self.guard()?
30 .solve
31 .set_ode_max_nonlinear_solver_iterations(value);
32 Ok(())
33 }
34 pub fn get_max_error_test_failures(&self) -> Result<usize, DiffsolJsError> {
35 Ok(self.guard()?.solve.ode_max_error_test_failures())
36 }
37 pub fn set_max_error_test_failures(&self, value: usize) -> Result<(), DiffsolJsError> {
38 self.guard()?.solve.set_ode_max_error_test_failures(value);
39 Ok(())
40 }
41 pub fn get_update_jacobian_after_steps(&self) -> Result<usize, DiffsolJsError> {
42 Ok(self.guard()?.solve.ode_update_jacobian_after_steps())
43 }
44 pub fn set_update_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolJsError> {
45 self.guard()?
46 .solve
47 .set_ode_update_jacobian_after_steps(value);
48 Ok(())
49 }
50 pub fn get_update_rhs_jacobian_after_steps(&self) -> Result<usize, DiffsolJsError> {
51 Ok(self.guard()?.solve.ode_update_rhs_jacobian_after_steps())
52 }
53 pub fn set_update_rhs_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolJsError> {
54 self.guard()?
55 .solve
56 .set_ode_update_rhs_jacobian_after_steps(value);
57 Ok(())
58 }
59 pub fn get_threshold_to_update_jacobian(&self) -> Result<f64, DiffsolJsError> {
60 Ok(self.guard()?.solve.ode_threshold_to_update_jacobian())
61 }
62 pub fn set_threshold_to_update_jacobian(&self, value: f64) -> Result<(), DiffsolJsError> {
63 self.guard()?
64 .solve
65 .set_ode_threshold_to_update_jacobian(value);
66 Ok(())
67 }
68 pub fn get_threshold_to_update_rhs_jacobian(&self) -> Result<f64, DiffsolJsError> {
69 Ok(self.guard()?.solve.ode_threshold_to_update_rhs_jacobian())
70 }
71 pub fn set_threshold_to_update_rhs_jacobian(&self, value: f64) -> Result<(), DiffsolJsError> {
72 self.guard()?
73 .solve
74 .set_ode_threshold_to_update_rhs_jacobian(value);
75 Ok(())
76 }
77 pub fn get_min_timestep(&self) -> Result<f64, DiffsolJsError> {
78 Ok(self.guard()?.solve.ode_min_timestep())
79 }
80 pub fn set_min_timestep(&self, value: f64) -> Result<(), DiffsolJsError> {
81 self.guard()?.solve.set_ode_min_timestep(value);
82 Ok(())
83 }
84}
85
86impl Serialize for OdeSolverOptions {
87 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
88 where
89 S: Serializer,
90 {
91 let mut state = serializer.serialize_struct("OdeSolverOptions", 7)?;
92 state.serialize_field(
93 "max_nonlinear_solver_iterations",
94 &self
95 .get_max_nonlinear_solver_iterations()
96 .map_err(serde::ser::Error::custom)?,
97 )?;
98 state.serialize_field(
99 "max_error_test_failures",
100 &self
101 .get_max_error_test_failures()
102 .map_err(serde::ser::Error::custom)?,
103 )?;
104 state.serialize_field(
105 "update_jacobian_after_steps",
106 &self
107 .get_update_jacobian_after_steps()
108 .map_err(serde::ser::Error::custom)?,
109 )?;
110 state.serialize_field(
111 "update_rhs_jacobian_after_steps",
112 &self
113 .get_update_rhs_jacobian_after_steps()
114 .map_err(serde::ser::Error::custom)?,
115 )?;
116 state.serialize_field(
117 "threshold_to_update_jacobian",
118 &self
119 .get_threshold_to_update_jacobian()
120 .map_err(serde::ser::Error::custom)?,
121 )?;
122 state.serialize_field(
123 "threshold_to_update_rhs_jacobian",
124 &self
125 .get_threshold_to_update_rhs_jacobian()
126 .map_err(serde::ser::Error::custom)?,
127 )?;
128 state.serialize_field(
129 "min_timestep",
130 &self.get_min_timestep().map_err(serde::ser::Error::custom)?,
131 )?;
132 state.end()
133 }
134}
135
136#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
137mod tests {
138 use crate::{
139 jit::JitBackendType,
140 linear_solver_type::LinearSolverType,
141 matrix_type::MatrixType,
142 ode::OdeWrapper,
143 ode_solver_type::OdeSolverType,
144 scalar_type::ScalarType,
145 test_support::{available_jit_backends, logistic_diffsl_code},
146 };
147
148 use super::OdeSolverOptions;
149
150 fn make_options(jit_backend: JitBackendType) -> OdeSolverOptions {
151 OdeWrapper::new_jit(
152 logistic_diffsl_code(),
153 jit_backend,
154 ScalarType::F64,
155 MatrixType::NalgebraDense,
156 LinearSolverType::Default,
157 OdeSolverType::Bdf,
158 )
159 .unwrap()
160 .get_options()
161 }
162
163 #[test]
164 fn ode_solver_options_roundtrip_and_serialize() {
165 for jit_backend in available_jit_backends() {
166 let options = make_options(jit_backend);
167 options.set_max_nonlinear_solver_iterations(17).unwrap();
168 options.set_max_error_test_failures(19).unwrap();
169 options.set_update_jacobian_after_steps(23).unwrap();
170 options.set_update_rhs_jacobian_after_steps(29).unwrap();
171 options.set_threshold_to_update_jacobian(1e-3).unwrap();
172 options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
173 options.set_min_timestep(1e-4).unwrap();
174
175 assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 17);
176 assert_eq!(options.get_max_error_test_failures().unwrap(), 19);
177 assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 23);
178 assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 29);
179 assert_eq!(options.get_threshold_to_update_jacobian().unwrap(), 1e-3);
180 assert_eq!(
181 options.get_threshold_to_update_rhs_jacobian().unwrap(),
182 2e-3
183 );
184 assert_eq!(options.get_min_timestep().unwrap(), 1e-4);
185
186 let value = serde_json::to_value(&options).unwrap();
187 assert_eq!(value["max_nonlinear_solver_iterations"], 17);
188 assert_eq!(value["max_error_test_failures"], 19);
189 assert_eq!(value["update_jacobian_after_steps"], 23);
190 assert_eq!(value["update_rhs_jacobian_after_steps"], 29);
191 assert_eq!(value["threshold_to_update_jacobian"], 1e-3);
192 assert_eq!(value["threshold_to_update_rhs_jacobian"], 2e-3);
193 assert_eq!(value["min_timestep"], 1e-4);
194 }
195 }
196}