use crate::bias_correction::{estimate_success_rate, EstimationResult};
use crate::error::{JudgyError, Result};
use crate::synthetic::{run_sensitivity_experiment, SensitivityConfig, SyntheticConfig};
use crate::utils::{
format_float, format_percentage, load_binary_from_csv, parse_binary_string, parse_range,
validate_probability,
};
use crate::{DEFAULT_BOOTSTRAP_ITERATIONS, DEFAULT_CONFIDENCE_LEVEL};
use clap::{Parser, Subcommand};
use serde_json;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
#[derive(Parser)]
#[command(name = "llmjury")]
#[command(
about = "A Rust CLI tool for estimating success rates when using LLM judges for evaluation"
)]
#[command(version = crate::VERSION)]
#[command(long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}
#[derive(Subcommand)]
pub enum Commands {
Estimate(EstimateArgs),
SynthExperiment(SynthExperimentArgs),
}
#[derive(Parser)]
pub struct EstimateArgs {
#[arg(long, conflicts_with = "test_labels_file")]
pub test_labels: Option<String>,
#[arg(long, conflicts_with = "test_preds_file")]
pub test_preds: Option<String>,
#[arg(long, conflicts_with = "unlabeled_preds_file")]
pub unlabeled_preds: Option<String>,
#[arg(long, conflicts_with = "test_labels")]
pub test_labels_file: Option<PathBuf>,
#[arg(long, conflicts_with = "test_preds")]
pub test_preds_file: Option<PathBuf>,
#[arg(long, conflicts_with = "unlabeled_preds")]
pub unlabeled_preds_file: Option<PathBuf>,
#[arg(long, default_value_t = DEFAULT_BOOTSTRAP_ITERATIONS)]
pub bootstrap_iterations: usize,
#[arg(long, default_value_t = DEFAULT_CONFIDENCE_LEVEL)]
pub confidence_level: f64,
#[arg(long)]
pub output: Option<PathBuf>,
#[arg(long, default_value = "text")]
pub format: String,
}
#[derive(Parser)]
pub struct SynthExperimentArgs {
#[arg(long, default_value_t = 0.1)]
pub true_failure_rate: f64,
#[arg(long, default_value = "0.5,1.0")]
pub tpr_range: String,
#[arg(long, default_value = "0.5,1.0")]
pub tnr_range: String,
#[arg(long, default_value_t = 10)]
pub n_points: usize,
#[arg(long, default_value_t = 100)]
pub n_test_positive: usize,
#[arg(long, default_value_t = 100)]
pub n_test_negative: usize,
#[arg(long, default_value_t = 1000)]
pub n_unlabeled: usize,
#[arg(long, default_value_t = 2000)]
pub bootstrap_iterations: usize,
#[arg(long)]
pub seed: Option<u64>,
#[arg(long)]
pub output: Option<PathBuf>,
}
impl EstimateArgs {
pub fn load_data(&self) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let test_labels = if let Some(ref file) = self.test_labels_file {
load_binary_from_csv(file)?
} else if let Some(ref data) = self.test_labels {
parse_binary_string(data)?
} else {
return Err(JudgyError::input_validation(
"Must provide either --test-labels or --test-labels-file".to_string(),
));
};
let test_preds = if let Some(ref file) = self.test_preds_file {
load_binary_from_csv(file)?
} else if let Some(ref data) = self.test_preds {
parse_binary_string(data)?
} else {
return Err(JudgyError::input_validation(
"Must provide either --test-preds or --test-preds-file".to_string(),
));
};
let unlabeled_preds = if let Some(ref file) = self.unlabeled_preds_file {
load_binary_from_csv(file)?
} else if let Some(ref data) = self.unlabeled_preds {
parse_binary_string(data)?
} else {
return Err(JudgyError::input_validation(
"Must provide either --unlabeled-preds or --unlabeled-preds-file".to_string(),
));
};
Ok((test_labels, test_preds, unlabeled_preds))
}
pub fn validate(&self) -> Result<()> {
validate_probability(self.confidence_level, "confidence_level")?;
if self.bootstrap_iterations == 0 {
return Err(JudgyError::input_validation(
"bootstrap_iterations must be positive".to_string(),
));
}
match self.format.as_str() {
"text" | "json" | "csv" => Ok(()),
_ => Err(JudgyError::input_validation(format!(
"Invalid format '{}'. Must be one of: text, json, csv",
self.format
))),
}
}
}
impl SynthExperimentArgs {
pub fn create_config(&self) -> Result<(SensitivityConfig, SensitivityConfig)> {
validate_probability(self.true_failure_rate, "true_failure_rate")?;
let tpr_range = parse_range(&self.tpr_range)?;
let tnr_range = parse_range(&self.tnr_range)?;
validate_probability(tpr_range.0, "tpr_range.min")?;
validate_probability(tpr_range.1, "tpr_range.max")?;
validate_probability(tnr_range.0, "tnr_range.min")?;
validate_probability(tnr_range.1, "tnr_range.max")?;
if self.n_points == 0 {
return Err(JudgyError::input_validation(
"n_points must be positive".to_string(),
));
}
if self.n_test_positive == 0 || self.n_test_negative == 0 {
return Err(JudgyError::input_validation(
"n_test_positive and n_test_negative must be positive".to_string(),
));
}
if self.n_unlabeled == 0 {
return Err(JudgyError::input_validation(
"n_unlabeled must be positive".to_string(),
));
}
if self.bootstrap_iterations == 0 {
return Err(JudgyError::input_validation(
"bootstrap_iterations must be positive".to_string(),
));
}
let test_config = SyntheticConfig {
n_positive: self.n_test_positive,
n_negative: self.n_test_negative,
true_positive_rate: 0.8, true_negative_rate: 0.85, random_seed: self.seed,
};
let tpr_config = SensitivityConfig {
true_pass_rate: 1.0 - self.true_failure_rate,
test_range: tpr_range,
fixed_value: (tnr_range.0 + tnr_range.1) / 2.0, vary_tpr: true,
n_points: self.n_points,
test_config: test_config.clone(),
n_unlabeled: self.n_unlabeled,
bootstrap_iterations: self.bootstrap_iterations,
random_seed: self.seed,
};
let tnr_config = SensitivityConfig {
true_pass_rate: 1.0 - self.true_failure_rate,
test_range: tnr_range,
fixed_value: (tpr_range.0 + tpr_range.1) / 2.0, vary_tpr: false,
n_points: self.n_points,
test_config,
n_unlabeled: self.n_unlabeled,
bootstrap_iterations: self.bootstrap_iterations,
random_seed: self.seed,
};
Ok((tpr_config, tnr_config))
}
}
pub fn run_estimate(args: &EstimateArgs) -> Result<()> {
args.validate()?;
let (test_labels, test_preds, unlabeled_preds) = args.load_data()?;
let result = estimate_success_rate(
&test_labels,
&test_preds,
&unlabeled_preds,
args.bootstrap_iterations,
args.confidence_level,
)?;
match args.format.as_str() {
"text" => output_text(&result),
"json" => output_json(&result, args.output.as_ref())?,
"csv" => output_csv(&result, args.output.as_ref())?,
_ => unreachable!(), }
Ok(())
}
pub fn run_synth_experiment(args: &SynthExperimentArgs) -> Result<()> {
let (tpr_config, tnr_config) = args.create_config()?;
println!("Running TPR sensitivity experiment...");
let tpr_result = run_sensitivity_experiment(&tpr_config)?;
println!("Running TNR sensitivity experiment...");
let tnr_result = run_sensitivity_experiment(&tnr_config)?;
if let Some(ref output_path) = args.output {
let combined_results = serde_json::json!({
"tpr_sensitivity": tpr_result,
"tnr_sensitivity": tnr_result,
"metadata": {
"true_pass_rate": 1.0 - args.true_failure_rate,
"tpr_range": args.tpr_range,
"tnr_range": args.tnr_range,
"n_points": args.n_points,
"n_test_positive": args.n_test_positive,
"n_test_negative": args.n_test_negative,
"n_unlabeled": args.n_unlabeled,
"bootstrap_iterations": args.bootstrap_iterations,
"seed": args.seed,
}
});
let mut file = File::create(output_path)?;
writeln!(file, "{}", serde_json::to_string_pretty(&combined_results)?)?;
println!("Results saved to: {}", output_path.display());
} else {
println!("\nTPR Sensitivity Results:");
print_sensitivity_summary(&tpr_result);
println!("\nTNR Sensitivity Results:");
print_sensitivity_summary(&tnr_result);
}
Ok(())
}
fn output_text(result: &EstimationResult) {
println!("Bias-Corrected Success Rate Estimation");
println!("=====================================");
println!();
println!("Point Estimate:");
println!(
" Estimated true pass rate: {}",
format_percentage(result.theta_hat, 3)
);
println!();
println!("Confidence Interval:");
println!(
" {}% confidence interval: [{}, {}]",
format_float(result.confidence_level * 100.0, 1),
format_percentage(result.lower_bound, 3),
format_percentage(result.upper_bound, 3)
);
println!();
println!("Judge Performance:");
println!(
" True Positive Rate (TPR): {}",
format_percentage(result.judge_metrics.tpr, 1)
);
println!(
" True Negative Rate (TNR): {}",
format_percentage(result.judge_metrics.tnr, 1)
);
println!(
" False Positive Rate (FPR): {}",
format_percentage(result.judge_metrics.fpr, 1)
);
println!(
" False Negative Rate (FNR): {}",
format_percentage(result.judge_metrics.fnr, 1)
);
println!(
" Overall Accuracy: {}",
format_percentage(result.judge_metrics.accuracy, 1)
);
println!();
println!("Other Metrics:");
println!(
" Raw observed pass rate: {}",
format_percentage(result.raw_pass_rate, 3)
);
println!(" Bootstrap iterations: {}", result.bootstrap_iterations);
}
fn output_json(result: &EstimationResult, output_path: Option<&PathBuf>) -> Result<()> {
let json_str = serde_json::to_string_pretty(result)?;
if let Some(path) = output_path {
let mut file = File::create(path)?;
writeln!(file, "{}", json_str)?;
println!("Results saved to: {}", path.display());
} else {
println!("{}", json_str);
}
Ok(())
}
fn output_csv(result: &EstimationResult, output_path: Option<&PathBuf>) -> Result<()> {
let csv_content = format!(
"metric,value\n\
theta_hat,{}\n\
lower_bound,{}\n\
upper_bound,{}\n\
confidence_level,{}\n\
tpr,{}\n\
tnr,{}\n\
fpr,{}\n\
fnr,{}\n\
accuracy,{}\n\
raw_pass_rate,{}\n\
bootstrap_iterations,{}\n",
result.theta_hat,
result.lower_bound,
result.upper_bound,
result.confidence_level,
result.judge_metrics.tpr,
result.judge_metrics.tnr,
result.judge_metrics.fpr,
result.judge_metrics.fnr,
result.judge_metrics.accuracy,
result.raw_pass_rate,
result.bootstrap_iterations
);
if let Some(path) = output_path {
let mut file = File::create(path)?;
write!(file, "{}", csv_content)?;
println!("Results saved to: {}", path.display());
} else {
print!("{}", csv_content);
}
Ok(())
}
fn print_sensitivity_summary(result: &crate::synthetic::SensitivityResult) {
let metric_name = if result.config.vary_tpr { "TPR" } else { "TNR" };
let fixed_name = if result.config.vary_tpr { "TNR" } else { "TPR" };
println!(
" {} range: {:.1}% to {:.1}% (fixed {} = {:.1}%)",
metric_name,
result.values.first().unwrap_or(&0.0) * 100.0,
result.values.last().unwrap_or(&0.0) * 100.0,
fixed_name,
result.config.fixed_value * 100.0
);
println!(
" Test set size: {} ({} positive + {} negative)",
result.config.test_config.n_positive + result.config.test_config.n_negative,
result.config.test_config.n_positive,
result.config.test_config.n_negative
);
println!(" Unlabeled set size: {}", result.config.n_unlabeled);
println!(" Example results:");
for (_i, ((((val, est), lower), upper), raw)) in result
.values
.iter()
.zip(&result.estimates)
.zip(&result.lower_bounds)
.zip(&result.upper_bounds)
.zip(&result.raw_rates)
.enumerate()
.take(3)
{
if !est.is_nan() {
println!(
" {}={:.0}%: Raw={:.1}%, Corrected={:.1}% [{:.1}%, {:.1}%]",
metric_name,
val * 100.0,
raw * 100.0,
est * 100.0,
lower * 100.0,
upper * 100.0
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_estimate_args_load_data_from_strings() {
let args = EstimateArgs {
test_labels: Some("1,0,1,0".to_string()),
test_preds: Some("1,0,0,1".to_string()),
unlabeled_preds: Some("1,1,0".to_string()),
test_labels_file: None,
test_preds_file: None,
unlabeled_preds_file: None,
bootstrap_iterations: 1000,
confidence_level: 0.95,
output: None,
format: "text".to_string(),
};
let (test_labels, test_preds, unlabeled_preds) = args.load_data().unwrap();
assert_eq!(test_labels, vec![1, 0, 1, 0]);
assert_eq!(test_preds, vec![1, 0, 0, 1]);
assert_eq!(unlabeled_preds, vec![1, 1, 0]);
}
#[test]
fn test_estimate_args_load_data_from_files() -> Result<()> {
let mut test_labels_file = NamedTempFile::new()?;
writeln!(test_labels_file, "1\n0\n1\n0")?;
let mut test_preds_file = NamedTempFile::new()?;
writeln!(test_preds_file, "1\n0\n0\n1")?;
let mut unlabeled_file = NamedTempFile::new()?;
writeln!(unlabeled_file, "1\n1\n0")?;
let args = EstimateArgs {
test_labels: None,
test_preds: None,
unlabeled_preds: None,
test_labels_file: Some(test_labels_file.path().to_path_buf()),
test_preds_file: Some(test_preds_file.path().to_path_buf()),
unlabeled_preds_file: Some(unlabeled_file.path().to_path_buf()),
bootstrap_iterations: 1000,
confidence_level: 0.95,
output: None,
format: "text".to_string(),
};
let (test_labels, test_preds, unlabeled_preds) = args.load_data()?;
assert_eq!(test_labels, vec![1, 0, 1, 0]);
assert_eq!(test_preds, vec![1, 0, 0, 1]);
assert_eq!(unlabeled_preds, vec![1, 1, 0]);
Ok(())
}
#[test]
fn test_estimate_args_validation() {
let mut args = EstimateArgs {
test_labels: Some("1,0".to_string()),
test_preds: Some("1,0".to_string()),
unlabeled_preds: Some("1,0".to_string()),
test_labels_file: None,
test_preds_file: None,
unlabeled_preds_file: None,
bootstrap_iterations: 1000,
confidence_level: 0.95,
output: None,
format: "text".to_string(),
};
assert!(args.validate().is_ok());
args.confidence_level = 1.5;
assert!(args.validate().is_err());
args.confidence_level = 0.95;
args.bootstrap_iterations = 0;
assert!(args.validate().is_err());
args.bootstrap_iterations = 1000;
args.format = "invalid".to_string();
assert!(args.validate().is_err());
}
#[test]
fn test_synth_experiment_args_create_config() {
let args = SynthExperimentArgs {
true_failure_rate: 0.2,
tpr_range: "0.6,0.9".to_string(),
tnr_range: "0.7,0.95".to_string(),
n_points: 5,
n_test_positive: 50,
n_test_negative: 50,
n_unlabeled: 500,
bootstrap_iterations: 1000,
seed: Some(42),
output: None,
};
let (tpr_config, tnr_config) = args.create_config().unwrap();
assert_eq!(tpr_config.true_pass_rate, 0.8);
assert_eq!(tpr_config.test_range, (0.6, 0.9));
assert_eq!(tpr_config.vary_tpr, true);
assert_eq!(tpr_config.n_points, 5);
assert_eq!(tnr_config.true_pass_rate, 0.8);
assert_eq!(tnr_config.test_range, (0.7, 0.95));
assert_eq!(tnr_config.vary_tpr, false);
assert_eq!(tnr_config.n_points, 5);
}
}