batuta/serve/banco/
handlers_tokens.rs1use axum::{extract::State, response::Json};
4
5use super::state::BancoState;
6use super::types::{
7 DetokenizeRequest, DetokenizeResponse, EmbeddingData, EmbeddingsRequest, EmbeddingsResponse,
8 EmbeddingsUsage, InferenceParams, TokenizeRequest, TokenizeResponse,
9};
10use crate::serve::templates::ChatMessage;
11
12pub async fn embeddings_handler(
17 State(state): State<BancoState>,
18 Json(request): Json<EmbeddingsRequest>,
19) -> Json<EmbeddingsResponse> {
20 let model_name = request.model.unwrap_or_else(|| "banco-heuristic".to_string());
21 let texts = request.input.texts();
22
23 #[cfg(feature = "realizar")]
24 let use_model = state.model.has_inference_model();
25
26 let data: Vec<EmbeddingData> = texts
27 .iter()
28 .enumerate()
29 .map(|(i, text)| {
30 #[cfg(feature = "realizar")]
31 let embedding = if use_model {
32 model_embedding(&state, text).unwrap_or_else(|| heuristic_embedding(text))
33 } else {
34 heuristic_embedding(text)
35 };
36 #[cfg(not(feature = "realizar"))]
37 let embedding = heuristic_embedding(text);
38 EmbeddingData { object: "embedding".to_string(), index: i as u32, embedding }
39 })
40 .collect();
41
42 let total_tokens: u32 = texts
43 .iter()
44 .map(|t| state.context_manager.estimate_tokens(&[ChatMessage::user(*t)]) as u32)
45 .sum();
46
47 Json(EmbeddingsResponse {
48 object: "list".to_string(),
49 data,
50 model: model_name,
51 usage: EmbeddingsUsage { prompt_tokens: total_tokens, total_tokens },
52 })
53}
54
55#[cfg(feature = "realizar")]
56fn model_embedding(state: &BancoState, text: &str) -> Option<Vec<f32>> {
57 let model = state.model.quantized_model()?;
58 let token_ids = state.model.encode_text(text);
59 super::inference::embed_tokens(&model, &token_ids)
60}
61
62fn heuristic_embedding(text: &str) -> Vec<f32> {
63 let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
64 for byte in text.bytes() {
65 hash ^= byte as u64;
66 hash = hash.wrapping_mul(0x0100_0000_01b3);
67 }
68 let mut state = hash;
69 (0..128)
70 .map(|_| {
71 state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
72 ((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
73 })
74 .collect()
75}
76
77pub async fn tokenize_handler(
82 State(state): State<BancoState>,
83 Json(request): Json<TokenizeRequest>,
84) -> Json<TokenizeResponse> {
85 #[cfg(feature = "realizar")]
86 if state.model.has_inference_model() {
87 let tokens = state.model.encode_text(&request.text);
88 if !tokens.is_empty() {
89 let count = tokens.len() as u32;
90 return Json(TokenizeResponse { tokens, count });
91 }
92 }
93
94 let estimated = state.context_manager.estimate_tokens(&[ChatMessage::user(&request.text)]);
95 let tokens: Vec<u32> = (0..estimated as u32).collect();
96 Json(TokenizeResponse { count: estimated as u32, tokens })
97}
98
99pub async fn detokenize_handler(
100 State(state): State<BancoState>,
101 Json(request): Json<DetokenizeRequest>,
102) -> Json<DetokenizeResponse> {
103 #[cfg(feature = "realizar")]
104 if state.model.has_inference_model() {
105 let vocab = state.model.vocabulary();
106 if !vocab.is_empty() {
107 let text: String =
108 request.tokens.iter().filter_map(|&id| vocab.get(id as usize)).cloned().collect();
109 return Json(DetokenizeResponse { text });
110 }
111 }
112
113 let _ = &state;
114 let approx_chars = request.tokens.len() * 4;
115 let text = format!("[{} tokens ≈ {} chars]", request.tokens.len(), approx_chars);
116 Json(DetokenizeResponse { text })
117}
118
119pub async fn get_parameters_handler(State(state): State<BancoState>) -> Json<InferenceParams> {
124 let params = state.inference_params.read().unwrap_or_else(|e| e.into_inner());
125 Json(params.clone())
126}
127
128pub async fn update_parameters_handler(
129 State(state): State<BancoState>,
130 Json(update): Json<InferenceParams>,
131) -> Json<InferenceParams> {
132 if let Ok(mut params) = state.inference_params.write() {
133 *params = update;
134 }
135 let params = state.inference_params.read().unwrap_or_else(|e| e.into_inner());
136 Json(params.clone())
137}