batuta/serve/banco/
handlers_batch.rs1use axum::{extract::State, http::StatusCode, response::Json};
4use serde::Deserialize;
5
6use super::batch::{BatchItem, BatchItemResult, BatchJob};
7use super::state::BancoState;
8use super::types::ErrorResponse;
9
10pub async fn submit_batch_handler(
12 State(state): State<BancoState>,
13 Json(request): Json<BatchRequest>,
14) -> Json<BatchJob> {
15 let job = state.batches.run(request.items, |item| {
16 #[cfg(feature = "realizar")]
18 if let Some(result) = try_batch_inference(&state, item) {
19 return result;
20 }
21
22 let formatted = state.template_engine.apply(&item.messages);
24 let content = format!(
25 "[batch dry-run] id={} | prompt_len={} | formatted_len={}",
26 item.id,
27 item.messages.len(),
28 formatted.len()
29 );
30 let tokens = (content.len() / 4) as u32;
31 BatchItemResult {
32 id: item.id.clone(),
33 content,
34 finish_reason: "dry_run".to_string(),
35 tokens,
36 }
37 });
38 Json(job)
39}
40
41#[cfg(feature = "realizar")]
43fn try_batch_inference(state: &BancoState, item: &BatchItem) -> Option<BatchItemResult> {
44 let model = state.model.quantized_model()?;
45 let vocab = state.model.vocabulary();
46 if vocab.is_empty() {
47 return None;
48 }
49
50 let formatted = state.template_engine.apply(&item.messages);
51 let prompt_tokens = state.model.encode_text(&formatted);
52 if prompt_tokens.is_empty() {
53 return None;
54 }
55
56 let server_params = state.inference_params.read().ok()?;
57 let params = super::inference::SamplingParams {
58 temperature: server_params.temperature,
59 top_k: server_params.top_k,
60 max_tokens: item.max_tokens,
61 };
62 drop(server_params);
63
64 match super::inference::generate_sync(&model, &vocab, &prompt_tokens, ¶ms) {
65 Ok(result) => Some(BatchItemResult {
66 id: item.id.clone(),
67 content: result.text,
68 finish_reason: result.finish_reason,
69 tokens: result.token_count,
70 }),
71 Err(_) => None,
72 }
73}
74
75pub async fn get_batch_handler(
77 State(state): State<BancoState>,
78 axum::extract::Path(id): axum::extract::Path<String>,
79) -> Result<Json<BatchJob>, (StatusCode, Json<ErrorResponse>)> {
80 state.batches.get(&id).map(Json).ok_or((
81 StatusCode::NOT_FOUND,
82 Json(ErrorResponse::new(format!("Batch {id} not found"), "not_found", 404)),
83 ))
84}
85
86pub async fn list_batches_handler(State(state): State<BancoState>) -> Json<BatchListResponse> {
88 Json(BatchListResponse { batches: state.batches.list() })
89}
90
91#[derive(Debug, Deserialize)]
92pub struct BatchRequest {
93 pub items: Vec<BatchItem>,
94}
95
96#[derive(Debug, serde::Serialize)]
97pub struct BatchListResponse {
98 pub batches: Vec<BatchJob>,
99}