1use chat::Tool;
17use serde::{Deserialize, Serialize};
18
19pub mod backends;
21
22pub mod builder;
24
25pub mod chat;
27
28pub mod completion;
30
31pub mod embedding;
33
34pub mod error;
36
37pub mod evaluator;
39
40pub mod secret_store;
42
43pub mod models;
45
46pub trait LLMProvider:
49 chat::ChatProvider
50 + completion::CompletionProvider
51 + embedding::EmbeddingProvider
52 + models::ModelsProvider
53 + Send
54 + Sync
55 + 'static
56{
57 fn tools(&self) -> Option<&[Tool]> {
58 None
59 }
60}
61
62#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
65pub struct ToolCall {
66 pub id: String,
68 #[serde(rename = "type")]
70 pub call_type: String,
71 pub function: FunctionCall,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
77pub struct FunctionCall {
78 pub name: String,
80 pub arguments: String,
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::chat::ChatProvider;
88 use crate::completion::CompletionProvider;
89 use crate::embedding::EmbeddingProvider;
90 use async_trait::async_trait;
91 use serde_json::json;
92
93 #[test]
94 fn test_tool_call_creation() {
95 let tool_call = ToolCall {
96 id: "call_123".to_string(),
97 call_type: "function".to_string(),
98 function: FunctionCall {
99 name: "test_function".to_string(),
100 arguments: "{\"param\": \"value\"}".to_string(),
101 },
102 };
103
104 assert_eq!(tool_call.id, "call_123");
105 assert_eq!(tool_call.call_type, "function");
106 assert_eq!(tool_call.function.name, "test_function");
107 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
108 }
109
110 #[test]
111 fn test_tool_call_serialization() {
112 let tool_call = ToolCall {
113 id: "call_456".to_string(),
114 call_type: "function".to_string(),
115 function: FunctionCall {
116 name: "serialize_test".to_string(),
117 arguments: "{\"test\": true}".to_string(),
118 },
119 };
120
121 let serialized = serde_json::to_string(&tool_call).unwrap();
122 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
123
124 assert_eq!(deserialized.id, "call_456");
125 assert_eq!(deserialized.call_type, "function");
126 assert_eq!(deserialized.function.name, "serialize_test");
127 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
128 }
129
130 #[test]
131 fn test_tool_call_equality() {
132 let tool_call1 = ToolCall {
133 id: "call_1".to_string(),
134 call_type: "function".to_string(),
135 function: FunctionCall {
136 name: "equal_test".to_string(),
137 arguments: "{}".to_string(),
138 },
139 };
140
141 let tool_call2 = ToolCall {
142 id: "call_1".to_string(),
143 call_type: "function".to_string(),
144 function: FunctionCall {
145 name: "equal_test".to_string(),
146 arguments: "{}".to_string(),
147 },
148 };
149
150 let tool_call3 = ToolCall {
151 id: "call_2".to_string(),
152 call_type: "function".to_string(),
153 function: FunctionCall {
154 name: "equal_test".to_string(),
155 arguments: "{}".to_string(),
156 },
157 };
158
159 assert_eq!(tool_call1, tool_call2);
160 assert_ne!(tool_call1, tool_call3);
161 }
162
163 #[test]
164 fn test_tool_call_clone() {
165 let tool_call = ToolCall {
166 id: "clone_test".to_string(),
167 call_type: "function".to_string(),
168 function: FunctionCall {
169 name: "test_clone".to_string(),
170 arguments: "{\"clone\": true}".to_string(),
171 },
172 };
173
174 let cloned = tool_call.clone();
175 assert_eq!(tool_call, cloned);
176 assert_eq!(tool_call.id, cloned.id);
177 assert_eq!(tool_call.function.name, cloned.function.name);
178 }
179
180 #[test]
181 fn test_tool_call_debug() {
182 let tool_call = ToolCall {
183 id: "debug_test".to_string(),
184 call_type: "function".to_string(),
185 function: FunctionCall {
186 name: "debug_function".to_string(),
187 arguments: "{}".to_string(),
188 },
189 };
190
191 let debug_str = format!("{tool_call:?}");
192 assert!(debug_str.contains("ToolCall"));
193 assert!(debug_str.contains("debug_test"));
194 assert!(debug_str.contains("debug_function"));
195 }
196
197 #[test]
198 fn test_function_call_creation() {
199 let function_call = FunctionCall {
200 name: "test_function".to_string(),
201 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
202 };
203
204 assert_eq!(function_call.name, "test_function");
205 assert_eq!(
206 function_call.arguments,
207 "{\"param1\": \"value1\", \"param2\": 42}"
208 );
209 }
210
211 #[test]
212 fn test_function_call_serialization() {
213 let function_call = FunctionCall {
214 name: "serialize_function".to_string(),
215 arguments: "{\"data\": [1, 2, 3]}".to_string(),
216 };
217
218 let serialized = serde_json::to_string(&function_call).unwrap();
219 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
220
221 assert_eq!(deserialized.name, "serialize_function");
222 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
223 }
224
225 #[test]
226 fn test_function_call_equality() {
227 let func1 = FunctionCall {
228 name: "equal_func".to_string(),
229 arguments: "{}".to_string(),
230 };
231
232 let func2 = FunctionCall {
233 name: "equal_func".to_string(),
234 arguments: "{}".to_string(),
235 };
236
237 let func3 = FunctionCall {
238 name: "different_func".to_string(),
239 arguments: "{}".to_string(),
240 };
241
242 assert_eq!(func1, func2);
243 assert_ne!(func1, func3);
244 }
245
246 #[test]
247 fn test_function_call_clone() {
248 let function_call = FunctionCall {
249 name: "clone_func".to_string(),
250 arguments: "{\"clone\": \"test\"}".to_string(),
251 };
252
253 let cloned = function_call.clone();
254 assert_eq!(function_call, cloned);
255 assert_eq!(function_call.name, cloned.name);
256 assert_eq!(function_call.arguments, cloned.arguments);
257 }
258
259 #[test]
260 fn test_function_call_debug() {
261 let function_call = FunctionCall {
262 name: "debug_func".to_string(),
263 arguments: "{}".to_string(),
264 };
265
266 let debug_str = format!("{function_call:?}");
267 assert!(debug_str.contains("FunctionCall"));
268 assert!(debug_str.contains("debug_func"));
269 }
270
271 #[test]
272 fn test_tool_call_with_empty_values() {
273 let tool_call = ToolCall {
274 id: String::new(),
275 call_type: String::new(),
276 function: FunctionCall {
277 name: String::new(),
278 arguments: String::new(),
279 },
280 };
281
282 assert!(tool_call.id.is_empty());
283 assert!(tool_call.call_type.is_empty());
284 assert!(tool_call.function.name.is_empty());
285 assert!(tool_call.function.arguments.is_empty());
286 }
287
288 #[test]
289 fn test_tool_call_with_complex_arguments() {
290 let complex_args = json!({
291 "nested": {
292 "array": [1, 2, 3],
293 "object": {
294 "key": "value"
295 }
296 },
297 "simple": "string"
298 });
299
300 let tool_call = ToolCall {
301 id: "complex_call".to_string(),
302 call_type: "function".to_string(),
303 function: FunctionCall {
304 name: "complex_function".to_string(),
305 arguments: complex_args.to_string(),
306 },
307 };
308
309 let serialized = serde_json::to_string(&tool_call).unwrap();
310 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
311
312 assert_eq!(deserialized.id, "complex_call");
313 assert_eq!(deserialized.function.name, "complex_function");
314 assert!(deserialized.function.arguments.contains("nested"));
316 assert!(deserialized.function.arguments.contains("array"));
317 }
318
319 #[test]
320 fn test_tool_call_with_unicode() {
321 let tool_call = ToolCall {
322 id: "unicode_call".to_string(),
323 call_type: "function".to_string(),
324 function: FunctionCall {
325 name: "unicode_function".to_string(),
326 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
327 },
328 };
329
330 let serialized = serde_json::to_string(&tool_call).unwrap();
331 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
332
333 assert_eq!(deserialized.id, "unicode_call");
334 assert_eq!(deserialized.function.name, "unicode_function");
335 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
336 }
337
338 #[test]
339 fn test_tool_call_large_arguments() {
340 let large_arg = "x".repeat(10000);
341 let tool_call = ToolCall {
342 id: "large_call".to_string(),
343 call_type: "function".to_string(),
344 function: FunctionCall {
345 name: "large_function".to_string(),
346 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
347 },
348 };
349
350 let serialized = serde_json::to_string(&tool_call).unwrap();
351 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
352
353 assert_eq!(deserialized.id, "large_call");
354 assert_eq!(deserialized.function.name, "large_function");
355 assert!(deserialized.function.arguments.len() > 10000);
356 }
357
358 struct MockLLMProvider;
360
361 #[async_trait]
362 impl chat::ChatProvider for MockLLMProvider {
363 async fn chat_with_tools(
364 &self,
365 _messages: &[chat::ChatMessage],
366 _tools: Option<&[chat::Tool]>,
367 _json_schema: Option<chat::StructuredOutputFormat>,
368 ) -> Result<Box<dyn chat::ChatResponse>, error::LLMError> {
369 Ok(Box::new(MockChatResponse {
370 text: Some("Mock response".to_string()),
371 }))
372 }
373 }
374
375 #[async_trait]
376 impl completion::CompletionProvider for MockLLMProvider {
377 async fn complete(
378 &self,
379 _req: &completion::CompletionRequest,
380 _json_schema: Option<chat::StructuredOutputFormat>,
381 ) -> Result<completion::CompletionResponse, error::LLMError> {
382 Ok(completion::CompletionResponse {
383 text: "Mock completion".to_string(),
384 })
385 }
386 }
387
388 #[async_trait]
389 impl embedding::EmbeddingProvider for MockLLMProvider {
390 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
391 let mut embeddings = Vec::new();
392 for (i, _) in input.iter().enumerate() {
393 embeddings.push(vec![i as f32, (i + 1) as f32]);
394 }
395 Ok(embeddings)
396 }
397 }
398
399 #[async_trait]
400 impl models::ModelsProvider for MockLLMProvider {}
401
402 impl LLMProvider for MockLLMProvider {
403 fn tools(&self) -> Option<&[chat::Tool]> {
404 None
405 }
406 }
407
408 struct MockChatResponse {
409 text: Option<String>,
410 }
411
412 impl chat::ChatResponse for MockChatResponse {
413 fn text(&self) -> Option<String> {
414 self.text.clone()
415 }
416
417 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
418 None
419 }
420 }
421
422 impl std::fmt::Debug for MockChatResponse {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 write!(f, "MockChatResponse")
425 }
426 }
427
428 impl std::fmt::Display for MockChatResponse {
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 write!(f, "{}", self.text.as_deref().unwrap_or(""))
431 }
432 }
433
434 #[tokio::test]
435 async fn test_llm_provider_trait_chat() {
436 let provider = MockLLMProvider;
437 let messages = vec![chat::ChatMessage::user().content("Test").build()];
438
439 let response = provider.chat(&messages, None).await.unwrap();
440 assert_eq!(response.text(), Some("Mock response".to_string()));
441 }
442
443 #[tokio::test]
444 async fn test_llm_provider_trait_completion() {
445 let provider = MockLLMProvider;
446 let request = completion::CompletionRequest::new("Test prompt");
447
448 let response = provider.complete(&request, None).await.unwrap();
449 assert_eq!(response.text, "Mock completion");
450 }
451
452 #[tokio::test]
453 async fn test_llm_provider_trait_embedding() {
454 let provider = MockLLMProvider;
455 let input = vec!["First".to_string(), "Second".to_string()];
456
457 let embeddings = provider.embed(input).await.unwrap();
458 assert_eq!(embeddings.len(), 2);
459 assert_eq!(embeddings[0], vec![0.0, 1.0]);
460 assert_eq!(embeddings[1], vec![1.0, 2.0]);
461 }
462
463 #[test]
464 fn test_llm_provider_tools() {
465 let provider = MockLLMProvider;
466 assert!(provider.tools().is_none());
467 }
468}