use std::f64::consts::PI;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr;
use super::types::{SciComplexVector, SciResult};
#[derive(Clone, Copy, Debug)]
struct Complex {
re: f64,
im: f64,
}
impl Complex {
fn new(re: f64, im: f64) -> Self {
Complex { re, im }
}
fn mul(self, other: Self) -> Self {
Complex {
re: self.re * other.re - self.im * other.im,
im: self.re * other.im + self.im * other.re,
}
}
fn add(self, other: Self) -> Self {
Complex {
re: self.re + other.re,
im: self.im + other.im,
}
}
fn sub(self, other: Self) -> Self {
Complex {
re: self.re - other.re,
im: self.im - other.im,
}
}
}
fn is_power_of_2(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0
}
fn next_power_of_2(n: usize) -> usize {
if n == 0 {
return 1;
}
let mut v = n - 1;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v |= v >> 32;
v + 1
}
fn fft_radix2(data: &mut [Complex], inverse: bool) {
let n = data.len();
if n <= 1 {
return;
}
debug_assert!(is_power_of_2(n));
let mut j: usize = 0;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let sign = if inverse { 1.0 } else { -1.0 };
let mut size = 2;
while size <= n {
let half = size / 2;
let angle = sign * 2.0 * PI / size as f64;
let w_base = Complex::new(angle.cos(), angle.sin());
let mut k = 0;
while k < n {
let mut w = Complex::new(1.0, 0.0);
for jj in 0..half {
let u = data[k + jj];
let t = w.mul(data[k + jj + half]);
data[k + jj] = u.add(t);
data[k + jj + half] = u.sub(t);
w = w.mul(w_base);
}
k += size;
}
size <<= 1;
}
}
fn fft_bluestein(input: &[Complex], inverse: bool) -> Vec<Complex> {
let n = input.len();
if n == 0 {
return vec![];
}
if n == 1 {
return input.to_vec();
}
if is_power_of_2(n) {
let mut data = input.to_vec();
fft_radix2(&mut data, inverse);
return data;
}
let sign = if inverse { 1.0 } else { -1.0 };
let mut chirp = Vec::with_capacity(n);
for k in 0..n {
let angle = sign * PI * (k as f64 * k as f64) / n as f64;
chirp.push(Complex::new(angle.cos(), angle.sin()));
}
let m = next_power_of_2(2 * n - 1);
let mut a = vec![Complex::new(0.0, 0.0); m];
for k in 0..n {
let conj_chirp = Complex::new(chirp[k].re, -chirp[k].im);
a[k] = input[k].mul(conj_chirp);
}
let mut b = vec![Complex::new(0.0, 0.0); m];
b[0] = chirp[0];
for k in 1..n {
b[k] = chirp[k];
b[m - k] = chirp[k];
}
fft_radix2(&mut a, false);
fft_radix2(&mut b, false);
for i in 0..m {
a[i] = a[i].mul(b[i]);
}
fft_radix2(&mut a, true);
let scale = 1.0 / m as f64;
let mut result = Vec::with_capacity(n);
for k in 0..n {
let scaled = Complex::new(a[k].re * scale, a[k].im * scale);
let conj_chirp = Complex::new(chirp[k].re, -chirp[k].im);
result.push(scaled.mul(conj_chirp));
}
result
}
#[no_mangle]
pub unsafe extern "C" fn sci_fft_forward(
real_in: *const f64,
imag_in: *const f64,
len: usize,
out: *mut SciComplexVector,
) -> SciResult {
if out.is_null() {
return SciResult::err("sci_fft_forward: out pointer is null");
}
if real_in.is_null() && len > 0 {
return SciResult::err("sci_fft_forward: real_in is null but len > 0");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let real_slice = if len == 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(real_in, len) }
};
let imag_slice: Vec<f64> = if imag_in.is_null() {
vec![0.0; len]
} else if len == 0 {
vec![]
} else {
unsafe { std::slice::from_raw_parts(imag_in, len) }.to_vec()
};
let input: Vec<Complex> = real_slice
.iter()
.zip(imag_slice.iter())
.map(|(&r, &i)| Complex::new(r, i))
.collect();
let output = fft_bluestein(&input, false);
let real_out: Vec<f64> = output.iter().map(|c| c.re).collect();
let imag_out: Vec<f64> = output.iter().map(|c| c.im).collect();
SciComplexVector::from_vecs(real_out, imag_out)
.ok_or_else(|| "sci_fft_forward: internal error creating result".to_string())
}));
match result {
Ok(Ok(cv)) => {
unsafe { ptr::write(out, cv) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_fft_inverse(
real_in: *const f64,
imag_in: *const f64,
len: usize,
out: *mut SciComplexVector,
) -> SciResult {
if out.is_null() {
return SciResult::err("sci_fft_inverse: out pointer is null");
}
if real_in.is_null() && len > 0 {
return SciResult::err("sci_fft_inverse: real_in is null but len > 0");
}
let result = catch_unwind(AssertUnwindSafe(|| {
let real_slice = if len == 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(real_in, len) }
};
let imag_slice: Vec<f64> = if imag_in.is_null() {
vec![0.0; len]
} else if len == 0 {
vec![]
} else {
unsafe { std::slice::from_raw_parts(imag_in, len) }.to_vec()
};
let input: Vec<Complex> = real_slice
.iter()
.zip(imag_slice.iter())
.map(|(&r, &i)| Complex::new(r, i))
.collect();
let mut output = fft_bluestein(&input, true);
let n = output.len() as f64;
if n > 0.0 {
for c in &mut output {
c.re /= n;
c.im /= n;
}
}
let real_out: Vec<f64> = output.iter().map(|c| c.re).collect();
let imag_out: Vec<f64> = output.iter().map(|c| c.im).collect();
SciComplexVector::from_vecs(real_out, imag_out)
.ok_or_else(|| "sci_fft_inverse: internal error creating result".to_string())
}));
match result {
Ok(Ok(cv)) => {
unsafe { ptr::write(out, cv) };
SciResult::ok()
}
Ok(Err(msg)) => SciResult::err(&msg),
Err(e) => SciResult::from_panic(e),
}
}
#[no_mangle]
pub unsafe extern "C" fn sci_rfft(
real_in: *const f64,
len: usize,
out: *mut SciComplexVector,
) -> SciResult {
unsafe { sci_fft_forward(real_in, ptr::null(), len, out) }
}