#![allow(dead_code)]
use crate::DType;
use crate::interpolate::error::{InterpolateError, InterpolateResult};
use numr::ops::{CompareOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct ValidatedData<R: Runtime<DType = DType>> {
pub x: Tensor<R>,
pub y: Tensor<R>,
pub n: usize,
pub x_min: f64,
pub x_max: f64,
}
pub struct HermiteDataTensor<'a, R: Runtime<DType = DType>> {
pub x: &'a Tensor<R>,
pub y: &'a Tensor<R>,
pub slopes: &'a Tensor<R>,
pub n: usize,
}
pub fn validate_inputs<R: Runtime<DType = DType>>(
x: &Tensor<R>,
y: &Tensor<R>,
context: &str,
) -> InterpolateResult<ValidatedData<R>> {
let x_shape = x.shape();
let y_shape = y.shape();
if x_shape.len() != 1 || y_shape.len() != 1 {
return Err(InterpolateError::InvalidParameter {
parameter: "x, y".to_string(),
message: "x and y must be 1D tensors".to_string(),
});
}
let n = x_shape[0];
if n != y_shape[0] {
return Err(InterpolateError::ShapeMismatch {
expected: n,
actual: y_shape[0],
context: context.to_string(),
});
}
if n < 2 {
return Err(InterpolateError::InsufficientData {
required: 2,
actual: n,
context: context.to_string(),
});
}
let x_data: Vec<f64> = x.contiguous()?.to_vec();
for i in 1..n {
if x_data[i] <= x_data[i - 1] {
return Err(InterpolateError::NotMonotonic {
context: context.to_string(),
});
}
}
let x_min = x_data[0];
let x_max = x_data[n - 1];
Ok(ValidatedData {
x: x.clone(),
y: y.clone(),
n,
x_min,
x_max,
})
}
pub fn evaluate_hermite_tensor<R, C>(
client: &C,
x_new: &Tensor<R>,
data: &HermiteDataTensor<'_, R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let x_new_shape = x_new.shape();
if x_new_shape.len() != 1 {
return Err(InterpolateError::InvalidParameter {
parameter: "x_new".to_string(),
message: "x_new must be a 1D tensor".to_string(),
});
}
let indices = client
.searchsorted(data.x, x_new, false)
.map_err(to_interp_err)?;
let n = data.n;
let ones_i64 = create_constant_tensor_i64(client, x_new_shape[0], 1)?;
let n_minus_1_i64 = create_constant_tensor_i64(client, x_new_shape[0], (n - 1) as i64)?;
let indices_clamped = client
.maximum(
&client
.minimum(&indices, &n_minus_1_i64)
.map_err(to_interp_err)?,
&ones_i64,
)
.map_err(to_interp_err)?;
let idx = client
.sub(&indices_clamped, &ones_i64)
.map_err(to_interp_err)?;
let idx_plus_1 = client.add(&idx, &ones_i64).map_err(to_interp_err)?;
let x0 = client
.index_select(data.x, 0, &idx)
.map_err(to_interp_err)?;
let x1 = client
.index_select(data.x, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let y0 = client
.index_select(data.y, 0, &idx)
.map_err(to_interp_err)?;
let y1 = client
.index_select(data.y, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let d0 = client
.index_select(data.slopes, 0, &idx)
.map_err(to_interp_err)?;
let d1 = client
.index_select(data.slopes, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let h = client.sub(&x1, &x0).map_err(to_interp_err)?;
let x_shifted = client.sub(x_new, &x0).map_err(to_interp_err)?;
let t = client.div(&x_shifted, &h).map_err(to_interp_err)?;
let t2 = client.mul(&t, &t).map_err(to_interp_err)?;
let t3 = client.mul(&t2, &t).map_err(to_interp_err)?;
let h00 = client
.add_scalar(
&client
.sub(
&client.mul_scalar(&t3, 2.0).map_err(to_interp_err)?,
&client.mul_scalar(&t2, 3.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?,
1.0,
)
.map_err(to_interp_err)?;
let h10 = client
.add(
&client
.sub(&t3, &client.mul_scalar(&t2, 2.0).map_err(to_interp_err)?)
.map_err(to_interp_err)?,
&t,
)
.map_err(to_interp_err)?;
let h01 = client
.add(
&client.mul_scalar(&t3, -2.0).map_err(to_interp_err)?,
&client.mul_scalar(&t2, 3.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?;
let h11 = client.sub(&t3, &t2).map_err(to_interp_err)?;
let term1 = client.mul(&h00, &y0).map_err(to_interp_err)?;
let term2 = client
.mul(&h10, &client.mul(&h, &d0).map_err(to_interp_err)?)
.map_err(to_interp_err)?;
let term3 = client.mul(&h01, &y1).map_err(to_interp_err)?;
let term4 = client
.mul(&h11, &client.mul(&h, &d1).map_err(to_interp_err)?)
.map_err(to_interp_err)?;
let result = client
.add(
&client.add(&term1, &term2).map_err(to_interp_err)?,
&client.add(&term3, &term4).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?;
Ok(result)
}
pub fn derivative_hermite_tensor<R, C>(
client: &C,
x_new: &Tensor<R>,
data: &HermiteDataTensor<'_, R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let x_new_shape = x_new.shape();
if x_new_shape.len() != 1 {
return Err(InterpolateError::InvalidParameter {
parameter: "x_new".to_string(),
message: "x_new must be a 1D tensor".to_string(),
});
}
let indices = client
.searchsorted(data.x, x_new, false)
.map_err(to_interp_err)?;
let n = data.n;
let ones_i64 = create_constant_tensor_i64(client, x_new_shape[0], 1)?;
let n_minus_1_i64 = create_constant_tensor_i64(client, x_new_shape[0], (n - 1) as i64)?;
let indices_clamped = client
.maximum(
&client
.minimum(&indices, &n_minus_1_i64)
.map_err(to_interp_err)?,
&ones_i64,
)
.map_err(to_interp_err)?;
let idx = client
.sub(&indices_clamped, &ones_i64)
.map_err(to_interp_err)?;
let idx_plus_1 = client.add(&idx, &ones_i64).map_err(to_interp_err)?;
let x0 = client
.index_select(data.x, 0, &idx)
.map_err(to_interp_err)?;
let x1 = client
.index_select(data.x, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let y0 = client
.index_select(data.y, 0, &idx)
.map_err(to_interp_err)?;
let y1 = client
.index_select(data.y, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let d0 = client
.index_select(data.slopes, 0, &idx)
.map_err(to_interp_err)?;
let d1 = client
.index_select(data.slopes, 0, &idx_plus_1)
.map_err(to_interp_err)?;
let h = client.sub(&x1, &x0).map_err(to_interp_err)?;
let x_shifted = client.sub(x_new, &x0).map_err(to_interp_err)?;
let t = client.div(&x_shifted, &h).map_err(to_interp_err)?;
let t2 = client.mul(&t, &t).map_err(to_interp_err)?;
let dh00 = client
.div(
&client
.sub(
&client.mul_scalar(&t2, 6.0).map_err(to_interp_err)?,
&client.mul_scalar(&t, 6.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?,
&h,
)
.map_err(to_interp_err)?;
let dh10 = client
.add_scalar(
&client
.sub(
&client.mul_scalar(&t2, 3.0).map_err(to_interp_err)?,
&client.mul_scalar(&t, 4.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?,
1.0,
)
.map_err(to_interp_err)?;
let dh01 = client
.div(
&client
.add(
&client.mul_scalar(&t2, -6.0).map_err(to_interp_err)?,
&client.mul_scalar(&t, 6.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?,
&h,
)
.map_err(to_interp_err)?;
let dh11 = client
.sub(
&client.mul_scalar(&t2, 3.0).map_err(to_interp_err)?,
&client.mul_scalar(&t, 2.0).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?;
let term1 = client.mul(&dh00, &y0).map_err(to_interp_err)?;
let term2 = client.mul(&dh10, &d0).map_err(to_interp_err)?;
let term3 = client.mul(&dh01, &y1).map_err(to_interp_err)?;
let term4 = client.mul(&dh11, &d1).map_err(to_interp_err)?;
let result = client
.add(
&client.add(&term1, &term2).map_err(to_interp_err)?,
&client.add(&term3, &term4).map_err(to_interp_err)?,
)
.map_err(to_interp_err)?;
Ok(result)
}
fn create_constant_tensor_i64<R, C>(
client: &C,
len: usize,
value: i64,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let data = vec![value; len];
Ok(Tensor::from_slice(&data, &[len], client.device()))
}
fn to_interp_err(e: numr::error::Error) -> InterpolateError {
InterpolateError::NumericalError {
message: format!("Tensor operation failed: {}", e),
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_validate_inputs() {
let (device, _client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 4.0], &[3], &device);
let result = validate_inputs(&x, &y, "test");
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data.n, 3);
assert!((data.x_min - 0.0).abs() < 1e-10);
assert!((data.x_max - 2.0).abs() < 1e-10);
}
#[test]
fn test_validate_inputs_non_monotonic() {
let (device, _client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 2.0, 1.0], &[3], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let result = validate_inputs(&x, &y, "test");
assert!(matches!(result, Err(InterpolateError::NotMonotonic { .. })));
}
#[test]
fn test_validate_inputs_shape_mismatch() {
let (device, _client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 2.0], &[3], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0], &[2], &device);
let result = validate_inputs(&x, &y, "test");
assert!(matches!(
result,
Err(InterpolateError::ShapeMismatch { .. })
));
}
}