use crate::DType;
use numr::error::{Error, Result};
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn simpson_impl<R, C>(
client: &C,
y: &Tensor<R>,
x: Option<&Tensor<R>>,
dx: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let y_shape = y.shape();
if y_shape.is_empty() {
return Err(Error::InvalidArgument {
arg: "y",
reason: "simpson: tensor must be at least 1D".to_string(),
});
}
let n = y_shape[y_shape.len() - 1];
if n < 2 {
return Err(Error::InvalidArgument {
arg: "y",
reason: "simpson: need at least 2 points".to_string(),
});
}
if let Some(x_val) = x {
simpson_variable_spacing(client, y, x_val, n)
} else {
simpson_constant_spacing(client, y, dx, n)
}
}
fn simpson_constant_spacing<R, C>(client: &C, y: &Tensor<R>, dx: f64, n: usize) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let y_shape = y.shape();
let last_dim = y_shape.len() - 1;
if n == 2 {
let sum_y = client.sum(y, &[last_dim], false)?;
return client.mul_scalar(&sum_y, 0.5 * dx);
}
let intervals = n - 1;
let weights = create_simpson_weights(n, intervals);
let weight_shape = if y_shape.len() == 1 {
vec![n]
} else {
vec![1, n]
};
let weight_tensor = Tensor::<R>::from_slice(&weights, &weight_shape, client.device());
let weighted = client.mul(y, &weight_tensor)?;
let sum = client.sum(&weighted, &[last_dim], false)?;
client.mul_scalar(&sum, dx / 3.0)
}
fn create_simpson_weights(n: usize, intervals: usize) -> Vec<f64> {
let mut weights = vec![0.0; n];
if intervals.is_multiple_of(2) {
weights[0] = 1.0;
weights[n - 1] = 1.0;
for (i, w) in weights.iter_mut().enumerate().take(n - 1).skip(1) {
*w = if i % 2 == 1 { 4.0 } else { 2.0 };
}
} else {
weights[0] = 1.0;
for (i, w) in weights.iter_mut().enumerate().take(n - 2).skip(1) {
*w = if i % 2 == 1 { 4.0 } else { 2.0 };
}
weights[n - 2] = if (n - 2) % 2 == 1 { 4.0 } else { 1.0 };
weights[n - 2] += 1.5;
weights[n - 1] = 1.5;
}
weights
}
fn simpson_variable_spacing<R, C>(
client: &C,
y: &Tensor<R>,
x: &Tensor<R>,
n: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let y_shape = y.shape();
let last_dim = y_shape.len() - 1;
let x_last_dim = x.shape().len() - 1;
if n == 2 {
let x_left = x.narrow(x_last_dim as isize, 0, 1)?.contiguous()?;
let x_right = x.narrow(x_last_dim as isize, 1, 1)?.contiguous()?;
let dx = client.sub(&x_right, &x_left)?;
let y_left = y.narrow(last_dim as isize, 0, 1)?.contiguous()?;
let y_right = y.narrow(last_dim as isize, 1, 1)?.contiguous()?;
let y_sum = client.add(&y_left, &y_right)?;
let area = client.mul(&dx, &y_sum)?;
let scaled = client.mul_scalar(&area, 0.5)?;
return client.sum(&scaled, &[last_dim], false);
}
let intervals = n - 1;
if intervals.is_multiple_of(2) {
simpson_even_intervals(client, y, x, n, last_dim, x_last_dim)
} else {
simpson_odd_intervals(client, y, x, n, last_dim, x_last_dim)
}
}
fn simpson_even_intervals<R, C>(
client: &C,
y: &Tensor<R>,
x: &Tensor<R>,
n: usize,
last_dim: usize,
x_last_dim: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let num_pairs = (n - 1) / 2;
let start_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j) as i64).collect();
let mid_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j + 1) as i64).collect();
let end_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j + 2) as i64).collect();
let start_idx = Tensor::<R>::from_slice(&start_indices, &[num_pairs], client.device());
let mid_idx = Tensor::<R>::from_slice(&mid_indices, &[num_pairs], client.device());
let end_idx = Tensor::<R>::from_slice(&end_indices, &[num_pairs], client.device());
let y_start = client.index_select(y, last_dim, &start_idx)?;
let y_mid = client.index_select(y, last_dim, &mid_idx)?;
let y_end = client.index_select(y, last_dim, &end_idx)?;
let x_start = client.index_select(x, x_last_dim, &start_idx)?;
let x_end = client.index_select(x, x_last_dim, &end_idx)?;
let x_diff = client.sub(&x_end, &x_start)?;
let h = client.mul_scalar(&x_diff, 0.5)?;
let y_mid_scaled = client.mul_scalar(&y_mid, 4.0)?;
let y_sum = client.add(&y_start, &y_mid_scaled)?;
let y_total = client.add(&y_sum, &y_end)?;
let h_over_3 = client.mul_scalar(&h, 1.0 / 3.0)?;
let contrib = client.mul(&h_over_3, &y_total)?;
client.sum(&contrib, &[last_dim], false)
}
fn simpson_odd_intervals<R, C>(
client: &C,
y: &Tensor<R>,
x: &Tensor<R>,
n: usize,
last_dim: usize,
x_last_dim: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let simpson_n = n - 1;
let num_pairs = (simpson_n - 1) / 2;
let start_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j) as i64).collect();
let mid_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j + 1) as i64).collect();
let end_indices: Vec<i64> = (0..num_pairs).map(|j| (2 * j + 2) as i64).collect();
let start_idx = Tensor::<R>::from_slice(&start_indices, &[num_pairs], client.device());
let mid_idx = Tensor::<R>::from_slice(&mid_indices, &[num_pairs], client.device());
let end_idx = Tensor::<R>::from_slice(&end_indices, &[num_pairs], client.device());
let y_start = client.index_select(y, last_dim, &start_idx)?;
let y_mid = client.index_select(y, last_dim, &mid_idx)?;
let y_end = client.index_select(y, last_dim, &end_idx)?;
let x_start = client.index_select(x, x_last_dim, &start_idx)?;
let x_end = client.index_select(x, x_last_dim, &end_idx)?;
let x_diff = client.sub(&x_end, &x_start)?;
let h = client.mul_scalar(&x_diff, 0.5)?;
let y_mid_scaled = client.mul_scalar(&y_mid, 4.0)?;
let y_sum = client.add(&y_start, &y_mid_scaled)?;
let y_total = client.add(&y_sum, &y_end)?;
let h_over_3 = client.mul_scalar(&h, 1.0 / 3.0)?;
let simpson_contrib = client.mul(&h_over_3, &y_total)?;
let simpson_integral = client.sum(&simpson_contrib, &[last_dim], false)?;
let last_x_left = x.narrow(x_last_dim as isize, n - 2, 1)?.contiguous()?;
let last_x_right = x.narrow(x_last_dim as isize, n - 1, 1)?.contiguous()?;
let last_dx = client.sub(&last_x_right, &last_x_left)?;
let last_y_left = y.narrow(last_dim as isize, n - 2, 1)?.contiguous()?;
let last_y_right = y.narrow(last_dim as isize, n - 1, 1)?.contiguous()?;
let last_y_sum = client.add(&last_y_left, &last_y_right)?;
let trap_area = client.mul(&last_dx, &last_y_sum)?;
let trap_scaled = client.mul_scalar(&trap_area, 0.5)?;
let trap_integral = client.sum(&trap_scaled, &[last_dim], false)?;
client.add(&simpson_integral, &trap_integral)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice};
fn get_client() -> CpuClient {
let device = CpuDevice::new();
CpuClient::new(device)
}
#[test]
fn test_simpson_constant_spacing() {
let client = get_client();
let y = Tensor::from_slice(&[0.0, 0.25, 1.0, 2.25, 4.0], &[5], client.device());
let result = simpson_impl(&client, &y, None, 0.5).unwrap();
let values: Vec<f64> = result.to_vec();
assert!((values[0] - 8.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_simpson_variable_spacing_even() {
let client = get_client();
let x = Tensor::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0], &[5], client.device());
let y = Tensor::from_slice(&[0.0, 0.25, 1.0, 2.25, 4.0], &[5], client.device());
let result = simpson_impl(&client, &y, Some(&x), 0.5).unwrap();
let values: Vec<f64> = result.to_vec();
assert!((values[0] - 8.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_simpson_variable_spacing_odd() {
let client = get_client();
let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], client.device());
let y = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], &[4], client.device());
let result = simpson_impl(&client, &y, Some(&x), 1.0).unwrap();
let values: Vec<f64> = result.to_vec();
assert!((values[0] - 4.5).abs() < 1e-10);
}
#[test]
fn test_simpson_batch() {
let client = get_client();
let x = Tensor::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0], &[5], client.device());
let y = Tensor::from_slice(
&[
0.0, 0.5, 1.0, 1.5, 2.0, 0.0, 0.25, 1.0, 2.25, 4.0, ],
&[2, 5],
client.device(),
);
let result = simpson_impl(&client, &y, Some(&x), 0.5).unwrap();
let values: Vec<f64> = result.to_vec();
assert!((values[0] - 2.0).abs() < 1e-10);
assert!((values[1] - 8.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_simpson_two_points() {
let client = get_client();
let x = Tensor::from_slice(&[0.0, 1.0], &[2], client.device());
let y = Tensor::from_slice(&[0.0, 1.0], &[2], client.device());
let result = simpson_impl(&client, &y, Some(&x), 1.0).unwrap();
let values: Vec<f64> = result.to_vec();
assert!((values[0] - 0.5).abs() < 1e-10);
}
}