differential_equations/solution.rs
1//! Solution of differential equations
2
3use crate::{
4 Status,
5 stats::{Evals, Steps, Timer},
6 traits::{CallBackData, Real, State},
7};
8
9#[cfg(feature = "polars")]
10use polars::prelude::*;
11
12/// Solution of returned by differential equation solvers
13///
14/// # Fields
15/// * `y` - Outputted dependent variable points.
16/// * `t` - Outputted independent variable points.
17/// * `status` - Status of the solver.
18/// * `evals` - Number of function evaluations.
19/// * `steps` - Number of steps.
20/// * `rejected_steps` - Number of rejected steps.
21/// * `accepted_steps` - Number of accepted steps.
22/// * `timer` - Timer for tracking solution time.
23///
24#[derive(Debug, Clone)]
25pub struct Solution<T, V, D>
26where
27 T: Real,
28 V: State<T>,
29 D: CallBackData,
30{
31 /// Outputted independent variable points.
32 pub t: Vec<T>,
33
34 /// Outputted dependent variable points.
35 pub y: Vec<V>,
36
37 /// Status of the solver.
38 pub status: Status<T, V, D>,
39
40 /// Number of function, jacobian, etc evaluations.
41 pub evals: Evals,
42
43 /// Number of steps taken during the solution.
44 pub steps: Steps,
45
46 /// Timer for tracking solution time - Running during solving, Completed after finalization
47 pub timer: Timer<T>,
48}
49
50// Initial methods for the solution
51impl<T, V, D> Default for Solution<T, V, D>
52where
53 T: Real,
54 V: State<T>,
55 D: CallBackData,
56{
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<T, V, D> Solution<T, V, D>
63where
64 T: Real,
65 V: State<T>,
66 D: CallBackData,
67{
68 /// Creates a new Solution object.
69 pub fn new() -> Self {
70 Solution {
71 t: Vec::with_capacity(100),
72 y: Vec::with_capacity(100),
73 status: Status::Uninitialized,
74 evals: Evals::new(),
75 steps: Steps::new(),
76 timer: Timer::Off,
77 }
78 }
79}
80
81// Methods used during solving
82impl<T, V, D> Solution<T, V, D>
83where
84 T: Real,
85 V: State<T>,
86 D: CallBackData,
87{
88 /// Puhes a new point to the solution, e.g. t and y vecs.
89 ///
90 /// # Arguments
91 /// * `t` - The time point.
92 /// * `y` - The state vector.
93 ///
94 pub fn push(&mut self, t: T, y: V) {
95 self.t.push(t);
96 self.y.push(y);
97 }
98
99 /// Pops the last point from the solution, e.g. t and y vecs.
100 ///
101 /// # Returns
102 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
103 ///
104 pub fn pop(&mut self) -> Option<(T, V)> {
105 if self.t.is_empty() || self.y.is_empty() {
106 return None;
107 }
108 let t = self.t.pop().unwrap();
109 let y = self.y.pop().unwrap();
110 Some((t, y))
111 }
112
113 /// Truncates the solution's (t, y) points to the given index.
114 ///
115 /// # Arguments
116 /// * `index` - The index to truncate to.
117 ///
118 pub fn truncate(&mut self, index: usize) {
119 self.t.truncate(index);
120 self.y.truncate(index);
121 }
122}
123
124// Post-processing methods for the solution
125impl<T, V, D> Solution<T, V, D>
126where
127 T: Real,
128 V: State<T>,
129 D: CallBackData,
130{
131 /// Simplifies the Solution into a tuple of vectors in form (t, y).
132 /// By doing so, the Solution will be consumed and the status,
133 /// evals, steps, rejected_steps, and accepted_steps will be discarded.
134 ///
135 /// # Returns
136 /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
137 ///
138 pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
139 (self.t, self.y)
140 }
141
142 /// Returns the last accepted step of the solution in form (t, y).
143 ///
144 /// # Returns
145 /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
146 ///
147 pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
148 let t = self.t.last().ok_or("No t steps available")?;
149 let y = self.y.last().ok_or("No y vectors available")?;
150 Ok((t, y))
151 }
152
153 /// Returns an iterator over the solution.
154 ///
155 /// # Returns
156 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
157 /// yielding (t, y) tuples.
158 ///
159 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
160 self.t.iter().zip(self.y.iter())
161 }
162
163 /// Creates a CSV file of the solution using standard library functionality.
164 ///
165 /// Note the columns will be named t, y0, y1, ..., yN.
166 ///
167 /// # Arguments
168 /// * `filename` - Name of the file to save the solution.
169 ///
170 /// # Returns
171 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
172 ///
173 #[cfg(not(feature = "polars"))]
174 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
175 use std::io::{BufWriter, Write};
176
177 // Create file and path if it does not exist
178 let path = std::path::Path::new(filename);
179 if let Some(parent) = path.parent() {
180 if !parent.exists() {
181 std::fs::create_dir_all(parent)?;
182 }
183 }
184 let file = std::fs::File::create(filename)?;
185 let mut writer = BufWriter::new(file);
186
187 // Length of state vector
188 let n = self.y[0].len();
189
190 // Write header
191 let mut header = String::from("t");
192 for i in 0..n {
193 header.push_str(&format!(",y{}", i));
194 }
195 writeln!(writer, "{}", header)?;
196
197 // Write data
198 for (t, y) in self.iter() {
199 let mut row = format!("{:?}", t);
200 for i in 0..n {
201 row.push_str(&format!(",{:?}", y.get(i)));
202 }
203 writeln!(writer, "{}", row)?;
204 }
205
206 // Ensure all data is flushed to disk
207 writer.flush()?;
208
209 Ok(())
210 }
211
212 /// Creates a csv file of the solution using Polars DataFrame.
213 ///
214 /// Note the columns will be named t, y0, y1, ..., yN.
215 ///
216 /// # Arguments
217 /// * `filename` - Name of the file to save the solution.
218 ///
219 /// # Returns
220 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
221 ///
222 #[cfg(feature = "polars")]
223 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
224 // Create file and path if it does not exist
225 let path = std::path::Path::new(filename);
226 if !path.exists() {
227 std::fs::create_dir_all(path.parent().unwrap())?;
228 }
229 let mut file = std::fs::File::create(filename)?;
230
231 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
232 let mut columns = vec![Column::new("t".into(), t)];
233 let n = self.y[0].len();
234 for i in 0..n {
235 let header = format!("y{}", i);
236 columns.push(Column::new(
237 header.into(),
238 self.y
239 .iter()
240 .map(|x| x.get(i).to_f64())
241 .collect::<Vec<f64>>(),
242 ));
243 }
244 let mut df = DataFrame::new(columns)?;
245
246 // Write the dataframe to a csv file
247 CsvWriter::new(&mut file).finish(&mut df)?;
248
249 Ok(())
250 }
251
252 /// Creates a Polars DataFrame of the solution.
253 ///
254 /// Requires feature "polars" to be enabled.
255 ///
256 /// Note that the columns will be named t, y0, y1, ..., yN.
257 ///
258 /// # Returns
259 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
260 ///
261 #[cfg(feature = "polars")]
262 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
263 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
264 let mut columns = vec![Column::new("t".into(), t)];
265 let n = self.y[0].len();
266 for i in 0..n {
267 let header = format!("y{}", i);
268 columns.push(Column::new(
269 header.into(),
270 self.y
271 .iter()
272 .map(|x| x.get(i).to_f64())
273 .collect::<Vec<f64>>(),
274 ));
275 }
276
277 DataFrame::new(columns)
278 }
279
280 /// Creates a Polars DataFrame with column names.
281 ///
282 /// Requires feature "polars" to be enabled.
283 ///
284 /// # Arguments
285 /// * `t_name` - Custom name for the time column
286 /// * `y_names` - Custom names for the state variables
287 ///
288 /// # Returns
289 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
290 ///
291 #[cfg(feature = "polars")]
292 pub fn to_named_polars(
293 &self,
294 t_name: &str,
295 y_names: Vec<&str>,
296 ) -> Result<DataFrame, PolarsError> {
297 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
298 let mut columns = vec![Column::new(t_name.into(), t)];
299
300 let n = self.y[0].len();
301
302 // Validate that we have enough names for all state variables
303 if y_names.len() != n {
304 return Err(PolarsError::ComputeError(
305 format!(
306 "Expected {} column names for state variables, but got {}",
307 n,
308 y_names.len()
309 )
310 .into(),
311 ));
312 }
313
314 for (i, name) in y_names.iter().enumerate() {
315 columns.push(Column::new(
316 (*name).into(),
317 self.y
318 .iter()
319 .map(|x| x.get(i).to_f64())
320 .collect::<Vec<f64>>(),
321 ));
322 }
323
324 DataFrame::new(columns)
325 }
326}