use numra_core::Scalar;
#[derive(Clone, Debug)]
pub struct DenseSegment<S: Scalar> {
pub t_start: S,
pub t_end: S,
pub coeffs: Vec<S>,
pub dim: usize,
}
impl<S: Scalar> DenseSegment<S> {
pub fn new(t_start: S, t_end: S, coeffs: Vec<S>, dim: usize) -> Self {
Self {
t_start,
t_end,
coeffs,
dim,
}
}
#[inline]
pub fn contains(&self, t: S) -> bool {
t >= self.t_start && t <= self.t_end
}
#[inline]
pub fn h(&self) -> S {
self.t_end - self.t_start
}
#[inline]
pub fn theta(&self, t: S) -> S {
(t - self.t_start) / self.h()
}
}
#[derive(Clone, Debug)]
pub struct DenseOutput<S: Scalar> {
segments: Vec<DenseSegment<S>>,
#[allow(dead_code)]
dim: usize,
direction: S,
}
impl<S: Scalar> Default for DenseOutput<S> {
fn default() -> Self {
Self::new(0, S::ONE)
}
}
impl<S: Scalar> DenseOutput<S> {
pub fn new(dim: usize, direction: S) -> Self {
Self {
segments: Vec::new(),
dim,
direction,
}
}
pub fn add_segment(&mut self, segment: DenseSegment<S>) {
self.segments.push(segment);
}
pub fn len(&self) -> usize {
self.segments.len()
}
pub fn is_empty(&self) -> bool {
self.segments.is_empty()
}
pub fn tspan(&self) -> Option<(S, S)> {
if self.segments.is_empty() {
None
} else {
let t0 = self.segments.first().unwrap().t_start;
let tf = self.segments.last().unwrap().t_end;
Some((t0, tf))
}
}
pub fn find_segment(&self, t: S) -> Option<&DenseSegment<S>> {
if self.segments.is_empty() {
return None;
}
let first = &self.segments[0];
let last = &self.segments[self.segments.len() - 1];
if self.direction > S::ZERO {
if t < first.t_start || t > last.t_end {
return None;
}
} else {
if t > first.t_start || t < last.t_end {
return None;
}
}
let mut lo = 0;
let mut hi = self.segments.len();
while lo < hi {
let mid = (lo + hi) / 2;
let seg = &self.segments[mid];
if seg.contains(t) {
return Some(seg);
}
if self.direction > S::ZERO {
if t < seg.t_start {
hi = mid;
} else {
lo = mid + 1;
}
} else {
if t > seg.t_start {
hi = mid;
} else {
lo = mid + 1;
}
}
}
None
}
pub fn clear(&mut self) {
self.segments.clear();
}
pub fn segments(&self) -> &[DenseSegment<S>] {
&self.segments
}
}
pub trait DenseInterpolant<S: Scalar> {
fn interpolate(&self, segment: &DenseSegment<S>, t: S, y_out: &mut [S]);
fn interpolate_derivative(&self, segment: &DenseSegment<S>, t: S, dydt_out: &mut [S]);
}
#[derive(Clone, Debug, Default)]
pub struct DoPri5Interpolant;
impl<S: Scalar> DenseInterpolant<S> for DoPri5Interpolant {
fn interpolate(&self, segment: &DenseSegment<S>, t: S, y_out: &mut [S]) {
let theta = segment.theta(t);
let h = segment.h();
let dim = segment.dim;
for i in 0..dim {
let y0 = segment.coeffs[i];
let d0 = segment.coeffs[dim + i];
let d1 = segment.coeffs[2 * dim + i];
let d2 = segment.coeffs[3 * dim + i];
let d3 = segment.coeffs[4 * dim + i];
let d4 = segment.coeffs[5 * dim + i];
let poly = d0 + theta * (d1 + theta * (d2 + theta * (d3 + theta * d4)));
y_out[i] = y0 + h * theta * poly;
}
}
fn interpolate_derivative(&self, segment: &DenseSegment<S>, t: S, dydt_out: &mut [S]) {
let theta = segment.theta(t);
let dim = segment.dim;
for i in 0..dim {
let d0 = segment.coeffs[dim + i];
let d1 = segment.coeffs[2 * dim + i];
let d2 = segment.coeffs[3 * dim + i];
let d3 = segment.coeffs[4 * dim + i];
let d4 = segment.coeffs[5 * dim + i];
let two = S::from_f64(2.0);
let three = S::from_f64(3.0);
let four = S::from_f64(4.0);
let five = S::from_f64(5.0);
let theta2 = theta * theta;
let theta3 = theta2 * theta;
let theta4 = theta3 * theta;
dydt_out[i] = d0
+ two * theta * d1
+ three * theta2 * d2
+ four * theta3 * d3
+ five * theta4 * d4;
}
}
}
impl DoPri5Interpolant {
pub fn build_coefficients<S: Scalar>(y0: &[S], y1: &[S], k: &[S], h: S, dim: usize) -> Vec<S> {
let mut coeffs = vec![S::ZERO; 6 * dim];
for i in 0..dim {
coeffs[i] = y0[i];
}
let k1 = &k[0..dim];
let _k2 = &k[dim..2 * dim];
let k3 = &k[2 * dim..3 * dim];
let k4 = &k[3 * dim..4 * dim];
let k5 = &k[4 * dim..5 * dim];
let _k6 = &k[5 * dim..6 * dim];
let k7 = &k[6 * dim..7 * dim];
for i in 0..dim {
let d0 = k1[i];
let ydiff = y1[i] - y0[i];
let bspl = h * k1[i] - ydiff;
let d1 = ydiff - h * k1[i];
let d2 = S::from_f64(2.0) * bspl - h * (k7[i] - k1[i]);
let d3 = -S::from_f64(2.0) * bspl
+ h * (k7[i] - k1[i])
+ h * (S::from_f64(-5.0 / 3.0) * k1[i]
+ S::from_f64(1.0 / 3.0) * k3[i]
+ S::from_f64(1.0 / 3.0) * k4[i]
+ S::from_f64(-1.0 / 3.0) * k5[i]
+ S::from_f64(5.0 / 3.0) * k7[i]);
let d4 = S::ZERO;
coeffs[dim + i] = d0;
coeffs[2 * dim + i] = d1;
coeffs[3 * dim + i] = d2;
coeffs[4 * dim + i] = d3;
coeffs[5 * dim + i] = d4;
}
coeffs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dense_segment_contains() {
let seg = DenseSegment::<f64>::new(0.0, 1.0, vec![], 1);
assert!(seg.contains(0.0));
assert!(seg.contains(0.5));
assert!(seg.contains(1.0));
assert!(!seg.contains(-0.1));
assert!(!seg.contains(1.1));
}
#[test]
fn test_dense_segment_theta() {
let seg = DenseSegment::<f64>::new(1.0, 2.0, vec![], 1);
assert!((seg.theta(1.0) - 0.0).abs() < 1e-10);
assert!((seg.theta(1.5) - 0.5).abs() < 1e-10);
assert!((seg.theta(2.0) - 1.0).abs() < 1e-10);
}
#[test]
fn test_dense_output_find_segment() {
let mut dense = DenseOutput::<f64>::new(1, 1.0);
dense.add_segment(DenseSegment::new(0.0, 1.0, vec![], 1));
dense.add_segment(DenseSegment::new(1.0, 2.0, vec![], 1));
dense.add_segment(DenseSegment::new(2.0, 3.0, vec![], 1));
assert!(dense.find_segment(0.5).is_some());
assert!(dense.find_segment(1.5).is_some());
assert!(dense.find_segment(2.5).is_some());
assert!(dense.find_segment(-0.5).is_none());
assert!(dense.find_segment(3.5).is_none());
}
#[test]
fn test_dopri5_interpolant_endpoints() {
let y0 = vec![1.0];
let y1 = vec![2.0];
let h = 1.0;
let k = vec![1.0; 7];
let coeffs = DoPri5Interpolant::build_coefficients(&y0, &y1, &k, h, 1);
let seg = DenseSegment::new(0.0, 1.0, coeffs, 1);
let interp = DoPri5Interpolant;
let mut y_at_0 = vec![0.0];
let mut y_at_1 = vec![0.0];
interp.interpolate(&seg, 0.0, &mut y_at_0);
interp.interpolate(&seg, 1.0, &mut y_at_1);
assert!((y_at_0[0] - y0[0]).abs() < 1e-10);
}
#[test]
fn test_dense_output_tspan() {
let mut dense = DenseOutput::<f64>::new(1, 1.0);
assert!(dense.tspan().is_none());
dense.add_segment(DenseSegment::new(0.0, 1.0, vec![], 1));
dense.add_segment(DenseSegment::new(1.0, 2.0, vec![], 1));
let (t0, tf) = dense.tspan().unwrap();
assert!((t0 - 0.0).abs() < 1e-10);
assert!((tf - 2.0).abs() < 1e-10);
}
}