differential_equations/solution.rs
1//! Solution of differential equations
2
3use crate::{
4 Status,
5 traits::{CallBackData, Real, State},
6};
7use std::time::Instant;
8
9#[cfg(feature = "polars")]
10use polars::prelude::*;
11
12/// Timer for tracking solution time
13#[derive(Debug, Clone)]
14pub enum Timer<T: Real> {
15 Off,
16 Running(Instant),
17 Completed(T),
18}
19
20impl<T: Real> Timer<T> {
21 /// Starts the timer
22 pub fn start(&mut self) {
23 *self = Timer::Running(Instant::now());
24 }
25
26 /// Returns the elapsed time in seconds
27 pub fn elapsed(&self) -> T {
28 match self {
29 Timer::Off => T::zero(),
30 Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
31 Timer::Completed(t) => *t,
32 }
33 }
34
35 /// Complete the running timer and convert it to a completed state
36 pub fn complete(&mut self) {
37 match self {
38 Timer::Off => {}
39 Timer::Running(start_time) => {
40 *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
41 }
42 Timer::Completed(_) => {}
43 }
44 }
45}
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y` - Outputted dependent variable points.
51/// * `t` - Outputted independent variable points.
52/// * `status` - Status of the solver.
53/// * `evals` - Number of function evaluations.
54/// * `steps` - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer` - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62 T: Real,
63 V: State<T>,
64 D: CallBackData,
65{
66 /// Outputted independent variable points.
67 pub t: Vec<T>,
68
69 /// Outputted dependent variable points.
70 pub y: Vec<V>,
71
72 /// Status of the solver.
73 pub status: Status<T, V, D>,
74
75 /// Number of function evaluations.
76 pub evals: usize,
77
78 /// Number of Jacobian evaluations.
79 pub jac_evals: usize,
80
81 /// Total number of steps taken by the solver.
82 pub steps: usize,
83
84 /// Number of rejected steps where the solution step-size had to be reduced.
85 pub rejected_steps: usize,
86
87 /// Number of accepted steps where the solution moved closer to tf.
88 pub accepted_steps: usize,
89
90 /// Timer for tracking solution time - Running during solving, Completed after finalization
91 pub timer: Timer<T>,
92}
93
94// Initial methods for the solution
95impl<T, V, D> Default for Solution<T, V, D>
96where
97 T: Real,
98 V: State<T>,
99 D: CallBackData,
100{
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl<T, V, D> Solution<T, V, D>
107where
108 T: Real,
109 V: State<T>,
110 D: CallBackData,
111{
112 /// Creates a new Solution object.
113 pub fn new() -> Self {
114 Solution {
115 t: Vec::with_capacity(100),
116 y: Vec::with_capacity(100),
117 status: Status::Uninitialized,
118 evals: 0,
119 steps: 0,
120 jac_evals: 0,
121 rejected_steps: 0,
122 accepted_steps: 0,
123 timer: Timer::Off,
124 }
125 }
126}
127
128// Methods used during solving
129impl<T, V, D> Solution<T, V, D>
130where
131 T: Real,
132 V: State<T>,
133 D: CallBackData,
134{
135 /// Puhes a new point to the solution, e.g. t and y vecs.
136 ///
137 /// # Arguments
138 /// * `t` - The time point.
139 /// * `y` - The state vector.
140 ///
141 pub fn push(&mut self, t: T, y: V) {
142 self.t.push(t);
143 self.y.push(y);
144 }
145
146 /// Pops the last point from the solution, e.g. t and y vecs.
147 ///
148 /// # Returns
149 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
150 ///
151 pub fn pop(&mut self) -> Option<(T, V)> {
152 if self.t.is_empty() || self.y.is_empty() {
153 return None;
154 }
155 let t = self.t.pop().unwrap();
156 let y = self.y.pop().unwrap();
157 Some((t, y))
158 }
159
160 /// Truncates the solution's (t, y) points to the given index.
161 ///
162 /// # Arguments
163 /// * `index` - The index to truncate to.
164 ///
165 pub fn truncate(&mut self, index: usize) {
166 self.t.truncate(index);
167 self.y.truncate(index);
168 }
169}
170
171// Post-processing methods for the solution
172impl<T, V, D> Solution<T, V, D>
173where
174 T: Real,
175 V: State<T>,
176 D: CallBackData,
177{
178 /// Simplifies the Solution into a tuple of vectors in form (t, y).
179 /// By doing so, the Solution will be consumed and the status,
180 /// evals, steps, rejected_steps, and accepted_steps will be discarded.
181 ///
182 /// # Returns
183 /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
184 ///
185 pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
186 (self.t, self.y)
187 }
188
189 /// Returns the last accepted step of the solution in form (t, y).
190 ///
191 /// # Returns
192 /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
193 ///
194 pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
195 let t = self.t.last().ok_or("No t steps available")?;
196 let y = self.y.last().ok_or("No y vectors available")?;
197 Ok((t, y))
198 }
199
200 /// Returns an iterator over the solution.
201 ///
202 /// # Returns
203 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
204 /// yielding (t, y) tuples.
205 ///
206 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
207 self.t.iter().zip(self.y.iter())
208 }
209
210 /// Creates a CSV file of the solution using standard library functionality.
211 ///
212 /// Note the columns will be named t, y0, y1, ..., yN.
213 ///
214 /// # Arguments
215 /// * `filename` - Name of the file to save the solution.
216 ///
217 /// # Returns
218 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
219 ///
220 #[cfg(not(feature = "polars"))]
221 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
222 use std::io::{BufWriter, Write};
223
224 // Create file and path if it does not exist
225 let path = std::path::Path::new(filename);
226 if let Some(parent) = path.parent() {
227 if !parent.exists() {
228 std::fs::create_dir_all(parent)?;
229 }
230 }
231 let file = std::fs::File::create(filename)?;
232 let mut writer = BufWriter::new(file);
233
234 // Length of state vector
235 let n = self.y[0].len();
236
237 // Write header
238 let mut header = String::from("t");
239 for i in 0..n {
240 header.push_str(&format!(",y{}", i));
241 }
242 writeln!(writer, "{}", header)?;
243
244 // Write data
245 for (t, y) in self.iter() {
246 let mut row = format!("{:?}", t);
247 for i in 0..n {
248 row.push_str(&format!(",{:?}", y.get(i)));
249 }
250 writeln!(writer, "{}", row)?;
251 }
252
253 // Ensure all data is flushed to disk
254 writer.flush()?;
255
256 Ok(())
257 }
258
259 /// Creates a csv file of the solution using Polars DataFrame.
260 ///
261 /// Note the columns will be named t, y0, y1, ..., yN.
262 ///
263 /// # Arguments
264 /// * `filename` - Name of the file to save the solution.
265 ///
266 /// # Returns
267 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
268 ///
269 #[cfg(feature = "polars")]
270 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
271 // Create file and path if it does not exist
272 let path = std::path::Path::new(filename);
273 if !path.exists() {
274 std::fs::create_dir_all(path.parent().unwrap())?;
275 }
276 let mut file = std::fs::File::create(filename)?;
277
278 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
279 let mut columns = vec![Column::new("t".into(), t)];
280 let n = self.y[0].len();
281 for i in 0..n {
282 let header = format!("y{}", i);
283 columns.push(Column::new(
284 header.into(),
285 self.y
286 .iter()
287 .map(|x| x.get(i).to_f64())
288 .collect::<Vec<f64>>(),
289 ));
290 }
291 let mut df = DataFrame::new(columns)?;
292
293 // Write the dataframe to a csv file
294 CsvWriter::new(&mut file).finish(&mut df)?;
295
296 Ok(())
297 }
298
299 /// Creates a Polars DataFrame of the solution.
300 ///
301 /// Requires feature "polars" to be enabled.
302 ///
303 /// Note that the columns will be named t, y0, y1, ..., yN.
304 ///
305 /// # Returns
306 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
307 ///
308 #[cfg(feature = "polars")]
309 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
310 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
311 let mut columns = vec![Column::new("t".into(), t)];
312 let n = self.y[0].len();
313 for i in 0..n {
314 let header = format!("y{}", i);
315 columns.push(Column::new(
316 header.into(),
317 self.y
318 .iter()
319 .map(|x| x.get(i).to_f64())
320 .collect::<Vec<f64>>(),
321 ));
322 }
323
324 DataFrame::new(columns)
325 }
326
327 /// Creates a Polars DataFrame with column names.
328 ///
329 /// Requires feature "polars" to be enabled.
330 ///
331 /// # Arguments
332 /// * `t_name` - Custom name for the time column
333 /// * `y_names` - Custom names for the state variables
334 ///
335 /// # Returns
336 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
337 ///
338 #[cfg(feature = "polars")]
339 pub fn to_named_polars(&self, t_name: &str, y_names: Vec<&str>) -> Result<DataFrame, PolarsError> {
340 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
341 let mut columns = vec![Column::new(t_name.into(), t)];
342
343 let n = self.y[0].len();
344
345 // Validate that we have enough names for all state variables
346 if y_names.len() != n {
347 return Err(PolarsError::ComputeError(
348 format!("Expected {} column names for state variables, but got {}",
349 n, y_names.len()).into()
350 ));
351 }
352
353 for (i, name) in y_names.iter().enumerate() {
354 columns.push(Column::new(
355 (*name).into(),
356 self.y
357 .iter()
358 .map(|x| x.get(i).to_f64())
359 .collect::<Vec<f64>>(),
360 ));
361 }
362
363 DataFrame::new(columns)
364 }
365}