1use nalgebra::{
8 allocator::Allocator, dimension::DimMin, ComplexField, DefaultAllocator, DimName, VectorN, U1,
9 U6, U7,
10};
11use num_traits::Zero;
12
13mod adams;
14mod bdf;
15mod rk;
16pub use adams::*;
17pub use bdf::*;
18pub use rk::*;
19
20pub enum IVPStatus<N: ComplexField, S: DimName>
22where
23 DefaultAllocator: Allocator<N, S>,
24{
25 Redo,
26 Ok(Vec<(N::RealField, VectorN<N, S>)>),
27 Done,
28}
29
30type Path<Complex, Real, S> = Result<Vec<(Real, VectorN<Complex, S>)>, String>;
31
32pub trait IVPSolver<N: ComplexField, S: DimName>: Sized
35where
36 DefaultAllocator: Allocator<N, S>,
37{
38 fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
41 &mut self,
42 f: F,
43 params: &mut T,
44 ) -> Result<IVPStatus<N, S>, String>;
45 fn with_tolerance(self, tol: N::RealField) -> Result<Self, String>;
47 fn with_dt_max(self, max: N::RealField) -> Result<Self, String>;
49 fn with_dt_min(self, min: N::RealField) -> Result<Self, String>;
51 fn with_start(self, t_initial: N::RealField) -> Result<Self, String>;
53 fn with_end(self, t_final: N::RealField) -> Result<Self, String>;
55 fn with_initial_conditions(self, start: &[N]) -> Result<Self, String>;
57 fn build(self) -> Self;
59
60 fn get_initial_conditions(&self) -> Option<VectorN<N, S>>;
63 fn get_time(&self) -> Option<N::RealField>;
65 fn check_start(&self) -> Result<(), String>;
68
69 fn solve_ivp<
71 T: Clone,
72 F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
73 >(
74 mut self,
75 mut f: F,
76 params: &mut T,
77 ) -> Path<N, N::RealField, S> {
78 self.check_start()?;
79 let mut path = vec![];
80 let init_conditions = self.get_initial_conditions();
81 let time = self.get_time();
82 path.push((time.unwrap(), init_conditions.unwrap()));
83
84 'out: loop {
85 let step = self.step(&mut f, params)?;
86 match step {
87 IVPStatus::Done => break 'out,
88 IVPStatus::Redo => {}
89 IVPStatus::Ok(mut state) => path.append(&mut state),
90 }
91 }
92
93 Ok(path)
94 }
95}
96
97#[derive(Debug, Clone, Default)]
125#[cfg_attr(serialize, derive(Serialize, Deserialize))]
126pub struct Euler<N: ComplexField, S: DimName>
127where
128 DefaultAllocator: Allocator<N, S>,
129{
130 dt: Option<N::RealField>,
131 time: Option<N::RealField>,
132 end: Option<N::RealField>,
133 state: Option<VectorN<N, S>>,
134}
135
136impl<N: ComplexField, S: DimName> Euler<N, S>
137where
138 DefaultAllocator: Allocator<N, S>,
139{
140 pub fn new() -> Self {
141 Euler {
142 dt: None,
143 time: None,
144 end: None,
145 state: None,
146 }
147 }
148}
149
150impl<N: ComplexField, S: DimName> IVPSolver<N, S> for Euler<N, S>
151where
152 DefaultAllocator: Allocator<N, S>,
153{
154 fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
155 &mut self,
156 mut f: F,
157 params: &mut T,
158 ) -> Result<IVPStatus<N, S>, String> {
159 if self.time >= self.end {
160 return Ok(IVPStatus::Done);
161 }
162 if self.time.unwrap() + self.dt.unwrap() >= self.end.unwrap() {
163 self.dt = Some(self.end.unwrap() - self.time.unwrap());
164 }
165
166 let deriv = f(
167 self.time.unwrap(),
168 self.state.as_ref().unwrap().as_slice(),
169 params,
170 )?;
171
172 *self.state.get_or_insert(VectorN::from_iterator(
173 [N::zero()].repeat(self.state.as_ref().unwrap().as_slice().len()),
174 )) += deriv * N::from_real(self.dt.unwrap());
175 *self.time.get_or_insert(N::RealField::zero()) += self.dt.unwrap();
176 Ok(IVPStatus::Ok(vec![(
177 self.time.unwrap(),
178 self.state.clone().unwrap(),
179 )]))
180 }
181
182 fn with_tolerance(self, _tol: N::RealField) -> Result<Self, String> {
183 Ok(self)
184 }
185
186 fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
187 self.dt = Some(max);
188 Ok(self)
189 }
190
191 fn with_dt_min(self, _min: N::RealField) -> Result<Self, String> {
192 Ok(self)
193 }
194
195 fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
196 if let Some(end) = self.end {
197 if end <= t_initial {
198 return Err("Euler with_end: Start must be after end".to_owned());
199 }
200 }
201 self.time = Some(t_initial);
202 Ok(self)
203 }
204
205 fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
206 if let Some(start) = self.time {
207 if start >= t_final {
208 return Err("Euler with_end: Start must be after end".to_owned());
209 }
210 }
211 self.end = Some(t_final);
212 Ok(self)
213 }
214
215 fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
216 self.state = Some(VectorN::from_column_slice(start));
217 Ok(self)
218 }
219
220 fn build(self) -> Self {
221 self
222 }
223
224 fn get_initial_conditions(&self) -> Option<VectorN<N, S>> {
225 if let Some(state) = &self.state {
226 Some(state.clone())
227 } else {
228 None
229 }
230 }
231
232 fn get_time(&self) -> Option<N::RealField> {
233 self.time
234 }
235
236 fn check_start(&self) -> Result<(), String> {
237 if self.time == None {
238 Err("Euler check_start: No initial time".to_owned())
239 } else if self.end == None {
240 Err("Euler check_start: No end time".to_owned())
241 } else if self.state == None {
242 Err("Euler check_start: No initial conditions".to_owned())
243 } else if self.dt == None {
244 Err("Euler check_start: No dt".to_owned())
245 } else {
246 Ok(())
247 }
248 }
249}
250
251pub fn solve_ivp<N, S, T, F>(
289 (start, end): (N::RealField, N::RealField),
290 (dt_max, dt_min): (N::RealField, N::RealField),
291 y_0: &[N],
292 mut f: F,
293 tol: N::RealField,
294 params: &mut T,
295) -> Path<N, N::RealField, S>
296where
297 N: ComplexField,
298 S: DimName + DimMin<S, Output = S>,
299 T: Clone,
300 F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
301 DefaultAllocator: Allocator<N, S>
302 + Allocator<N, U6>
303 + Allocator<N, S, U6>
304 + Allocator<N, U6, U6>
305 + Allocator<N::RealField, U6>
306 + Allocator<N::RealField, U6, U6>
307 + Allocator<N, U7>
308 + Allocator<N, S, S>
309 + Allocator<N, U1, S>
310 + Allocator<(usize, usize), S>,
311{
312 let solver = Adams::new()
313 .with_start(start)?
314 .with_end(end)?
315 .with_dt_max(dt_max)?
316 .with_dt_min(dt_min)?
317 .with_tolerance(tol)?
318 .with_initial_conditions(y_0)?
319 .build();
320
321 let path = solver.solve_ivp(&mut f, &mut params.clone());
322
323 if let Ok(path) = path {
324 return Ok(path);
325 }
326
327 let solver: RK45<N, S> = RK45::new()
328 .with_initial_conditions(y_0)?
329 .with_start(start)?
330 .with_end(end)?
331 .with_dt_max(dt_max)?
332 .with_dt_min(dt_min)?
333 .with_tolerance(tol)?
334 .build();
335
336 let path = solver.solve_ivp(&mut f, &mut params.clone());
337
338 if let Ok(path) = path {
339 return Ok(path);
340 }
341
342 let solver = BDF6::new()
343 .with_start(start)?
344 .with_end(end)?
345 .with_dt_max(dt_max)?
346 .with_dt_min(dt_min)?
347 .with_tolerance(tol)?
348 .with_initial_conditions(y_0)?
349 .build();
350
351 solver.solve_ivp(&mut f, params)
352}