batuta/serve/banco/
handlers_inference.rs1use super::state::BancoState;
6use super::types::BancoChatRequest;
7
8pub fn try_inference(
11 state: &BancoState,
12 request: &BancoChatRequest,
13) -> Option<(String, String, u32)> {
14 let model = state.model.quantized_model()?;
15 let vocab = state.model.vocabulary();
16 if vocab.is_empty() {
17 return None;
18 }
19
20 let formatted = state.template_engine.apply(&request.messages);
21 let prompt_tokens = state.model.encode_text(&formatted);
23 if prompt_tokens.is_empty() {
24 return None;
25 }
26
27 let server_params = state.inference_params.read().ok()?;
28 let params = super::inference::SamplingParams {
29 temperature: if (request.temperature - 0.7).abs() < f32::EPSILON {
30 server_params.temperature
31 } else {
32 request.temperature
33 },
34 top_k: server_params.top_k,
35 max_tokens: request.max_tokens,
36 };
37 drop(server_params);
38
39 match super::inference::generate_sync(&model, &vocab, &prompt_tokens, ¶ms) {
40 Ok(result) => Some((result.text, result.finish_reason, result.token_count)),
41 Err(e) => {
42 eprintln!("[banco] inference error: {e}");
43 None
44 }
45 }
46}
47
48pub fn try_stream_inference(
51 state: &BancoState,
52 request: &BancoChatRequest,
53) -> Option<Vec<(String, Option<String>)>> {
54 let model = state.model.quantized_model()?;
55 let vocab = state.model.vocabulary();
56 if vocab.is_empty() {
57 return None;
58 }
59
60 let formatted = state.template_engine.apply(&request.messages);
61 let prompt_tokens = state.model.encode_text(&formatted);
63 if prompt_tokens.is_empty() {
64 return None;
65 }
66
67 let server_params = state.inference_params.read().ok()?;
68 let params = super::inference::SamplingParams {
69 temperature: if (request.temperature - 0.7).abs() < f32::EPSILON {
70 server_params.temperature
71 } else {
72 request.temperature
73 },
74 top_k: server_params.top_k,
75 max_tokens: request.max_tokens,
76 };
77 drop(server_params);
78
79 match super::inference::generate_stream_tokens(&model, &vocab, &prompt_tokens, ¶ms) {
80 Ok(stream_tokens) => {
81 let result: Vec<(String, Option<String>)> =
82 stream_tokens.into_iter().map(|st| (st.text, st.finish_reason)).collect();
83 Some(result)
84 }
85 Err(e) => {
86 eprintln!("[banco] stream inference error: {e}");
87 None
88 }
89 }
90}