use ndarray::{Array1, Array2};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum RuViewTier {
Fail,
Bronze,
Silver,
Gold,
}
impl std::fmt::Display for RuViewTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RuViewTier::Fail => write!(f, "FAIL"),
RuViewTier::Bronze => write!(f, "BRONZE"),
RuViewTier::Silver => write!(f, "SILVER"),
RuViewTier::Gold => write!(f, "GOLD"),
}
}
}
#[derive(Debug, Clone)]
pub struct JointErrorThresholds {
pub pck_all: f32,
pub pck_torso: f32,
pub oks: f32,
pub jitter_rms_m: f32,
pub max_error_p95_m: f32,
}
impl Default for JointErrorThresholds {
fn default() -> Self {
JointErrorThresholds {
pck_all: 0.70,
pck_torso: 0.80,
oks: 0.50,
jitter_rms_m: 0.03,
max_error_p95_m: 0.15,
}
}
}
#[derive(Debug, Clone)]
pub struct JointErrorResult {
pub pck_all: f32,
pub pck_torso: f32,
pub oks: f32,
pub jitter_rms_m: f32,
pub max_error_p95_m: f32,
pub passes: bool,
}
const COCO_SIGMAS: [f32; 17] = [
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089,
];
const TORSO_INDICES: [usize; 4] = [5, 6, 11, 12];
pub fn evaluate_joint_error(
pred_kpts: &[Array2<f32>],
gt_kpts: &[Array2<f32>],
visibility: &[Array1<f32>],
scale: &[f32],
thresholds: &JointErrorThresholds,
) -> JointErrorResult {
let n = pred_kpts.len();
if n == 0 {
return JointErrorResult {
pck_all: 0.0,
pck_torso: 0.0,
oks: 0.0,
jitter_rms_m: f32::MAX,
max_error_p95_m: f32::MAX,
passes: false,
};
}
let pck_threshold = 0.2;
let mut all_correct = 0_usize;
let mut all_total = 0_usize;
let mut torso_correct = 0_usize;
let mut torso_total = 0_usize;
let mut oks_sum = 0.0_f64;
let mut per_kp_errors: Vec<Vec<f32>> = vec![Vec::new(); 17];
for i in 0..n {
let bbox_diag = compute_bbox_diag(>_kpts[i], &visibility[i]);
let safe_diag = bbox_diag.max(1e-3);
let dist_thr = pck_threshold * safe_diag;
for j in 0..17 {
if visibility[i][j] < 0.5 {
continue;
}
let dx = pred_kpts[i][[j, 0]] - gt_kpts[i][[j, 0]];
let dy = pred_kpts[i][[j, 1]] - gt_kpts[i][[j, 1]];
let dist = (dx * dx + dy * dy).sqrt();
per_kp_errors[j].push(dist);
all_total += 1;
if dist <= dist_thr {
all_correct += 1;
}
if TORSO_INDICES.contains(&j) {
torso_total += 1;
if dist <= dist_thr {
torso_correct += 1;
}
}
}
let s = scale.get(i).copied().unwrap_or(1.0);
let oks_frame = compute_single_oks(&pred_kpts[i], >_kpts[i], &visibility[i], s);
oks_sum += oks_frame as f64;
}
let pck_all = if all_total > 0 { all_correct as f32 / all_total as f32 } else { 0.0 };
let pck_torso = if torso_total > 0 { torso_correct as f32 / torso_total as f32 } else { 0.0 };
let oks = (oks_sum / n as f64) as f32;
let jitter_rms_m = compute_torso_jitter(pred_kpts, visibility);
let max_error_p95_m = compute_p95_max_error(&per_kp_errors);
let passes = pck_all >= thresholds.pck_all
&& pck_torso >= thresholds.pck_torso
&& oks >= thresholds.oks
&& jitter_rms_m < thresholds.jitter_rms_m
&& max_error_p95_m < thresholds.max_error_p95_m;
JointErrorResult {
pck_all,
pck_torso,
oks,
jitter_rms_m,
max_error_p95_m,
passes,
}
}
#[derive(Debug, Clone)]
pub struct TrackingThresholds {
pub max_id_switches: usize,
pub max_frag_ratio: f32,
pub max_false_tracks_per_min: f32,
}
impl Default for TrackingThresholds {
fn default() -> Self {
TrackingThresholds {
max_id_switches: 0,
max_frag_ratio: 0.05,
max_false_tracks_per_min: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct TrackingFrame {
pub frame_idx: usize,
pub gt_ids: Vec<u32>,
pub pred_ids: Vec<u32>,
pub assignments: Vec<(u32, u32)>,
}
#[derive(Debug, Clone)]
pub struct TrackingResult {
pub id_switches: usize,
pub fragmentation_ratio: f32,
pub false_tracks_per_min: f32,
pub mota: f32,
pub n_frames: usize,
pub passes: bool,
}
pub fn evaluate_tracking(
frames: &[TrackingFrame],
duration_minutes: f32,
thresholds: &TrackingThresholds,
) -> TrackingResult {
let n_frames = frames.len();
if n_frames == 0 {
return TrackingResult {
id_switches: 0,
fragmentation_ratio: 0.0,
false_tracks_per_min: 0.0,
mota: 0.0,
n_frames: 0,
passes: false,
};
}
let mut id_switches = 0_usize;
let mut prev_assignment: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
let mut total_gt = 0_usize;
let mut total_misses = 0_usize;
let mut total_false_positives = 0_usize;
let mut gt_track_presence: std::collections::HashMap<u32, Vec<bool>> =
std::collections::HashMap::new();
for frame in frames {
total_gt += frame.gt_ids.len();
let n_matched = frame.assignments.len();
total_misses += frame.gt_ids.len().saturating_sub(n_matched);
total_false_positives += frame.pred_ids.len().saturating_sub(n_matched);
let mut current_assignment: std::collections::HashMap<u32, u32> =
std::collections::HashMap::new();
for &(pred_id, gt_id) in &frame.assignments {
current_assignment.insert(gt_id, pred_id);
if let Some(&prev_pred) = prev_assignment.get(>_id) {
if prev_pred != pred_id {
id_switches += 1;
}
}
}
for >_id in &frame.gt_ids {
gt_track_presence
.entry(gt_id)
.or_default()
.push(frame.assignments.iter().any(|&(_, gid)| gid == gt_id));
}
prev_assignment = current_assignment;
}
let mut n_fragmented = 0_usize;
let mut n_tracks = 0_usize;
for presence in gt_track_presence.values() {
if presence.len() < 2 {
continue;
}
n_tracks += 1;
let mut has_gap = false;
let mut was_present = false;
let mut lost = false;
for &present in presence {
if was_present && !present {
lost = true;
}
if lost && present {
has_gap = true;
break;
}
was_present = present;
}
if has_gap {
n_fragmented += 1;
}
}
let fragmentation_ratio = if n_tracks > 0 {
n_fragmented as f32 / n_tracks as f32
} else {
0.0
};
let safe_duration = duration_minutes.max(1e-6);
let false_tracks_per_min = total_false_positives as f32 / safe_duration;
let mota = if total_gt > 0 {
1.0 - (total_misses + total_false_positives + id_switches) as f32 / total_gt as f32
} else {
0.0
};
let passes = id_switches <= thresholds.max_id_switches
&& fragmentation_ratio < thresholds.max_frag_ratio
&& false_tracks_per_min <= thresholds.max_false_tracks_per_min;
TrackingResult {
id_switches,
fragmentation_ratio,
false_tracks_per_min,
mota,
n_frames,
passes,
}
}
#[derive(Debug, Clone)]
pub struct VitalSignThresholds {
pub breathing_bpm_tolerance: f32,
pub breathing_snr_db: f32,
pub heartbeat_bpm_tolerance: f32,
pub heartbeat_snr_db: f32,
pub micro_motion_m: f32,
pub micro_motion_range_m: f32,
}
impl Default for VitalSignThresholds {
fn default() -> Self {
VitalSignThresholds {
breathing_bpm_tolerance: 2.0,
breathing_snr_db: 6.0,
heartbeat_bpm_tolerance: 5.0,
heartbeat_snr_db: 3.0,
micro_motion_m: 0.001,
micro_motion_range_m: 3.0,
}
}
}
#[derive(Debug, Clone)]
pub struct VitalSignMeasurement {
pub breathing_bpm: f32,
pub gt_breathing_bpm: f32,
pub breathing_snr_db: f32,
pub heartbeat_bpm: Option<f32>,
pub gt_heartbeat_bpm: Option<f32>,
pub heartbeat_snr_db: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct VitalSignResult {
pub breathing_error_bpm: f32,
pub breathing_snr_db: f32,
pub heartbeat_error_bpm: Option<f32>,
pub heartbeat_snr_db: Option<f32>,
pub n_measurements: usize,
pub passes: bool,
}
pub fn evaluate_vital_signs(
measurements: &[VitalSignMeasurement],
thresholds: &VitalSignThresholds,
) -> VitalSignResult {
let n = measurements.len();
if n == 0 {
return VitalSignResult {
breathing_error_bpm: f32::MAX,
breathing_snr_db: 0.0,
heartbeat_error_bpm: None,
heartbeat_snr_db: None,
n_measurements: 0,
passes: false,
};
}
let breathing_errors: Vec<f32> = measurements
.iter()
.map(|m| (m.breathing_bpm - m.gt_breathing_bpm).abs())
.collect();
let breathing_error_mean = breathing_errors.iter().sum::<f32>() / n as f32;
let breathing_snr_mean =
measurements.iter().map(|m| m.breathing_snr_db).sum::<f32>() / n as f32;
let heartbeat_pairs: Vec<(f32, f32, f32)> = measurements
.iter()
.filter_map(|m| {
match (m.heartbeat_bpm, m.gt_heartbeat_bpm, m.heartbeat_snr_db) {
(Some(hb), Some(gt), Some(snr)) => Some((hb, gt, snr)),
_ => None,
}
})
.collect();
let (heartbeat_error, heartbeat_snr) = if heartbeat_pairs.is_empty() {
(None, None)
} else {
let hb_n = heartbeat_pairs.len() as f32;
let err = heartbeat_pairs
.iter()
.map(|(hb, gt, _)| (hb - gt).abs())
.sum::<f32>()
/ hb_n;
let snr = heartbeat_pairs.iter().map(|(_, _, s)| s).sum::<f32>() / hb_n;
(Some(err), Some(snr))
};
let breathing_passes = breathing_error_mean <= thresholds.breathing_bpm_tolerance
&& breathing_snr_mean >= thresholds.breathing_snr_db;
let heartbeat_passes = match (heartbeat_error, heartbeat_snr) {
(Some(err), Some(snr)) => {
err <= thresholds.heartbeat_bpm_tolerance && snr >= thresholds.heartbeat_snr_db
}
_ => true, };
let passes = breathing_passes && heartbeat_passes;
VitalSignResult {
breathing_error_bpm: breathing_error_mean,
breathing_snr_db: breathing_snr_mean,
heartbeat_error_bpm: heartbeat_error,
heartbeat_snr_db: heartbeat_snr,
n_measurements: n,
passes,
}
}
#[derive(Debug, Clone)]
pub struct RuViewAcceptanceResult {
pub joint_error: JointErrorResult,
pub tracking: TrackingResult,
pub vital_signs: VitalSignResult,
pub tier: RuViewTier,
}
impl RuViewAcceptanceResult {
pub fn summary(&self) -> String {
format!(
"RuView Tier={} | PCK={:.3} OKS={:.3} | MOTA={:.3} IDsw={} | Breathing={:.1}BPM err",
self.tier,
self.joint_error.pck_all,
self.joint_error.oks,
self.tracking.mota,
self.tracking.id_switches,
self.vital_signs.breathing_error_bpm,
)
}
}
pub fn determine_tier(
joint_error: &JointErrorResult,
tracking: &TrackingResult,
vital_signs: &VitalSignResult,
) -> RuViewTier {
if !tracking.passes {
return RuViewTier::Fail;
}
if !joint_error.passes {
return RuViewTier::Bronze;
}
if !vital_signs.passes {
return RuViewTier::Silver;
}
RuViewTier::Gold
}
fn compute_bbox_diag(kp: &Array2<f32>, vis: &Array1<f32>) -> f32 {
let mut x_min = f32::MAX;
let mut x_max = f32::MIN;
let mut y_min = f32::MAX;
let mut y_max = f32::MIN;
let mut any = false;
for j in 0..17.min(kp.shape()[0]) {
if vis[j] >= 0.5 {
let x = kp[[j, 0]];
let y = kp[[j, 1]];
x_min = x_min.min(x);
x_max = x_max.max(x);
y_min = y_min.min(y);
y_max = y_max.max(y);
any = true;
}
}
if !any {
return 0.0;
}
let w = (x_max - x_min).max(0.0);
let h = (y_max - y_min).max(0.0);
(w * w + h * h).sqrt()
}
fn compute_single_oks(pred: &Array2<f32>, gt: &Array2<f32>, vis: &Array1<f32>, s: f32) -> f32 {
let s_sq = s * s;
let mut num = 0.0_f32;
let mut den = 0.0_f32;
for j in 0..17 {
if vis[j] < 0.5 {
continue;
}
den += 1.0;
let dx = pred[[j, 0]] - gt[[j, 0]];
let dy = pred[[j, 1]] - gt[[j, 1]];
let d_sq = dx * dx + dy * dy;
let k = COCO_SIGMAS[j];
num += (-d_sq / (2.0 * s_sq * k * k)).exp();
}
if den > 0.0 { num / den } else { 0.0 }
}
fn compute_torso_jitter(pred_kpts: &[Array2<f32>], visibility: &[Array1<f32>]) -> f32 {
if pred_kpts.len() < 2 {
return 0.0;
}
let centroids: Vec<Option<(f32, f32)>> = pred_kpts
.iter()
.zip(visibility.iter())
.map(|(kp, vis)| {
let mut cx = 0.0_f32;
let mut cy = 0.0_f32;
let mut count = 0_usize;
for &idx in &TORSO_INDICES {
if vis[idx] >= 0.5 {
cx += kp[[idx, 0]];
cy += kp[[idx, 1]];
count += 1;
}
}
if count > 0 {
Some((cx / count as f32, cy / count as f32))
} else {
None
}
})
.collect();
let mut sum_sq = 0.0_f64;
let mut n_pairs = 0_usize;
for i in 1..centroids.len() {
if let (Some((x0, y0)), Some((x1, y1))) = (centroids[i - 1], centroids[i]) {
let dx = (x1 - x0) as f64;
let dy = (y1 - y0) as f64;
sum_sq += dx * dx + dy * dy;
n_pairs += 1;
}
}
if n_pairs == 0 {
return 0.0;
}
(sum_sq / n_pairs as f64).sqrt() as f32
}
fn compute_p95_max_error(per_kp_errors: &[Vec<f32>]) -> f32 {
let mut all_errors: Vec<f32> = per_kp_errors.iter().flat_map(|e| e.iter().copied()).collect();
if all_errors.is_empty() {
return 0.0;
}
all_errors.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((all_errors.len() as f64 * 0.95) as usize).min(all_errors.len() - 1);
all_errors[idx]
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
fn make_perfect_kpts() -> (Array2<f32>, Array2<f32>, Array1<f32>) {
let kp = Array2::from_shape_fn((17, 2), |(j, d)| {
if d == 0 { j as f32 * 0.05 } else { j as f32 * 0.03 }
});
let vis = Array1::ones(17);
(kp.clone(), kp, vis)
}
fn make_noisy_kpts(noise: f32) -> (Array2<f32>, Array2<f32>, Array1<f32>) {
let gt = Array2::from_shape_fn((17, 2), |(j, d)| {
if d == 0 { j as f32 * 0.03 } else { j as f32 * 0.02 }
});
let pred = Array2::from_shape_fn((17, 2), |(j, d)| {
gt[[j, d]] + noise * ((j * 7 + d * 3) as f32).sin()
});
let vis = Array1::ones(17);
(pred, gt, vis)
}
#[test]
fn joint_error_perfect_predictions_pass() {
let (pred, gt, vis) = make_perfect_kpts();
let result = evaluate_joint_error(
&[pred],
&[gt],
&[vis],
&[1.0],
&JointErrorThresholds::default(),
);
assert_eq!(result.pck_all, 1.0, "perfect predictions should have PCK=1.0");
assert!((result.oks - 1.0).abs() < 1e-3, "perfect predictions should have OKS~1.0");
}
#[test]
fn joint_error_empty_returns_fail() {
let result = evaluate_joint_error(
&[],
&[],
&[],
&[],
&JointErrorThresholds::default(),
);
assert!(!result.passes);
}
#[test]
fn joint_error_noisy_predictions_lower_pck() {
let (pred, gt, vis) = make_noisy_kpts(0.5);
let result = evaluate_joint_error(
&[pred],
&[gt],
&[vis],
&[1.0],
&JointErrorThresholds::default(),
);
assert!(result.pck_all < 1.0, "noisy predictions should have PCK < 1.0");
}
#[test]
fn tracking_no_id_switches_pass() {
let frames: Vec<TrackingFrame> = (0..100)
.map(|i| TrackingFrame {
frame_idx: i,
gt_ids: vec![1, 2],
pred_ids: vec![1, 2],
assignments: vec![(1, 1), (2, 2)],
})
.collect();
let result = evaluate_tracking(&frames, 1.0, &TrackingThresholds::default());
assert_eq!(result.id_switches, 0);
assert!(result.passes);
}
#[test]
fn tracking_id_switches_detected() {
let mut frames: Vec<TrackingFrame> = (0..10)
.map(|i| TrackingFrame {
frame_idx: i,
gt_ids: vec![1, 2],
pred_ids: vec![1, 2],
assignments: vec![(1, 1), (2, 2)],
})
.collect();
frames[5].assignments = vec![(2, 1), (1, 2)];
let result = evaluate_tracking(&frames, 1.0, &TrackingThresholds::default());
assert!(result.id_switches >= 1, "should detect ID switch at frame 5");
assert!(!result.passes, "ID switches should cause failure");
}
#[test]
fn tracking_empty_returns_fail() {
let result = evaluate_tracking(&[], 1.0, &TrackingThresholds::default());
assert!(!result.passes);
}
#[test]
fn vital_signs_accurate_breathing_passes() {
let measurements = vec![
VitalSignMeasurement {
breathing_bpm: 15.0,
gt_breathing_bpm: 14.5,
breathing_snr_db: 10.0,
heartbeat_bpm: None,
gt_heartbeat_bpm: None,
heartbeat_snr_db: None,
},
VitalSignMeasurement {
breathing_bpm: 16.0,
gt_breathing_bpm: 15.5,
breathing_snr_db: 8.0,
heartbeat_bpm: None,
gt_heartbeat_bpm: None,
heartbeat_snr_db: None,
},
];
let result = evaluate_vital_signs(&measurements, &VitalSignThresholds::default());
assert!(result.breathing_error_bpm <= 2.0);
assert!(result.passes);
}
#[test]
fn vital_signs_inaccurate_breathing_fails() {
let measurements = vec![VitalSignMeasurement {
breathing_bpm: 25.0,
gt_breathing_bpm: 15.0,
breathing_snr_db: 10.0,
heartbeat_bpm: None,
gt_heartbeat_bpm: None,
heartbeat_snr_db: None,
}];
let result = evaluate_vital_signs(&measurements, &VitalSignThresholds::default());
assert!(!result.passes, "10 BPM error should fail");
}
#[test]
fn vital_signs_empty_returns_fail() {
let result = evaluate_vital_signs(&[], &VitalSignThresholds::default());
assert!(!result.passes);
}
#[test]
fn tier_determination_gold() {
let je = JointErrorResult {
pck_all: 0.85,
pck_torso: 0.90,
oks: 0.65,
jitter_rms_m: 0.01,
max_error_p95_m: 0.10,
passes: true,
};
let tr = TrackingResult {
id_switches: 0,
fragmentation_ratio: 0.01,
false_tracks_per_min: 0.0,
mota: 0.95,
n_frames: 1000,
passes: true,
};
let vs = VitalSignResult {
breathing_error_bpm: 1.0,
breathing_snr_db: 8.0,
heartbeat_error_bpm: Some(3.0),
heartbeat_snr_db: Some(4.0),
n_measurements: 10,
passes: true,
};
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Gold);
}
#[test]
fn tier_determination_silver() {
let je = JointErrorResult { passes: true, ..Default::default() };
let tr = TrackingResult { passes: true, ..Default::default() };
let vs = VitalSignResult { passes: false, ..Default::default() };
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Silver);
}
#[test]
fn tier_determination_bronze() {
let je = JointErrorResult { passes: false, ..Default::default() };
let tr = TrackingResult { passes: true, ..Default::default() };
let vs = VitalSignResult { passes: false, ..Default::default() };
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Bronze);
}
#[test]
fn tier_determination_fail() {
let je = JointErrorResult { passes: true, ..Default::default() };
let tr = TrackingResult { passes: false, ..Default::default() };
let vs = VitalSignResult { passes: true, ..Default::default() };
assert_eq!(determine_tier(&je, &tr, &vs), RuViewTier::Fail);
}
#[test]
fn tier_ordering() {
assert!(RuViewTier::Gold > RuViewTier::Silver);
assert!(RuViewTier::Silver > RuViewTier::Bronze);
assert!(RuViewTier::Bronze > RuViewTier::Fail);
}
impl Default for JointErrorResult {
fn default() -> Self {
JointErrorResult {
pck_all: 0.0,
pck_torso: 0.0,
oks: 0.0,
jitter_rms_m: 0.0,
max_error_p95_m: 0.0,
passes: false,
}
}
}
impl Default for TrackingResult {
fn default() -> Self {
TrackingResult {
id_switches: 0,
fragmentation_ratio: 0.0,
false_tracks_per_min: 0.0,
mota: 0.0,
n_frames: 0,
passes: false,
}
}
}
impl Default for VitalSignResult {
fn default() -> Self {
VitalSignResult {
breathing_error_bpm: 0.0,
breathing_snr_db: 0.0,
heartbeat_error_bpm: None,
heartbeat_snr_db: None,
n_measurements: 0,
passes: false,
}
}
}
}