use super::{convolve, shift_down_fs_over_4, shift_up_fs_over_4};
use ndarray::{Array2, ArrayView2, ArrayViewMut2};
use num_complex::Complex;
use num_traits::{Float, NumAssignOps};
pub fn triband_cascade_noalloc<T: Float + NumAssignOps>(
inp: ArrayView2<Complex<T>>,
mut out: ArrayViewMut2<Complex<T>>,
flt: &Vec<T>,
) {
let n_inp_rows = inp.shape()[0];
let n_inp_cols = inp.shape()[1];
let n_out_rows = out.shape()[0];
let n_out_cols = out.shape()[1];
assert_eq!(n_out_rows, n_inp_rows - flt.len() + 1);
assert_eq!(n_out_cols, n_inp_cols * 2 + 1);
for icol in 0..n_inp_cols {
let col_vec = inp.column(icol).to_vec();
let col_slice = col_vec.as_slice();
let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, 0);
let low_flt = convolve(&tuned, &flt);
let mid_flt = convolve(col_slice, &flt);
assert_eq!(n_out_rows, low_flt.len());
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 0)] = low_flt[irow];
out[(irow, icol * 2 + 1)] = mid_flt[irow];
}
if icol == n_inp_cols - 1 {
let high_tuned = shift_up_fs_over_4(&col_slice, 0);
let high_flt = convolve(&high_tuned, &flt);
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 2)] = high_flt[irow];
}
}
}
}
pub fn triband_cascade<T: Float + NumAssignOps>(
inp: ArrayView2<Complex<T>>,
flt: &Vec<T>,
) -> Array2<Complex<T>> {
let n_inp_rows = inp.shape()[0];
let n_inp_cols = inp.shape()[1];
let n_out_rows = n_inp_rows - flt.len() + 1;
let n_out_cols = n_inp_cols * 2 + 1;
let mut out = Array2::<Complex<T>>::zeros((n_out_rows, n_out_cols));
triband_cascade_noalloc(inp, out.view_mut(), flt);
out
}
#[cfg(test)]
mod test {
use super::*;
use crate::get_hb_filter;
use ndarray::{Array1, Array2, Axis};
use rustfft::{Fft, FftPlanner};
use std::f64::consts::PI;
use std::sync::Arc;
fn find_local_maxima(vv: &Vec<Complex<f32>>) -> Vec<usize> {
let mut idxs = vec![];
for ii in 0..vv.len() {
let lidx = if ii == 0 { vv.len() - 1 } else { ii - 1 };
let uidx = if ii == vv.len() - 1 { 0 } else { ii + 1 };
let nrm = vv[ii].norm_sqr();
if vv[lidx].norm_sqr() <= nrm && vv[uidx].norm_sqr() <= nrm {
idxs.push(ii);
}
}
idxs
}
#[test]
fn test_run_filter() {
let flt = get_hb_filter::<f32>(31);
let sigs = vec![
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(0.5, (-0.25 * 2.0 * PI * nn as f64) as f32)
})),
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(1.0, (0.25 * 2.0 * PI * nn as f64) as f32)
})),
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(2.0, (0.0 * 2.0 * PI * nn as f64) as f32)
})),
];
let mut sum = sigs[0].clone();
for arr in &sigs[1..] {
sum += arr;
}
let sum_col: Array2<Complex<f32>> = sum.insert_axis(ndarray::Axis(1));
let flt1 = triband_cascade(sum_col.view(), &flt);
let flt2 = triband_cascade(flt1.view(), &flt);
let flt3 = triband_cascade(flt2.view(), &flt);
let mut planner = FftPlanner::new();
let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt1.shape()[0]);
for (_icol, col) in flt1.axis_iter(Axis(1)).enumerate() {
let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
fft.process(&mut buffer);
let mut idxs: Vec<usize> = find_local_maxima(&buffer);
idxs.sort_by(|&i, &j| {
buffer[j]
.norm_sqr()
.partial_cmp(&buffer[i].norm_sqr())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut max = Complex::<f32>::new(0.0, 0.0);
let mut max_idx = 2509384098850992185;
for ii in 0..buffer.len() {
if buffer[ii].norm_sqr() > max.norm_sqr() {
max = buffer[ii];
max_idx = ii;
}
}
if !idxs.contains(&max_idx) {
println!("MAX idx isn't in list of local maxima?");
}
assert!(idxs.contains(&max_idx));
println!("max_idx={} max={}", max_idx, max);
println!("idxs {} {} {}", idxs[0], idxs[1], idxs[2]);
println!(
"vals {} {} {}",
buffer[idxs[0]].norm_sqr().log10() * 10.0,
buffer[idxs[1]].norm_sqr().log10() * 10.0,
buffer[idxs[2]].norm_sqr().log10() * 10.0
);
}
}
#[test]
fn test_row_major() {
let nrows = 30;
let ncols = 5;
let mut arr = Array2::<Complex<f32>>::zeros((nrows, ncols));
for ii in 0..nrows {
for jj in 0..ncols {
arr[(ii, jj)] = Complex::<f32>::new((ii * ncols + jj) as f32, 0.0);
}
}
assert_eq!(arr.shape()[0], 30);
assert_eq!(arr.shape()[1], 5);
let ptr: *const Complex<f32> = arr.as_ptr();
unsafe {
for ii in 0..nrows {
for jj in 0..ncols {
let offs = ii * ncols + jj;
let val = *(ptr.wrapping_add(offs));
assert_eq!(val.re, (ii * ncols + jj) as f32);
}
}
}
println!("{}", arr);
}
}