differential_equations/
solution.rs

1//! Solution of differential equations
2
3use crate::{
4    Status,
5    traits::{CallBackData, Real, State},
6};
7use std::time::Instant;
8
9/// Timer for tracking solution time
10#[derive(Debug, Clone)]
11pub enum Timer<T: Real> {
12    Off,
13    Running(Instant),
14    Completed(T),
15}
16
17impl<T: Real> Timer<T> {
18    /// Starts the timer
19    pub fn start(&mut self) {
20        *self = Timer::Running(Instant::now());
21    }
22
23    /// Returns the elapsed time in seconds
24    pub fn elapsed(&self) -> T {
25        match self {
26            Timer::Off => T::zero(),
27            Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
28            Timer::Completed(t) => *t,
29        }
30    }
31
32    /// Complete the running timer and convert it to a completed state
33    pub fn complete(&mut self) {
34        match self {
35            Timer::Off => {}
36            Timer::Running(start_time) => {
37                *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
38            }
39            Timer::Completed(_) => {}
40        }
41    }
42}
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y`              - Outputted dependent variable points.
51/// * `t`              - Outputted independent variable points.
52/// * `status`         - Status of the solver.
53/// * `evals`          - Number of function evaluations.
54/// * `steps`          - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer`          - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62    T: Real,
63    V: State<T>,
64    D: CallBackData,
65{
66    /// Outputted independent variable points.
67    pub t: Vec<T>,
68
69    /// Outputted dependent variable points.
70    pub y: Vec<V>,
71
72    /// Status of the solver.
73    pub status: Status<T, V, D>,
74
75    /// Number of function evaluations.
76    pub evals: usize,
77
78    /// Number of Jacobian evaluations.
79    pub jac_evals: usize,
80
81    /// Total number of steps taken by the solver.
82    pub steps: usize,
83
84    /// Number of rejected steps where the solution step-size had to be reduced.
85    pub rejected_steps: usize,
86
87    /// Number of accepted steps where the solution moved closer to tf.
88    pub accepted_steps: usize,
89
90    /// Timer for tracking solution time - Running during solving, Completed after finalization
91    pub timer: Timer<T>,
92}
93
94// Initial methods for the solution
95impl<T, V, D> Default for Solution<T, V, D>
96where
97    T: Real,
98    V: State<T>,
99    D: CallBackData,
100{
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl<T, V, D> Solution<T, V, D>
107where
108    T: Real,
109    V: State<T>,
110    D: CallBackData,
111{
112    /// Creates a new Solution object.
113    pub fn new() -> Self {
114        Solution {
115            t: Vec::with_capacity(100),
116            y: Vec::with_capacity(100),
117            status: Status::Uninitialized,
118            evals: 0,
119            steps: 0,
120            jac_evals: 0,
121            rejected_steps: 0,
122            accepted_steps: 0,
123            timer: Timer::Off,
124        }
125    }
126}
127
128// Methods used during solving
129impl<T, V, D> Solution<T, V, D>
130where
131    T: Real,
132    V: State<T>,
133    D: CallBackData,
134{
135    /// Puhes a new point to the solution, e.g. t and y vecs.
136    ///
137    /// # Arguments
138    /// * `t` - The time point.
139    /// * `y` - The state vector.
140    ///
141    pub fn push(&mut self, t: T, y: V) {
142        self.t.push(t);
143        self.y.push(y);
144    }
145
146    /// Pops the last point from the solution, e.g. t and y vecs.
147    ///
148    /// # Returns
149    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
150    ///
151    pub fn pop(&mut self) -> Option<(T, V)> {
152        if self.t.is_empty() || self.y.is_empty() {
153            return None;
154        }
155        let t = self.t.pop().unwrap();
156        let y = self.y.pop().unwrap();
157        Some((t, y))
158    }
159
160    /// Truncates the solution's (t, y) points to the given index.
161    ///
162    /// # Arguments
163    /// * `index` - The index to truncate to.
164    ///
165    pub fn truncate(&mut self, index: usize) {
166        self.t.truncate(index);
167        self.y.truncate(index);
168    }
169}
170
171// Post-processing methods for the solution
172impl<T, V, D> Solution<T, V, D>
173where
174    T: Real,
175    V: State<T>,
176    D: CallBackData,
177{
178    /// Simplifies the Solution into a tuple of vectors in form (t, y).
179    /// By doing so, the Solution will be consumed and the status,
180    /// evals, steps, rejected_steps, and accepted_steps will be discarded.
181    ///
182    /// # Returns
183    /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
184    ///
185    pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
186        (self.t, self.y)
187    }
188
189    /// Returns the last accepted step of the solution in form (t, y).
190    ///
191    /// # Returns
192    /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
193    ///
194    pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
195        let t = self.t.last().ok_or("No t steps available")?;
196        let y = self.y.last().ok_or("No y vectors available")?;
197        Ok((t, y))
198    }
199
200    /// Returns an iterator over the solution.
201    ///
202    /// # Returns
203    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
204    ///   yielding (t, y) tuples.
205    ///
206    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
207        self.t.iter().zip(self.y.iter())
208    }
209
210    /// Creates a CSV file of the solution using standard library functionality.
211    ///
212    /// Note the columns will be named t, y0, y1, ..., yN.
213    ///
214    /// # Arguments
215    /// * `filename` - Name of the file to save the solution.
216    ///
217    /// # Returns
218    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
219    ///
220    #[cfg(not(feature = "polars"))]
221    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
222        use std::io::{BufWriter, Write};
223
224        // Create file and path if it does not exist
225        let path = std::path::Path::new(filename);
226        if let Some(parent) = path.parent() {
227            if !parent.exists() {
228                std::fs::create_dir_all(parent)?;
229            }
230        }
231        let file = std::fs::File::create(filename)?;
232        let mut writer = BufWriter::new(file);
233
234        // Length of state vector
235        let n = self.y[0].len();
236
237        // Write header
238        let mut header = String::from("t");
239        for i in 0..n {
240            header.push_str(&format!(",y{}", i));
241        }
242        writeln!(writer, "{}", header)?;
243
244        // Write data
245        for (t, y) in self.iter() {
246            let mut row = format!("{:?}", t);
247            for i in 0..n {
248                row.push_str(&format!(",{:?}", y.get(i)));
249            }
250            writeln!(writer, "{}", row)?;
251        }
252
253        // Ensure all data is flushed to disk
254        writer.flush()?;
255
256        Ok(())
257    }
258
259    /// Creates a csv file of the solution using Polars DataFrame.
260    ///
261    /// Note the columns will be named t, y0, y1, ..., yN.
262    ///
263    /// # Arguments
264    /// * `filename` - Name of the file to save the solution.
265    ///
266    /// # Returns
267    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
268    ///
269    #[cfg(feature = "polars")]
270    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
271        // Create file and path if it does not exist
272        let path = std::path::Path::new(filename);
273        if !path.exists() {
274            std::fs::create_dir_all(path.parent().unwrap())?;
275        }
276        let mut file = std::fs::File::create(filename)?;
277
278        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
279        let mut columns = vec![Column::new("t".into(), t)];
280        let n = self.y[0].len();
281        for i in 0..n {
282            let header = format!("y{}", i);
283            columns.push(Column::new(
284                header.into(),
285                self.y
286                    .iter()
287                    .map(|x| x.get(i).to_f64())
288                    .collect::<Vec<f64>>(),
289            ));
290        }
291        let mut df = DataFrame::new(columns)?;
292
293        // Write the dataframe to a csv file
294        CsvWriter::new(&mut file).finish(&mut df)?;
295
296        Ok(())
297    }
298
299    /// Creates a Polars DataFrame of the solution.
300    ///
301    /// Note that the columns will be named t, y0, y1, ..., yN.
302    ///
303    /// # Returns
304    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
305    ///
306    #[cfg(feature = "polars")]
307    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
308        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
309        let mut columns = vec![Column::new("t".into(), t)];
310        let n = self.y[0].len();
311        for i in 0..n {
312            let header = format!("y{}", i);
313            columns.push(Column::new(
314                header.into(),
315                self.y
316                    .iter()
317                    .map(|x| x.get(i).to_f64())
318                    .collect::<Vec<f64>>(),
319            ));
320        }
321
322        DataFrame::new(columns)
323    }
324}