use crate::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use crate::integrate::sensitivity::traits::{Checkpoint, CheckpointStrategy};
#[derive(Debug, Clone)]
pub struct CheckpointManager<R: Runtime<DType = DType>> {
checkpoints: Vec<Checkpoint<R>>,
#[allow(dead_code)]
n_checkpoints: usize,
#[allow(dead_code)]
strategy: CheckpointStrategy,
#[allow(dead_code)]
t_span: [f64; 2],
checkpoint_times: Vec<f64>,
}
impl<R: Runtime<DType = DType>> CheckpointManager<R> {
pub fn new(n_checkpoints: usize, strategy: CheckpointStrategy, t_span: [f64; 2]) -> Self {
let checkpoint_times = Self::compute_checkpoint_times(n_checkpoints, strategy, t_span);
Self {
checkpoints: Vec::with_capacity(n_checkpoints + 2),
n_checkpoints,
strategy,
t_span,
checkpoint_times,
}
}
fn compute_checkpoint_times(
n_checkpoints: usize,
strategy: CheckpointStrategy,
t_span: [f64; 2],
) -> Vec<f64> {
let [t0, tf] = t_span;
let dt = tf - t0;
let n_interior = n_checkpoints.saturating_sub(2);
match strategy {
CheckpointStrategy::Uniform => {
let n_total = n_interior + 2;
(0..n_total)
.map(|i| t0 + dt * (i as f64) / ((n_total - 1) as f64))
.collect()
}
CheckpointStrategy::Logarithmic => {
let mut times = vec![t0];
if n_interior > 0 {
for i in 1..=n_interior {
let alpha = (i as f64) / ((n_interior + 1) as f64);
let k = 3.0;
let t = t0 + dt * ((alpha * k).exp() - 1.0) / (k.exp() - 1.0);
times.push(t);
}
}
times.push(tf);
times
}
CheckpointStrategy::Adaptive => {
let n_total = n_interior + 2;
(0..n_total)
.map(|i| t0 + dt * (i as f64) / ((n_total - 1) as f64))
.collect()
}
}
}
pub fn checkpoint_times(&self) -> &[f64] {
&self.checkpoint_times
}
pub fn add_checkpoint(&mut self, t: f64, y: Tensor<R>) {
self.checkpoints.push(Checkpoint::new(t, y));
}
pub fn checkpoints(&self) -> &[Checkpoint<R>] {
&self.checkpoints
}
pub fn len(&self) -> usize {
self.checkpoints.len()
}
pub fn is_empty(&self) -> bool {
self.checkpoints.is_empty()
}
pub fn find_interval(&self, t: f64) -> Option<(usize, usize)> {
if self.checkpoints.is_empty() {
return None;
}
let mut before_idx = 0;
for (i, ck) in self.checkpoints.iter().enumerate() {
if ck.t <= t {
before_idx = i;
} else {
break;
}
}
let after_idx = (before_idx + 1).min(self.checkpoints.len() - 1);
Some((before_idx, after_idx))
}
pub fn get(&self, index: usize) -> Option<&Checkpoint<R>> {
self.checkpoints.get(index)
}
pub fn iter_reverse(&self) -> impl Iterator<Item = &Checkpoint<R>> {
self.checkpoints.iter().rev()
}
pub fn should_checkpoint(&self, t: f64, tol: f64) -> bool {
self.checkpoint_times.iter().any(|&tc| (t - tc).abs() < tol)
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
fn make_checkpoint(t: f64) -> Checkpoint<CpuRuntime> {
let device = CpuDevice::new();
let y = Tensor::<CpuRuntime>::from_slice(&[t], &[1], &device);
Checkpoint::new(t, y)
}
#[test]
fn test_uniform_checkpoint_times() {
let manager =
CheckpointManager::<CpuRuntime>::new(5, CheckpointStrategy::Uniform, [0.0, 1.0]);
let times = manager.checkpoint_times();
assert_eq!(times.len(), 5);
assert!((times[0] - 0.0).abs() < 1e-10);
assert!((times[1] - 0.25).abs() < 1e-10);
assert!((times[2] - 0.5).abs() < 1e-10);
assert!((times[3] - 0.75).abs() < 1e-10);
assert!((times[4] - 1.0).abs() < 1e-10);
}
#[test]
fn test_logarithmic_checkpoint_times() {
let manager =
CheckpointManager::<CpuRuntime>::new(5, CheckpointStrategy::Logarithmic, [0.0, 1.0]);
let times = manager.checkpoint_times();
assert_eq!(times.len(), 5);
assert!((times[0] - 0.0).abs() < 1e-10);
assert!((times[4] - 1.0).abs() < 1e-10);
let dt1 = times[1] - times[0];
let dt2 = times[2] - times[1];
let dt3 = times[3] - times[2];
assert!(dt1 < dt2, "dt1={} should be < dt2={}", dt1, dt2);
assert!(dt2 < dt3, "dt2={} should be < dt3={}", dt2, dt3);
}
#[test]
fn test_add_and_find_checkpoint() {
let mut manager =
CheckpointManager::<CpuRuntime>::new(5, CheckpointStrategy::Uniform, [0.0, 1.0]);
manager.add_checkpoint(0.0, make_checkpoint(0.0).y);
manager.add_checkpoint(0.5, make_checkpoint(0.5).y);
manager.add_checkpoint(1.0, make_checkpoint(1.0).y);
assert_eq!(manager.len(), 3);
let (before, after) = manager.find_interval(0.3).unwrap();
assert_eq!(before, 0);
assert_eq!(after, 1);
let (before, after) = manager.find_interval(0.7).unwrap();
assert_eq!(before, 1);
assert_eq!(after, 2);
}
#[test]
fn test_should_checkpoint() {
let manager =
CheckpointManager::<CpuRuntime>::new(5, CheckpointStrategy::Uniform, [0.0, 1.0]);
let tol = 1e-8;
assert!(manager.should_checkpoint(0.0, tol));
assert!(manager.should_checkpoint(0.25, tol));
assert!(manager.should_checkpoint(0.5, tol));
assert!(!manager.should_checkpoint(0.3, tol));
}
}