use anyhow::{bail, Result};
use clap::Args;
use ndarray::Array2;
use num_complex::Complex64;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use wifi_densepose_core::types::{
AntennaConfig, CsiFrame, CsiMetadata, DeviceId, FrequencyBand, Timestamp,
};
use wifi_densepose_signal::{
BaselineCalibration, CalibrationConfig, CalibrationDeviationScore, CalibrationRecorder,
};
#[derive(Args, Debug, Clone)]
pub struct CalibrateArgs {
#[arg(long, default_value_t = 5005)]
pub udp_port: u16,
#[arg(long, default_value = "0.0.0.0")]
pub bind: String,
#[arg(long, default_value_t = 30)]
pub duration_s: u32,
#[arg(long, default_value = "./baseline.bin")]
pub output: String,
#[arg(long, default_value = "ht20")]
pub tier: String,
#[arg(long, default_value_t = 20)]
pub banner_every: u32,
#[arg(long, default_value_t = 2.0)]
pub abort_z_threshold: f32,
#[arg(long, default_value_t = 0)]
pub min_frames: u32,
}
const RECV_BUF: usize = 2048;
const ABORT_WINDOW_INTERVALS: u32 = 20;
pub async fn execute(args: CalibrateArgs) -> Result<()> {
validate_args(&args)?;
let mut config = tier_config(&args.tier);
if args.min_frames > 0 {
config.min_frames = args.min_frames;
eprintln!(
"[calibrate] WARN: --min-frames={} overrides ADR-135 tier default ({} for {}). \
This relaxes the phase-concentration guarantee; do not use in production.",
args.min_frames, tier_config(&args.tier).min_frames, args.tier
);
}
let target_frames = config.min_frames as usize;
let addr = format!("{}:{}", args.bind, args.udp_port);
let socket = UdpSocket::bind(&addr).await
.map_err(|e| anyhow::anyhow!("cannot bind UDP socket on {addr}: {e}"))?;
eprintln!("[calibrate] listening on udp://{addr}");
eprintln!(
"[calibrate] capturing {} frames (~{} s, tier={}) — ensure room is empty",
target_frames, args.duration_s, args.tier
);
let mut recorder = CalibrationRecorder::new(config);
let mut buf = vec![0u8; RECV_BUF];
let mut high_z_count: u32 = 0;
let deadline = Instant::now() + Duration::from_secs(args.duration_s as u64);
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
break;
}
let timeout = remaining.min(Duration::from_millis(500));
let recv = tokio::time::timeout(timeout, socket.recv(&mut buf)).await;
let n = match recv {
Ok(Ok(n)) => n,
Ok(Err(e)) => { eprintln!("[calibrate] recv error: {e}"); continue; }
Err(_) => continue, };
let Some(csi_frame) = parse_csi_packet(&buf[..n], &args.tier) else {
continue;
};
let score: CalibrationDeviationScore = match recorder.record(&csi_frame) {
Ok(s) => s,
Err(e) => { eprintln!("[calibrate] WARN frame skipped: {e}"); continue; }
};
let frames = recorder.frames_recorded() as usize;
if args.banner_every > 0 && (frames as u32) % args.banner_every == 0 {
print_banner(frames, target_frames, &score);
if args.abort_z_threshold > 0.0 && score.amplitude_z_median > args.abort_z_threshold {
high_z_count += 1;
if high_z_count >= ABORT_WINDOW_INTERVALS {
bail!(
"aborted: amplitude_z_median={:.2} exceeded threshold={:.2} for {} \
consecutive banner intervals — ensure the room is empty and retry",
score.amplitude_z_median, args.abort_z_threshold, high_z_count
);
}
} else {
high_z_count = 0;
}
}
if frames >= target_frames {
break;
}
}
finalise_and_save(recorder, &args.output)
}
fn print_banner(frames: usize, target: usize, score: &CalibrationDeviationScore) {
let motion_str = if score.motion_flagged {
"YES \u{2190} operator should be still"
} else {
"no"
};
eprintln!(
"[calibrate] {}/{} frames | z_med={:.2} z_max={:.2} | motion: {}",
frames, target, score.amplitude_z_median, score.amplitude_z_max, motion_str
);
}
fn finalise_and_save(recorder: CalibrationRecorder, output: &str) -> Result<()> {
let frames = recorder.frames_recorded();
eprintln!("[calibrate] finalising baseline from {frames} frames…");
let baseline: BaselineCalibration = recorder
.finalize()
.map_err(|e| anyhow::anyhow!("calibration failed: {e}"))?;
let bytes = baseline.to_bytes();
std::fs::write(output, &bytes)
.map_err(|e| anyhow::anyhow!("cannot write {output}: {e}"))?;
eprintln!(
"[calibrate] baseline saved to {output} ({} bytes)",
bytes.len()
);
eprintln!(
"[calibrate] summary: frames={} tier={:?} subcarriers={}",
baseline.frame_count,
baseline.tier,
baseline.subcarriers.len(),
);
Ok(())
}
pub(crate) fn tier_config(tier: &str) -> CalibrationConfig {
match tier.to_ascii_lowercase().as_str() {
"ht40" => CalibrationConfig::ht40(),
"he20" => CalibrationConfig::he20(),
"he40" => CalibrationConfig::he40(),
_ => CalibrationConfig::ht20(), }
}
pub(crate) fn parse_csi_packet(buf: &[u8], tier: &str) -> Option<CsiFrame> {
if buf.len() < 20 {
return None;
}
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0001 {
return None;
}
let node_id = buf[4];
let n_antennas = buf[5] as usize;
let n_subcarriers = u16::from_le_bytes([buf[6], buf[7]]) as usize;
let freq_mhz = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
let freq_mhz = u16::try_from(freq_mhz).unwrap_or(0);
let _sequence = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
let rssi = buf[16] as i8;
let noise_floor = buf[17] as i8;
let _ppdu_type = buf[18];
let n_pairs = n_antennas * n_subcarriers;
let iq_start = 20usize;
if buf.len() < iq_start + n_pairs * 2 {
return None;
}
let mut data = Array2::<Complex64>::zeros((n_antennas.max(1), n_subcarriers.max(1)));
for s in 0..n_antennas {
for k in 0..n_subcarriers {
let idx = s * n_subcarriers + k;
let i_val = buf[iq_start + idx * 2] as i8 as f64;
let q_val = buf[iq_start + idx * 2 + 1] as i8 as f64;
data[[s, k]] = Complex64::new(i_val, q_val);
}
}
let band = if freq_mhz >= 5000 {
FrequencyBand::Band5GHz
} else {
FrequencyBand::Band2_4GHz
};
let bw = tier_to_bw_mhz(tier);
let mut meta = CsiMetadata::new(
DeviceId::new(format!("esp32-node{}", node_id)),
band,
freq_mhz_to_channel(freq_mhz),
);
meta.bandwidth_mhz = bw;
meta.rssi_dbm = rssi;
meta.noise_floor_dbm = noise_floor;
meta.antenna_config = AntennaConfig {
tx_antennas: 1,
rx_antennas: n_antennas as u8,
spacing_mm: None,
};
meta.timestamp = Timestamp::now();
Some(CsiFrame::new(meta, data))
}
fn tier_to_bw_mhz(tier: &str) -> u16 {
match tier.to_ascii_lowercase().as_str() {
"ht40" | "he40" => 40,
_ => 20,
}
}
fn freq_mhz_to_channel(freq_mhz: u16) -> u8 {
if freq_mhz < 3000 {
((freq_mhz.saturating_sub(2407)) / 5) as u8
} else {
((freq_mhz.saturating_sub(5000)) / 5) as u8
}
}
fn validate_args(args: &CalibrateArgs) -> Result<()> {
if args.duration_s < 10 {
bail!(
"--duration-s must be at least 10 s (got {}). \
Fewer frames produce unreliable phase-concentration estimates (ADR-135 §2.3).",
args.duration_s
);
}
if args.duration_s > 300 {
eprintln!(
"[calibrate] WARN: --duration-s={} exceeds 300 s; this is unusual.",
args.duration_s
);
}
let valid = ["ht20", "ht40", "he20", "he40"];
if !valid.contains(&args.tier.to_ascii_lowercase().as_str()) {
bail!(
"--tier must be one of {:?} (got {:?})",
valid, args.tier
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_args_min_duration() {
let mut args = default_args();
args.duration_s = 5;
assert!(validate_args(&args).is_err());
}
#[test]
fn test_validate_args_ok() {
let args = default_args();
assert!(validate_args(&args).is_ok());
}
#[test]
fn test_validate_args_bad_tier() {
let mut args = default_args();
args.tier = "ht80".into();
assert!(validate_args(&args).is_err());
}
#[test]
fn test_tier_config_ht20() {
let cfg = tier_config("ht20");
assert_eq!(cfg.num_active, 52);
}
#[test]
fn test_tier_config_ht40() {
let cfg = tier_config("ht40");
assert_eq!(cfg.num_active, 114);
}
#[test]
fn test_tier_config_he20() {
let cfg = tier_config("he20");
assert_eq!(cfg.num_active, 256);
}
#[test]
fn test_parse_csi_packet_bad_magic() {
let buf = vec![0u8; 32];
assert!(parse_csi_packet(&buf, "ht20").is_none());
}
#[test]
fn test_parse_csi_packet_too_short() {
let buf = vec![0u8; 10];
assert!(parse_csi_packet(&buf, "ht20").is_none());
}
fn build_frame(n_subcarriers: u16, ppdu: u8) -> Vec<u8> {
let mut buf = vec![0u8; 20 + n_subcarriers as usize * 2];
buf[0..4].copy_from_slice(&0xC511_0001u32.to_le_bytes());
buf[4] = 12; buf[5] = 1; buf[6..8].copy_from_slice(&n_subcarriers.to_le_bytes());
buf[8..12].copy_from_slice(&2432u32.to_le_bytes()); buf[12..16].copy_from_slice(&11610u32.to_le_bytes()); buf[16] = (-40i8) as u8; buf[17] = (-87i8) as u8; buf[18] = ppdu;
buf[19] = 0x10; for k in 0..n_subcarriers as usize {
buf[20 + k * 2] = (10 + (k % 100) as i8) as u8;
buf[20 + k * 2 + 1] = (k % 50) as u8;
}
buf
}
#[test]
fn test_parse_csi_packet_valid() {
let buf = build_frame(2, 0);
let frame = parse_csi_packet(&buf, "ht20");
assert!(frame.is_some());
let f = frame.unwrap();
assert_eq!(f.num_spatial_streams(), 1);
assert_eq!(f.num_subcarriers(), 2);
assert_eq!(f.metadata.rssi_dbm, -40);
assert_eq!(f.metadata.noise_floor_dbm, -87);
}
#[test]
fn test_parse_csi_packet_he_su_256_bins() {
let buf = build_frame(256, 1);
assert_eq!(buf.len(), 532); let f = parse_csi_packet(&buf, "he20").expect("256-bin HE frame must parse");
assert_eq!(f.num_subcarriers(), 256);
assert_eq!(f.metadata.rssi_dbm, -40);
let mut he = wifi_densepose_signal::CalibrationRecorder::new(tier_config("he20"));
assert!(he.record(&f).is_ok());
let mut ht = wifi_densepose_signal::CalibrationRecorder::new(tier_config("ht20"));
assert!(ht.record(&f).is_err());
}
#[test]
fn test_freq_to_channel_24ghz() {
assert_eq!(freq_mhz_to_channel(2437), 6);
}
#[test]
fn test_freq_to_channel_5ghz() {
assert_eq!(freq_mhz_to_channel(5180), 36);
}
fn default_args() -> CalibrateArgs {
CalibrateArgs {
udp_port: 5005,
bind: "0.0.0.0".into(),
duration_s: 30,
output: "./baseline.bin".into(),
tier: "ht20".into(),
banner_every: 20,
abort_z_threshold: 2.0,
min_frames: 0,
}
}
}