use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
String(String),
Named(NamedToolChoice),
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct NamedToolChoice {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionName,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct FunctionName {
pub name: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCallResult,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct FunctionCallResult {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct LogprobsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogprob>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct TopLogprob {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ChoiceLogprobs {
pub content: Option<Vec<LogprobsContent>>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
pub json_schema: Option<JsonSchemaFormat>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: serde_json::Value,
pub strict: Option<bool>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(untagged)]
pub enum StopSequences {
Single(String),
Multiple(Vec<String>),
}
impl StopSequences {
pub fn as_slice(&self) -> &[String] {
match self {
StopSequences::Single(s) => std::slice::from_ref(s),
StopSequences::Multiple(v) => v.as_slice(),
}
}
pub fn into_vec(self) -> Vec<String> {
match self {
StopSequences::Single(s) => vec![s],
StopSequences::Multiple(v) => v,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct UsageInfo {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, serde::Deserialize)]
pub struct ExtendedChatRequest {
pub messages: Vec<crate::server::ChatMessage>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stream: Option<bool>,
pub stop: Option<StopSequences>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<usize>,
pub response_format: Option<ResponseFormat>,
pub seed: Option<u64>,
pub n: Option<usize>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub user: Option<String>,
}
fn default_max_tokens() -> usize {
256
}
#[derive(Debug, serde::Serialize)]
pub struct ExtendedChoice {
pub index: usize,
pub message: crate::server::ChatMessage,
pub finish_reason: String,
pub logprobs: Option<ChoiceLogprobs>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, serde::Serialize)]
pub struct ExtendedChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ExtendedChoice>,
pub usage: UsageInfo,
pub system_fingerprint: Option<String>,
}
pub fn compute_logprobs(
logits: &[f32],
chosen_token: u32,
top_k: usize,
id_to_token: &dyn Fn(u32) -> String,
) -> LogprobsContent {
if logits.is_empty() {
return LogprobsContent {
token: id_to_token(chosen_token),
logprob: 0.0,
bytes: token_bytes(id_to_token(chosen_token).as_str()),
top_logprobs: vec![],
};
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
let log_sum_exp = sum_exp.ln() + max_logit;
let effective_k = top_k.clamp(1, logits.len());
let mut indexed: Vec<(u32, f32)> = logits
.iter()
.enumerate()
.map(|(i, &l)| (i as u32, l - log_sum_exp))
.collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(effective_k);
let chosen_logprob = logits
.get(chosen_token as usize)
.copied()
.unwrap_or(f32::NEG_INFINITY)
- log_sum_exp;
let chosen_text = id_to_token(chosen_token);
let chosen_bytes = token_bytes(&chosen_text);
let top_logprobs: Vec<TopLogprob> = indexed
.iter()
.map(|&(tid, lp)| {
let text = id_to_token(tid);
let bytes = token_bytes(&text);
TopLogprob {
token: text,
logprob: lp,
bytes,
}
})
.collect();
LogprobsContent {
token: chosen_text,
logprob: chosen_logprob,
bytes: chosen_bytes,
top_logprobs,
}
}
fn token_bytes(token: &str) -> Option<Vec<u8>> {
if token.is_empty() {
None
} else {
Some(token.as_bytes().to_vec())
}
}
pub fn is_valid_json(text: &str) -> bool {
let trimmed = text.trim();
if trimmed.is_empty() {
return false;
}
serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
}
pub fn parse_tool_call(text: &str, call_id: &str) -> Option<ToolCall> {
let start_tag = "<tool_call>";
let end_tag = "</tool_call>";
let start = text.find(start_tag)?;
let inner_start = start + start_tag.len();
let end = text[inner_start..].find(end_tag).map(|e| inner_start + e)?;
let inner = text[inner_start..end].trim();
let value: serde_json::Value = serde_json::from_str(inner).ok()?;
let name = value.get("name")?.as_str()?.to_string();
let arguments = match value.get("arguments") {
Some(args) => serde_json::to_string(args).ok()?,
None => "{}".to_string(),
};
Some(ToolCall {
id: call_id.to_string(),
tool_type: "function".to_string(),
function: FunctionCallResult { name, arguments },
})
}
pub fn generate_tool_call_id() -> String {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let mut hasher = DefaultHasher::new();
ts.hash(&mut hasher);
let hash = hasher.finish();
format!("call_{:08x}", hash & 0xFFFF_FFFF)
}
pub fn fingerprint_from_config(config_hash_input: &str) -> String {
let mut hasher = DefaultHasher::new();
config_hash_input.hash(&mut hasher);
format!("fp_{:x}", hasher.finish())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stop_sequences_single_as_slice() {
let s = StopSequences::Single("stop".to_string());
assert_eq!(s.as_slice(), &["stop"]);
}
#[test]
fn stop_sequences_multiple_as_slice() {
let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
assert_eq!(s.as_slice(), &["a", "b"]);
}
#[test]
fn stop_sequences_single_into_vec() {
let s = StopSequences::Single("x".to_string());
assert_eq!(s.into_vec(), vec!["x"]);
}
#[test]
fn stop_sequences_multiple_into_vec() {
let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
assert_eq!(s.into_vec(), vec!["a", "b"]);
}
#[test]
fn is_valid_json_object() {
assert!(is_valid_json(r#"{"key": "value"}"#));
}
#[test]
fn is_valid_json_array() {
assert!(is_valid_json(r#"[1, 2, 3]"#));
}
#[test]
fn is_valid_json_invalid() {
assert!(!is_valid_json("not json"));
assert!(!is_valid_json(""));
}
#[test]
fn parse_tool_call_valid() {
let text = r#"<tool_call>{"name":"get_weather","arguments":{"city":"London"}}</tool_call>"#;
let tc = parse_tool_call(text, "call_abc123").expect("should parse");
assert_eq!(tc.function.name, "get_weather");
assert_eq!(tc.id, "call_abc123");
assert_eq!(tc.tool_type, "function");
}
#[test]
fn parse_tool_call_invalid() {
let text = "No tool call here";
assert!(parse_tool_call(text, "call_x").is_none());
}
#[test]
fn generate_tool_call_id_prefix() {
let id = generate_tool_call_id();
assert!(id.starts_with("call_"), "expected call_ prefix, got: {id}");
assert_eq!(id.len(), 13, "expected 13 chars, got: {id}");
}
#[test]
fn fingerprint_from_config_stable() {
let fp1 = fingerprint_from_config("bonsai-8b");
let fp2 = fingerprint_from_config("bonsai-8b");
assert_eq!(fp1, fp2);
assert!(fp1.starts_with("fp_"));
}
#[test]
fn compute_logprobs_top_tokens() {
let logits = vec![1.0f32, 3.0, 2.0, 0.5, 1.5];
let lp = compute_logprobs(&logits, 1, 3, &|id| format!("tok{id}"));
assert_eq!(lp.token, "tok1");
assert!(
lp.logprob <= 0.0,
"logprob should be <= 0 (log probability)"
);
assert_eq!(lp.top_logprobs.len(), 3);
assert_eq!(lp.top_logprobs[0].token, "tok1");
}
}