Skip to main content

diffsol_c/
ode_options.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{Deserialize, Serialize, Serializer};
4
5use crate::{error::DiffsolRtError, ode::Ode, solve::Solve};
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
8pub struct OdeSolverOptionsSnapshot {
9    pub max_nonlinear_solver_iterations: usize,
10    pub max_error_test_failures: usize,
11    pub update_jacobian_after_steps: usize,
12    pub update_rhs_jacobian_after_steps: usize,
13    pub threshold_to_update_jacobian: f64,
14    pub threshold_to_update_rhs_jacobian: f64,
15    pub min_timestep: f64,
16}
17
18impl OdeSolverOptionsSnapshot {
19    pub(crate) fn from_solve(solve: &dyn Solve) -> Self {
20        Self {
21            max_nonlinear_solver_iterations: solve.ode_max_nonlinear_solver_iterations(),
22            max_error_test_failures: solve.ode_max_error_test_failures(),
23            update_jacobian_after_steps: solve.ode_update_jacobian_after_steps(),
24            update_rhs_jacobian_after_steps: solve.ode_update_rhs_jacobian_after_steps(),
25            threshold_to_update_jacobian: solve.ode_threshold_to_update_jacobian(),
26            threshold_to_update_rhs_jacobian: solve.ode_threshold_to_update_rhs_jacobian(),
27            min_timestep: solve.ode_min_timestep(),
28        }
29    }
30
31    pub(crate) fn apply_to_solve(&self, solve: &mut dyn Solve) {
32        solve.set_ode_max_nonlinear_solver_iterations(self.max_nonlinear_solver_iterations);
33        solve.set_ode_max_error_test_failures(self.max_error_test_failures);
34        solve.set_ode_update_jacobian_after_steps(self.update_jacobian_after_steps);
35        solve.set_ode_update_rhs_jacobian_after_steps(self.update_rhs_jacobian_after_steps);
36        solve.set_ode_threshold_to_update_jacobian(self.threshold_to_update_jacobian);
37        solve.set_ode_threshold_to_update_rhs_jacobian(self.threshold_to_update_rhs_jacobian);
38        solve.set_ode_min_timestep(self.min_timestep);
39    }
40}
41
42#[derive(Clone)]
43pub struct OdeSolverOptions {
44    ode: Arc<Mutex<Ode>>,
45}
46impl OdeSolverOptions {
47    pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
48        Self { ode }
49    }
50    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
51        self.ode.lock().map_err(|_| {
52            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
53                "Failed to acquire lock on Ode object".to_string(),
54            ))
55        })
56    }
57}
58
59impl OdeSolverOptions {
60    pub fn get_max_nonlinear_solver_iterations(&self) -> Result<usize, DiffsolRtError> {
61        Ok(self.guard()?.solve.ode_max_nonlinear_solver_iterations())
62    }
63    pub fn set_max_nonlinear_solver_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
64        self.guard()?
65            .solve
66            .set_ode_max_nonlinear_solver_iterations(value);
67        Ok(())
68    }
69    pub fn get_max_error_test_failures(&self) -> Result<usize, DiffsolRtError> {
70        Ok(self.guard()?.solve.ode_max_error_test_failures())
71    }
72    pub fn set_max_error_test_failures(&self, value: usize) -> Result<(), DiffsolRtError> {
73        self.guard()?.solve.set_ode_max_error_test_failures(value);
74        Ok(())
75    }
76    pub fn get_update_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
77        Ok(self.guard()?.solve.ode_update_jacobian_after_steps())
78    }
79    pub fn set_update_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
80        self.guard()?
81            .solve
82            .set_ode_update_jacobian_after_steps(value);
83        Ok(())
84    }
85    pub fn get_update_rhs_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
86        Ok(self.guard()?.solve.ode_update_rhs_jacobian_after_steps())
87    }
88    pub fn set_update_rhs_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
89        self.guard()?
90            .solve
91            .set_ode_update_rhs_jacobian_after_steps(value);
92        Ok(())
93    }
94    pub fn get_threshold_to_update_jacobian(&self) -> Result<f64, DiffsolRtError> {
95        Ok(self.guard()?.solve.ode_threshold_to_update_jacobian())
96    }
97    pub fn set_threshold_to_update_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
98        self.guard()?
99            .solve
100            .set_ode_threshold_to_update_jacobian(value);
101        Ok(())
102    }
103    pub fn get_threshold_to_update_rhs_jacobian(&self) -> Result<f64, DiffsolRtError> {
104        Ok(self.guard()?.solve.ode_threshold_to_update_rhs_jacobian())
105    }
106    pub fn set_threshold_to_update_rhs_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
107        self.guard()?
108            .solve
109            .set_ode_threshold_to_update_rhs_jacobian(value);
110        Ok(())
111    }
112    pub fn get_min_timestep(&self) -> Result<f64, DiffsolRtError> {
113        Ok(self.guard()?.solve.ode_min_timestep())
114    }
115    pub fn set_min_timestep(&self, value: f64) -> Result<(), DiffsolRtError> {
116        self.guard()?.solve.set_ode_min_timestep(value);
117        Ok(())
118    }
119}
120
121impl Serialize for OdeSolverOptions {
122    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123    where
124        S: Serializer,
125    {
126        let guard = self.guard().map_err(serde::ser::Error::custom)?;
127        OdeSolverOptionsSnapshot::from_solve(guard.solve.as_ref()).serialize(serializer)
128    }
129}
130
131#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
132mod tests {
133    use crate::{
134        jit::JitBackendType,
135        linear_solver_type::LinearSolverType,
136        matrix_type::MatrixType,
137        ode::OdeWrapper,
138        ode_solver_type::OdeSolverType,
139        scalar_type::ScalarType,
140        test_support::{available_jit_backends, logistic_diffsl_code},
141    };
142
143    use super::OdeSolverOptions;
144
145    fn make_options(jit_backend: JitBackendType) -> OdeSolverOptions {
146        OdeWrapper::new_jit(
147            logistic_diffsl_code(),
148            jit_backend,
149            ScalarType::F64,
150            MatrixType::NalgebraDense,
151            LinearSolverType::Default,
152            OdeSolverType::Bdf,
153        )
154        .unwrap()
155        .get_options()
156    }
157
158    #[test]
159    fn ode_solver_options_roundtrip_and_serialize() {
160        for jit_backend in available_jit_backends() {
161            let options = make_options(jit_backend);
162            options.set_max_nonlinear_solver_iterations(17).unwrap();
163            options.set_max_error_test_failures(19).unwrap();
164            options.set_update_jacobian_after_steps(23).unwrap();
165            options.set_update_rhs_jacobian_after_steps(29).unwrap();
166            options.set_threshold_to_update_jacobian(1e-3).unwrap();
167            options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
168            options.set_min_timestep(1e-4).unwrap();
169
170            assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 17);
171            assert_eq!(options.get_max_error_test_failures().unwrap(), 19);
172            assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 23);
173            assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 29);
174            assert_eq!(options.get_threshold_to_update_jacobian().unwrap(), 1e-3);
175            assert_eq!(
176                options.get_threshold_to_update_rhs_jacobian().unwrap(),
177                2e-3
178            );
179            assert_eq!(options.get_min_timestep().unwrap(), 1e-4);
180
181            let value = serde_json::to_value(&options).unwrap();
182            assert_eq!(value["max_nonlinear_solver_iterations"], 17);
183            assert_eq!(value["max_error_test_failures"], 19);
184            assert_eq!(value["update_jacobian_after_steps"], 23);
185            assert_eq!(value["update_rhs_jacobian_after_steps"], 29);
186            assert_eq!(value["threshold_to_update_jacobian"], 1e-3);
187            assert_eq!(value["threshold_to_update_rhs_jacobian"], 2e-3);
188            assert_eq!(value["min_timestep"], 1e-4);
189        }
190    }
191}