mrrs 0.1.0

Mulitrate Signal Processing
Documentation
/// Implement a 3-band channelizer with 50% overlap
///
/// Goal is to build a channelizer that takes an input, and returns data that
/// has been tuned by -fs/4, 0, and fs/4, and then ran through a half bandwidth
/// filter.
///
/// Multiple of those channelizers can be cascaded together, though some care
/// needs to be taken to remove duplication. Since the highest channel of one
/// band is the same as the lowest channel of another we only compute the
/// highest channel output on the last leg.
use super::{convolve, shift_down_fs_over_4, shift_up_fs_over_4};
use ndarray::{Array2, ArrayView2, ArrayViewMut2};
use num_complex::Complex;
use num_traits::{Float, NumAssignOps};

/// Implement a single stage of the trib
pub fn triband_cascade_noalloc<T: Float + NumAssignOps>(
    inp: ArrayView2<Complex<T>>,
    mut out: ArrayViewMut2<Complex<T>>,
    flt: &Vec<T>,
) {
    let n_inp_rows = inp.shape()[0];
    let n_inp_cols = inp.shape()[1];

    let n_out_rows = out.shape()[0];
    let n_out_cols = out.shape()[1];

    assert_eq!(n_out_rows, n_inp_rows - flt.len() + 1);
    assert_eq!(n_out_cols, n_inp_cols * 2 + 1);

    for icol in 0..n_inp_cols {
        //let col_view = inp.column(icol);
        let col_vec = inp.column(icol).to_vec();
        let col_slice = col_vec.as_slice();
            //.as_slice()
            //.expect("Currently must always be able to get this as a slice unless we change around memory orders or something.");
        let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, 0);
        let low_flt = convolve(&tuned, &flt);
        let mid_flt = convolve(col_slice, &flt);
        assert_eq!(n_out_rows, low_flt.len());
        for irow in 0..n_out_rows {
            out[(irow, icol * 2 + 0)] = low_flt[irow];
            out[(irow, icol * 2 + 1)] = mid_flt[irow];
        }
        if icol == n_inp_cols - 1 {
            let high_tuned = shift_up_fs_over_4(&col_slice, 0);
            let high_flt = convolve(&high_tuned, &flt);
            for irow in 0..n_out_rows {
                out[(irow, icol * 2 + 2)] = high_flt[irow];
            }
        }
    }
}

pub fn triband_cascade<T: Float + NumAssignOps>(
    inp: ArrayView2<Complex<T>>,
    flt: &Vec<T>,
) -> Array2<Complex<T>> {
    let n_inp_rows = inp.shape()[0];
    let n_inp_cols = inp.shape()[1];

    let n_out_rows = n_inp_rows - flt.len() + 1;
    let n_out_cols = n_inp_cols * 2 + 1;

    let mut out = Array2::<Complex<T>>::zeros((n_out_rows, n_out_cols));
    triband_cascade_noalloc(inp, out.view_mut(), flt);
    out
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::get_hb_filter;
    use ndarray::{Array1, Array2, Axis};
    use rustfft::{Fft, FftPlanner};
    use std::f64::consts::PI;
    use std::sync::Arc;

    fn find_local_maxima(vv: &Vec<Complex<f32>>) -> Vec<usize> {
        let mut idxs = vec![];
        for ii in 0..vv.len() {
            let lidx = if ii == 0 { vv.len() - 1 } else { ii - 1 };
            let uidx = if ii == vv.len() - 1 { 0 } else { ii + 1 };
            let nrm = vv[ii].norm_sqr();
            if vv[lidx].norm_sqr() <= nrm && vv[uidx].norm_sqr() <= nrm {
                idxs.push(ii);
            }
        }
        idxs
    }
    #[test]
    fn test_run_filter() {
        let flt = get_hb_filter::<f32>(31);
        let sigs = vec![
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(0.5, (-0.25 * 2.0 * PI * nn as f64) as f32)
            })),
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(1.0, (0.25 * 2.0 * PI * nn as f64) as f32)
            })),
            Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
                Complex::<f32>::from_polar(2.0, (0.0 * 2.0 * PI * nn as f64) as f32)
            })),
        ];
        let mut sum = sigs[0].clone();
        for arr in &sigs[1..] {
            sum += arr;
        }
        let sum_col: Array2<Complex<f32>> = sum.insert_axis(ndarray::Axis(1));

        let flt1 = triband_cascade(sum_col.view(), &flt);
        let flt2 = triband_cascade(flt1.view(), &flt);
        let flt3 = triband_cascade(flt2.view(), &flt);
        let mut planner = FftPlanner::new();
        let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt1.shape()[0]);
        for (_icol, col) in flt1.axis_iter(Axis(1)).enumerate() {
            let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
            fft.process(&mut buffer);
            let mut idxs: Vec<usize> = find_local_maxima(&buffer);
            idxs.sort_by(|&i, &j| {
                buffer[j]
                    .norm_sqr()
                    .partial_cmp(&buffer[i].norm_sqr())
                    .unwrap_or(std::cmp::Ordering::Equal)
            });
            let mut max = Complex::<f32>::new(0.0, 0.0);
            let mut max_idx = 2509384098850992185;
            for ii in 0..buffer.len() {
                if buffer[ii].norm_sqr() > max.norm_sqr() {
                    max = buffer[ii];
                    max_idx = ii;
                }
            }
            if !idxs.contains(&max_idx) {
                println!("MAX idx isn't in list of local maxima?");
            }
            assert!(idxs.contains(&max_idx));
            println!("max_idx={} max={}", max_idx, max);
            println!("idxs {} {} {}", idxs[0], idxs[1], idxs[2]);
            println!(
                "vals {} {} {}",
                buffer[idxs[0]].norm_sqr().log10() * 10.0,
                buffer[idxs[1]].norm_sqr().log10() * 10.0,
                buffer[idxs[2]].norm_sqr().log10() * 10.0
            );
            //TODO: Actually assert something, probably want to do something different for each _icol
        }
    }
    // Just a test to make sure things are in the order I expect
    #[test]
    fn test_row_major() {
        let nrows = 30;
        let ncols = 5;
        let mut arr = Array2::<Complex<f32>>::zeros((nrows, ncols));
        for ii in 0..nrows {
            for jj in 0..ncols {
                arr[(ii, jj)] = Complex::<f32>::new((ii * ncols + jj) as f32, 0.0);
            }
        }
        assert_eq!(arr.shape()[0], 30);
        assert_eq!(arr.shape()[1], 5);
        let ptr: *const Complex<f32> = arr.as_ptr();
        unsafe {
            for ii in 0..nrows {
                for jj in 0..ncols {
                    let offs = ii * ncols + jj;
                    let val = *(ptr.wrapping_add(offs));
                    assert_eq!(val.re, (ii * ncols + jj) as f32);
                }
            }
        }
        println!("{}", arr);
    }
}