diffsol_c/
initial_condition_options.rs1use std::sync::{Arc, Mutex};
2
3use serde::{Deserialize, Serialize, Serializer};
4
5use crate::{error::DiffsolRtError, ode::Ode, solve::Solve};
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
8pub struct InitialConditionSolverOptionsSnapshot {
9 pub use_linesearch: bool,
10 pub max_linesearch_iterations: usize,
11 pub max_newton_iterations: usize,
12 pub max_linear_solver_setups: usize,
13 pub step_reduction_factor: f64,
14 pub armijo_constant: f64,
15}
16
17impl InitialConditionSolverOptionsSnapshot {
18 pub(crate) fn from_solve(solve: &dyn Solve) -> Self {
19 Self {
20 use_linesearch: solve.ic_use_linesearch(),
21 max_linesearch_iterations: solve.ic_max_linesearch_iterations(),
22 max_newton_iterations: solve.ic_max_newton_iterations(),
23 max_linear_solver_setups: solve.ic_max_linear_solver_setups(),
24 step_reduction_factor: solve.ic_step_reduction_factor(),
25 armijo_constant: solve.ic_armijo_constant(),
26 }
27 }
28
29 pub(crate) fn apply_to_solve(&self, solve: &mut dyn Solve) {
30 solve.set_ic_use_linesearch(self.use_linesearch);
31 solve.set_ic_max_linesearch_iterations(self.max_linesearch_iterations);
32 solve.set_ic_max_newton_iterations(self.max_newton_iterations);
33 solve.set_ic_max_linear_solver_setups(self.max_linear_solver_setups);
34 solve.set_ic_step_reduction_factor(self.step_reduction_factor);
35 solve.set_ic_armijo_constant(self.armijo_constant);
36 }
37}
38
39#[derive(Clone)]
40pub struct InitialConditionSolverOptions {
41 ode: Arc<Mutex<Ode>>,
42}
43impl InitialConditionSolverOptions {
44 pub(crate) fn new(ode: Arc<Mutex<Ode>>) -> Self {
45 Self { ode }
46 }
47 fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
48 self.ode.lock().map_err(|_| {
49 DiffsolRtError::from(diffsol::error::DiffsolError::Other(
50 "Failed to acquire lock on ODE solver".to_string(),
51 ))
52 })
53 }
54}
55
56impl InitialConditionSolverOptions {
57 pub fn get_use_linesearch(&self) -> Result<bool, DiffsolRtError> {
58 Ok(self.guard()?.solve.ic_use_linesearch())
59 }
60 pub fn set_use_linesearch(&self, value: bool) -> Result<(), DiffsolRtError> {
61 self.guard()?.solve.set_ic_use_linesearch(value);
62 Ok(())
63 }
64 pub fn get_max_linesearch_iterations(&self) -> Result<usize, DiffsolRtError> {
65 Ok(self.guard()?.solve.ic_max_linesearch_iterations())
66 }
67 pub fn set_max_linesearch_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
68 self.guard()?.solve.set_ic_max_linesearch_iterations(value);
69 Ok(())
70 }
71 pub fn get_max_newton_iterations(&self) -> Result<usize, DiffsolRtError> {
72 Ok(self.guard()?.solve.ic_max_newton_iterations())
73 }
74 pub fn set_max_newton_iterations(&self, value: usize) -> Result<(), DiffsolRtError> {
75 self.guard()?.solve.set_ic_max_newton_iterations(value);
76 Ok(())
77 }
78 pub fn get_max_linear_solver_setups(&self) -> Result<usize, DiffsolRtError> {
79 Ok(self.guard()?.solve.ic_max_linear_solver_setups())
80 }
81 pub fn set_max_linear_solver_setups(&self, value: usize) -> Result<(), DiffsolRtError> {
82 self.guard()?.solve.set_ic_max_linear_solver_setups(value);
83 Ok(())
84 }
85 pub fn get_step_reduction_factor(&self) -> Result<f64, DiffsolRtError> {
86 Ok(self.guard()?.solve.ic_step_reduction_factor())
87 }
88 pub fn set_step_reduction_factor(&self, value: f64) -> Result<(), DiffsolRtError> {
89 self.guard()?.solve.set_ic_step_reduction_factor(value);
90 Ok(())
91 }
92 pub fn get_armijo_constant(&self) -> Result<f64, DiffsolRtError> {
93 Ok(self.guard()?.solve.ic_armijo_constant())
94 }
95 pub fn set_armijo_constant(&self, value: f64) -> Result<(), DiffsolRtError> {
96 self.guard()?.solve.set_ic_armijo_constant(value);
97 Ok(())
98 }
99}
100
101impl Serialize for InitialConditionSolverOptions {
102 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103 where
104 S: Serializer,
105 {
106 let guard = self.guard().map_err(serde::ser::Error::custom)?;
107 InitialConditionSolverOptionsSnapshot::from_solve(guard.solve.as_ref())
108 .serialize(serializer)
109 }
110}
111
112#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
113mod tests {
114 use crate::{
115 jit::JitBackendType,
116 linear_solver_type::LinearSolverType,
117 matrix_type::MatrixType,
118 ode::OdeWrapper,
119 ode_solver_type::OdeSolverType,
120 scalar_type::ScalarType,
121 test_support::{available_jit_backends, logistic_diffsl_code},
122 };
123
124 use super::InitialConditionSolverOptions;
125
126 fn make_options(jit_backend: JitBackendType) -> InitialConditionSolverOptions {
127 OdeWrapper::new_jit(
128 logistic_diffsl_code(),
129 jit_backend,
130 ScalarType::F64,
131 MatrixType::NalgebraDense,
132 LinearSolverType::Default,
133 OdeSolverType::Bdf,
134 )
135 .unwrap()
136 .get_ic_options()
137 }
138
139 #[test]
140 fn initial_condition_options_roundtrip_and_serialize() {
141 for jit_backend in available_jit_backends() {
142 let options = make_options(jit_backend);
143 options.set_use_linesearch(true).unwrap();
144 options.set_max_linesearch_iterations(13).unwrap();
145 options.set_max_newton_iterations(17).unwrap();
146 options.set_max_linear_solver_setups(19).unwrap();
147 options.set_step_reduction_factor(0.5).unwrap();
148 options.set_armijo_constant(1e-4).unwrap();
149
150 assert!(options.get_use_linesearch().unwrap());
151 assert_eq!(options.get_max_linesearch_iterations().unwrap(), 13);
152 assert_eq!(options.get_max_newton_iterations().unwrap(), 17);
153 assert_eq!(options.get_max_linear_solver_setups().unwrap(), 19);
154 assert_eq!(options.get_step_reduction_factor().unwrap(), 0.5);
155 assert_eq!(options.get_armijo_constant().unwrap(), 1e-4);
156
157 let value = serde_json::to_value(&options).unwrap();
158 assert_eq!(value["use_linesearch"], true);
159 assert_eq!(value["max_linesearch_iterations"], 13);
160 assert_eq!(value["max_newton_iterations"], 17);
161 assert_eq!(value["max_linear_solver_setups"], 19);
162 assert_eq!(value["step_reduction_factor"], 0.5);
163 assert_eq!(value["armijo_constant"], 1e-4);
164 }
165 }
166}