1use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct ToolFunction {
14 pub name: String,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub description: Option<String>,
19 pub parameters: serde_json::Value,
21}
22
23#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
25pub struct ToolDefinition {
26 #[serde(rename = "type")]
28 pub r#type: String,
29 pub function: ToolFunction,
31}
32
33impl ToolDefinition {
34 pub fn function(
36 name: impl Into<String>,
37 description: Option<String>,
38 parameters: serde_json::Value,
39 ) -> Self {
40 Self {
41 r#type: "function".to_string(),
42 function: ToolFunction {
43 name: name.into(),
44 description,
45 parameters,
46 },
47 }
48 }
49}
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ToolFunctionCall {
54 pub name: String,
56 pub arguments: String,
58}
59
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ToolCallResult {
65 pub id: String,
67 #[serde(rename = "type")]
69 pub r#type: String,
70 pub function: ToolFunctionCall,
72}
73
74impl ToolCallResult {
75 pub fn new_function(id: String, name: String, arguments: String) -> Self {
77 Self {
78 id,
79 r#type: "function".to_string(),
80 function: ToolFunctionCall { name, arguments },
81 }
82 }
83}
84
85#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
89pub struct FunctionDefinition {
90 pub name: String,
92 pub description: Option<String>,
94 pub parameters: Option<serde_json::Value>,
96}
97
98#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
100pub struct Tool {
101 #[serde(rename = "type")]
103 pub tool_type: String,
104 pub function: FunctionDefinition,
106}
107
108#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
110#[serde(untagged)]
111pub enum ToolChoice {
112 String(String),
114 Named(NamedToolChoice),
116}
117
118#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
120pub struct NamedToolChoice {
121 #[serde(rename = "type")]
123 pub tool_type: String,
124 pub function: FunctionName,
126}
127
128#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
130pub struct FunctionName {
131 pub name: String,
133}
134
135#[derive(Debug, Clone, serde::Serialize)]
137pub struct ToolCall {
138 pub id: String,
140 #[serde(rename = "type")]
142 pub tool_type: String,
143 pub function: FunctionCallResult,
145}
146
147#[derive(Debug, Clone, serde::Serialize)]
149pub struct FunctionCallResult {
150 pub name: String,
152 pub arguments: String,
154}
155
156#[derive(Debug, Clone, serde::Serialize)]
160pub struct LogprobsContent {
161 pub token: String,
163 pub logprob: f32,
165 pub bytes: Option<Vec<u8>>,
167 pub top_logprobs: Vec<TopLogprob>,
169}
170
171#[derive(Debug, Clone, serde::Serialize)]
173pub struct TopLogprob {
174 pub token: String,
176 pub logprob: f32,
178 pub bytes: Option<Vec<u8>>,
180}
181
182#[derive(Debug, Clone, serde::Serialize)]
184pub struct ChoiceLogprobs {
185 pub content: Option<Vec<LogprobsContent>>,
187}
188
189#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
193pub struct ResponseFormat {
194 #[serde(rename = "type")]
196 pub format_type: String,
197 pub json_schema: Option<JsonSchemaFormat>,
199}
200
201#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
203pub struct JsonSchemaFormat {
204 pub name: String,
206 pub schema: serde_json::Value,
208 pub strict: Option<bool>,
210}
211
212#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
216#[serde(untagged)]
217pub enum StopSequences {
218 Single(String),
220 Multiple(Vec<String>),
222}
223
224impl StopSequences {
225 pub fn as_slice(&self) -> &[String] {
227 match self {
228 StopSequences::Single(s) => std::slice::from_ref(s),
229 StopSequences::Multiple(v) => v.as_slice(),
230 }
231 }
232
233 pub fn into_vec(self) -> Vec<String> {
235 match self {
236 StopSequences::Single(s) => vec![s],
237 StopSequences::Multiple(v) => v,
238 }
239 }
240}
241
242#[derive(Debug, Clone, serde::Serialize)]
246pub struct UsageInfo {
247 pub prompt_tokens: usize,
249 pub completion_tokens: usize,
251 pub total_tokens: usize,
253}
254
255#[derive(Debug, serde::Deserialize)]
259pub struct ExtendedChatRequest {
260 pub messages: Vec<crate::server::ChatMessage>,
262 #[serde(default = "default_max_tokens")]
264 pub max_tokens: usize,
265 pub temperature: Option<f32>,
267 pub top_p: Option<f32>,
269 pub stream: Option<bool>,
271 pub stop: Option<StopSequences>,
273 pub tools: Option<Vec<Tool>>,
275 pub tool_choice: Option<ToolChoice>,
277 pub logprobs: Option<bool>,
279 pub top_logprobs: Option<usize>,
281 pub response_format: Option<ResponseFormat>,
283 pub seed: Option<u64>,
285 pub n: Option<usize>,
287 pub presence_penalty: Option<f32>,
289 pub frequency_penalty: Option<f32>,
291 pub user: Option<String>,
293}
294
295fn default_max_tokens() -> usize {
296 256
297}
298
299#[derive(Debug, serde::Serialize)]
303pub struct ExtendedChoice {
304 pub index: usize,
306 pub message: crate::server::ChatMessage,
308 pub finish_reason: String,
310 pub logprobs: Option<ChoiceLogprobs>,
312 pub tool_calls: Option<Vec<ToolCall>>,
314}
315
316#[derive(Debug, serde::Serialize)]
320pub struct ExtendedChatResponse {
321 pub id: String,
323 pub object: String,
325 pub created: u64,
327 pub model: String,
329 pub choices: Vec<ExtendedChoice>,
331 pub usage: UsageInfo,
333 pub system_fingerprint: Option<String>,
335}
336
337pub fn compute_logprobs(
346 logits: &[f32],
347 chosen_token: u32,
348 top_k: usize,
349 id_to_token: &dyn Fn(u32) -> String,
350) -> LogprobsContent {
351 if logits.is_empty() {
352 return LogprobsContent {
353 token: id_to_token(chosen_token),
354 logprob: 0.0,
355 bytes: token_bytes(id_to_token(chosen_token).as_str()),
356 top_logprobs: vec![],
357 };
358 }
359
360 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
362 let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
363 let log_sum_exp = sum_exp.ln() + max_logit;
364
365 let effective_k = top_k.clamp(1, logits.len());
367 let mut indexed: Vec<(u32, f32)> = logits
368 .iter()
369 .enumerate()
370 .map(|(i, &l)| (i as u32, l - log_sum_exp))
371 .collect();
372 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
374 indexed.truncate(effective_k);
375
376 let chosen_logprob = logits
377 .get(chosen_token as usize)
378 .copied()
379 .unwrap_or(f32::NEG_INFINITY)
380 - log_sum_exp;
381
382 let chosen_text = id_to_token(chosen_token);
383 let chosen_bytes = token_bytes(&chosen_text);
384
385 let top_logprobs: Vec<TopLogprob> = indexed
386 .iter()
387 .map(|&(tid, lp)| {
388 let text = id_to_token(tid);
389 let bytes = token_bytes(&text);
390 TopLogprob {
391 token: text,
392 logprob: lp,
393 bytes,
394 }
395 })
396 .collect();
397
398 LogprobsContent {
399 token: chosen_text,
400 logprob: chosen_logprob,
401 bytes: chosen_bytes,
402 top_logprobs,
403 }
404}
405
406fn token_bytes(token: &str) -> Option<Vec<u8>> {
408 if token.is_empty() {
409 None
410 } else {
411 Some(token.as_bytes().to_vec())
412 }
413}
414
415pub fn is_valid_json(text: &str) -> bool {
417 let trimmed = text.trim();
418 if trimmed.is_empty() {
419 return false;
420 }
421 serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
422}
423
424pub fn parse_tool_call(text: &str, call_id: &str) -> Option<ToolCall> {
434 let start_tag = "<tool_call>";
435 let end_tag = "</tool_call>";
436
437 let start = text.find(start_tag)?;
438 let inner_start = start + start_tag.len();
439 let end = text[inner_start..].find(end_tag).map(|e| inner_start + e)?;
440
441 let inner = text[inner_start..end].trim();
442 let value: serde_json::Value = serde_json::from_str(inner).ok()?;
443
444 let name = value.get("name")?.as_str()?.to_string();
445 let arguments = match value.get("arguments") {
446 Some(args) => serde_json::to_string(args).ok()?,
447 None => "{}".to_string(),
448 };
449
450 Some(ToolCall {
451 id: call_id.to_string(),
452 tool_type: "function".to_string(),
453 function: FunctionCallResult { name, arguments },
454 })
455}
456
457pub fn generate_tool_call_id() -> String {
462 let ts = std::time::SystemTime::now()
463 .duration_since(std::time::UNIX_EPOCH)
464 .unwrap_or_default()
465 .as_nanos();
466
467 let mut hasher = DefaultHasher::new();
468 ts.hash(&mut hasher);
469 let hash = hasher.finish();
470 format!("call_{:08x}", hash & 0xFFFF_FFFF)
471}
472
473pub fn fingerprint_from_config(config_hash_input: &str) -> String {
478 let mut hasher = DefaultHasher::new();
479 config_hash_input.hash(&mut hasher);
480 format!("fp_{:x}", hasher.finish())
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn stop_sequences_single_as_slice() {
489 let s = StopSequences::Single("stop".to_string());
490 assert_eq!(s.as_slice(), &["stop"]);
491 }
492
493 #[test]
494 fn stop_sequences_multiple_as_slice() {
495 let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
496 assert_eq!(s.as_slice(), &["a", "b"]);
497 }
498
499 #[test]
500 fn stop_sequences_single_into_vec() {
501 let s = StopSequences::Single("x".to_string());
502 assert_eq!(s.into_vec(), vec!["x"]);
503 }
504
505 #[test]
506 fn stop_sequences_multiple_into_vec() {
507 let s = StopSequences::Multiple(vec!["a".to_string(), "b".to_string()]);
508 assert_eq!(s.into_vec(), vec!["a", "b"]);
509 }
510
511 #[test]
512 fn is_valid_json_object() {
513 assert!(is_valid_json(r#"{"key": "value"}"#));
514 }
515
516 #[test]
517 fn is_valid_json_array() {
518 assert!(is_valid_json(r#"[1, 2, 3]"#));
519 }
520
521 #[test]
522 fn is_valid_json_invalid() {
523 assert!(!is_valid_json("not json"));
524 assert!(!is_valid_json(""));
525 }
526
527 #[test]
528 fn parse_tool_call_valid() {
529 let text = r#"<tool_call>{"name":"get_weather","arguments":{"city":"London"}}</tool_call>"#;
530 let tc = parse_tool_call(text, "call_abc123").expect("should parse");
531 assert_eq!(tc.function.name, "get_weather");
532 assert_eq!(tc.id, "call_abc123");
533 assert_eq!(tc.tool_type, "function");
534 }
535
536 #[test]
537 fn parse_tool_call_invalid() {
538 let text = "No tool call here";
539 assert!(parse_tool_call(text, "call_x").is_none());
540 }
541
542 #[test]
543 fn generate_tool_call_id_prefix() {
544 let id = generate_tool_call_id();
545 assert!(id.starts_with("call_"), "expected call_ prefix, got: {id}");
546 assert_eq!(id.len(), 13, "expected 13 chars, got: {id}");
547 }
548
549 #[test]
550 fn fingerprint_from_config_stable() {
551 let fp1 = fingerprint_from_config("bonsai-8b");
552 let fp2 = fingerprint_from_config("bonsai-8b");
553 assert_eq!(fp1, fp2);
554 assert!(fp1.starts_with("fp_"));
555 }
556
557 #[test]
558 fn compute_logprobs_top_tokens() {
559 let logits = vec![1.0f32, 3.0, 2.0, 0.5, 1.5];
560 let lp = compute_logprobs(&logits, 1, 3, &|id| format!("tok{id}"));
561 assert_eq!(lp.token, "tok1");
562 assert!(
563 lp.logprob <= 0.0,
564 "logprob should be <= 0 (log probability)"
565 );
566 assert_eq!(lp.top_logprobs.len(), 3);
567 assert_eq!(lp.top_logprobs[0].token, "tok1");
569 }
570}