use super::{DTypeSupport, convolve_impl, create_index_tensor};
use crate::algorithm::fft::FftAlgorithms;
use crate::algorithm::polynomial::helpers::{validate_polynomial_dtype, validate_polynomial_roots};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ComplexOps, IndexingOps, ReduceOps, ShapeOps, UnaryOps, UtilityOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn polyfromroots_impl<R, C>(
client: &C,
roots_real: &Tensor<R>,
roots_imag: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ ReduceOps<R>
+ FftAlgorithms<R>
+ ComplexOps<R>,
{
validate_polynomial_dtype(roots_real.dtype())?;
validate_polynomial_dtype(roots_imag.dtype())?;
dtype_support.check(roots_real.dtype(), "polyfromroots")?;
if roots_real.dtype() != roots_imag.dtype() {
return Err(Error::DTypeMismatch {
lhs: roots_real.dtype(),
rhs: roots_imag.dtype(),
});
}
let n_roots = validate_polynomial_roots(roots_real.shape())?;
let n_imag = validate_polynomial_roots(roots_imag.shape())?;
if n_roots != n_imag {
return Err(Error::ShapeMismatch {
expected: vec![n_roots],
got: vec![n_imag],
});
}
let dtype = roots_real.dtype();
let device = client.device();
let index_dtype = dtype_support.index_dtype;
if n_roots == 0 {
return Ok(Tensor::full_scalar(&[1], dtype, 1.0, device));
}
let mut p_real = Tensor::full_scalar(&[1], dtype, 1.0, device);
let mut p_imag = Tensor::full_scalar(&[1], dtype, 0.0, device);
for i in 0..n_roots {
let idx = create_index_tensor::<R>(i, index_dtype, device);
let r_real = client.index_select(roots_real, 0, &idx)?;
let r_imag = client.index_select(roots_imag, 0, &idx)?;
let neg_r_real = client.neg(&r_real)?;
let neg_r_imag = client.neg(&r_imag)?;
let one = Tensor::full_scalar(&[1], dtype, 1.0, device);
let zero = Tensor::full_scalar(&[1], dtype, 0.0, device);
let factor_real = client.cat(&[&neg_r_real, &one], 0)?;
let factor_imag = client.cat(&[&neg_r_imag, &zero], 0)?;
let conv_rr = convolve_impl(client, &p_real, &factor_real, dtype_support)?;
let conv_ii = convolve_impl(client, &p_imag, &factor_imag, dtype_support)?;
let conv_ri = convolve_impl(client, &p_real, &factor_imag, dtype_support)?;
let conv_ir = convolve_impl(client, &p_imag, &factor_real, dtype_support)?;
p_real = client.sub(&conv_rr, &conv_ii)?;
p_imag = client.add(&conv_ri, &conv_ir)?;
}
Ok(p_real)
}