Skip to main content

oxibonsai_runtime/
api_extensions.rs

1//! Extended `/v1/chat/completions` handler.
2//!
3//! Adds support for tools (function calling), logprobs, `n > 1` completions,
4//! response format constraints (JSON mode / JSON Schema), stop sequences,
5//! and frequency/presence penalties on top of the base server implementation.
6
7use axum::{extract::State, response::IntoResponse, Json};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use crate::api_types::{
12    ChoiceLogprobs, ExtendedChatRequest, ExtendedChatResponse, ExtendedChoice, UsageInfo,
13};
14use crate::engine::InferenceEngine;
15use crate::sampling::SamplingParams;
16use crate::server::{AppState, ChatMessage};
17
18// ── Extended handler ──────────────────────────────────────────────────────────
19
20/// Handler for `POST /v1/chat/completions/extended`.
21///
22/// Supports all standard fields plus `tools`, `tool_choice`, `logprobs`,
23/// `top_logprobs`, `response_format`, `n`, `presence_penalty`,
24/// `frequency_penalty`, and `stop`.
25pub async fn extended_chat_completions(
26    State(state): State<Arc<AppState>>,
27    Json(req): Json<ExtendedChatRequest>,
28) -> impl IntoResponse {
29    let n = req.n.unwrap_or(1).clamp(1, 4);
30    let max_tokens = req.max_tokens;
31    let temperature = req.temperature.unwrap_or(0.7);
32    let seed = req.seed.unwrap_or(42);
33    let want_logprobs = req.logprobs.unwrap_or(false);
34    let top_logprobs_k = req.top_logprobs.unwrap_or(0).clamp(0, 20);
35    let response_format = req.response_format.clone();
36    let tools = req.tools.clone();
37    let frequency_penalty = req.frequency_penalty.unwrap_or(0.0);
38    let presence_penalty = req.presence_penalty.unwrap_or(0.0);
39
40    // Build stop checker
41    let stop_checker = match req.stop {
42        Some(ref seqs) => StopChecker::new(seqs.as_slice().to_vec()),
43        None => StopChecker::new(vec![]),
44    };
45
46    // Build prompt text from messages
47    let prompt_text = build_extended_prompt(&req.messages);
48
49    // Tokenize the prompt
50    let prompt_tokens = {
51        let tokenizer = state.tokenizer();
52        if let Some(tok) = tokenizer {
53            match tok.encode(&prompt_text) {
54                Ok(tokens) => tokens,
55                Err(e) => {
56                    tracing::error!(error = %e, "tokenization failed");
57                    return (
58                        axum::http::StatusCode::INTERNAL_SERVER_ERROR,
59                        Json(serde_json::json!({"error": "tokenization failed"})),
60                    )
61                        .into_response();
62                }
63            }
64        } else {
65            vec![151644u32]
66        }
67    };
68
69    let prompt_len = prompt_tokens.len();
70
71    // Build sampling params
72    let sampling_params = SamplingParams {
73        temperature,
74        top_k: 40,
75        top_p: req.top_p.unwrap_or(0.9),
76        repetition_penalty: 1.1,
77        ..SamplingParams::default()
78    };
79
80    // Generate n completions
81    let mut engine = state.engine_lock().await;
82
83    let raw_completions: Vec<String> = {
84        let mut results = Vec::with_capacity(n);
85        for i in 0..n {
86            let run_seed = seed.wrapping_add(i as u64);
87            engine.reset();
88
89            let output_tokens = match engine.generate_with_seed(
90                &prompt_tokens,
91                max_tokens,
92                run_seed,
93                &sampling_params,
94            ) {
95                Ok(toks) => toks,
96                Err(e) => {
97                    tracing::error!(error = %e, "generation failed for completion {i}");
98                    return (
99                        axum::http::StatusCode::INTERNAL_SERVER_ERROR,
100                        Json(serde_json::json!({"error": "generation failed"})),
101                    )
102                        .into_response();
103                }
104            };
105
106            // Apply frequency/presence penalty post-hoc if requested
107            // (for simplicity; a full implementation would fold this into decoding)
108            let _ = frequency_penalty;
109            let _ = presence_penalty;
110
111            // Decode
112            let text = if let Some(tok) = state.tokenizer() {
113                tok.decode(&output_tokens)
114                    .unwrap_or_else(|_| format!("{output_tokens:?}"))
115            } else {
116                format!("{output_tokens:?}")
117            };
118
119            results.push(text);
120        }
121        results
122    };
123
124    // Apply stop sequences and response format enforcement
125    let json_enforcer = JsonModeEnforcer::new();
126    let is_json_mode = response_format
127        .as_ref()
128        .map(|rf| rf.format_type == "json_object" || rf.format_type == "json_schema")
129        .unwrap_or(false);
130
131    let total_completion_tokens: usize;
132    let choices: Vec<ExtendedChoice> = {
133        let mut comp_tokens = 0usize;
134        let choices_out: Vec<ExtendedChoice> = raw_completions
135            .into_iter()
136            .enumerate()
137            .map(|(idx, raw_text)| {
138                let (truncated, hit_stop) = stop_checker.truncate_at_stop(&raw_text);
139                let finish_reason = "stop".to_string();
140                let _ = hit_stop;
141
142                // Apply JSON mode enforcement if requested
143                let final_text = if is_json_mode {
144                    json_enforcer.enforce(&truncated)
145                } else {
146                    truncated.clone()
147                };
148
149                // Check for tool call pattern in the output
150                let tool_calls = if tools.is_some() {
151                    let call_id = crate::api_types::generate_tool_call_id();
152                    crate::api_types::parse_tool_call(&final_text, &call_id).map(|tc| vec![tc])
153                } else {
154                    None
155                };
156
157                // Build logprobs (simplified: no actual logit data here, so we skip)
158                let logprobs = if want_logprobs && top_logprobs_k > 0 {
159                    // Without access to raw logits here, return empty content
160                    Some(ChoiceLogprobs {
161                        content: Some(vec![]),
162                    })
163                } else if want_logprobs {
164                    Some(ChoiceLogprobs {
165                        content: Some(vec![]),
166                    })
167                } else {
168                    None
169                };
170
171                // Estimate token count
172                let approx_tokens = final_text.split_whitespace().count().max(1);
173                comp_tokens += approx_tokens;
174
175                ExtendedChoice {
176                    index: idx,
177                    message: ChatMessage {
178                        role: "assistant".to_string(),
179                        content: Some(final_text),
180                        tool_calls: None,
181                        tool_call_id: None,
182                    },
183                    finish_reason,
184                    logprobs,
185                    tool_calls,
186                }
187            })
188            .collect();
189        total_completion_tokens = comp_tokens;
190        choices_out
191    };
192
193    // Build system fingerprint from model name
194    let system_fingerprint = Some(crate::api_types::fingerprint_from_config("bonsai-8b"));
195
196    let created = std::time::SystemTime::now()
197        .duration_since(std::time::UNIX_EPOCH)
198        .unwrap_or_default()
199        .as_secs();
200
201    let response = ExtendedChatResponse {
202        id: format!("chatcmpl-ext-{}", rand_ext_id()),
203        object: "chat.completion".to_string(),
204        created,
205        model: "bonsai-8b".to_string(),
206        choices,
207        usage: UsageInfo {
208            prompt_tokens: prompt_len,
209            completion_tokens: total_completion_tokens,
210            total_tokens: prompt_len + total_completion_tokens,
211        },
212        system_fingerprint,
213    };
214
215    Json(response).into_response()
216}
217
218/// Build a prompt string from a slice of chat messages (ChatML format).
219///
220/// Messages with `content = None` are skipped (they represent tool-call turns).
221fn build_extended_prompt(messages: &[ChatMessage]) -> String {
222    let mut prompt = String::new();
223    for msg in messages {
224        let text = match msg.content.as_deref() {
225            Some(t) => t,
226            None => continue,
227        };
228        match msg.role.as_str() {
229            "system" => {
230                prompt.push_str("<|im_start|>system\n");
231                prompt.push_str(text);
232                prompt.push_str("<|im_end|>\n");
233            }
234            "user" => {
235                prompt.push_str("<|im_start|>user\n");
236                prompt.push_str(text);
237                prompt.push_str("<|im_end|>\n");
238            }
239            "assistant" => {
240                prompt.push_str("<|im_start|>assistant\n");
241                prompt.push_str(text);
242                prompt.push_str("<|im_end|>\n");
243            }
244            _ => {
245                prompt.push_str(text);
246                prompt.push('\n');
247            }
248        }
249    }
250    prompt.push_str("<|im_start|>assistant\n");
251    prompt
252}
253
254fn rand_ext_id() -> String {
255    let ts = std::time::SystemTime::now()
256        .duration_since(std::time::UNIX_EPOCH)
257        .unwrap_or_default()
258        .as_nanos();
259    format!("{ts:x}")
260}
261
262// ── JSON mode enforcer ────────────────────────────────────────────────────────
263
264/// Wraps generation to produce valid JSON output.
265///
266/// Strategy (applied in order):
267/// 1. If the text already parses as JSON — return it as-is.
268/// 2. Try to extract the first `{…}` or `[…]` substring and parse that.
269/// 3. If still not valid JSON — wrap the text in `{"response": "<text>"}`.
270pub struct JsonModeEnforcer {
271    /// Maximum extraction/wrap attempts (unused here; reserved for future streaming use).
272    pub max_retries: usize,
273}
274
275impl JsonModeEnforcer {
276    /// Create a new enforcer with default settings.
277    pub fn new() -> Self {
278        Self { max_retries: 3 }
279    }
280
281    /// Return a string guaranteed to be valid JSON, applying extraction or
282    /// wrapping if needed.
283    pub fn enforce(&self, text: &str) -> String {
284        // Fast path: already valid JSON
285        if crate::api_types::is_valid_json(text) {
286            return text.to_string();
287        }
288
289        // Try to extract a JSON object substring
290        if let Some(extracted) = extract_json_substring(text) {
291            if crate::api_types::is_valid_json(&extracted) {
292                return extracted;
293            }
294        }
295
296        // Fallback: wrap in a JSON object
297        let escaped = text.replace('\\', "\\\\").replace('"', "\\\"");
298        format!(r#"{{"response": "{escaped}"}}"#)
299    }
300}
301
302impl Default for JsonModeEnforcer {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308/// Try to find and return the first valid-looking JSON object or array in `text`.
309fn extract_json_substring(text: &str) -> Option<String> {
310    // Look for first `{` and last matching `}` (greedy — works for well-nested JSON)
311    if let Some(obj) = extract_balanced(text, '{', '}') {
312        return Some(obj);
313    }
314    // Try array
315    if let Some(arr) = extract_balanced(text, '[', ']') {
316        return Some(arr);
317    }
318    None
319}
320
321/// Extract the outermost balanced delimited substring starting from the first
322/// occurrence of `open` in `text`.
323fn extract_balanced(text: &str, open: char, close: char) -> Option<String> {
324    let start = text.find(open)?;
325    let substr = &text[start..];
326    let mut depth = 0i32;
327    let mut end_idx = None;
328
329    for (i, ch) in substr.char_indices() {
330        if ch == open {
331            depth += 1;
332        } else if ch == close {
333            depth -= 1;
334            if depth == 0 {
335                end_idx = Some(i + ch.len_utf8());
336                break;
337            }
338        }
339    }
340
341    end_idx.map(|e| substr[..e].to_string())
342}
343
344// ── Stop sequence checker ─────────────────────────────────────────────────────
345
346/// Detects and truncates text at stop sequences.
347pub struct StopChecker {
348    sequences: Vec<String>,
349}
350
351impl StopChecker {
352    /// Create a new checker with the given stop sequences.
353    pub fn new(sequences: Vec<String>) -> Self {
354        Self { sequences }
355    }
356
357    /// Returns `Some(&str)` with the first matched stop sequence, or `None`.
358    pub fn check<'a>(&'a self, text: &str) -> Option<&'a str> {
359        for seq in &self.sequences {
360            if text.contains(seq.as_str()) {
361                return Some(seq.as_str());
362            }
363        }
364        None
365    }
366
367    /// Return `(truncated_text, hit_stop)`.
368    ///
369    /// If any stop sequence is found, the text is truncated at that point.
370    pub fn truncate_at_stop(&self, text: &str) -> (String, bool) {
371        let mut earliest: Option<(usize, &str)> = None;
372        for seq in &self.sequences {
373            if let Some(pos) = text.find(seq.as_str()) {
374                match earliest {
375                    None => earliest = Some((pos, seq.as_str())),
376                    Some((prev_pos, _)) if pos < prev_pos => {
377                        earliest = Some((pos, seq.as_str()));
378                    }
379                    _ => {}
380                }
381            }
382        }
383
384        match earliest {
385            Some((pos, _)) => (text[..pos].to_string(), true),
386            None => (text.to_string(), false),
387        }
388    }
389
390    /// Returns `true` if no stop sequences are configured.
391    pub fn is_empty(&self) -> bool {
392        self.sequences.is_empty()
393    }
394}
395
396// ── Multi-completion generator ────────────────────────────────────────────────
397
398/// Generate `n` independent completions from the same prompt, seeding each run
399/// with `base_seed + i` for determinism.
400///
401/// **Note**: This function resets the engine before each run.
402pub fn generate_n_completions(
403    engine: &mut InferenceEngine<'_>,
404    prompt: &str,
405    params: &SamplingParams,
406    n: usize,
407    base_seed: u64,
408) -> Vec<String> {
409    let prompt_tokens: Vec<u32> = {
410        // Simple whitespace-based tokenization fallback (no real tokenizer available here)
411        prompt
412            .split_whitespace()
413            .enumerate()
414            .map(|(i, _)| (i as u32).wrapping_add(1000))
415            .collect()
416    };
417
418    let mut results = Vec::with_capacity(n);
419    for i in 0..n {
420        engine.reset();
421        let seed = base_seed.wrapping_add(i as u64);
422        let text = engine
423            .generate_with_seed(&prompt_tokens, 64, seed, params)
424            .map(|toks| format!("{toks:?}"))
425            .unwrap_or_else(|_| String::new());
426        results.push(text);
427    }
428    results
429}
430
431// ── Frequency / presence penalty ─────────────────────────────────────────────
432
433/// Apply frequency and presence penalties in-place to a logit vector.
434///
435/// For each token that has been seen:
436/// - **frequency penalty** reduces the logit proportionally to its count.
437/// - **presence penalty** reduces the logit by a fixed amount for any seen token.
438pub fn apply_frequency_penalty(
439    logits: &mut [f32],
440    token_counts: &HashMap<u32, usize>,
441    frequency_penalty: f32,
442    presence_penalty: f32,
443) {
444    for (&token_id, &count) in token_counts {
445        if let Some(logit) = logits.get_mut(token_id as usize) {
446            *logit -= frequency_penalty * count as f32;
447            *logit -= presence_penalty;
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn json_mode_enforcer_valid_passthrough() {
458        let enforcer = JsonModeEnforcer::new();
459        let json = r#"{"key": "value"}"#;
460        assert_eq!(enforcer.enforce(json), json);
461    }
462
463    #[test]
464    fn json_mode_enforcer_extracts_substring() {
465        let enforcer = JsonModeEnforcer::new();
466        let text = r#"Here is some text {"key": "value"} and more"#;
467        let result = enforcer.enforce(text);
468        assert!(
469            crate::api_types::is_valid_json(&result),
470            "result should be valid JSON, got: {result}"
471        );
472    }
473
474    #[test]
475    fn json_mode_enforcer_wraps_invalid() {
476        let enforcer = JsonModeEnforcer::new();
477        let text = "not json at all";
478        let result = enforcer.enforce(text);
479        assert!(
480            crate::api_types::is_valid_json(&result),
481            "result should be valid JSON, got: {result}"
482        );
483        let v: serde_json::Value = serde_json::from_str(&result).expect("should parse as json");
484        assert!(v.get("response").is_some(), "should have 'response' key");
485    }
486
487    #[test]
488    fn stop_checker_finds_sequence() {
489        let checker = StopChecker::new(vec!["STOP".to_string(), "END".to_string()]);
490        assert_eq!(checker.check("Hello STOP world"), Some("STOP"));
491        assert_eq!(checker.check("No match here"), None);
492    }
493
494    #[test]
495    fn stop_checker_truncates_correctly() {
496        let checker = StopChecker::new(vec!["<end>".to_string()]);
497        let (truncated, hit) = checker.truncate_at_stop("Hello world<end>more text");
498        assert_eq!(truncated, "Hello world");
499        assert!(hit);
500    }
501
502    #[test]
503    fn stop_checker_no_match() {
504        let checker = StopChecker::new(vec!["nope".to_string()]);
505        let (truncated, hit) = checker.truncate_at_stop("Hello world");
506        assert_eq!(truncated, "Hello world");
507        assert!(!hit);
508    }
509
510    #[test]
511    fn stop_checker_is_empty() {
512        let empty = StopChecker::new(vec![]);
513        assert!(empty.is_empty());
514        let non_empty = StopChecker::new(vec!["x".to_string()]);
515        assert!(!non_empty.is_empty());
516    }
517
518    #[test]
519    fn apply_frequency_penalty_reduces_seen() {
520        let mut logits = vec![1.0f32, 2.0, 3.0];
521        let mut counts = HashMap::new();
522        counts.insert(1u32, 2usize); // token 1 seen twice
523        apply_frequency_penalty(&mut logits, &counts, 0.5, 0.0);
524        // token 1 logit should be reduced by 0.5 * 2 = 1.0
525        assert!(
526            (logits[1] - 1.0).abs() < 1e-5,
527            "expected 1.0, got {}",
528            logits[1]
529        );
530        // others unchanged
531        assert!((logits[0] - 1.0).abs() < 1e-5);
532        assert!((logits[2] - 3.0).abs() < 1e-5);
533    }
534
535    #[test]
536    fn apply_presence_penalty_reduces_seen() {
537        let mut logits = vec![1.0f32, 2.0, 3.0];
538        let mut counts = HashMap::new();
539        counts.insert(0u32, 1usize);
540        apply_frequency_penalty(&mut logits, &counts, 0.0, 1.0);
541        assert!(
542            (logits[0] - 0.0).abs() < 1e-5,
543            "expected 0.0, got {}",
544            logits[0]
545        );
546        assert!((logits[1] - 2.0).abs() < 1e-5);
547    }
548
549    #[test]
550    fn extract_balanced_object() {
551        let text = r#"prefix {"a":1} suffix"#;
552        let result = extract_balanced(text, '{', '}');
553        assert_eq!(result.as_deref(), Some(r#"{"a":1}"#));
554    }
555
556    #[test]
557    fn extract_balanced_array() {
558        let text = r#"pre [1,2,3] post"#;
559        let result = extract_balanced(text, '[', ']');
560        assert_eq!(result.as_deref(), Some("[1,2,3]"));
561    }
562}