use crate::interpolate::error::{InterpolateError, InterpolateResult};
use crate::interpolate::traits::cubic_spline::{SplineBoundary, SplineCoefficients};
use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::ops::ScalarOps;
use numr::prelude::DType;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
type DiagonalSystem<R> = (Tensor<R>, Tensor<R>, Tensor<R>, Tensor<R>);
pub fn cubic_spline_coefficients<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
boundary: &SplineBoundary,
) -> InterpolateResult<SplineCoefficients<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let device = client.device();
let n = x.shape()[0];
if n < 2 {
return Err(InterpolateError::InsufficientData {
required: 2,
actual: n,
context: "cubic_spline_coefficients".to_string(),
});
}
let x_hi = x.narrow(0, 1, n - 1)?;
let x_lo = x.narrow(0, 0, n - 1)?;
let h = client.sub(&x_hi, &x_lo)?;
let y_hi = y.narrow(0, 1, n - 1)?;
let y_lo = y.narrow(0, 0, n - 1)?;
let dy = client.sub(&y_hi, &y_lo)?;
let slopes = client.div(&dy, &h)?;
let h_lo_int = h.narrow(0, 0, n - 2)?.contiguous()?;
let h_hi_int = h.narrow(0, 1, n - 2)?.contiguous()?;
let main_interior = client.mul_scalar(&client.add(&h_lo_int, &h_hi_int)?, 2.0)?;
let s_lo = slopes.narrow(0, 0, n - 2)?.contiguous()?;
let s_hi = slopes.narrow(0, 1, n - 2)?.contiguous()?;
let rhs_interior = client.mul_scalar(&client.sub(&s_hi, &s_lo)?, 3.0)?;
let (main_diag, upper_diag, lower_diag, rhs) = build_boundary_diagonals(
client,
&h,
&slopes,
&main_interior,
&rhs_interior,
boundary,
n,
)?;
let a_mat = build_tridiagonal(client, &main_diag, &upper_diag, &lower_diag)?;
let rhs_col = rhs.reshape(&[n, 1])?;
let c_col = LinearAlgebraAlgorithms::solve(client, &a_mat, &rhs_col).map_err(|e| {
InterpolateError::NumericalError {
message: format!("Failed to solve tridiagonal system: {}", e),
}
})?;
let c_tensor = c_col.reshape(&[n])?;
let a_tensor = y.clone();
let c_left = c_tensor.narrow(0, 0, n - 1)?;
let c_right = c_tensor.narrow(0, 1, n - 1)?;
let two_c_left = client.mul_scalar(&c_left, 2.0)?;
let sum_c = client.add(&two_c_left, &c_right)?;
let three = Tensor::full_scalar(&[n - 1], DType::F64, 3.0, device);
let hc_term = client.div(&client.mul(&h, &sum_c)?, &three)?;
let b_tensor = client.sub(&slopes, &hc_term)?;
let c_diff = client.sub(&c_right, &c_left)?;
let three_h = client.mul_scalar(&h, 3.0)?;
let d_tensor = client.div(&c_diff, &three_h)?;
Ok((a_tensor, b_tensor, c_tensor, d_tensor))
}
fn build_boundary_diagonals<R, C>(
client: &C,
h: &Tensor<R>,
slopes: &Tensor<R>,
main_interior: &Tensor<R>,
rhs_interior: &Tensor<R>,
boundary: &SplineBoundary,
n: usize,
) -> InterpolateResult<DiagonalSystem<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + RuntimeClient<R>,
{
let device = client.device();
let one = Tensor::full_scalar(&[1], DType::F64, 1.0, device);
let zero_1 = Tensor::zeros(&[1], DType::F64, device);
let h_lo_int = h.narrow(0, 0, n - 2)?.contiguous()?;
let h_hi_int = h.narrow(0, 1, n - 2)?.contiguous()?;
match boundary {
SplineBoundary::Natural => {
let main = client.cat(&[&one, main_interior, &one], 0)?;
let upper = client.cat(&[&zero_1, &h_hi_int], 0)?;
let lower = client.cat(&[&h_lo_int, &zero_1], 0)?;
let rhs = client.cat(&[&zero_1, rhs_interior, &zero_1], 0)?;
Ok((main, upper, lower, rhs))
}
SplineBoundary::Clamped { left, right } => {
let h_first = h.narrow(0, 0, 1)?.contiguous()?;
let h_last = h.narrow(0, n - 2, 1)?.contiguous()?;
let two_h_first = client.mul_scalar(&h_first, 2.0)?;
let two_h_last = client.mul_scalar(&h_last, 2.0)?;
let main = client.cat(&[&two_h_first, main_interior, &two_h_last], 0)?;
let upper = client.cat(&[&h_first, &h_hi_int], 0)?;
let lower = client.cat(&[&h_lo_int, &h_last], 0)?;
let s_first = slopes.narrow(0, 0, 1)?.contiguous()?;
let s_last = slopes.narrow(0, n - 2, 1)?.contiguous()?;
let left_t = Tensor::full_scalar(&[1], DType::F64, *left, device);
let right_t = Tensor::full_scalar(&[1], DType::F64, *right, device);
let rhs_first = client.mul_scalar(&client.sub(&s_first, &left_t)?, 3.0)?;
let rhs_last = client.mul_scalar(&client.sub(&right_t, &s_last)?, 3.0)?;
let rhs = client.cat(&[&rhs_first, rhs_interior, &rhs_last], 0)?;
Ok((main, upper, lower, rhs))
}
SplineBoundary::NotAKnot => {
if n < 4 {
let main = client.cat(&[&one, main_interior, &one], 0)?;
let upper = client.cat(&[&zero_1, &h_hi_int], 0)?;
let lower = client.cat(&[&h_lo_int, &zero_1], 0)?;
let rhs = client.cat(&[&zero_1, rhs_interior, &zero_1], 0)?;
Ok((main, upper, lower, rhs))
} else {
let h0 = h.narrow(0, 0, 1)?.contiguous()?;
let h1 = h.narrow(0, 1, 1)?.contiguous()?;
let hn3 = h.narrow(0, n - 3, 1)?.contiguous()?;
let hn2 = h.narrow(0, n - 2, 1)?.contiguous()?;
let h1_sq = client.mul(&h1, &h1)?;
let main_first = client.mul(&h1_sq, &h0)?;
let hn3_sq = client.mul(&hn3, &hn3)?;
let main_last = client.mul(&hn3_sq, &hn2)?;
let main = client.cat(&[&main_first, main_interior, &main_last], 0)?;
let h0_sq = client.mul(&h0, &h0)?;
let h0_sq_h1 = client.mul(&h0_sq, &h1)?;
let h1_sq_h0 = client.mul(&h1_sq, &h0)?;
let upper_first = client.mul_scalar(&client.add(&h0_sq_h1, &h1_sq_h0)?, -1.0)?;
let upper = client.cat(&[&upper_first, &h_hi_int], 0)?;
let hn2_sq = client.mul(&hn2, &hn2)?;
let hn2_sq_hn3 = client.mul(&hn2_sq, &hn3)?;
let hn3_sq_hn2 = client.mul(&hn3_sq, &hn2)?;
let lower_last = client.mul_scalar(&client.add(&hn2_sq_hn3, &hn3_sq_hn2)?, -1.0)?;
let lower = client.cat(&[&h_lo_int, &lower_last], 0)?;
let s0 = slopes.narrow(0, 0, 1)?.contiguous()?;
let s1 = slopes.narrow(0, 1, 1)?.contiguous()?;
let rhs_first = client.mul(&h0_sq_h1, &client.sub(&s1, &s0)?)?;
let sn3 = slopes.narrow(0, n - 3, 1)?.contiguous()?;
let sn2 = slopes.narrow(0, n - 2, 1)?.contiguous()?;
let rhs_last = client.mul(&hn2_sq_hn3, &client.sub(&sn2, &sn3)?)?;
let rhs = client.cat(&[&rhs_first, rhs_interior, &rhs_last], 0)?;
Ok((main, upper, lower, rhs))
}
}
}
}
fn build_tridiagonal<R, C>(
client: &C,
main_diag: &Tensor<R>,
upper_diag: &Tensor<R>,
lower_diag: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + RuntimeClient<R>,
{
let device = client.device();
let n = main_diag.shape()[0];
let main_mat = client.diagflat(main_diag)?;
let upper_small = client.diagflat(upper_diag)?;
let zero_col = Tensor::zeros(&[n - 1, 1], DType::F64, device);
let upper_shifted = client.cat(&[&zero_col, &upper_small], 1)?; let zero_row = Tensor::zeros(&[1, n], DType::F64, device);
let upper_mat = client.cat(&[&upper_shifted, &zero_row], 0)?;
let lower_small = client.diagflat(lower_diag)?;
let lower_shifted = client.cat(&[&lower_small, &zero_col], 1)?; let lower_mat = client.cat(&[&zero_row, &lower_shifted], 0)?;
let result = client.add(&client.add(&main_mat, &upper_mat)?, &lower_mat)?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_natural_spline_interpolates_knots() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 4.0, 9.0], &[4], &device);
let (a, b, c, d) =
cubic_spline_coefficients(&client, &x, &y, &SplineBoundary::Natural).unwrap();
let a_v: Vec<f64> = a.to_vec();
let b_v: Vec<f64> = b.to_vec();
let c_v: Vec<f64> = c.to_vec();
let d_v: Vec<f64> = d.to_vec();
for i in 0..3 {
let val = a_v[i] + b_v[i] + c_v[i] + d_v[i];
assert!(
(val - a_v[i + 1]).abs() < 1e-10,
"segment {} endpoint: {} vs {}",
i,
val,
a_v[i + 1]
);
}
}
#[test]
fn test_clamped_spline_derivative() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 0.0], &[3], &device);
let (_a, b, _c, _d) = cubic_spline_coefficients(
&client,
&x,
&y,
&SplineBoundary::Clamped {
left: 1.0,
right: -1.0,
},
)
.unwrap();
let b_v: Vec<f64> = b.to_vec();
assert!(
(b_v[0] - 1.0).abs() < 1e-10,
"left derivative: {} vs 1.0",
b_v[0]
);
}
#[test]
fn test_not_a_knot_spline_continuity() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0], &[5], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 4.0, 9.0, 16.0], &[5], &device);
let (a, b, c, d) =
cubic_spline_coefficients(&client, &x, &y, &SplineBoundary::NotAKnot).unwrap();
let a_v: Vec<f64> = a.to_vec();
let b_v: Vec<f64> = b.to_vec();
let c_v: Vec<f64> = c.to_vec();
let d_v: Vec<f64> = d.to_vec();
for i in 0..4 {
let val = a_v[i] + b_v[i] + c_v[i] + d_v[i];
assert!(
(val - a_v[i + 1]).abs() < 1e-8,
"segment {} mismatch: {} vs {}",
i,
val,
a_v[i + 1]
);
}
}
}