use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiaConfig {
pub auc_threshold: f64,
pub k_neighbors: usize,
}
impl Default for MiaConfig {
fn default() -> Self {
Self {
auc_threshold: 0.6,
k_neighbors: 1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiaResults {
pub auc_roc: f64,
pub accuracy: f64,
pub precision: f64,
pub recall: f64,
pub passes: bool,
pub n_members: usize,
pub n_non_members: usize,
pub auc_threshold: f64,
}
pub struct MembershipInferenceAttack {
config: MiaConfig,
}
impl MembershipInferenceAttack {
pub fn new(config: MiaConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(MiaConfig::default())
}
pub fn evaluate(
&self,
members: &[Vec<f64>],
non_members: &[Vec<f64>],
synthetic: &[Vec<f64>],
) -> MiaResults {
if members.is_empty() || non_members.is_empty() || synthetic.is_empty() {
return MiaResults {
auc_roc: 0.5,
accuracy: 0.5,
precision: 0.0,
recall: 0.0,
passes: true,
n_members: members.len(),
n_non_members: non_members.len(),
auc_threshold: self.config.auc_threshold,
};
}
let member_distances: Vec<f64> = members
.iter()
.map(|record| self.knn_distance(record, synthetic))
.collect();
let non_member_distances: Vec<f64> = non_members
.iter()
.map(|record| self.knn_distance(record, synthetic))
.collect();
let mut scored: Vec<(f64, bool)> = Vec::with_capacity(members.len() + non_members.len());
for d in &member_distances {
scored.push((-d, true)); }
for d in &non_member_distances {
scored.push((-d, false));
}
let auc_roc = compute_auc(&scored);
let (accuracy, precision, recall) = compute_best_threshold_metrics(&scored);
MiaResults {
auc_roc,
accuracy,
precision,
recall,
passes: auc_roc <= self.config.auc_threshold,
n_members: members.len(),
n_non_members: non_members.len(),
auc_threshold: self.config.auc_threshold,
}
}
fn knn_distance(&self, record: &[f64], dataset: &[Vec<f64>]) -> f64 {
let mut distances: Vec<f64> = dataset
.iter()
.map(|other| euclidean_distance(record, other))
.collect();
distances.sort_by(f64::total_cmp);
let k = self.config.k_neighbors.min(distances.len());
if k == 0 {
return f64::MAX;
}
distances[..k].iter().sum::<f64>() / k as f64
}
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
let len = a.len().min(b.len());
let sum: f64 = (0..len).map(|i| (a[i] - b[i]).powi(2)).sum();
sum.sqrt()
}
fn compute_auc(scored: &[(f64, bool)]) -> f64 {
if scored.is_empty() {
return 0.5;
}
let mut sorted = scored.to_vec();
sorted.sort_by(|a, b| b.0.total_cmp(&a.0));
let total_pos = sorted.iter().filter(|s| s.1).count() as f64;
let total_neg = sorted.iter().filter(|s| !s.1).count() as f64;
if total_pos == 0.0 || total_neg == 0.0 {
return 0.5;
}
let mut auc = 0.0;
let mut tp = 0.0;
let mut fp = 0.0;
let mut prev_fpr = 0.0;
let mut prev_tpr = 0.0;
for &(_, is_pos) in &sorted {
if is_pos {
tp += 1.0;
} else {
fp += 1.0;
}
let tpr = tp / total_pos;
let fpr = fp / total_neg;
auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
prev_fpr = fpr;
prev_tpr = tpr;
}
auc
}
fn compute_best_threshold_metrics(scored: &[(f64, bool)]) -> (f64, f64, f64) {
if scored.is_empty() {
return (0.5, 0.0, 0.0);
}
let mut sorted = scored.to_vec();
sorted.sort_by(|a, b| b.0.total_cmp(&a.0));
let total = sorted.len() as f64;
let total_pos = sorted.iter().filter(|s| s.1).count() as f64;
let mut best_accuracy = 0.0;
let mut best_precision = 0.0;
let mut best_recall = 0.0;
let mut tp = 0.0;
for (i, &(_, is_pos)) in sorted.iter().enumerate() {
if is_pos {
tp += 1.0;
}
let predicted_pos = (i + 1) as f64;
let fn_count = total_pos - tp;
let tn = total - predicted_pos - fn_count;
let accuracy = (tp + tn) / total;
let precision = if predicted_pos > 0.0 {
tp / predicted_pos
} else {
0.0
};
let recall = if total_pos > 0.0 { tp / total_pos } else { 0.0 };
if accuracy > best_accuracy {
best_accuracy = accuracy;
best_precision = precision;
best_recall = recall;
}
}
(best_accuracy, best_precision, best_recall)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_exact_copy_high_auc() {
let members: Vec<Vec<f64>> = (0..50)
.map(|i| vec![i as f64, (i * 2) as f64, (i * 3) as f64])
.collect();
let non_members: Vec<Vec<f64>> = (100..150)
.map(|i| vec![i as f64, (i * 2) as f64, (i * 3) as f64])
.collect();
let synthetic = members.clone();
let attack = MembershipInferenceAttack::with_defaults();
let results = attack.evaluate(&members, &non_members, &synthetic);
assert!(
results.auc_roc > 0.8,
"Expected high AUC for exact copies, got {}",
results.auc_roc
);
assert!(
!results.passes,
"Should NOT pass privacy check for exact copies"
);
}
#[test]
fn test_random_data_low_auc() {
let members: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let non_members: Vec<Vec<f64>> =
(50..100).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let synthetic: Vec<Vec<f64>> = (200..300).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let attack = MembershipInferenceAttack::with_defaults();
let results = attack.evaluate(&members, &non_members, &synthetic);
assert!(
results.auc_roc < 0.7,
"Expected low AUC for unrelated data, got {}",
results.auc_roc
);
}
#[test]
fn test_empty_inputs() {
let attack = MembershipInferenceAttack::with_defaults();
let results = attack.evaluate(&[], &[], &[]);
assert!(results.passes);
assert!((results.auc_roc - 0.5).abs() < 1e-10);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.0).abs() < 1e-10);
}
#[test]
fn test_compute_auc_perfect() {
let scored: Vec<(f64, bool)> = vec![
(1.0, true),
(0.9, true),
(0.8, true),
(0.3, false),
(0.2, false),
(0.1, false),
];
let auc = compute_auc(&scored);
assert!(
(auc - 1.0).abs() < 1e-10,
"Perfect AUC should be 1.0, got {}",
auc
);
}
#[test]
fn test_compute_auc_random() {
let scored: Vec<(f64, bool)> = vec![
(0.6, true),
(0.5, false),
(0.4, true),
(0.3, false),
(0.2, true),
(0.1, false),
];
let auc = compute_auc(&scored);
assert!(
(auc - 0.5).abs() < 0.2,
"Near-random AUC should be around 0.5, got {}",
auc
);
}
#[test]
fn test_mia_config_serde() {
let config = MiaConfig::default();
let json = serde_json::to_string(&config).unwrap();
let parsed: MiaConfig = serde_json::from_str(&json).unwrap();
assert!((parsed.auc_threshold - 0.6).abs() < 1e-10);
assert_eq!(parsed.k_neighbors, 1);
}
#[test]
fn test_mia_results_serde() {
let results = MiaResults {
auc_roc: 0.55,
accuracy: 0.52,
precision: 0.51,
recall: 0.53,
passes: true,
n_members: 100,
n_non_members: 100,
auc_threshold: 0.6,
};
let json = serde_json::to_string(&results).unwrap();
let parsed: MiaResults = serde_json::from_str(&json).unwrap();
assert!((parsed.auc_roc - 0.55).abs() < 1e-10);
assert!(parsed.passes);
}
}