Skip to main content

batuta/serve/banco/
handlers_tokens.rs

1//! Token-related handlers — embeddings, tokenize, detokenize, inference parameters.
2
3use axum::{extract::State, response::Json};
4
5use super::state::BancoState;
6use super::types::{
7    DetokenizeRequest, DetokenizeResponse, EmbeddingData, EmbeddingsRequest, EmbeddingsResponse,
8    EmbeddingsUsage, InferenceParams, TokenizeRequest, TokenizeResponse,
9};
10use crate::serve::templates::ChatMessage;
11
12// ============================================================================
13// Embeddings
14// ============================================================================
15
16pub async fn embeddings_handler(
17    State(state): State<BancoState>,
18    Json(request): Json<EmbeddingsRequest>,
19) -> Json<EmbeddingsResponse> {
20    let model_name = request.model.unwrap_or_else(|| "banco-heuristic".to_string());
21    let texts = request.input.texts();
22
23    #[cfg(feature = "realizar")]
24    let use_model = state.model.has_inference_model();
25
26    let data: Vec<EmbeddingData> = texts
27        .iter()
28        .enumerate()
29        .map(|(i, text)| {
30            #[cfg(feature = "realizar")]
31            let embedding = if use_model {
32                model_embedding(&state, text).unwrap_or_else(|| heuristic_embedding(text))
33            } else {
34                heuristic_embedding(text)
35            };
36            #[cfg(not(feature = "realizar"))]
37            let embedding = heuristic_embedding(text);
38            EmbeddingData { object: "embedding".to_string(), index: i as u32, embedding }
39        })
40        .collect();
41
42    let total_tokens: u32 = texts
43        .iter()
44        .map(|t| state.context_manager.estimate_tokens(&[ChatMessage::user(*t)]) as u32)
45        .sum();
46
47    Json(EmbeddingsResponse {
48        object: "list".to_string(),
49        data,
50        model: model_name,
51        usage: EmbeddingsUsage { prompt_tokens: total_tokens, total_tokens },
52    })
53}
54
55#[cfg(feature = "realizar")]
56fn model_embedding(state: &BancoState, text: &str) -> Option<Vec<f32>> {
57    let model = state.model.quantized_model()?;
58    let token_ids = state.model.encode_text(text);
59    super::inference::embed_tokens(&model, &token_ids)
60}
61
62fn heuristic_embedding(text: &str) -> Vec<f32> {
63    let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
64    for byte in text.bytes() {
65        hash ^= byte as u64;
66        hash = hash.wrapping_mul(0x0100_0000_01b3);
67    }
68    let mut state = hash;
69    (0..128)
70        .map(|_| {
71            state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
72            ((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
73        })
74        .collect()
75}
76
77// ============================================================================
78// Tokenize / Detokenize
79// ============================================================================
80
81pub async fn tokenize_handler(
82    State(state): State<BancoState>,
83    Json(request): Json<TokenizeRequest>,
84) -> Json<TokenizeResponse> {
85    #[cfg(feature = "realizar")]
86    if state.model.has_inference_model() {
87        let tokens = state.model.encode_text(&request.text);
88        if !tokens.is_empty() {
89            let count = tokens.len() as u32;
90            return Json(TokenizeResponse { tokens, count });
91        }
92    }
93
94    let estimated = state.context_manager.estimate_tokens(&[ChatMessage::user(&request.text)]);
95    let tokens: Vec<u32> = (0..estimated as u32).collect();
96    Json(TokenizeResponse { count: estimated as u32, tokens })
97}
98
99pub async fn detokenize_handler(
100    State(state): State<BancoState>,
101    Json(request): Json<DetokenizeRequest>,
102) -> Json<DetokenizeResponse> {
103    #[cfg(feature = "realizar")]
104    if state.model.has_inference_model() {
105        let vocab = state.model.vocabulary();
106        if !vocab.is_empty() {
107            let text: String =
108                request.tokens.iter().filter_map(|&id| vocab.get(id as usize)).cloned().collect();
109            return Json(DetokenizeResponse { text });
110        }
111    }
112
113    let _ = &state;
114    let approx_chars = request.tokens.len() * 4;
115    let text = format!("[{} tokens ≈ {} chars]", request.tokens.len(), approx_chars);
116    Json(DetokenizeResponse { text })
117}
118
119// ============================================================================
120// Inference Parameters
121// ============================================================================
122
123pub async fn get_parameters_handler(State(state): State<BancoState>) -> Json<InferenceParams> {
124    let params = state.inference_params.read().unwrap_or_else(|e| e.into_inner());
125    Json(params.clone())
126}
127
128pub async fn update_parameters_handler(
129    State(state): State<BancoState>,
130    Json(update): Json<InferenceParams>,
131) -> Json<InferenceParams> {
132    if let Ok(mut params) = state.inference_params.write() {
133        *params = update;
134    }
135    let params = state.inference_params.read().unwrap_or_else(|e| e.into_inner());
136    Json(params.clone())
137}