use crate::bias_correction::estimate_success_rate;
use crate::error::{JudgyError, Result};
use rand::distributions::{Bernoulli, Distribution};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyntheticConfig {
pub n_positive: usize,
pub n_negative: usize,
pub true_positive_rate: f64,
pub true_negative_rate: f64,
pub random_seed: Option<u64>,
}
impl Default for SyntheticConfig {
fn default() -> Self {
Self {
n_positive: 50,
n_negative: 50,
true_positive_rate: 0.8,
true_negative_rate: 0.85,
random_seed: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityResult {
pub values: Vec<f64>,
pub estimates: Vec<f64>,
pub lower_bounds: Vec<f64>,
pub upper_bounds: Vec<f64>,
pub raw_rates: Vec<f64>,
pub config: SensitivityConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityConfig {
pub true_pass_rate: f64,
pub test_range: (f64, f64),
pub fixed_value: f64,
pub vary_tpr: bool,
pub n_points: usize,
pub test_config: SyntheticConfig,
pub n_unlabeled: usize,
pub bootstrap_iterations: usize,
pub random_seed: Option<u64>,
}
impl Default for SensitivityConfig {
fn default() -> Self {
Self {
true_pass_rate: 0.8,
test_range: (0.5, 1.0),
fixed_value: 0.85,
vary_tpr: true,
n_points: 10,
test_config: SyntheticConfig::default(),
n_unlabeled: 1000,
bootstrap_iterations: 2000,
random_seed: None,
}
}
}
pub fn generate_test_data(config: &SyntheticConfig) -> Result<(Vec<u8>, Vec<u8>)> {
validate_rates(config.true_positive_rate, config.true_negative_rate)?;
if config.n_positive == 0 || config.n_negative == 0 {
return Err(JudgyError::input_validation(
"n_positive and n_negative must be positive".to_string(),
));
}
let mut rng = match config.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let mut test_labels = Vec::with_capacity(config.n_positive + config.n_negative);
test_labels.extend(vec![1u8; config.n_positive]);
test_labels.extend(vec![0u8; config.n_negative]);
let mut test_preds = Vec::with_capacity(test_labels.len());
let pos_correct_dist = Bernoulli::new(config.true_positive_rate)
.map_err(|e| JudgyError::config(format!("Invalid TPR: {}", e)))?;
for _ in 0..config.n_positive {
let correct = pos_correct_dist.sample(&mut rng);
test_preds.push(if correct { 1 } else { 0 });
}
let neg_correct_dist = Bernoulli::new(config.true_negative_rate)
.map_err(|e| JudgyError::config(format!("Invalid TNR: {}", e)))?;
for _ in 0..config.n_negative {
let correct = neg_correct_dist.sample(&mut rng);
test_preds.push(if correct { 0 } else { 1 });
}
Ok((test_labels, test_preds))
}
pub fn generate_unlabeled_data(
n_samples: usize,
true_pass_rate: f64,
true_positive_rate: f64,
true_negative_rate: f64,
random_seed: Option<u64>,
) -> Result<Vec<u8>> {
if !(0.0..=1.0).contains(&true_pass_rate) {
return Err(JudgyError::input_validation(
"true_pass_rate must be between 0 and 1".to_string(),
));
}
validate_rates(true_positive_rate, true_negative_rate)?;
if n_samples == 0 {
return Err(JudgyError::input_validation(
"n_samples must be positive".to_string(),
));
}
let mut rng = match random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let label_dist = Bernoulli::new(true_pass_rate)
.map_err(|e| JudgyError::config(format!("Invalid pass rate: {}", e)))?;
let true_labels: Vec<u8> = (0..n_samples)
.map(|_| if label_dist.sample(&mut rng) { 1 } else { 0 })
.collect();
let pos_correct_dist = Bernoulli::new(true_positive_rate)
.map_err(|e| JudgyError::config(format!("Invalid TPR: {}", e)))?;
let neg_correct_dist = Bernoulli::new(true_negative_rate)
.map_err(|e| JudgyError::config(format!("Invalid TNR: {}", e)))?;
let unlabeled_preds: Vec<u8> = true_labels
.iter()
.map(|&label| {
if label == 1 {
if pos_correct_dist.sample(&mut rng) {
1
} else {
0
}
} else {
if neg_correct_dist.sample(&mut rng) {
0
} else {
1
}
}
})
.collect();
Ok(unlabeled_preds)
}
pub fn run_sensitivity_experiment(config: &SensitivityConfig) -> Result<SensitivityResult> {
let mut rng = match config.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let (min_val, max_val) = config.test_range;
if min_val >= max_val {
return Err(JudgyError::config(
"test_range min must be less than max".to_string(),
));
}
let values: Vec<f64> = (0..config.n_points)
.map(|i| {
if config.n_points == 1 {
min_val
} else {
min_val + (max_val - min_val) * i as f64 / (config.n_points - 1) as f64
}
})
.collect();
let mut estimates = Vec::new();
let mut lower_bounds = Vec::new();
let mut upper_bounds = Vec::new();
let mut raw_rates = Vec::new();
for &accuracy_val in &values {
let (tpr, tnr) = if config.vary_tpr {
(accuracy_val, config.fixed_value)
} else {
(config.fixed_value, accuracy_val)
};
let test_config = SyntheticConfig {
true_positive_rate: tpr,
true_negative_rate: tnr,
random_seed: Some(rng.gen()),
..config.test_config.clone()
};
let (test_labels, test_preds) = generate_test_data(&test_config)?;
let unlabeled_preds = generate_unlabeled_data(
config.n_unlabeled,
config.true_pass_rate,
tpr,
tnr,
Some(rng.gen()),
)?;
let raw_success_rate =
unlabeled_preds.iter().map(|&x| x as f64).sum::<f64>() / unlabeled_preds.len() as f64;
raw_rates.push(raw_success_rate);
match estimate_success_rate(
&test_labels,
&test_preds,
&unlabeled_preds,
config.bootstrap_iterations,
0.95, ) {
Ok(result) => {
estimates.push(result.theta_hat);
lower_bounds.push(result.lower_bound);
upper_bounds.push(result.upper_bound);
}
Err(_) => {
estimates.push(f64::NAN);
lower_bounds.push(f64::NAN);
upper_bounds.push(f64::NAN);
}
}
}
Ok(SensitivityResult {
values,
estimates,
lower_bounds,
upper_bounds,
raw_rates,
config: config.clone(),
})
}
pub fn create_example_dataset(
scenario: &str,
random_seed: Option<u64>,
) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let (tpr, tnr, true_rate) = match scenario {
"good_judge" => (0.95, 0.90, 0.8),
"mediocre_judge" => (0.75, 0.70, 0.6),
"biased_judge" => (0.90, 0.60, 0.7), "poor_judge" => (0.60, 0.55, 0.5),
_ => {
return Err(JudgyError::config(format!(
"Unknown scenario '{}'. Choose from: good_judge, mediocre_judge, biased_judge, poor_judge",
scenario
)));
}
};
let test_config = SyntheticConfig {
n_positive: 50,
n_negative: 50,
true_positive_rate: tpr,
true_negative_rate: tnr,
random_seed,
};
let (test_labels, test_preds) = generate_test_data(&test_config)?;
let unlabeled_preds = generate_unlabeled_data(500, true_rate, tpr, tnr, random_seed)?;
Ok((test_labels, test_preds, unlabeled_preds))
}
fn validate_rates(tpr: f64, tnr: f64) -> Result<()> {
if !(0.0..=1.0).contains(&tpr) {
return Err(JudgyError::input_validation(
"true_positive_rate must be between 0 and 1".to_string(),
));
}
if !(0.0..=1.0).contains(&tnr) {
return Err(JudgyError::input_validation(
"true_negative_rate must be between 0 and 1".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_generate_test_data_basic() {
let config = SyntheticConfig {
n_positive: 10,
n_negative: 5,
true_positive_rate: 0.8,
true_negative_rate: 0.9,
random_seed: Some(42),
};
let (test_labels, test_preds) = generate_test_data(&config).unwrap();
assert_eq!(test_labels.len(), 15);
assert_eq!(test_preds.len(), 15);
let positive_count = test_labels.iter().filter(|&&x| x == 1).count();
let negative_count = test_labels.iter().filter(|&&x| x == 0).count();
assert_eq!(positive_count, 10);
assert_eq!(negative_count, 5);
assert!(test_labels.iter().all(|&x| x == 0 || x == 1));
assert!(test_preds.iter().all(|&x| x == 0 || x == 1));
}
#[test]
fn test_generate_test_data_perfect_accuracy() {
let config = SyntheticConfig {
n_positive: 20,
n_negative: 20,
true_positive_rate: 1.0,
true_negative_rate: 1.0,
random_seed: Some(42),
};
let (test_labels, test_preds) = generate_test_data(&config).unwrap();
assert_eq!(test_labels, test_preds);
}
#[test]
fn test_generate_test_data_zero_accuracy() {
let config = SyntheticConfig {
n_positive: 10,
n_negative: 10,
true_positive_rate: 0.0,
true_negative_rate: 0.0,
random_seed: Some(42),
};
let (test_labels, test_preds) = generate_test_data(&config).unwrap();
for (label, pred) in test_labels.iter().zip(test_preds.iter()) {
assert_eq!(*pred, 1 - *label);
}
}
#[test]
fn test_generate_test_data_reproducibility() {
let config = SyntheticConfig {
n_positive: 5,
n_negative: 5,
true_positive_rate: 0.7,
true_negative_rate: 0.8,
random_seed: Some(123),
};
let result1 = generate_test_data(&config).unwrap();
let result2 = generate_test_data(&config).unwrap();
assert_eq!(result1.0, result2.0);
assert_eq!(result1.1, result2.1);
}
#[test]
fn test_generate_test_data_input_validation() {
let mut config = SyntheticConfig::default();
config.true_positive_rate = -0.1;
assert!(generate_test_data(&config).is_err());
config.true_positive_rate = 1.1;
assert!(generate_test_data(&config).is_err());
config.true_positive_rate = 0.8;
config.true_negative_rate = -0.1;
assert!(generate_test_data(&config).is_err());
config.true_negative_rate = 1.1;
assert!(generate_test_data(&config).is_err());
config.true_negative_rate = 0.8;
config.n_positive = 0;
assert!(generate_test_data(&config).is_err());
config.n_positive = 5;
config.n_negative = 0;
assert!(generate_test_data(&config).is_err());
}
#[test]
fn test_generate_unlabeled_data_basic() {
let unlabeled_preds = generate_unlabeled_data(100, 0.6, 0.8, 0.9, Some(42)).unwrap();
assert_eq!(unlabeled_preds.len(), 100);
assert!(unlabeled_preds.iter().all(|&x| x == 0 || x == 1));
}
#[test]
fn test_generate_unlabeled_data_extreme_pass_rates() {
let unlabeled_preds = generate_unlabeled_data(50, 1.0, 1.0, 1.0, Some(42)).unwrap();
assert!(unlabeled_preds.iter().all(|&x| x == 1));
let unlabeled_preds = generate_unlabeled_data(50, 0.0, 1.0, 1.0, Some(42)).unwrap();
assert!(unlabeled_preds.iter().all(|&x| x == 0));
}
#[test]
fn test_generate_unlabeled_data_input_validation() {
assert!(generate_unlabeled_data(10, -0.1, 0.8, 0.8, None).is_err());
assert!(generate_unlabeled_data(10, 1.1, 0.8, 0.8, None).is_err());
assert!(generate_unlabeled_data(10, 0.5, -0.1, 0.8, None).is_err());
assert!(generate_unlabeled_data(10, 0.5, 0.8, 1.1, None).is_err());
assert!(generate_unlabeled_data(0, 0.5, 0.8, 0.8, None).is_err());
}
#[test]
fn test_run_sensitivity_experiment_tpr() {
let config = SensitivityConfig {
test_range: (0.6, 0.8),
n_points: 3,
vary_tpr: true,
fixed_value: 0.9,
bootstrap_iterations: 50,
random_seed: Some(42),
..Default::default()
};
let result = run_sensitivity_experiment(&config).unwrap();
assert_eq!(result.values.len(), 3);
assert_eq!(result.estimates.len(), 3);
assert_eq!(result.lower_bounds.len(), 3);
assert_eq!(result.upper_bounds.len(), 3);
assert_eq!(result.raw_rates.len(), 3);
assert_relative_eq!(result.values[0], 0.6, epsilon = 1e-10);
assert_relative_eq!(result.values[2], 0.8, epsilon = 1e-10);
for &estimate in &result.estimates {
if !estimate.is_nan() {
assert!(estimate >= 0.0 && estimate <= 1.0);
}
}
}
#[test]
fn test_run_sensitivity_experiment_tnr() {
let config = SensitivityConfig {
test_range: (0.6, 0.8),
n_points: 3,
vary_tpr: false,
fixed_value: 0.9,
bootstrap_iterations: 50,
random_seed: Some(42),
..Default::default()
};
let result = run_sensitivity_experiment(&config).unwrap();
assert_eq!(result.values.len(), 3);
assert_eq!(result.estimates.len(), 3);
assert_eq!(result.lower_bounds.len(), 3);
assert_eq!(result.upper_bounds.len(), 3);
}
#[test]
fn test_create_example_dataset_all_scenarios() {
let scenarios = ["good_judge", "mediocre_judge", "biased_judge", "poor_judge"];
for scenario in &scenarios {
let (test_labels, test_preds, unlabeled_preds) =
create_example_dataset(scenario, Some(42)).unwrap();
assert_eq!(test_labels.len(), 100); assert_eq!(test_preds.len(), 100);
assert_eq!(unlabeled_preds.len(), 500);
assert!(test_labels.iter().all(|&x| x == 0 || x == 1));
assert!(test_preds.iter().all(|&x| x == 0 || x == 1));
assert!(unlabeled_preds.iter().all(|&x| x == 0 || x == 1));
let positive_count = test_labels.iter().filter(|&&x| x == 1).count();
assert_eq!(positive_count, 50);
}
}
#[test]
fn test_create_example_dataset_reproducibility() {
let result1 = create_example_dataset("good_judge", Some(123)).unwrap();
let result2 = create_example_dataset("good_judge", Some(123)).unwrap();
assert_eq!(result1.0, result2.0);
assert_eq!(result1.1, result2.1);
assert_eq!(result1.2, result2.2);
}
#[test]
fn test_create_example_dataset_different_scenarios_differ() {
let good_result = create_example_dataset("good_judge", Some(42)).unwrap();
let poor_result = create_example_dataset("poor_judge", Some(42)).unwrap();
assert_ne!(good_result.1, poor_result.1);
}
#[test]
fn test_create_example_dataset_invalid_scenario() {
let result = create_example_dataset("invalid_scenario", Some(42));
assert!(matches!(result, Err(JudgyError::Config(_))));
}
#[test]
fn test_scenario_accuracy_properties() {
let (test_labels, test_preds, _) = create_example_dataset("good_judge", Some(42)).unwrap();
let mut tp = 0;
let mut fp = 0;
let mut tn = 0;
let mut fn_count = 0;
for (&label, &pred) in test_labels.iter().zip(test_preds.iter()) {
match (label, pred) {
(1, 1) => tp += 1,
(0, 1) => fp += 1,
(0, 0) => tn += 1,
(1, 0) => fn_count += 1,
_ => {}
}
}
let tpr = tp as f64 / (tp + fn_count) as f64;
let tnr = tn as f64 / (tn + fp) as f64;
assert!(tpr > 0.8);
assert!(tnr > 0.8);
let (test_labels, test_preds, _) = create_example_dataset("poor_judge", Some(42)).unwrap();
let mut tp_poor = 0;
let mut fp_poor = 0;
let mut tn_poor = 0;
let mut fn_poor = 0;
for (&label, &pred) in test_labels.iter().zip(test_preds.iter()) {
match (label, pred) {
(1, 1) => tp_poor += 1,
(0, 1) => fp_poor += 1,
(0, 0) => tn_poor += 1,
(1, 0) => fn_poor += 1,
_ => {}
}
}
let tpr_poor = tp_poor as f64 / (tp_poor + fn_poor) as f64;
let tnr_poor = tn_poor as f64 / (tn_poor + fp_poor) as f64;
assert!(tpr_poor < tpr);
assert!(tnr_poor < tnr);
}
}