use ndarray::{s, Array1, Array4};
use wifi_densepose_signal::csi_processor::CsiData;
use wifi_densepose_signal::features::FeatureExtractor;
pub const FEATURE_LEN: usize = 12;
const DEFAULT_CENTRE_FREQ_HZ: f64 = 2.4e9;
const DEFAULT_BANDWIDTH_HZ: f64 = 40.0e6;
pub fn extract_signal_features(amplitude: &Array4<f32>, phase: &Array4<f32>) -> Array1<f32> {
let (n_t, n_tx, n_rx, n_sc) = amplitude.dim();
debug_assert_eq!(amplitude.dim(), phase.dim(), "amplitude/phase shape mismatch");
if n_t == 0 || n_tx == 0 || n_rx == 0 || n_sc == 0 {
return Array1::zeros(FEATURE_LEN);
}
let n_ant = n_tx * n_rx;
let t = n_t / 2;
let to_2d = |src: &Array4<f32>| -> Vec<f64> {
src.slice(s![t, .., .., ..]).iter().map(|&v| f64::from(v)).collect()
};
let amp2d = match ndarray::Array2::from_shape_vec((n_ant, n_sc), to_2d(amplitude)) {
Ok(a) => a,
Err(_) => return Array1::zeros(FEATURE_LEN),
};
let phase2d = match ndarray::Array2::from_shape_vec((n_ant, n_sc), to_2d(phase)) {
Ok(p) => p,
Err(_) => return Array1::zeros(FEATURE_LEN),
};
let csi = match CsiData::builder()
.amplitude(amp2d)
.phase(phase2d)
.frequency(DEFAULT_CENTRE_FREQ_HZ)
.bandwidth(DEFAULT_BANDWIDTH_HZ)
.build()
{
Ok(c) => c,
Err(_) => return Array1::zeros(FEATURE_LEN),
};
let feats = FeatureExtractor::default_config().extract(&csi);
let amp_mean_overall = mean_or_zero(feats.amplitude.mean.iter().copied());
let amp_var_overall = mean_or_zero(feats.amplitude.variance.iter().copied());
let phase_var_overall = mean_or_zero(feats.phase.variance.iter().copied());
let raw = [
feats.amplitude.peak,
feats.amplitude.rms,
feats.amplitude.dynamic_range,
amp_mean_overall,
amp_var_overall,
feats.phase.coherence,
phase_var_overall,
feats.psd.total_power,
feats.psd.peak_power,
feats.psd.peak_frequency,
feats.psd.centroid,
feats.psd.bandwidth,
];
debug_assert_eq!(raw.len(), FEATURE_LEN);
Array1::from_iter(raw.iter().map(|&v| sanitise(v)))
}
fn mean_or_zero<I: Iterator<Item = f64>>(it: I) -> f64 {
let (sum, n) = it.fold((0.0_f64, 0_usize), |(s, k), v| (s + v, k + 1));
if n == 0 {
0.0
} else {
sum / n as f64
}
}
fn sanitise(v: f64) -> f32 {
if v.is_finite() {
v as f32
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array4;
#[test]
fn zero_sized_input_yields_zero_vector() {
let empty = Array4::<f32>::zeros((0, 0, 0, 0));
let f = extract_signal_features(&empty, &empty);
assert_eq!(f.len(), FEATURE_LEN);
assert!(f.iter().all(|&v| v == 0.0));
}
#[test]
fn constant_input_is_finite_and_correct_length() {
let amp = Array4::<f32>::from_elem((4, 3, 3, 56), 1.5);
let phase = Array4::<f32>::from_elem((4, 3, 3, 56), 0.25);
let f = extract_signal_features(&, &phase);
assert_eq!(f.len(), FEATURE_LEN);
assert!(f.iter().all(|v| v.is_finite()), "features must be finite: {f:?}");
}
}