mrrs 0.1.1

Multi rate filtering tools
Documentation
/// Implement a 3-band channelizer with 50% overlap
///
/// Goal is to build a channelizer that takes an input, and returns data that
/// has been tuned by -fs/4, 0, and fs/4, and then ran through a half bandwidth
/// filter.
///
/// Multiple of those channelizers can be cascaded together, though some care
/// needs to be taken to remove duplication. Since the highest channel of one
/// band is the same as the lowest channel of another we only compute the
/// highest channel output on the last leg.
use super::{convolve, shift_down_fs_over_4, shift_up_fs_over_4};
use ndarray::{Array2, ArrayView2, ArrayViewMut2};
use num_complex::Complex;
use num_traits::{Float, NumAssignOps};

/// Perform a single-stage 3-band channelization with 50% overlap
///
/// Uses the naive tuning / convolving implementation to make a clear easily verifiable test case.
pub fn triband_cascade_noalloc<T: Float + NumAssignOps>(
    inp: ArrayView2<Complex<T>>,
    mut out: ArrayViewMut2<Complex<T>>,
    flt: &Vec<T>,
    integer_phase_offset: usize,
) {
    let n_inp_rows = inp.shape()[0];
    let n_inp_cols = inp.shape()[1];

    let n_out_rows = out.shape()[0];
    let n_out_cols = out.shape()[1];

    assert_eq!(n_out_rows, n_inp_rows - flt.len() + 1);
    assert_eq!(n_out_cols, n_inp_cols * 2 + 1);

    for icol in 0..n_inp_cols {
        //Making a copy to use simple naive implementations.
        let col_vec = inp.column(icol).to_vec();
        let col_slice = col_vec.as_slice();

        //While this says no alloc, we're clearly allocating here. The main goal was to get the
        //function signature right for future non-naive implementations.
        let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, integer_phase_offset);

        let low_flt = convolve(&tuned, &flt);
        let mid_flt = convolve(col_slice, &flt);

        assert_eq!(n_out_rows, low_flt.len());

        //Copying again to use the simple tuning / convolving methods.
        for irow in 0..n_out_rows {
            out[(irow, icol * 2 + 0)] = low_flt[irow];
            out[(irow, icol * 2 + 1)] = mid_flt[irow];
        }

        if icol == n_inp_cols - 1 {
            let high_tuned = shift_up_fs_over_4(&col_slice, integer_phase_offset);
            let high_flt = convolve(&high_tuned, &flt);
            for irow in 0..n_out_rows {
                out[(irow, icol * 2 + 2)] = high_flt[irow];
            }
        }
    }
}

pub fn triband_cascade_decimate_noalloc<T: Float + NumAssignOps>(
    inp: ArrayView2<Complex<T>>,
    mut out: ArrayViewMut2<Complex<T>>,
    flt: &Vec<T>,
    integer_phase_offset: usize,
    start_row: bool,
) {
    let n_inp_rows = inp.shape()[0];
    let n_inp_cols = inp.shape()[1];

    // Determine the starting index based on the start_row parameter
    let start_index = if start_row { 0 } else { 1 };

    // Calculate the number of decimated rows
    let decimated_rows = (n_inp_rows - start_index + 1) / 2;
    let mut decimated_inp = Array2::<Complex<T>>::zeros((decimated_rows, n_inp_cols));

    for icol in 0..n_inp_cols {
        for irow in 0..decimated_rows {
            decimated_inp[(irow, icol)] = inp[(irow * 2 + start_index, icol)];
        }
    }

    let n_out_rows = decimated_inp.shape()[0] - flt.len() + 1;
    let n_out_cols = decimated_inp.shape()[1] * 2 + 1;

    assert_eq!(n_out_rows, out.shape()[0]);
    assert_eq!(n_out_cols, out.shape()[1]);

    for icol in 0..decimated_inp.shape()[1] {
        let col_vec = decimated_inp.column(icol).to_vec();
        let col_slice = col_vec.as_slice();

        let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, integer_phase_offset);

        let low_flt = convolve(&tuned, &flt);
        let mid_flt = convolve(col_slice, &flt);

        assert_eq!(n_out_rows, low_flt.len());

        for irow in 0..n_out_rows {
            out[(irow, icol * 2 + 0)] = low_flt[irow];
            out[(irow, icol * 2 + 1)] = mid_flt[irow];
        }

        if icol == decimated_inp.shape()[1] - 1 {
            let high_tuned = shift_up_fs_over_4(&col_slice, integer_phase_offset);
            let high_flt = convolve(&high_tuned, &flt);
            for irow in 0..n_out_rows {
                out[(irow, icol * 2 + 2)] = high_flt[irow];
            }
        }
    }
}

