use oxifft::{irfft, rfft, Complex};
use std::f64::consts::PI;
fn fft_real_full(signal: &[f64]) -> Vec<Complex<f64>> {
let input: Vec<Complex<f64>> = signal.iter().map(|&v| Complex::new(v, 0.0)).collect();
oxifft::fft(&input)
}
fn ifft_to_real_full(spectrum: &[Complex<f64>]) -> Vec<f64> {
let output = oxifft::ifft(spectrum);
output.into_iter().map(|c| c.re).collect()
}
fn rfft_cross_power(fa: &[Complex<f64>], fb: &[Complex<f64>]) -> Vec<Complex<f64>> {
fa.iter()
.zip(fb.iter())
.map(|(a, b)| {
let prod = a.conj() * *b;
let mag = prod.norm();
if mag < f64::EPSILON {
Complex::new(0.0, 0.0)
} else {
prod / mag
}
})
.collect()
}
#[must_use]
pub fn dft_1d_naive(signal: &[f64]) -> Vec<(f64, f64)> {
let n = signal.len();
if n == 0 {
return Vec::new();
}
(0..n)
.map(|k| {
let (mut re, mut im) = (0.0_f64, 0.0_f64);
for (j, &xj) in signal.iter().enumerate() {
let angle = -2.0 * PI * k as f64 * j as f64 / n as f64;
re += xj * angle.cos();
im += xj * angle.sin();
}
(re, im)
})
.collect()
}
#[must_use]
pub fn idft_1d_naive(spectrum: &[(f64, f64)]) -> Vec<f64> {
let n = spectrum.len();
if n == 0 {
return Vec::new();
}
(0..n)
.map(|j| {
let (mut re, mut _im) = (0.0_f64, 0.0_f64);
for (k, &(sk_re, sk_im)) in spectrum.iter().enumerate() {
let angle = 2.0 * PI * k as f64 * j as f64 / n as f64;
re += sk_re * angle.cos() - sk_im * angle.sin();
_im += sk_re * angle.sin() + sk_im * angle.cos();
}
re / n as f64
})
.collect()
}
#[must_use]
pub fn dft_1d(signal: &[f64]) -> Vec<(f64, f64)> {
dft_1d_naive(signal)
}
#[must_use]
pub fn idft_1d(spectrum: &[(f64, f64)]) -> Vec<f64> {
idft_1d_naive(spectrum)
}
#[must_use]
pub fn cross_power_spectrum(a: &[(f64, f64)], b: &[(f64, f64)]) -> Vec<(f64, f64)> {
a.iter()
.zip(b.iter())
.map(|(&(ar, ai), &(br, bi))| {
let prod_r = ar * br + ai * bi; let prod_i = ar * bi - ai * br;
let mag = (prod_r * prod_r + prod_i * prod_i).sqrt();
if mag < f64::EPSILON {
(0.0, 0.0)
} else {
(prod_r / mag, prod_i / mag)
}
})
.collect()
}
#[must_use]
pub fn find_peak_index(signal: &[f64]) -> usize {
signal
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i)
}
#[must_use]
pub fn interpolate_peak(signal: &[f64], peak_idx: usize) -> f64 {
let n = signal.len();
if n < 3 || peak_idx == 0 || peak_idx >= n - 1 {
return peak_idx as f64;
}
let y_m1 = signal[peak_idx - 1];
let y_0 = signal[peak_idx];
let y_p1 = signal[peak_idx + 1];
let denom = 2.0 * (2.0 * y_0 - y_m1 - y_p1);
if denom.abs() < f64::EPSILON {
return peak_idx as f64;
}
let delta = (y_m1 - y_p1) / denom;
peak_idx as f64 + delta
}
#[must_use]
pub fn phase_correlate_1d(a: &[f64], b: &[f64]) -> f64 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
let n = a.len();
let fa: Vec<Complex<f64>> = rfft(a);
let fb: Vec<Complex<f64>> = rfft(b);
let cps = rfft_cross_power(&fa, &fb);
let corr: Vec<f64> = irfft(&cps, n);
let peak_idx = find_peak_index(&corr);
let sub_peak = interpolate_peak(&corr, peak_idx);
let nf = n as f64;
if sub_peak > nf / 2.0 {
sub_peak - nf
} else {
sub_peak
}
}
#[must_use]
pub fn phase_correlate_1d_full_complex(a: &[f64], b: &[f64]) -> f64 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
let fa: Vec<Complex<f64>> = fft_real_full(a);
let fb: Vec<Complex<f64>> = fft_real_full(b);
let cps: Vec<Complex<f64>> = fa
.iter()
.zip(fb.iter())
.map(|(ca, cb)| {
let prod = ca.conj() * *cb;
let mag = prod.norm();
if mag < f64::EPSILON {
Complex::new(0.0, 0.0)
} else {
prod / mag
}
})
.collect();
let corr: Vec<f64> = ifft_to_real_full(&cps);
let peak_idx = find_peak_index(&corr);
let sub_peak = interpolate_peak(&corr, peak_idx);
let n = a.len() as f64;
if sub_peak > n / 2.0 {
sub_peak - n
} else {
sub_peak
}
}
#[must_use]
pub fn phase_correlate_1d_naive(a: &[f64], b: &[f64]) -> f64 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
let fa = dft_1d_naive(a);
let fb = dft_1d_naive(b);
let cps = cross_power_spectrum(&fa, &fb);
let corr = idft_1d_naive(&cps);
let peak_idx = find_peak_index(&corr);
let sub_peak = interpolate_peak(&corr, peak_idx);
let n = a.len() as f64;
if sub_peak > n / 2.0 {
sub_peak - n
} else {
sub_peak
}
}
#[must_use]
pub fn phase_correlate_2d(a: &[f64], b: &[f64], width: usize, height: usize) -> (f64, f64) {
if a.len() != width * height || b.len() != width * height {
return (0.0, 0.0);
}
let a_row_sum: Vec<f64> = (0..width)
.map(|x| (0..height).map(|y| a[y * width + x]).sum())
.collect();
let b_row_sum: Vec<f64> = (0..width)
.map(|x| (0..height).map(|y| b[y * width + x]).sum())
.collect();
let a_col_sum: Vec<f64> = (0..height)
.map(|y| (0..width).map(|x| a[y * width + x]).sum())
.collect();
let b_col_sum: Vec<f64> = (0..height)
.map(|y| (0..width).map(|x| b[y * width + x]).sum())
.collect();
let dx = phase_correlate_1d(&a_row_sum, &b_row_sum);
let dy = phase_correlate_1d(&a_col_sum, &b_col_sum);
(dx, dy)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dft_empty() {
assert!(dft_1d(&[]).is_empty());
}
#[test]
fn test_idft_empty() {
assert!(idft_1d(&[]).is_empty());
}
#[test]
fn test_dft_idft_round_trip() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let spectrum = dft_1d(&signal);
let recovered = idft_1d(&spectrum);
assert_eq!(recovered.len(), signal.len());
for (a, b) in signal.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-9, "{a} ≠ {b}");
}
}
#[test]
fn test_dft_dc_component() {
let signal = vec![2.0_f64; 4];
let spectrum = dft_1d(&signal);
assert!((spectrum[0].0 - 8.0).abs() < 1e-9); assert!(spectrum[0].1.abs() < 1e-9); assert!(spectrum[1].0.abs() < 1e-9);
assert!(spectrum[2].0.abs() < 1e-9);
}
#[test]
fn test_cross_power_spectrum_identical() {
let s = vec![(1.0_f64, 0.0_f64), (0.5, 0.5)];
let cps = cross_power_spectrum(&s, &s);
for (re, im) in &cps {
let mag = (re * re + im * im).sqrt();
assert!((mag - 1.0).abs() < 1e-9 || mag.abs() < 1e-9);
}
}
#[test]
fn test_cross_power_spectrum_zero_element() {
let a = vec![(0.0_f64, 0.0_f64)];
let b = vec![(1.0_f64, 0.0_f64)];
let cps = cross_power_spectrum(&a, &b);
assert_eq!(cps[0], (0.0, 0.0));
}
#[test]
fn test_find_peak_index_basic() {
let s = vec![0.1, 0.5, 0.9, 0.3];
assert_eq!(find_peak_index(&s), 2);
}
#[test]
fn test_find_peak_index_empty() {
assert_eq!(find_peak_index(&[]), 0);
}
#[test]
fn test_interpolate_peak_boundary_returns_index() {
let s = vec![0.1, 0.5, 0.9, 0.3, 0.1];
let result = interpolate_peak(&s, 4);
assert_eq!(result, 4.0);
}
#[test]
fn test_interpolate_peak_symmetric_returns_exact() {
let s = vec![0.25_f64, 0.75, 1.0, 0.75, 0.25];
let peak = find_peak_index(&s);
let refined = interpolate_peak(&s, peak);
assert!((refined - 2.0).abs() < 1e-9);
}
#[test]
fn test_phase_correlate_1d_identical_signals() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.5];
let offset = phase_correlate_1d(&signal, &signal);
assert!(offset.abs() < 0.5, "Offset for identical signals: {offset}");
}
#[test]
fn test_phase_correlate_1d_empty() {
assert_eq!(phase_correlate_1d(&[], &[]), 0.0);
}
#[test]
fn test_phase_correlate_1d_mismatched_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(phase_correlate_1d(&a, &b), 0.0);
}
#[test]
fn test_phase_correlate_2d_identical_images() {
let img = vec![
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 100.0, 90.0,
80.0, 70.0,
];
let (dx, dy) = phase_correlate_2d(&img, &img, 4, 4);
assert!(
dx.abs() < 0.5 && dy.abs() < 0.5,
"dx={dx}, dy={dy} for identical images"
);
}
#[test]
fn test_phase_correlate_2d_mismatched_size() {
let a = vec![1.0; 16];
let b = vec![1.0; 9];
let (dx, dy) = phase_correlate_2d(&a, &b, 4, 4);
assert_eq!((dx, dy), (0.0, 0.0));
}
#[test]
fn test_oxifft_phase_correlate_matches_naive() {
let n = 128usize;
let known_shift = 3usize;
let a: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
let env = (-50.0 * (t - 0.5).powi(2)).exp();
env * (2.0 * PI * 8.0 * t).cos()
})
.collect();
let b: Vec<f64> = (0..n).map(|i| a[(i + n - known_shift) % n]).collect();
let rfft_shift = phase_correlate_1d(&a, &b);
let naive_shift = phase_correlate_1d_naive(&a, &b);
assert!(
(rfft_shift - known_shift as f64).abs() < 0.5,
"rFFT shift={rfft_shift:.4}, expected ≈{known_shift}"
);
assert!(
(rfft_shift - naive_shift).abs() < 0.5,
"rFFT shift={rfft_shift:.4} diverges from naïve={naive_shift:.4}"
);
}
#[test]
fn test_rfft_matches_full_complex_integer_shift() {
let n = 128usize;
let known_shift = 3usize;
let a: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
let env = (-50.0 * (t - 0.5).powi(2)).exp();
env * (2.0 * PI * 8.0 * t).cos()
})
.collect();
let b: Vec<f64> = (0..n).map(|i| a[(i + n - known_shift) % n]).collect();
let rfft_shift = phase_correlate_1d(&a, &b);
let full_shift = phase_correlate_1d_full_complex(&a, &b);
assert!(
(rfft_shift - known_shift as f64).abs() < 0.5,
"rFFT shift={rfft_shift:.4}, expected ≈{known_shift}"
);
assert!(
(full_shift - known_shift as f64).abs() < 0.5,
"full-complex shift={full_shift:.4}, expected ≈{known_shift}"
);
assert!(
(rfft_shift - full_shift).abs() < 0.5,
"rFFT ({rfft_shift:.4}) diverges from full-complex ({full_shift:.4})"
);
}
#[test]
fn test_rfft_sub_pixel_shift() {
let n = 256usize;
let dummy: Vec<f64> = vec![1.0; n];
let n_half = rfft(&dummy).len();
assert_eq!(n_half, n / 2 + 1, "rfft half-spectrum length mismatch");
let a: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
let env = (-200.0 * (t - 0.5).powi(2)).exp();
env * (2.0 * PI * 16.0 * t).cos()
})
.collect();
let shift_int = 2usize;
let b_int: Vec<f64> = (0..n).map(|i| a[(i + n - shift_int) % n]).collect();
let rfft_result = phase_correlate_1d(&a, &b_int);
let full_result = phase_correlate_1d_full_complex(&a, &b_int);
assert!(
(rfft_result - full_result).abs() < 0.5,
"rFFT ({rfft_result:.4}) disagrees with full-complex ({full_result:.4})"
);
assert!(
(rfft_result - shift_int as f64).abs() < 0.5,
"rFFT shift={rfft_result:.4}, expected ≈{shift_int}"
);
let frac_part = (rfft_result - rfft_result.round()).abs();
let _ = frac_part;
}
#[test]
fn test_rfft_zero_shift() {
let n = 128usize;
let a: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
let env = (-50.0 * (t - 0.5).powi(2)).exp();
env * (2.0 * PI * 8.0 * t).cos()
})
.collect();
let shift = phase_correlate_1d(&a, &a);
assert!(
shift.abs() < 0.5,
"Zero-shift test: expected shift ≈ 0.0, got {shift:.4}"
);
}
#[test]
fn test_rfft_noise_robustness() {
let n = 256usize;
let known_shift = 5usize;
let mut state: u64 = 0xFEED_FACE_CAFE_BABE_u64;
let signal_a: Vec<f64> = (0..n)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(state >> 11) as f64 / (1u64 << 52) as f64 - 1.0
})
.collect();
let mut signal_b: Vec<f64> = (0..n)
.map(|i| signal_a[(i + n - known_shift) % n])
.collect();
let noise_amp = 0.05_f64;
let mut noise_state: u64 = 0xDEAD_BEEF_CAFE_1234_u64;
for sample in &mut signal_b {
noise_state = noise_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let uniform = (noise_state >> 11) as f64 / (1u64 << 52) as f64 - 1.0;
*sample += noise_amp * uniform;
}
let shift = phase_correlate_1d(&signal_a, &signal_b);
assert!(
(shift - known_shift as f64).abs() < 1.0,
"Noise robustness: expected shift ≈ {known_shift}, got {shift:.4}"
);
}
}