Skip to main content

diffsol_c/
ode_options.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{ser::SerializeStruct, Serialize, Serializer};
4
5use crate::{error::DiffsolRtError, ode::Ode};
6
7#[derive(Clone)]
8pub struct OdeSolverOptions {
9    ode: Arc<Mutex<Ode>>,
10}
11impl OdeSolverOptions {
12    pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
13        Self { ode }
14    }
15    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
16        self.ode.lock().map_err(|_| {
17            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
18                "Failed to acquire lock on Ode object".to_string(),
19            ))
20        })
21    }
22}
23
24impl OdeSolverOptions {
25    pub fn get_max_nonlinear_solver_iterations(&self) -> Result<usize, DiffsolRtError> {
26        Ok(self.guard()?.solve.ode_max_nonlinear_solver_iterations())
27    }
28    pub fn set_max_nonlinear_solver_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
29        self.guard()?
30            .solve
31            .set_ode_max_nonlinear_solver_iterations(value);
32        Ok(())
33    }
34    pub fn get_max_error_test_failures(&self) -> Result<usize, DiffsolRtError> {
35        Ok(self.guard()?.solve.ode_max_error_test_failures())
36    }
37    pub fn set_max_error_test_failures(&self, value: usize) -> Result<(), DiffsolRtError> {
38        self.guard()?.solve.set_ode_max_error_test_failures(value);
39        Ok(())
40    }
41    pub fn get_update_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
42        Ok(self.guard()?.solve.ode_update_jacobian_after_steps())
43    }
44    pub fn set_update_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
45        self.guard()?
46            .solve
47            .set_ode_update_jacobian_after_steps(value);
48        Ok(())
49    }
50    pub fn get_update_rhs_jacobian_after_steps(&self) -> Result<usize, DiffsolRtError> {
51        Ok(self.guard()?.solve.ode_update_rhs_jacobian_after_steps())
52    }
53    pub fn set_update_rhs_jacobian_after_steps(&self, value: usize) -> Result<(), DiffsolRtError> {
54        self.guard()?
55            .solve
56            .set_ode_update_rhs_jacobian_after_steps(value);
57        Ok(())
58    }
59    pub fn get_threshold_to_update_jacobian(&self) -> Result<f64, DiffsolRtError> {
60        Ok(self.guard()?.solve.ode_threshold_to_update_jacobian())
61    }
62    pub fn set_threshold_to_update_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
63        self.guard()?
64            .solve
65            .set_ode_threshold_to_update_jacobian(value);
66        Ok(())
67    }
68    pub fn get_threshold_to_update_rhs_jacobian(&self) -> Result<f64, DiffsolRtError> {
69        Ok(self.guard()?.solve.ode_threshold_to_update_rhs_jacobian())
70    }
71    pub fn set_threshold_to_update_rhs_jacobian(&self, value: f64) -> Result<(), DiffsolRtError> {
72        self.guard()?
73            .solve
74            .set_ode_threshold_to_update_rhs_jacobian(value);
75        Ok(())
76    }
77    pub fn get_min_timestep(&self) -> Result<f64, DiffsolRtError> {
78        Ok(self.guard()?.solve.ode_min_timestep())
79    }
80    pub fn set_min_timestep(&self, value: f64) -> Result<(), DiffsolRtError> {
81        self.guard()?.solve.set_ode_min_timestep(value);
82        Ok(())
83    }
84}
85
86impl Serialize for OdeSolverOptions {
87    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
88    where
89        S: Serializer,
90    {
91        let mut state = serializer.serialize_struct("OdeSolverOptions", 7)?;
92        state.serialize_field(
93            "max_nonlinear_solver_iterations",
94            &self
95                .get_max_nonlinear_solver_iterations()
96                .map_err(serde::ser::Error::custom)?,
97        )?;
98        state.serialize_field(
99            "max_error_test_failures",
100            &self
101                .get_max_error_test_failures()
102                .map_err(serde::ser::Error::custom)?,
103        )?;
104        state.serialize_field(
105            "update_jacobian_after_steps",
106            &self
107                .get_update_jacobian_after_steps()
108                .map_err(serde::ser::Error::custom)?,
109        )?;
110        state.serialize_field(
111            "update_rhs_jacobian_after_steps",
112            &self
113                .get_update_rhs_jacobian_after_steps()
114                .map_err(serde::ser::Error::custom)?,
115        )?;
116        state.serialize_field(
117            "threshold_to_update_jacobian",
118            &self
119                .get_threshold_to_update_jacobian()
120                .map_err(serde::ser::Error::custom)?,
121        )?;
122        state.serialize_field(
123            "threshold_to_update_rhs_jacobian",
124            &self
125                .get_threshold_to_update_rhs_jacobian()
126                .map_err(serde::ser::Error::custom)?,
127        )?;
128        state.serialize_field(
129            "min_timestep",
130            &self.get_min_timestep().map_err(serde::ser::Error::custom)?,
131        )?;
132        state.end()
133    }
134}
135
136#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
137mod tests {
138    use crate::{
139        jit::JitBackendType,
140        linear_solver_type::LinearSolverType,
141        matrix_type::MatrixType,
142        ode::OdeWrapper,
143        ode_solver_type::OdeSolverType,
144        scalar_type::ScalarType,
145        test_support::{available_jit_backends, logistic_diffsl_code},
146    };
147
148    use super::OdeSolverOptions;
149
150    fn make_options(jit_backend: JitBackendType) -> OdeSolverOptions {
151        OdeWrapper::new_jit(
152            logistic_diffsl_code(),
153            jit_backend,
154            ScalarType::F64,
155            MatrixType::NalgebraDense,
156            LinearSolverType::Default,
157            OdeSolverType::Bdf,
158        )
159        .unwrap()
160        .get_options()
161    }
162
163    #[test]
164    fn ode_solver_options_roundtrip_and_serialize() {
165        for jit_backend in available_jit_backends() {
166            let options = make_options(jit_backend);
167            options.set_max_nonlinear_solver_iterations(17).unwrap();
168            options.set_max_error_test_failures(19).unwrap();
169            options.set_update_jacobian_after_steps(23).unwrap();
170            options.set_update_rhs_jacobian_after_steps(29).unwrap();
171            options.set_threshold_to_update_jacobian(1e-3).unwrap();
172            options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
173            options.set_min_timestep(1e-4).unwrap();
174
175            assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 17);
176            assert_eq!(options.get_max_error_test_failures().unwrap(), 19);
177            assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 23);
178            assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 29);
179            assert_eq!(options.get_threshold_to_update_jacobian().unwrap(), 1e-3);
180            assert_eq!(
181                options.get_threshold_to_update_rhs_jacobian().unwrap(),
182                2e-3
183            );
184            assert_eq!(options.get_min_timestep().unwrap(), 1e-4);
185
186            let value = serde_json::to_value(&options).unwrap();
187            assert_eq!(value["max_nonlinear_solver_iterations"], 17);
188            assert_eq!(value["max_error_test_failures"], 19);
189            assert_eq!(value["update_jacobian_after_steps"], 23);
190            assert_eq!(value["update_rhs_jacobian_after_steps"], 29);
191            assert_eq!(value["threshold_to_update_jacobian"], 1e-3);
192            assert_eq!(value["threshold_to_update_rhs_jacobian"], 2e-3);
193            assert_eq!(value["min_timestep"], 1e-4);
194        }
195    }
196}