frinZ 4.4.0

fringe search for Yamaguchi Interferometer and Japanese VLBI Network
Documentation
use ndarray::prelude::*;
use num_complex::Complex;
use rustfft::FftPlanner;
use std::f64::consts::PI;

type C32 = Complex<f32>;

#[derive(Clone, Copy)]
struct PhaseCorrection {
    rate_hz: f32,
    delay_samples: f32,
    acel_hz: f32,
    effective_integration_length: f32,
    start_time_offset_sec: f32,
}

impl PhaseCorrection {
    fn is_enabled(self) -> bool {
        self.rate_hz != 0.0 || self.delay_samples != 0.0 || self.acel_hz != 0.0
    }

    fn is_valid_for(self, sampling_speed: u32, fft_point: u32) -> bool {
        self.is_enabled()
            && sampling_speed > 0
            && fft_point >= 2
            && (self.effective_integration_length as f64).abs() > 1e-9
    }
}

pub fn process_fft(
    complex_vec: &[C32],
    physical_length: i32,
    fft_point: i32,
    sampling_speed: i32,
    rfi_ranges: &[(usize, usize)],
    rate_padding: u32,
) -> (Array2<C32>, usize) {
    process_fft_impl(
        complex_vec,
        physical_length,
        fft_point,
        sampling_speed,
        rfi_ranges,
        rate_padding,
        None,
    )
}

pub fn process_fft_with_phase_correction(
    complex_vec: &[C32],
    physical_length: i32,
    fft_point: i32,
    sampling_speed: i32,
    rfi_ranges: &[(usize, usize)],
    rate_padding: u32,
    rate_hz_for_correction: f32,
    delay_samples_for_correction: f32,
    acel_hz_for_correction: f32,
    effective_integration_length: f32,
    start_time_offset_sec: f32,
) -> (Array2<C32>, usize) {
    let phase = PhaseCorrection {
        rate_hz: rate_hz_for_correction,
        delay_samples: delay_samples_for_correction,
        acel_hz: acel_hz_for_correction,
        effective_integration_length,
        start_time_offset_sec,
    };
    process_fft_impl(
        complex_vec,
        physical_length,
        fft_point,
        sampling_speed,
        rfi_ranges,
        rate_padding,
        Some(phase),
    )
}

fn process_fft_impl(
    complex_vec: &[C32],
    physical_length: i32,
    fft_point: i32,
    sampling_speed: i32,
    rfi_ranges: &[(usize, usize)],
    rate_padding: u32,
    phase_correction: Option<PhaseCorrection>,
) -> (Array2<C32>, usize) {
    let fft_point_half = (fft_point / 2) as usize;
    let rows = if fft_point_half == 0 {
        0
    } else {
        complex_vec.len() / fft_point_half
    };
    let base_length = rows.max(1);
    let mut padding_length = base_length.saturating_mul(rate_padding.max(1) as usize);
    if base_length == 1 {
        padding_length = padding_length.saturating_mul(2);
    }
    let padding_length_half = padding_length / 2;
    let length_f32 = if physical_length > 0 {
        physical_length as f32
    } else {
        1.0
    };
    let fft_scale = if length_f32 > 0.0 {
        fft_point as f32 / length_f32
    } else {
        1.0
    };
    let bandwidth_hz = sampling_speed as f32 / 2.0;
    let bandwidth_mhz = bandwidth_hz / 1_000_000.0;
    let power_scale = if bandwidth_mhz > 0.0 {
        512.0 / bandwidth_mhz
    } else {
        1.0
    };
    let scale_factor = fft_scale * power_scale;

    let mut planner = FftPlanner::new();
    let fft = planner.plan_fft_forward(padding_length);

    let mut freq_rate_array = Array2::<C32>::zeros((fft_point_half, padding_length));
    let mut fft_exe = vec![C32::new(0.0, 0.0); padding_length];
    let mut rfi_mask = vec![false; fft_point_half];
    for &(min, max) in rfi_ranges {
        if min >= fft_point_half {
            continue;
        }
        let end = max.min(fft_point_half.saturating_sub(1));
        if end < min {
            continue;
        }
        for masked in &mut rfi_mask[min..=end] {
            *masked = true;
        }
    }

    let phase_factors = phase_correction.and_then(|phase| {
        build_phase_factors(
            phase,
            fft_point_half,
            rows,
            sampling_speed as u32,
            fft_point as u32,
        )
    });

    for i in 1..fft_point_half {
        if rfi_mask[i] {
            continue;
        }

        for j in 0..rows {
            let mut sample = complex_vec[j * fft_point_half + i];
            if let Some((delay_factors, row_factors)) = &phase_factors {
                sample *= row_factors[j] * delay_factors[i];
            }
            fft_exe[j] = sample;
        }
        fft_exe[rows..].fill(C32::new(0.0, 0.0));

        fft.process(&mut fft_exe);

        let (first_half, second_half) = fft_exe.split_at(padding_length_half);
        let mut row = freq_rate_array.row_mut(i);
        for (dst, src) in row
            .iter_mut()
            .zip(second_half.iter().chain(first_half.iter()))
        {
            *dst = *src * scale_factor;
        }
    }

    (freq_rate_array, padding_length)
}

