1use axum::{extract::State, response::IntoResponse, Json};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use crate::api_types::{
12 ChoiceLogprobs, ExtendedChatRequest, ExtendedChatResponse, ExtendedChoice, UsageInfo,
13};
14use crate::engine::InferenceEngine;
15use crate::sampling::SamplingParams;
16use crate::server::{AppState, ChatMessage};
17
18pub async fn extended_chat_completions(
26 State(state): State<Arc<AppState>>,
27 Json(req): Json<ExtendedChatRequest>,
28) -> impl IntoResponse {
29 let n = req.n.unwrap_or(1).clamp(1, 4);
30 let max_tokens = req.max_tokens;
31 let temperature = req.temperature.unwrap_or(0.7);
32 let seed = req.seed.unwrap_or(42);
33 let want_logprobs = req.logprobs.unwrap_or(false);
34 let top_logprobs_k = req.top_logprobs.unwrap_or(0).clamp(0, 20);
35 let response_format = req.response_format.clone();
36 let tools = req.tools.clone();
37 let frequency_penalty = req.frequency_penalty.unwrap_or(0.0);
38 let presence_penalty = req.presence_penalty.unwrap_or(0.0);
39
40 let stop_checker = match req.stop {
42 Some(ref seqs) => StopChecker::new(seqs.as_slice().to_vec()),
43 None => StopChecker::new(vec![]),
44 };
45
46 let prompt_text = build_extended_prompt(&req.messages);
48
49 let prompt_tokens = {
51 let tokenizer = state.tokenizer();
52 if let Some(tok) = tokenizer {
53 match tok.encode(&prompt_text) {
54 Ok(tokens) => tokens,
55 Err(e) => {
56 tracing::error!(error = %e, "tokenization failed");
57 return (
58 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
59 Json(serde_json::json!({"error": "tokenization failed"})),
60 )
61 .into_response();
62 }
63 }
64 } else {
65 vec![151644u32]
66 }
67 };
68
69 let prompt_len = prompt_tokens.len();
70
71 let sampling_params = SamplingParams {
73 temperature,
74 top_k: 40,
75 top_p: req.top_p.unwrap_or(0.9),
76 repetition_penalty: 1.1,
77 ..SamplingParams::default()
78 };
79
80 let mut engine = state.engine_lock().await;
82
83 let raw_completions: Vec<String> = {
84 let mut results = Vec::with_capacity(n);
85 for i in 0..n {
86 let run_seed = seed.wrapping_add(i as u64);
87 engine.reset();
88
89 let output_tokens = match engine.generate_with_seed(
90 &prompt_tokens,
91 max_tokens,
92 run_seed,
93 &sampling_params,
94 ) {
95 Ok(toks) => toks,
96 Err(e) => {
97 tracing::error!(error = %e, "generation failed for completion {i}");
98 return (
99 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
100 Json(serde_json::json!({"error": "generation failed"})),
101 )
102 .into_response();
103 }
104 };
105
106 let _ = frequency_penalty;
109 let _ = presence_penalty;
110
111 let text = if let Some(tok) = state.tokenizer() {
113 tok.decode(&output_tokens)
114 .unwrap_or_else(|_| format!("{output_tokens:?}"))
115 } else {
116 format!("{output_tokens:?}")
117 };
118
119 results.push(text);
120 }
121 results
122 };
123
124 let json_enforcer = JsonModeEnforcer::new();
126 let is_json_mode = response_format
127 .as_ref()
128 .map(|rf| rf.format_type == "json_object" || rf.format_type == "json_schema")
129 .unwrap_or(false);
130
131 let total_completion_tokens: usize;
132 let choices: Vec<ExtendedChoice> = {
133 let mut comp_tokens = 0usize;
134 let choices_out: Vec<ExtendedChoice> = raw_completions
135 .into_iter()
136 .enumerate()
137 .map(|(idx, raw_text)| {
138 let (truncated, hit_stop) = stop_checker.truncate_at_stop(&raw_text);
139 let finish_reason = "stop".to_string();
140 let _ = hit_stop;
141
142 let final_text = if is_json_mode {
144 json_enforcer.enforce(&truncated)
145 } else {
146 truncated.clone()
147 };
148
149 let tool_calls = if tools.is_some() {
151 let call_id = crate::api_types::generate_tool_call_id();
152 crate::api_types::parse_tool_call(&final_text, &call_id).map(|tc| vec![tc])
153 } else {
154 None
155 };
156
157 let logprobs = if want_logprobs && top_logprobs_k > 0 {
159 Some(ChoiceLogprobs {
161 content: Some(vec![]),
162 })
163 } else if want_logprobs {
164 Some(ChoiceLogprobs {
165 content: Some(vec![]),
166 })
167 } else {
168 None
169 };
170
171 let approx_tokens = final_text.split_whitespace().count().max(1);
173 comp_tokens += approx_tokens;
174
175 ExtendedChoice {
176 index: idx,
177 message: ChatMessage {
178 role: "assistant".to_string(),
179 content: Some(final_text),
180 tool_calls: None,
181 tool_call_id: None,
182 },
183 finish_reason,
184 logprobs,
185 tool_calls,
186 }
187 })
188 .collect();
189 total_completion_tokens = comp_tokens;
190 choices_out
191 };
192
193 let system_fingerprint = Some(crate::api_types::fingerprint_from_config("bonsai-8b"));
195
196 let created = std::time::SystemTime::now()
197 .duration_since(std::time::UNIX_EPOCH)
198 .unwrap_or_default()
199 .as_secs();
200
201 let response = ExtendedChatResponse {
202 id: format!("chatcmpl-ext-{}", rand_ext_id()),
203 object: "chat.completion".to_string(),
204 created,
205 model: "bonsai-8b".to_string(),
206 choices,
207 usage: UsageInfo {
208 prompt_tokens: prompt_len,
209 completion_tokens: total_completion_tokens,
210 total_tokens: prompt_len + total_completion_tokens,
211 },
212 system_fingerprint,
213 };
214
215 Json(response).into_response()
216}
217
218fn build_extended_prompt(messages: &[ChatMessage]) -> String {
222 let mut prompt = String::new();
223 for msg in messages {
224 let text = match msg.content.as_deref() {
225 Some(t) => t,
226 None => continue,
227 };
228 match msg.role.as_str() {
229 "system" => {
230 prompt.push_str("<|im_start|>system\n");
231 prompt.push_str(text);
232 prompt.push_str("<|im_end|>\n");
233 }
234 "user" => {
235 prompt.push_str("<|im_start|>user\n");
236 prompt.push_str(text);
237 prompt.push_str("<|im_end|>\n");
238 }
239 "assistant" => {
240 prompt.push_str("<|im_start|>assistant\n");
241 prompt.push_str(text);
242 prompt.push_str("<|im_end|>\n");
243 }
244 _ => {
245 prompt.push_str(text);
246 prompt.push('\n');
247 }
248 }
249 }
250 prompt.push_str("<|im_start|>assistant\n");
251 prompt
252}
253
254fn rand_ext_id() -> String {
255 let ts = std::time::SystemTime::now()
256 .duration_since(std::time::UNIX_EPOCH)
257 .unwrap_or_default()
258 .as_nanos();
259 format!("{ts:x}")
260}
261
262pub struct JsonModeEnforcer {
271 pub max_retries: usize,
273}
274
275impl JsonModeEnforcer {
276 pub fn new() -> Self {
278 Self { max_retries: 3 }
279 }
280
281 pub fn enforce(&self, text: &str) -> String {
284 if crate::api_types::is_valid_json(text) {
286 return text.to_string();
287 }
288
289 if let Some(extracted) = extract_json_substring(text) {
291 if crate::api_types::is_valid_json(&extracted) {
292 return extracted;
293 }
294 }
295
296 let escaped = text.replace('\\', "\\\\").replace('"', "\\\"");
298 format!(r#"{{"response": "{escaped}"}}"#)
299 }
300}
301
302impl Default for JsonModeEnforcer {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308fn extract_json_substring(text: &str) -> Option<String> {
310 if let Some(obj) = extract_balanced(text, '{', '}') {
312 return Some(obj);
313 }
314 if let Some(arr) = extract_balanced(text, '[', ']') {
316 return Some(arr);
317 }
318 None
319}
320
321fn extract_balanced(text: &str, open: char, close: char) -> Option<String> {
324 let start = text.find(open)?;
325 let substr = &text[start..];
326 let mut depth = 0i32;
327 let mut end_idx = None;
328
329 for (i, ch) in substr.char_indices() {
330 if ch == open {
331 depth += 1;
332 } else if ch == close {
333 depth -= 1;
334 if depth == 0 {
335 end_idx = Some(i + ch.len_utf8());
336 break;
337 }
338 }
339 }
340
341 end_idx.map(|e| substr[..e].to_string())
342}
343
344pub struct StopChecker {
348 sequences: Vec<String>,
349}
350
351impl StopChecker {
352 pub fn new(sequences: Vec<String>) -> Self {
354 Self { sequences }
355 }
356
357 pub fn check<'a>(&'a self, text: &str) -> Option<&'a str> {
359 for seq in &self.sequences {
360 if text.contains(seq.as_str()) {
361 return Some(seq.as_str());
362 }
363 }
364 None
365 }
366
367 pub fn truncate_at_stop(&self, text: &str) -> (String, bool) {
371 let mut earliest: Option<(usize, &str)> = None;
372 for seq in &self.sequences {
373 if let Some(pos) = text.find(seq.as_str()) {
374 match earliest {
375 None => earliest = Some((pos, seq.as_str())),
376 Some((prev_pos, _)) if pos < prev_pos => {
377 earliest = Some((pos, seq.as_str()));
378 }
379 _ => {}
380 }
381 }
382 }
383
384 match earliest {
385 Some((pos, _)) => (text[..pos].to_string(), true),
386 None => (text.to_string(), false),
387 }
388 }
389
390 pub fn is_empty(&self) -> bool {
392 self.sequences.is_empty()
393 }
394}
395
396pub fn generate_n_completions(
403 engine: &mut InferenceEngine<'_>,
404 prompt: &str,
405 params: &SamplingParams,
406 n: usize,
407 base_seed: u64,
408) -> Vec<String> {
409 let prompt_tokens: Vec<u32> = {
410 prompt
412 .split_whitespace()
413 .enumerate()
414 .map(|(i, _)| (i as u32).wrapping_add(1000))
415 .collect()
416 };
417
418 let mut results = Vec::with_capacity(n);
419 for i in 0..n {
420 engine.reset();
421 let seed = base_seed.wrapping_add(i as u64);
422 let text = engine
423 .generate_with_seed(&prompt_tokens, 64, seed, params)
424 .map(|toks| format!("{toks:?}"))
425 .unwrap_or_else(|_| String::new());
426 results.push(text);
427 }
428 results
429}
430
431pub fn apply_frequency_penalty(
439 logits: &mut [f32],
440 token_counts: &HashMap<u32, usize>,
441 frequency_penalty: f32,
442 presence_penalty: f32,
443) {
444 for (&token_id, &count) in token_counts {
445 if let Some(logit) = logits.get_mut(token_id as usize) {
446 *logit -= frequency_penalty * count as f32;
447 *logit -= presence_penalty;
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn json_mode_enforcer_valid_passthrough() {
458 let enforcer = JsonModeEnforcer::new();
459 let json = r#"{"key": "value"}"#;
460 assert_eq!(enforcer.enforce(json), json);
461 }
462
463 #[test]
464 fn json_mode_enforcer_extracts_substring() {
465 let enforcer = JsonModeEnforcer::new();
466 let text = r#"Here is some text {"key": "value"} and more"#;
467 let result = enforcer.enforce(text);
468 assert!(
469 crate::api_types::is_valid_json(&result),
470 "result should be valid JSON, got: {result}"
471 );
472 }
473
474 #[test]
475 fn json_mode_enforcer_wraps_invalid() {
476 let enforcer = JsonModeEnforcer::new();
477 let text = "not json at all";
478 let result = enforcer.enforce(text);
479 assert!(
480 crate::api_types::is_valid_json(&result),
481 "result should be valid JSON, got: {result}"
482 );
483 let v: serde_json::Value = serde_json::from_str(&result).expect("should parse as json");
484 assert!(v.get("response").is_some(), "should have 'response' key");
485 }
486
487 #[test]
488 fn stop_checker_finds_sequence() {
489 let checker = StopChecker::new(vec!["STOP".to_string(), "END".to_string()]);
490 assert_eq!(checker.check("Hello STOP world"), Some("STOP"));
491 assert_eq!(checker.check("No match here"), None);
492 }
493
494 #[test]
495 fn stop_checker_truncates_correctly() {
496 let checker = StopChecker::new(vec!["<end>".to_string()]);
497 let (truncated, hit) = checker.truncate_at_stop("Hello world<end>more text");
498 assert_eq!(truncated, "Hello world");
499 assert!(hit);
500 }
501
502 #[test]
503 fn stop_checker_no_match() {
504 let checker = StopChecker::new(vec!["nope".to_string()]);
505 let (truncated, hit) = checker.truncate_at_stop("Hello world");
506 assert_eq!(truncated, "Hello world");
507 assert!(!hit);
508 }
509
510 #[test]
511 fn stop_checker_is_empty() {
512 let empty = StopChecker::new(vec![]);
513 assert!(empty.is_empty());
514 let non_empty = StopChecker::new(vec!["x".to_string()]);
515 assert!(!non_empty.is_empty());
516 }
517
518 #[test]
519 fn apply_frequency_penalty_reduces_seen() {
520 let mut logits = vec![1.0f32, 2.0, 3.0];
521 let mut counts = HashMap::new();
522 counts.insert(1u32, 2usize); apply_frequency_penalty(&mut logits, &counts, 0.5, 0.0);
524 assert!(
526 (logits[1] - 1.0).abs() < 1e-5,
527 "expected 1.0, got {}",
528 logits[1]
529 );
530 assert!((logits[0] - 1.0).abs() < 1e-5);
532 assert!((logits[2] - 3.0).abs() < 1e-5);
533 }
534
535 #[test]
536 fn apply_presence_penalty_reduces_seen() {
537 let mut logits = vec![1.0f32, 2.0, 3.0];
538 let mut counts = HashMap::new();
539 counts.insert(0u32, 1usize);
540 apply_frequency_penalty(&mut logits, &counts, 0.0, 1.0);
541 assert!(
542 (logits[0] - 0.0).abs() < 1e-5,
543 "expected 0.0, got {}",
544 logits[0]
545 );
546 assert!((logits[1] - 2.0).abs() < 1e-5);
547 }
548
549 #[test]
550 fn extract_balanced_object() {
551 let text = r#"prefix {"a":1} suffix"#;
552 let result = extract_balanced(text, '{', '}');
553 assert_eq!(result.as_deref(), Some(r#"{"a":1}"#));
554 }
555
556 #[test]
557 fn extract_balanced_array() {
558 let text = r#"pre [1,2,3] post"#;
559 let result = extract_balanced(text, '[', ']');
560 assert_eq!(result.as_deref(), Some("[1,2,3]"));
561 }
562}