use crate::error::{JudgyError, Result};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EstimationResult {
pub theta_hat: f64,
pub lower_bound: f64,
pub upper_bound: f64,
pub confidence_level: f64,
pub bootstrap_iterations: usize,
pub judge_metrics: JudgeMetrics,
pub raw_pass_rate: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JudgeMetrics {
pub tpr: f64,
pub tnr: f64,
pub fpr: f64,
pub fnr: f64,
pub accuracy: f64,
}
impl JudgeMetrics {
pub fn from_test_data(test_labels: &[u8], test_preds: &[u8]) -> Result<Self> {
if test_labels.len() != test_preds.len() {
return Err(JudgyError::input_validation(format!(
"test_labels and test_preds must have the same length. Got {} and {}",
test_labels.len(),
test_preds.len()
)));
}
let mut tp = 0; let mut fp = 0; let mut tn = 0; let mut fn_ = 0;
for (&label, &pred) in test_labels.iter().zip(test_preds.iter()) {
match (label, pred) {
(1, 1) => tp += 1,
(0, 1) => fp += 1,
(0, 0) => tn += 1,
(1, 0) => fn_ += 1,
_ => {
return Err(JudgyError::input_validation(
"Labels and predictions must be 0 or 1".to_string(),
))
}
}
}
let positive_count = tp + fn_;
let negative_count = tn + fp;
if positive_count == 0 || negative_count == 0 {
return Err(JudgyError::input_validation(
"test_labels must contain both positive and negative examples".to_string(),
));
}
let tpr = tp as f64 / positive_count as f64;
let tnr = tn as f64 / negative_count as f64;
let fpr = fp as f64 / negative_count as f64;
let fnr = fn_ as f64 / positive_count as f64;
let accuracy = (tp + tn) as f64 / test_labels.len() as f64;
Ok(JudgeMetrics {
tpr,
tnr,
fpr,
fnr,
accuracy,
})
}
}
pub fn estimate_success_rate(
test_labels: &[u8],
test_preds: &[u8],
unlabeled_preds: &[u8],
bootstrap_iterations: usize,
confidence_level: f64,
) -> Result<EstimationResult> {
validate_inputs(
test_labels,
test_preds,
unlabeled_preds,
confidence_level,
bootstrap_iterations,
)?;
let judge_metrics = JudgeMetrics::from_test_data(test_labels, test_preds)?;
let tpr_plus_tnr = judge_metrics.tpr + judge_metrics.tnr;
if tpr_plus_tnr <= 1.0 {
return Err(JudgyError::JudgeAccuracyTooLow { tpr_plus_tnr });
}
let raw_pass_rate =
unlabeled_preds.iter().map(|&x| x as f64).sum::<f64>() / unlabeled_preds.len() as f64;
let denominator = tpr_plus_tnr - 1.0;
let theta_hat = ((raw_pass_rate + judge_metrics.tnr - 1.0) / denominator).clamp(0.0, 1.0);
let (lower_bound, upper_bound) = bootstrap_confidence_interval(
test_labels,
test_preds,
raw_pass_rate,
bootstrap_iterations,
confidence_level,
)?;
Ok(EstimationResult {
theta_hat,
lower_bound,
upper_bound,
confidence_level,
bootstrap_iterations,
judge_metrics,
raw_pass_rate,
})
}
fn validate_inputs(
test_labels: &[u8],
test_preds: &[u8],
unlabeled_preds: &[u8],
confidence_level: f64,
bootstrap_iterations: usize,
) -> Result<()> {
if test_labels.len() != test_preds.len() {
return Err(JudgyError::input_validation(format!(
"test_labels and test_preds must have the same length. Got {} and {}",
test_labels.len(),
test_preds.len()
)));
}
if test_labels.is_empty() {
return Err(JudgyError::input_validation(
"test_labels cannot be empty".to_string(),
));
}
if unlabeled_preds.is_empty() {
return Err(JudgyError::input_validation(
"unlabeled_preds cannot be empty".to_string(),
));
}
if !(0.0 < confidence_level && confidence_level < 1.0) {
return Err(JudgyError::input_validation(
"confidence_level must be between 0 and 1 (exclusive)".to_string(),
));
}
if bootstrap_iterations == 0 {
return Err(JudgyError::input_validation(
"bootstrap_iterations must be positive".to_string(),
));
}
for (i, &val) in test_labels.iter().enumerate() {
if val != 0 && val != 1 {
return Err(JudgyError::input_validation(format!(
"test_labels[{}] = {} is not 0 or 1",
i, val
)));
}
}
for (i, &val) in test_preds.iter().enumerate() {
if val != 0 && val != 1 {
return Err(JudgyError::input_validation(format!(
"test_preds[{}] = {} is not 0 or 1",
i, val
)));
}
}
for (i, &val) in unlabeled_preds.iter().enumerate() {
if val != 0 && val != 1 {
return Err(JudgyError::input_validation(format!(
"unlabeled_preds[{}] = {} is not 0 or 1",
i, val
)));
}
}
Ok(())
}
fn bootstrap_confidence_interval(
test_labels: &[u8],
test_preds: &[u8],
observed_pass_rate: f64,
bootstrap_iterations: usize,
confidence_level: f64,
) -> Result<(f64, f64)> {
let mut rng = StdRng::from_entropy();
let test_size = test_labels.len();
let indices: Vec<usize> = (0..test_size).collect();
let mut bootstrap_samples = Vec::new();
for _ in 0..bootstrap_iterations {
let bootstrap_indices: Vec<usize> = indices
.choose_multiple(&mut rng, test_size)
.cloned()
.collect();
let bootstrap_labels: Vec<u8> = bootstrap_indices.iter().map(|&i| test_labels[i]).collect();
let bootstrap_preds: Vec<u8> = bootstrap_indices.iter().map(|&i| test_preds[i]).collect();
match JudgeMetrics::from_test_data(&bootstrap_labels, &bootstrap_preds) {
Ok(metrics) => {
let tpr_plus_tnr = metrics.tpr + metrics.tnr;
if tpr_plus_tnr > 1.0 {
let denominator = tpr_plus_tnr - 1.0;
let theta_bootstrap =
((observed_pass_rate + metrics.tnr - 1.0) / denominator).clamp(0.0, 1.0);
bootstrap_samples.push(theta_bootstrap);
}
}
Err(_) => {
continue;
}
}
}
if bootstrap_samples.is_empty() {
return Err(JudgyError::bootstrap(
"No valid bootstrap samples generated. This may indicate insufficient test data or very poor judge performance.".to_string(),
));
}
bootstrap_samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
let alpha = 1.0 - confidence_level;
let lower_percentile = alpha / 2.0;
let upper_percentile = 1.0 - alpha / 2.0;
let lower_idx = ((bootstrap_samples.len() as f64 * lower_percentile) as usize)
.min(bootstrap_samples.len() - 1);
let upper_idx = ((bootstrap_samples.len() as f64 * upper_percentile) as usize)
.min(bootstrap_samples.len() - 1);
let lower_bound = bootstrap_samples[lower_idx];
let upper_bound = bootstrap_samples[upper_idx];
Ok((lower_bound, upper_bound))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_judge_metrics_perfect_judge() {
let test_labels = vec![1, 1, 1, 0, 0, 0];
let test_preds = vec![1, 1, 1, 0, 0, 0];
let metrics = JudgeMetrics::from_test_data(&test_labels, &test_preds).unwrap();
assert_relative_eq!(metrics.tpr, 1.0);
assert_relative_eq!(metrics.tnr, 1.0);
assert_relative_eq!(metrics.fpr, 0.0);
assert_relative_eq!(metrics.fnr, 0.0);
assert_relative_eq!(metrics.accuracy, 1.0);
}
#[test]
fn test_judge_metrics_random_judge() {
let test_labels = vec![1, 1, 1, 1, 0, 0, 0, 0];
let test_preds = vec![0, 0, 0, 0, 1, 1, 1, 1];
let metrics = JudgeMetrics::from_test_data(&test_labels, &test_preds).unwrap();
assert_relative_eq!(metrics.tpr, 0.0);
assert_relative_eq!(metrics.tnr, 0.0);
assert_relative_eq!(metrics.fpr, 1.0);
assert_relative_eq!(metrics.fnr, 1.0);
assert_relative_eq!(metrics.accuracy, 0.0);
}
#[test]
fn test_estimate_success_rate_basic() {
let test_labels = vec![1, 1, 0, 0, 1, 0, 1, 0];
let test_preds = vec![1, 0, 0, 1, 1, 0, 1, 0];
let unlabeled_preds = vec![1, 1, 0, 1, 0, 1, 0, 1, 1, 0];
let result =
estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95).unwrap();
assert!(result.theta_hat >= 0.0 && result.theta_hat <= 1.0);
assert!(result.lower_bound >= 0.0 && result.lower_bound <= 1.0);
assert!(result.upper_bound >= 0.0 && result.upper_bound <= 1.0);
assert!(result.lower_bound <= result.theta_hat);
assert!(result.theta_hat <= result.upper_bound);
assert_relative_eq!(result.confidence_level, 0.95);
assert_eq!(result.bootstrap_iterations, 100);
}
#[test]
fn test_estimate_success_rate_perfect_judge() {
let test_labels = vec![1, 1, 1, 0, 0, 0];
let test_preds = vec![1, 1, 1, 0, 0, 0]; let unlabeled_preds = vec![1, 1, 0, 0];
let result =
estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95).unwrap();
let observed_rate = 0.5;
assert!((result.theta_hat - observed_rate).abs() < 0.1);
}
#[test]
fn test_input_validation_mismatched_lengths() {
let test_labels = vec![1, 0];
let test_preds = vec![1, 0, 1];
let unlabeled_preds = vec![1, 0, 1];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_input_validation_empty_arrays() {
let empty: Vec<u8> = vec![];
let valid = vec![1, 0, 1, 0];
let result = estimate_success_rate(&empty, &empty, &valid, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
let result = estimate_success_rate(&valid, &valid, &empty, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_input_validation_non_binary() {
let test_labels = vec![1, 2, 0, 1]; let test_preds = vec![1, 0, 1, 1];
let unlabeled_preds = vec![1, 0, 1];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_input_validation_invalid_confidence_level() {
let test_labels = vec![1, 0, 1, 0];
let test_preds = vec![1, 0, 1, 1];
let unlabeled_preds = vec![1, 0, 1];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 1.5);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.0);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_judge_accuracy_too_low() {
let test_labels = vec![1, 1, 1, 1, 0, 0, 0, 0];
let test_preds = vec![0, 0, 0, 0, 1, 1, 1, 1]; let unlabeled_preds = vec![1, 0, 1, 0];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95);
assert!(matches!(
result,
Err(JudgyError::JudgeAccuracyTooLow { .. })
));
}
#[test]
fn test_no_positive_examples() {
let test_labels = vec![0, 0, 0]; let test_preds = vec![1, 0, 1];
let unlabeled_preds = vec![1, 0, 1];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_no_negative_examples() {
let test_labels = vec![1, 1, 1]; let test_preds = vec![1, 0, 1];
let unlabeled_preds = vec![1, 0, 1];
let result = estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 100, 0.95);
assert!(matches!(result, Err(JudgyError::InputValidation(_))));
}
#[test]
fn test_different_confidence_levels() {
let test_labels = vec![1, 1, 0, 0, 1, 0, 1, 0];
let test_preds = vec![1, 0, 0, 1, 1, 0, 1, 0];
let unlabeled_preds = vec![1, 1, 0, 1, 0, 1, 1, 0, 1, 0];
let result_90 =
estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 200, 0.90).unwrap();
let result_99 =
estimate_success_rate(&test_labels, &test_preds, &unlabeled_preds, 200, 0.99).unwrap();
assert!((result_90.theta_hat - result_99.theta_hat).abs() < 0.1);
let width_90 = result_90.upper_bound - result_90.lower_bound;
let width_99 = result_99.upper_bound - result_99.lower_bound;
assert!(width_90 >= 0.0);
assert!(width_99 >= 0.0);
}
}