use core::hash::{Hash, Hasher};
use crate::kernel::{Complex, Float, Problem, ProblemKind, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Sign {
Forward = -1,
Backward = 1,
}
impl Sign {
#[must_use]
pub const fn value(self) -> i32 {
self as i32
}
}
#[derive(Debug, Clone)]
pub struct DftProblem<T: Float> {
pub sz: Tensor,
pub vecsz: Tensor,
pub input: *mut Complex<T>,
pub output: *mut Complex<T>,
pub sign: Sign,
}
unsafe impl<T: Float> Send for DftProblem<T> {}
unsafe impl<T: Float> Sync for DftProblem<T> {}
impl<T: Float> DftProblem<T> {
#[must_use]
pub fn new_1d(n: usize, input: *mut Complex<T>, output: *mut Complex<T>, sign: Sign) -> Self {
Self {
sz: Tensor::rank1(n),
vecsz: Tensor::empty(),
input,
output,
sign,
}
}
#[must_use]
pub fn new_2d(
n0: usize,
n1: usize,
input: *mut Complex<T>,
output: *mut Complex<T>,
sign: Sign,
) -> Self {
Self {
sz: Tensor::rank2(n0, n1),
vecsz: Tensor::empty(),
input,
output,
sign,
}
}
#[must_use]
pub fn is_inplace(&self) -> bool {
self.input == self.output
}
#[must_use]
pub fn transform_size(&self) -> usize {
self.sz.total_size()
}
#[must_use]
pub fn batch_size(&self) -> usize {
if self.vecsz.is_empty() {
1
} else {
self.vecsz.total_size()
}
}
}
impl<T: Float> Hash for DftProblem<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.sz.hash(state);
self.vecsz.hash(state);
self.sign.hash(state);
self.is_inplace().hash(state);
}
}
impl<T: Float> Problem for DftProblem<T> {
fn kind(&self) -> ProblemKind {
ProblemKind::Dft
}
fn zero(&self) {
let size = self.sz.total_size() * self.vecsz.total_size().max(1);
unsafe {
for i in 0..size {
*self.output.add(i) = Complex::zero();
}
}
}
fn total_size(&self) -> usize {
self.transform_size() * self.batch_size()
}
fn is_inplace(&self) -> bool {
self.input == self.output
}
}