cellular_raza_core/
time.rs

1//! Controls how the simulation time is advanced
2
3use kdam::BarExt;
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "tracing")]
7use tracing::instrument;
8
9use cellular_raza_concepts::TimeError;
10
11/// A [TimeEvent] describes that a certain action is to be executed after the next iteration step.
12#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)]
13pub enum TimeEvent {
14    /// Saves a partial simulation run which is suitable for data readout but not full recovery of the
15    /// simulation for restarting.
16    PartialSave,
17    /// Performs a complete save from which the simulation should be able to be recovered.
18    FullSave,
19}
20
21/// Represents the next time point which is returned by the [TimeStepper::advance] method.
22///
23/// It is important to note that the absolute time value $t$ is not meant to be used
24/// in updating steps but rather for saving results and annotating them correctly.
25/// Library authors are advised to keep this in mind.
26#[derive(Clone, Debug)]
27pub struct NextTimePoint<F> {
28    /// Time increment $dt$
29    pub increment: F,
30    /// Time value $t$
31    pub time: F,
32    /// Current iteration
33    pub iteration: usize,
34    /// Event at this iteration, or None
35    pub event: Option<TimeEvent>,
36}
37
38/// Increments time of the simulation
39///
40/// In the future we hope to add adaptive steppers depending on a specified accuracy function.
41pub trait TimeStepper<F> {
42    /// Advances the time stepper to the next time point. Also returns if there is an event
43    /// scheduled to take place and the next time value and iteration number
44    #[must_use]
45    fn advance(&mut self) -> Result<Option<NextTimePoint<F>>, TimeError>;
46
47    /// Indicates if the initial values should be stored
48    #[must_use]
49    fn save_initial(&self) -> Option<NextTimePoint<F>>;
50
51    /// Retrieved the last point at which the simulation was fully recovered.
52    /// This might be helpful in the future when error handling is more mature and able to recover.
53    fn get_last_full_save(&self) -> Option<(F, usize)>;
54
55    /// Creates a bar that tracks the simulation progress
56    fn initialize_bar(&self, title: Option<&str>) -> Result<kdam::Bar, TimeError>;
57
58    /// Update a given bar to show the current simulation state
59    #[allow(unused)]
60    fn update_bar(&self, bar: &mut kdam::Bar) -> Result<(), std::io::Error>;
61}
62
63/// Time stepping with a fixed time length
64///
65/// This time-stepper increments the time variable by the same length.
66/// ```
67/// # use cellular_raza_core::time::FixedStepsize;
68/// let t0 = 1.0;
69/// let dt = 0.2;
70/// let partial_save_points = vec![3.0, 5.0, 11.0, 20.0];
71/// let time_stepper = FixedStepsize::from_partial_save_points(t0, dt, partial_save_points).unwrap();
72/// ```
73#[derive(Clone, Deserialize, Serialize)]
74pub struct FixedStepsize<F> {
75    // The stepsize which was fixed
76    dt: F,
77    t0: F,
78    // An ordered set of time points to store every value at which we should evaluate
79    all_events: Vec<(F, usize, TimeEvent)>,
80    current_time: F,
81    current_iteration: usize,
82    maximum_iterations: usize,
83    current_event: Option<TimeEvent>,
84    past_events: Vec<(F, usize, TimeEvent)>,
85}
86
87impl<F> FixedStepsize<F>
88where
89    F: num::Float + num::ToPrimitive + num::FromPrimitive,
90{
91    /// Construct the stepper from initial time, increment,
92    /// number of steps and save interval
93    #[cfg_attr(feature = "tracing", instrument(skip_all))]
94    pub fn from_partial_save_steps(
95        t0: F,
96        dt: F,
97        n_steps: u64,
98        save_interval: u64,
99    ) -> Result<Self, TimeError> {
100        let max_save_points = n_steps.div_ceil(save_interval);
101        let save_point_to_float = |u: u64| -> Result<F, TimeError> {
102            F::from_u64(save_interval * u).ok_or(TimeError(format!(
103                "Could not convert save_interval={save_interval} to type: {}",
104                std::any::type_name::<F>()
105            )))
106        };
107        let partial_save_points = (0..max_save_points + 1)
108            .map(|n| Ok(t0 + save_point_to_float(n)? * dt))
109            .collect::<Result<_, TimeError>>()?;
110        Self::from_partial_save_points(t0, dt, partial_save_points)
111    }
112
113    /// Similar to [Self::from_partial_save_points] but specify the time step between every save
114    /// point together with the integration step.
115    #[cfg_attr(feature = "tracing", instrument(skip_all))]
116    pub fn from_partial_save_interval(
117        t0: F,
118        dt: F,
119        t_max: F,
120        save_interval: F,
121    ) -> Result<Self, TimeError> {
122        let mut partial_save_points = vec![];
123        let mut t = t0;
124        while t <= t_max {
125            partial_save_points.push(t);
126            t = t + save_interval;
127        }
128        Self::from_partial_save_points(t0, dt, partial_save_points)
129    }
130
131    /// Similar to [Self::from_partial_save_interval] but specify a multiple of the time increment
132    /// instead of a floating point value.
133    /// This method is preferred over the one previously mentioned.
134    #[cfg_attr(feature = "tracing", instrument(skip_all))]
135    pub fn from_partial_save_freq(
136        t0: F,
137        dt: F,
138        t_max: F,
139        save_freq: usize,
140    ) -> Result<Self, TimeError> {
141        let max_iterations = F::to_usize(&((t_max - t0) / dt).round())
142            .ok_or(TimeError(format!("Could not round value to usize")))?;
143        let all_events = (0..max_iterations)
144            .step_by(save_freq)
145            .map(|n| {
146                Ok((
147                    t0 + F::from_usize(n * save_freq).ok_or(TimeError(format!(
148                        "Could not convert usize {} to type {}",
149                        n,
150                        std::any::type_name::<F>()
151                    )))? * dt,
152                    n,
153                    TimeEvent::PartialSave,
154                ))
155            })
156            .collect::<Result<Vec<_>, TimeError>>()?;
157        Ok(Self {
158            dt,
159            t0,
160            all_events,
161            current_time: t0,
162            current_iteration: 0,
163            maximum_iterations: max_iterations,
164            current_event: Some(TimeEvent::PartialSave),
165            past_events: Vec::new(),
166        })
167    }
168
169    /// Simple function to construct the stepper from an initial time point, the time increment and
170    /// the time points at which the simulation should be saved. Notice that these saves do not
171    /// cover [FullSaves](TimeEvent::FullSave) but only [PartialSaves](TimeEvent::PartialSave).
172    #[cfg_attr(feature = "tracing", instrument(skip_all))]
173    pub fn from_partial_save_points(
174        t0: F,
175        dt: F,
176        partial_save_points: Vec<F>,
177    ) -> Result<Self, TimeError> {
178        // Sort the save points
179        let mut save_points = partial_save_points;
180        save_points.sort_by(|x, y| x.partial_cmp(y).unwrap());
181        if save_points.iter().any(|x| t0 > *x) {
182            return Err(TimeError(
183                "Invalid time configuration! Evaluation time point is before starting time point."
184                    .to_owned(),
185            ));
186        }
187        let last_save_point = save_points
188            .clone()
189            .into_iter()
190            .max_by(|x, y| x.partial_cmp(y).unwrap())
191            .ok_or(TimeError(
192                "No savepoints specified. Simulation will not save any results.".to_owned(),
193            ))?;
194        let maximum_iterations =
195            (((last_save_point - t0) / dt).round())
196                .to_usize()
197                .ok_or(TimeError(
198                    "An error in casting of float type to usize occurred".to_owned(),
199                ))?;
200        let all_events = save_points
201            .clone()
202            .into_iter()
203            .map(|t_save| {
204                (
205                    t_save,
206                    ((t_save - t0) / dt).round().to_usize().unwrap(),
207                    TimeEvent::PartialSave,
208                )
209            })
210            .collect();
211
212        let current_event = if t0
213            == save_points
214                .into_iter()
215                .min_by(|x, y| x.partial_cmp(y).unwrap())
216                .unwrap()
217        {
218            Some(TimeEvent::PartialSave)
219        } else {
220            None
221        };
222
223        Ok(Self {
224            dt,
225            t0,
226            all_events,
227            current_time: t0,
228            current_iteration: 0,
229            maximum_iterations,
230            // TODO check this again
231            current_event,
232            past_events: Vec::new(),
233        })
234    }
235}
236
237impl<F> TimeStepper<F> for FixedStepsize<F>
238where
239    F: num::Float + num::FromPrimitive,
240{
241    #[cfg_attr(feature = "tracing", instrument(skip_all))]
242    fn advance(&mut self) -> Result<Option<NextTimePoint<F>>, TimeError> {
243        self.current_iteration += 1;
244        self.current_time = F::from_usize(self.current_iteration).ok_or(TimeError(
245            "Error when casting from usize to floating point value".to_owned(),
246        ))? * self.dt
247            + self.t0;
248        // TODO Check if a current event should take place
249        let event = self
250            .all_events
251            .iter()
252            .filter(|(_, iteration, _)| *iteration == self.current_iteration)
253            .map(|(_, _, event)| event.clone())
254            .last();
255
256        if self.current_iteration <= self.maximum_iterations {
257            Ok(Some(NextTimePoint {
258                increment: self.dt,
259                time: self.current_time,
260                iteration: self.current_iteration,
261                event,
262            }))
263        } else {
264            Ok(None)
265        }
266    }
267
268    #[cfg_attr(feature = "tracing", instrument(skip_all))]
269    fn save_initial(&self) -> Option<NextTimePoint<F>> {
270        if self.current_time == self.t0 {
271            Some(NextTimePoint {
272                increment: self.dt,
273                time: self.current_time,
274                iteration: self.current_iteration,
275                event: Some(TimeEvent::PartialSave),
276            })
277        } else {
278            None
279        }
280    }
281
282    #[cfg_attr(feature = "tracing", instrument(skip_all))]
283    fn get_last_full_save(&self) -> Option<(F, usize)> {
284        self.past_events
285            .clone()
286            .into_iter()
287            .filter(|(_, _, event)| *event == TimeEvent::FullSave)
288            .last()
289            .and_then(|x| Some((x.0, x.1)))
290    }
291
292    #[cfg_attr(feature = "tracing", instrument(skip_all))]
293    fn initialize_bar(&self, title: Option<&str>) -> Result<kdam::Bar, TimeError> {
294        let bar_format = "\
295        {desc}{percentage:3.0}%|{animation}| \
296        {count}/{total} \
297        [{elapsed}, \
298        {rate:.2}{unit}/s{postfix}]";
299        let mut bar = kdam::BarBuilder::default()
300            .total(self.maximum_iterations)
301            .bar_format(bar_format)
302            .dynamic_ncols(true);
303        if let Some(title) = title {
304            bar = bar.desc(title);
305        }
306        Ok(bar.build()?)
307    }
308
309    #[cfg_attr(feature = "tracing", instrument(skip_all))]
310    fn update_bar(&self, bar: &mut kdam::Bar) -> Result<(), std::io::Error> {
311        let _ = bar.update(1)?;
312        Ok(())
313    }
314}
315
316#[cfg(test)]
317mod test_time_stepper {
318    use rand::Rng;
319    use rand::SeedableRng;
320
321    use super::*;
322
323    fn generate_new_fixed_stepper<F>(rng_seed: u64) -> FixedStepsize<F>
324    where
325        F: num::Float + From<f32> + num::FromPrimitive,
326    {
327        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(rng_seed);
328        let t0 = <F as From<_>>::from(rng.random_range(0.0..1.0));
329        let dt = <F as From<_>>::from(rng.random_range(0.1..2.0));
330        let save_points = vec![
331            <F as From<_>>::from(rng.random_range(0.01..1.8)),
332            <F as From<_>>::from(rng.random_range(2.01..3.8)),
333            <F as From<_>>::from(rng.random_range(4.01..5.8)),
334            <F as From<_>>::from(rng.random_range(6.01..7.8)),
335        ];
336        FixedStepsize::<F>::from_partial_save_points(t0, dt, save_points).unwrap()
337    }
338
339    #[test]
340    fn initialization() {
341        let t0 = 1.0;
342        let dt = 0.2;
343        let save_points = vec![3.0, 5.0, 11.0, 20.0];
344        let time_stepper = FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
345        assert_eq!(t0, time_stepper.current_time);
346        assert_eq!(0.2, time_stepper.dt);
347        assert_eq!(0, time_stepper.current_iteration);
348        assert_eq!(None, time_stepper.current_event);
349    }
350
351    #[test]
352    #[should_panic]
353    fn panic_wrong_save_points() {
354        let t0 = 10.0;
355        let dt = 0.2;
356        let save_points = vec![3.0, 5.0, 11.0, 20.0];
357        // This call should fail since t0 is larger than the first two save points
358        let _time_stepper = FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
359    }
360
361    #[test]
362    fn stepping_1() {
363        let t0 = 1.0;
364        let dt = 0.2;
365        let save_points = vec![3.0, 5.0, 11.0, 20.0];
366        let mut time_stepper =
367            FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
368
369        for i in 1..11 {
370            let next = time_stepper.advance().unwrap().unwrap();
371            assert_eq!(dt, next.increment);
372            assert_eq!(t0 + i as f64 * dt, next.time);
373            assert_eq!(i, next.iteration);
374            if i == 10 {
375                assert_eq!(Some(TimeEvent::PartialSave), next.event);
376            } else {
377                assert_eq!(None, next.event);
378            }
379        }
380    }
381
382    #[test]
383    fn stepping_2() {
384        let t0 = 0.0;
385        let dt = 0.1;
386        let save_points = vec![0.5, 0.7, 0.9, 1.0];
387        let mut time_stepper =
388            FixedStepsize::from_partial_save_points(t0, dt, save_points.clone()).unwrap();
389
390        assert_eq!(t0, time_stepper.current_time);
391        for i in 1..11 {
392            let next = time_stepper.advance().unwrap().unwrap();
393            assert_eq!(dt, next.increment);
394            assert_eq!(t0 + i as f64 * dt, next.time);
395            assert_eq!(i, next.iteration);
396            if save_points.contains(&next.time) {
397                assert_eq!(Some(TimeEvent::PartialSave), next.event);
398            }
399        }
400    }
401
402    fn test_stepping(rng_seed: u64) {
403        let mut time_stepper = generate_new_fixed_stepper::<f32>(rng_seed);
404
405        for _ in 0..100 {
406            let res = time_stepper.advance().unwrap();
407            match res {
408                Some(_) => (),
409                None => return,
410            }
411        }
412        panic!("The time stepper should have reached the end by now");
413    }
414
415    #[test]
416    fn stepping_end_0() {
417        test_stepping(0);
418    }
419
420    #[test]
421    fn stepping_end_1() {
422        test_stepping(1);
423    }
424
425    #[test]
426    fn stepping_end_2() {
427        test_stepping(2);
428    }
429
430    #[test]
431    fn stepping_end_3() {
432        test_stepping(3);
433    }
434
435    #[test]
436    fn produce_correct_increments() {
437        let t0 = 10.0;
438        let dt = 0.1;
439        let t_max = 11.0;
440        let save_interval = 0.25;
441        let mut stepper =
442            FixedStepsize::from_partial_save_interval(t0, dt, t_max, save_interval).unwrap();
443        let all_times = Vec::from_iter(std::iter::from_fn(move || stepper.advance().unwrap()));
444        for time in all_times {
445            assert_eq!(time.increment, 0.1);
446            match time.event {
447                Some(_) => assert!((time.time - t0) % save_interval < dt),
448                _ => (),
449            }
450        }
451    }
452}