mod convolve;
mod polyfromroots;
mod polymul;
mod polyroots;
mod polyval;
pub use convolve::convolve_impl;
pub use polyfromroots::polyfromroots_impl;
pub use polymul::polymul_impl;
pub use polyroots::polyroots_impl;
pub use polyval::polyval_impl;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct DTypeSupport {
pub f32: bool,
pub f64: bool,
pub index_dtype: DType,
}
impl DTypeSupport {
pub const FULL: Self = Self {
f32: true,
f64: true,
index_dtype: DType::I64,
};
pub const F32_ONLY: Self = Self {
f32: true,
f64: false,
index_dtype: DType::I32,
};
pub fn check(&self, dtype: DType, op: &'static str) -> Result<()> {
match dtype {
DType::F32 if self.f32 => Ok(()),
DType::F64 if self.f64 => Ok(()),
DType::F16 | DType::BF16 | DType::FP8E4M3 | DType::FP8E5M2 if self.f32 => Ok(()),
DType::F32 | DType::F64 => Err(Error::UnsupportedDType { dtype, op }),
_ => Err(Error::UnsupportedDType { dtype, op }),
}
}
}
pub(crate) fn create_index_tensor<R: Runtime<DType = DType>>(
index: usize,
index_dtype: DType,
device: &R::Device,
) -> Tensor<R> {
match index_dtype {
DType::I32 => Tensor::<R>::from_slice(&[index as i32], &[1], device),
_ => Tensor::<R>::from_slice(&[index as i64], &[1], device),
}
}
pub(crate) fn create_arange_tensor<R: Runtime<DType = DType>>(
start: usize,
end: usize,
index_dtype: DType,
device: &R::Device,
) -> Tensor<R> {
match index_dtype {
DType::I32 => {
let indices: Vec<i32> = (start..end).map(|i| i as i32).collect();
Tensor::<R>::from_slice(&indices, &[indices.len()], device)
}
_ => {
let indices: Vec<i64> = (start..end).map(|i| i as i64).collect();
Tensor::<R>::from_slice(&indices, &[indices.len()], device)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_support() {
assert!(DTypeSupport::FULL.check(DType::F32, "test").is_ok());
assert!(DTypeSupport::FULL.check(DType::F64, "test").is_ok());
assert!(DTypeSupport::FULL.check(DType::I32, "test").is_err());
assert!(DTypeSupport::F32_ONLY.check(DType::F32, "test").is_ok());
assert!(DTypeSupport::F32_ONLY.check(DType::F64, "test").is_err());
}
}