Skip to main content

batuta/serve/banco/
handlers_batch.rs

1//! Batch inference endpoint handlers.
2
3use axum::{extract::State, http::StatusCode, response::Json};
4use serde::Deserialize;
5
6use super::batch::{BatchItem, BatchItemResult, BatchJob};
7use super::state::BancoState;
8use super::types::ErrorResponse;
9
10/// POST /api/v1/batch — submit a batch of prompts for processing.
11pub async fn submit_batch_handler(
12    State(state): State<BancoState>,
13    Json(request): Json<BatchRequest>,
14) -> Json<BatchJob> {
15    let job = state.batches.run(request.items, |item| {
16        // Try inference when model is loaded
17        #[cfg(feature = "realizar")]
18        if let Some(result) = try_batch_inference(&state, item) {
19            return result;
20        }
21
22        // Dry-run fallback
23        let formatted = state.template_engine.apply(&item.messages);
24        let content = format!(
25            "[batch dry-run] id={} | prompt_len={} | formatted_len={}",
26            item.id,
27            item.messages.len(),
28            formatted.len()
29        );
30        let tokens = (content.len() / 4) as u32;
31        BatchItemResult {
32            id: item.id.clone(),
33            content,
34            finish_reason: "dry_run".to_string(),
35            tokens,
36        }
37    });
38    Json(job)
39}
40
41/// Try to run inference for a single batch item.
42#[cfg(feature = "realizar")]
43fn try_batch_inference(state: &BancoState, item: &BatchItem) -> Option<BatchItemResult> {
44    let model = state.model.quantized_model()?;
45    let vocab = state.model.vocabulary();
46    if vocab.is_empty() {
47        return None;
48    }
49
50    let formatted = state.template_engine.apply(&item.messages);
51    let prompt_tokens = state.model.encode_text(&formatted);
52    if prompt_tokens.is_empty() {
53        return None;
54    }
55
56    let server_params = state.inference_params.read().ok()?;
57    let params = super::inference::SamplingParams {
58        temperature: server_params.temperature,
59        top_k: server_params.top_k,
60        max_tokens: item.max_tokens,
61    };
62    drop(server_params);
63
64    match super::inference::generate_sync(&model, &vocab, &prompt_tokens, &params) {
65        Ok(result) => Some(BatchItemResult {
66            id: item.id.clone(),
67            content: result.text,
68            finish_reason: result.finish_reason,
69            tokens: result.token_count,
70        }),
71        Err(_) => None,
72    }
73}
74
75/// GET /api/v1/batch/:id — get batch job status and results.
76pub async fn get_batch_handler(
77    State(state): State<BancoState>,
78    axum::extract::Path(id): axum::extract::Path<String>,
79) -> Result<Json<BatchJob>, (StatusCode, Json<ErrorResponse>)> {
80    state.batches.get(&id).map(Json).ok_or((
81        StatusCode::NOT_FOUND,
82        Json(ErrorResponse::new(format!("Batch {id} not found"), "not_found", 404)),
83    ))
84}
85
86/// GET /api/v1/batch — list all batch jobs.
87pub async fn list_batches_handler(State(state): State<BancoState>) -> Json<BatchListResponse> {
88    Json(BatchListResponse { batches: state.batches.list() })
89}
90
91#[derive(Debug, Deserialize)]
92pub struct BatchRequest {
93    pub items: Vec<BatchItem>,
94}
95
96#[derive(Debug, serde::Serialize)]
97pub struct BatchListResponse {
98    pub batches: Vec<BatchJob>,
99}