Skip to main content

batuta/serve/banco/
handlers_inference.rs

1//! Inference bridge — connects banco handlers to realizar inference engine.
2//!
3//! Gated behind `#[cfg(feature = "realizar")]`.
4
5use super::state::BancoState;
6use super::types::BancoChatRequest;
7
8/// Try to run inference if a model is loaded and the inference feature is enabled.
9/// Returns Some((content, finish_reason, completion_tokens)) on success.
10pub fn try_inference(
11    state: &BancoState,
12    request: &BancoChatRequest,
13) -> Option<(String, String, u32)> {
14    let model = state.model.quantized_model()?;
15    let vocab = state.model.vocabulary();
16    if vocab.is_empty() {
17        return None;
18    }
19
20    let formatted = state.template_engine.apply(&request.messages);
21    // Use proper BPE tokenizer when available, else greedy fallback
22    let prompt_tokens = state.model.encode_text(&formatted);
23    if prompt_tokens.is_empty() {
24        return None;
25    }
26
27    let server_params = state.inference_params.read().ok()?;
28    let params = super::inference::SamplingParams {
29        temperature: if (request.temperature - 0.7).abs() < f32::EPSILON {
30            server_params.temperature
31        } else {
32            request.temperature
33        },
34        top_k: server_params.top_k,
35        max_tokens: request.max_tokens,
36    };
37    drop(server_params);
38
39    match super::inference::generate_sync(&model, &vocab, &prompt_tokens, &params) {
40        Ok(result) => Some((result.text, result.finish_reason, result.token_count)),
41        Err(e) => {
42            eprintln!("[banco] inference error: {e}");
43            None
44        }
45    }
46}
47
48/// Try to generate streaming tokens via inference.
49/// Returns Some(vec of (text, optional_finish_reason)) on success.
50pub fn try_stream_inference(
51    state: &BancoState,
52    request: &BancoChatRequest,
53) -> Option<Vec<(String, Option<String>)>> {
54    let model = state.model.quantized_model()?;
55    let vocab = state.model.vocabulary();
56    if vocab.is_empty() {
57        return None;
58    }
59
60    let formatted = state.template_engine.apply(&request.messages);
61    // Use proper BPE tokenizer when available, else greedy fallback
62    let prompt_tokens = state.model.encode_text(&formatted);
63    if prompt_tokens.is_empty() {
64        return None;
65    }
66
67    let server_params = state.inference_params.read().ok()?;
68    let params = super::inference::SamplingParams {
69        temperature: if (request.temperature - 0.7).abs() < f32::EPSILON {
70            server_params.temperature
71        } else {
72            request.temperature
73        },
74        top_k: server_params.top_k,
75        max_tokens: request.max_tokens,
76    };
77    drop(server_params);
78
79    match super::inference::generate_stream_tokens(&model, &vocab, &prompt_tokens, &params) {
80        Ok(stream_tokens) => {
81            let result: Vec<(String, Option<String>)> =
82                stream_tokens.into_iter().map(|st| (st.text, st.finish_reason)).collect();
83            Some(result)
84        }
85        Err(e) => {
86            eprintln!("[banco] stream inference error: {e}");
87            None
88        }
89    }
90}