Skip to main content

oxibonsai_runtime/
api_types.rs

1//! Extended OpenAI-compatible API types.
2//!
3//! Provides request/response types for full OpenAI API compatibility including
4//! function calling (tools), logprobs, JSON mode, and multi-completion support.
5
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9// ── Phase 19: Tool calling types ──────────────────────────────────────────────
10
11/// A function definition for tool use.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct ToolFunction {
14    /// The name of the function.
15    pub name: String,
16    /// An optional description of the function.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub description: Option<String>,
19    /// JSON Schema object describing the function parameters.
20    pub parameters: serde_json::Value,
21}
22
23/// A tool available to the model (OpenAI-compatible format).
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
25pub struct ToolDefinition {
26    /// Must be `"function"`.
27    #[serde(rename = "type")]
28    pub r#type: String,
29    /// The function definition.
30    pub function: ToolFunction,
31}
32
33impl ToolDefinition {
34    /// Convenience constructor for a function-type tool.
35    pub fn function(
36        name: impl Into<String>,
37        description: Option<String>,
38        parameters: serde_json::Value,
39    ) -> Self {
40        Self {
41            r#type: "function".to_string(),
42            function: ToolFunction {
43                name: name.into(),
44                description,
45                parameters,
46            },
47        }
48    }
49}
50
51/// A function call made by the model (name + serialised arguments).
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ToolFunctionCall {
54    /// Name of the function invoked.
55    pub name: String,
56    /// JSON-encoded arguments string.
57    pub arguments: String,
58}
59
60/// A tool call produced by the model in a chat completion response.
61///
62/// Uses `r#type` (serialised as `"type"`) to avoid the reserved keyword.
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ToolCallResult {
65    /// Unique identifier for this tool call (prefix `call_`).
66    pub id: String,
67    /// Type of tool call — always `"function"`.
68    #[serde(rename = "type")]
69    pub r#type: String,
70    /// The function invoked.
71    pub function: ToolFunctionCall,
72}
73
74impl ToolCallResult {
75    /// Construct a `ToolCallResult` for a function call.
76    pub fn new_function(id: String, name: String, arguments: String) -> Self {
77        Self {
78            id,
79            r#type: "function".to_string(),
80            function: ToolFunctionCall { name, arguments },
81        }
82    }
83}
84
85// ── Function calling ──────────────────────────────────────────────────────────
86
87/// A function that can be called by the model.
88#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
89pub struct FunctionDefinition {
90    /// The name of the function.
91    pub name: String,
92    /// A description of what the function does.
93    pub description: Option<String>,
94    /// The parameters the function accepts (JSON Schema object).
95    pub parameters: Option<serde_json::Value>,
96}
97
98/// A tool that can be used during generation.
99#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
100pub struct Tool {
101    /// The type of tool. Currently only `"function"` is supported.
102    #[serde(rename = "type")]
103    pub tool_type: String,
104    /// The function definition.
105    pub function: FunctionDefinition,
106}
107
108/// Controls which tool (if any) is called by the model.
109#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
110#[serde(untagged)]
111pub enum ToolChoice {
112    /// A string value: `"none"`, `"auto"`, or `"required"`.
113    String(String),
114    /// A specific named tool to call.
115    Named(NamedToolChoice),
116}
117
118/// A specific tool choice identifying a function by name.
119#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
120pub struct NamedToolChoice {
121    /// The type of the tool (e.g. `"function"`).
122    #[serde(rename = "type")]
123    pub tool_type: String,
124    /// The function to call.
125    pub function: FunctionName,
126}
127
128/// A function identified by name only.
129#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
130pub struct FunctionName {
131    /// The name of the function.
132    pub name: String,
133}
134
135/// A tool call made by the model in the response.
136#[derive(Debug, Clone, serde::Serialize)]
137pub struct ToolCall {
138    /// A unique ID for this tool call.
139    pub id: String,
140    /// The type of tool call (always `"function"`).
141    #[serde(rename = "type")]
142    pub tool_type: String,
143    /// The function that was called.
144    pub function: FunctionCallResult,
145}
146
147/// The result of a function call — name and serialized arguments.
148#[derive(Debug, Clone, serde::Serialize)]
149pub struct FunctionCallResult {
150    /// The name of the function called.
151    pub name: String,
152    /// The arguments to the function as a JSON string.
153    pub arguments: String,
154}
155
156// ── Logprobs ─────────────────────────────────────────────────────────────────
157
158/// Log probability information for a single generated token.
159#[derive(Debug, Clone, serde::Serialize)]
160pub struct LogprobsContent {
161    /// The token text.
162    pub token: String,
163    /// The log probability of this token.
164    pub logprob: f32,
165    /// The UTF-8 bytes of the token, if representable.
166    pub bytes: Option<Vec<u8>>,
167    /// The top alternative tokens at this position.
168    pub top_logprobs: Vec<TopLogprob>,
169}
170
171/// A top-k alternative token and its log probability.
172#[derive(Debug, Clone, serde::Serialize)]
173pub struct TopLogprob {
174    /// The token text.
175    pub token: String,
176    /// The log probability of this token.
177    pub logprob: f32,
178    /// The UTF-8 bytes of the token, if representable.
179    pub bytes: Option<Vec<u8>>,
180}
181
182/// Logprob information attached to a choice.
183#[derive(Debug, Clone, serde::Serialize)]
184pub struct ChoiceLogprobs {
185    /// Per-token log probability content for the choice.
186    pub content: Option<Vec<LogprobsContent>>,
187}
188
189// ── Response format ───────────────────────────────────────────────────────────
190
191/// The format in which the model should return its response.
192#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
193pub struct ResponseFormat {
194    /// `"text"`, `"json_object"`, or `"json_schema"`.
195    #[serde(rename = "type")]
196    pub format_type: String,
197    /// JSON schema definition (only used when `format_type == "json_schema"`).
198    pub json_schema: Option<JsonSchemaFormat>,
199}
200
201/// A named JSON schema that the model output must conform to.
202#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
203pub struct JsonSchemaFormat {
204    /// A human-readable name for the schema.
205    pub name: String,
206    /// The JSON Schema object.
207    pub schema: serde_json::Value,
208    /// Whether the model must strictly follow the schema.
209    pub strict: Option<bool>,
210}
211
212// ── Stop sequences ────────────────────────────────────────────────────────────
213
214/// One or more stop sequences that terminate generation.
215#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
216#[serde(untagged)]
217pub enum StopSequences {
218    /// A single stop sequence string.
219    Single(String),
220    /// Multiple stop sequence strings.
221    Multiple(Vec<String>),
222}
223
224impl StopSequences {
225    /// Return a slice of stop sequence strings.
226    pub fn as_slice(&self) -> &[String] {
227        match self {
228            StopSequences::Single(s) => std::slice::from_ref(s),
229            StopSequences::Multiple(v) => v.as_slice(),
230        }
231    }
232
233    /// Consume and return all stop sequences as a `Vec<String>`.
234    pub fn into_vec(self) -> Vec<String> {
235        match self {
236            StopSequences::Single(s) => vec![s],
237            StopSequences::Multiple(v) => v,
238        }
239    }
240}
241
242// ── Usage info (public alias used by ExtendedChatResponse) ───────────────────
243
244/// Token usage information for a completion request.
245#[derive(Debug, Clone, serde::Serialize)]
246pub struct UsageInfo {
247    /// Tokens consumed by the prompt.
248    pub prompt_tokens: usize,
249    /// Tokens generated in the completion.
250    pub completion_tokens: usize,
251    /// Total tokens (prompt + completion).
252    pub total_tokens: usize,
253}
254
255// ── Extended chat completion request ─────────────────────────────────────────
256
257/// A full OpenAI-compatible chat completion request including all optional fields.
258#[derive(Debug, serde::Deserialize)]
259pub struct ExtendedChatRequest {
260    /// The conversation messages.
261    pub messages: Vec<crate::server::ChatMessage>,
262    /// Maximum number of tokens to generate.
263    #[serde(default = "default_max_tokens")]
264    pub max_tokens: usize,
265    /// Sampling temperature (0.0 = greedy).
266    pub temperature: Option<f32>,
267    /// Nucleus sampling probability threshold.
268    pub top_p: Option<f32>,
269    /// Whether to stream the response as SSE.
270    pub stream: Option<bool>,
271    /// Sequences that stop generation.
272    pub stop: Option<StopSequences>,
273    /// Tools available to the model.
274    pub tools: Option<Vec<Tool>>,
275    /// Controls which tool is called, if any.
276    pub tool_choice: Option<ToolChoice>,
277    /// Whether to return log probabilities for generated tokens.
278    pub logprobs: Option<bool>,
279    /// Number of top alternative tokens to include in logprobs (0–20).
280    pub top_logprobs: Option<usize>,
281    /// Format constraint for the response.
282    pub response_format: Option<ResponseFormat>,
283    /// Seed for deterministic generation.
284    pub seed: Option<u64>,
285    /// Number of independent completions to generate (default 1, max 4).
286    pub n: Option<usize>,
287    /// Penalty applied for tokens that are present in the context.
288    pub presence_penalty: Option<f32>,
289    /// Penalty applied proportional to a token's frequency in the context.
290    pub frequency_penalty: Option<f32>,
291    /// An optional identifier for the end user.
292    pub user: Option<String>,
293}
294
295fn default_max_tokens() -> usize {
296    256
297}
298
299// ── Extended choice with logprobs ─────────────────────────────────────────────
300
301/// A single completion choice that may include logprobs and tool calls.
302#[derive(Debug, serde::Serialize)]
303pub struct ExtendedChoice {
304    /// Zero-based index of this choice among all returned completions.
305    pub index: usize,
306    /// The generated assistant message.
307    pub message: crate::server::ChatMessage,
308    /// Why generation stopped (`"stop"`, `"length"`, `"tool_calls"`, etc.).
309    pub finish_reason: String,
310    /// Log probability information (present only when `logprobs` was requested).
311    pub logprobs: Option<ChoiceLogprobs>,
312    /// Tool calls made by the model, if any.
313    pub tool_calls: Option<Vec<ToolCall>>,
314}
315
316// ── Extended completion response ──────────────────────────────────────────────
317
318/// A full OpenAI-compatible chat completion response.
319#[derive(Debug, serde::Serialize)]
320pub struct ExtendedChatResponse {
321    /// Unique identifier for this completion.
322    pub id: String,
323    /// Object type: always `"chat.completion"`.
324    pub object: String,
325    /// Unix timestamp of creation.
326    pub created: u64,
327    /// The model that generated this completion.
328    pub model: String,
329    /// One or more completion choices.
330    pub choices: Vec<ExtendedChoice>,
331    /// Token usage statistics.
332    pub usage: UsageInfo,
333    /// A fingerprint of the model/backend configuration for reproducibility.
334    pub system_fingerprint: Option<String>,
335}
336
337// ── Utility functions ─────────────────────────────────────────────────────────
338
339/// Compute logprob information for the chosen token, including top-k alternatives.
340///
341/// `logits` is the raw (pre-softmax) logit vector from the model.
342/// `chosen_token` is the index of the token that was actually sampled.
343/// `top_k` is the number of alternatives to include (clamped to `logits.len()`).
344/// `id_to_token` maps a token ID to its string representation.
345pub fn compute_logprobs(
346    logits: &[f32],
347    chosen_token: u32,
348    top_k: usize,
349    id_to_token: &dyn Fn(u32) -> String,
350) -> LogprobsContent {
351    if logits.is_empty() {
352        return LogprobsContent {
353            token: id_to_token(chosen_token),
354            logprob: 0.0,
355            bytes: token_bytes(id_to_token(chosen_token).as_str()),
356            top_logprobs: vec![],
357        };
358    }
359
360    // Compute log-softmax over the full logit vector.
361    let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
362    let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
363    let log_sum_exp = sum_exp.ln() + max_logit;
364
365    // Build sorted list of (token_id, logprob) for top-k.
366    let effective_k = top_k.clamp(1, logits.len());
367    let mut indexed: Vec<(u32, f32)> = logits
368        .iter()
369        .enumerate()
370        .map(|(i, &l)| (i as u32, l - log_sum_exp))
371        .collect();
372    // Partial sort: bring top-k to the front.
373    indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
374    indexed.truncate(effective_k);
375
376    let chosen_logprob = logits
377        .get(chosen_token as usize)
378        .copied()
379        .unwrap_or(f32::NEG_INFINITY)
380        - log_sum_exp;
381
382    let chosen_text = id_to_token(chosen_token);
383    let chosen_bytes = token_bytes(&chosen_text);
384
385    let top_logprobs: Vec<TopLogprob> = indexed
386        .iter()
387        .map(|&(tid, lp)| {
388            let text = id_to_token(tid);
389            let bytes = token_bytes(&text);
390            TopLogprob {
391                token: text,
392                logprob: lp,
393                bytes,
394            }
395        })
396        .collect();
397
398    LogprobsContent {
399        token: chosen_text,
400        logprob: chosen_logprob,
401        bytes: chosen_bytes,
402        top_logprobs,
403    }
404}
405
406/// Return the UTF-8 bytes of a token string, or `None` if empty.
407fn token_bytes(token: &str) -> Option<Vec<u8>> {
408    if token.is_empty() {
409        None
410    } else {
411        Some(token.as_bytes().to_vec())
412    }
413}
414
415/// Return `true` if `text` is valid JSON (object or array).
416pub fn is_valid_json(text: &str) -> bool {
417    let trimmed = text.trim();
418    if trimmed.is_empty() {
419        return false;
420    }
421    serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
422}
423
424/// Attempt to parse a tool call from generated text.
425///
426/// The model is expected to emit tool calls in the form:
427/// ```text
428/// <tool_call>{"name": "fn_name", "arguments": {...}}</tool_call>
429/// ```
430///
431/// Returns `Some(ToolCall)` on success, `None` if the pattern is not found
432/// or the inner JSON cannot be parsed.
433pub fn parse_tool_call(text: &str, call_id: &str) -> Option<ToolCall> {
434    let start_tag = "<tool_call>";
435    let end_tag = "</tool_call>";
436
437    let start = text.find(start_tag)?;
438    let inner_start = start + start_tag.len();
439    let end = text[inner_start..].find(end_tag).map(|e| inner_start + e)?;
440
441    let inner = text[inner_start..end].trim();
442    let value: serde_json::Value = serde_json::from_str(inner).ok()?;
443
444    let name = value.get("name")?.as_str()?.to_string();
445    let arguments = match value.get("arguments") {
446        Some(args) => serde_json::to_string(args).ok()?,
447        None => "{}".to_string(),
448    };
449
450    Some(ToolCall {
451        id: call_id.to_string(),
452        tool_type: "function".to_string(),
453        function: FunctionCallResult { name, arguments },
454    })
455}
456
457/// Generate a unique tool call identifier with the `call_` prefix.
458///
459/// Uses a timestamp-derived hash to produce 8 hex characters, yielding
460/// identifiers such as `call_1a2b3c4d`.
461pub fn generate_tool_call_id() -> String {
462    let ts = std::time::SystemTime::now()
463        .duration_since(std::time::UNIX_EPOCH)
464        .unwrap_or_default()
465        .as_nanos();
466
467    let mut hasher = DefaultHasher::new();
468    ts.hash(&mut hasher);
469    let hash = hasher.finish();
470    format!("call_{:08x}", hash & 0xFFFF_FFFF)
471}
472
473/// Compute a stable hex fingerprint from a model configuration value.
474///
475/// Used to populate `system_fingerprint` in responses, giving clients a way
476/// to detect backend configuration changes between requests.
477pub fn fingerprint_from_config(config_hash_input: &str) -> String {
478    let mut hasher = DefaultHasher::new();
479    config_hash_input.hash(&mut hasher);
480    format!("fp_{:x}", hasher.finish())
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn stop_sequences_single_as_slice() {
489        let s = StopSequences::Single("stop".to_string());
490        assert_eq!(s.as_slice(), &["stop"]);
491    }
492
493    #[test]
494    fn stop_sequences_multiple_as_slice() {
495        let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
496        assert_eq!(s.as_slice(), &["a", "b"]);
497    }
498
499    #[test]
500    fn stop_sequences_single_into_vec() {
501        let s = StopSequences::Single("x".to_string());
502        assert_eq!(s.into_vec(), vec!["x"]);
503    }
504
505    #[test]
506    fn stop_sequences_multiple_into_vec() {
507        let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
508        assert_eq!(s.into_vec(), vec!["a", "b"]);
509    }
510
511    #[test]
512    fn is_valid_json_object() {
513        assert!(is_valid_json(r#"{"key": "value"}"#));
514    }
515
516    #[test]
517    fn is_valid_json_array() {
518        assert!(is_valid_json(r#"[1, 2, 3]"#));
519    }
520
521    #[test]
522    fn is_valid_json_invalid() {
523        assert!(!is_valid_json("not json"));
524        assert!(!is_valid_json(""));
525    }
526
527    #[test]
528    fn parse_tool_call_valid() {
529        let text = r#"<tool_call>{"name":"get_weather","arguments":{"city":"London"}}</tool_call>"#;
530        let tc = parse_tool_call(text, "call_abc123").expect("should parse");
531        assert_eq!(tc.function.name, "get_weather");
532        assert_eq!(tc.id, "call_abc123");
533        assert_eq!(tc.tool_type, "function");
534    }
535
536    #[test]
537    fn parse_tool_call_invalid() {
538        let text = "No tool call here";
539        assert!(parse_tool_call(text, "call_x").is_none());
540    }
541
542    #[test]
543    fn generate_tool_call_id_prefix() {
544        let id = generate_tool_call_id();
545        assert!(id.starts_with("call_"), "expected call_ prefix, got: {id}");
546        assert_eq!(id.len(), 13, "expected 13 chars, got: {id}");
547    }
548
549    #[test]
550    fn fingerprint_from_config_stable() {
551        let fp1 = fingerprint_from_config("bonsai-8b");
552        let fp2 = fingerprint_from_config("bonsai-8b");
553        assert_eq!(fp1, fp2);
554        assert!(fp1.starts_with("fp_"));
555    }
556
557    #[test]
558    fn compute_logprobs_top_tokens() {
559        let logits = vec![1.0f32, 3.0, 2.0, 0.5, 1.5];
560        let lp = compute_logprobs(&logits, 1, 3, &|id| format!("tok{id}"));
561        assert_eq!(lp.token, "tok1");
562        assert!(
563            lp.logprob <= 0.0,
564            "logprob should be <= 0 (log probability)"
565        );
566        assert_eq!(lp.top_logprobs.len(), 3);
567        // The highest logit (index 1) should be the first top logprob
568        assert_eq!(lp.top_logprobs[0].token, "tok1");
569    }
570}