differential_equations/solution.rs
1//! Solution container for differential equation solvers.
2
3use crate::{
4 stats::{Evals, Steps, Timer},
5 status::Status,
6 traits::{CallBackData, Real, State},
7};
8
9#[cfg(feature = "polars")]
10use polars::prelude::*;
11
12/// The result produced by differential equation solvers.
13///
14/// # Fields
15/// * `y` - Outputted dependent variable points.
16/// * `t` - Outputted independent variable points.
17/// * `status` - Status of the solver.
18/// * `evals` - Number of function evaluations.
19/// * `steps` - Number of steps.
20/// * `timer` - Timer for tracking solution time.
21///
22#[derive(Debug, Clone)]
23pub struct Solution<T, Y, D>
24where
25 T: Real,
26 Y: State<T>,
27 D: CallBackData,
28{
29 /// Outputted independent variable points.
30 pub t: Vec<T>,
31
32 /// Outputted dependent variable points.
33 pub y: Vec<Y>,
34
35 /// Status of the solver.
36 pub status: Status<T, Y, D>,
37
38 /// Number of function, Jacobian, and related evaluations.
39 pub evals: Evals,
40
41 /// Number of steps taken during the solution.
42 pub steps: Steps,
43
44 /// Timer tracking wall-clock time. `Running` during solving, `Completed` after finalization.
45 pub timer: Timer<T>,
46}
47
48// Initial methods for the solution
49impl<T, Y, D> Default for Solution<T, Y, D>
50where
51 T: Real,
52 Y: State<T>,
53 D: CallBackData,
54{
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl<T, Y, D> Solution<T, Y, D>
61where
62 T: Real,
63 Y: State<T>,
64 D: CallBackData,
65{
66 /// Creates a new Solution object.
67 pub fn new() -> Self {
68 Solution {
69 t: Vec::with_capacity(100),
70 y: Vec::with_capacity(100),
71 status: Status::Uninitialized,
72 evals: Evals::new(),
73 steps: Steps::new(),
74 timer: Timer::Off,
75 }
76 }
77}
78
79// Methods used during solving
80impl<T, Y, D> Solution<T, Y, D>
81where
82 T: Real,
83 Y: State<T>,
84 D: CallBackData,
85{
86 /// Push a new `(t, y)` point into the solution.
87 ///
88 /// # Arguments
89 /// * `t` - The time point.
90 /// * `y` - The state vector.
91 ///
92 pub fn push(&mut self, t: T, y: Y) {
93 self.t.push(t);
94 self.y.push(y);
95 }
96
97 /// Pop the last `(t, y)` point from the solution.
98 ///
99 /// # Returns
100 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
101 ///
102 pub fn pop(&mut self) -> Option<(T, Y)> {
103 if self.t.is_empty() || self.y.is_empty() {
104 return None;
105 }
106 let t = self.t.pop().unwrap();
107 let y = self.y.pop().unwrap();
108 Some((t, y))
109 }
110
111 /// Truncates the solution's (t, y) points to the given index.
112 ///
113 /// # Arguments
114 /// * `index` - The index to truncate to.
115 ///
116 pub fn truncate(&mut self, index: usize) {
117 self.t.truncate(index);
118 self.y.truncate(index);
119 }
120}
121
122// Post-processing methods for the solution
123impl<T, Y, D> Solution<T, Y, D>
124where
125 T: Real,
126 Y: State<T>,
127 D: CallBackData,
128{
129 /// Consume the solution into `(t, y)` vectors.
130 ///
131 /// Status, evaluation counters, steps, and timers are discarded.
132 ///
133 /// # Returns
134 /// * `(Vec<T>, Vec<Y)` - Tuple of time and state vectors.
135 ///
136 pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
137 (self.t, self.y)
138 }
139
140 /// Return the last accepted step `(t, y)`.
141 ///
142 /// # Returns
143 /// * `Result<(T, Y), Box<dyn std::error::Error>>` - Result of time and state vector.
144 ///
145 pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
146 let t = self.t.last().ok_or("No t steps available")?;
147 let y = self.y.last().ok_or("No y vectors available")?;
148 Ok((t, y))
149 }
150
151 /// Returns an iterator over the solution.
152 ///
153 /// # Returns
154 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>>` - An iterator
155 /// yielding (t, y) tuples.
156 ///
157 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
158 self.t.iter().zip(self.y.iter())
159 }
160
161 /// Write the solution to CSV using only the standard library.
162 ///
163 /// Note the columns will be named t, y0, y1, ..., yN.
164 ///
165 /// # Arguments
166 /// * `filename` - Name of the file to save the solution.
167 ///
168 /// # Returns
169 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
170 ///
171 #[cfg(not(feature = "polars"))]
172 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
173 use std::io::{BufWriter, Write};
174
175 // Create file and path if it does not exist
176 let path = std::path::Path::new(filename);
177 if let Some(parent) = path.parent() {
178 if !parent.exists() {
179 std::fs::create_dir_all(parent)?;
180 }
181 }
182 let file = std::fs::File::create(filename)?;
183 let mut writer = BufWriter::new(file);
184
185 // Length of state vector
186 let n = self.y[0].len();
187
188 // Header
189 let mut header = String::from("t");
190 for i in 0..n {
191 header.push_str(&format!(",y{}", i));
192 }
193 writeln!(writer, "{}", header)?;
194
195 // Data rows
196 for (t, y) in self.iter() {
197 let mut row = format!("{:?}", t);
198 for i in 0..n {
199 row.push_str(&format!(",{:?}", y.get(i)));
200 }
201 writeln!(writer, "{}", row)?;
202 }
203
204 writer.flush()?;
205
206 Ok(())
207 }
208
209 /// Write the solution to CSV via a Polars `DataFrame`.
210 ///
211 /// Note the columns will be named t, y0, y1, ..., yN.
212 ///
213 /// # Arguments
214 /// * `filename` - Name of the file to save the solution.
215 ///
216 /// # Returns
217 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
218 ///
219 #[cfg(feature = "polars")]
220 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
221 // Create file and path if it does not exist
222 let path = std::path::Path::new(filename);
223 if !path.exists() {
224 std::fs::create_dir_all(path.parent().unwrap())?;
225 }
226 let mut file = std::fs::File::create(filename)?;
227
228 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
229 let mut columns = vec![Column::new("t".into(), t)];
230 let n = self.y[0].len();
231 for i in 0..n {
232 let header = format!("y{}", i);
233 columns.push(Column::new(
234 header.into(),
235 self.y
236 .iter()
237 .map(|x| x.get(i).to_f64())
238 .collect::<Vec<f64>>(),
239 ));
240 }
241 let mut df = DataFrame::new(columns)?;
242
243 // Write the DataFrame to CSV
244 CsvWriter::new(&mut file).finish(&mut df)?;
245
246 Ok(())
247 }
248
249 /// Convert the solution to a Polars `DataFrame`.
250 ///
251 /// Requires feature "polars" to be enabled.
252 ///
253 /// Note that the columns will be named t, y0, y1, ..., yN.
254 ///
255 /// # Returns
256 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
257 ///
258 #[cfg(feature = "polars")]
259 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
260 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
261 let mut columns = vec![Column::new("t".into(), t)];
262 let n = self.y[0].len();
263 for i in 0..n {
264 let header = format!("y{}", i);
265 columns.push(Column::new(
266 header.into(),
267 self.y
268 .iter()
269 .map(|x| x.get(i).to_f64())
270 .collect::<Vec<f64>>(),
271 ));
272 }
273
274 DataFrame::new(columns)
275 }
276
277 /// Convert the solution to a Polars `DataFrame` with custom column names.
278 ///
279 /// Requires feature "polars" to be enabled.
280 ///
281 /// # Arguments
282 /// * `t_name` - Custom name for the time column
283 /// * `y_names` - Custom names for the state variables
284 ///
285 /// # Returns
286 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
287 ///
288 #[cfg(feature = "polars")]
289 pub fn to_named_polars(
290 &self,
291 t_name: &str,
292 y_names: Vec<&str>,
293 ) -> Result<DataFrame, PolarsError> {
294 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
295 let mut columns = vec![Column::new(t_name.into(), t)];
296
297 let n = self.y[0].len();
298
299 // Validate that we have enough names for all state variables
300 if y_names.len() != n {
301 return Err(PolarsError::ComputeError(
302 format!(
303 "Expected {} column names for state variables, but got {}",
304 n,
305 y_names.len()
306 )
307 .into(),
308 ));
309 }
310
311 for (i, name) in y_names.iter().enumerate() {
312 columns.push(Column::new(
313 (*name).into(),
314 self.y
315 .iter()
316 .map(|x| x.get(i).to_f64())
317 .collect::<Vec<f64>>(),
318 ));
319 }
320
321 DataFrame::new(columns)
322 }
323}