differential_equations/solution.rs
1//! Solution of differential equations
2
3use crate::{
4 Status,
5 traits::{Real, State, CallBackData},
6};
7use std::time::Instant;
8
9/// Timer for tracking solution time
10#[derive(Debug, Clone)]
11pub enum Timer<T: Real> {
12 Off,
13 Running(Instant),
14 Completed(T),
15}
16
17impl<T: Real> Timer<T> {
18 /// Starts the timer
19 pub fn start(&mut self) {
20 *self = Timer::Running(Instant::now());
21 }
22
23 /// Returns the elapsed time in seconds
24 pub fn elapsed(&self) -> T {
25 match self {
26 Timer::Off => T::zero(),
27 Timer::Running(start_time) => T::from_f64(start_time.elapsed().as_secs_f64()).unwrap(),
28 Timer::Completed(t) => *t,
29 }
30 }
31
32 /// Complete the running timer and convert it to a completed state
33 pub fn complete(&mut self) {
34 match self {
35 Timer::Off => {}
36 Timer::Running(start_time) => {
37 *self = Timer::Completed(T::from_f64(start_time.elapsed().as_secs_f64()).unwrap());
38 }
39 Timer::Completed(_) => {}
40 }
41 }
42}
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47/// Solution of returned by differential equation solvers
48///
49/// # Fields
50/// * `y` - Outputted dependent variable points.
51/// * `t` - Outputted independent variable points.
52/// * `status` - Status of the solver.
53/// * `evals` - Number of function evaluations.
54/// * `steps` - Number of steps.
55/// * `rejected_steps` - Number of rejected steps.
56/// * `accepted_steps` - Number of accepted steps.
57/// * `timer` - Timer for tracking solution time.
58///
59#[derive(Debug, Clone)]
60pub struct Solution<T, V, D>
61where
62 T: Real,
63 V: State<T>,
64 D: CallBackData,
65{
66 /// Outputted independent variable points.
67 pub t: Vec<T>,
68
69 /// Outputted dependent variable points.
70 pub y: Vec<V>,
71
72 /// Status of the solver.
73 pub status: Status<T, V, D>,
74
75 /// Number of function evaluations.
76 pub evals: usize,
77
78 /// Total number of steps taken by the solver.
79 pub steps: usize,
80
81 /// Number of rejected steps where the solution step-size had to be reduced.
82 pub rejected_steps: usize,
83
84 /// Number of accepted steps where the solution moved closer to tf.
85 pub accepted_steps: usize,
86
87 /// Timer for tracking solution time - Running during solving, Completed after finalization
88 pub timer: Timer<T>,
89}
90
91// Initial methods for the solution
92impl<T, V, D> Solution<T, V, D>
93where
94 T: Real,
95 V: State<T>,
96 D: CallBackData,
97{
98 /// Creates a new Solution object.
99 pub fn new() -> Self {
100 Solution {
101 t: Vec::with_capacity(100),
102 y: Vec::with_capacity(100),
103 status: Status::Uninitialized,
104 evals: 0,
105 steps: 0,
106 rejected_steps: 0,
107 accepted_steps: 0,
108 timer: Timer::Off,
109 }
110 }
111}
112
113// Methods used during solving
114impl<T, V, D> Solution<T, V, D>
115where
116 T: Real,
117 V: State<T>,
118 D: CallBackData,
119{
120 /// Puhes a new point to the solution, e.g. t and y vecs.
121 ///
122 /// # Arguments
123 /// * `t` - The time point.
124 /// * `y` - The state vector.
125 ///
126 pub fn push(&mut self, t: T, y: V) {
127 self.t.push(t);
128 self.y.push(y);
129 }
130
131 /// Pops the last point from the solution, e.g. t and y vecs.
132 ///
133 /// # Returns
134 /// * `Option<(T, SMatrix<T, R, C>)>` - The last point in the solution.
135 ///
136 pub fn pop(&mut self) -> Option<(T, V)> {
137 if self.t.is_empty() || self.y.is_empty() {
138 return None;
139 }
140 let t = self.t.pop().unwrap();
141 let y = self.y.pop().unwrap();
142 Some((t, y))
143 }
144
145 /// Truncates the solution's (t, y) points to the given index.
146 ///
147 /// # Arguments
148 /// * `index` - The index to truncate to.
149 ///
150 pub fn truncate(&mut self, index: usize) {
151 self.t.truncate(index);
152 self.y.truncate(index);
153 }
154}
155
156// Post-processing methods for the solution
157impl<T, V, D> Solution<T, V, D>
158where
159 T: Real,
160 V: State<T>,
161 D: CallBackData,
162{
163 /// Simplifies the Solution into a tuple of vectors in form (t, y).
164 /// By doing so, the Solution will be consumed and the status,
165 /// evals, steps, rejected_steps, and accepted_steps will be discarded.
166 ///
167 /// # Returns
168 /// * `(Vec<T>, Vec<V)` - Tuple of time and state vectors.
169 ///
170 pub fn into_tuple(self) -> (Vec<T>, Vec<V>) {
171 (self.t, self.y)
172 }
173
174 /// Returns the last accepted step of the solution in form (t, y).
175 ///
176 /// # Returns
177 /// * `Result<(T, V), Box<dyn std::error::Error>>` - Result of time and state vector.
178 ///
179 pub fn last(&self) -> Result<(&T, &V), Box<dyn std::error::Error>> {
180 let t = self.t.last().ok_or("No t steps available")?;
181 let y = self.y.last().ok_or("No y vectors available")?;
182 Ok((t, y))
183 }
184
185 /// Returns an iterator over the solution.
186 ///
187 /// # Returns
188 /// * `std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>>` - An iterator
189 /// yielding (t, y) tuples.
190 ///
191 pub fn iter(
192 &self,
193 ) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, V>> {
194 self.t.iter().zip(self.y.iter())
195 }
196
197 /// Creates a CSV file of the solution using standard library functionality.
198 ///
199 /// Note the columns will be named t, y0, y1, ..., yN.
200 ///
201 /// # Arguments
202 /// * `filename` - Name of the file to save the solution.
203 ///
204 /// # Returns
205 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
206 ///
207 #[cfg(not(feature = "polars"))]
208 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
209 use std::io::{BufWriter, Write};
210
211 // Create file and path if it does not exist
212 let path = std::path::Path::new(filename);
213 if let Some(parent) = path.parent() {
214 if !parent.exists() {
215 std::fs::create_dir_all(parent)?;
216 }
217 }
218 let file = std::fs::File::create(filename)?;
219 let mut writer = BufWriter::new(file);
220
221 // Length of state vector
222 let n = self.y[0].len();
223
224 // Write header
225 let mut header = String::from("t");
226 for i in 0..n {
227 header.push_str(&format!(",y{}", i));
228 }
229 writeln!(writer, "{}", header)?;
230
231 // Write data
232 for (t, y) in self.iter() {
233 let mut row = format!("{:?}", t);
234 for i in 0..n {
235 row.push_str(&format!(",{:?}", y.get(i)));
236 }
237 writeln!(writer, "{}", row)?;
238 }
239
240 // Ensure all data is flushed to disk
241 writer.flush()?;
242
243 Ok(())
244 }
245
246 /// Creates a csv file of the solution using Polars DataFrame.
247 ///
248 /// Note the columns will be named t, y0, y1, ..., yN.
249 ///
250 /// # Arguments
251 /// * `filename` - Name of the file to save the solution.
252 ///
253 /// # Returns
254 /// * `Result<(), Box<dyn std::error::Error>>` - Result of writing the file.
255 ///
256 #[cfg(feature = "polars")]
257 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
258 // Create file and path if it does not exist
259 let path = std::path::Path::new(filename);
260 if !path.exists() {
261 std::fs::create_dir_all(path.parent().unwrap())?;
262 }
263 let mut file = std::fs::File::create(filename)?;
264
265 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
266 let mut columns = vec![Column::new("t".into(), t)];
267 let n = self.y[0].len();
268 for i in 0..n {
269 let header = format!("y{}", i);
270 columns.push(Column::new(
271 header.into(),
272 self.y.iter().map(|x| x.get(i).to_f64()).collect::<Vec<f64>>(),
273 ));
274 }
275 let mut df = DataFrame::new(columns)?;
276
277 // Write the dataframe to a csv file
278 CsvWriter::new(&mut file).finish(&mut df)?;
279
280 Ok(())
281 }
282
283 /// Creates a Polars DataFrame of the solution.
284 ///
285 /// Note that the columns will be named t, y0, y1, ..., yN.
286 ///
287 /// # Returns
288 /// * `Result<DataFrame, PolarsError>` - Result of creating the DataFrame.
289 ///
290 #[cfg(feature = "polars")]
291 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
292 let t = self.t.iter().map(|x| x.to_f64()).collect::<Vec<f64>>();
293 let mut columns = vec![Column::new("t".into(), t)];
294 let n = self.y[0].len();
295 for i in 0..n {
296 let header = format!("y{}", i);
297 columns.push(Column::new(
298 header.into(),
299 self.y.iter().map(|x| x.get(i).to_f64()).collect::<Vec<f64>>(),
300 ));
301 }
302
303 DataFrame::new(columns)
304 }
305}