diffsol_c/
initial_condition_options.rs1use std::sync::{Arc, Mutex};
2
3use serde::{Serialize, Serializer, ser::SerializeStruct};
4
5use crate::{error::DiffsolJsError, ode::Ode};
6
7#[derive(Clone)]
8pub struct InitialConditionSolverOptions {
9 ode: Arc<Mutex<Ode>>,
10}
11impl InitialConditionSolverOptions {
12 pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
13 Self { ode }
14 }
15 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolJsError> {
16 self.ode.lock().map_err(|_| {
17 DiffsolJsError::from(diffsol::error::DiffsolError::Other(
18 "Failed to acquire lock on ODE solver".to_string(),
19 ))
20 })
21 }
22}
23
24impl InitialConditionSolverOptions {
25 pub fn get_use_linesearch(&self) -> Result<bool, DiffsolJsError> {
26 Ok(self.guard()?.solve.ic_use_linesearch())
27 }
28 pub fn set_use_linesearch(&self, value: bool) -> Result<(), DiffsolJsError> {
29 self.guard()?.solve.set_ic_use_linesearch(value);
30 Ok(())
31 }
32 pub fn get_max_linesearch_iterations(&self) -> Result<usize, DiffsolJsError> {
33 Ok(self.guard()?.solve.ic_max_linesearch_iterations())
34 }
35 pub fn set_max_linesearch_iterations(&self, value: usize) -> Result<(), DiffsolJsError> {
36 self.guard()?.solve.set_ic_max_linesearch_iterations(value);
37 Ok(())
38 }
39 pub fn get_max_newton_iterations(&self) -> Result<usize, DiffsolJsError> {
40 Ok(self.guard()?.solve.ic_max_newton_iterations())
41 }
42 pub fn set_max_newton_iterations(&self, value: usize) -> Result<(), DiffsolJsError> {
43 self.guard()?.solve.set_ic_max_newton_iterations(value);
44 Ok(())
45 }
46 pub fn get_max_linear_solver_setups(&self) -> Result<usize, DiffsolJsError> {
47 Ok(self.guard()?.solve.ic_max_linear_solver_setups())
48 }
49 pub fn set_max_linear_solver_setups(&self, value: usize) -> Result<(), DiffsolJsError> {
50 self.guard()?.solve.set_ic_max_linear_solver_setups(value);
51 Ok(())
52 }
53 pub fn get_step_reduction_factor(&self) -> Result<f64, DiffsolJsError> {
54 Ok(self.guard()?.solve.ic_step_reduction_factor())
55 }
56 pub fn set_step_reduction_factor(&self, value: f64) -> Result<(), DiffsolJsError> {
57 self.guard()?.solve.set_ic_step_reduction_factor(value);
58 Ok(())
59 }
60 pub fn get_armijo_constant(&self) -> Result<f64, DiffsolJsError> {
61 Ok(self.guard()?.solve.ic_armijo_constant())
62 }
63 pub fn set_armijo_constant(&self, value: f64) -> Result<(), DiffsolJsError> {
64 self.guard()?.solve.set_ic_armijo_constant(value);
65 Ok(())
66 }
67}
68
69impl Serialize for InitialConditionSolverOptions {
70 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71 where
72 S: Serializer,
73 {
74 let mut state = serializer.serialize_struct("InitialConditionSolverOptions", 6)?;
75 state.serialize_field(
76 "use_linesearch",
77 &self
78 .get_use_linesearch()
79 .map_err(serde::ser::Error::custom)?,
80 )?;
81 state.serialize_field(
82 "max_linesearch_iterations",
83 &self
84 .get_max_linesearch_iterations()
85 .map_err(serde::ser::Error::custom)?,
86 )?;
87 state.serialize_field(
88 "max_newton_iterations",
89 &self
90 .get_max_newton_iterations()
91 .map_err(serde::ser::Error::custom)?,
92 )?;
93 state.serialize_field(
94 "max_linear_solver_setups",
95 &self
96 .get_max_linear_solver_setups()
97 .map_err(serde::ser::Error::custom)?,
98 )?;
99 state.serialize_field(
100 "step_reduction_factor",
101 &self
102 .get_step_reduction_factor()
103 .map_err(serde::ser::Error::custom)?,
104 )?;
105 state.serialize_field(
106 "armijo_constant",
107 &self
108 .get_armijo_constant()
109 .map_err(serde::ser::Error::custom)?,
110 )?;
111 state.end()
112 }
113}
114
115#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
116mod tests {
117 use crate::{
118 jit::JitBackendType,
119 linear_solver_type::LinearSolverType,
120 matrix_type::MatrixType,
121 ode::OdeWrapper,
122 ode_solver_type::OdeSolverType,
123 scalar_type::ScalarType,
124 test_support::{available_jit_backends, logistic_diffsl_code},
125 };
126
127 use super::InitialConditionSolverOptions;
128
129 fn make_options(jit_backend: JitBackendType) -> InitialConditionSolverOptions {
130 OdeWrapper::new_jit(
131 logistic_diffsl_code(),
132 jit_backend,
133 ScalarType::F64,
134 MatrixType::NalgebraDense,
135 LinearSolverType::Default,
136 OdeSolverType::Bdf,
137 )
138 .unwrap()
139 .get_ic_options()
140 }
141
142 #[test]
143 fn initial_condition_options_roundtrip_and_serialize() {
144 for jit_backend in available_jit_backends() {
145 let options = make_options(jit_backend);
146 options.set_use_linesearch(true).unwrap();
147 options.set_max_linesearch_iterations(13).unwrap();
148 options.set_max_newton_iterations(17).unwrap();
149 options.set_max_linear_solver_setups(19).unwrap();
150 options.set_step_reduction_factor(0.5).unwrap();
151 options.set_armijo_constant(1e-4).unwrap();
152
153 assert!(options.get_use_linesearch().unwrap());
154 assert_eq!(options.get_max_linesearch_iterations().unwrap(), 13);
155 assert_eq!(options.get_max_newton_iterations().unwrap(), 17);
156 assert_eq!(options.get_max_linear_solver_setups().unwrap(), 19);
157 assert_eq!(options.get_step_reduction_factor().unwrap(), 0.5);
158 assert_eq!(options.get_armijo_constant().unwrap(), 1e-4);
159
160 let value = serde_json::to_value(&options).unwrap();
161 assert_eq!(value["use_linesearch"], true);
162 assert_eq!(value["max_linesearch_iterations"], 13);
163 assert_eq!(value["max_newton_iterations"], 17);
164 assert_eq!(value["max_linear_solver_setups"], 19);
165 assert_eq!(value["step_reduction_factor"], 0.5);
166 assert_eq!(value["armijo_constant"], 1e-4);
167 }
168 }
169}