differential_equations/
solution.rs

1//! Solution of differential equations
2
3use crate::{
4    Status,
5    stats::{Evals, Steps, Timer},
6    traits::{CallBackData, Real, State},
7};
8
9#[cfg(feature = "polars")]
10use polars::prelude::*;
11
12/// Solution of returned by differential equation solvers
13///
14/// # Fields
15/// * `y`              - Outputted dependent variable points.
16/// * `t`              - Outputted independent variable points.
17/// * `status`         - Status of the solver.
18/// * `evals`          - Number of function evaluations.
19/// * `steps`          - Number of steps.
20/// * `rejected_steps` - Number of rejected steps.
21/// * `accepted_steps` - Number of accepted steps.
22/// * `timer`          - Timer for tracking solution time.
23///
24#[derive(Debug, Clone)]
25pub struct Solution<T, V, D>
26where
27    T: Real,
28    V: State<T>,
29    D: CallBackData,
30{
31    /// Outputted independent variable points.
32    pub t: Vec<T>,
33
34    /// Outputted dependent variable points.
35    pub y: Vec<V>,
36
37    /// Status of the solver.
38    pub status: Status<T, V, D>,
39
40    /// Number of function, jacobian, etc evaluations.
41    pub evals: Evals,
42
43    /// Number of steps taken during the solution.
44    pub steps: Steps,
45
46    /// Timer for tracking solution time - Running during solving, Completed after finalization
47    pub timer: Timer<T>,
48}
49
50// Initial methods for the solution
51impl<T, V, D> Default for Solution<T, V, D>
52where
53    T: Real,
54    V: State<T>,
55    D: CallBackData,
56{
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<T, V, D> Solution<T, V, D>
63where
64    T: Real,
65    V: State<T>,
66    D: CallBackData,
67{
68    /// Creates a new Solution object.
69    pub fn new() -> Self {
70        Solution {
71            t: Vec::with_capacity(100),
72            y: Vec::with_capacity(100),
73            status: Status::Uninitialized,
74            evals: Evals::new(),
75            steps: Steps::new(),
76            timer: Timer::Off,
77        }
78    }
79}
80
81// Methods used during solving
82impl<T, V, D> Solution<T, V, D>
83where
84    T: Real,
85    V: State<T>,
86    D: CallBackData,
87{
88    /// Puhes a new point to the solution, e.g. t and y vecs.
89    ///
90    /// # Arguments
91    /// * `t` - The time point.
92    /// * `y` - The state vector.
93    ///
94    pub fn push(&mut self, t: T, y: V) {
95        self.t.push(t);
96        self.y.push(y);
97    }
98
99    /// Pops the last point from the solution, e.g. t and y vecs.
100    ///
101    /// # Returns
102    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
103    ///
104    pub fn pop(&mut self) -> Option<(T, V)> {
105        if self.t.is_empty() || self.y.is_empty() {
106            return None;
107        }
108        let t = self.t.pop().unwrap();
109        let y = self.y.pop().unwrap();
110        Some((t, y))
111    }
112
113    /// Truncates the solution's (t, y) points to the given index.
114    ///
115    /// # Arguments
116    /// * `index` - The index to truncate to.
117    ///
118    pub fn truncate(&mut self, index: usize) {
119        self.t.truncate(index);
120        self.y.truncate(index);
121    }
122}
123
124// Post-processing methods for the solution
125impl<T, V, D> Solution<T, V, D>
126where
127    T: Real,
128    V: State<T>,
129    D: CallBackData,
130{
131    /// Simplifies the Solution into a tuple of vectors in form (t, y).
132    /// By doing so, the Solution will be consumed and the status,
133    /// evals, steps, rejected_steps, and accepted_steps will be discarded.
134    ///
135    /// # Returns
136    /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
137    ///
138    pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
139        (self.t, self.y)
140    }
141
142    /// Returns the last accepted step of the solution in form (t, y).
143    ///
144    /// # Returns
145    /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
146    ///
147    pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
148        let t = self.t.last().ok_or("No t steps available")?;
149        let y = self.y.last().ok_or("No y vectors available")?;
150        Ok((t, y))
151    }
152
153    /// Returns an iterator over the solution.
154    ///
155    /// # Returns
156    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
157    ///   yielding (t, y) tuples.
158    ///
159    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
160        self.t.iter().zip(self.y.iter())
161    }
162
163    /// Creates a CSV file of the solution using standard library functionality.
164    ///
165    /// Note the columns will be named t, y0, y1, ..., yN.
166    ///
167    /// # Arguments
168    /// * `filename` - Name of the file to save the solution.
169    ///
170    /// # Returns
171    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
172    ///
173    #[cfg(not(feature = "polars"))]
174    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
175        use std::io::{BufWriter, Write};
176
177        // Create file and path if it does not exist
178        let path = std::path::Path::new(filename);
179        if let Some(parent) = path.parent() {
180            if !parent.exists() {
181                std::fs::create_dir_all(parent)?;
182            }
183        }
184        let file = std::fs::File::create(filename)?;
185        let mut writer = BufWriter::new(file);
186
187        // Length of state vector
188        let n = self.y[0].len();
189
190        // Write header
191        let mut header = String::from("t");
192        for i in 0..n {
193            header.push_str(&format!(",y{}", i));
194        }
195        writeln!(writer, "{}", header)?;
196
197        // Write data
198        for (t, y) in self.iter() {
199            let mut row = format!("{:?}", t);
200            for i in 0..n {
201                row.push_str(&format!(",{:?}", y.get(i)));
202            }
203            writeln!(writer, "{}", row)?;
204        }
205
206        // Ensure all data is flushed to disk
207        writer.flush()?;
208
209        Ok(())
210    }
211
212    /// Creates a csv file of the solution using Polars DataFrame.
213    ///
214    /// Note the columns will be named t, y0, y1, ..., yN.
215    ///
216    /// # Arguments
217    /// * `filename` - Name of the file to save the solution.
218    ///
219    /// # Returns
220    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
221    ///
222    #[cfg(feature = "polars")]
223    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
224        // Create file and path if it does not exist
225        let path = std::path::Path::new(filename);
226        if !path.exists() {
227            std::fs::create_dir_all(path.parent().unwrap())?;
228        }
229        let mut file = std::fs::File::create(filename)?;
230
231        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
232        let mut columns = vec![Column::new("t".into(), t)];
233        let n = self.y[0].len();
234        for i in 0..n {
235            let header = format!("y{}", i);
236            columns.push(Column::new(
237                header.into(),
238                self.y
239                    .iter()
240                    .map(|x| x.get(i).to_f64())
241                    .collect::<Vec<f64>>(),
242            ));
243        }
244        let mut df = DataFrame::new(columns)?;
245
246        // Write the dataframe to a csv file
247        CsvWriter::new(&mut file).finish(&mut df)?;
248
249        Ok(())
250    }
251
252    /// Creates a Polars DataFrame of the solution.
253    ///
254    /// Requires feature "polars" to be enabled.
255    ///
256    /// Note that the columns will be named t, y0, y1, ..., yN.
257    ///
258    /// # Returns
259    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
260    ///
261    #[cfg(feature = "polars")]
262    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
263        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
264        let mut columns = vec![Column::new("t".into(), t)];
265        let n = self.y[0].len();
266        for i in 0..n {
267            let header = format!("y{}", i);
268            columns.push(Column::new(
269                header.into(),
270                self.y
271                    .iter()
272                    .map(|x| x.get(i).to_f64())
273                    .collect::<Vec<f64>>(),
274            ));
275        }
276
277        DataFrame::new(columns)
278    }
279
280    /// Creates a Polars DataFrame with column names.
281    ///
282    /// Requires feature "polars" to be enabled.
283    ///
284    /// # Arguments
285    /// * `t_name` - Custom name for the time column
286    /// * `y_names` - Custom names for the state variables
287    ///
288    /// # Returns
289    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
290    ///
291    #[cfg(feature = "polars")]
292    pub fn to_named_polars(
293        &self,
294        t_name: &str,
295        y_names: Vec<&str>,
296    ) -> Result<DataFrame, PolarsError> {
297        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
298        let mut columns = vec![Column::new(t_name.into(), t)];
299
300        let n = self.y[0].len();
301
302        // Validate that we have enough names for all state variables
303        if y_names.len() != n {
304            return Err(PolarsError::ComputeError(
305                format!(
306                    "Expected {} column names for state variables, but got {}",
307                    n,
308                    y_names.len()
309                )
310                .into(),
311            ));
312        }
313
314        for (i, name) in y_names.iter().enumerate() {
315            columns.push(Column::new(
316                (*name).into(),
317                self.y
318                    .iter()
319                    .map(|x| x.get(i).to_f64())
320                    .collect::<Vec<f64>>(),
321            ));
322        }
323
324        DataFrame::new(columns)
325    }
326}