differential_equations/
solution.rs

1//! Solution of differential equations
2
3use crate::{
4    Status,
5    traits::{Real, CallBackData},
6};
7use nalgebra::SMatrix;
8use std::time::Instant;
9
10/// Timer for tracking solution time
11#[derive(Debug, Clone)]
12pub enum Timer<T: Real> {
13    Off,
14    Running(Instant),
15    Completed(T),
16}
17
18impl<T: Real> Timer<T> {
19    /// Starts the timer
20    pub fn start(&mut self) {
21        *self = Timer::Running(Instant::now());
22    }
23
24    /// Returns the elapsed time in seconds
25    pub fn elapsed(&self) -> T {
26        match self {
27            Timer::Off => T::zero(),
28            Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
29            Timer::Completed(t) => *t,
30        }
31    }
32
33    /// Complete the running timer and convert it to a completed state
34    pub fn complete(&mut self) {
35        match self {
36            Timer::Off => {}
37            Timer::Running(start_time) => {
38                *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
39            }
40            Timer::Completed(_) => {}
41        }
42    }
43}
44
45#[cfg(feature = "polars")]
46use polars::prelude::*;
47
48/// Solution of returned by differential equation solvers
49///
50/// # Fields
51/// * `y`              - Outputted dependent variable points.
52/// * `t`              - Outputted independent variable points.
53/// * `status`         - Status of the solver.
54/// * `evals`          - Number of function evaluations.
55/// * `steps`          - Number of steps.
56/// * `rejected_steps` - Number of rejected steps.
57/// * `accepted_steps` - Number of accepted steps.
58/// * `timer`          - Timer for tracking solution time.
59///
60#[derive(Debug, Clone)]
61pub struct Solution<T, const R: usize, const C: usize, D>
62where
63    T: Real,
64    D: CallBackData,
65{
66    /// Outputted independent variable points.
67    pub t: Vec<T>,
68
69    /// Outputted dependent variable points.
70    pub y: Vec<SMatrix<T, R, C>>,
71
72    /// Status of the solver.
73    pub status: Status<T, R, C, D>,
74
75    /// Number of function evaluations.
76    pub evals: usize,
77
78    /// Total number of steps taken by the solver.
79    pub steps: usize,
80
81    /// Number of rejected steps where the solution step-size had to be reduced.
82    pub rejected_steps: usize,
83
84    /// Number of accepted steps where the solution moved closer to tf.
85    pub accepted_steps: usize,
86
87    /// Timer for tracking solution time - Running during solving, Completed after finalization
88    pub timer: Timer<T>,
89}
90
91// Initial methods for the solution
92impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
93where
94    T: Real,
95    D: CallBackData,
96{
97    /// Creates a new Solution object.
98    pub fn new() -> Self {
99        Solution {
100            t: Vec::with_capacity(100),
101            y: Vec::with_capacity(100),
102            status: Status::Uninitialized,
103            evals: 0,
104            steps: 0,
105            rejected_steps: 0,
106            accepted_steps: 0,
107            timer: Timer::Off,
108        }
109    }
110}
111
112// Methods used during solving
113impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
114where
115    T: Real,
116    D: CallBackData,
117{
118    /// Puhes a new point to the solution, e.g. t and y vecs.
119    ///
120    /// # Arguments
121    /// * `t` - The time point.
122    /// * `y` - The state vector.
123    ///
124    pub fn push(&mut self, t: T, y: SMatrix<T, R, C>) {
125        self.t.push(t);
126        self.y.push(y);
127    }
128
129    /// Pops the last point from the solution, e.g. t and y vecs.
130    ///
131    /// # Returns
132    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
133    ///
134    pub fn pop(&mut self) -> Option<(T, SMatrix<T, R, C>)> {
135        if self.t.is_empty() || self.y.is_empty() {
136            return None;
137        }
138        let t = self.t.pop().unwrap();
139        let y = self.y.pop().unwrap();
140        Some((t, y))
141    }
142
143    /// Truncates the solution's (t, y) points to the given index.
144    ///
145    /// # Arguments
146    /// * `index` - The index to truncate to.
147    ///
148    pub fn truncate(&mut self, index: usize) {
149        self.t.truncate(index);
150        self.y.truncate(index);
151    }
152}
153
154// Post-processing methods for the solution
155impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
156where
157    T: Real,
158    D: CallBackData,
159{
160    /// Simplifies the Solution into a tuple of vectors in form (t, y).
161    /// By doing so, the Solution will be consumed and the status,
162    /// evals, steps, rejected_steps, and accepted_steps will be discarded.
163    ///
164    /// # Returns
165    /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
166    ///
167    pub fn into_tuple(self) -> (Vec<T>, Vec<SMatrix<T, R, C>>) {
168        (self.t, self.y)
169    }
170
171    /// Returns the last accepted step of the solution in form (t, y).
172    ///
173    /// # Returns
174    /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
175    ///
176    pub fn last(&self) -> Result<(&T, &SMatrix<T, R, C>), Box<dyn std::error::Error>> {
177        let t = self.t.last().ok_or("No t steps available")?;
178        let y = self.y.last().ok_or("No y vectors available")?;
179        Ok((t, y))
180    }
181
182    /// Returns an iterator over the solution.
183    ///
184    /// # Returns
185    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
186    ///   yielding (t, y) tuples.
187    ///
188    pub fn iter(
189        &self,
190    ) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, SMatrix<T, R, C>>> {
191        self.t.iter().zip(self.y.iter())
192    }
193
194    /// Creates a CSV file of the solution using standard library functionality.
195    ///
196    /// Note the columns will be named t, y0, y1, ..., yN.
197    ///
198    /// # Arguments
199    /// * `filename` - Name of the file to save the solution.
200    ///
201    /// # Returns
202    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
203    ///
204    #[cfg(not(feature = "polars"))]
205    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
206        use std::io::{BufWriter, Write};
207
208        // Create file and path if it does not exist
209        let path = std::path::Path::new(filename);
210        if let Some(parent) = path.parent() {
211            if !parent.exists() {
212                std::fs::create_dir_all(parent)?;
213            }
214        }
215        let file = std::fs::File::create(filename)?;
216        let mut writer = BufWriter::new(file);
217
218        // Length of state vector
219        let n = self.y[0].len();
220
221        // Write header
222        let mut header = String::from("t");
223        for i in 0..n {
224            header.push_str(&format!(",y{}", i));
225        }
226        writeln!(writer, "{}", header)?;
227
228        // Write data
229        for (t, y) in self.iter() {
230            let mut row = format!("{:?}", t);
231            for i in 0..n {
232                row.push_str(&format!(",{:?}", y[i]));
233            }
234            writeln!(writer, "{}", row)?;
235        }
236
237        // Ensure all data is flushed to disk
238        writer.flush()?;
239
240        Ok(())
241    }
242
243    /// Creates a csv file of the solution using Polars DataFrame.
244    ///
245    /// Note the columns will be named t, y0, y1, ..., yN.
246    ///
247    /// # Arguments
248    /// * `filename` - Name of the file to save the solution.
249    ///
250    /// # Returns
251    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
252    ///
253    #[cfg(feature = "polars")]
254    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
255        // Create file and path if it does not exist
256        let path = std::path::Path::new(filename);
257        if !path.exists() {
258            std::fs::create_dir_all(path.parent().unwrap())?;
259        }
260        let mut file = std::fs::File::create(filename)?;
261
262        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
263        let mut columns = vec![Column::new("t".into(), t)];
264        let n = self.y[0].len();
265        for i in 0..n {
266            let header = format!("y{}", i);
267            columns.push(Column::new(
268                header.into(),
269                self.y.iter().map(|x| x[i].to_f64()).collect::<Vec<f64>>(),
270            ));
271        }
272        let mut df = DataFrame::new(columns)?;
273
274        // Write the dataframe to a csv file
275        CsvWriter::new(&mut file).finish(&mut df)?;
276
277        Ok(())
278    }
279
280    /// Creates a Polars DataFrame of the solution.
281    ///
282    /// Note that the columns will be named t, y0, y1, ..., yN.
283    ///
284    /// # Returns
285    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
286    ///
287    #[cfg(feature = "polars")]
288    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
289        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
290        let mut columns = vec![Column::new("t".into(), t)];
291        let n = self.y[0].len();
292        for i in 0..n {
293            let header = format!("y{}", i);
294            columns.push(Column::new(
295                header.into(),
296                self.y.iter().map(|x| x[i].to_f64()).collect::<Vec<f64>>(),
297            ));
298        }
299
300        DataFrame::new(columns)
301    }
302}