use std::collections::{BTreeMap, BTreeSet};
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::types::{L2Projection, L2SweepResult, PlanConfig, RequestLog};
#[derive(Clone)]
struct LiveEntry {
ts: DateTime<Utc>,
embedding: Vec<f32>,
finish_reason: Option<String>,
output_tokens: u32,
}
#[must_use]
pub fn project_l2_hits(requests: &[RequestLog], config: &PlanConfig) -> L2SweepResult {
if requests.is_empty() || config.l2_threshold_sweep.is_empty() {
return L2SweepResult::default();
}
let Some(ttl_secs) = config.l2_ttl_seconds else {
return L2SweepResult::default();
};
let any_embedding = requests.iter().any(|r| r.embedding.is_some());
if !any_embedding {
return L2SweepResult::default();
}
let ttl = chrono::Duration::seconds(i64::from(ttl_secs));
let mut sorted: Vec<&RequestLog> = requests.iter().collect();
sorted.sort_by(|a, b| a.ts.cmp(&b.ts).then_with(|| a.id.cmp(&b.id)));
let mut per_threshold: Vec<L2Projection> = Vec::with_capacity(config.l2_threshold_sweep.len());
let mut distinct_poisoning: BTreeSet<Uuid> = BTreeSet::new();
for &threshold in &config.l2_threshold_sweep {
let proj = run_single_threshold(&sorted, threshold, ttl, &mut distinct_poisoning);
per_threshold.push(proj);
}
let poisoning_candidates = u32::try_from(distinct_poisoning.len()).unwrap_or(u32::MAX);
L2SweepResult {
per_threshold,
poisoning_candidates,
}
}
fn run_single_threshold(
sorted: &[&RequestLog],
threshold: f32,
ttl: chrono::Duration,
distinct_poisoning: &mut BTreeSet<Uuid>,
) -> L2Projection {
let mut active: BTreeMap<(String, String), Vec<LiveEntry>> = BTreeMap::new();
let mut total_considered: u32 = 0;
let mut hits: u32 = 0;
let mut poisoning: u32 = 0;
for req in sorted {
let Some(embedding) = req.embedding.as_ref() else {
continue;
};
total_considered = total_considered.saturating_add(1);
let key = (req.provider.clone(), req.model.clone());
let bucket = active.entry(key).or_default();
let cutoff = req.ts - ttl;
bucket.retain(|e| e.ts >= cutoff);
let mut best: Option<(usize, f32)> = None;
for (idx, entry) in bucket.iter().enumerate() {
let sim = cosine(embedding, &entry.embedding);
if sim >= threshold && best.is_none_or(|(_, b)| sim > b) {
best = Some((idx, sim));
}
}
if let Some((idx, _sim)) = best {
hits = hits.saturating_add(1);
let source = &bucket[idx];
if outcomes_diverged(source, req) {
poisoning = poisoning.saturating_add(1);
distinct_poisoning.insert(req.id);
}
} else {
bucket.push(LiveEntry {
ts: req.ts,
embedding: embedding.clone(),
finish_reason: req.finish_reason.clone(),
output_tokens: req.output_tokens,
});
}
}
let rate = if total_considered == 0 {
0.0
} else {
f64::from(hits) / f64::from(total_considered)
};
L2Projection {
threshold,
total: total_considered,
projected_l2_hits: hits,
projected_l2_hit_rate: rate,
poisoning_candidates: poisoning,
}
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na * nb)
}
fn outcomes_diverged(source: &LiveEntry, req: &RequestLog) -> bool {
let finish_diverged = match (
source.finish_reason.as_deref(),
req.finish_reason.as_deref(),
) {
(Some(a), Some(b)) => a != b,
_ => false,
};
let tolerance = std::cmp::max(20, req.output_tokens / 4);
let token_delta =
(i64::from(source.output_tokens) - i64::from(req.output_tokens)).unsigned_abs();
let tokens_diverged = token_delta > u64::from(tolerance);
finish_diverged || tokens_diverged
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
use uuid::Uuid;
fn req_with(
id_seed: u128,
secs: i64,
embedding: Option<Vec<f32>>,
finish_reason: Option<&str>,
output_tokens: u32,
) -> RequestLog {
RequestLog {
id: Uuid::from_u128(id_seed),
org_id: Uuid::nil(),
ts: Utc.with_ymd_and_hms(2026, 5, 1, 0, 0, 0).unwrap()
+ chrono::Duration::seconds(secs),
provider: "anthropic".into(),
model: "claude".into(),
input_tokens: 100,
output_tokens,
cached_tokens: 0,
cost_usd: 0.0,
baseline_cost_usd: 0.0,
cached: false,
cache_layer: None,
matched_route_id: None,
latency_ms: 0,
upstream_latency_ms: None,
status: 200,
tag: None,
embedding,
finish_reason: finish_reason.map(String::from),
body: None,
response_body: None,
}
}
#[test]
fn cosine_identical_vectors_is_one() {
let a = vec![1.0_f32, 0.0, 0.0];
assert!((cosine(&a, &a) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
assert!(cosine(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_mismatched_length_returns_zero() {
assert_eq!(cosine(&[1.0, 0.0], &[1.0]), 0.0);
}
#[test]
fn outcomes_diverged_token_delta() {
let src = LiveEntry {
ts: Utc::now(),
embedding: vec![],
finish_reason: None,
output_tokens: 100,
};
let req = req_with(1, 0, None, None, 100);
assert!(!outcomes_diverged(&src, &req));
let mut src2 = src.clone();
src2.output_tokens = 130;
assert!(outcomes_diverged(&src2, &req));
}
#[test]
fn outcomes_diverged_finish_reason() {
let src = LiveEntry {
ts: Utc::now(),
embedding: vec![],
finish_reason: Some("length".into()),
output_tokens: 100,
};
let req = req_with(1, 0, None, Some("stop"), 100);
assert!(outcomes_diverged(&src, &req));
}
#[test]
fn outcomes_diverged_missing_finish_reason_does_not_flag() {
let src = LiveEntry {
ts: Utc::now(),
embedding: vec![],
finish_reason: None,
output_tokens: 100,
};
let req = req_with(1, 0, None, Some("stop"), 100);
assert!(!outcomes_diverged(&src, &req));
}
#[test]
fn empty_sweep_when_no_embeddings() {
let reqs = vec![
req_with(1, 0, None, None, 10),
req_with(2, 1, None, None, 10),
];
let cfg = PlanConfig {
l2_ttl_seconds: Some(60),
..PlanConfig::default()
};
let result = project_l2_hits(&reqs, &cfg);
assert!(result.per_threshold.is_empty());
assert_eq!(result.poisoning_candidates, 0);
}
#[test]
fn poisoning_reported_per_threshold_and_deduped_in_aggregate() {
let emb = Some(vec![1.0_f32, 0.0, 0.0]);
let reqs = vec![
req_with(1, 0, emb.clone(), Some("length"), 100),
req_with(2, 1, emb.clone(), Some("stop"), 100),
];
let cfg = PlanConfig {
l2_ttl_seconds: Some(600),
l2_threshold_sweep: vec![0.80, 0.90, 0.95],
..PlanConfig::default()
};
let result = project_l2_hits(&reqs, &cfg);
assert_eq!(result.per_threshold.len(), 3);
for proj in &result.per_threshold {
assert_eq!(proj.projected_l2_hits, 1);
assert_eq!(
proj.poisoning_candidates, 1,
"threshold {} should report its own poisoning count",
proj.threshold
);
}
assert_eq!(
result.poisoning_candidates, 1,
"aggregate must dedup across the sweep, not sum (would be 3)"
);
}
#[test]
fn poisoning_aggregate_counts_distinct_requests() {
let a = Some(vec![1.0_f32, 0.0]);
let b = Some(vec![0.0_f32, 1.0]);
let reqs = vec![
req_with(1, 0, a.clone(), Some("length"), 100),
req_with(2, 1, a.clone(), Some("stop"), 100),
req_with(3, 2, b.clone(), Some("length"), 100),
req_with(4, 3, b.clone(), Some("stop"), 100),
];
let cfg = PlanConfig {
l2_ttl_seconds: Some(600),
l2_threshold_sweep: vec![0.90, 0.95],
..PlanConfig::default()
};
let result = project_l2_hits(&reqs, &cfg);
for proj in &result.per_threshold {
assert_eq!(proj.poisoning_candidates, 2);
}
assert_eq!(result.poisoning_candidates, 2);
}
#[test]
fn empty_sweep_when_ttl_none() {
let reqs = vec![req_with(1, 0, Some(vec![1.0, 0.0]), None, 10)];
let cfg = PlanConfig::default(); let result = project_l2_hits(&reqs, &cfg);
assert!(result.per_threshold.is_empty());
}
}