autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use serde::{Deserialize, Serialize};
17
18/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
19pub mod backends;
20
21/// Builder pattern for configuring and instantiating LLM providers
22pub mod builder;
23
24/// Chat-based interactions with language models (e.g. ChatGPT style)
25pub mod chat;
26
27/// Text completion capabilities (e.g. GPT-3 style completion)
28pub mod completion;
29
30/// Vector embeddings generation for text
31pub mod embedding;
32
33/// Error types and handling
34pub mod error;
35
36/// Evaluator for LLM providers
37pub mod evaluator;
38
39/// Secret store for storing API keys and other sensitive information
40#[cfg(not(target_arch = "wasm32"))]
41pub mod secret_store;
42
43/// Listing models support
44pub mod models;
45
46pub mod providers;
47
48//Re-export for convenience
49pub use async_trait::async_trait;
50
51/// Core trait that all LLM providers must implement, combining chat, completion
52/// and embedding capabilities into a unified interface
53pub trait LLMProvider:
54    chat::ChatProvider
55    + completion::CompletionProvider
56    + embedding::EmbeddingProvider
57    + models::ModelsProvider
58    + Send
59    + Sync
60    + 'static
61{
62}
63
64/// Tool call represents a function call that an LLM wants to make.
65/// This is a standardized structure used across all providers.
66#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
67pub struct ToolCall {
68    /// The ID of the tool call.
69    pub id: String,
70    /// The type of the tool call (usually "function").
71    #[serde(rename = "type")]
72    pub call_type: String,
73    /// The function to call.
74    pub function: FunctionCall,
75}
76
77/// FunctionCall contains details about which function to call and with what arguments.
78#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
79pub struct FunctionCall {
80    /// The name of the function to call.
81    pub name: String,
82    /// The arguments to pass to the function, typically serialized as a JSON string.
83    pub arguments: String,
84}
85
86/// Default value for call_type field in ToolCall
87pub fn default_call_type() -> String {
88    "function".to_string()
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
95    use crate::completion::CompletionProvider;
96    use crate::embedding::EmbeddingProvider;
97    use crate::error::LLMError;
98    use async_trait::async_trait;
99    use serde_json::json;
100
101    #[test]
102    fn test_tool_call_creation() {
103        let tool_call = ToolCall {
104            id: "call_123".to_string(),
105            call_type: "function".to_string(),
106            function: FunctionCall {
107                name: "test_function".to_string(),
108                arguments: "{\"param\": \"value\"}".to_string(),
109            },
110        };
111
112        assert_eq!(tool_call.id, "call_123");
113        assert_eq!(tool_call.call_type, "function");
114        assert_eq!(tool_call.function.name, "test_function");
115        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
116    }
117
118    #[test]
119    fn test_tool_call_serialization() {
120        let tool_call = ToolCall {
121            id: "call_456".to_string(),
122            call_type: "function".to_string(),
123            function: FunctionCall {
124                name: "serialize_test".to_string(),
125                arguments: "{\"test\": true}".to_string(),
126            },
127        };
128
129        let serialized = serde_json::to_string(&tool_call).unwrap();
130        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
131
132        assert_eq!(deserialized.id, "call_456");
133        assert_eq!(deserialized.call_type, "function");
134        assert_eq!(deserialized.function.name, "serialize_test");
135        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
136    }
137
138    #[test]
139    fn test_tool_call_equality() {
140        let tool_call1 = ToolCall {
141            id: "call_1".to_string(),
142            call_type: "function".to_string(),
143            function: FunctionCall {
144                name: "equal_test".to_string(),
145                arguments: "{}".to_string(),
146            },
147        };
148
149        let tool_call2 = ToolCall {
150            id: "call_1".to_string(),
151            call_type: "function".to_string(),
152            function: FunctionCall {
153                name: "equal_test".to_string(),
154                arguments: "{}".to_string(),
155            },
156        };
157
158        let tool_call3 = ToolCall {
159            id: "call_2".to_string(),
160            call_type: "function".to_string(),
161            function: FunctionCall {
162                name: "equal_test".to_string(),
163                arguments: "{}".to_string(),
164            },
165        };
166
167        assert_eq!(tool_call1, tool_call2);
168        assert_ne!(tool_call1, tool_call3);
169    }
170
171    #[test]
172    fn test_tool_call_clone() {
173        let tool_call = ToolCall {
174            id: "clone_test".to_string(),
175            call_type: "function".to_string(),
176            function: FunctionCall {
177                name: "test_clone".to_string(),
178                arguments: "{\"clone\": true}".to_string(),
179            },
180        };
181
182        let cloned = tool_call.clone();
183        assert_eq!(tool_call, cloned);
184        assert_eq!(tool_call.id, cloned.id);
185        assert_eq!(tool_call.function.name, cloned.function.name);
186    }
187
188    #[test]
189    fn test_tool_call_debug() {
190        let tool_call = ToolCall {
191            id: "debug_test".to_string(),
192            call_type: "function".to_string(),
193            function: FunctionCall {
194                name: "debug_function".to_string(),
195                arguments: "{}".to_string(),
196            },
197        };
198
199        let debug_str = format!("{tool_call:?}");
200        assert!(debug_str.contains("ToolCall"));
201        assert!(debug_str.contains("debug_test"));
202        assert!(debug_str.contains("debug_function"));
203    }
204
205    #[test]
206    fn test_function_call_creation() {
207        let function_call = FunctionCall {
208            name: "test_function".to_string(),
209            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
210        };
211
212        assert_eq!(function_call.name, "test_function");
213        assert_eq!(
214            function_call.arguments,
215            "{\"param1\": \"value1\", \"param2\": 42}"
216        );
217    }
218
219    #[test]
220    fn test_function_call_serialization() {
221        let function_call = FunctionCall {
222            name: "serialize_function".to_string(),
223            arguments: "{\"data\": [1, 2, 3]}".to_string(),
224        };
225
226        let serialized = serde_json::to_string(&function_call).unwrap();
227        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
228
229        assert_eq!(deserialized.name, "serialize_function");
230        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
231    }
232
233    #[test]
234    fn test_function_call_equality() {
235        let func1 = FunctionCall {
236            name: "equal_func".to_string(),
237            arguments: "{}".to_string(),
238        };
239
240        let func2 = FunctionCall {
241            name: "equal_func".to_string(),
242            arguments: "{}".to_string(),
243        };
244
245        let func3 = FunctionCall {
246            name: "different_func".to_string(),
247            arguments: "{}".to_string(),
248        };
249
250        assert_eq!(func1, func2);
251        assert_ne!(func1, func3);
252    }
253
254    #[test]
255    fn test_function_call_clone() {
256        let function_call = FunctionCall {
257            name: "clone_func".to_string(),
258            arguments: "{\"clone\": \"test\"}".to_string(),
259        };
260
261        let cloned = function_call.clone();
262        assert_eq!(function_call, cloned);
263        assert_eq!(function_call.name, cloned.name);
264        assert_eq!(function_call.arguments, cloned.arguments);
265    }
266
267    #[test]
268    fn test_function_call_debug() {
269        let function_call = FunctionCall {
270            name: "debug_func".to_string(),
271            arguments: "{}".to_string(),
272        };
273
274        let debug_str = format!("{function_call:?}");
275        assert!(debug_str.contains("FunctionCall"));
276        assert!(debug_str.contains("debug_func"));
277    }
278
279    #[test]
280    fn test_tool_call_with_empty_values() {
281        let tool_call = ToolCall {
282            id: String::new(),
283            call_type: String::new(),
284            function: FunctionCall {
285                name: String::new(),
286                arguments: String::new(),
287            },
288        };
289
290        assert!(tool_call.id.is_empty());
291        assert!(tool_call.call_type.is_empty());
292        assert!(tool_call.function.name.is_empty());
293        assert!(tool_call.function.arguments.is_empty());
294    }
295
296    #[test]
297    fn test_tool_call_with_complex_arguments() {
298        let complex_args = json!({
299            "nested": {
300                "array": [1, 2, 3],
301                "object": {
302                    "key": "value"
303                }
304            },
305            "simple": "string"
306        });
307
308        let tool_call = ToolCall {
309            id: "complex_call".to_string(),
310            call_type: "function".to_string(),
311            function: FunctionCall {
312                name: "complex_function".to_string(),
313                arguments: complex_args.to_string(),
314            },
315        };
316
317        let serialized = serde_json::to_string(&tool_call).unwrap();
318        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
319
320        assert_eq!(deserialized.id, "complex_call");
321        assert_eq!(deserialized.function.name, "complex_function");
322        // Arguments should be preserved as string
323        assert!(deserialized.function.arguments.contains("nested"));
324        assert!(deserialized.function.arguments.contains("array"));
325    }
326
327    #[test]
328    fn test_tool_call_with_unicode() {
329        let tool_call = ToolCall {
330            id: "unicode_call".to_string(),
331            call_type: "function".to_string(),
332            function: FunctionCall {
333                name: "unicode_function".to_string(),
334                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
335            },
336        };
337
338        let serialized = serde_json::to_string(&tool_call).unwrap();
339        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
340
341        assert_eq!(deserialized.id, "unicode_call");
342        assert_eq!(deserialized.function.name, "unicode_function");
343        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
344    }
345
346    #[test]
347    fn test_tool_call_large_arguments() {
348        let large_arg = "x".repeat(10000);
349        let tool_call = ToolCall {
350            id: "large_call".to_string(),
351            call_type: "function".to_string(),
352            function: FunctionCall {
353                name: "large_function".to_string(),
354                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
355            },
356        };
357
358        let serialized = serde_json::to_string(&tool_call).unwrap();
359        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
360
361        assert_eq!(deserialized.id, "large_call");
362        assert_eq!(deserialized.function.name, "large_function");
363        assert!(deserialized.function.arguments.len() > 10000);
364    }
365
366    // Mock LLM provider for testing
367    struct MockLLMProvider;
368
369    #[async_trait]
370    impl chat::ChatProvider for MockLLMProvider {
371        async fn chat(
372            &self,
373            _messages: &[ChatMessage],
374            _tools: Option<&[Tool]>,
375            _json_schema: Option<StructuredOutputFormat>,
376        ) -> Result<Box<dyn ChatResponse>, LLMError> {
377            Ok(Box::new(MockChatResponse {
378                text: Some("Mock response".into()),
379            }))
380        }
381    }
382
383    #[async_trait]
384    impl completion::CompletionProvider for MockLLMProvider {
385        async fn complete(
386            &self,
387            _req: &completion::CompletionRequest,
388            _json_schema: Option<chat::StructuredOutputFormat>,
389        ) -> Result<completion::CompletionResponse, error::LLMError> {
390            Ok(completion::CompletionResponse {
391                text: "Mock completion".to_string(),
392            })
393        }
394    }
395
396    #[async_trait]
397    impl embedding::EmbeddingProvider for MockLLMProvider {
398        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
399            let mut embeddings = Vec::new();
400            for (i, _) in input.iter().enumerate() {
401                embeddings.push(vec![i as f32, (i + 1) as f32]);
402            }
403            Ok(embeddings)
404        }
405    }
406
407    #[async_trait]
408    impl models::ModelsProvider for MockLLMProvider {}
409
410    impl LLMProvider for MockLLMProvider {}
411
412    struct MockChatResponse {
413        text: Option<String>,
414    }
415
416    impl chat::ChatResponse for MockChatResponse {
417        fn text(&self) -> Option<String> {
418            self.text.clone()
419        }
420
421        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
422            None
423        }
424    }
425
426    impl std::fmt::Debug for MockChatResponse {
427        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428            write!(f, "MockChatResponse")
429        }
430    }
431
432    impl std::fmt::Display for MockChatResponse {
433        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434            write!(f, "{}", self.text.as_deref().unwrap_or(""))
435        }
436    }
437
438    #[tokio::test]
439    async fn test_llm_provider_trait_chat() {
440        let provider = MockLLMProvider;
441        let messages = vec![chat::ChatMessage::user().content("Test").build()];
442
443        let response = provider.chat(&messages, None, None).await.unwrap();
444        assert_eq!(response.text(), Some("Mock response".to_string()));
445    }
446
447    #[tokio::test]
448    async fn test_llm_provider_trait_completion() {
449        let provider = MockLLMProvider;
450        let request = completion::CompletionRequest::new("Test prompt");
451
452        let response = provider.complete(&request, None).await.unwrap();
453        assert_eq!(response.text, "Mock completion");
454    }
455
456    #[tokio::test]
457    async fn test_llm_provider_trait_embedding() {
458        let provider = MockLLMProvider;
459        let input = vec!["First".to_string(), "Second".to_string()];
460
461        let embeddings = provider.embed(input).await.unwrap();
462        assert_eq!(embeddings.len(), 2);
463        assert_eq!(embeddings[0], vec![0.0, 1.0]);
464        assert_eq!(embeddings[1], vec![1.0, 2.0]);
465    }
466}