differential_equations/
solution.rs

1//! Solution container for differential equation solvers.
2
3#[cfg(feature = "polars")]
4use polars::prelude::*;
5
6use crate::{
7    stats::{Evals, Steps, Timer},
8    status::Status,
9    traits::{CallBackData, Real, State},
10};
11
12/// The result produced by differential equation solvers.
13///
14/// # Fields
15/// * `y`              - Outputted dependent variable points.
16/// * `t`              - Outputted independent variable points.
17/// * `status`         - Status of the solver.
18/// * `evals`          - Number of function evaluations.
19/// * `steps`          - Number of steps.
20/// * `timer`          - Timer for tracking solution time.
21///
22#[derive(Debug, Clone)]
23pub struct Solution<T, Y, D>
24where
25    T: Real,
26    Y: State<T>,
27    D: CallBackData,
28{
29    /// Outputted independent variable points.
30    pub t: Vec<T>,
31
32    /// Outputted dependent variable points.
33    pub y: Vec<Y>,
34
35    /// Status of the solver.
36    pub status: Status<T, Y, D>,
37
38    /// Number of function, Jacobian, and related evaluations.
39    pub evals: Evals,
40
41    /// Number of steps taken during the solution.
42    pub steps: Steps,
43
44    /// Timer tracking wall-clock time. `Running` during solving, `Completed` after finalization.
45    pub timer: Timer<T>,
46}
47
48// Initial methods for the solution
49impl<T, Y, D> Default for Solution<T, Y, D>
50where
51    T: Real,
52    Y: State<T>,
53    D: CallBackData,
54{
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl<T, Y, D> Solution<T, Y, D>
61where
62    T: Real,
63    Y: State<T>,
64    D: CallBackData,
65{
66    /// Creates a new Solution object.
67    pub fn new() -> Self {
68        Solution {
69            t: Vec::with_capacity(100),
70            y: Vec::with_capacity(100),
71            status: Status::Uninitialized,
72            evals: Evals::new(),
73            steps: Steps::new(),
74            timer: Timer::Off,
75        }
76    }
77}
78
79// Methods used during solving
80impl<T, Y, D> Solution<T, Y, D>
81where
82    T: Real,
83    Y: State<T>,
84    D: CallBackData,
85{
86    /// Push a new `(t, y)` point into the solution.
87    ///
88    /// # Arguments
89    /// * `t` - The time point.
90    /// * `y` - The state vector.
91    ///
92    pub fn push(&mut self, t: T, y: Y) {
93        self.t.push(t);
94        self.y.push(y);
95    }
96
97    /// Pop the last `(t, y)` point from the solution.
98    ///
99    /// # Returns
100    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
101    ///
102    pub fn pop(&mut self) -> Option<(T, Y)> {
103        if self.t.is_empty() || self.y.is_empty() {
104            return None;
105        }
106        let t = self.t.pop().unwrap();
107        let y = self.y.pop().unwrap();
108        Some((t, y))
109    }
110
111    /// Truncates the solution's (t, y) points to the given index.
112    ///
113    /// # Arguments
114    /// * `index` - The index to truncate to.
115    ///
116    pub fn truncate(&mut self, index: usize) {
117        self.t.truncate(index);
118        self.y.truncate(index);
119    }
120}
121
122// Post-processing methods for the solution
123impl<T, Y, D> Solution<T, Y, D>
124where
125    T: Real,
126    Y: State<T>,
127    D: CallBackData,
128{
129    /// Consume the solution into `(t, y)` vectors.
130    ///
131    /// Status, evaluation counters, steps, and timers are discarded.
132    ///
133    /// # Returns
134    /// * `(Vec<T>, Vec<Y)` - Tuple of time and state vectors.
135    ///
136    pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
137        (self.t, self.y)
138    }
139
140    /// Return the last accepted step `(t, y)`.
141    ///
142    /// # Returns
143    /// * `Result<(T, Y), Box<dyn std::error::Error>>` - Result of time and state vector.
144    ///
145    pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
146        let t = self.t.last().ok_or("No t steps available")?;
147        let y = self.y.last().ok_or("No y vectors available")?;
148        Ok((t, y))
149    }
150
151    /// Returns an iterator over the solution.
152    ///
153    /// # Returns
154    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>>` - An iterator
155    ///   yielding (t, y) tuples.
156    ///
157    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
158        self.t.iter().zip(self.y.iter())
159    }
160
161    /// Write the solution to CSV using only the standard library.
162    ///
163    /// Note the columns will be named t, y0, y1, ..., yN.
164    ///
165    /// # Arguments
166    /// * `filename` - Name of the file to save the solution.
167    ///
168    /// # Returns
169    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
170    ///
171    #[cfg(not(feature = "polars"))]
172    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
173        use std::io::{BufWriter, Write};
174
175        // Create file and path if it does not exist
176        let path = std::path::Path::new(filename);
177        if let Some(parent) = path.parent() {
178            if !parent.exists() {
179                std::fs::create_dir_all(parent)?;
180            }
181        }
182        let file = std::fs::File::create(filename)?;
183        let mut writer = BufWriter::new(file);
184
185        // Length of state vector
186        let n = self.y[0].len();
187
188        // Header
189        let mut header = String::from("t");
190        for i in 0..n {
191            header.push_str(&format!(",y{}", i));
192        }
193        writeln!(writer, "{}", header)?;
194
195        // Data rows
196        for (t, y) in self.iter() {
197            let mut row = format!("{:?}", t);
198            for i in 0..n {
199                row.push_str(&format!(",{:?}", y.get(i)));
200            }
201            writeln!(writer, "{}", row)?;
202        }
203
204        writer.flush()?;
205
206        Ok(())
207    }
208
209    /// Write the solution to CSV via a Polars `DataFrame`.
210    ///
211    /// Note the columns will be named t, y0, y1, ..., yN.
212    ///
213    /// # Arguments
214    /// * `filename` - Name of the file to save the solution.
215    ///
216    /// # Returns
217    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
218    ///
219    #[cfg(feature = "polars")]
220    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
221        // Create file and path if it does not exist
222        let path = std::path::Path::new(filename);
223        if !path.exists() {
224            std::fs::create_dir_all(path.parent().unwrap())?;
225        }
226        let mut file = std::fs::File::create(filename)?;
227
228        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
229        let mut columns = vec![Column::new("t".into(), t)];
230        let n = self.y[0].len();
231        for i in 0..n {
232            let header = format!("y{}", i);
233            columns.push(Column::new(
234                header.into(),
235                self.y
236                    .iter()
237                    .map(|x| x.get(i).to_f64())
238                    .collect::<Vec<f64>>(),
239            ));
240        }
241        let mut df = DataFrame::new(columns)?;
242
243        // Write the DataFrame to CSV
244        CsvWriter::new(&mut file).finish(&mut df)?;
245
246        Ok(())
247    }
248
249    /// Convert the solution to a Polars `DataFrame`.
250    ///
251    /// Requires feature "polars" to be enabled.
252    ///
253    /// Note that the columns will be named t, y0, y1, ..., yN.
254    ///
255    /// # Returns
256    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
257    ///
258    #[cfg(feature = "polars")]
259    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
260        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
261        let mut columns = vec![Column::new("t".into(), t)];
262        let n = self.y[0].len();
263        for i in 0..n {
264            let header = format!("y{}", i);
265            columns.push(Column::new(
266                header.into(),
267                self.y
268                    .iter()
269                    .map(|x| x.get(i).to_f64())
270                    .collect::<Vec<f64>>(),
271            ));
272        }
273
274        DataFrame::new(columns)
275    }
276
277    /// Convert the solution to a Polars `DataFrame` with custom column names.
278    ///
279    /// Requires feature "polars" to be enabled.
280    ///
281    /// # Arguments
282    /// * `t_name` - Custom name for the time column
283    /// * `y_names` - Custom names for the state variables
284    ///
285    /// # Returns
286    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
287    ///
288    #[cfg(feature = "polars")]
289    pub fn to_named_polars(
290        &self,
291        t_name: &str,
292        y_names: Vec<&str>,
293    ) -> Result<DataFrame, PolarsError> {
294        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
295        let mut columns = vec![Column::new(t_name.into(), t)];
296
297        let n = self.y[0].len();
298
299        // Validate that we have enough names for all state variables
300        if y_names.len() != n {
301            return Err(PolarsError::ComputeError(
302                format!(
303                    "Expected {} column names for state variables, but got {}",
304                    n,
305                    y_names.len()
306                )
307                .into(),
308            ));
309        }
310
311        for (i, name) in y_names.iter().enumerate() {
312            columns.push(Column::new(
313                (*name).into(),
314                self.y
315                    .iter()
316                    .map(|x| x.get(i).to_f64())
317                    .collect::<Vec<f64>>(),
318            ));
319        }
320
321        DataFrame::new(columns)
322    }
323}