differential_equations/
solution.rs

1//! Solution container for differential equation solvers.
2
3#[cfg(feature = "polars")]
4use polars::prelude::*;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    stats::{Evals, Steps, Timer},
11    status::Status,
12    traits::{Real, State},
13};
14
15/// The result produced by differential equation solvers.
16///
17/// # Fields
18/// * `y`              - Outputted dependent variable points.
19/// * `t`              - Outputted independent variable points.
20/// * `status`         - Status of the solver.
21/// * `evals`          - Number of function evaluations.
22/// * `steps`          - Number of steps.
23/// * `timer`          - Timer for tracking solution time.
24///
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct Solution<T, Y>
28where
29    T: Real,
30    Y: State<T>,
31{
32    /// Outputted independent variable points.
33    pub t: Vec<T>,
34
35    /// Outputted dependent variable points.
36    pub y: Vec<Y>,
37
38    /// Status of the solver.
39    pub status: Status<T, Y>,
40
41    /// Number of function, Jacobian, and related evaluations.
42    pub evals: Evals,
43
44    /// Number of steps taken during the solution.
45    pub steps: Steps,
46
47    /// Timer tracking wall-clock time. `Running` during solving, `Completed` after finalization.
48    pub timer: Timer<T>,
49}
50
51// Initial methods for the solution
52impl<T, Y> Default for Solution<T, Y>
53where
54    T: Real,
55    Y: State<T>,
56{
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<T, Y> Solution<T, Y>
63where
64    T: Real,
65    Y: State<T>,
66{
67    /// Creates a new Solution object.
68    pub fn new() -> Self {
69        Solution {
70            t: Vec::with_capacity(100),
71            y: Vec::with_capacity(100),
72            status: Status::Uninitialized,
73            evals: Evals::new(),
74            steps: Steps::new(),
75            timer: Timer::Off,
76        }
77    }
78}
79
80// Methods used during solving
81impl<T, Y> Solution<T, Y>
82where
83    T: Real,
84    Y: State<T>,
85{
86    /// Push a new `(t, y)` point into the solution.
87    ///
88    /// # Arguments
89    /// * `t` - The time point.
90    /// * `y` - The state vector.
91    ///
92    pub fn push(&mut self, t: T, y: Y) {
93        self.t.push(t);
94        self.y.push(y);
95    }
96
97    /// Pop the last `(t, y)` point from the solution.
98    ///
99    /// # Returns
100    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
101    ///
102    pub fn pop(&mut self) -> Option<(T, Y)> {
103        if self.t.is_empty() || self.y.is_empty() {
104            return None;
105        }
106        let t = self.t.pop().unwrap();
107        let y = self.y.pop().unwrap();
108        Some((t, y))
109    }
110
111    /// Truncates the solution's (t, y) points to the given index.
112    ///
113    /// # Arguments
114    /// * `index` - The index to truncate to.
115    ///
116    pub fn truncate(&mut self, index: usize) {
117        self.t.truncate(index);
118        self.y.truncate(index);
119    }
120}
121
122// Post-processing methods for the solution
123impl<T, Y> Solution<T, Y>
124where
125    T: Real,
126    Y: State<T>,
127{
128    /// Consume the solution into `(t, y)` vectors.
129    ///
130    /// Status, evaluation counters, steps, and timers are discarded.
131    ///
132    /// # Returns
133    /// * `(Vec<T>, Vec<Y)` - Tuple of time and state vectors.
134    ///
135    pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
136        (self.t, self.y)
137    }
138
139    /// Return the last accepted step `(t, y)`.
140    ///
141    /// # Returns
142    /// * `Result<(T, Y), Box<dyn std::error::Error>>` - Result of time and state vector.
143    ///
144    pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
145        let t = self.t.last().ok_or("No t steps available")?;
146        let y = self.y.last().ok_or("No y vectors available")?;
147        Ok((t, y))
148    }
149
150    /// Returns an iterator over the solution.
151    ///
152    /// # Returns
153    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>>` - An iterator
154    ///   yielding (t, y) tuples.
155    ///
156    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
157        self.t.iter().zip(self.y.iter())
158    }
159
160    /// Write the solution to CSV using only the standard library.
161    ///
162    /// Note the columns will be named t, y0, y1, ..., yN.
163    ///
164    /// # Arguments
165    /// * `filename` - Name of the file to save the solution.
166    ///
167    /// # Returns
168    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
169    ///
170    #[cfg(not(feature = "polars"))]
171    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
172        use std::io::{BufWriter, Write};
173
174        // Create file and path if it does not exist
175        let path = std::path::Path::new(filename);
176        if let Some(parent) = path.parent() {
177            if !parent.exists() {
178                std::fs::create_dir_all(parent)?;
179            }
180        }
181        let file = std::fs::File::create(filename)?;
182        let mut writer = BufWriter::new(file);
183
184        // Length of state vector
185        let n = self.y[0].len();
186
187        // Header
188        let mut header = String::from("t");
189        for i in 0..n {
190            header.push_str(&format!(",y{}", i));
191        }
192        writeln!(writer, "{}", header)?;
193
194        // Data rows
195        for (t, y) in self.iter() {
196            let mut row = format!("{:?}", t);
197            for i in 0..n {
198                row.push_str(&format!(",{:?}", y.get(i)));
199            }
200            writeln!(writer, "{}", row)?;
201        }
202
203        writer.flush()?;
204
205        Ok(())
206    }
207
208    /// Write the solution to CSV via a Polars `DataFrame`.
209    ///
210    /// Note the columns will be named t, y0, y1, ..., yN.
211    ///
212    /// # Arguments
213    /// * `filename` - Name of the file to save the solution.
214    ///
215    /// # Returns
216    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
217    ///
218    #[cfg(feature = "polars")]
219    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
220        // Create file and path if it does not exist
221        let path = std::path::Path::new(filename);
222        if !path.exists() {
223            std::fs::create_dir_all(path.parent().unwrap())?;
224        }
225        let mut file = std::fs::File::create(filename)?;
226
227        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
228        let mut columns = vec![Column::new("t".into(), t)];
229        let n = self.y[0].len();
230        for i in 0..n {
231            let header = format!("y{}", i);
232            columns.push(Column::new(
233                header.into(),
234                self.y
235                    .iter()
236                    .map(|x| x.get(i).to_f64())
237                    .collect::<Vec<f64>>(),
238            ));
239        }
240        let mut df = DataFrame::new(columns)?;
241
242        // Write the DataFrame to CSV
243        CsvWriter::new(&mut file).finish(&mut df)?;
244
245        Ok(())
246    }
247
248    /// Convert the solution to a Polars `DataFrame`.
249    ///
250    /// Requires feature "polars" to be enabled.
251    ///
252    /// Note that the columns will be named t, y0, y1, ..., yN.
253    ///
254    /// # Returns
255    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
256    ///
257    #[cfg(feature = "polars")]
258    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
259        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
260        let mut columns = vec![Column::new("t".into(), t)];
261        let n = self.y[0].len();
262        for i in 0..n {
263            let header = format!("y{}", i);
264            columns.push(Column::new(
265                header.into(),
266                self.y
267                    .iter()
268                    .map(|x| x.get(i).to_f64())
269                    .collect::<Vec<f64>>(),
270            ));
271        }
272
273        DataFrame::new(columns)
274    }
275
276    /// Convert the solution to a Polars `DataFrame` with custom column names.
277    ///
278    /// Requires feature "polars" to be enabled.
279    ///
280    /// # Arguments
281    /// * `t_name` - Custom name for the time column
282    /// * `y_names` - Custom names for the state variables
283    ///
284    /// # Returns
285    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
286    ///
287    #[cfg(feature = "polars")]
288    pub fn to_named_polars(
289        &self,
290        t_name: &str,
291        y_names: Vec<&str>,
292    ) -> Result<DataFrame, PolarsError> {
293        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
294        let mut columns = vec![Column::new(t_name.into(), t)];
295
296        let n = self.y[0].len();
297
298        // Validate that we have enough names for all state variables
299        if y_names.len() != n {
300            return Err(PolarsError::ComputeError(
301                format!(
302                    "Expected {} column names for state variables, but got {}",
303                    n,
304                    y_names.len()
305                )
306                .into(),
307            ));
308        }
309
310        for (i, name) in y_names.iter().enumerate() {
311            columns.push(Column::new(
312                (*name).into(),
313                self.y
314                    .iter()
315                    .map(|x| x.get(i).to_f64())
316                    .collect::<Vec<f64>>(),
317            ));
318        }
319
320        DataFrame::new(columns)
321    }
322}