1use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20pub mod backends;
22
23pub mod builder;
25
26pub mod chat;
28
29pub mod completion;
31
32pub mod embedding;
34
35pub mod error;
37
38pub mod evaluator;
40
41#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45pub mod models;
47
48pub mod providers;
49
50pub use async_trait::async_trait;
52
53pub trait LLMProvider:
56 chat::ChatProvider
57 + completion::CompletionProvider
58 + embedding::EmbeddingProvider
59 + models::ModelsProvider
60 + Send
61 + Sync
62 + 'static
63{
64}
65
66#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
69pub struct ToolCall {
70 pub id: String,
72 #[serde(rename = "type")]
74 pub call_type: String,
75 pub function: FunctionCall,
77}
78
79impl Display for ToolCall {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(
82 f,
83 "ToolCall {{ id: {}, type: {}, function: {:?} }}",
84 self.id, self.call_type, self.function
85 )
86 }
87}
88
89#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
91pub struct FunctionCall {
92 pub name: String,
94 pub arguments: String,
96}
97
98pub fn default_call_type() -> String {
100 "function".to_string()
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
107 use crate::completion::CompletionProvider;
108 use crate::embedding::EmbeddingProvider;
109 use crate::error::LLMError;
110 use async_trait::async_trait;
111 use serde_json::json;
112
113 #[test]
114 fn test_tool_call_creation() {
115 let tool_call = ToolCall {
116 id: "call_123".to_string(),
117 call_type: "function".to_string(),
118 function: FunctionCall {
119 name: "test_function".to_string(),
120 arguments: "{\"param\": \"value\"}".to_string(),
121 },
122 };
123
124 assert_eq!(tool_call.id, "call_123");
125 assert_eq!(tool_call.call_type, "function");
126 assert_eq!(tool_call.function.name, "test_function");
127 assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
128 }
129
130 #[test]
131 fn test_tool_call_serialization() {
132 let tool_call = ToolCall {
133 id: "call_456".to_string(),
134 call_type: "function".to_string(),
135 function: FunctionCall {
136 name: "serialize_test".to_string(),
137 arguments: "{\"test\": true}".to_string(),
138 },
139 };
140
141 let serialized = serde_json::to_string(&tool_call).unwrap();
142 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
143
144 assert_eq!(deserialized.id, "call_456");
145 assert_eq!(deserialized.call_type, "function");
146 assert_eq!(deserialized.function.name, "serialize_test");
147 assert_eq!(deserialized.function.arguments, "{\"test\": true}");
148 }
149
150 #[test]
151 fn test_tool_call_equality() {
152 let tool_call1 = ToolCall {
153 id: "call_1".to_string(),
154 call_type: "function".to_string(),
155 function: FunctionCall {
156 name: "equal_test".to_string(),
157 arguments: "{}".to_string(),
158 },
159 };
160
161 let tool_call2 = ToolCall {
162 id: "call_1".to_string(),
163 call_type: "function".to_string(),
164 function: FunctionCall {
165 name: "equal_test".to_string(),
166 arguments: "{}".to_string(),
167 },
168 };
169
170 let tool_call3 = ToolCall {
171 id: "call_2".to_string(),
172 call_type: "function".to_string(),
173 function: FunctionCall {
174 name: "equal_test".to_string(),
175 arguments: "{}".to_string(),
176 },
177 };
178
179 assert_eq!(tool_call1, tool_call2);
180 assert_ne!(tool_call1, tool_call3);
181 }
182
183 #[test]
184 fn test_tool_call_clone() {
185 let tool_call = ToolCall {
186 id: "clone_test".to_string(),
187 call_type: "function".to_string(),
188 function: FunctionCall {
189 name: "test_clone".to_string(),
190 arguments: "{\"clone\": true}".to_string(),
191 },
192 };
193
194 let cloned = tool_call.clone();
195 assert_eq!(tool_call, cloned);
196 assert_eq!(tool_call.id, cloned.id);
197 assert_eq!(tool_call.function.name, cloned.function.name);
198 }
199
200 #[test]
201 fn test_tool_call_debug() {
202 let tool_call = ToolCall {
203 id: "debug_test".to_string(),
204 call_type: "function".to_string(),
205 function: FunctionCall {
206 name: "debug_function".to_string(),
207 arguments: "{}".to_string(),
208 },
209 };
210
211 let debug_str = format!("{tool_call:?}");
212 assert!(debug_str.contains("ToolCall"));
213 assert!(debug_str.contains("debug_test"));
214 assert!(debug_str.contains("debug_function"));
215 }
216
217 #[test]
218 fn test_function_call_creation() {
219 let function_call = FunctionCall {
220 name: "test_function".to_string(),
221 arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
222 };
223
224 assert_eq!(function_call.name, "test_function");
225 assert_eq!(
226 function_call.arguments,
227 "{\"param1\": \"value1\", \"param2\": 42}"
228 );
229 }
230
231 #[test]
232 fn test_function_call_serialization() {
233 let function_call = FunctionCall {
234 name: "serialize_function".to_string(),
235 arguments: "{\"data\": [1, 2, 3]}".to_string(),
236 };
237
238 let serialized = serde_json::to_string(&function_call).unwrap();
239 let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
240
241 assert_eq!(deserialized.name, "serialize_function");
242 assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
243 }
244
245 #[test]
246 fn test_function_call_equality() {
247 let func1 = FunctionCall {
248 name: "equal_func".to_string(),
249 arguments: "{}".to_string(),
250 };
251
252 let func2 = FunctionCall {
253 name: "equal_func".to_string(),
254 arguments: "{}".to_string(),
255 };
256
257 let func3 = FunctionCall {
258 name: "different_func".to_string(),
259 arguments: "{}".to_string(),
260 };
261
262 assert_eq!(func1, func2);
263 assert_ne!(func1, func3);
264 }
265
266 #[test]
267 fn test_function_call_clone() {
268 let function_call = FunctionCall {
269 name: "clone_func".to_string(),
270 arguments: "{\"clone\": \"test\"}".to_string(),
271 };
272
273 let cloned = function_call.clone();
274 assert_eq!(function_call, cloned);
275 assert_eq!(function_call.name, cloned.name);
276 assert_eq!(function_call.arguments, cloned.arguments);
277 }
278
279 #[test]
280 fn test_function_call_debug() {
281 let function_call = FunctionCall {
282 name: "debug_func".to_string(),
283 arguments: "{}".to_string(),
284 };
285
286 let debug_str = format!("{function_call:?}");
287 assert!(debug_str.contains("FunctionCall"));
288 assert!(debug_str.contains("debug_func"));
289 }
290
291 #[test]
292 fn test_tool_call_with_empty_values() {
293 let tool_call = ToolCall {
294 id: String::new(),
295 call_type: String::new(),
296 function: FunctionCall {
297 name: String::new(),
298 arguments: String::new(),
299 },
300 };
301
302 assert!(tool_call.id.is_empty());
303 assert!(tool_call.call_type.is_empty());
304 assert!(tool_call.function.name.is_empty());
305 assert!(tool_call.function.arguments.is_empty());
306 }
307
308 #[test]
309 fn test_tool_call_with_complex_arguments() {
310 let complex_args = json!({
311 "nested": {
312 "array": [1, 2, 3],
313 "object": {
314 "key": "value"
315 }
316 },
317 "simple": "string"
318 });
319
320 let tool_call = ToolCall {
321 id: "complex_call".to_string(),
322 call_type: "function".to_string(),
323 function: FunctionCall {
324 name: "complex_function".to_string(),
325 arguments: complex_args.to_string(),
326 },
327 };
328
329 let serialized = serde_json::to_string(&tool_call).unwrap();
330 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
331
332 assert_eq!(deserialized.id, "complex_call");
333 assert_eq!(deserialized.function.name, "complex_function");
334 assert!(deserialized.function.arguments.contains("nested"));
336 assert!(deserialized.function.arguments.contains("array"));
337 }
338
339 #[test]
340 fn test_tool_call_with_unicode() {
341 let tool_call = ToolCall {
342 id: "unicode_call".to_string(),
343 call_type: "function".to_string(),
344 function: FunctionCall {
345 name: "unicode_function".to_string(),
346 arguments: "{\"message\": \"Hello δΈη! π\"}".to_string(),
347 },
348 };
349
350 let serialized = serde_json::to_string(&tool_call).unwrap();
351 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
352
353 assert_eq!(deserialized.id, "unicode_call");
354 assert_eq!(deserialized.function.name, "unicode_function");
355 assert!(deserialized.function.arguments.contains("Hello δΈη! π"));
356 }
357
358 #[test]
359 fn test_tool_call_large_arguments() {
360 let large_arg = "x".repeat(10000);
361 let tool_call = ToolCall {
362 id: "large_call".to_string(),
363 call_type: "function".to_string(),
364 function: FunctionCall {
365 name: "large_function".to_string(),
366 arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
367 },
368 };
369
370 let serialized = serde_json::to_string(&tool_call).unwrap();
371 let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
372
373 assert_eq!(deserialized.id, "large_call");
374 assert_eq!(deserialized.function.name, "large_function");
375 assert!(deserialized.function.arguments.len() > 10000);
376 }
377
378 struct MockLLMProvider;
380
381 #[async_trait]
382 impl chat::ChatProvider for MockLLMProvider {
383 async fn chat(
384 &self,
385 _messages: &[ChatMessage],
386 _json_schema: Option<StructuredOutputFormat>,
387 ) -> Result<Box<dyn ChatResponse>, LLMError> {
388 Ok(Box::new(MockChatResponse {
389 text: Some("Mock response".into()),
390 }))
391 }
392
393 async fn chat_with_tools(
394 &self,
395 _messages: &[ChatMessage],
396 _tools: Option<&[Tool]>,
397 _json_schema: Option<StructuredOutputFormat>,
398 ) -> Result<Box<dyn ChatResponse>, LLMError> {
399 Ok(Box::new(MockChatResponse {
400 text: Some("Mock response".into()),
401 }))
402 }
403 }
404
405 #[async_trait]
406 impl completion::CompletionProvider for MockLLMProvider {
407 async fn complete(
408 &self,
409 _req: &completion::CompletionRequest,
410 _json_schema: Option<chat::StructuredOutputFormat>,
411 ) -> Result<completion::CompletionResponse, error::LLMError> {
412 Ok(completion::CompletionResponse {
413 text: "Mock completion".to_string(),
414 })
415 }
416 }
417
418 #[async_trait]
419 impl embedding::EmbeddingProvider for MockLLMProvider {
420 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
421 let mut embeddings = Vec::new();
422 for (i, _) in input.iter().enumerate() {
423 embeddings.push(vec![i as f32, (i + 1) as f32]);
424 }
425 Ok(embeddings)
426 }
427 }
428
429 #[async_trait]
430 impl models::ModelsProvider for MockLLMProvider {}
431
432 impl LLMProvider for MockLLMProvider {}
433
434 struct MockChatResponse {
435 text: Option<String>,
436 }
437
438 impl chat::ChatResponse for MockChatResponse {
439 fn text(&self) -> Option<String> {
440 self.text.clone()
441 }
442
443 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
444 None
445 }
446 }
447
448 impl std::fmt::Debug for MockChatResponse {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 write!(f, "MockChatResponse")
451 }
452 }
453
454 impl std::fmt::Display for MockChatResponse {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 write!(f, "{}", self.text.as_deref().unwrap_or(""))
457 }
458 }
459
460 #[tokio::test]
461 async fn test_llm_provider_trait_chat() {
462 let provider = MockLLMProvider;
463 let messages = vec![chat::ChatMessage::user().content("Test").build()];
464
465 let response = provider.chat(&messages, None).await.unwrap();
466 assert_eq!(response.text(), Some("Mock response".to_string()));
467 }
468
469 #[tokio::test]
470 async fn test_llm_provider_trait_completion() {
471 let provider = MockLLMProvider;
472 let request = completion::CompletionRequest::new("Test prompt");
473
474 let response = provider.complete(&request, None).await.unwrap();
475 assert_eq!(response.text, "Mock completion");
476 }
477
478 #[tokio::test]
479 async fn test_llm_provider_trait_embedding() {
480 let provider = MockLLMProvider;
481 let input = vec!["First".to_string(), "Second".to_string()];
482
483 let embeddings = provider.embed(input).await.unwrap();
484 assert_eq!(embeddings.len(), 2);
485 assert_eq!(embeddings[0], vec![0.0, 1.0]);
486 assert_eq!(embeddings[1], vec![1.0, 2.0]);
487 }
488}