#![forbid(unsafe_code)]
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::fs;
use std::future::Future;
use std::io;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
pub use crate::url_filter::{DefaultUrlFilter, UrlFilter};
pub const MAX_RETRY_ATTEMPTS: u32 = 16;
pub const MAX_BATCH_SIZE: usize = 128;
pub const DEFAULT_JUDGE_CACHE_CAPACITY: usize = 1024;
pub trait JudgeClient: Send + Sync {
fn judge<'a>(&'a self, prompt: &'a str) -> JudgeFuture<'a>;
}
pub type JudgeFuture<'a> =
Pin<Box<dyn Future<Output = Result<JudgeVerdict, JudgeError>> + Send + 'a>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub max_delay: Duration,
pub jitter: bool,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct CacheKey([u8; 32]);
impl CacheKey {
#[must_use]
pub fn for_prompt(model_id: &str, prompt: &str) -> Self {
let mut hasher = Sha256::new();
update_with_len_prefixed_bytes(&mut hasher, model_id.as_bytes());
update_with_len_prefixed_bytes(&mut hasher, prompt.as_bytes());
Self(hasher.finalize().into())
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
fn to_hex(self) -> String {
hex_lower(self.as_bytes())
}
}
impl fmt::Debug for CacheKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("CacheKey")
.field(&hex_lower(self.as_bytes()))
.finish()
}
}
#[derive(Debug)]
pub struct JudgeCache {
capacity: usize,
entries: HashMap<CacheKey, JudgeVerdict>,
recency: VecDeque<CacheKey>,
disk_path: Option<PathBuf>,
dirty: bool,
}
impl JudgeCache {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(DEFAULT_JUDGE_CACHE_CAPACITY)
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
capacity: capacity.max(1),
entries: HashMap::new(),
recency: VecDeque::new(),
disk_path: None,
dirty: false,
}
}
pub fn with_disk_path(capacity: usize, path: impl Into<PathBuf>) -> io::Result<Self> {
let disk_path = path.into();
fs::create_dir_all(&disk_path)?;
let mut cache = Self {
capacity: capacity.max(1),
entries: HashMap::new(),
recency: VecDeque::new(),
disk_path: Some(disk_path.clone()),
dirty: false,
};
let mut files = fs::read_dir(&disk_path)?
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
.collect::<Vec<_>>();
files.sort();
for path in files {
let Some(key) = cache_key_from_path(&path) else {
continue;
};
let Ok(bytes) = fs::read(&path) else {
continue;
};
let Ok(verdict) = serde_json::from_slice::<JudgeVerdict>(&bytes) else {
continue;
};
cache.put_loaded(key, verdict);
}
cache.dirty = false;
Ok(cache)
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn disk_path(&self) -> Option<&Path> {
self.disk_path.as_deref()
}
pub fn get(&mut self, key: &CacheKey) -> Option<JudgeVerdict> {
let verdict = self.entries.get(key).cloned();
if verdict.is_some() {
self.touch(*key);
}
verdict
}
pub fn put(&mut self, key: CacheKey, verdict: JudgeVerdict) {
let replacing = self.entries.insert(key, verdict).is_some();
self.touch(key);
self.dirty = true;
if !replacing {
self.evict_over_capacity();
}
}
pub fn flush_to_disk(&mut self) -> io::Result<()> {
let Some(path) = self.disk_path.as_ref() else {
self.dirty = false;
return Ok(());
};
fs::create_dir_all(path)?;
for entry in fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
continue;
}
let remove =
cache_key_from_path(&path).is_some_and(|key| !self.entries.contains_key(&key));
if remove {
fs::remove_file(path)?;
}
}
for (key, verdict) in &self.entries {
let path = path.join(format!("{}.json", key.to_hex()));
let bytes = serde_json::to_vec(verdict).map_err(io::Error::other)?;
fs::write(path, bytes)?;
}
self.dirty = false;
Ok(())
}
fn put_loaded(&mut self, key: CacheKey, verdict: JudgeVerdict) {
let replacing = self.entries.insert(key, verdict).is_some();
self.touch(key);
if !replacing {
self.evict_over_capacity();
}
}
fn touch(&mut self, key: CacheKey) {
self.recency.retain(|candidate| candidate != &key);
self.recency.push_back(key);
}
fn evict_over_capacity(&mut self) {
while self.entries.len() > self.capacity {
if let Some(oldest) = self.recency.pop_front() {
self.entries.remove(&oldest);
} else {
break;
}
}
}
}
impl Default for JudgeCache {
fn default() -> Self {
Self::new()
}
}
impl Drop for JudgeCache {
fn drop(&mut self) {
if self.dirty {
let _ = self.flush_to_disk();
}
}
}
fn update_with_len_prefixed_bytes(hasher: &mut Sha256, bytes: &[u8]) {
hasher.update(bytes.len().to_le_bytes());
hasher.update(bytes);
}
fn hex_lower(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0f) as usize] as char);
}
out
}
fn cache_key_from_path(path: &Path) -> Option<CacheKey> {
let stem = path.file_stem()?.to_str()?;
cache_key_from_hex(stem)
}
fn cache_key_from_hex(hex: &str) -> Option<CacheKey> {
if hex.len() != 64 {
return None;
}
let mut bytes = [0_u8; 32];
let raw = hex.as_bytes();
for (idx, byte) in bytes.iter_mut().enumerate() {
let high = hex_nibble(raw[idx * 2])?;
let low = hex_nibble(raw[idx * 2 + 1])?;
*byte = (high << 4) | low;
}
Some(CacheKey(bytes))
}
fn hex_nibble(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'a'..=b'f' => Some(byte - b'a' + 10),
b'A'..=b'F' => Some(byte - b'A' + 10),
_ => None,
}
}
impl RetryPolicy {
#[must_use]
pub const fn new(max_attempts: u32, max_delay: Duration, jitter: bool) -> Self {
Self {
max_attempts,
max_delay,
jitter,
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 6,
max_delay: Duration::from_secs(240),
jitter: true,
}
}
}
pub struct JudgeRegistry {
client: Arc<dyn JudgeClient>,
model_id: String,
retry_policy: RetryPolicy,
batch_size: usize,
url_filter: Arc<dyn UrlFilter>,
}
impl std::fmt::Debug for JudgeRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JudgeRegistry")
.field("model_id", &self.model_id)
.field("retry_policy", &self.retry_policy)
.field("batch_size", &self.batch_size)
.finish_non_exhaustive()
}
}
impl JudgeRegistry {
#[must_use]
pub fn builder(
client: Arc<dyn JudgeClient>,
model_id: impl Into<String>,
) -> JudgeRegistryBuilder {
JudgeRegistryBuilder {
client,
model_id: model_id.into(),
retry_policy: RetryPolicy::default(),
batch_size: 1,
url_filter: Arc::new(DefaultUrlFilter),
}
}
#[must_use]
pub fn client(&self) -> &Arc<dyn JudgeClient> {
&self.client
}
#[must_use]
pub fn model_id(&self) -> &str {
&self.model_id
}
#[must_use]
pub const fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
#[must_use]
pub const fn batch_size(&self) -> usize {
self.batch_size
}
#[must_use]
pub fn url_filter(&self) -> &Arc<dyn UrlFilter> {
&self.url_filter
}
}
pub struct JudgeRegistryBuilder {
client: Arc<dyn JudgeClient>,
model_id: String,
retry_policy: RetryPolicy,
batch_size: usize,
url_filter: Arc<dyn UrlFilter>,
}
impl JudgeRegistryBuilder {
#[must_use]
pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
self.retry_policy = retry_policy;
self
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_url_filter(mut self, url_filter: Arc<dyn UrlFilter>) -> Self {
self.url_filter = url_filter;
self
}
pub fn build(self) -> Result<JudgeRegistry, JudgeRegistryError> {
let model_id = self.model_id.trim().to_string();
if model_id.is_empty() {
return Err(JudgeRegistryError::MissingModelId);
}
if !(1..=MAX_BATCH_SIZE).contains(&self.batch_size) {
return Err(JudgeRegistryError::InvalidBatchSize {
batch_size: self.batch_size,
});
}
if self.retry_policy.max_attempts > MAX_RETRY_ATTEMPTS {
return Err(JudgeRegistryError::InvalidRetryPolicy {
reason: format!(
"max_attempts must be <= {MAX_RETRY_ATTEMPTS}, got {}",
self.retry_policy.max_attempts
),
});
}
if self.retry_policy.max_attempts == 0 {
return Err(JudgeRegistryError::InvalidRetryPolicy {
reason: "max_attempts must be greater than 0".to_string(),
});
}
Ok(JudgeRegistry {
client: self.client,
model_id,
retry_policy: self.retry_policy,
batch_size: self.batch_size,
url_filter: self.url_filter,
})
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum JudgeRegistryError {
#[error("judge registry requires an explicit model_id")]
MissingModelId,
#[error("judge batch_size must be in 1..={MAX_BATCH_SIZE}, got {batch_size}")]
InvalidBatchSize { batch_size: usize },
#[error("invalid judge retry policy: {reason}")]
InvalidRetryPolicy { reason: String },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JudgeVerdict {
pub score: f64,
pub pass: bool,
pub reason: Option<String>,
pub label: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum JudgeError {
#[error("transport: {0}")]
Transport(String),
#[error("timeout")]
Timeout,
#[error("malformed response: {0}")]
MalformedResponse(String),
#[error("other: {0}")]
Other(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn judge_error_display_variants() {
assert_eq!(
JudgeError::Transport("boom".into()).to_string(),
"transport: boom"
);
assert_eq!(JudgeError::Timeout.to_string(), "timeout");
assert_eq!(
JudgeError::MalformedResponse("bad".into()).to_string(),
"malformed response: bad"
);
assert_eq!(
JudgeError::Other("thing".into()).to_string(),
"other: thing"
);
}
#[test]
fn verdict_fields_are_public() {
let v = JudgeVerdict {
score: 0.75,
pass: true,
reason: Some("looks right".into()),
label: Some("equivalent".into()),
};
assert!((v.score - 0.75).abs() < f64::EPSILON);
assert!(v.pass);
assert_eq!(v.reason.as_deref(), Some("looks right"));
assert_eq!(v.label.as_deref(), Some("equivalent"));
}
}