differential_equations/
solution.rs

1//! Solution of differential equations
2
3use crate::{
4    Status,
5    traits::{CallBackData, Real, State},
6};
7use std::time::Instant;
8
9#[cfg(feature = "polars")]
10use polars::prelude::*;
11
12/// Timer for tracking solution time
13#[derive(Debug, Clone)]
14pub enum Timer<T: Real> {
15    Off,
16    Running(Instant),
17    Completed(T),
18}
19
20impl<T: Real> Timer<T> {
21    /// Starts the timer
22    pub fn start(&mut self) {
23        *self = Timer::Running(Instant::now());
24    }
25
26    /// Returns the elapsed time in seconds
27    pub fn elapsed(&self) -> T {
28        match self {
29            Timer::Off => T::zero(),
30            Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
31            Timer::Completed(t) => *t,
32        }
33    }
34
35    /// Complete the running timer and convert it to a completed state
36    pub fn complete(&mut self) {
37        match self {
38            Timer::Off => {}
39            Timer::Running(start_time) => {
40                *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
41            }
42            Timer::Completed(_) => {}
43        }
44    }
45}
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y`              - Outputted dependent variable points.
51/// * `t`              - Outputted independent variable points.
52/// * `status`         - Status of the solver.
53/// * `evals`          - Number of function evaluations.
54/// * `steps`          - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer`          - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62    T: Real,
63    V: State<T>,
64    D: CallBackData,
65{
66    /// Outputted independent variable points.
67    pub t: Vec<T>,
68
69    /// Outputted dependent variable points.
70    pub y: Vec<V>,
71
72    /// Status of the solver.
73    pub status: Status<T, V, D>,
74
75    /// Number of function evaluations.
76    pub evals: usize,
77
78    /// Number of Jacobian evaluations.
79    pub jac_evals: usize,
80
81    /// Total number of steps taken by the solver.
82    pub steps: usize,
83
84    /// Number of rejected steps where the solution step-size had to be reduced.
85    pub rejected_steps: usize,
86
87    /// Number of accepted steps where the solution moved closer to tf.
88    pub accepted_steps: usize,
89
90    /// Timer for tracking solution time - Running during solving, Completed after finalization
91    pub timer: Timer<T>,
92}
93
94// Initial methods for the solution
95impl<T, V, D> Default for Solution<T, V, D>
96where
97    T: Real,
98    V: State<T>,
99    D: CallBackData,
100{
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl<T, V, D> Solution<T, V, D>
107where
108    T: Real,
109    V: State<T>,
110    D: CallBackData,
111{
112    /// Creates a new Solution object.
113    pub fn new() -> Self {
114        Solution {
115            t: Vec::with_capacity(100),
116            y: Vec::with_capacity(100),
117            status: Status::Uninitialized,
118            evals: 0,
119            steps: 0,
120            jac_evals: 0,
121            rejected_steps: 0,
122            accepted_steps: 0,
123            timer: Timer::Off,
124        }
125    }
126}
127
128// Methods used during solving
129impl<T, V, D> Solution<T, V, D>
130where
131    T: Real,
132    V: State<T>,
133    D: CallBackData,
134{
135    /// Puhes a new point to the solution, e.g. t and y vecs.
136    ///
137    /// # Arguments
138    /// * `t` - The time point.
139    /// * `y` - The state vector.
140    ///
141    pub fn push(&mut self, t: T, y: V) {
142        self.t.push(t);
143        self.y.push(y);
144    }
145
146    /// Pops the last point from the solution, e.g. t and y vecs.
147    ///
148    /// # Returns
149    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
150    ///
151    pub fn pop(&mut self) -> Option<(T, V)> {
152        if self.t.is_empty() || self.y.is_empty() {
153            return None;
154        }
155        let t = self.t.pop().unwrap();
156        let y = self.y.pop().unwrap();
157        Some((t, y))
158    }
159
160    /// Truncates the solution's (t, y) points to the given index.
161    ///
162    /// # Arguments
163    /// * `index` - The index to truncate to.
164    ///
165    pub fn truncate(&mut self, index: usize) {
166        self.t.truncate(index);
167        self.y.truncate(index);
168    }
169}
170
171// Post-processing methods for the solution
172impl<T, V, D> Solution<T, V, D>
173where
174    T: Real,
175    V: State<T>,
176    D: CallBackData,
177{
178    /// Simplifies the Solution into a tuple of vectors in form (t, y).
179    /// By doing so, the Solution will be consumed and the status,
180    /// evals, steps, rejected_steps, and accepted_steps will be discarded.
181    ///
182    /// # Returns
183    /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
184    ///
185    pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
186        (self.t, self.y)
187    }
188
189    /// Returns the last accepted step of the solution in form (t, y).
190    ///
191    /// # Returns
192    /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
193    ///
194    pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
195        let t = self.t.last().ok_or("No t steps available")?;
196        let y = self.y.last().ok_or("No y vectors available")?;
197        Ok((t, y))
198    }
199
200    /// Returns an iterator over the solution.
201    ///
202    /// # Returns
203    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
204    ///   yielding (t, y) tuples.
205    ///
206    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
207        self.t.iter().zip(self.y.iter())
208    }
209
210    /// Creates a CSV file of the solution using standard library functionality.
211    ///
212    /// Note the columns will be named t, y0, y1, ..., yN.
213    ///
214    /// # Arguments
215    /// * `filename` - Name of the file to save the solution.
216    ///
217    /// # Returns
218    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
219    ///
220    #[cfg(not(feature = "polars"))]
221    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
222        use std::io::{BufWriter, Write};
223
224        // Create file and path if it does not exist
225        let path = std::path::Path::new(filename);
226        if let Some(parent) = path.parent() {
227            if !parent.exists() {
228                std::fs::create_dir_all(parent)?;
229            }
230        }
231        let file = std::fs::File::create(filename)?;
232        let mut writer = BufWriter::new(file);
233
234        // Length of state vector
235        let n = self.y[0].len();
236
237        // Write header
238        let mut header = String::from("t");
239        for i in 0..n {
240            header.push_str(&format!(",y{}", i));
241        }
242        writeln!(writer, "{}", header)?;
243
244        // Write data
245        for (t, y) in self.iter() {
246            let mut row = format!("{:?}", t);
247            for i in 0..n {
248                row.push_str(&format!(",{:?}", y.get(i)));
249            }
250            writeln!(writer, "{}", row)?;
251        }
252
253        // Ensure all data is flushed to disk
254        writer.flush()?;
255
256        Ok(())
257    }
258
259    /// Creates a csv file of the solution using Polars DataFrame.
260    ///
261    /// Note the columns will be named t, y0, y1, ..., yN.
262    ///
263    /// # Arguments
264    /// * `filename` - Name of the file to save the solution.
265    ///
266    /// # Returns
267    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
268    ///
269    #[cfg(feature = "polars")]
270    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
271        // Create file and path if it does not exist
272        let path = std::path::Path::new(filename);
273        if !path.exists() {
274            std::fs::create_dir_all(path.parent().unwrap())?;
275        }
276        let mut file = std::fs::File::create(filename)?;
277
278        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
279        let mut columns = vec![Column::new("t".into(), t)];
280        let n = self.y[0].len();
281        for i in 0..n {
282            let header = format!("y{}", i);
283            columns.push(Column::new(
284                header.into(),
285                self.y
286                    .iter()
287                    .map(|x| x.get(i).to_f64())
288                    .collect::<Vec<f64>>(),
289            ));
290        }
291        let mut df = DataFrame::new(columns)?;
292
293        // Write the dataframe to a csv file
294        CsvWriter::new(&mut file).finish(&mut df)?;
295
296        Ok(())
297    }
298
299    /// Creates a Polars DataFrame of the solution.
300    /// 
301    /// Requires feature "polars" to be enabled.
302    /// 
303    /// Note that the columns will be named t, y0, y1, ..., yN.
304    ///
305    /// # Returns
306    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
307    ///
308    #[cfg(feature = "polars")]
309    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
310        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
311        let mut columns = vec![Column::new("t".into(), t)];
312        let n = self.y[0].len();
313        for i in 0..n {
314            let header = format!("y{}", i);
315            columns.push(Column::new(
316                header.into(),
317                self.y
318                    .iter()
319                    .map(|x| x.get(i).to_f64())
320                    .collect::<Vec<f64>>(),
321            ));
322        }
323
324        DataFrame::new(columns)
325    }
326
327    /// Creates a Polars DataFrame with column names.
328    /// 
329    /// Requires feature "polars" to be enabled.
330    ///
331    /// # Arguments
332    /// * `t_name` - Custom name for the time column
333    /// * `y_names` - Custom names for the state variables
334    ///
335    /// # Returns
336    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
337    ///
338    #[cfg(feature = "polars")]
339    pub fn to_named_polars(&self, t_name: &str, y_names: Vec<&str>) -> Result<DataFrame, PolarsError> {
340        let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
341        let mut columns = vec![Column::new(t_name.into(), t)];
342        
343        let n = self.y[0].len();
344        
345        // Validate that we have enough names for all state variables
346        if y_names.len() != n {
347            return Err(PolarsError::ComputeError(
348                format!("Expected {} column names for state variables, but got {}", 
349                        n, y_names.len()).into()
350            ));
351        }
352        
353        for (i, name) in y_names.iter().enumerate() {
354            columns.push(Column::new(
355                (*name).into(),
356                self.y
357                    .iter()
358                    .map(|x| x.get(i).to_f64())
359                    .collect::<Vec<f64>>(),
360            ));
361        }
362    
363        DataFrame::new(columns)
364    }
365}