use core::fmt::Write as _;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FeatureSet {
CsiOnly,
CirOnly,
CsiCir,
CsiCirDoppler,
CsiCirDopplerBfld,
FullUwb,
}
impl FeatureSet {
pub const MATRIX: [FeatureSet; 6] = [
FeatureSet::CsiOnly,
FeatureSet::CirOnly,
FeatureSet::CsiCir,
FeatureSet::CsiCirDoppler,
FeatureSet::CsiCirDopplerBfld,
FeatureSet::FullUwb,
];
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::CsiOnly => "csi_only",
Self::CirOnly => "cir_only",
Self::CsiCir => "csi+cir",
Self::CsiCirDoppler => "csi+cir+doppler",
Self::CsiCirDopplerBfld => "csi+cir+doppler+bfld",
Self::FullUwb => "full+uwb",
}
}
}
#[must_use]
pub fn latency_percentiles_ms(samples_ms: &[f64]) -> (f64, f64) {
let mut s: Vec<f64> = samples_ms
.iter()
.copied()
.filter(|v| v.is_finite())
.collect();
if s.is_empty() {
return (0.0, 0.0);
}
s.sort_by(f64::total_cmp);
let pick = |q: f64| {
let rank = ((q * s.len() as f64).ceil() as usize).clamp(1, s.len()) - 1;
s[rank]
};
(pick(0.50), pick(0.95))
}
#[must_use]
pub fn confusion_rates(tp: u64, fp: u64, tn: u64, fn_: u64) -> (f64, f64) {
let fp_rate = if fp + tn == 0 {
0.0
} else {
fp as f64 / (fp + tn) as f64
};
let fn_rate = if fn_ + tp == 0 {
0.0
} else {
fn_ as f64 / (fn_ + tp) as f64
};
(fp_rate, fn_rate)
}
#[must_use]
pub fn membership_inference_leakage(member_scores: &[f64], nonmember_scores: &[f64]) -> f64 {
if member_scores.is_empty() || nonmember_scores.is_empty() {
return 0.0;
}
let mut wins = 0.0;
let total = (member_scores.len() * nonmember_scores.len()) as f64;
for &m in member_scores {
for &n in nonmember_scores {
if m > n {
wins += 1.0;
} else if (m - n).abs() < f64::EPSILON {
wins += 0.5;
}
}
}
let auc = wins / total;
((auc - 0.5).abs() * 2.0).clamp(0.0, 1.0)
}
#[derive(Debug, Clone)]
pub struct AblationMetrics {
pub feature_set: FeatureSet,
pub presence_accuracy: f64,
pub localization_err_m: f64,
pub fp_rate: f64,
pub fn_rate: f64,
pub latency_p50_ms: f64,
pub latency_p95_ms: f64,
pub privacy_leakage: f64,
pub cross_room_degradation: f64,
}
#[derive(Debug, Clone)]
pub struct VariantRun {
pub feature_set: FeatureSet,
pub confusion: (u64, u64, u64, u64),
pub localization_err_m: f64,
pub latency_samples_ms: Vec<f64>,
pub member_scores: Vec<f64>,
pub nonmember_scores: Vec<f64>,
pub room_a_accuracy: f64,
pub room_b_accuracy: f64,
}
impl AblationMetrics {
#[must_use]
pub fn from_run(run: &VariantRun) -> Self {
let (tp, fp, tn, fn_) = run.confusion;
let (fp_rate, fn_rate) = confusion_rates(tp, fp, tn, fn_);
let total = (tp + fp + tn + fn_).max(1);
let presence_accuracy = (tp + tn) as f64 / total as f64;
let (p50, p95) = latency_percentiles_ms(&run.latency_samples_ms);
Self {
feature_set: run.feature_set,
presence_accuracy,
localization_err_m: run.localization_err_m,
fp_rate,
fn_rate,
latency_p50_ms: p50,
latency_p95_ms: p95,
privacy_leakage: membership_inference_leakage(
&run.member_scores,
&run.nonmember_scores,
),
cross_room_degradation: (run.room_a_accuracy - run.room_b_accuracy).max(0.0),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AblationReport {
pub rows: Vec<AblationMetrics>,
}
impl AblationReport {
#[must_use]
pub fn from_runs(runs: &[VariantRun]) -> Self {
Self {
rows: runs.iter().map(AblationMetrics::from_run).collect(),
}
}
#[must_use]
pub fn get(&self, fs: FeatureSet) -> Option<&AblationMetrics> {
self.rows.iter().find(|m| m.feature_set == fs)
}
#[must_use]
pub fn csi_cir_beats_csi_only(&self, min_wins: usize) -> bool {
let (Some(a), Some(b)) = (self.get(FeatureSet::CsiOnly), self.get(FeatureSet::CsiCir))
else {
return false;
};
let wins = [
b.presence_accuracy > a.presence_accuracy,
b.localization_err_m < a.localization_err_m,
b.latency_p95_ms <= a.latency_p95_ms,
]
.iter()
.filter(|w| **w)
.count();
wins >= min_wins
}
#[must_use]
pub fn to_markdown(&self) -> String {
let mut s = String::new();
let _ = writeln!(
s,
"| variant | presence_acc | loc_err_m | fp | fn | p50_ms | p95_ms | privacy_leak | xroom_degr |"
);
let _ = writeln!(s, "|---|---|---|---|---|---|---|---|---|");
for m in &self.rows {
let _ = writeln!(
s,
"| {} | {:.3} | {:.3} | {:.3} | {:.3} | {:.2} | {:.2} | {:.3} | {:.3} |",
m.feature_set.label(),
m.presence_accuracy,
m.localization_err_m,
m.fp_rate,
m.fn_rate,
m.latency_p50_ms,
m.latency_p95_ms,
m.privacy_leakage,
m.cross_room_degradation,
);
}
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn latency_percentiles_nearest_rank() {
let s: Vec<f64> = (1..=100).map(|i| i as f64).collect();
let (p50, p95) = latency_percentiles_ms(&s);
assert!((p50 - 50.0).abs() < 1e-9);
assert!((p95 - 95.0).abs() < 1e-9);
assert_eq!(latency_percentiles_ms(&[]), (0.0, 0.0));
}
#[test]
fn latency_percentiles_with_nan_does_not_panic() {
let s = vec![
10.0,
f64::NAN,
20.0,
30.0,
f64::INFINITY,
40.0,
f64::NEG_INFINITY,
50.0,
];
let (p50, p95) = latency_percentiles_ms(&s);
assert!(p50.is_finite() && p95.is_finite());
assert!((p50 - 30.0).abs() < 1e-9);
assert!((p95 - 50.0).abs() < 1e-9);
assert_eq!(latency_percentiles_ms(&[f64::NAN, f64::NAN]), (0.0, 0.0));
}
#[test]
fn confusion_rates_basic() {
let (fp_rate, fn_rate) = confusion_rates(80, 10, 90, 20);
assert!((fp_rate - 0.1).abs() < 1e-9); assert!((fn_rate - 0.2).abs() < 1e-9); }
#[test]
fn mia_leakage_zero_when_indistinguishable_high_when_separable() {
let same = vec![0.5, 0.6, 0.7];
assert!(membership_inference_leakage(&same, &same) < 1e-9);
let members = vec![0.9, 0.95, 0.99];
let nonmembers = vec![0.1, 0.2, 0.3];
assert!((membership_inference_leakage(&members, &nonmembers) - 1.0).abs() < 1e-9);
}
#[test]
fn csi_cir_beats_csi_only_acceptance() {
let csi_only = VariantRun {
feature_set: FeatureSet::CsiOnly,
confusion: (70, 15, 70, 30), localization_err_m: 0.40,
latency_samples_ms: vec![10.0; 10],
member_scores: vec![0.5],
nonmember_scores: vec![0.5],
room_a_accuracy: 0.8,
room_b_accuracy: 0.6,
};
let csi_cir = VariantRun {
feature_set: FeatureSet::CsiCir,
confusion: (88, 6, 90, 12), localization_err_m: 0.22,
latency_samples_ms: vec![11.0; 10],
member_scores: vec![0.5],
nonmember_scores: vec![0.5],
room_a_accuracy: 0.85,
room_b_accuracy: 0.80,
};
let runs = [csi_only, csi_cir];
let report = AblationReport::from_runs(&runs);
assert!(report.csi_cir_beats_csi_only(2));
let md = report.to_markdown();
assert!(md.contains("csi_only") && md.contains("csi+cir"));
assert_eq!(md, AblationReport::from_runs(&runs).to_markdown());
}
#[test]
fn matrix_has_six_variants() {
assert_eq!(FeatureSet::MATRIX.len(), 6);
}
}