#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum FftNorm {
#[default]
Backward,
Forward,
Ortho,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FftDirection {
Forward,
Inverse,
}
impl FftNorm {
pub fn scale_factor_f64(self, n: usize, direction: FftDirection) -> f64 {
let nf = n as f64;
match (self, direction) {
(FftNorm::Backward, FftDirection::Forward) => 1.0,
(FftNorm::Backward, FftDirection::Inverse) => 1.0 / nf,
(FftNorm::Forward, FftDirection::Forward) => 1.0 / nf,
(FftNorm::Forward, FftDirection::Inverse) => 1.0,
(FftNorm::Ortho, _) => 1.0 / nf.sqrt(),
}
}
#[inline]
pub(crate) fn scale_factor(self, n: usize, direction: FftDirection) -> f64 {
self.scale_factor_f64(n, direction)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backward_norm_factors() {
let n = 8;
assert_eq!(
FftNorm::Backward.scale_factor(n, FftDirection::Forward),
1.0
);
assert!((FftNorm::Backward.scale_factor(n, FftDirection::Inverse) - 0.125).abs() < 1e-15);
}
#[test]
fn forward_norm_factors() {
let n = 8;
assert!((FftNorm::Forward.scale_factor(n, FftDirection::Forward) - 0.125).abs() < 1e-15);
assert_eq!(FftNorm::Forward.scale_factor(n, FftDirection::Inverse), 1.0);
}
#[test]
fn ortho_norm_factors() {
let n = 4;
let expected = 1.0 / 2.0; assert!((FftNorm::Ortho.scale_factor(n, FftDirection::Forward) - expected).abs() < 1e-15);
assert!((FftNorm::Ortho.scale_factor(n, FftDirection::Inverse) - expected).abs() < 1e-15);
}
#[test]
fn default_is_backward() {
assert_eq!(FftNorm::default(), FftNorm::Backward);
}
}