pub fn triband_cascade<T: Float + NumAssignOps>(
    inp: ArrayView2<Complex<T>>,
    flt: &Vec<T>,
    integer_phase_offset: usize,
) -> Array2<Complex<T>> {
    let n_inp_rows = inp.shape()[0];
    let n_inp_cols = inp.shape()[1];

    let n_out_rows = n_inp_rows - flt.len() + 1;
    let n_out_cols = n_inp_cols * 2 + 1;

    let mut out = Array2::<Complex<T>>::zeros((n_out_rows, n_out_cols));
    triband_cascade_noalloc(inp, out.view_mut(), flt, integer_phase_offset);
    out
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::get_hb_filter;
    use approx::assert_relative_eq;
    use ndarray::{Array1, Array2, Axis};
    use rustfft::{Fft, FftPlanner};
    use std::f64::consts::PI;
    use std::sync::Arc;

    fn find_local_maxima(vv: &Vec<Complex<f32>>) -> Vec<usize> {
        let mut idxs = vec![];
        for ii in 0..vv.len() {
            let lidx = if ii == 0 { vv.len() - 1 } else { ii - 1 };
            let uidx = if ii == vv.len() - 1 { 0 } else { ii + 1 };
            let nrm = vv[ii].norm_sqr();
            if vv[lidx].norm_sqr() <= nrm && vv[uidx].norm_sqr() <= nrm {
                idxs.push(ii);
            }
        }
        idxs
    }
    #[test]
    fn test_run_filter() {
        let flt = get_hb_filter::<f32>(31);
        let sigs = vec![
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(0.5, (-0.17 * 2.0 * PI * nn as f64) as f32)
            })),
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(1.0, (0.35 * 2.0 * PI * nn as f64) as f32)
            })),
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(2.0, (0.21 * 2.0 * PI * nn as f64) as f32)
            })),
        ];
        let mut sum = sigs[0].clone();
        for arr in &sigs[1..] {
            sum += arr;
        }
        let sum_col: Array2<Complex<f32>> = sum.insert_axis(ndarray::Axis(1));

        let flt1 = triband_cascade(sum_col.view(), &flt, 0);

        let mut planner = FftPlanner::new();
        let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt1.shape()[0]);
        for (icol, col) in flt1.axis_iter(Axis(1)).enumerate() {
            let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
            assert_eq!(buffer.len(), 32768-30);
            fft.process(&mut buffer);
            let mut idxs: Vec<usize> = find_local_maxima(&buffer);
            idxs.sort_by(|&i, &j| {
                buffer[j]
                    .norm_sqr()
                    .partial_cmp(&buffer[i].norm_sqr())
                    .unwrap_or(std::cmp::Ordering::Equal)
            });
            let mut max = Complex::<f32>::new(0.0, 0.0);
            let mut max_idx = 0;
            for ii in 0..buffer.len() {
                if buffer[ii].norm_sqr() > max.norm_sqr() {
                    max = buffer[ii];
                    max_idx = ii;
                }
            }

            assert!(idxs.contains(&max_idx));

            for imax in 0..3 {
                let max_idx = idxs[imax];
                let max_freq = if max_idx < buffer.len() / 2 {
                    max_idx as f64 / buffer.len() as f64
                } else {
                    (max_idx as f64 / buffer.len() as f64) - 1.0
                };
                let max_db =
                    (buffer[max_idx].norm_sqr() / (buffer.len() as f32).powi(2)).log10() * 10.0;
                println!("{} {} {} {} {}", icol, imax, max_idx, max_freq, max_db);
                let exp_freq = match icol * 10 + imax {
                    00 => Some(-0.17 + 0.25),
                    10 => Some(0.21),
                    11 => Some(-0.17),
                    12 => Some(0.35),
                    20 => Some(0.21 - 0.25),
                    21 => Some(0.35 - 0.25),
                    _ => None,
                };
                if exp_freq.is_some() {
                    //1 / 30e3 is approximately the FFT bin resolution.
                    assert_relative_eq!(max_freq, exp_freq.unwrap(), epsilon = 1.0 / 30e3);
                }
            }
        }

        let flt2 = triband_cascade(flt1.view(), &flt, 0);
        let mut planner = FftPlanner::new();
        let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt2.shape()[0]);
        for (icol, col) in flt2.axis_iter(Axis(1)).enumerate() {
            let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
            assert_eq!(buffer.len(), 32768-60);
            fft.process(&mut buffer);
            let mut idxs: Vec<usize> = find_local_maxima(&buffer);
            idxs.sort_by(|&i, &j| {
                buffer[j]
                    .norm_sqr()
                    .partial_cmp(&buffer[i].norm_sqr())
                    .unwrap_or(std::cmp::Ordering::Equal)
            });
            let mut max = Complex::<f32>::new(0.0, 0.0);
            let mut max_idx = 0;
            for ii in 0..buffer.len() {
                if buffer[ii].norm_sqr() > max.norm_sqr() {
                    max = buffer[ii];
                    max_idx = ii;
                }
            }

            assert!(idxs.contains(&max_idx));

            for imax in 0..3 {
                let max_idx = idxs[imax];
                let max_freq = if max_idx < buffer.len() / 2 {
                    max_idx as f64 / buffer.len() as f64
                } else {
                    (max_idx as f64 / buffer.len() as f64) - 1.0
                };
                let max_db =
                    (buffer[max_idx].norm_sqr() / (buffer.len() as f32).powi(2)).log10() * 10.0;
                println!("{} {} {} {} {}", icol, imax, max_idx, max_freq, max_db);
                let exp_freq = match icol * 10 + imax {
                    30 => Some(0.21),
                    31 => Some(-0.17),
                    _ => None,
                };
                if exp_freq.is_some() {
                    //1 / 30e3 is approximately the FFT bin resolution.
                    assert_relative_eq!(max_freq, exp_freq.unwrap(), epsilon = 1.0 / 30e3);
                }
            }
        }
    }

    #[test]
    fn test_row_major() {
        let nrows = 30;
        let ncols = 5;
        let mut arr = Array2::<Complex<f32>>::zeros((nrows, ncols));
        for ii in 0..nrows {
            for jj in 0..ncols {
                arr[(ii, jj)] = Complex::<f32>::new((ii * ncols + jj) as f32, 0.0);
            }
        }
        assert_eq!(arr.shape()[0], 30);
        assert_eq!(arr.shape()[1], 5);
        let ptr: *const Complex<f32> = arr.as_ptr();
        unsafe {
            for ii in 0..nrows {
                for jj in 0..ncols {
                    let offs = ii * ncols + jj;
                    let val = *(ptr.wrapping_add(offs));
                    assert_eq!(val.re, (ii * ncols + jj) as f32);
                }
            }
        }
        println!("{}", arr);
    }
}