differential_equations/solution.rs
1//! Solution of differential equations
2
3use crate::{
4 Status,
5 traits::{Real, CallBackData},
6};
7use nalgebra::SMatrix;
8use std::time::Instant;
9
10/// Timer for tracking solution time
11#[derive(Debug, Clone)]
12pub enum Timer<T: Real> {
13 Off,
14 Running(Instant),
15 Completed(T),
16}
17
18impl<T: Real> Timer<T> {
19 /// Starts the timer
20 pub fn start(&mut self) {
21 *self = Timer::Running(Instant::now());
22 }
23
24 /// Returns the elapsed time in seconds
25 pub fn elapsed(&self) -> T {
26 match self {
27 Timer::Off => T::zero(),
28 Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
29 Timer::Completed(t) => *t,
30 }
31 }
32
33 /// Complete the running timer and convert it to a completed state
34 pub fn complete(&mut self) {
35 match self {
36 Timer::Off => {}
37 Timer::Running(start_time) => {
38 *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
39 }
40 Timer::Completed(_) => {}
41 }
42 }
43}
44
45#[cfg(feature = "polars")]
46use polars::prelude::*;
47
48/// Solution of returned by differential equation solvers
49///
50/// # Fields
51/// * `y` - Outputted dependent variable points.
52/// * `t` - Outputted independent variable points.
53/// * `status` - Status of the solver.
54/// * `evals` - Number of function evaluations.
55/// * `steps` - Number of steps.
56/// * `rejected_steps` - Number of rejected steps.
57/// * `accepted_steps` - Number of accepted steps.
58/// * `timer` - Timer for tracking solution time.
59///
60#[derive(Debug, Clone)]
61pub struct Solution<T, const R: usize, const C: usize, D>
62where
63 T: Real,
64 D: CallBackData,
65{
66 /// Outputted independent variable points.
67 pub t: Vec<T>,
68
69 /// Outputted dependent variable points.
70 pub y: Vec<SMatrix<T, R, C>>,
71
72 /// Status of the solver.
73 pub status: Status<T, R, C, D>,
74
75 /// Number of function evaluations.
76 pub evals: usize,
77
78 /// Total number of steps taken by the solver.
79 pub steps: usize,
80
81 /// Number of rejected steps where the solution step-size had to be reduced.
82 pub rejected_steps: usize,
83
84 /// Number of accepted steps where the solution moved closer to tf.
85 pub accepted_steps: usize,
86
87 /// Timer for tracking solution time - Running during solving, Completed after finalization
88 pub timer: Timer<T>,
89}
90
91// Initial methods for the solution
92impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
93where
94 T: Real,
95 D: CallBackData,
96{
97 /// Creates a new Solution object.
98 pub fn new() -> Self {
99 Solution {
100 t: Vec::with_capacity(100),
101 y: Vec::with_capacity(100),
102 status: Status::Uninitialized,
103 evals: 0,
104 steps: 0,
105 rejected_steps: 0,
106 accepted_steps: 0,
107 timer: Timer::Off,
108 }
109 }
110}
111
112// Methods used during solving
113impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
114where
115 T: Real,
116 D: CallBackData,
117{
118 /// Puhes a new point to the solution, e.g. t and y vecs.
119 ///
120 /// # Arguments
121 /// * `t` - The time point.
122 /// * `y` - The state vector.
123 ///
124 pub fn push(&mut self, t: T, y: SMatrix<T, R, C>) {
125 self.t.push(t);
126 self.y.push(y);
127 }
128
129 /// Pops the last point from the solution, e.g. t and y vecs.
130 ///
131 /// # Returns
132 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
133 ///
134 pub fn pop(&mut self) -> Option<(T, SMatrix<T, R, C>)> {
135 if self.t.is_empty() || self.y.is_empty() {
136 return None;
137 }
138 let t = self.t.pop().unwrap();
139 let y = self.y.pop().unwrap();
140 Some((t, y))
141 }
142
143 /// Truncates the solution's (t, y) points to the given index.
144 ///
145 /// # Arguments
146 /// * `index` - The index to truncate to.
147 ///
148 pub fn truncate(&mut self, index: usize) {
149 self.t.truncate(index);
150 self.y.truncate(index);
151 }
152}
153
154// Post-processing methods for the solution
155impl<T, const R: usize, const C: usize, D> Solution<T, R, C, D>
156where
157 T: Real,
158 D: CallBackData,
159{
160 /// Simplifies the Solution into a tuple of vectors in form (t, y).
161 /// By doing so, the Solution will be consumed and the status,
162 /// evals, steps, rejected_steps, and accepted_steps will be discarded.
163 ///
164 /// # Returns
165 /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
166 ///
167 pub fn into_tuple(self) -> (Vec<T>, Vec<SMatrix<T, R, C>>) {
168 (self.t, self.y)
169 }
170
171 /// Returns the last accepted step of the solution in form (t, y).
172 ///
173 /// # Returns
174 /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
175 ///
176 pub fn last(&self) -> Result<(&T, &SMatrix<T, R, C>), Box<dyn std::error::Error>> {
177 let t = self.t.last().ok_or("No t steps available")?;
178 let y = self.y.last().ok_or("No y vectors available")?;
179 Ok((t, y))
180 }
181
182 /// Returns an iterator over the solution.
183 ///
184 /// # Returns
185 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
186 /// yielding (t, y) tuples.
187 ///
188 pub fn iter(
189 &self,
190 ) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, SMatrix<T, R, C>>> {
191 self.t.iter().zip(self.y.iter())
192 }
193
194 /// Creates a CSV file of the solution using standard library functionality.
195 ///
196 /// Note the columns will be named t, y0, y1, ..., yN.
197 ///
198 /// # Arguments
199 /// * `filename` - Name of the file to save the solution.
200 ///
201 /// # Returns
202 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
203 ///
204 #[cfg(not(feature = "polars"))]
205 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
206 use std::io::{BufWriter, Write};
207
208 // Create file and path if it does not exist
209 let path = std::path::Path::new(filename);
210 if let Some(parent) = path.parent() {
211 if !parent.exists() {
212 std::fs::create_dir_all(parent)?;
213 }
214 }
215 let file = std::fs::File::create(filename)?;
216 let mut writer = BufWriter::new(file);
217
218 // Length of state vector
219 let n = self.y[0].len();
220
221 // Write header
222 let mut header = String::from("t");
223 for i in 0..n {
224 header.push_str(&format!(",y{}", i));
225 }
226 writeln!(writer, "{}", header)?;
227
228 // Write data
229 for (t, y) in self.iter() {
230 let mut row = format!("{:?}", t);
231 for i in 0..n {
232 row.push_str(&format!(",{:?}", y[i]));
233 }
234 writeln!(writer, "{}", row)?;
235 }
236
237 // Ensure all data is flushed to disk
238 writer.flush()?;
239
240 Ok(())
241 }
242
243 /// Creates a csv file of the solution using Polars DataFrame.
244 ///
245 /// Note the columns will be named t, y0, y1, ..., yN.
246 ///
247 /// # Arguments
248 /// * `filename` - Name of the file to save the solution.
249 ///
250 /// # Returns
251 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
252 ///
253 #[cfg(feature = "polars")]
254 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
255 // Create file and path if it does not exist
256 let path = std::path::Path::new(filename);
257 if !path.exists() {
258 std::fs::create_dir_all(path.parent().unwrap())?;
259 }
260 let mut file = std::fs::File::create(filename)?;
261
262 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
263 let mut columns = vec![Column::new("t".into(), t)];
264 let n = self.y[0].len();
265 for i in 0..n {
266 let header = format!("y{}", i);
267 columns.push(Column::new(
268 header.into(),
269 self.y.iter().map(|x| x[i].to_f64()).collect::<Vec<f64>>(),
270 ));
271 }
272 let mut df = DataFrame::new(columns)?;
273
274 // Write the dataframe to a csv file
275 CsvWriter::new(&mut file).finish(&mut df)?;
276
277 Ok(())
278 }
279
280 /// Creates a Polars DataFrame of the solution.
281 ///
282 /// Note that the columns will be named t, y0, y1, ..., yN.
283 ///
284 /// # Returns
285 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
286 ///
287 #[cfg(feature = "polars")]
288 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
289 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
290 let mut columns = vec![Column::new("t".into(), t)];
291 let n = self.y[0].len();
292 for i in 0..n {
293 let header = format!("y{}", i);
294 columns.push(Column::new(
295 header.into(),
296 self.y.iter().map(|x| x[i].to_f64()).collect::<Vec<f64>>(),
297 ));
298 }
299
300 DataFrame::new(columns)
301 }
302}