use std::collections::HashMap;
use uuid::Uuid;
use crate::{
bootstrap, cost,
error::PlanError,
routing,
types::{
Aggregates, ConfidenceIntervals, PerRouteBreakdown, PlanInput, PlanResult, ProposedRoute,
RequestLog,
},
};
pub fn replay(input: PlanInput) -> Result<PlanResult, PlanError> {
validate(&input)?;
let mut routes = input.proposed_routes.clone();
routes.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.id.cmp(&b.id)));
let mut requests = input.requests.clone();
requests.sort_by_key(|r| r.id);
let cache_hit_ids = crate::cache_projection::project_l1_hit_ids(&requests, &input.config);
let projection = project_requests(&requests, &routes, &input.pricing, &cache_hit_ids);
let mut aggregates = aggregate(&projection);
if input.config.l1_ttl_seconds.is_some() {
let proj = crate::cache_projection::project_l1_hits(&requests, &input.config);
aggregates.cache_hit_rate_projected = proj.projected_l1_hit_rate;
}
if !requests.is_empty() && requests.iter().any(|r| r.embedding.is_some()) {
let l2 = crate::l2_projection::project_l2_hits(&requests, &input.config);
aggregates.l2_projections = l2.per_threshold;
aggregates.l2_poisoning_candidates = l2.poisoning_candidates;
}
let confidence_intervals = compute_cis(&projection, input.seed, input.bootstrap_iterations);
let per_route_breakdown = build_per_route(projection.per_route);
let proposed_routes = input.proposed_routes;
let mut caveats = build_caveats(
requests.len(),
aggregates.requests_unprice_able,
projection.latency_unprojected,
projection.would_block,
);
caveats.extend(wide_ci_caveats(&aggregates, &confidence_intervals));
Ok(PlanResult {
plan_id: input.plan_id,
org_id: input.org_id,
window_start: input.window_start,
window_end: input.window_end,
sample_size: requests.len() as u32,
aggregates,
confidence_intervals,
per_route_breakdown,
caveats,
quality: None,
proposed_routes,
})
}
pub async fn replay_with_quality<F>(
input: PlanInput,
judge: &dyn crate::quality::JudgeProvider,
quality_config: &crate::quality::QualityConfig,
proposed_response_for: F,
) -> Result<PlanResult, ReplayWithQualityError>
where
F: Fn(&Uuid) -> Option<String>,
{
let requests = input.requests.clone();
let mut result = replay(input).map_err(ReplayWithQualityError::Replay)?;
let quality =
crate::quality::score_quality(&requests, quality_config, judge, proposed_response_for)
.await
.map_err(ReplayWithQualityError::Quality)?;
result.quality = Some(quality);
Ok(result)
}
#[derive(Debug, thiserror::Error)]
pub enum ReplayWithQualityError {
#[error("replay: {0}")]
Replay(#[from] crate::error::PlanError),
#[error("quality: {0}")]
Quality(#[from] crate::quality::QualityError),
}
fn validate(input: &PlanInput) -> Result<(), PlanError> {
if input.window_end <= input.window_start {
return Err(PlanError::InvalidWindow {
start: input.window_start.to_rfc3339(),
end: input.window_end.to_rfc3339(),
});
}
if input.bootstrap_iterations == 0 {
return Err(PlanError::ZeroBootstrapIterations);
}
Ok(())
}
struct PerRouteBucket {
route_id: Uuid,
route_name: String,
matched: u32,
baseline_cost_usd: f64,
projected_cost_usd: f64,
}
struct Projection {
per_request_baseline: Vec<f64>,
per_request_projected: Vec<f64>,
per_request_latency: Vec<f64>,
per_request_cache_hit: Vec<f64>,
per_route: HashMap<Uuid, PerRouteBucket>,
requests_rerouted: u32,
requests_unchanged: u32,
requests_unprice_able: u32,
latency_unprojected: u32,
would_block: u32,
}
fn project_requests(
requests: &[RequestLog],
routes: &[ProposedRoute],
pricing: &crate::types::PricingTable,
cache_hit_ids: &std::collections::HashSet<Uuid>,
) -> Projection {
let cap = requests.len();
let mut per_request_baseline = Vec::with_capacity(cap);
let mut per_request_projected = Vec::with_capacity(cap);
let mut per_request_latency = Vec::with_capacity(cap);
let mut per_request_cache_hit = Vec::with_capacity(cap);
let mut per_route: HashMap<Uuid, PerRouteBucket> = HashMap::new();
let mut requests_rerouted: u32 = 0;
let mut requests_unchanged: u32 = 0;
let mut requests_unprice_able: u32 = 0;
let mut latency_unprojected: u32 = 0;
let mut would_block: u32 = 0;
let model_medians = model_median_latencies(requests);
let model_to_provider: HashMap<&str, &str> = {
let mut keys: Vec<&str> = pricing.keys().map(String::as_str).collect();
keys.sort_unstable();
let mut m: HashMap<&str, &str> = HashMap::new();
for k in keys {
if let Some((prov, model)) = k.split_once(':') {
m.entry(model).or_insert(prov);
}
}
m
};
for req in requests {
per_request_baseline.push(req.baseline_cost_usd);
per_request_cache_hit.push(if req.cached { 1.0 } else { 0.0 });
let is_cache_hit = cache_hit_ids.contains(&req.id);
let matched = routing::match_route(req, routes);
match matched {
Some(route) => {
let same_provider_key =
crate::types::pricing_key(&req.provider, &route.then.target_model);
let target_key = if pricing.contains_key(&same_provider_key) {
same_provider_key
} else {
let target_provider = model_to_provider
.get(route.then.target_model.as_str())
.copied()
.unwrap_or(req.provider.as_str());
crate::types::pricing_key(target_provider, &route.then.target_model)
};
if let Some(p) = pricing.get(&target_key) {
let projected = cost::project_cost(req, &route.then.target_model, p);
let mut projected_cost = if is_cache_hit {
0.0
} else {
projected.cost_usd
};
if !is_cache_hit
&& route
.then
.max_cost_usd
.is_some_and(|c| projected.cost_usd > c)
{
projected_cost = req.cost_usd;
would_block += 1;
}
per_request_projected.push(projected_cost);
match model_medians.get(route.then.target_model.as_str()) {
Some(&med) => per_request_latency.push(med),
None => {
per_request_latency.push(f64::from(req.latency_ms));
latency_unprojected += 1;
}
}
let bucket = per_route.entry(route.id).or_insert_with(|| PerRouteBucket {
route_id: route.id,
route_name: route.name.clone(),
matched: 0,
baseline_cost_usd: 0.0,
projected_cost_usd: 0.0,
});
bucket.matched += 1;
bucket.baseline_cost_usd += req.baseline_cost_usd;
bucket.projected_cost_usd += projected_cost;
requests_rerouted += 1;
} else {
per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
per_request_latency.push(f64::from(req.latency_ms));
requests_unprice_able += 1;
}
}
None => {
per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
per_request_latency.push(f64::from(req.latency_ms));
requests_unchanged += 1;
}
}
}
Projection {
per_request_baseline,
per_request_projected,
per_request_latency,
per_request_cache_hit,
per_route,
requests_rerouted,
requests_unchanged,
requests_unprice_able,
latency_unprojected,
would_block,
}
}
fn model_median_latencies(requests: &[RequestLog]) -> HashMap<&str, f64> {
let mut by_model: HashMap<&str, Vec<u32>> = HashMap::new();
for r in requests {
by_model
.entry(r.model.as_str())
.or_default()
.push(r.latency_ms);
}
by_model
.into_iter()
.map(|(model, mut lat)| {
lat.sort_unstable();
(model, f64::from(lat[lat.len() / 2]))
})
.collect()
}
fn aggregate(p: &Projection) -> Aggregates {
let total_baseline: f64 = p.per_request_baseline.iter().sum();
let total_projected: f64 = p.per_request_projected.iter().sum();
let projected_savings = (total_baseline - total_projected).max(0.0);
let projected_savings_pct = if total_baseline > 0.0 {
projected_savings / total_baseline * 100.0
} else {
0.0
};
let cache_hit_rate = if p.per_request_cache_hit.is_empty() {
0.0
} else {
p.per_request_cache_hit.iter().sum::<f64>() / p.per_request_cache_hit.len() as f64
};
let p50_latency = percentile(&p.per_request_latency, 0.50);
let p95_latency = percentile(&p.per_request_latency, 0.95);
Aggregates {
total_baseline_cost_usd: total_baseline,
total_projected_cost_usd: total_projected,
projected_savings_usd: projected_savings,
projected_savings_pct,
cache_hit_rate_projected: cache_hit_rate,
p50_latency_ms_projected: p50_latency,
p95_latency_ms_projected: p95_latency,
requests_rerouted: p.requests_rerouted,
requests_unchanged: p.requests_unchanged,
requests_unprice_able: p.requests_unprice_able,
l2_projections: Vec::new(),
l2_poisoning_candidates: 0,
}
}
fn compute_cis(p: &Projection, seed: u64, iterations: u32) -> ConfidenceIntervals {
let n = p.per_request_baseline.len() as f64;
let savings_per_req: Vec<f64> = p
.per_request_baseline
.iter()
.zip(p.per_request_projected.iter())
.map(|(b, pr)| (b - pr).max(0.0))
.collect();
let (sv_lo_mean, sv_hi_mean) =
bootstrap::bootstrap_ci(&savings_per_req, seed, iterations, (0.025, 0.975));
let savings_usd_95 = (sv_lo_mean * n, sv_hi_mean * n);
let savings_pct_95 = bootstrap_pct_savings_ci(
&p.per_request_baseline,
&p.per_request_projected,
seed.wrapping_add(1),
iterations,
);
let cache_hit_rate_95 = bootstrap::bootstrap_ci(
&p.per_request_cache_hit,
seed.wrapping_add(2),
iterations,
(0.025, 0.975),
);
let p50_latency_ms_95 = bootstrap_percentile_ci(
&p.per_request_latency,
0.50,
seed.wrapping_add(3),
iterations,
);
let p95_latency_ms_95 = bootstrap_percentile_ci(
&p.per_request_latency,
0.95,
seed.wrapping_add(4),
iterations,
);
ConfidenceIntervals {
savings_usd_95,
savings_pct_95,
cache_hit_rate_95,
p50_latency_ms_95,
p95_latency_ms_95,
}
}
fn bootstrap_percentile_ci(values: &[f64], q: f64, seed: u64, iterations: u32) -> (f64, f64) {
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
if values.is_empty() || iterations == 0 {
return (0.0, 0.0);
}
let n = values.len();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut samples: Vec<f64> = Vec::with_capacity(iterations as usize);
let mut buf: Vec<f64> = Vec::with_capacity(n);
for _ in 0..iterations {
buf.clear();
for _ in 0..n {
buf.push(values[rng.gen_range(0..n)]);
}
samples.push(percentile(&buf, q));
}
samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let lo_idx = (0.025 * iterations as f64) as usize;
let hi_idx = ((0.975 * iterations as f64) as usize).min(iterations as usize - 1);
(samples[lo_idx], samples[hi_idx])
}
fn build_per_route(buckets: HashMap<Uuid, PerRouteBucket>) -> Vec<PerRouteBreakdown> {
let mut rows: Vec<PerRouteBreakdown> = buckets
.into_values()
.map(|b| PerRouteBreakdown {
route_id: b.route_id,
route_name: b.route_name,
matched: b.matched,
baseline_cost_usd: b.baseline_cost_usd,
projected_cost_usd: b.projected_cost_usd,
savings_usd: (b.baseline_cost_usd - b.projected_cost_usd).max(0.0),
})
.collect();
rows.sort_by_key(|r| r.route_id);
rows
}
fn build_caveats(
sample_size: usize,
requests_unprice_able: u32,
latency_unprojected: u32,
would_block: u32,
) -> Vec<String> {
let mut caveats = Vec::new();
if sample_size < 1000 {
caveats.push(format!(
"Small sample size ({sample_size} requests) — confidence intervals are wide."
));
}
if requests_unprice_able > 0 {
caveats.push(format!(
"{requests_unprice_able} request(s) routed to a target model with no pricing entry — counted as unchanged."
));
}
if latency_unprojected > 0 {
caveats.push(format!(
"{latency_unprojected} rerouted request(s) had no latency history for the target model — their latency is shown unchanged, not projected."
));
}
if would_block > 0 {
caveats.push(format!(
"{would_block} request(s) would be rejected by a max_cost_usd ceiling — counted unchanged, not as savings."
));
}
caveats
}
pub(crate) fn wide_ci_caveats(aggregates: &Aggregates, cis: &ConfidenceIntervals) -> Vec<String> {
let mut out = Vec::new();
let rel_width = |lo: f64, hi: f64, center: f64| -> Option<f64> {
if center.abs() < f64::EPSILON {
return None;
}
Some((hi - lo).abs() / center.abs())
};
if let Some(w) = rel_width(
cis.savings_usd_95.0,
cis.savings_usd_95.1,
aggregates.projected_savings_usd,
) {
if w > 0.30 {
out.push(format!(
"Savings CI is wide: ±{:.0}% relative width. Treat the headline savings number as a rough estimate; consider scanning a larger window.",
w * 100.0
));
}
}
if let Some(w) = rel_width(
cis.p50_latency_ms_95.0,
cis.p50_latency_ms_95.1,
aggregates.p50_latency_ms_projected,
) {
if w > 0.30 {
out.push(format!(
"p50 latency CI is wide: ±{:.0}% relative width.",
w * 100.0
));
}
}
out
}
fn percentile(values: &[f64], q: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut v = values.to_vec();
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((q * (v.len() as f64 - 1.0)).round() as usize).min(v.len() - 1);
v[idx]
}
fn bootstrap_pct_savings_ci(
baseline: &[f64],
projected: &[f64],
seed: u64,
iterations: u32,
) -> (f64, f64) {
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let n = baseline.len();
if n == 0 || iterations == 0 || n != projected.len() {
return (0.0, 0.0);
}
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut pct_samples: Vec<f64> = Vec::with_capacity(iterations as usize);
for _ in 0..iterations {
let mut b_sum = 0.0;
let mut p_sum = 0.0;
for _ in 0..n {
let idx = rng.gen_range(0..n);
b_sum += baseline[idx];
p_sum += projected[idx];
}
let pct = if b_sum > 0.0 {
(b_sum - p_sum) / b_sum * 100.0
} else {
0.0
};
pct_samples.push(pct.max(0.0));
}
pct_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let iter_f = iterations as f64;
let lo_idx = ((0.025 * iter_f) as usize).min(pct_samples.len() - 1);
let hi_idx = ((0.975 * iter_f) as usize).min(pct_samples.len() - 1);
(pct_samples[lo_idx], pct_samples[hi_idx])
}