use crate::DType;
use crate::interpolate::error::InterpolateResult;
use crate::interpolate::traits::interp1d::InterpMethod;
use numr::ops::{CompareOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn interp1d_evaluate<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
x_new: &Tensor<R>,
method: InterpMethod,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
match method {
InterpMethod::Nearest => evaluate_nearest(client, x, y, x_new),
InterpMethod::Linear => evaluate_linear(client, x, y, x_new),
InterpMethod::Cubic => evaluate_cubic(client, x, y, x_new),
}
}
fn evaluate_nearest<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
x_new: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let m = x_new.shape()[0];
let n = x.shape()[0];
let device = client.device();
let indices = client.searchsorted(x, x_new, false)?;
let ones = Tensor::<R>::from_slice(&vec![1i64; m], &[m], device);
let n_minus_1 = Tensor::<R>::from_slice(&vec![(n - 1) as i64; m], &[m], device);
let indices_clamped = client.maximum(&client.minimum(&indices, &n_minus_1)?, &ones)?;
let idx = client.sub(&indices_clamped, &ones)?;
let idx_plus_1 = client.add(&idx, &ones)?;
let x0 = client.index_select(x, 0, &idx)?;
let x1 = client.index_select(x, 0, &idx_plus_1)?;
let y0 = client.index_select(y, 0, &idx)?;
let y1 = client.index_select(y, 0, &idx_plus_1)?;
let d0 = client.sub(x_new, &x0)?;
let d1 = client.sub(&x1, x_new)?;
let diff = client.sub(&d1, &d0)?;
let diff_abs = client.abs(&diff)?;
let epsilon = Tensor::<R>::from_slice(&vec![1e-14; m], &[m], device);
let sum = client.add(&diff, &diff_abs)?;
let denom = client.add(&client.mul_scalar(&diff_abs, 2.0)?, &epsilon)?;
let indicator = client.div(&sum, &denom)?;
let ones_f64 = Tensor::<R>::from_slice(&vec![1.0; m], &[m], device);
let one_minus_ind = client.sub(&ones_f64, &indicator)?;
let term0 = client.mul(&y0, &indicator)?;
let term1 = client.mul(&y1, &one_minus_ind)?;
let result = client.add(&term0, &term1)?;
Ok(result)
}
fn evaluate_linear<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
x_new: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let m = x_new.shape()[0];
let n = x.shape()[0];
let device = client.device();
let indices = client.searchsorted(x, x_new, false)?;
let ones = Tensor::<R>::from_slice(&vec![1i64; m], &[m], device);
let n_minus_1 = Tensor::<R>::from_slice(&vec![(n - 1) as i64; m], &[m], device);
let indices_clamped = client.maximum(&client.minimum(&indices, &n_minus_1)?, &ones)?;
let idx = client.sub(&indices_clamped, &ones)?;
let idx_plus_1 = client.add(&idx, &ones)?;
let x0 = client.index_select(x, 0, &idx)?;
let x1 = client.index_select(x, 0, &idx_plus_1)?;
let y0 = client.index_select(y, 0, &idx)?;
let y1 = client.index_select(y, 0, &idx_plus_1)?;
let dx = client.sub(&x1, &x0)?;
let dy = client.sub(&y1, &y0)?;
let x_offset = client.sub(x_new, &x0)?;
let epsilon = Tensor::<R>::from_slice(&vec![1e-14; m], &[m], device);
let dx_safe = client.add(&dx, &epsilon)?;
let t = client.div(&x_offset, &dx_safe)?;
let scaled_dy = client.mul(&dy, &t)?;
let result = client.add(&y0, &scaled_dy)?;
Ok(result)
}
fn evaluate_cubic<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
x_new: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let m = x_new.shape()[0];
let n = x.shape()[0];
let device = client.device();
let indices = client.searchsorted(x, x_new, false)?;
let zeros_i64 = Tensor::<R>::from_slice(&vec![0i64; m], &[m], device);
let ones_i64 = Tensor::<R>::from_slice(&vec![1i64; m], &[m], device);
let n_minus_1 = Tensor::<R>::from_slice(&vec![(n - 1) as i64; m], &[m], device);
let indices_clamped = client.maximum(&client.minimum(&indices, &n_minus_1)?, &ones_i64)?;
let i1 = client.sub(&indices_clamped, &ones_i64)?; let i2 = client.add(&i1, &ones_i64)?;
let i1_minus_1 = client.sub(&i1, &ones_i64)?;
let i0 = client.maximum(&i1_minus_1, &zeros_i64)?;
let i2_plus_1 = client.add(&i2, &ones_i64)?;
let i3 = client.minimum(&i2_plus_1, &n_minus_1)?;
let x0 = client.index_select(x, 0, &i0)?;
let x1 = client.index_select(x, 0, &i1)?;
let x2 = client.index_select(x, 0, &i2)?;
let x3 = client.index_select(x, 0, &i3)?;
let y0 = client.index_select(y, 0, &i0)?;
let y1 = client.index_select(y, 0, &i1)?;
let y2 = client.index_select(y, 0, &i2)?;
let y3 = client.index_select(y, 0, &i3)?;
let h1 = client.sub(&x2, &x1)?; let h0 = client.sub(&x1, &x0)?; let h2 = client.sub(&x3, &x2)?;
let epsilon = Tensor::<R>::from_slice(&vec![1e-14; m], &[m], device);
let h1_safe = client.add(&h1, &epsilon)?;
let h0_safe = client.add(&h0, &epsilon)?;
let h2_safe = client.add(&h2, &epsilon)?;
let slope_01 = client.div(&client.sub(&y1, &y0)?, &h0_safe)?;
let slope_12 = client.div(&client.sub(&y2, &y1)?, &h1_safe)?;
let slope_23 = client.div(&client.sub(&y3, &y2)?, &h2_safe)?;
let ones_f64 = Tensor::<R>::from_slice(&vec![1.0; m], &[m], device);
let half = Tensor::<R>::from_slice(&vec![0.5; m], &[m], device);
let h0_abs = client.abs(&h0)?;
let h0_ratio = client.div(&h0_abs, &client.add(&h0_abs, &epsilon)?)?;
let left_boundary = client.sub(&ones_f64, &h0_ratio)?;
let h2_abs = client.abs(&h2)?;
let h2_ratio = client.div(&h2_abs, &client.add(&h2_abs, &epsilon)?)?;
let right_boundary = client.sub(&ones_f64, &h2_ratio)?;
let avg_m1 = client.mul(&half, &client.add(&slope_01, &slope_12)?)?;
let one_minus_left = client.sub(&ones_f64, &left_boundary)?;
let m1 = client.add(
&client.mul(&left_boundary, &slope_12)?,
&client.mul(&one_minus_left, &avg_m1)?,
)?;
let avg_m2 = client.mul(&half, &client.add(&slope_12, &slope_23)?)?;
let one_minus_right = client.sub(&ones_f64, &right_boundary)?;
let m2 = client.add(
&client.mul(&right_boundary, &slope_12)?,
&client.mul(&one_minus_right, &avg_m2)?,
)?;
let x_offset = client.sub(x_new, &x1)?;
let t = client.div(&x_offset, &h1_safe)?;
let t2 = client.mul(&t, &t)?;
let t3 = client.mul(&t2, &t)?;
let h00 = client.add_scalar(
&client.sub(&client.mul_scalar(&t3, 2.0)?, &client.mul_scalar(&t2, 3.0)?)?,
1.0,
)?;
let h10 = client.add(&client.sub(&t3, &client.mul_scalar(&t2, 2.0)?)?, &t)?;
let h01 = client.add(
&client.mul_scalar(&t3, -2.0)?,
&client.mul_scalar(&t2, 3.0)?,
)?;
let h11 = client.sub(&t3, &t2)?;
let term1 = client.mul(&h00, &y1)?;
let term2 = client.mul(&h10, &client.mul(&h1, &m1)?)?;
let term3 = client.mul(&h01, &y2)?;
let term4 = client.mul(&h11, &client.mul(&h1, &m2)?)?;
let result = client.add(&client.add(&term1, &term2)?, &client.add(&term3, &term4)?)?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_linear_interpolation() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 2.0, 4.0, 6.0], &[4], &device);
let x_new = Tensor::<CpuRuntime>::from_slice(&[0.5, 1.5, 2.5], &[3], &device);
let y_new = interp1d_evaluate(&client, &x, &y, &x_new, InterpMethod::Linear).unwrap();
let y_result: Vec<f64> = y_new.to_vec();
assert!((y_result[0] - 1.0).abs() < 1e-10);
assert!((y_result[1] - 3.0).abs() < 1e-10);
assert!((y_result[2] - 5.0).abs() < 1e-10);
}
}