1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21 pub prompt_tokens: u32,
22 pub completion_tokens: u32,
23 pub total_tokens: u32,
24 pub cached_prompt_tokens: Option<u32>,
25 pub cache_creation_tokens: Option<u32>,
26 pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30 #[inline]
31 fn has_cache_read_metric(&self) -> bool {
32 self.cache_read_tokens.is_some() || self.cached_prompt_tokens.is_some()
33 }
34
35 #[inline]
36 fn has_any_cache_metrics(&self) -> bool {
37 self.has_cache_read_metric() || self.cache_creation_tokens.is_some()
38 }
39
40 #[inline]
41 pub fn cache_read_tokens_or_fallback(&self) -> u32 {
42 self.cache_read_tokens
43 .or(self.cached_prompt_tokens)
44 .unwrap_or(0)
45 }
46
47 #[inline]
48 pub fn cache_creation_tokens_or_zero(&self) -> u32 {
49 self.cache_creation_tokens.unwrap_or(0)
50 }
51
52 #[inline]
53 pub fn cache_hit_rate(&self) -> Option<f64> {
54 if !self.has_any_cache_metrics() {
55 return None;
56 }
57 let read = self.cache_read_tokens_or_fallback() as f64;
58 let creation = self.cache_creation_tokens_or_zero() as f64;
59 let total = read + creation;
60 if total > 0.0 {
61 Some((read / total) * 100.0)
62 } else {
63 None
64 }
65 }
66
67 #[inline]
68 pub fn is_cache_hit(&self) -> Option<bool> {
69 self.has_any_cache_metrics()
70 .then(|| self.cache_read_tokens_or_fallback() > 0)
71 }
72
73 #[inline]
74 pub fn is_cache_miss(&self) -> Option<bool> {
75 self.has_any_cache_metrics().then(|| {
76 self.cache_creation_tokens_or_zero() > 0 && self.cache_read_tokens_or_fallback() == 0
77 })
78 }
79
80 #[inline]
81 pub fn total_cache_tokens(&self) -> u32 {
82 let read = self.cache_read_tokens_or_fallback();
83 let creation = self.cache_creation_tokens_or_zero();
84 read + creation
85 }
86
87 #[inline]
88 pub fn cache_savings_ratio(&self) -> Option<f64> {
89 if !self.has_cache_read_metric() {
90 return None;
91 }
92 let read = self.cache_read_tokens_or_fallback() as f64;
93 let prompt = self.prompt_tokens as f64;
94 if prompt > 0.0 {
95 Some(read / prompt)
96 } else {
97 None
98 }
99 }
100}
101
102#[cfg(test)]
103mod usage_tests {
104 use super::Usage;
105
106 #[test]
107 fn cache_helpers_fall_back_to_cached_prompt_tokens() {
108 let usage = Usage {
109 prompt_tokens: 1_000,
110 completion_tokens: 200,
111 total_tokens: 1_200,
112 cached_prompt_tokens: Some(600),
113 cache_creation_tokens: Some(150),
114 cache_read_tokens: None,
115 };
116
117 assert_eq!(usage.cache_read_tokens_or_fallback(), 600);
118 assert_eq!(usage.cache_creation_tokens_or_zero(), 150);
119 assert_eq!(usage.total_cache_tokens(), 750);
120 assert_eq!(usage.is_cache_hit(), Some(true));
121 assert_eq!(usage.is_cache_miss(), Some(false));
122 assert_eq!(usage.cache_savings_ratio(), Some(0.6));
123 assert_eq!(usage.cache_hit_rate(), Some(80.0));
124 }
125
126 #[test]
127 fn cache_helpers_preserve_unknown_without_metrics() {
128 let usage = Usage {
129 prompt_tokens: 1_000,
130 completion_tokens: 200,
131 total_tokens: 1_200,
132 cached_prompt_tokens: None,
133 cache_creation_tokens: None,
134 cache_read_tokens: None,
135 };
136
137 assert_eq!(usage.total_cache_tokens(), 0);
138 assert_eq!(usage.is_cache_hit(), None);
139 assert_eq!(usage.is_cache_miss(), None);
140 assert_eq!(usage.cache_savings_ratio(), None);
141 assert_eq!(usage.cache_hit_rate(), None);
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
146pub enum FinishReason {
147 #[default]
148 Stop,
149 Length,
150 ToolCalls,
151 ContentFilter,
152 Pause,
153 Refusal,
154 Error(String),
155}
156
157#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolCall {
160 pub id: String,
162
163 #[serde(rename = "type")]
165 pub call_type: String,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub function: Option<FunctionCall>,
170
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub text: Option<String>,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub thought_signature: Option<String>,
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct FunctionCall {
183 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub namespace: Option<String>,
186
187 pub name: String,
189
190 pub arguments: String,
192}
193
194impl ToolCall {
195 pub fn function(id: String, name: String, arguments: String) -> Self {
197 Self::function_with_namespace(id, None, name, arguments)
198 }
199
200 pub fn function_with_namespace(
202 id: String,
203 namespace: Option<String>,
204 name: String,
205 arguments: String,
206 ) -> Self {
207 Self {
208 id,
209 call_type: "function".to_owned(),
210 function: Some(FunctionCall {
211 namespace,
212 name,
213 arguments,
214 }),
215 text: None,
216 thought_signature: None,
217 }
218 }
219
220 pub fn custom(id: String, name: String, text: String) -> Self {
222 Self {
223 id,
224 call_type: "custom".to_owned(),
225 function: Some(FunctionCall {
226 namespace: None,
227 name,
228 arguments: text.clone(),
229 }),
230 text: Some(text),
231 thought_signature: None,
232 }
233 }
234
235 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
237 if let Some(ref func) = self.function {
238 parse_tool_arguments(&func.arguments)
239 } else {
240 serde_json::from_str("")
242 }
243 }
244
245 pub fn validate(&self) -> Result<(), String> {
247 if self.id.is_empty() {
248 return Err("Tool call ID cannot be empty".to_owned());
249 }
250
251 match self.call_type.as_str() {
252 "function" => {
253 if let Some(func) = &self.function {
254 if func.name.is_empty() {
255 return Err("Function name cannot be empty".to_owned());
256 }
257 if let Err(e) = self.parsed_arguments() {
259 return Err(format!("Invalid JSON in function arguments: {}", e));
260 }
261 } else {
262 return Err("Function tool call missing function details".to_owned());
263 }
264 }
265 "custom" => {
266 if let Some(func) = &self.function {
268 if func.name.is_empty() {
269 return Err("Custom tool name cannot be empty".to_owned());
270 }
271 } else {
272 return Err("Custom tool call missing function details".to_owned());
273 }
274 }
275 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
276 }
277
278 Ok(())
279 }
280}
281
282fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
283 let trimmed = raw_arguments.trim();
284 match serde_json::from_str(trimmed) {
285 Ok(parsed) => Ok(parsed),
286 Err(primary_error) => {
287 if let Some(candidate) = extract_balanced_json(trimmed)
288 && let Ok(parsed) = serde_json::from_str(candidate)
289 {
290 return Ok(parsed);
291 }
292 Err(primary_error)
293 }
294 }
295}
296
297fn extract_balanced_json(input: &str) -> Option<&str> {
298 let start = input.find(['{', '['])?;
299 let opening = input.as_bytes().get(start).copied()?;
300 let closing = match opening {
301 b'{' => b'}',
302 b'[' => b']',
303 _ => return None,
304 };
305
306 let mut depth = 0usize;
307 let mut in_string = false;
308 let mut escaped = false;
309
310 for (offset, ch) in input[start..].char_indices() {
311 if in_string {
312 if escaped {
313 escaped = false;
314 continue;
315 }
316 if ch == '\\' {
317 escaped = true;
318 continue;
319 }
320 if ch == '"' {
321 in_string = false;
322 }
323 continue;
324 }
325
326 match ch {
327 '"' => in_string = true,
328 _ if ch as u32 == opening as u32 => depth += 1,
329 _ if ch as u32 == closing as u32 => {
330 depth = depth.saturating_sub(1);
331 if depth == 0 {
332 let end = start + offset + ch.len_utf8();
333 return input.get(start..end);
334 }
335 }
336 _ => {}
337 }
338 }
339
340 None
341}
342
343#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
345pub struct LLMResponse {
346 pub content: Option<String>,
348
349 pub tool_calls: Option<Vec<ToolCall>>,
351
352 pub model: String,
354
355 pub usage: Option<Usage>,
357
358 pub finish_reason: FinishReason,
360
361 pub reasoning: Option<String>,
363
364 pub reasoning_details: Option<Vec<String>>,
366
367 pub tool_references: Vec<String>,
369
370 pub request_id: Option<String>,
372
373 pub organization_id: Option<String>,
375}
376
377impl LLMResponse {
378 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
380 Self {
381 content: Some(content.into()),
382 tool_calls: None,
383 model: model.into(),
384 usage: None,
385 finish_reason: FinishReason::Stop,
386 reasoning: None,
387 reasoning_details: None,
388 tool_references: Vec::new(),
389 request_id: None,
390 organization_id: None,
391 }
392 }
393
394 pub fn content_text(&self) -> &str {
396 self.content.as_deref().unwrap_or("")
397 }
398
399 pub fn content_string(&self) -> String {
401 self.content.clone().unwrap_or_default()
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
406pub struct LLMErrorMetadata {
407 pub provider: Option<String>,
408 pub status: Option<u16>,
409 pub code: Option<String>,
410 pub request_id: Option<String>,
411 pub organization_id: Option<String>,
412 pub retry_after: Option<String>,
413 pub message: Option<String>,
414}
415
416impl LLMErrorMetadata {
417 pub fn new(
418 provider: impl Into<String>,
419 status: Option<u16>,
420 code: Option<String>,
421 request_id: Option<String>,
422 organization_id: Option<String>,
423 retry_after: Option<String>,
424 message: Option<String>,
425 ) -> Box<Self> {
426 Box::new(Self {
427 provider: Some(provider.into()),
428 status,
429 code,
430 request_id,
431 organization_id,
432 retry_after,
433 message,
434 })
435 }
436}
437
438#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
440#[serde(tag = "type", rename_all = "snake_case")]
441pub enum LLMError {
442 #[error("Authentication failed: {message}")]
443 Authentication {
444 message: String,
445 metadata: Option<Box<LLMErrorMetadata>>,
446 },
447 #[error("Rate limit exceeded")]
448 RateLimit {
449 metadata: Option<Box<LLMErrorMetadata>>,
450 },
451 #[error("Invalid request: {message}")]
452 InvalidRequest {
453 message: String,
454 metadata: Option<Box<LLMErrorMetadata>>,
455 },
456 #[error("Network error: {message}")]
457 Network {
458 message: String,
459 metadata: Option<Box<LLMErrorMetadata>>,
460 },
461 #[error("Provider error: {message}")]
462 Provider {
463 message: String,
464 metadata: Option<Box<LLMErrorMetadata>>,
465 },
466}
467
468#[cfg(test)]
469mod tests {
470 use super::ToolCall;
471 use serde_json::json;
472
473 #[test]
474 fn parsed_arguments_accepts_trailing_characters() {
475 let call = ToolCall::function(
476 "call_read".to_string(),
477 "read_file".to_string(),
478 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
479 );
480
481 let parsed = call
482 .parsed_arguments()
483 .expect("arguments with trailing text should recover");
484 assert_eq!(parsed, json!({"path":"src/main.rs"}));
485 }
486
487 #[test]
488 fn parsed_arguments_accepts_code_fenced_json() {
489 let call = ToolCall::function(
490 "call_read".to_string(),
491 "read_file".to_string(),
492 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
493 );
494
495 let parsed = call
496 .parsed_arguments()
497 .expect("code-fenced arguments should recover");
498 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
499 }
500
501 #[test]
502 fn parsed_arguments_rejects_incomplete_json() {
503 let call = ToolCall::function(
504 "call_read".to_string(),
505 "read_file".to_string(),
506 r#"{"path":"src/main.rs""#.to_string(),
507 );
508
509 assert!(call.parsed_arguments().is_err());
510 }
511
512 #[test]
513 fn function_call_serializes_optional_namespace() {
514 let call = ToolCall::function_with_namespace(
515 "call_read".to_string(),
516 Some("workspace".to_string()),
517 "read_file".to_string(),
518 r#"{"path":"src/main.rs"}"#.to_string(),
519 );
520
521 let json = serde_json::to_value(&call).expect("tool call should serialize");
522 assert_eq!(json["function"]["namespace"], "workspace");
523 assert_eq!(json["function"]["name"], "read_file");
524 }
525}