1#[cfg(feature = "polars")]
4use polars::prelude::*;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 stats::{Evals, Steps, Timer},
11 status::Status,
12 traits::{Real, State},
13};
14
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct Solution<T, Y>
28where
29 T: Real,
30 Y: State<T>,
31{
32 pub t: Vec<T>,
34
35 pub y: Vec<Y>,
37
38 pub status: Status<T, Y>,
40
41 pub evals: Evals,
43
44 pub steps: Steps,
46
47 pub timer: Timer<T>,
49}
50
51impl<T, Y> Default for Solution<T, Y>
53where
54 T: Real,
55 Y: State<T>,
56{
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<T, Y> Solution<T, Y>
63where
64 T: Real,
65 Y: State<T>,
66{
67 pub fn new() -> Self {
69 Solution {
70 t: Vec::new(),
71 y: Vec::new(),
72 status: Status::Uninitialized,
73 evals: Evals::new(),
74 steps: Steps::new(),
75 timer: Timer::Off,
76 }
77 }
78
79 pub fn new_with_capacity(capacity: usize) -> Self {
84 Solution {
85 t: Vec::with_capacity(capacity),
86 y: Vec::with_capacity(capacity),
87 status: Status::Uninitialized,
88 evals: Evals::new(),
89 steps: Steps::new(),
90 timer: Timer::Off,
91 }
92 }
93}
94
95impl<T, Y> Solution<T, Y>
97where
98 T: Real,
99 Y: State<T>,
100{
101 pub fn push(&mut self, t: T, y: Y) {
108 self.t.push(t);
109 self.y.push(y);
110 }
111
112 pub fn pop(&mut self) -> Option<(T, Y)> {
118 if self.t.is_empty() || self.y.is_empty() {
119 return None;
120 }
121 let t = self.t.pop().unwrap();
122 let y = self.y.pop().unwrap();
123 Some((t, y))
124 }
125
126 pub fn truncate(&mut self, index: usize) {
132 self.t.truncate(index);
133 self.y.truncate(index);
134 }
135}
136
137impl<T, Y> Solution<T, Y>
139where
140 T: Real,
141 Y: State<T>,
142{
143 pub fn into_tuple(self) -> (Vec<T>, Vec<Y>) {
151 (self.t, self.y)
152 }
153
154 pub fn last(&self) -> Result<(&T, &Y), Box<dyn std::error::Error>> {
160 let t = self.t.last().ok_or("No t steps available")?;
161 let y = self.y.last().ok_or("No y vectors available")?;
162 Ok((t, y))
163 }
164
165 pub fn iter(&self) -> std::iter::Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, Y>> {
172 self.t.iter().zip(self.y.iter())
173 }
174
175 #[cfg(not(feature = "polars"))]
186 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
187 use std::io::{BufWriter, Write};
188
189 let path = std::path::Path::new(filename);
191 if let Some(parent) = path.parent()
192 && !parent.exists()
193 {
194 std::fs::create_dir_all(parent)?;
195 }
196 let file = std::fs::File::create(filename)?;
197 let mut writer = BufWriter::new(file);
198
199 let n = self.y[0].len();
201
202 let mut header = String::from("t");
204 for i in 0..n {
205 header.push_str(&format!(",y{}", i));
206 }
207 writeln!(writer, "{}", header)?;
208
209 for (t, y) in self.iter() {
211 let mut row = format!("{:?}", t);
212 for i in 0..n {
213 row.push_str(&format!(",{:?}", y.get_component(i)));
214 }
215 writeln!(writer, "{}", row)?;
216 }
217
218 writer.flush()?;
219
220 Ok(())
221 }
222
223 #[cfg(feature = "polars")]
234 pub fn to_csv(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
235 let path = std::path::Path::new(filename);
237 if let Some(parent) = path.parent()
238 && !parent.exists()
239 {
240 std::fs::create_dir_all(parent)?;
241 }
242 let mut file = std::fs::File::create(filename)?;
243
244 let t = self
245 .t
246 .iter()
247 .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
248 .collect::<Vec<f64>>();
249 let mut columns = vec![Column::new("t".into(), t)];
250 let n = self.y[0].len();
251 for i in 0..n {
252 let header = format!("y{}", i);
253 columns.push(Column::new(
254 header.into(),
255 self.y
256 .iter()
257 .map(|y| {
258 simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
259 })
260 .collect::<Vec<f64>>(),
261 ));
262 }
263 let mut df = DataFrame::new(self.t.len(), columns)?;
264
265 CsvWriter::new(&mut file).finish(&mut df)?;
267
268 Ok(())
269 }
270
271 #[cfg(feature = "polars")]
281 pub fn to_polars(&self) -> Result<DataFrame, PolarsError> {
282 let t = self
283 .t
284 .iter()
285 .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
286 .collect::<Vec<f64>>();
287 let mut columns = vec![Column::new("t".into(), t)];
288 let n = self.y[0].len();
289 for i in 0..n {
290 let header = format!("y{}", i);
291 columns.push(Column::new(
292 header.into(),
293 self.y
294 .iter()
295 .map(|y| {
296 simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
297 })
298 .collect::<Vec<f64>>(),
299 ));
300 }
301
302 DataFrame::new(self.t.len(), columns)
303 }
304
305 #[cfg(feature = "polars")]
317 pub fn to_named_polars(
318 &self,
319 t_name: &str,
320 y_names: Vec<&str>,
321 ) -> Result<DataFrame, PolarsError> {
322 let t = self
323 .t
324 .iter()
325 .map(simba::scalar::SupersetOf::<f64>::to_subset_unchecked)
326 .collect::<Vec<f64>>();
327 let mut columns = vec![Column::new(t_name.into(), t)];
328
329 let n = self.y[0].len();
330
331 if y_names.len() != n {
333 return Err(PolarsError::ComputeError(
334 format!(
335 "Expected {} column names for state variables, but got {}",
336 n,
337 y_names.len()
338 )
339 .into(),
340 ));
341 }
342
343 for (i, name) in y_names.iter().enumerate() {
344 columns.push(Column::new(
345 (*name).into(),
346 self.y
347 .iter()
348 .map(|y| {
349 simba::scalar::SupersetOf::<f64>::to_subset_unchecked(&y.get_component(i))
350 })
351 .collect::<Vec<f64>>(),
352 ));
353 }
354
355 DataFrame::new(self.t.len(), columns)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_into_tuple() {
365 let mut sol: Solution<f64, f64> = Solution::new();
366 sol.push(0.0, 10.0);
367 sol.push(1.0, 20.0);
368
369 let (t, y) = sol.into_tuple();
370 assert_eq!(t, vec![0.0, 1.0]);
371 assert_eq!(y, vec![10.0, 20.0]);
372 }
373
374 #[test]
375 fn test_solution_lifecycle() {
376 let sol_new: Solution<f64, f64> = Solution::new();
378 assert!(sol_new.t.is_empty());
379 assert!(sol_new.y.is_empty());
380
381 let sol_cap: Solution<f64, f64> = Solution::new_with_capacity(10);
382 assert!(sol_cap.t.is_empty());
383 assert!(sol_cap.y.is_empty());
384 assert!(sol_cap.t.capacity() >= 10);
385 assert!(sol_cap.y.capacity() >= 10);
386
387 let mut sol = sol_new;
389 sol.push(2.0, 30.0);
390 assert_eq!(sol.t.len(), 1);
391 assert_eq!(sol.y.len(), 1);
392 assert_eq!(sol.t[0], 2.0);
393 assert_eq!(sol.y[0], 30.0);
394
395 let last = sol.last().unwrap();
397 assert_eq!(*last.0, 2.0);
398 assert_eq!(*last.1, 30.0);
399
400 let popped = sol.pop();
402 assert_eq!(popped, Some((2.0, 30.0)));
403 assert!(sol.t.is_empty());
404 assert!(sol.y.is_empty());
405
406 assert!(sol.last().is_err());
408
409 assert_eq!(sol.pop(), None);
411
412 sol.push(0.0, 10.0);
414 sol.push(1.0, 20.0);
415 sol.push(2.0, 30.0);
416
417 let expected = vec![(0.0, 10.0), (1.0, 20.0), (2.0, 30.0)];
418 let actual: Vec<(f64, f64)> = sol.iter().map(|(&t, &y)| (t, y)).collect();
419 assert_eq!(actual, expected);
420
421 sol.truncate(1);
422 assert_eq!(sol.t.len(), 1);
423 assert_eq!(sol.y.len(), 1);
424 assert_eq!(sol.t[0], 0.0);
425 assert_eq!(sol.y[0], 10.0);
426 }
427}