1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BackendKind {
7 Gemini,
8 OpenAI,
9 Anthropic,
10 DeepSeek,
11 OpenRouter,
12 Ollama,
13 ZAI,
14 Moonshot,
15 HuggingFace,
16 Minimax,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct Usage {
21 pub prompt_tokens: u32,
22 pub completion_tokens: u32,
23 pub total_tokens: u32,
24 pub cached_prompt_tokens: Option<u32>,
25 pub cache_creation_tokens: Option<u32>,
26 pub cache_read_tokens: Option<u32>,
27}
28
29impl Usage {
30 #[inline]
31 pub fn cache_hit_rate(&self) -> Option<f64> {
32 let read = self.cache_read_tokens? as f64;
33 let creation = self.cache_creation_tokens? as f64;
34 let total = read + creation;
35 if total > 0.0 {
36 Some((read / total) * 100.0)
37 } else {
38 None
39 }
40 }
41
42 #[inline]
43 pub fn is_cache_hit(&self) -> Option<bool> {
44 Some(self.cache_read_tokens? > 0)
45 }
46
47 #[inline]
48 pub fn is_cache_miss(&self) -> Option<bool> {
49 Some(self.cache_creation_tokens? > 0 && self.cache_read_tokens? == 0)
50 }
51
52 #[inline]
53 pub fn total_cache_tokens(&self) -> u32 {
54 let read = self.cache_read_tokens.unwrap_or(0);
55 let creation = self.cache_creation_tokens.unwrap_or(0);
56 read + creation
57 }
58
59 #[inline]
60 pub fn cache_savings_ratio(&self) -> Option<f64> {
61 let read = self.cache_read_tokens? as f64;
62 let prompt = self.prompt_tokens as f64;
63 if prompt > 0.0 {
64 Some(read / prompt)
65 } else {
66 None
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
72pub enum FinishReason {
73 #[default]
74 Stop,
75 Length,
76 ToolCalls,
77 ContentFilter,
78 Pause,
79 Refusal,
80 Error(String),
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub struct ToolCall {
86 pub id: String,
88
89 #[serde(rename = "type")]
91 pub call_type: String,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub function: Option<FunctionCall>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub text: Option<String>,
100
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub thought_signature: Option<String>,
104}
105
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
108pub struct FunctionCall {
109 pub name: String,
111
112 pub arguments: String,
114}
115
116impl ToolCall {
117 pub fn function(id: String, name: String, arguments: String) -> Self {
119 Self {
120 id,
121 call_type: "function".to_owned(),
122 function: Some(FunctionCall { name, arguments }),
123 text: None,
124 thought_signature: None,
125 }
126 }
127
128 pub fn custom(id: String, name: String, text: String) -> Self {
130 Self {
131 id,
132 call_type: "custom".to_owned(),
133 function: Some(FunctionCall {
134 name,
135 arguments: text.clone(),
136 }),
137 text: Some(text),
138 thought_signature: None,
139 }
140 }
141
142 pub fn parsed_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
144 if let Some(ref func) = self.function {
145 parse_tool_arguments(&func.arguments)
146 } else {
147 serde_json::from_str("")
149 }
150 }
151
152 pub fn validate(&self) -> Result<(), String> {
154 if self.id.is_empty() {
155 return Err("Tool call ID cannot be empty".to_owned());
156 }
157
158 match self.call_type.as_str() {
159 "function" => {
160 if let Some(func) = &self.function {
161 if func.name.is_empty() {
162 return Err("Function name cannot be empty".to_owned());
163 }
164 if let Err(e) = self.parsed_arguments() {
166 return Err(format!("Invalid JSON in function arguments: {}", e));
167 }
168 } else {
169 return Err("Function tool call missing function details".to_owned());
170 }
171 }
172 "custom" => {
173 if let Some(func) = &self.function {
175 if func.name.is_empty() {
176 return Err("Custom tool name cannot be empty".to_owned());
177 }
178 } else {
179 return Err("Custom tool call missing function details".to_owned());
180 }
181 }
182 _ => return Err(format!("Unsupported tool call type: {}", self.call_type)),
183 }
184
185 Ok(())
186 }
187}
188
189fn parse_tool_arguments(raw_arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
190 let trimmed = raw_arguments.trim();
191 match serde_json::from_str(trimmed) {
192 Ok(parsed) => Ok(parsed),
193 Err(primary_error) => {
194 if let Some(candidate) = extract_balanced_json(trimmed)
195 && let Ok(parsed) = serde_json::from_str(candidate)
196 {
197 return Ok(parsed);
198 }
199 Err(primary_error)
200 }
201 }
202}
203
204fn extract_balanced_json(input: &str) -> Option<&str> {
205 let start = input.find(['{', '['])?;
206 let opening = input.as_bytes().get(start).copied()?;
207 let closing = match opening {
208 b'{' => b'}',
209 b'[' => b']',
210 _ => return None,
211 };
212
213 let mut depth = 0usize;
214 let mut in_string = false;
215 let mut escaped = false;
216
217 for (offset, ch) in input[start..].char_indices() {
218 if in_string {
219 if escaped {
220 escaped = false;
221 continue;
222 }
223 if ch == '\\' {
224 escaped = true;
225 continue;
226 }
227 if ch == '"' {
228 in_string = false;
229 }
230 continue;
231 }
232
233 match ch {
234 '"' => in_string = true,
235 _ if ch as u32 == opening as u32 => depth += 1,
236 _ if ch as u32 == closing as u32 => {
237 depth = depth.saturating_sub(1);
238 if depth == 0 {
239 let end = start + offset + ch.len_utf8();
240 return input.get(start..end);
241 }
242 }
243 _ => {}
244 }
245 }
246
247 None
248}
249
250#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
252pub struct LLMResponse {
253 pub content: Option<String>,
255
256 pub tool_calls: Option<Vec<ToolCall>>,
258
259 pub model: String,
261
262 pub usage: Option<Usage>,
264
265 pub finish_reason: FinishReason,
267
268 pub reasoning: Option<String>,
270
271 pub reasoning_details: Option<Vec<String>>,
273
274 pub tool_references: Vec<String>,
276
277 pub request_id: Option<String>,
279
280 pub organization_id: Option<String>,
282}
283
284impl LLMResponse {
285 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
287 Self {
288 content: Some(content.into()),
289 tool_calls: None,
290 model: model.into(),
291 usage: None,
292 finish_reason: FinishReason::Stop,
293 reasoning: None,
294 reasoning_details: None,
295 tool_references: Vec::new(),
296 request_id: None,
297 organization_id: None,
298 }
299 }
300
301 pub fn content_text(&self) -> &str {
303 self.content.as_deref().unwrap_or("")
304 }
305
306 pub fn content_string(&self) -> String {
308 self.content.clone().unwrap_or_default()
309 }
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
313pub struct LLMErrorMetadata {
314 pub provider: Option<String>,
315 pub status: Option<u16>,
316 pub code: Option<String>,
317 pub request_id: Option<String>,
318 pub organization_id: Option<String>,
319 pub retry_after: Option<String>,
320 pub message: Option<String>,
321}
322
323impl LLMErrorMetadata {
324 pub fn new(
325 provider: impl Into<String>,
326 status: Option<u16>,
327 code: Option<String>,
328 request_id: Option<String>,
329 organization_id: Option<String>,
330 retry_after: Option<String>,
331 message: Option<String>,
332 ) -> Box<Self> {
333 Box::new(Self {
334 provider: Some(provider.into()),
335 status,
336 code,
337 request_id,
338 organization_id,
339 retry_after,
340 message,
341 })
342 }
343}
344
345#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)]
347#[serde(tag = "type", rename_all = "snake_case")]
348pub enum LLMError {
349 #[error("Authentication failed: {message}")]
350 Authentication {
351 message: String,
352 metadata: Option<Box<LLMErrorMetadata>>,
353 },
354 #[error("Rate limit exceeded")]
355 RateLimit {
356 metadata: Option<Box<LLMErrorMetadata>>,
357 },
358 #[error("Invalid request: {message}")]
359 InvalidRequest {
360 message: String,
361 metadata: Option<Box<LLMErrorMetadata>>,
362 },
363 #[error("Network error: {message}")]
364 Network {
365 message: String,
366 metadata: Option<Box<LLMErrorMetadata>>,
367 },
368 #[error("Provider error: {message}")]
369 Provider {
370 message: String,
371 metadata: Option<Box<LLMErrorMetadata>>,
372 },
373}
374
375#[cfg(test)]
376mod tests {
377 use super::ToolCall;
378 use serde_json::json;
379
380 #[test]
381 fn parsed_arguments_accepts_trailing_characters() {
382 let call = ToolCall::function(
383 "call_read".to_string(),
384 "read_file".to_string(),
385 r#"{"path":"src/main.rs"} trailing text"#.to_string(),
386 );
387
388 let parsed = call
389 .parsed_arguments()
390 .expect("arguments with trailing text should recover");
391 assert_eq!(parsed, json!({"path":"src/main.rs"}));
392 }
393
394 #[test]
395 fn parsed_arguments_accepts_code_fenced_json() {
396 let call = ToolCall::function(
397 "call_read".to_string(),
398 "read_file".to_string(),
399 "```json\n{\"path\":\"src/lib.rs\",\"limit\":25}\n```".to_string(),
400 );
401
402 let parsed = call
403 .parsed_arguments()
404 .expect("code-fenced arguments should recover");
405 assert_eq!(parsed, json!({"path":"src/lib.rs","limit":25}));
406 }
407
408 #[test]
409 fn parsed_arguments_rejects_incomplete_json() {
410 let call = ToolCall::function(
411 "call_read".to_string(),
412 "read_file".to_string(),
413 r#"{"path":"src/main.rs""#.to_string(),
414 );
415
416 assert!(call.parsed_arguments().is_err());
417 }
418}