1#![forbid(unsafe_code)]
32
33use std::collections::{HashMap, VecDeque};
34use std::fmt;
35use std::fs;
36use std::future::Future;
37use std::io;
38use std::path::{Path, PathBuf};
39use std::pin::Pin;
40use std::sync::Arc;
41use std::time::Duration;
42
43use serde::{Deserialize, Serialize};
44use sha2::{Digest, Sha256};
45
46pub use crate::url_filter::{DefaultUrlFilter, UrlFilter};
47
48pub const MAX_RETRY_ATTEMPTS: u32 = 16;
50
51pub const MAX_BATCH_SIZE: usize = 128;
53
54pub const DEFAULT_JUDGE_CACHE_CAPACITY: usize = 1024;
56
57pub trait JudgeClient: Send + Sync {
64 fn judge<'a>(&'a self, prompt: &'a str) -> JudgeFuture<'a>;
70}
71
72pub type JudgeFuture<'a> =
74 Pin<Box<dyn Future<Output = Result<JudgeVerdict, JudgeError>> + Send + 'a>>;
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct RetryPolicy {
79 pub max_attempts: u32,
81 pub max_delay: Duration,
83 pub jitter: bool,
85}
86
87#[derive(Clone, Copy, PartialEq, Eq, Hash)]
92pub struct CacheKey([u8; 32]);
93
94impl CacheKey {
95 #[must_use]
97 pub fn for_prompt(model_id: &str, prompt: &str) -> Self {
98 let mut hasher = Sha256::new();
99 update_with_len_prefixed_bytes(&mut hasher, model_id.as_bytes());
100 update_with_len_prefixed_bytes(&mut hasher, prompt.as_bytes());
101 Self(hasher.finalize().into())
102 }
103
104 #[must_use]
106 pub const fn as_bytes(&self) -> &[u8; 32] {
107 &self.0
108 }
109
110 fn to_hex(self) -> String {
111 hex_lower(self.as_bytes())
112 }
113}
114
115impl fmt::Debug for CacheKey {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 f.debug_tuple("CacheKey")
118 .field(&hex_lower(self.as_bytes()))
119 .finish()
120 }
121}
122
123#[derive(Debug)]
129pub struct JudgeCache {
130 capacity: usize,
131 entries: HashMap<CacheKey, JudgeVerdict>,
132 recency: VecDeque<CacheKey>,
133 disk_path: Option<PathBuf>,
134 dirty: bool,
135}
136
137impl JudgeCache {
138 #[must_use]
140 pub fn new() -> Self {
141 Self::with_capacity(DEFAULT_JUDGE_CACHE_CAPACITY)
142 }
143
144 #[must_use]
149 pub fn with_capacity(capacity: usize) -> Self {
150 Self {
151 capacity: capacity.max(1),
152 entries: HashMap::new(),
153 recency: VecDeque::new(),
154 disk_path: None,
155 dirty: false,
156 }
157 }
158
159 pub fn with_disk_path(capacity: usize, path: impl Into<PathBuf>) -> io::Result<Self> {
165 let disk_path = path.into();
166 fs::create_dir_all(&disk_path)?;
167
168 let mut cache = Self {
169 capacity: capacity.max(1),
170 entries: HashMap::new(),
171 recency: VecDeque::new(),
172 disk_path: Some(disk_path.clone()),
173 dirty: false,
174 };
175
176 let mut files = fs::read_dir(&disk_path)?
177 .filter_map(Result::ok)
178 .map(|entry| entry.path())
179 .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
180 .collect::<Vec<_>>();
181 files.sort();
182
183 for path in files {
184 let Some(key) = cache_key_from_path(&path) else {
185 continue;
186 };
187 let Ok(bytes) = fs::read(&path) else {
188 continue;
189 };
190 let Ok(verdict) = serde_json::from_slice::<JudgeVerdict>(&bytes) else {
191 continue;
192 };
193 cache.put_loaded(key, verdict);
194 }
195
196 cache.dirty = false;
197 Ok(cache)
198 }
199
200 #[must_use]
202 pub fn len(&self) -> usize {
203 self.entries.len()
204 }
205
206 #[must_use]
208 pub fn is_empty(&self) -> bool {
209 self.entries.is_empty()
210 }
211
212 #[must_use]
214 pub const fn capacity(&self) -> usize {
215 self.capacity
216 }
217
218 #[must_use]
220 pub fn disk_path(&self) -> Option<&Path> {
221 self.disk_path.as_deref()
222 }
223
224 pub fn get(&mut self, key: &CacheKey) -> Option<JudgeVerdict> {
226 let verdict = self.entries.get(key).cloned();
227 if verdict.is_some() {
228 self.touch(*key);
229 }
230 verdict
231 }
232
233 pub fn put(&mut self, key: CacheKey, verdict: JudgeVerdict) {
235 let replacing = self.entries.insert(key, verdict).is_some();
236 self.touch(key);
237 self.dirty = true;
238
239 if !replacing {
240 self.evict_over_capacity();
241 }
242 }
243
244 pub fn flush_to_disk(&mut self) -> io::Result<()> {
246 let Some(path) = self.disk_path.as_ref() else {
247 self.dirty = false;
248 return Ok(());
249 };
250
251 fs::create_dir_all(path)?;
252 for entry in fs::read_dir(path)? {
253 let entry = entry?;
254 let path = entry.path();
255 if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
256 continue;
257 }
258 let remove =
259 cache_key_from_path(&path).is_some_and(|key| !self.entries.contains_key(&key));
260 if remove {
261 fs::remove_file(path)?;
262 }
263 }
264
265 for (key, verdict) in &self.entries {
266 let path = path.join(format!("{}.json", key.to_hex()));
267 let bytes = serde_json::to_vec(verdict).map_err(io::Error::other)?;
268 fs::write(path, bytes)?;
269 }
270
271 self.dirty = false;
272 Ok(())
273 }
274
275 fn put_loaded(&mut self, key: CacheKey, verdict: JudgeVerdict) {
276 let replacing = self.entries.insert(key, verdict).is_some();
277 self.touch(key);
278
279 if !replacing {
280 self.evict_over_capacity();
281 }
282 }
283
284 fn touch(&mut self, key: CacheKey) {
285 self.recency.retain(|candidate| candidate != &key);
286 self.recency.push_back(key);
287 }
288
289 fn evict_over_capacity(&mut self) {
290 while self.entries.len() > self.capacity {
291 if let Some(oldest) = self.recency.pop_front() {
292 self.entries.remove(&oldest);
293 } else {
294 break;
295 }
296 }
297 }
298}
299
300impl Default for JudgeCache {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306impl Drop for JudgeCache {
307 fn drop(&mut self) {
308 if self.dirty {
309 let _ = self.flush_to_disk();
310 }
311 }
312}
313
314fn update_with_len_prefixed_bytes(hasher: &mut Sha256, bytes: &[u8]) {
315 hasher.update(bytes.len().to_le_bytes());
316 hasher.update(bytes);
317}
318
319fn hex_lower(bytes: &[u8]) -> String {
320 const HEX: &[u8; 16] = b"0123456789abcdef";
321 let mut out = String::with_capacity(bytes.len() * 2);
322 for byte in bytes {
323 out.push(HEX[(byte >> 4) as usize] as char);
324 out.push(HEX[(byte & 0x0f) as usize] as char);
325 }
326 out
327}
328
329fn cache_key_from_path(path: &Path) -> Option<CacheKey> {
330 let stem = path.file_stem()?.to_str()?;
331 cache_key_from_hex(stem)
332}
333
334fn cache_key_from_hex(hex: &str) -> Option<CacheKey> {
335 if hex.len() != 64 {
336 return None;
337 }
338
339 let mut bytes = [0_u8; 32];
340 let raw = hex.as_bytes();
341 for (idx, byte) in bytes.iter_mut().enumerate() {
342 let high = hex_nibble(raw[idx * 2])?;
343 let low = hex_nibble(raw[idx * 2 + 1])?;
344 *byte = (high << 4) | low;
345 }
346
347 Some(CacheKey(bytes))
348}
349
350fn hex_nibble(byte: u8) -> Option<u8> {
351 match byte {
352 b'0'..=b'9' => Some(byte - b'0'),
353 b'a'..=b'f' => Some(byte - b'a' + 10),
354 b'A'..=b'F' => Some(byte - b'A' + 10),
355 _ => None,
356 }
357}
358
359impl RetryPolicy {
360 #[must_use]
362 pub const fn new(max_attempts: u32, max_delay: Duration, jitter: bool) -> Self {
363 Self {
364 max_attempts,
365 max_delay,
366 jitter,
367 }
368 }
369}
370
371impl Default for RetryPolicy {
372 fn default() -> Self {
373 Self {
374 max_attempts: 6,
375 max_delay: Duration::from_secs(240),
376 jitter: true,
377 }
378 }
379}
380
381pub struct JudgeRegistry {
387 client: Arc<dyn JudgeClient>,
388 model_id: String,
389 retry_policy: RetryPolicy,
390 batch_size: usize,
391 url_filter: Arc<dyn UrlFilter>,
392}
393
394impl std::fmt::Debug for JudgeRegistry {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("JudgeRegistry")
397 .field("model_id", &self.model_id)
398 .field("retry_policy", &self.retry_policy)
399 .field("batch_size", &self.batch_size)
400 .finish_non_exhaustive()
401 }
402}
403
404impl JudgeRegistry {
405 #[must_use]
407 pub fn builder(
408 client: Arc<dyn JudgeClient>,
409 model_id: impl Into<String>,
410 ) -> JudgeRegistryBuilder {
411 JudgeRegistryBuilder {
412 client,
413 model_id: model_id.into(),
414 retry_policy: RetryPolicy::default(),
415 batch_size: 1,
416 url_filter: Arc::new(DefaultUrlFilter),
417 }
418 }
419
420 #[must_use]
422 pub fn client(&self) -> &Arc<dyn JudgeClient> {
423 &self.client
424 }
425
426 #[must_use]
428 pub fn model_id(&self) -> &str {
429 &self.model_id
430 }
431
432 #[must_use]
434 pub const fn retry_policy(&self) -> &RetryPolicy {
435 &self.retry_policy
436 }
437
438 #[must_use]
440 pub const fn batch_size(&self) -> usize {
441 self.batch_size
442 }
443
444 #[must_use]
446 pub fn url_filter(&self) -> &Arc<dyn UrlFilter> {
447 &self.url_filter
448 }
449}
450
451pub struct JudgeRegistryBuilder {
453 client: Arc<dyn JudgeClient>,
454 model_id: String,
455 retry_policy: RetryPolicy,
456 batch_size: usize,
457 url_filter: Arc<dyn UrlFilter>,
458}
459
460impl JudgeRegistryBuilder {
461 #[must_use]
463 pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
464 self.retry_policy = retry_policy;
465 self
466 }
467
468 #[must_use]
470 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
471 self.batch_size = batch_size;
472 self
473 }
474
475 #[must_use]
477 pub fn with_url_filter(mut self, url_filter: Arc<dyn UrlFilter>) -> Self {
478 self.url_filter = url_filter;
479 self
480 }
481
482 pub fn build(self) -> Result<JudgeRegistry, JudgeRegistryError> {
484 let model_id = self.model_id.trim().to_string();
485 if model_id.is_empty() {
486 return Err(JudgeRegistryError::MissingModelId);
487 }
488 if !(1..=MAX_BATCH_SIZE).contains(&self.batch_size) {
489 return Err(JudgeRegistryError::InvalidBatchSize {
490 batch_size: self.batch_size,
491 });
492 }
493 if self.retry_policy.max_attempts > MAX_RETRY_ATTEMPTS {
494 return Err(JudgeRegistryError::InvalidRetryPolicy {
495 reason: format!(
496 "max_attempts must be <= {MAX_RETRY_ATTEMPTS}, got {}",
497 self.retry_policy.max_attempts
498 ),
499 });
500 }
501 if self.retry_policy.max_attempts == 0 {
502 return Err(JudgeRegistryError::InvalidRetryPolicy {
503 reason: "max_attempts must be greater than 0".to_string(),
504 });
505 }
506
507 Ok(JudgeRegistry {
508 client: self.client,
509 model_id,
510 retry_policy: self.retry_policy,
511 batch_size: self.batch_size,
512 url_filter: self.url_filter,
513 })
514 }
515}
516
517#[derive(Debug, thiserror::Error, PartialEq, Eq)]
519pub enum JudgeRegistryError {
520 #[error("judge registry requires an explicit model_id")]
522 MissingModelId,
523 #[error("judge batch_size must be in 1..={MAX_BATCH_SIZE}, got {batch_size}")]
525 InvalidBatchSize { batch_size: usize },
526 #[error("invalid judge retry policy: {reason}")]
528 InvalidRetryPolicy { reason: String },
529}
530
531#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
536pub struct JudgeVerdict {
537 pub score: f64,
539 pub pass: bool,
542 pub reason: Option<String>,
544 pub label: Option<String>,
546}
547
548#[derive(Debug, thiserror::Error)]
553pub enum JudgeError {
554 #[error("transport: {0}")]
556 Transport(String),
557 #[error("timeout")]
559 Timeout,
560 #[error("malformed response: {0}")]
562 MalformedResponse(String),
563 #[error("other: {0}")]
565 Other(String),
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn judge_error_display_variants() {
574 assert_eq!(
575 JudgeError::Transport("boom".into()).to_string(),
576 "transport: boom"
577 );
578 assert_eq!(JudgeError::Timeout.to_string(), "timeout");
579 assert_eq!(
580 JudgeError::MalformedResponse("bad".into()).to_string(),
581 "malformed response: bad"
582 );
583 assert_eq!(
584 JudgeError::Other("thing".into()).to_string(),
585 "other: thing"
586 );
587 }
588
589 #[test]
590 fn verdict_fields_are_public() {
591 let v = JudgeVerdict {
592 score: 0.75,
593 pass: true,
594 reason: Some("looks right".into()),
595 label: Some("equivalent".into()),
596 };
597 assert!((v.score - 0.75).abs() < f64::EPSILON);
598 assert!(v.pass);
599 assert_eq!(v.reason.as_deref(), Some("looks right"));
600 assert_eq!(v.label.as_deref(), Some("equivalent"));
601 }
602}