use super::{convolve, shift_down_fs_over_4, shift_up_fs_over_4};
use ndarray::{Array2, ArrayView2, ArrayViewMut2};
use num_complex::Complex;
use num_traits::{Float, NumAssignOps};
pub fn triband_cascade_noalloc<T: Float + NumAssignOps>(
inp: ArrayView2<Complex<T>>,
mut out: ArrayViewMut2<Complex<T>>,
flt: &Vec<T>,
integer_phase_offset: usize,
) {
let n_inp_rows = inp.shape()[0];
let n_inp_cols = inp.shape()[1];
let n_out_rows = out.shape()[0];
let n_out_cols = out.shape()[1];
assert_eq!(n_out_rows, n_inp_rows - flt.len() + 1);
assert_eq!(n_out_cols, n_inp_cols * 2 + 1);
for icol in 0..n_inp_cols {
let col_vec = inp.column(icol).to_vec();
let col_slice = col_vec.as_slice();
let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, integer_phase_offset);
let low_flt = convolve(&tuned, &flt);
let mid_flt = convolve(col_slice, &flt);
assert_eq!(n_out_rows, low_flt.len());
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 0)] = low_flt[irow];
out[(irow, icol * 2 + 1)] = mid_flt[irow];
}
if icol == n_inp_cols - 1 {
let high_tuned = shift_up_fs_over_4(&col_slice, integer_phase_offset);
let high_flt = convolve(&high_tuned, &flt);
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 2)] = high_flt[irow];
}
}
}
}
pub fn triband_cascade_decimate_noalloc<T: Float + NumAssignOps>(
inp: ArrayView2<Complex<T>>,
mut out: ArrayViewMut2<Complex<T>>,
flt: &Vec<T>,
integer_phase_offset: usize,
start_row: bool,
) {
let n_inp_rows = inp.shape()[0];
let n_inp_cols = inp.shape()[1];
let start_index = if start_row { 0 } else { 1 };
let decimated_rows = (n_inp_rows - start_index + 1) / 2;
let mut decimated_inp = Array2::<Complex<T>>::zeros((decimated_rows, n_inp_cols));
for icol in 0..n_inp_cols {
for irow in 0..decimated_rows {
decimated_inp[(irow, icol)] = inp[(irow * 2 + start_index, icol)];
}
}
let n_out_rows = decimated_inp.shape()[0] - flt.len() + 1;
let n_out_cols = decimated_inp.shape()[1] * 2 + 1;
assert_eq!(n_out_rows, out.shape()[0]);
assert_eq!(n_out_cols, out.shape()[1]);
for icol in 0..decimated_inp.shape()[1] {
let col_vec = decimated_inp.column(icol).to_vec();
let col_slice = col_vec.as_slice();
let tuned: Vec<Complex<T>> = shift_down_fs_over_4(&col_slice, integer_phase_offset);
let low_flt = convolve(&tuned, &flt);
let mid_flt = convolve(col_slice, &flt);
assert_eq!(n_out_rows, low_flt.len());
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 0)] = low_flt[irow];
out[(irow, icol * 2 + 1)] = mid_flt[irow];
}
if icol == decimated_inp.shape()[1] - 1 {
let high_tuned = shift_up_fs_over_4(&col_slice, integer_phase_offset);
let high_flt = convolve(&high_tuned, &flt);
for irow in 0..n_out_rows {
out[(irow, icol * 2 + 2)] = high_flt[irow];
}
}
}
}
pub fn triband_cascade<T: Float + NumAssignOps>(
inp: ArrayView2<Complex<T>>,
flt: &Vec<T>,
integer_phase_offset: usize,
) -> Array2<Complex<T>> {
let n_inp_rows = inp.shape()[0];
let n_inp_cols = inp.shape()[1];
let n_out_rows = n_inp_rows - flt.len() + 1;
let n_out_cols = n_inp_cols * 2 + 1;
let mut out = Array2::<Complex<T>>::zeros((n_out_rows, n_out_cols));
triband_cascade_noalloc(inp, out.view_mut(), flt, integer_phase_offset);
out
}
#[cfg(test)]
mod test {
use super::*;
use crate::get_hb_filter;
use approx::assert_relative_eq;
use ndarray::{Array1, Array2, Axis};
use rustfft::{Fft, FftPlanner};
use std::f64::consts::PI;
use std::sync::Arc;
fn find_local_maxima(vv: &Vec<Complex<f32>>) -> Vec<usize> {
let mut idxs = vec![];
for ii in 0..vv.len() {
let lidx = if ii == 0 { vv.len() - 1 } else { ii - 1 };
let uidx = if ii == vv.len() - 1 { 0 } else { ii + 1 };
let nrm = vv[ii].norm_sqr();
if vv[lidx].norm_sqr() <= nrm && vv[uidx].norm_sqr() <= nrm {
idxs.push(ii);
}
}
idxs
}
#[test]
fn test_run_filter() {
let flt = get_hb_filter::<f32>(31);
let sigs = vec![
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(0.5, (-0.17 * 2.0 * PI * nn as f64) as f32)
})),
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(1.0, (0.35 * 2.0 * PI * nn as f64) as f32)
})),
Array1::from_iter((0..32768).map(|nn| -> Complex<f32> {
Complex::<f32>::from_polar(2.0, (0.21 * 2.0 * PI * nn as f64) as f32)
})),
];
let mut sum = sigs[0].clone();
for arr in &sigs[1..] {
sum += arr;
}
let sum_col: Array2<Complex<f32>> = sum.insert_axis(ndarray::Axis(1));
let flt1 = triband_cascade(sum_col.view(), &flt, 0);
let mut planner = FftPlanner::new();
let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt1.shape()[0]);
for (icol, col) in flt1.axis_iter(Axis(1)).enumerate() {
let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
assert_eq!(buffer.len(), 32768-30);
fft.process(&mut buffer);
let mut idxs: Vec<usize> = find_local_maxima(&buffer);
idxs.sort_by(|&i, &j| {
buffer[j]
.norm_sqr()
.partial_cmp(&buffer[i].norm_sqr())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut max = Complex::<f32>::new(0.0, 0.0);
let mut max_idx = 0;
for ii in 0..buffer.len() {
if buffer[ii].norm_sqr() > max.norm_sqr() {
max = buffer[ii];
max_idx = ii;
}
}
assert!(idxs.contains(&max_idx));
for imax in 0..3 {
let max_idx = idxs[imax];
let max_freq = if max_idx < buffer.len() / 2 {
max_idx as f64 / buffer.len() as f64
} else {
(max_idx as f64 / buffer.len() as f64) - 1.0
};
let max_db =
(buffer[max_idx].norm_sqr() / (buffer.len() as f32).powi(2)).log10() * 10.0;
println!("{} {} {} {} {}", icol, imax, max_idx, max_freq, max_db);
let exp_freq = match icol * 10 + imax {
00 => Some(-0.17 + 0.25),
10 => Some(0.21),
11 => Some(-0.17),
12 => Some(0.35),
20 => Some(0.21 - 0.25),
21 => Some(0.35 - 0.25),
_ => None,
};
if exp_freq.is_some() {
assert_relative_eq!(max_freq, exp_freq.unwrap(), epsilon = 1.0 / 30e3);
}
}
}
let flt2 = triband_cascade(flt1.view(), &flt, 0);
let mut planner = FftPlanner::new();
let fft: Arc<dyn Fft<f32>> = planner.plan_fft_forward(flt2.shape()[0]);
for (icol, col) in flt2.axis_iter(Axis(1)).enumerate() {
let (mut buffer, _) = col.to_owned().into_raw_vec_and_offset();
assert_eq!(buffer.len(), 32768-60);
fft.process(&mut buffer);
let mut idxs: Vec<usize> = find_local_maxima(&buffer);
idxs.sort_by(|&i, &j| {
buffer[j]
.norm_sqr()
.partial_cmp(&buffer[i].norm_sqr())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut max = Complex::<f32>::new(0.0, 0.0);
let mut max_idx = 0;
for ii in 0..buffer.len() {
if buffer[ii].norm_sqr() > max.norm_sqr() {
max = buffer[ii];
max_idx = ii;
}
}
assert!(idxs.contains(&max_idx));
for imax in 0..3 {
let max_idx = idxs[imax];
let max_freq = if max_idx < buffer.len() / 2 {
max_idx as f64 / buffer.len() as f64
} else {
(max_idx as f64 / buffer.len() as f64) - 1.0
};
let max_db =
(buffer[max_idx].norm_sqr() / (buffer.len() as f32).powi(2)).log10() * 10.0;
println!("{} {} {} {} {}", icol, imax, max_idx, max_freq, max_db);
let exp_freq = match icol * 10 + imax {
30 => Some(0.21),
31 => Some(-0.17),
_ => None,
};
if exp_freq.is_some() {
assert_relative_eq!(max_freq, exp_freq.unwrap(), epsilon = 1.0 / 30e3);
}
}
}
}
#[test]
fn test_row_major() {
let nrows = 30;
let ncols = 5;
let mut arr = Array2::<Complex<f32>>::zeros((nrows, ncols));
for ii in 0..nrows {
for jj in 0..ncols {
arr[(ii, jj)] = Complex::<f32>::new((ii * ncols + jj) as f32, 0.0);
}
}
assert_eq!(arr.shape()[0], 30);
assert_eq!(arr.shape()[1], 5);
let ptr: *const Complex<f32> = arr.as_ptr();
unsafe {
for ii in 0..nrows {
for jj in 0..ncols {
let offs = ii * ncols + jj;
let val = *(ptr.wrapping_add(offs));
assert_eq!(val.re, (ii * ncols + jj) as f32);
}
}
}
println!("{}", arr);
}
}