use numra_core::Scalar;
use crate::error::SolverError;
pub struct TEvalEmitter<'a, S: Scalar> {
points: &'a [S],
idx: usize,
direction: S,
}
impl<'a, S: Scalar> TEvalEmitter<'a, S> {
pub fn new(points: &'a [S], direction: S) -> Self {
Self {
points,
idx: 0,
direction,
}
}
pub fn is_done(&self) -> bool {
self.idx >= self.points.len()
}
pub fn emit_step(
&mut self,
t_old: S,
y_old: &[S],
dy_old: &[S],
t_new: S,
y_new: &[S],
dy_new: &[S],
t_out: &mut Vec<S>,
y_out: &mut Vec<S>,
) {
let dim = y_old.len();
let h = t_new - t_old;
if h == S::ZERO {
while self.idx < self.points.len() && self.points[self.idx] == t_old {
t_out.push(t_old);
y_out.extend_from_slice(y_old);
self.idx += 1;
}
return;
}
while self.idx < self.points.len() {
let t_q = self.points[self.idx];
let in_step = if self.direction > S::ZERO {
t_q >= t_old && t_q <= t_new
} else {
t_q <= t_old && t_q >= t_new
};
if !in_step {
break;
}
let theta = (t_q - t_old) / h;
let theta2 = theta * theta;
let theta3 = theta2 * theta;
let three = S::from_f64(3.0);
let h00 = S::TWO * theta3 - three * theta2 + S::ONE;
let h10 = theta3 - S::TWO * theta2 + theta;
let h01 = -S::TWO * theta3 + three * theta2;
let h11 = theta3 - theta2;
t_out.push(t_q);
for i in 0..dim {
y_out.push(
h00 * y_old[i] + h10 * h * dy_old[i] + h01 * y_new[i] + h11 * h * dy_new[i],
);
}
self.idx += 1;
}
}
}
pub fn validate_grid<S: Scalar>(grid: &[S], t0: S, tf: S) -> Result<(), SolverError> {
if grid.is_empty() {
return Ok(());
}
let direction = if tf >= t0 { S::ONE } else { -S::ONE };
let (lo, hi) = if direction > S::ZERO {
(t0, tf)
} else {
(tf, t0)
};
let span = (tf - t0).abs();
let tol = S::EPSILON * S::from_f64(16.0) * (span + S::ONE);
for window in grid.windows(2) {
let d = window[1] - window[0];
if d * direction < -tol {
return Err(SolverError::Other(
"t_eval must be sorted in the direction of integration".into(),
));
}
}
for &t in grid {
if t < lo - tol || t > hi + tol {
return Err(SolverError::Other(format!(
"t_eval contains {} which lies outside [t0, tf] = [{}, {}]",
t.to_f64(),
t0.to_f64(),
tf.to_f64()
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_rejects_out_of_range() {
assert!(validate_grid::<f64>(&[0.0, 1.0, 5.0], 0.0, 4.0).is_err());
assert!(validate_grid::<f64>(&[-1.0, 0.0, 1.0], 0.0, 4.0).is_err());
}
#[test]
fn validate_rejects_unsorted() {
assert!(validate_grid::<f64>(&[0.0, 2.0, 1.0], 0.0, 4.0).is_err());
assert!(validate_grid::<f64>(&[4.0, 1.0, 2.0], 4.0, 0.0).is_err());
}
#[test]
fn validate_accepts_descending_for_backward() {
assert!(validate_grid::<f64>(&[4.0, 3.0, 2.0, 1.0, 0.0], 4.0, 0.0).is_ok());
}
#[test]
fn emit_reproduces_endpoints_exactly() {
let grid = vec![0.0, 0.5, 1.0];
let mut emitter = TEvalEmitter::new(&grid, 1.0_f64);
let mut t_out = Vec::new();
let mut y_out = Vec::new();
emitter.emit_step(
0.0,
&[0.0],
&[1.0],
1.0,
&[1.0],
&[1.0],
&mut t_out,
&mut y_out,
);
assert_eq!(t_out, vec![0.0, 0.5, 1.0]);
assert!((y_out[0] - 0.0).abs() < 1e-15);
assert!((y_out[1] - 0.5).abs() < 1e-15);
assert!((y_out[2] - 1.0).abs() < 1e-15);
assert!(emitter.is_done());
}
#[test]
fn emit_advances_across_step_boundary_without_double_count() {
let grid = vec![1.0, 2.0];
let mut emitter = TEvalEmitter::new(&grid, 1.0_f64);
let mut t_out = Vec::new();
let mut y_out = Vec::new();
emitter.emit_step(
0.0,
&[0.0],
&[1.0],
1.0,
&[1.0],
&[1.0],
&mut t_out,
&mut y_out,
);
assert_eq!(t_out.len(), 1);
emitter.emit_step(
1.0,
&[1.0],
&[1.0],
2.0,
&[2.0],
&[1.0],
&mut t_out,
&mut y_out,
);
assert_eq!(t_out, vec![1.0, 2.0]);
}
#[test]
fn emit_handles_backward_direction() {
let grid = vec![1.0, 0.5, 0.0];
let mut emitter = TEvalEmitter::new(&grid, -1.0_f64);
let mut t_out = Vec::new();
let mut y_out = Vec::new();
emitter.emit_step(
1.0,
&[1.0],
&[1.0],
0.0,
&[0.0],
&[1.0],
&mut t_out,
&mut y_out,
);
assert_eq!(t_out, vec![1.0, 0.5, 0.0]);
assert!((y_out[1] - 0.5).abs() < 1e-15);
}
}