Skip to main content

batuta/serve/banco/
eval.rs

1//! Model evaluation — perplexity and benchmarks.
2//!
3//! Uses the existing inference engine to compute perplexity on text samples.
4//! Perplexity measures how well the model predicts the next token.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// Eval run result.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EvalResult {
13    pub eval_id: String,
14    pub model: String,
15    pub metric: String,
16    pub value: f64,
17    pub tokens_evaluated: usize,
18    pub duration_secs: f64,
19    pub status: EvalStatus,
20}
21
22/// Eval status.
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24#[serde(rename_all = "snake_case")]
25pub enum EvalStatus {
26    Running,
27    Complete,
28    Failed,
29    NoModel,
30}
31
32/// Eval store — tracks evaluation runs.
33pub struct EvalStore {
34    runs: RwLock<HashMap<String, EvalResult>>,
35    counter: std::sync::atomic::AtomicU64,
36}
37
38impl EvalStore {
39    #[must_use]
40    pub fn new() -> Arc<Self> {
41        Arc::new(Self {
42            runs: RwLock::new(HashMap::new()),
43            counter: std::sync::atomic::AtomicU64::new(0),
44        })
45    }
46
47    /// Record an eval result.
48    pub fn record(&self, result: EvalResult) {
49        if let Ok(mut store) = self.runs.write() {
50            store.insert(result.eval_id.clone(), result);
51        }
52    }
53
54    /// List all eval runs (most recent first).
55    #[must_use]
56    pub fn list(&self) -> Vec<EvalResult> {
57        let store = self.runs.read().unwrap_or_else(|e| e.into_inner());
58        let mut runs: Vec<EvalResult> = store.values().cloned().collect();
59        runs.sort_by(|a, b| b.eval_id.cmp(&a.eval_id));
60        runs
61    }
62
63    /// Get an eval run by ID.
64    #[must_use]
65    pub fn get(&self, id: &str) -> Option<EvalResult> {
66        self.runs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
67    }
68
69    /// Generate a unique eval ID.
70    pub fn next_id(&self) -> String {
71        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
72        format!("eval-{}-{seq}", epoch_secs())
73    }
74}
75
76/// Compute perplexity on a text sample using the inference engine.
77///
78/// PPL = exp(-1/N * Σ log P(token_i | context))
79///
80/// Requires a loaded model with inference feature. Returns None without.
81/// Accepts pre-tokenized IDs so the caller can use proper BPE encoding.
82#[cfg(feature = "realizar")]
83pub fn compute_perplexity(
84    model: &Arc<realizar::gguf::OwnedQuantizedModel>,
85    token_ids: &[u32],
86    max_tokens: usize,
87) -> Option<(f64, usize)> {
88    if token_ids.len() < 2 {
89        return None;
90    }
91
92    let config = model.config();
93    let num_kv_heads = config.num_kv_heads;
94    let head_dim = config.hidden_dim / config.num_heads;
95    let kv_dim = num_kv_heads * head_dim;
96    let eval_len = token_ids.len().min(max_tokens);
97
98    let mut cache =
99        realizar::gguf::OwnedQuantizedKVCache::new(config.num_layers, kv_dim, eval_len + 1);
100
101    let mut total_log_prob = 0.0f64;
102    let mut count = 0usize;
103
104    for pos in 0..eval_len - 1 {
105        let logits = model.forward_single_with_cache(token_ids[pos], &mut cache, pos).ok()?;
106
107        // Softmax to get probabilities
108        let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
109        let exp_sum: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
110        let next_token = token_ids[pos + 1] as usize;
111
112        if next_token < logits.len() {
113            let log_prob = (logits[next_token] - max_logit) as f64 - (exp_sum as f64).ln();
114            total_log_prob += log_prob;
115            count += 1;
116        }
117    }
118
119    if count == 0 {
120        return None;
121    }
122
123    let ppl = (-total_log_prob / count as f64).exp();
124    Some((ppl, count))
125}
126
127fn epoch_secs() -> u64 {
128    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
129}