use super::error::{MdapResult, ScalingError};
#[derive(Clone, Debug)]
pub struct MdapEstimate {
pub expected_cost_usd: f64,
pub expected_api_calls: u64,
pub success_probability: f64,
pub recommended_k: u32,
pub estimated_time_seconds: f64,
pub per_step_success: f64,
pub num_steps: u64,
}
pub fn estimate_mdap(
num_steps: u64,
per_step_success_rate: f64,
valid_response_rate: f64,
cost_per_sample_usd: f64,
target_success_rate: f64,
) -> MdapResult<MdapEstimate> {
if num_steps == 0 {
return Err(ScalingError::InvalidStepCount(0).into());
}
if per_step_success_rate <= 0.5 {
return Err(ScalingError::VotingCannotConverge {
p: per_step_success_rate,
}
.into());
}
if per_step_success_rate >= 1.0 {
return Err(ScalingError::InvalidSuccessProbability(per_step_success_rate).into());
}
if target_success_rate <= 0.0 || target_success_rate >= 1.0 {
return Err(ScalingError::InvalidTargetProbability(target_success_rate).into());
}
let p = per_step_success_rate;
let s = num_steps as f64;
let t = target_success_rate;
let v = valid_response_rate.clamp(0.01, 1.0); let c = cost_per_sample_usd;
let k_min = calculate_k_min(s, p, t);
let p_full = calculate_p_full(s, p, k_min);
let expected_cost = (c * s * k_min as f64) / (v * (2.0 * p - 1.0));
let expected_calls = (s * k_min as f64 / v).ceil() as u64;
let time_per_step = 0.5 * (k_min as f64 / 4.0).ceil();
let estimated_time = s * time_per_step;
Ok(MdapEstimate {
expected_cost_usd: expected_cost,
expected_api_calls: expected_calls,
success_probability: p_full,
recommended_k: k_min,
estimated_time_seconds: estimated_time,
per_step_success: p,
num_steps,
})
}
pub fn calculate_k_min(num_steps: f64, p: f64, target: f64) -> u32 {
if p <= 0.5 {
return u32::MAX; }
let ratio = (1.0 - p) / p;
if target >= 0.9999 {
let a = target.powf(-1.0 / num_steps) - 1.0;
if a <= 0.0 || ratio <= 0.0 {
return 10; }
let k = (a.ln() / ratio.ln()).ceil() as u32;
return k.clamp(1, 100); }
let a = target.powf(-1.0 / num_steps) - 1.0;
if a <= 0.0 {
return 1; }
if ratio <= 0.0 || ratio >= 1.0 {
return 1;
}
let k = (a.ln() / ratio.ln()).ceil() as u32;
k.max(1) }
pub fn calculate_p_full(num_steps: f64, p: f64, k: u32) -> f64 {
if p <= 0.5 {
return 0.0; }
let ratio = (1.0 - p) / p;
let ratio_k = if k > 50 {
0.0
} else {
ratio.powi(k as i32)
};
let p_sub = 1.0 / (1.0 + ratio_k);
p_sub.powf(num_steps)
}
pub fn calculate_expected_votes(p: f64, k: u32) -> f64 {
if p <= 0.5 {
return f64::INFINITY;
}
k as f64 / (2.0 * p - 1.0)
}
pub fn estimate_per_step_success(
total_samples: u64,
correct_samples: u64,
red_flagged_samples: u64,
) -> f64 {
let valid = total_samples.saturating_sub(red_flagged_samples);
if valid == 0 {
return 0.5; }
(correct_samples as f64 / valid as f64).clamp(0.0, 1.0)
}
pub fn estimate_valid_response_rate(total_samples: u64, red_flagged_samples: u64) -> f64 {
if total_samples == 0 {
return 0.95; }
let valid = total_samples.saturating_sub(red_flagged_samples);
(valid as f64 / total_samples as f64).clamp(0.01, 1.0)
}
pub fn calculate_expected_cost(
num_steps: u64,
k: u32,
valid_rate: f64,
per_step_success: f64,
cost_per_call: f64,
) -> f64 {
let s = num_steps as f64;
let v = valid_rate.clamp(0.01, 1.0);
let p = per_step_success.clamp(0.51, 0.999);
(cost_per_call * s * k as f64) / (v * (2.0 * p - 1.0))
}
pub fn suggest_k_for_budget(
num_steps: u64,
per_step_success: f64,
valid_rate: f64,
cost_per_call: f64,
budget_usd: f64,
) -> u32 {
let s = num_steps as f64;
let v = valid_rate.clamp(0.01, 1.0);
let p = per_step_success.clamp(0.51, 0.999);
let c = cost_per_call;
let k = (budget_usd * v * (2.0 * p - 1.0)) / (c * s);
(k.floor() as u32).max(1)
}
#[derive(Clone, Debug)]
pub struct ModelCosts {
pub input_per_1k: f64,
pub output_per_1k: f64,
}
impl ModelCosts {
pub fn claude_sonnet() -> Self {
Self {
input_per_1k: 0.003,
output_per_1k: 0.015,
}
}
pub fn claude_haiku() -> Self {
Self {
input_per_1k: 0.00025,
output_per_1k: 0.00125,
}
}
pub fn gpt4o() -> Self {
Self {
input_per_1k: 0.0025,
output_per_1k: 0.01,
}
}
pub fn gpt4o_mini() -> Self {
Self {
input_per_1k: 0.00015,
output_per_1k: 0.0006,
}
}
pub fn estimate_call_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
(input_tokens as f64 / 1000.0 * self.input_per_1k)
+ (output_tokens as f64 / 1000.0 * self.output_per_1k)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_k_min_basic() {
let k = calculate_k_min(100.0, 0.99, 0.95);
assert!(k >= 1);
assert!(k <= 10); }
#[test]
fn test_calculate_k_min_low_p() {
let k = calculate_k_min(10.0, 0.6, 0.95);
assert!(k > 1);
}
#[test]
fn test_calculate_k_min_edge_cases() {
let k = calculate_k_min(10.0, 0.5, 0.95);
assert_eq!(k, u32::MAX);
let k = calculate_k_min(10.0, 0.4, 0.95);
assert_eq!(k, u32::MAX);
}
#[test]
fn test_calculate_p_full() {
let p_full = calculate_p_full(10.0, 0.99, 5);
assert!(p_full > 0.99);
let p_full_low = calculate_p_full(10.0, 0.7, 3);
assert!(p_full_low < p_full);
}
#[test]
fn test_calculate_p_full_convergence() {
let p_full = calculate_p_full(10.0, 0.5, 5);
assert_eq!(p_full, 0.0);
}
#[test]
fn test_estimate_mdap_valid() {
let estimate = estimate_mdap(100, 0.99, 0.95, 0.001, 0.95).unwrap();
assert!(estimate.success_probability > 0.9);
assert!(estimate.recommended_k >= 1);
assert!(estimate.expected_cost_usd > 0.0);
assert!(estimate.expected_api_calls > 0);
}
#[test]
fn test_estimate_mdap_invalid_p() {
let result = estimate_mdap(100, 0.4, 0.95, 0.001, 0.95);
assert!(result.is_err());
}
#[test]
fn test_estimate_mdap_invalid_steps() {
let result = estimate_mdap(0, 0.99, 0.95, 0.001, 0.95);
assert!(result.is_err());
}
#[test]
fn test_estimate_per_step_success() {
let p = estimate_per_step_success(100, 80, 10);
assert!((p - 0.889).abs() < 0.01);
let p_all_flagged = estimate_per_step_success(100, 0, 100);
assert_eq!(p_all_flagged, 0.5);
}
#[test]
fn test_estimate_valid_response_rate() {
let v = estimate_valid_response_rate(100, 10);
assert_eq!(v, 0.9);
let v_zero = estimate_valid_response_rate(0, 0);
assert_eq!(v_zero, 0.95);
}
#[test]
fn test_calculate_expected_cost() {
let cost = calculate_expected_cost(100, 3, 0.95, 0.99, 0.001);
assert!(cost > 0.0);
let cost_high_k = calculate_expected_cost(100, 10, 0.95, 0.99, 0.001);
assert!(cost_high_k > cost);
}
#[test]
fn test_suggest_k_for_budget() {
let k = suggest_k_for_budget(100, 0.99, 0.95, 0.001, 1.0);
assert!(k >= 1);
let k_small = suggest_k_for_budget(100, 0.99, 0.95, 0.001, 0.1);
assert!(k_small <= k);
}
#[test]
fn test_model_costs() {
let sonnet = ModelCosts::claude_sonnet();
let cost = sonnet.estimate_call_cost(1000, 500);
assert!((cost - 0.0105).abs() < 0.001);
}
#[test]
fn test_calculate_expected_votes() {
let votes = calculate_expected_votes(0.99, 3);
assert!((votes - 3.06).abs() < 0.1);
let votes_half = calculate_expected_votes(0.5, 3);
assert!(votes_half.is_infinite());
}
#[test]
fn test_high_step_count() {
let estimate = estimate_mdap(1_000_000, 0.99, 0.95, 0.0001, 0.95).unwrap();
assert!(estimate.success_probability > 0.9);
}
}