use std::path::Path;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProbeConfig {
pub n_features: usize,
#[serde(default = "default_n_probes")]
pub n_probes: usize,
#[serde(default = "default_perturbation_budget")]
pub perturbation_budget: f64,
#[serde(default = "default_threshold")]
pub threshold: f64,
#[serde(default)]
pub target_class: usize,
}
fn default_n_probes() -> usize {
1000
}
fn default_perturbation_budget() -> f64 {
0.05
}
fn default_threshold() -> f64 {
0.5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProbeSample {
pub features: Vec<f32>,
pub score: f32,
pub predicted_positive: bool,
pub margin: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionStats {
pub mean_score: f64,
pub std_score: f64,
pub positive_rate: f64,
pub boundary_samples: usize,
pub mean_margin: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProbeResult {
pub samples: Vec<ProbeSample>,
pub stats: PredictionStats,
pub config: ModelProbeConfig,
}
pub struct ModelProbe {
session: ort::session::Session,
config: ModelProbeConfig,
}
impl ModelProbe {
pub fn load(model_path: &Path, config: ModelProbeConfig) -> Result<Self, String> {
let session = ort::session::Session::builder()
.and_then(|mut b| b.commit_from_file(model_path))
.map_err(|e| {
format!(
"Failed to load ONNX model from {}: {e}",
model_path.display()
)
})?;
Ok(Self { session, config })
}
pub fn predict(&mut self, features: &[Vec<f32>]) -> Result<Vec<f32>, String> {
if features.is_empty() {
return Ok(vec![]);
}
let n_samples = features.len();
let n_features = self.config.n_features;
let flat: Vec<f32> = features.iter().flat_map(|r| r.iter().copied()).collect();
let input_tensor =
ort::value::Tensor::<f32>::from_array(([n_samples as i64, n_features as i64], flat))
.map_err(|e| format!("Failed to create input tensor: {e}"))?;
let outputs = self
.session
.run(ort::inputs![input_tensor])
.map_err(|e| format!("Inference failed: {e}"))?;
let (_name, output) = outputs.iter().next().ok_or("Model returned no outputs")?;
let (_shape, scores_view) = output
.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract output tensor: {e}"))?;
let scores: Vec<f32> = scores_view.to_vec();
if scores.len() == n_samples {
Ok(scores)
} else if scores.len() >= n_samples * (self.config.target_class + 1) {
let n_classes = scores.len() / n_samples;
Ok(scores
.chunks(n_classes)
.map(|chunk| chunk.get(self.config.target_class).copied().unwrap_or(0.0))
.collect())
} else {
Ok(scores.into_iter().take(n_samples).collect())
}
}
pub fn probe(&mut self, seed_samples: &[Vec<f32>], seed: u64) -> Result<ProbeResult, String> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Normal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).map_err(|e| format!("Normal dist: {e}"))?;
let n_features = self.config.n_features;
let budget = self.config.perturbation_budget as f32;
let threshold = self.config.threshold as f32;
let mut probe_features = Vec::with_capacity(self.config.n_probes);
for i in 0..self.config.n_probes {
let base = if seed_samples.is_empty() {
(0..n_features)
.map(|_| {
let v: f64 = normal.sample(&mut rng);
v as f32
})
.collect::<Vec<_>>()
} else {
seed_samples[i % seed_samples.len()].clone()
};
let perturbed: Vec<f32> = base
.iter()
.map(|&v| {
let delta: f64 = normal.sample(&mut rng);
v + budget * v.abs().max(1.0) * delta as f32
})
.collect();
probe_features.push(perturbed);
}
let scores = self.predict(&probe_features)?;
let samples: Vec<ProbeSample> = probe_features
.iter()
.zip(scores.iter())
.map(|(feat, &score)| {
let margin = (score - threshold).abs();
ProbeSample {
features: feat.clone(),
score,
predicted_positive: score >= threshold,
margin,
}
})
.collect();
let stats = compute_stats(&samples);
Ok(ProbeResult {
samples,
stats,
config: self.config.clone(),
})
}
pub fn boundary_samples(result: &ProbeResult, top_n: usize) -> Vec<&ProbeSample> {
let mut sorted: Vec<&ProbeSample> = result.samples.iter().collect();
sorted.sort_by(|a, b| {
a.margin
.partial_cmp(&b.margin)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.truncate(top_n);
sorted
}
}
fn compute_stats(samples: &[ProbeSample]) -> PredictionStats {
if samples.is_empty() {
return PredictionStats {
mean_score: 0.0,
std_score: 0.0,
positive_rate: 0.0,
boundary_samples: 0,
mean_margin: 0.0,
};
}
let n = samples.len() as f64;
let mean_score = samples.iter().map(|s| s.score as f64).sum::<f64>() / n;
let std_score = (samples
.iter()
.map(|s| (s.score as f64 - mean_score).powi(2))
.sum::<f64>()
/ n)
.sqrt();
let positive_rate = samples.iter().filter(|s| s.predicted_positive).count() as f64 / n;
let boundary_samples = samples.iter().filter(|s| s.margin < 0.1).count();
let mean_margin = samples.iter().map(|s| s.margin as f64).sum::<f64>() / n;
PredictionStats {
mean_score,
std_score,
positive_rate,
boundary_samples,
mean_margin,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_probe_config_defaults() {
let config = ModelProbeConfig {
n_features: 10,
n_probes: default_n_probes(),
perturbation_budget: default_perturbation_budget(),
threshold: default_threshold(),
target_class: 0,
};
assert_eq!(config.n_probes, 1000);
assert!((config.perturbation_budget - 0.05).abs() < 1e-10);
assert!((config.threshold - 0.5).abs() < 1e-10);
}
#[test]
fn test_prediction_stats_serialization() {
let stats = PredictionStats {
mean_score: 0.45,
std_score: 0.15,
positive_rate: 0.3,
boundary_samples: 50,
mean_margin: 0.12,
};
let json = serde_json::to_string(&stats).unwrap();
let parsed: PredictionStats = serde_json::from_str(&json).unwrap();
assert!((parsed.mean_score - 0.45).abs() < 1e-10);
}
#[test]
fn test_boundary_samples_sorting() {
let result = ProbeResult {
samples: vec![
ProbeSample {
features: vec![1.0],
score: 0.9,
predicted_positive: true,
margin: 0.4,
},
ProbeSample {
features: vec![2.0],
score: 0.51,
predicted_positive: true,
margin: 0.01,
},
ProbeSample {
features: vec![3.0],
score: 0.3,
predicted_positive: false,
margin: 0.2,
},
],
stats: compute_stats(&[]),
config: ModelProbeConfig {
n_features: 1,
n_probes: 3,
perturbation_budget: 0.05,
threshold: 0.5,
target_class: 0,
},
};
let top = ModelProbe::boundary_samples(&result, 2);
assert_eq!(top.len(), 2);
assert!(top[0].margin <= top[1].margin);
assert!((top[0].margin - 0.01).abs() < 1e-5);
}
#[test]
fn test_compute_stats_empty() {
let stats = compute_stats(&[]);
assert!((stats.mean_score).abs() < 1e-10);
assert_eq!(stats.boundary_samples, 0);
}
#[test]
fn test_compute_stats_basic() {
let samples = vec![
ProbeSample {
features: vec![],
score: 0.8,
predicted_positive: true,
margin: 0.3,
},
ProbeSample {
features: vec![],
score: 0.2,
predicted_positive: false,
margin: 0.3,
},
];
let stats = compute_stats(&samples);
assert!((stats.mean_score - 0.5).abs() < 1e-6);
assert!((stats.positive_rate - 0.5).abs() < 1e-10);
}
}