differential_equations/
solution.rs

1//! Solution of differential equations
2
3use crate::{
4    Status,
5    traits::{CallBackData, Real, State},
6};
7use std::time::Instant;
8
9/// Timer for tracking solution time
10#[derive(Debug, Clone)]
11pub enum Timer<T: Real> {
12    Off,
13    Running(Instant),
14    Completed(T),
15}
16
17impl<T: Real> Timer<T> {
18    /// Starts the timer
19    pub fn start(&mut self) {
20        *self = Timer::Running(Instant::now());
21    }
22
23    /// Returns the elapsed time in seconds
24    pub fn elapsed(&self) -> T {
25        match self {
26            Timer::Off => T::zero(),
27            Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
28            Timer::Completed(t) => *t,
29        }
30    }
31
32    /// Complete the running timer and convert it to a completed state
33    pub fn complete(&mut self) {
34        match self {
35            Timer::Off => {}
36            Timer::Running(start_time) => {
37                *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
38            }
39            Timer::Completed(_) => {}
40        }
41    }
42}
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y`              - Outputted dependent variable points.
51/// * `t`              - Outputted independent variable points.
52/// * `status`         - Status of the solver.
53/// * `evals`          - Number of function evaluations.
54/// * `steps`          - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer`          - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62    T: Real,
63    V: State<T>,
64    D: CallBackData,
65{
66    /// Outputted independent variable points.
67    pub t: Vec<T>,
68
69    /// Outputted dependent variable points.
70    pub y: Vec<V>,
71
72    /// Status of the solver.
73    pub status: Status<T, V, D>,
74
75    /// Number of function evaluations.
76    pub evals: usize,
77
78    /// Total number of steps taken by the solver.
79    pub steps: usize,
80
81    /// Number of rejected steps where the solution step-size had to be reduced.
82    pub rejected_steps: usize,
83
84    /// Number of accepted steps where the solution moved closer to tf.
85    pub accepted_steps: usize,
86
87    /// Timer for tracking solution time - Running during solving, Completed after finalization
88    pub timer: Timer<T>,
89}
90
91// Initial methods for the solution
92impl<T, V, D> Default for Solution<T, V, D>
93where
94    T: Real,
95    V: State<T>,
96    D: CallBackData,
97{
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl<T, V, D> Solution<T, V, D>
104where
105    T: Real,
106    V: State<T>,
107    D: CallBackData,
108{
109    /// Creates a new Solution object.
110    pub fn new() -> Self {
111        Solution {
112            t: Vec::with_capacity(100),
113            y: Vec::with_capacity(100),
114            status: Status::Uninitialized,
115            evals: 0,
116            steps: 0,
117            rejected_steps: 0,
118            accepted_steps: 0,
119            timer: Timer::Off,
120        }
121    }
122}
123
124// Methods used during solving
125impl<T, V, D> Solution<T, V, D>
126where
127    T: Real,
128    V: State<T>,
129    D: CallBackData,
130{
131    /// Puhes a new point to the solution, e.g. t and y vecs.
132    ///
133    /// # Arguments
134    /// * `t` - The time point.
135    /// * `y` - The state vector.
136    ///
137    pub fn push(&mut self, t: T, y: V) {
138        self.t.push(t);
139        self.y.push(y);
140    }
141
142    /// Pops the last point from the solution, e.g. t and y vecs.
143    ///
144    /// # Returns
145    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
146    ///
147    pub fn pop(&mut self) -> Option<(T, V)> {
148        if self.t.is_empty() || self.y.is_empty() {
149            return None;
150        }
151        let t = self.t.pop().unwrap();
152        let y = self.y.pop().unwrap();
153        Some((t, y))
154    }
155
156    /// Truncates the solution's (t, y) points to the given index.
157    ///
158    /// # Arguments
159    /// * `index` - The index to truncate to.
160    ///
161    pub fn truncate(&mut self, index: usize) {
162        self.t.truncate(index);
163        self.y.truncate(index);
164    }
165}
166
167// Post-processing methods for the solution
168impl<T, V, D> Solution<T, V, D>
169where
170    T: Real,
171    V: State<T>,
172    D: CallBackData,
173{
174    /// Simplifies the Solution into a tuple of vectors in form (t, y).
175    /// By doing so, the Solution will be consumed and the status,
176    /// evals, steps, rejected_steps, and accepted_steps will be discarded.
177    ///
178    /// # Returns
179    /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
180    ///
181    pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
182        (self.t, self.y)
183    }
184
185    /// Returns the last accepted step of the solution in form (t, y).
186    ///
187    /// # Returns
188    /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
189    ///
190    pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
191        let t = self.t.last().ok_or("No t steps available")?;
192        let y = self.y.last().ok_or("No y vectors available")?;
193        Ok((t, y))
194    }
195
196    /// Returns an iterator over the solution.
197    ///
198    /// # Returns
199    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
200    ///   yielding (t, y) tuples.
201    ///
202    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
203        self.t.iter().zip(self.y.iter())
204    }
205
206    /// Creates a CSV file of the solution using standard library functionality.
207    ///
208    /// Note the columns will be named t, y0, y1, ..., yN.
209    ///
210    /// # Arguments
211    /// * `filename` - Name of the file to save the solution.
212    ///
213    /// # Returns
214    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
215    ///
216    #[cfg(not(feature = "polars"))]
217    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
218        use std::io::{BufWriter, Write};
219
220        // Create file and path if it does not exist
221        let path = std::path::Path::new(filename);
222        if let Some(parent) = path.parent() {
223            if !parent.exists() {
224                std::fs::create_dir_all(parent)?;
225            }
226        }
227        let file = std::fs::File::create(filename)?;
228        let mut writer = BufWriter::new(file);
229
230        // Length of state vector
231        let n = self.y[0].len();
232
233        // Write header
234        let mut header = String::from("t");
235        for i in 0..n {
236            header.push_str(&format!(",y{}", i));
237        }
238        writeln!(writer, "{}", header)?;
239
240        // Write data
241        for (t, y) in self.iter() {
242            let mut row = format!("{:?}", t);
243            for i in 0..n {
244                row.push_str(&format!(",{:?}", y.get(i)));
245            }
246            writeln!(writer, "{}", row)?;
247        }
248
249        // Ensure all data is flushed to disk
250        writer.flush()?;
251
252        Ok(())
253    }
254
255    /// Creates a csv file of the solution using Polars DataFrame.
256    ///
257    /// Note the columns will be named t, y0, y1, ..., yN.
258    ///
259    /// # Arguments
260    /// * `filename` - Name of the file to save the solution.
261    ///
262    /// # Returns
263    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
264    ///
265    #[cfg(feature = "polars")]
266    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
267        // Create file and path if it does not exist
268        let path = std::path::Path::new(filename);
269        if !path.exists() {
270            std::fs::create_dir_all(path.parent().unwrap())?;
271        }
272        let mut file = std::fs::File::create(filename)?;
273
274        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
275        let mut columns = vec![Column::new("t".into(), t)];
276        let n = self.y[0].len();
277        for i in 0..n {
278            let header = format!("y{}", i);
279            columns.push(Column::new(
280                header.into(),
281                self.y
282                    .iter()
283                    .map(|x| x.get(i).to_f64())
284                    .collect::<Vec<f64>>(),
285            ));
286        }
287        let mut df = DataFrame::new(columns)?;
288
289        // Write the dataframe to a csv file
290        CsvWriter::new(&mut file).finish(&mut df)?;
291
292        Ok(())
293    }
294
295    /// Creates a Polars DataFrame of the solution.
296    ///
297    /// Note that the columns will be named t, y0, y1, ..., yN.
298    ///
299    /// # Returns
300    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
301    ///
302    #[cfg(feature = "polars")]
303    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
304        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
305        let mut columns = vec![Column::new("t".into(), t)];
306        let n = self.y[0].len();
307        for i in 0..n {
308            let header = format!("y{}", i);
309            columns.push(Column::new(
310                header.into(),
311                self.y
312                    .iter()
313                    .map(|x| x.get(i).to_f64())
314                    .collect::<Vec<f64>>(),
315            ));
316        }
317
318        DataFrame::new(columns)
319    }
320}