fn build_phase_factors(
    phase: PhaseCorrection,
    fft_point_half: usize,
    rows: usize,
    sampling_speed: u32,
    fft_point: u32,
) -> Option<(Vec<C32>, Vec<C32>)> {
    if !phase.is_valid_for(sampling_speed, fft_point) {
        return None;
    }

    let freq_resolution_hz = sampling_speed as f64 / fft_point as f64;
    let delay_seconds = phase.delay_samples as f64 / sampling_speed as f64;
    let delay_factors = (0..fft_point_half)
        .map(|col| {
            let angle = -2.0 * PI * delay_seconds * col as f64 * freq_resolution_hz;
            C32::new(angle.cos() as f32, angle.sin() as f32)
        })
        .collect();

    let row_factors = (0..rows)
        .map(|row_idx| {
            let time_sec = row_idx as f64 * phase.effective_integration_length as f64
                + phase.start_time_offset_sec as f64;
            let angle = -2.0 * PI * phase.rate_hz as f64 * time_sec
                - PI * phase.acel_hz as f64 * time_sec.powi(2);
            C32::new(angle.cos() as f32, angle.sin() as f32)
        })
        .collect();

    Some((delay_factors, row_factors))
}

pub fn process_ifft(
    freq_rate_array: &Array2<C32>,
    fft_point: i32,
    padding_length: usize,
) -> Array2<C32> {
    let fft_point_usize = fft_point as usize;
    let mut delay_rate_array = Array2::<C32>::zeros((padding_length, fft_point_usize));
    let mut planner = FftPlanner::new();
    let ifft = planner.plan_fft_inverse(fft_point_usize);
    let mut ifft_exe = vec![C32::new(0.0, 0.0); fft_point_usize];
    let freq_bins = freq_rate_array.dim().0.min(fft_point_usize);
    let scale = fft_point_usize as f32;

    for i in 0..freq_rate_array.dim().1 {
        for (dst, src) in ifft_exe[..freq_bins]
            .iter_mut()
            .zip(freq_rate_array.column(i).iter().take(freq_bins))
        {
            *dst = *src;
        }
        ifft_exe[freq_bins..].fill(C32::new(0.0, 0.0));

        ifft.process(&mut ifft_exe);

        let half = fft_point_usize / 2;
        let (first_half, second_half) = ifft_exe.split_at(half);
        let mut row = delay_rate_array.row_mut(i);
        for (dst, src) in row.iter_mut().take(half).zip(first_half.iter().rev()) {
            *dst = *src / scale;
        }
        for (dst, src) in row.iter_mut().skip(half).zip(second_half.iter().rev()) {
            *dst = *src / scale;
        }
    }

    delay_rate_array
}

pub fn perform_ifft_on_vec(input: &[C32], ifft_size: usize) -> Vec<C32> {
    let mut planner = FftPlanner::new();
    let ifft = planner.plan_fft_inverse(ifft_size);

    let mut ifft_exe = vec![C32::new(0.0, 0.0); ifft_size];
    ifft_exe[..input.len()].copy_from_slice(input);

    ifft.process(&mut ifft_exe);

    let mut shifted_out = vec![C32::new(0.0, 0.0); ifft_size];
    let half = ifft_size / 2;
    let (first_half, second_half) = ifft_exe.split_at(half);
    let scale = ifft_size as f32;
    for (dst, src) in shifted_out
        .iter_mut()
        .take(first_half.len())
        .zip(first_half.iter().rev())
    {
        *dst = *src / scale;
    }
    for (dst, src) in shifted_out
        .iter_mut()
        .skip(first_half.len())
        .zip(second_half.iter().rev())
    {
        *dst = *src / scale;
    }

    shifted_out
}

pub fn apply_phase_correction_in_place(
    data: &mut [C32],
    fft_point_half: usize,
    rate_hz_for_correction: f32,
    delay_samples_for_correction: f32,
    acel_hz_for_correction: f32,
    effective_integration_length: f32,
    sampling_speed: u32,
    fft_point: u32,
    start_time_offset_sec: f32,
) {
    if data.is_empty() || fft_point_half == 0 || data.len() % fft_point_half != 0 {
        return;
    }

    let rows = data.len() / fft_point_half;
    let phase = PhaseCorrection {
        rate_hz: rate_hz_for_correction,
        delay_samples: delay_samples_for_correction,
        acel_hz: acel_hz_for_correction,
        effective_integration_length,
        start_time_offset_sec,
    };
    let Some((delay_factors, row_factors)) =
        build_phase_factors(phase, fft_point_half, rows, sampling_speed, fft_point)
    else {
        return;
    };

    for (row_idx, row) in data.chunks_mut(fft_point_half).enumerate() {
        for (sample, delay_factor) in row.iter_mut().zip(delay_factors.iter()) {
            *sample *= row_factors[row_idx] * *delay_factor;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fused_phase_correction_fft_matches_pre_corrected_fft() {
        let fft_point = 8;
        let fft_point_half = (fft_point / 2) as usize;
        let rows = 5usize;
        let input: Vec<C32> = (0..rows * fft_point_half)
            .map(|idx| C32::new(idx as f32 * 0.25 + 1.0, idx as f32 * -0.125))
            .collect();
        let mut corrected = input.clone();
        apply_phase_correction_in_place(
            &mut corrected,
            fft_point_half,
            0.03,
            0.2,
            0.001,
            0.5,
            32_000_000,
            fft_point as u32,
            0.25,
        );

        let (expected, expected_padding) =
            process_fft(&corrected, rows as i32, fft_point, 32_000_000, &[], 1);
        let (actual, actual_padding) = process_fft_with_phase_correction(
            &input,
            rows as i32,
            fft_point,
            32_000_000,
            &[],
            1,
            0.03,
            0.2,
            0.001,
            0.5,
            0.25,
        );

        assert_eq!(expected_padding, actual_padding);
        assert_eq!(expected.dim(), actual.dim());
        for (expected, actual) in expected.iter().zip(actual.iter()) {
            assert!((expected.re - actual.re).abs() < 1.0e-4);
            assert!((expected.im - actual.im).abs() < 1.0e-4);
        }
    }
}