use crate::enums::{BaseKind, TransformKind};
use crate::traits::{
BaseElements, BaseFromOrtho, BaseGradient, BaseMatOpDiffmat, BaseMatOpLaplacian,
BaseMatOpStencil, BaseSize, BaseTransform,
};
use crate::types::{FloatNum, ScalarNum};
use ndarray::{s, Array2};
use num_complex::Complex;
use num_traits::{One, Zero};
use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
use std::convert::TryInto;
use std::f64::consts::PI;
use std::ops::{Add, Div, Mul, Sub};
use std::sync::Arc;
#[derive(Clone)]
pub struct FourierR2c<A> {
n: usize,
m: usize,
plan_fwd: Arc<dyn RealToComplex<A>>,
plan_bwd: Arc<dyn ComplexToReal<A>>,
}
impl<A: FloatNum> FourierR2c<A> {
#[must_use]
pub fn new(n: usize) -> Self {
let mut planner = RealFftPlanner::<A>::new();
Self {
n,
m: n / 2 + 1,
plan_fwd: Arc::clone(&planner.plan_fft_forward(n)),
plan_bwd: Arc::clone(&planner.plan_fft_inverse(n)),
}
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn nodes(n: usize) -> Vec<A> {
let n64 = n as f64;
ndarray::Array1::range(0., 2. * PI, 2. * PI / n64)
.mapv(|elem| A::from_f64(elem).unwrap())
.to_vec()
}
#[must_use]
pub fn wavenumber(n: usize) -> Vec<Complex<A>> {
(0..=n / 2)
.map(|x| Complex::new(A::zero(), A::from(x).unwrap()))
.collect::<Vec<Complex<A>>>()
}
}
impl<A: FloatNum> BaseSize for FourierR2c<A> {
fn len_phys(&self) -> usize {
self.n
}
fn len_spec(&self) -> usize {
self.m
}
fn len_orth(&self) -> usize {
self.m
}
}
impl<A: FloatNum> BaseElements for FourierR2c<A> {
type RealNum = A;
fn base_kind(&self) -> BaseKind {
BaseKind::FourierR2c
}
fn transform_kind(&self) -> TransformKind {
TransformKind::R2c
}
fn coords(&self) -> Vec<A> {
Self::nodes(self.len_phys())
}
}
impl<A: FloatNum> BaseMatOpDiffmat for FourierR2c<A> {
type NumType = Complex<A>;
fn diffmat(&self, deriv: usize) -> Array2<Self::NumType> {
assert!(deriv > 0);
let mut mat = Array2::<Self::NumType>::zeros((self.len_spec(), self.len_spec()));
let wavenum = Self::wavenumber(self.len_phys());
for (l, k) in mat.diag_mut().iter_mut().zip(wavenum.iter()) {
*l = k.powi(deriv.try_into().unwrap());
}
mat
}
fn diffmat_pinv(&self, deriv: usize) -> (Array2<Self::NumType>, Array2<Self::NumType>) {
assert!(deriv > 0);
let peye: Array2<Self::NumType> = Array2::<Self::NumType>::eye(self.m)
.slice(s![1.., ..])
.to_owned();
let mut pinv = self.diffmat(deriv);
for p in pinv.slice_mut(ndarray::s![1.., 1..]).diag_mut().iter_mut() {
*p = Self::NumType::one() / *p;
}
(pinv, peye)
}
}
impl<A: FloatNum> BaseMatOpStencil for FourierR2c<A> {
type NumType = A;
fn stencil(&self) -> Array2<Self::NumType> {
Array2::<A>::eye(self.len_spec())
}
fn stencil_inv(&self) -> Array2<Self::NumType> {
Array2::<A>::eye(self.len_spec())
}
}
impl<A: FloatNum> BaseMatOpLaplacian for FourierR2c<A> {
type NumType = A;
fn laplacian(&self) -> Array2<Self::NumType> {
let wavenum = Self::wavenumber(self.len_phys());
let mut lap = Array2::<A>::zeros((self.m, self.m));
for (l, k) in lap.diag_mut().iter_mut().zip(wavenum.iter()) {
*l = -k.im * k.im;
}
lap
}
fn laplacian_pinv(&self) -> (Array2<Self::NumType>, Array2<Self::NumType>) {
let mut d_pinv = self.laplacian();
for p in d_pinv.slice_mut(s![1.., 1..]).diag_mut().iter_mut() {
*p = A::one() / *p;
}
let i_pinv = Array2::<A>::eye(self.m);
(d_pinv, i_pinv.slice(s![1.., ..]).to_owned())
}
}
impl<A, T> BaseGradient<T> for FourierR2c<A>
where
A: FloatNum,
T: ScalarNum
+ Add<Complex<A>, Output = T>
+ Mul<Complex<A>, Output = T>
+ Div<Complex<A>, Output = T>
+ Sub<Complex<A>, Output = T>,
{
fn gradient_slice(&self, indata: &[T], outdata: &mut [T], n_times: usize) {
assert!(outdata.len() == indata.len());
assert!(outdata.len() == self.len_spec());
for (y, x) in outdata.iter_mut().zip(indata.iter()) {
*y = *x;
}
for _ in 0..n_times {
for (k, y) in outdata.iter_mut().enumerate() {
let ki: Complex<A> = Complex::<A>::new(A::zero(), A::from_usize(k).unwrap());
*y = *y * ki;
}
}
}
}
impl<A, T> BaseFromOrtho<T> for FourierR2c<A>
where
A: FloatNum,
T: ScalarNum,
{
fn to_ortho_slice(&self, indata: &[T], outdata: &mut [T]) {
for (y, x) in outdata.iter_mut().zip(indata.iter()) {
*y = *x;
}
}
fn from_ortho_slice(&self, indata: &[T], outdata: &mut [T]) {
for (y, x) in outdata.iter_mut().zip(indata.iter()) {
*y = *x;
}
}
}
impl<A: FloatNum> BaseTransform for FourierR2c<A> {
type Physical = A;
type Spectral = Complex<A>;
fn forward_slice(&self, indata: &[Self::Physical], outdata: &mut [Self::Spectral]) {
assert!(indata.len() == self.len_phys());
assert!(outdata.len() == self.len_spec());
let mut indata_mut = vec![Self::Physical::zero(); indata.len()];
for (a, b) in indata_mut.iter_mut().zip(indata.iter()) {
*a = *b;
}
self.plan_fwd.process(&mut indata_mut, outdata).unwrap();
}
fn backward_slice(&self, indata: &[Self::Spectral], outdata: &mut [Self::Physical]) {
assert!(indata.len() == self.len_spec());
assert!(outdata.len() == self.len_phys());
let mut indata_mut = vec![Self::Spectral::zero(); indata.len()];
let cor = A::one() / A::from_usize(self.len_phys()).unwrap();
for (a, b) in indata_mut.iter_mut().zip(indata.iter()) {
a.re = cor * b.re;
a.im = cor * b.im;
}
indata_mut[0].im = A::zero();
if self.n % 2 == 0 {
indata_mut[self.m - 1].im = A::zero();
}
self.plan_bwd.process(&mut indata_mut, outdata).unwrap();
}
}