differential_equations/solution.rs
1//! Solution of differential equations
2
3use crate::{
4 Status,
5 traits::{CallBackData, Real, State},
6};
7use std::time::Instant;
8
9/// Timer for tracking solution time
10#[derive(Debug, Clone)]
11pub enum Timer<T: Real> {
12 Off,
13 Running(Instant),
14 Completed(T),
15}
16
17impl<T: Real> Timer<T> {
18 /// Starts the timer
19 pub fn start(&mut self) {
20 *self = Timer::Running(Instant::now());
21 }
22
23 /// Returns the elapsed time in seconds
24 pub fn elapsed(&self) -> T {
25 match self {
26 Timer::Off => T::zero(),
27 Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
28 Timer::Completed(t) => *t,
29 }
30 }
31
32 /// Complete the running timer and convert it to a completed state
33 pub fn complete(&mut self) {
34 match self {
35 Timer::Off => {}
36 Timer::Running(start_time) => {
37 *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
38 }
39 Timer::Completed(_) => {}
40 }
41 }
42}
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y` - Outputted dependent variable points.
51/// * `t` - Outputted independent variable points.
52/// * `status` - Status of the solver.
53/// * `evals` - Number of function evaluations.
54/// * `steps` - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer` - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62 T: Real,
63 V: State<T>,
64 D: CallBackData,
65{
66 /// Outputted independent variable points.
67 pub t: Vec<T>,
68
69 /// Outputted dependent variable points.
70 pub y: Vec<V>,
71
72 /// Status of the solver.
73 pub status: Status<T, V, D>,
74
75 /// Number of function evaluations.
76 pub evals: usize,
77
78 /// Total number of steps taken by the solver.
79 pub steps: usize,
80
81 /// Number of rejected steps where the solution step-size had to be reduced.
82 pub rejected_steps: usize,
83
84 /// Number of accepted steps where the solution moved closer to tf.
85 pub accepted_steps: usize,
86
87 /// Timer for tracking solution time - Running during solving, Completed after finalization
88 pub timer: Timer<T>,
89}
90
91// Initial methods for the solution
92impl<T, V, D> Default for Solution<T, V, D>
93where
94 T: Real,
95 V: State<T>,
96 D: CallBackData,
97{
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl<T, V, D> Solution<T, V, D>
104where
105 T: Real,
106 V: State<T>,
107 D: CallBackData,
108{
109 /// Creates a new Solution object.
110 pub fn new() -> Self {
111 Solution {
112 t: Vec::with_capacity(100),
113 y: Vec::with_capacity(100),
114 status: Status::Uninitialized,
115 evals: 0,
116 steps: 0,
117 rejected_steps: 0,
118 accepted_steps: 0,
119 timer: Timer::Off,
120 }
121 }
122}
123
124// Methods used during solving
125impl<T, V, D> Solution<T, V, D>
126where
127 T: Real,
128 V: State<T>,
129 D: CallBackData,
130{
131 /// Puhes a new point to the solution, e.g. t and y vecs.
132 ///
133 /// # Arguments
134 /// * `t` - The time point.
135 /// * `y` - The state vector.
136 ///
137 pub fn push(&mut self, t: T, y: V) {
138 self.t.push(t);
139 self.y.push(y);
140 }
141
142 /// Pops the last point from the solution, e.g. t and y vecs.
143 ///
144 /// # Returns
145 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
146 ///
147 pub fn pop(&mut self) -> Option<(T, V)> {
148 if self.t.is_empty() || self.y.is_empty() {
149 return None;
150 }
151 let t = self.t.pop().unwrap();
152 let y = self.y.pop().unwrap();
153 Some((t, y))
154 }
155
156 /// Truncates the solution's (t, y) points to the given index.
157 ///
158 /// # Arguments
159 /// * `index` - The index to truncate to.
160 ///
161 pub fn truncate(&mut self, index: usize) {
162 self.t.truncate(index);
163 self.y.truncate(index);
164 }
165}
166
167// Post-processing methods for the solution
168impl<T, V, D> Solution<T, V, D>
169where
170 T: Real,
171 V: State<T>,
172 D: CallBackData,
173{
174 /// Simplifies the Solution into a tuple of vectors in form (t, y).
175 /// By doing so, the Solution will be consumed and the status,
176 /// evals, steps, rejected_steps, and accepted_steps will be discarded.
177 ///
178 /// # Returns
179 /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
180 ///
181 pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
182 (self.t, self.y)
183 }
184
185 /// Returns the last accepted step of the solution in form (t, y).
186 ///
187 /// # Returns
188 /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
189 ///
190 pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
191 let t = self.t.last().ok_or("No t steps available")?;
192 let y = self.y.last().ok_or("No y vectors available")?;
193 Ok((t, y))
194 }
195
196 /// Returns an iterator over the solution.
197 ///
198 /// # Returns
199 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
200 /// yielding (t, y) tuples.
201 ///
202 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
203 self.t.iter().zip(self.y.iter())
204 }
205
206 /// Creates a CSV file of the solution using standard library functionality.
207 ///
208 /// Note the columns will be named t, y0, y1, ..., yN.
209 ///
210 /// # Arguments
211 /// * `filename` - Name of the file to save the solution.
212 ///
213 /// # Returns
214 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
215 ///
216 #[cfg(not(feature = "polars"))]
217 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
218 use std::io::{BufWriter, Write};
219
220 // Create file and path if it does not exist
221 let path = std::path::Path::new(filename);
222 if let Some(parent) = path.parent() {
223 if !parent.exists() {
224 std::fs::create_dir_all(parent)?;
225 }
226 }
227 let file = std::fs::File::create(filename)?;
228 let mut writer = BufWriter::new(file);
229
230 // Length of state vector
231 let n = self.y[0].len();
232
233 // Write header
234 let mut header = String::from("t");
235 for i in 0..n {
236 header.push_str(&format!(",y{}", i));
237 }
238 writeln!(writer, "{}", header)?;
239
240 // Write data
241 for (t, y) in self.iter() {
242 let mut row = format!("{:?}", t);
243 for i in 0..n {
244 row.push_str(&format!(",{:?}", y.get(i)));
245 }
246 writeln!(writer, "{}", row)?;
247 }
248
249 // Ensure all data is flushed to disk
250 writer.flush()?;
251
252 Ok(())
253 }
254
255 /// Creates a csv file of the solution using Polars DataFrame.
256 ///
257 /// Note the columns will be named t, y0, y1, ..., yN.
258 ///
259 /// # Arguments
260 /// * `filename` - Name of the file to save the solution.
261 ///
262 /// # Returns
263 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
264 ///
265 #[cfg(feature = "polars")]
266 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
267 // Create file and path if it does not exist
268 let path = std::path::Path::new(filename);
269 if !path.exists() {
270 std::fs::create_dir_all(path.parent().unwrap())?;
271 }
272 let mut file = std::fs::File::create(filename)?;
273
274 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
275 let mut columns = vec![Column::new("t".into(), t)];
276 let n = self.y[0].len();
277 for i in 0..n {
278 let header = format!("y{}", i);
279 columns.push(Column::new(
280 header.into(),
281 self.y
282 .iter()
283 .map(|x| x.get(i).to_f64())
284 .collect::<Vec<f64>>(),
285 ));
286 }
287 let mut df = DataFrame::new(columns)?;
288
289 // Write the dataframe to a csv file
290 CsvWriter::new(&mut file).finish(&mut df)?;
291
292 Ok(())
293 }
294
295 /// Creates a Polars DataFrame of the solution.
296 ///
297 /// Note that the columns will be named t, y0, y1, ..., yN.
298 ///
299 /// # Returns
300 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
301 ///
302 #[cfg(feature = "polars")]
303 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
304 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
305 let mut columns = vec![Column::new("t".into(), t)];
306 let n = self.y[0].len();
307 for i in 0..n {
308 let header = format!("y{}", i);
309 columns.push(Column::new(
310 header.into(),
311 self.y
312 .iter()
313 .map(|x| x.get(i).to_f64())
314 .collect::<Vec<f64>>(),
315 ));
316 }
317
318 DataFrame::new(columns)
319 }
320}