Skip to main content

diffsol_c/
initial_condition_options.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{ser::SerializeStruct, Serialize, Serializer};
4
5use crate::{error::DiffsolRtError, ode::Ode};
6
7#[derive(Clone)]
8pub struct InitialConditionSolverOptions {
9    ode: Arc<Mutex<Ode>>,
10}
11impl InitialConditionSolverOptions {
12    pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
13        Self { ode }
14    }
15    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
16        self.ode.lock().map_err(|_| {
17            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
18                "Failed to acquire lock on ODE solver".to_string(),
19            ))
20        })
21    }
22}
23
24impl InitialConditionSolverOptions {
25    pub fn get_use_linesearch(&self) -> Result<bool, DiffsolRtError> {
26        Ok(self.guard()?.solve.ic_use_linesearch())
27    }
28    pub fn set_use_linesearch(&self, value: bool) -> Result<(), DiffsolRtError> {
29        self.guard()?.solve.set_ic_use_linesearch(value);
30        Ok(())
31    }
32    pub fn get_max_linesearch_iterations(&self) -> Result<usize, DiffsolRtError> {
33        Ok(self.guard()?.solve.ic_max_linesearch_iterations())
34    }
35    pub fn set_max_linesearch_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
36        self.guard()?.solve.set_ic_max_linesearch_iterations(value);
37        Ok(())
38    }
39    pub fn get_max_newton_iterations(&self) -> Result<usize, DiffsolRtError> {
40        Ok(self.guard()?.solve.ic_max_newton_iterations())
41    }
42    pub fn set_max_newton_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
43        self.guard()?.solve.set_ic_max_newton_iterations(value);
44        Ok(())
45    }
46    pub fn get_max_linear_solver_setups(&self) -> Result<usize, DiffsolRtError> {
47        Ok(self.guard()?.solve.ic_max_linear_solver_setups())
48    }
49    pub fn set_max_linear_solver_setups(&self, value: usize) -> Result<(), DiffsolRtError> {
50        self.guard()?.solve.set_ic_max_linear_solver_setups(value);
51        Ok(())
52    }
53    pub fn get_step_reduction_factor(&self) -> Result<f64, DiffsolRtError> {
54        Ok(self.guard()?.solve.ic_step_reduction_factor())
55    }
56    pub fn set_step_reduction_factor(&self, value: f64) -> Result<(), DiffsolRtError> {
57        self.guard()?.solve.set_ic_step_reduction_factor(value);
58        Ok(())
59    }
60    pub fn get_armijo_constant(&self) -> Result<f64, DiffsolRtError> {
61        Ok(self.guard()?.solve.ic_armijo_constant())
62    }
63    pub fn set_armijo_constant(&self, value: f64) -> Result<(), DiffsolRtError> {
64        self.guard()?.solve.set_ic_armijo_constant(value);
65        Ok(())
66    }
67}
68
69impl Serialize for InitialConditionSolverOptions {
70    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71    where
72        S: Serializer,
73    {
74        let mut state = serializer.serialize_struct("InitialConditionSolverOptions", 6)?;
75        state.serialize_field(
76            "use_linesearch",
77            &self
78                .get_use_linesearch()
79                .map_err(serde::ser::Error::custom)?,
80        )?;
81        state.serialize_field(
82            "max_linesearch_iterations",
83            &self
84                .get_max_linesearch_iterations()
85                .map_err(serde::ser::Error::custom)?,
86        )?;
87        state.serialize_field(
88            "max_newton_iterations",
89            &self
90                .get_max_newton_iterations()
91                .map_err(serde::ser::Error::custom)?,
92        )?;
93        state.serialize_field(
94            "max_linear_solver_setups",
95            &self
96                .get_max_linear_solver_setups()
97                .map_err(serde::ser::Error::custom)?,
98        )?;
99        state.serialize_field(
100            "step_reduction_factor",
101            &self
102                .get_step_reduction_factor()
103                .map_err(serde::ser::Error::custom)?,
104        )?;
105        state.serialize_field(
106            "armijo_constant",
107            &self
108                .get_armijo_constant()
109                .map_err(serde::ser::Error::custom)?,
110        )?;
111        state.end()
112    }
113}
114
115#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
116mod tests {
117    use crate::{
118        jit::JitBackendType,
119        linear_solver_type::LinearSolverType,
120        matrix_type::MatrixType,
121        ode::OdeWrapper,
122        ode_solver_type::OdeSolverType,
123        scalar_type::ScalarType,
124        test_support::{available_jit_backends, logistic_diffsl_code},
125    };
126
127    use super::InitialConditionSolverOptions;
128
129    fn make_options(jit_backend: JitBackendType) -> InitialConditionSolverOptions {
130        OdeWrapper::new_jit(
131            logistic_diffsl_code(),
132            jit_backend,
133            ScalarType::F64,
134            MatrixType::NalgebraDense,
135            LinearSolverType::Default,
136            OdeSolverType::Bdf,
137        )
138        .unwrap()
139        .get_ic_options()
140    }
141
142    #[test]
143    fn initial_condition_options_roundtrip_and_serialize() {
144        for jit_backend in available_jit_backends() {
145            let options = make_options(jit_backend);
146            options.set_use_linesearch(true).unwrap();
147            options.set_max_linesearch_iterations(13).unwrap();
148            options.set_max_newton_iterations(17).unwrap();
149            options.set_max_linear_solver_setups(19).unwrap();
150            options.set_step_reduction_factor(0.5).unwrap();
151            options.set_armijo_constant(1e-4).unwrap();
152
153            assert!(options.get_use_linesearch().unwrap());
154            assert_eq!(options.get_max_linesearch_iterations().unwrap(), 13);
155            assert_eq!(options.get_max_newton_iterations().unwrap(), 17);
156            assert_eq!(options.get_max_linear_solver_setups().unwrap(), 19);
157            assert_eq!(options.get_step_reduction_factor().unwrap(), 0.5);
158            assert_eq!(options.get_armijo_constant().unwrap(), 1e-4);
159
160            let value = serde_json::to_value(&options).unwrap();
161            assert_eq!(value["use_linesearch"], true);
162            assert_eq!(value["max_linesearch_iterations"], 13);
163            assert_eq!(value["max_newton_iterations"], 17);
164            assert_eq!(value["max_linear_solver_setups"], 19);
165            assert_eq!(value["step_reduction_factor"], 0.5);
166            assert_eq!(value["armijo_constant"], 1e-4);
167        }
168    }
169}