use super::{DTypeSupport, create_arange_tensor, create_index_tensor};
use crate::algorithm::linalg::LinearAlgebraAlgorithms;
use crate::algorithm::polynomial::helpers::validate_polynomial_coeffs;
use crate::algorithm::polynomial::helpers::validate_polynomial_dtype;
use crate::algorithm::polynomial::types::PolynomialRoots;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{
BinaryOps, CompareOps, IndexingOps, LinalgOps, ReduceOps, ScalarOps, ShapeOps, UtilityOps,
};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn polyroots_impl<R, C>(
client: &C,
coeffs: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<PolynomialRoots<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ LinearAlgebraAlgorithms<R>
+ BinaryOps<R>
+ ScalarOps<R>
+ LinalgOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ ReduceOps<R>
+ CompareOps<R>,
{
validate_polynomial_dtype(coeffs.dtype())?;
dtype_support.check(coeffs.dtype(), "polyroots")?;
let n = validate_polynomial_coeffs(coeffs.shape())?;
let dtype = coeffs.dtype();
let device = client.device();
let index_dtype = dtype_support.index_dtype;
if n == 1 {
return Ok(PolynomialRoots {
roots_real: Tensor::zeros(&[0], dtype, device),
roots_imag: Tensor::zeros(&[0], dtype, device),
});
}
let degree = n - 1;
let last_idx = create_index_tensor::<R>(n - 1, index_dtype, device);
let leading_tensor = client.index_select(coeffs, 0, &last_idx)?;
let mut companion = Tensor::zeros(&[degree, degree], dtype, device);
if degree > 1 {
let sub_eye = client.eye(degree - 1, None, dtype)?;
let zeros_row = Tensor::zeros(&[1, degree - 1], dtype, device);
let zeros_col = Tensor::zeros(&[degree, 1], dtype, device);
let sub_with_top = client.cat(&[&zeros_row, &sub_eye], 0)?; companion = client.cat(&[&sub_with_top, &zeros_col], 1)?;
}
let coeff_indices = create_arange_tensor::<R>(0, degree, index_dtype, device);
let lower_coeffs = client.index_select(coeffs, 0, &coeff_indices)?;
let neg_coeffs = client.neg(&lower_coeffs)?;
let leading_broadcast = leading_tensor.broadcast_to(&[degree])?.contiguous();
let last_col = client.div(&neg_coeffs, &leading_broadcast)?;
let last_col_2d = last_col.reshape(&[degree, 1])?;
let col_indices = Tensor::full_scalar(&[degree], index_dtype, (degree - 1) as f64, device);
let col_indices_2d = col_indices.reshape(&[degree, 1])?;
companion = client.scatter(&companion, 1, &col_indices_2d, &last_col_2d)?;
let eig = client.eig_decompose(&companion)?;
Ok(PolynomialRoots {
roots_real: eig.eigenvalues_real,
roots_imag: eig.eigenvalues_imag,
})
}