use super::{DTypeSupport, create_index_tensor};
use crate::algorithm::polynomial::helpers::{
validate_polynomial_coeffs, validate_polynomial_dtype,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, IndexingOps, ScalarOps, ShapeOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn polyval_impl<R, C>(
client: &C,
coeffs: &Tensor<R>,
x: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + BinaryOps<R> + ScalarOps<R> + IndexingOps<R> + ShapeOps<R>,
{
validate_polynomial_dtype(coeffs.dtype())?;
validate_polynomial_dtype(x.dtype())?;
dtype_support.check(coeffs.dtype(), "polyval")?;
if coeffs.dtype() != x.dtype() {
return Err(Error::DTypeMismatch {
lhs: coeffs.dtype(),
rhs: x.dtype(),
});
}
let n = validate_polynomial_coeffs(coeffs.shape())?;
let device = client.device();
let index_dtype = dtype_support.index_dtype;
if n == 1 {
let idx = create_index_tensor::<R>(0, index_dtype, device);
let c0 = client.index_select(coeffs, 0, &idx)?; let result = c0.broadcast_to(x.shape())?;
return Ok(result.contiguous());
}
let last_idx = create_index_tensor::<R>(n - 1, index_dtype, device);
let mut result = client.index_select(coeffs, 0, &last_idx)?;
result = result.broadcast_to(x.shape())?.contiguous();
for i in (0..n - 1).rev() {
result = client.mul(&result, x)?;
let idx = create_index_tensor::<R>(i, index_dtype, device);
let coeff_i = client.index_select(coeffs, 0, &idx)?;
result = client.add(&result, &coeff_i)?;
}
Ok(result)
}