1use kdam::BarExt;
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "tracing")]
7use tracing::instrument;
8
9use cellular_raza_concepts::TimeError;
10
11#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)]
13pub enum TimeEvent {
14 PartialSave,
17 FullSave,
19}
20
21#[derive(Clone, Debug)]
27pub struct NextTimePoint<F> {
28 pub increment: F,
30 pub time: F,
32 pub iteration: usize,
34 pub event: Option<TimeEvent>,
36}
37
38pub trait TimeStepper<F> {
42 #[must_use]
45 fn advance(&mut self) -> Result<Option<NextTimePoint<F>>, TimeError>;
46
47 #[must_use]
49 fn save_initial(&self) -> Option<NextTimePoint<F>>;
50
51 fn get_last_full_save(&self) -> Option<(F, usize)>;
54
55 fn initialize_bar(&self, title: Option<&str>) -> Result<kdam::Bar, TimeError>;
57
58 #[allow(unused)]
60 fn update_bar(&self, bar: &mut kdam::Bar) -> Result<(), std::io::Error>;
61}
62
63#[derive(Clone, Deserialize, Serialize)]
74pub struct FixedStepsize<F> {
75 dt: F,
77 t0: F,
78 all_events: Vec<(F, usize, TimeEvent)>,
80 current_time: F,
81 current_iteration: usize,
82 maximum_iterations: usize,
83 current_event: Option<TimeEvent>,
84 past_events: Vec<(F, usize, TimeEvent)>,
85}
86
87impl<F> FixedStepsize<F>
88where
89 F: num::Float + num::ToPrimitive + num::FromPrimitive,
90{
91 #[cfg_attr(feature = "tracing", instrument(skip_all))]
94 pub fn from_partial_save_steps(
95 t0: F,
96 dt: F,
97 n_steps: u64,
98 save_interval: u64,
99 ) -> Result<Self, TimeError> {
100 let max_save_points = n_steps.div_ceil(save_interval);
101 let save_point_to_float = |u: u64| -> Result<F, TimeError> {
102 F::from_u64(save_interval * u).ok_or(TimeError(format!(
103 "Could not convert save_interval={save_interval} to type: {}",
104 std::any::type_name::<F>()
105 )))
106 };
107 let partial_save_points = (0..max_save_points + 1)
108 .map(|n| Ok(t0 + save_point_to_float(n)? * dt))
109 .collect::<Result<_, TimeError>>()?;
110 Self::from_partial_save_points(t0, dt, partial_save_points)
111 }
112
113 #[cfg_attr(feature = "tracing", instrument(skip_all))]
116 pub fn from_partial_save_interval(
117 t0: F,
118 dt: F,
119 t_max: F,
120 save_interval: F,
121 ) -> Result<Self, TimeError> {
122 let mut partial_save_points = vec![];
123 let mut t = t0;
124 while t <= t_max {
125 partial_save_points.push(t);
126 t = t + save_interval;
127 }
128 Self::from_partial_save_points(t0, dt, partial_save_points)
129 }
130
131 #[cfg_attr(feature = "tracing", instrument(skip_all))]
135 pub fn from_partial_save_freq(
136 t0: F,
137 dt: F,
138 t_max: F,
139 save_freq: usize,
140 ) -> Result<Self, TimeError> {
141 let max_iterations = F::to_usize(&((t_max - t0) / dt).round())
142 .ok_or(TimeError(format!("Could not round value to usize")))?;
143 let all_events = (0..max_iterations)
144 .step_by(save_freq)
145 .map(|n| {
146 Ok((
147 t0 + F::from_usize(n * save_freq).ok_or(TimeError(format!(
148 "Could not convert usize {} to type {}",
149 n,
150 std::any::type_name::<F>()
151 )))? * dt,
152 n,
153 TimeEvent::PartialSave,
154 ))
155 })
156 .collect::<Result<Vec<_>, TimeError>>()?;
157 Ok(Self {
158 dt,
159 t0,
160 all_events,
161 current_time: t0,
162 current_iteration: 0,
163 maximum_iterations: max_iterations,
164 current_event: Some(TimeEvent::PartialSave),
165 past_events: Vec::new(),
166 })
167 }
168
169 #[cfg_attr(feature = "tracing", instrument(skip_all))]
173 pub fn from_partial_save_points(
174 t0: F,
175 dt: F,
176 partial_save_points: Vec<F>,
177 ) -> Result<Self, TimeError> {
178 let mut save_points = partial_save_points;
180 save_points.sort_by(|x, y| x.partial_cmp(y).unwrap());
181 if save_points.iter().any(|x| t0 > *x) {
182 return Err(TimeError(
183 "Invalid time configuration! Evaluation time point is before starting time point."
184 .to_owned(),
185 ));
186 }
187 let last_save_point = save_points
188 .clone()
189 .into_iter()
190 .max_by(|x, y| x.partial_cmp(y).unwrap())
191 .ok_or(TimeError(
192 "No savepoints specified. Simulation will not save any results.".to_owned(),
193 ))?;
194 let maximum_iterations =
195 (((last_save_point - t0) / dt).round())
196 .to_usize()
197 .ok_or(TimeError(
198 "An error in casting of float type to usize occurred".to_owned(),
199 ))?;
200 let all_events = save_points
201 .clone()
202 .into_iter()
203 .map(|t_save| {
204 (
205 t_save,
206 ((t_save - t0) / dt).round().to_usize().unwrap(),
207 TimeEvent::PartialSave,
208 )
209 })
210 .collect();
211
212 let current_event = if t0
213 == save_points
214 .into_iter()
215 .min_by(|x, y| x.partial_cmp(y).unwrap())
216 .unwrap()
217 {
218 Some(TimeEvent::PartialSave)
219 } else {
220 None
221 };
222
223 Ok(Self {
224 dt,
225 t0,
226 all_events,
227 current_time: t0,
228 current_iteration: 0,
229 maximum_iterations,
230 current_event,
232 past_events: Vec::new(),
233 })
234 }
235}
236
237impl<F> TimeStepper<F> for FixedStepsize<F>
238where
239 F: num::Float + num::FromPrimitive,
240{
241 #[cfg_attr(feature = "tracing", instrument(skip_all))]
242 fn advance(&mut self) -> Result<Option<NextTimePoint<F>>, TimeError> {
243 self.current_iteration += 1;
244 self.current_time = F::from_usize(self.current_iteration).ok_or(TimeError(
245 "Error when casting from usize to floating point value".to_owned(),
246 ))? * self.dt
247 + self.t0;
248 let event = self
250 .all_events
251 .iter()
252 .filter(|(_, iteration, _)| *iteration == self.current_iteration)
253 .map(|(_, _, event)| event.clone())
254 .last();
255
256 if self.current_iteration <= self.maximum_iterations {
257 Ok(Some(NextTimePoint {
258 increment: self.dt,
259 time: self.current_time,
260 iteration: self.current_iteration,
261 event,
262 }))
263 } else {
264 Ok(None)
265 }
266 }
267
268 #[cfg_attr(feature = "tracing", instrument(skip_all))]
269 fn save_initial(&self) -> Option<NextTimePoint<F>> {
270 if self.current_time == self.t0 {
271 Some(NextTimePoint {
272 increment: self.dt,
273 time: self.current_time,
274 iteration: self.current_iteration,
275 event: Some(TimeEvent::PartialSave),
276 })
277 } else {
278 None
279 }
280 }
281
282 #[cfg_attr(feature = "tracing", instrument(skip_all))]
283 fn get_last_full_save(&self) -> Option<(F, usize)> {
284 self.past_events
285 .clone()
286 .into_iter()
287 .filter(|(_, _, event)| *event == TimeEvent::FullSave)
288 .last()
289 .and_then(|x| Some((x.0, x.1)))
290 }
291
292 #[cfg_attr(feature = "tracing", instrument(skip_all))]
293 fn initialize_bar(&self, title: Option<&str>) -> Result<kdam::Bar, TimeError> {
294 let bar_format = "\
295 {desc}{percentage:3.0}%|{animation}| \
296 {count}/{total} \
297 [{elapsed}, \
298 {rate:.2}{unit}/s{postfix}]";
299 let mut bar = kdam::BarBuilder::default()
300 .total(self.maximum_iterations)
301 .bar_format(bar_format)
302 .dynamic_ncols(true);
303 if let Some(title) = title {
304 bar = bar.desc(title);
305 }
306 Ok(bar.build()?)
307 }
308
309 #[cfg_attr(feature = "tracing", instrument(skip_all))]
310 fn update_bar(&self, bar: &mut kdam::Bar) -> Result<(), std::io::Error> {
311 let _ = bar.update(1)?;
312 Ok(())
313 }
314}
315
316#[cfg(test)]
317mod test_time_stepper {
318 use rand::Rng;
319 use rand::SeedableRng;
320
321 use super::*;
322
323 fn generate_new_fixed_stepper<F>(rng_seed: u64) -> FixedStepsize<F>
324 where
325 F: num::Float + From<f32> + num::FromPrimitive,
326 {
327 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(rng_seed);
328 let t0 = <F as From<_>>::from(rng.random_range(0.0..1.0));
329 let dt = <F as From<_>>::from(rng.random_range(0.1..2.0));
330 let save_points = vec![
331 <F as From<_>>::from(rng.random_range(0.01..1.8)),
332 <F as From<_>>::from(rng.random_range(2.01..3.8)),
333 <F as From<_>>::from(rng.random_range(4.01..5.8)),
334 <F as From<_>>::from(rng.random_range(6.01..7.8)),
335 ];
336 FixedStepsize::<F>::from_partial_save_points(t0, dt, save_points).unwrap()
337 }
338
339 #[test]
340 fn initialization() {
341 let t0 = 1.0;
342 let dt = 0.2;
343 let save_points = vec![3.0, 5.0, 11.0, 20.0];
344 let time_stepper = FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
345 assert_eq!(t0, time_stepper.current_time);
346 assert_eq!(0.2, time_stepper.dt);
347 assert_eq!(0, time_stepper.current_iteration);
348 assert_eq!(None, time_stepper.current_event);
349 }
350
351 #[test]
352 #[should_panic]
353 fn panic_wrong_save_points() {
354 let t0 = 10.0;
355 let dt = 0.2;
356 let save_points = vec![3.0, 5.0, 11.0, 20.0];
357 let _time_stepper = FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
359 }
360
361 #[test]
362 fn stepping_1() {
363 let t0 = 1.0;
364 let dt = 0.2;
365 let save_points = vec![3.0, 5.0, 11.0, 20.0];
366 let mut time_stepper =
367 FixedStepsize::from_partial_save_points(t0, dt, save_points).unwrap();
368
369 for i in 1..11 {
370 let next = time_stepper.advance().unwrap().unwrap();
371 assert_eq!(dt, next.increment);
372 assert_eq!(t0 + i as f64 * dt, next.time);
373 assert_eq!(i, next.iteration);
374 if i == 10 {
375 assert_eq!(Some(TimeEvent::PartialSave), next.event);
376 } else {
377 assert_eq!(None, next.event);
378 }
379 }
380 }
381
382 #[test]
383 fn stepping_2() {
384 let t0 = 0.0;
385 let dt = 0.1;
386 let save_points = vec![0.5, 0.7, 0.9, 1.0];
387 let mut time_stepper =
388 FixedStepsize::from_partial_save_points(t0, dt, save_points.clone()).unwrap();
389
390 assert_eq!(t0, time_stepper.current_time);
391 for i in 1..11 {
392 let next = time_stepper.advance().unwrap().unwrap();
393 assert_eq!(dt, next.increment);
394 assert_eq!(t0 + i as f64 * dt, next.time);
395 assert_eq!(i, next.iteration);
396 if save_points.contains(&next.time) {
397 assert_eq!(Some(TimeEvent::PartialSave), next.event);
398 }
399 }
400 }
401
402 fn test_stepping(rng_seed: u64) {
403 let mut time_stepper = generate_new_fixed_stepper::<f32>(rng_seed);
404
405 for _ in 0..100 {
406 let res = time_stepper.advance().unwrap();
407 match res {
408 Some(_) => (),
409 None => return,
410 }
411 }
412 panic!("The time stepper should have reached the end by now");
413 }
414
415 #[test]
416 fn stepping_end_0() {
417 test_stepping(0);
418 }
419
420 #[test]
421 fn stepping_end_1() {
422 test_stepping(1);
423 }
424
425 #[test]
426 fn stepping_end_2() {
427 test_stepping(2);
428 }
429
430 #[test]
431 fn stepping_end_3() {
432 test_stepping(3);
433 }
434
435 #[test]
436 fn produce_correct_increments() {
437 let t0 = 10.0;
438 let dt = 0.1;
439 let t_max = 11.0;
440 let save_interval = 0.25;
441 let mut stepper =
442 FixedStepsize::from_partial_save_interval(t0, dt, t_max, save_interval).unwrap();
443 let all_times = Vec::from_iter(std::iter::from_fn(move || stepper.advance().unwrap()));
444 for time in all_times {
445 assert_eq!(time.increment, 0.1);
446 match time.event {
447 Some(_) => assert!((time.time - t0) % save_interval < dt),
448 _ => (),
449 }
450 }
451 }
452}