use core::hash::{Hash, Hasher};
use crate::kernel::{Complex, Float, Problem, ProblemKind, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RdftKind {
R2C,
C2R,
R2HC,
HC2R,
R2R,
}
#[derive(Debug, Clone)]
pub struct RdftProblem<T: Float> {
pub sz: Tensor,
pub vecsz: Tensor,
pub real_buf: *mut T,
pub complex_buf: *mut Complex<T>,
pub kind: RdftKind,
}
unsafe impl<T: Float> Send for RdftProblem<T> {}
unsafe impl<T: Float> Sync for RdftProblem<T> {}
impl<T: Float> RdftProblem<T> {
#[must_use]
pub fn new_r2c_1d(n: usize, real_input: *mut T, complex_output: *mut Complex<T>) -> Self {
Self {
sz: Tensor::rank1(n),
vecsz: Tensor::empty(),
real_buf: real_input,
complex_buf: complex_output,
kind: RdftKind::R2C,
}
}
#[must_use]
pub fn new_c2r_1d(n: usize, complex_input: *mut Complex<T>, real_output: *mut T) -> Self {
Self {
sz: Tensor::rank1(n),
vecsz: Tensor::empty(),
real_buf: real_output,
complex_buf: complex_input,
kind: RdftKind::C2R,
}
}
#[must_use]
pub fn transform_size(&self) -> usize {
self.sz.total_size()
}
#[must_use]
pub fn complex_size(&self) -> usize {
self.transform_size() / 2 + 1
}
}
impl<T: Float> Hash for RdftProblem<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.sz.hash(state);
self.vecsz.hash(state);
self.kind.hash(state);
}
}
impl<T: Float> Problem for RdftProblem<T> {
fn kind(&self) -> ProblemKind {
ProblemKind::Rdft
}
fn zero(&self) {
}
fn total_size(&self) -> usize {
self.sz.total_size() * self.vecsz.total_size().max(1)
}
fn is_inplace(&self) -> bool {
self.real_buf as *const () == self.complex_buf as *const ()
}
}