Skip to main content

batuta/serve/banco/
handlers_eval.rs

1//! Eval endpoint handlers.
2
3use axum::{extract::State, http::StatusCode, response::Json};
4use serde::Deserialize;
5
6use super::eval::{EvalResult, EvalStatus};
7use super::state::BancoState;
8use super::types::ErrorResponse;
9
10/// POST /api/v1/eval/perplexity — compute perplexity on text.
11pub async fn eval_perplexity_handler(
12    State(state): State<BancoState>,
13    Json(request): Json<PerplexityRequest>,
14) -> Result<Json<EvalResult>, (StatusCode, Json<ErrorResponse>)> {
15    let eval_id = state.evals.next_id();
16    let model_name = state.model.info().map(|m| m.model_id).unwrap_or_else(|| "none".to_string());
17
18    #[cfg(feature = "realizar")]
19    let ppl_result = {
20        let model = state.model.quantized_model();
21        match model {
22            Some(m) => {
23                let token_ids = state.model.encode_text(&request.text);
24                if token_ids.is_empty() {
25                    None
26                } else {
27                    let max_tokens = request.max_tokens.unwrap_or(512) as usize;
28                    let start = std::time::Instant::now();
29                    let result = super::eval::compute_perplexity(&m, &token_ids, max_tokens);
30                    let duration = start.elapsed().as_secs_f64();
31                    result.map(|(ppl, tokens)| (ppl, tokens, duration))
32                }
33            }
34            _ => None,
35        }
36    };
37    #[cfg(not(feature = "realizar"))]
38    let ppl_result: Option<(f64, usize, f64)> = {
39        let _ = &request;
40        None
41    };
42
43    let result = if let Some((ppl, tokens, duration)) = ppl_result {
44        EvalResult {
45            eval_id,
46            model: model_name,
47            metric: "perplexity".to_string(),
48            value: ppl,
49            tokens_evaluated: tokens,
50            duration_secs: duration,
51            status: EvalStatus::Complete,
52        }
53    } else {
54        EvalResult {
55            eval_id,
56            model: model_name,
57            metric: "perplexity".to_string(),
58            value: 0.0,
59            tokens_evaluated: 0,
60            duration_secs: 0.0,
61            status: EvalStatus::NoModel,
62        }
63    };
64
65    state.evals.record(result.clone());
66    Ok(Json(result))
67}
68
69/// GET /api/v1/eval/runs — list eval runs.
70pub async fn list_eval_runs_handler(State(state): State<BancoState>) -> Json<EvalRunsResponse> {
71    Json(EvalRunsResponse { runs: state.evals.list() })
72}
73
74/// GET /api/v1/eval/runs/:id — get eval result.
75pub async fn get_eval_run_handler(
76    State(state): State<BancoState>,
77    axum::extract::Path(id): axum::extract::Path<String>,
78) -> Result<Json<EvalResult>, (StatusCode, Json<ErrorResponse>)> {
79    state.evals.get(&id).map(Json).ok_or((
80        StatusCode::NOT_FOUND,
81        Json(ErrorResponse::new(format!("Eval run {id} not found"), "not_found", 404)),
82    ))
83}
84
85#[derive(Debug, Deserialize)]
86pub struct PerplexityRequest {
87    pub text: String,
88    #[serde(default)]
89    pub max_tokens: Option<u32>,
90}
91
92#[derive(Debug, serde::Serialize)]
93pub struct EvalRunsResponse {
94    pub runs: Vec<EvalResult>,
95}