differential_equations/solution.rs
1//! Solution container for differential equation solvers.
2
3#[cfg(feature = "polars")]
4use polars::prelude::*;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 stats::{Evals, Steps, Timer},
11 status::Status,
12 traits::{Real, State},
13};
14
15/// The result produced by differential equation solvers.
16///
17/// # Fields
18/// * `y` - Outputted dependent variable points.
19/// * `t` - Outputted independent variable points.
20/// * `status` - Status of the solver.
21/// * `evals` - Number of function evaluations.
22/// * `steps` - Number of steps.
23/// * `timer` - Timer for tracking solution time.
24///
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct Solution<T, Y>
28where
29 T: Real,
30 Y: State<T>,
31{
32 /// Outputted independent variable points.
33 pub t: Vec<T>,
34
35 /// Outputted dependent variable points.
36 pub y: Vec<Y>,
37
38 /// Status of the solver.
39 pub status: Status<T, Y>,
40
41 /// Number of function, Jacobian, and related evaluations.
42 pub evals: Evals,
43
44 /// Number of steps taken during the solution.
45 pub steps: Steps,
46
47 /// Timer tracking wall-clock time. `Running` during solving, `Completed` after finalization.
48 pub timer: Timer<T>,
49}
50
51// Initial methods for the solution
52impl<T, Y> Default for Solution<T, Y>
53where
54 T: Real,
55 Y: State<T>,
56{
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<T, Y> Solution<T, Y>
63where
64 T: Real,
65 Y: State<T>,
66{
67 /// Creates a new Solution object.
68 pub fn new() -> Self {
69 Solution {
70 t: Vec::with_capacity(100),
71 y: Vec::with_capacity(100),
72 status: Status::Uninitialized,
73 evals: Evals::new(),
74 steps: Steps::new(),
75 timer: Timer::Off,
76 }
77 }
78}
79
80// Methods used during solving
81impl<T, Y> Solution<T, Y>
82where
83 T: Real,
84 Y: State<T>,
85{
86 /// Push a new `(t, y)` point into the solution.
87 ///
88 /// # Arguments
89 /// * `t` - The time point.
90 /// * `y` - The state vector.
91 ///
92 pub fn push(&mut self, t: T, y: Y) {
93 self.t.push(t);
94 self.y.push(y);
95 }
96
97 /// Pop the last `(t, y)` point from the solution.
98 ///
99 /// # Returns
100 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
101 ///
102 pub fn pop(&mut self) -> Option<(T, Y)> {
103 if self.t.is_empty() || self.y.is_empty() {
104 return None;
105 }
106 let t = self.t.pop().unwrap();
107 let y = self.y.pop().unwrap();
108 Some((t, y))
109 }
110
111 /// Truncates the solution's (t, y) points to the given index.
112 ///
113 /// # Arguments
114 /// * `index` - The index to truncate to.
115 ///
116 pub fn truncate(&mut self, index: usize) {
117 self.t.truncate(index);
118 self.y.truncate(index);
119 }
120}
121
122// Post-processing methods for the solution
123impl<T, Y> Solution<T, Y>
124where
125 T: Real,
126 Y: State<T>,
127{
128 /// Consume the solution into `(t, y)` vectors.
129 ///
130 /// Status, evaluation counters, steps, and timers are discarded.
131 ///
132 /// # Returns
133 /// * `(Vec<T>, Vec<Y)` - Tuple of time and state vectors.
134 ///
135 pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
136 (self.t, self.y)
137 }
138
139 /// Return the last accepted step `(t, y)`.
140 ///
141 /// # Returns
142 /// * `Result<(T, Y), Box<dyn std::error::Error>>` - Result of time and state vector.
143 ///
144 pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
145 let t = self.t.last().ok_or("No t steps available")?;
146 let y = self.y.last().ok_or("No y vectors available")?;
147 Ok((t, y))
148 }
149
150 /// Returns an iterator over the solution.
151 ///
152 /// # Returns
153 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>>` - An iterator
154 /// yielding (t, y) tuples.
155 ///
156 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
157 self.t.iter().zip(self.y.iter())
158 }
159
160 /// Write the solution to CSV using only the standard library.
161 ///
162 /// Note the columns will be named t, y0, y1, ..., yN.
163 ///
164 /// # Arguments
165 /// * `filename` - Name of the file to save the solution.
166 ///
167 /// # Returns
168 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
169 ///
170 #[cfg(not(feature = "polars"))]
171 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
172 use std::io::{BufWriter, Write};
173
174 // Create file and path if it does not exist
175 let path = std::path::Path::new(filename);
176 if let Some(parent) = path.parent() {
177 if !parent.exists() {
178 std::fs::create_dir_all(parent)?;
179 }
180 }
181 let file = std::fs::File::create(filename)?;
182 let mut writer = BufWriter::new(file);
183
184 // Length of state vector
185 let n = self.y[0].len();
186
187 // Header
188 let mut header = String::from("t");
189 for i in 0..n {
190 header.push_str(&format!(",y{}", i));
191 }
192 writeln!(writer, "{}", header)?;
193
194 // Data rows
195 for (t, y) in self.iter() {
196 let mut row = format!("{:?}", t);
197 for i in 0..n {
198 row.push_str(&format!(",{:?}", y.get(i)));
199 }
200 writeln!(writer, "{}", row)?;
201 }
202
203 writer.flush()?;
204
205 Ok(())
206 }
207
208 /// Write the solution to CSV via a Polars `DataFrame`.
209 ///
210 /// Note the columns will be named t, y0, y1, ..., yN.
211 ///
212 /// # Arguments
213 /// * `filename` - Name of the file to save the solution.
214 ///
215 /// # Returns
216 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
217 ///
218 #[cfg(feature = "polars")]
219 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
220 // Create file and path if it does not exist
221 let path = std::path::Path::new(filename);
222 if !path.exists() {
223 std::fs::create_dir_all(path.parent().unwrap())?;
224 }
225 let mut file = std::fs::File::create(filename)?;
226
227 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
228 let mut columns = vec![Column::new("t".into(), t)];
229 let n = self.y[0].len();
230 for i in 0..n {
231 let header = format!("y{}", i);
232 columns.push(Column::new(
233 header.into(),
234 self.y
235 .iter()
236 .map(|x| x.get(i).to_f64())
237 .collect::<Vec<f64>>(),
238 ));
239 }
240 let mut df = DataFrame::new(columns)?;
241
242 // Write the DataFrame to CSV
243 CsvWriter::new(&mut file).finish(&mut df)?;
244
245 Ok(())
246 }
247
248 /// Convert the solution to a Polars `DataFrame`.
249 ///
250 /// Requires feature "polars" to be enabled.
251 ///
252 /// Note that the columns will be named t, y0, y1, ..., yN.
253 ///
254 /// # Returns
255 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
256 ///
257 #[cfg(feature = "polars")]
258 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
259 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
260 let mut columns = vec![Column::new("t".into(), t)];
261 let n = self.y[0].len();
262 for i in 0..n {
263 let header = format!("y{}", i);
264 columns.push(Column::new(
265 header.into(),
266 self.y
267 .iter()
268 .map(|x| x.get(i).to_f64())
269 .collect::<Vec<f64>>(),
270 ));
271 }
272
273 DataFrame::new(columns)
274 }
275
276 /// Convert the solution to a Polars `DataFrame` with custom column names.
277 ///
278 /// Requires feature "polars" to be enabled.
279 ///
280 /// # Arguments
281 /// * `t_name` - Custom name for the time column
282 /// * `y_names` - Custom names for the state variables
283 ///
284 /// # Returns
285 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
286 ///
287 #[cfg(feature = "polars")]
288 pub fn to_named_polars(
289 &self,
290 t_name: &str,
291 y_names: Vec<&str>,
292 ) -> Result<DataFrame, PolarsError> {
293 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
294 let mut columns = vec![Column::new(t_name.into(), t)];
295
296 let n = self.y[0].len();
297
298 // Validate that we have enough names for all state variables
299 if y_names.len() != n {
300 return Err(PolarsError::ComputeError(
301 format!(
302 "Expected {} column names for state variables, but got {}",
303 n,
304 y_names.len()
305 )
306 .into(),
307 ));
308 }
309
310 for (i, name) in y_names.iter().enumerate() {
311 columns.push(Column::new(
312 (*name).into(),
313 self.y
314 .iter()
315 .map(|x| x.get(i).to_f64())
316 .collect::<Vec<f64>>(),
317 ));
318 }
319
320 DataFrame::new(columns)
321 }
322}