Skip to main content

differential_equations/
solution.rs

1//! Solution container for differential equation solvers.
2
3#[cfg(feature = "polars")]
4use polars::prelude::*;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    stats::{Evals, Steps, Timer},
11    status::Status,
12    traits::{Real, State},
13};
14
15/// The result produced by differential equation solvers.
16///
17/// # Fields
18/// * `y`              - Outputted dependent variable points.
19/// * `t`              - Outputted independent variable points.
20/// * `status`         - Status of the solver.
21/// * `evals`          - Number of function evaluations.
22/// * `steps`          - Number of steps.
23/// * `timer`          - Timer for tracking solution time.
24///
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct Solution<T, Y>
28where
29    T: Real,
30    Y: State<T>,
31{
32    /// Outputted independent variable points.
33    pub t: Vec<T>,
34
35    /// Outputted dependent variable points.
36    pub y: Vec<Y>,
37
38    /// Status of the solver.
39    pub status: Status<T, Y>,
40
41    /// Number of function, Jacobian, and related evaluations.
42    pub evals: Evals,
43
44    /// Number of steps taken during the solution.
45    pub steps: Steps,
46
47    /// Timer tracking wall-clock time. `Running` during solving, `Completed` after finalization.
48    pub timer: Timer<T>,
49}
50
51// Initial methods for the solution
52impl<T, Y> Default for Solution<T, Y>
53where
54    T: Real,
55    Y: State<T>,
56{
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<T, Y> Solution<T, Y>
63where
64    T: Real,
65    Y: State<T>,
66{
67    /// Creates a new Solution object.
68    pub fn new() -> Self {
69        Solution {
70            t: Vec::new(),
71            y: Vec::new(),
72            status: Status::Uninitialized,
73            evals: Evals::new(),
74            steps: Steps::new(),
75            timer: Timer::Off,
76        }
77    }
78
79    /// Creates a new Solution object with pre-allocated capacity for points.
80    ///
81    /// # Arguments
82    /// * `capacity` - Initial capacity for the vectors holding time and state points.
83    pub fn new_with_capacity(capacity: usize) -> Self {
84        Solution {
85            t: Vec::with_capacity(capacity),
86            y: Vec::with_capacity(capacity),
87            status: Status::Uninitialized,
88            evals: Evals::new(),
89            steps: Steps::new(),
90            timer: Timer::Off,
91        }
92    }
93}
94
95// Methods used during solving
96impl<T, Y> Solution<T, Y>
97where
98    T: Real,
99    Y: State<T>,
100{
101    /// Push a new `(t, y)` point into the solution.
102    ///
103    /// # Arguments
104    /// * `t` - The time point.
105    /// * `y` - The state vector.
106    ///
107    pub fn push(&mut self, t: T, y: Y) {
108        self.t.push(t);
109        self.y.push(y);
110    }
111
112    /// Pop the last `(t, y)` point from the solution.
113    ///
114    /// # Returns
115    /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
116    ///
117    pub fn pop(&mut self) -> Option<(T, Y)> {
118        if self.t.is_empty() || self.y.is_empty() {
119            return None;
120        }
121        let t = self.t.pop().unwrap();
122        let y = self.y.pop().unwrap();
123        Some((t, y))
124    }
125
126    /// Truncates the solution's (t, y) points to the given index.
127    ///
128    /// # Arguments
129    /// * `index` - The index to truncate to.
130    ///
131    pub fn truncate(&mut self, index: usize) {
132        self.t.truncate(index);
133        self.y.truncate(index);
134    }
135}
136
137// Post-processing methods for the solution
138impl<T, Y> Solution<T, Y>
139where
140    T: Real,
141    Y: State<T>,
142{
143    /// Consume the solution into `(t, y)` vectors.
144    ///
145    /// Status, evaluation counters, steps, and timers are discarded.
146    ///
147    /// # Returns
148    /// * `(Vec<T>, Vec<Y)` - Tuple of time and state vectors.
149    ///
150    pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
151        (self.t, self.y)
152    }
153
154    /// Return the last accepted step `(t, y)`.
155    ///
156    /// # Returns
157    /// * `Result<(T, Y), Box<dyn std::error::Error>>` - Result of time and state vector.
158    ///
159    pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
160        let t = self.t.last().ok_or("No t steps available")?;
161        let y = self.y.last().ok_or("No y vectors available")?;
162        Ok((t, y))
163    }
164
165    /// Returns an iterator over the solution.
166    ///
167    /// # Returns
168    /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>>` - An iterator
169    ///   yielding (t, y) tuples.
170    ///
171    pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
172        self.t.iter().zip(self.y.iter())
173    }
174
175    /// Write the solution to CSV using only the standard library.
176    ///
177    /// Note the columns will be named t, y0, y1, ..., yN.
178    ///
179    /// # Arguments
180    /// * `filename` - Name of the file to save the solution.
181    ///
182    /// # Returns
183    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
184    ///
185    #[cfg(not(feature = "polars"))]
186    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
187        use std::io::{BufWriter, Write};
188
189        // Create file and path if it does not exist
190        let path = std::path::Path::new(filename);
191        if let Some(parent) = path.parent()
192            && !parent.exists()
193        {
194            std::fs::create_dir_all(parent)?;
195        }
196        let file = std::fs::File::create(filename)?;
197        let mut writer = BufWriter::new(file);
198
199        // Length of state vector
200        let n = self.y[0].len();
201
202        // Header
203        let mut header = String::from("t");
204        for i in 0..n {
205            header.push_str(&format!(",y{}", i));
206        }
207        writeln!(writer, "{}", header)?;
208
209        // Data rows
210        for (t, y) in self.iter() {
211            let mut row = format!("{:?}", t);
212            for i in 0..n {
213                row.push_str(&format!(",{:?}", y.get_component(i)));
214            }
215            writeln!(writer, "{}", row)?;
216        }
217
218        writer.flush()?;
219
220        Ok(())
221    }
222
223    /// Write the solution to CSV via a Polars `DataFrame`.
224    ///
225    /// Note the columns will be named t, y0, y1, ..., yN.
226    ///
227    /// # Arguments
228    /// * `filename` - Name of the file to save the solution.
229    ///
230    /// # Returns
231    /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
232    ///
233    #[cfg(feature = "polars")]
234    pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
235        // Create file and path if it does not exist
236        let path = std::path::Path::new(filename);
237        if let Some(parent) = path.parent()
238            && !parent.exists()
239        {
240            std::fs::create_dir_all(parent)?;
241        }
242        let mut file = std::fs::File::create(filename)?;
243
244        let t = self
245            .t
246            .iter()
247            .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
248            .collect::<Vec<f64>>();
249        let mut columns = vec![Column::new("t".into(), t)];
250        let n = self.y[0].len();
251        for i in 0..n {
252            let header = format!("y{}", i);
253            columns.push(Column::new(
254                header.into(),
255                self.y
256                    .iter()
257                    .map(|y| {
258                        simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
259                    })
260                    .collect::<Vec<f64>>(),
261            ));
262        }
263        let mut df = DataFrame::new(self.t.len(), columns)?;
264
265        // Write the DataFrame to CSV
266        CsvWriter::new(&mut file).finish(&mut df)?;
267
268        Ok(())
269    }
270
271    /// Convert the solution to a Polars `DataFrame`.
272    ///
273    /// Requires feature "polars" to be enabled.
274    ///
275    /// Note that the columns will be named t, y0, y1, ..., yN.
276    ///
277    /// # Returns
278    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
279    ///
280    #[cfg(feature = "polars")]
281    pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
282        let t = self
283            .t
284            .iter()
285            .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
286            .collect::<Vec<f64>>();
287        let mut columns = vec![Column::new("t".into(), t)];
288        let n = self.y[0].len();
289        for i in 0..n {
290            let header = format!("y{}", i);
291            columns.push(Column::new(
292                header.into(),
293                self.y
294                    .iter()
295                    .map(|y| {
296                        simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
297                    })
298                    .collect::<Vec<f64>>(),
299            ));
300        }
301
302        DataFrame::new(self.t.len(), columns)
303    }
304
305    /// Convert the solution to a Polars `DataFrame` with custom column names.
306    ///
307    /// Requires feature "polars" to be enabled.
308    ///
309    /// # Arguments
310    /// * `t_name` - Custom name for the time column
311    /// * `y_names` - Custom names for the state variables
312    ///
313    /// # Returns
314    /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
315    ///
316    #[cfg(feature = "polars")]
317    pub fn to_named_polars(
318        &self,
319        t_name: &str,
320        y_names: Vec<&str>,
321    ) -> Result<DataFrame, PolarsError> {
322        let t = self
323            .t
324            .iter()
325            .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
326            .collect::<Vec<f64>>();
327        let mut columns = vec![Column::new(t_name.into(), t)];
328
329        let n = self.y[0].len();
330
331        // Validate that we have enough names for all state variables
332        if y_names.len() != n {
333            return Err(PolarsError::ComputeError(
334                format!(
335                    "Expected {} column names for state variables, but got {}",
336                    n,
337                    y_names.len()
338                )
339                .into(),
340            ));
341        }
342
343        for (i, name) in y_names.iter().enumerate() {
344            columns.push(Column::new(
345                (*name).into(),
346                self.y
347                    .iter()
348                    .map(|y| {
349                        simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
350                    })
351                    .collect::<Vec<f64>>(),
352            ));
353        }
354
355        DataFrame::new(self.t.len(), columns)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_into_tuple() {
365        let mut sol: Solution<f64, f64> = Solution::new();
366        sol.push(0.0, 10.0);
367        sol.push(1.0, 20.0);
368
369        let (t, y) = sol.into_tuple();
370        assert_eq!(t, vec![0.0, 1.0]);
371        assert_eq!(y, vec![10.0, 20.0]);
372    }
373
374    #[test]
375    fn test_solution_lifecycle() {
376        // Test new and new_with_capacity
377        let sol_new: Solution<f64, f64> = Solution::new();
378        assert!(sol_new.t.is_empty());
379        assert!(sol_new.y.is_empty());
380
381        let sol_cap: Solution<f64, f64> = Solution::new_with_capacity(10);
382        assert!(sol_cap.t.is_empty());
383        assert!(sol_cap.y.is_empty());
384        assert!(sol_cap.t.capacity() >= 10);
385        assert!(sol_cap.y.capacity() >= 10);
386
387        // Test push
388        let mut sol = sol_new;
389        sol.push(2.0, 30.0);
390        assert_eq!(sol.t.len(), 1);
391        assert_eq!(sol.y.len(), 1);
392        assert_eq!(sol.t[0], 2.0);
393        assert_eq!(sol.y[0], 30.0);
394
395        // Test last (non-empty)
396        let last = sol.last().unwrap();
397        assert_eq!(*last.0, 2.0);
398        assert_eq!(*last.1, 30.0);
399
400        // Test pop
401        let popped = sol.pop();
402        assert_eq!(popped, Some((2.0, 30.0)));
403        assert!(sol.t.is_empty());
404        assert!(sol.y.is_empty());
405
406        // Test last (empty)
407        assert!(sol.last().is_err());
408
409        // Test pop (empty)
410        assert_eq!(sol.pop(), None);
411
412        // Test truncate and iter
413        sol.push(0.0, 10.0);
414        sol.push(1.0, 20.0);
415        sol.push(2.0, 30.0);
416
417        let expected = vec![(0.0, 10.0), (1.0, 20.0), (2.0, 30.0)];
418        let actual: Vec<(f64, f64)> = sol.iter().map(|(&t, &y)| (t, y)).collect();
419        assert_eq!(actual, expected);
420
421        sol.truncate(1);
422        assert_eq!(sol.t.len(), 1);
423        assert_eq!(sol.y.len(), 1);
424        assert_eq!(sol.t[0], 0.0);
425        assert_eq!(sol.y[0], 10.0);
426    }
427}