autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use chat::Tool;
17use serde::{Deserialize, Serialize};
18
19/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
20pub mod backends;
21
22/// Builder pattern for configuring and instantiating LLM providers
23pub mod builder;
24
25/// Chat-based interactions with language models (e.g. ChatGPT style)
26pub mod chat;
27
28/// Text completion capabilities (e.g. GPT-3 style completion)
29pub mod completion;
30
31/// Vector embeddings generation for text
32pub mod embedding;
33
34/// Error types and handling
35pub mod error;
36
37/// Evaluator for LLM providers
38pub mod evaluator;
39
40/// Secret store for storing API keys and other sensitive information
41pub mod secret_store;
42
43/// Listing models support
44pub mod models;
45
46/// Core trait that all LLM providers must implement, combining chat, completion
47/// and embedding capabilities into a unified interface
48pub trait LLMProvider:
49    chat::ChatProvider
50    + completion::CompletionProvider
51    + embedding::EmbeddingProvider
52    + models::ModelsProvider
53    + Send
54    + Sync
55    + 'static
56{
57    fn tools(&self) -> Option<&[Tool]> {
58        None
59    }
60}
61
62/// Tool call represents a function call that an LLM wants to make.
63/// This is a standardized structure used across all providers.
64#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
65pub struct ToolCall {
66    /// The ID of the tool call.
67    pub id: String,
68    /// The type of the tool call (usually "function").
69    #[serde(rename = "type")]
70    pub call_type: String,
71    /// The function to call.
72    pub function: FunctionCall,
73}
74
75/// FunctionCall contains details about which function to call and with what arguments.
76#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
77pub struct FunctionCall {
78    /// The name of the function to call.
79    pub name: String,
80    /// The arguments to pass to the function, typically serialized as a JSON string.
81    pub arguments: String,
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::chat::ChatProvider;
88    use crate::completion::CompletionProvider;
89    use crate::embedding::EmbeddingProvider;
90    use async_trait::async_trait;
91    use serde_json::json;
92
93    #[test]
94    fn test_tool_call_creation() {
95        let tool_call = ToolCall {
96            id: "call_123".to_string(),
97            call_type: "function".to_string(),
98            function: FunctionCall {
99                name: "test_function".to_string(),
100                arguments: "{\"param\": \"value\"}".to_string(),
101            },
102        };
103
104        assert_eq!(tool_call.id, "call_123");
105        assert_eq!(tool_call.call_type, "function");
106        assert_eq!(tool_call.function.name, "test_function");
107        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
108    }
109
110    #[test]
111    fn test_tool_call_serialization() {
112        let tool_call = ToolCall {
113            id: "call_456".to_string(),
114            call_type: "function".to_string(),
115            function: FunctionCall {
116                name: "serialize_test".to_string(),
117                arguments: "{\"test\": true}".to_string(),
118            },
119        };
120
121        let serialized = serde_json::to_string(&tool_call).unwrap();
122        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
123
124        assert_eq!(deserialized.id, "call_456");
125        assert_eq!(deserialized.call_type, "function");
126        assert_eq!(deserialized.function.name, "serialize_test");
127        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
128    }
129
130    #[test]
131    fn test_tool_call_equality() {
132        let tool_call1 = ToolCall {
133            id: "call_1".to_string(),
134            call_type: "function".to_string(),
135            function: FunctionCall {
136                name: "equal_test".to_string(),
137                arguments: "{}".to_string(),
138            },
139        };
140
141        let tool_call2 = ToolCall {
142            id: "call_1".to_string(),
143            call_type: "function".to_string(),
144            function: FunctionCall {
145                name: "equal_test".to_string(),
146                arguments: "{}".to_string(),
147            },
148        };
149
150        let tool_call3 = ToolCall {
151            id: "call_2".to_string(),
152            call_type: "function".to_string(),
153            function: FunctionCall {
154                name: "equal_test".to_string(),
155                arguments: "{}".to_string(),
156            },
157        };
158
159        assert_eq!(tool_call1, tool_call2);
160        assert_ne!(tool_call1, tool_call3);
161    }
162
163    #[test]
164    fn test_tool_call_clone() {
165        let tool_call = ToolCall {
166            id: "clone_test".to_string(),
167            call_type: "function".to_string(),
168            function: FunctionCall {
169                name: "test_clone".to_string(),
170                arguments: "{\"clone\": true}".to_string(),
171            },
172        };
173
174        let cloned = tool_call.clone();
175        assert_eq!(tool_call, cloned);
176        assert_eq!(tool_call.id, cloned.id);
177        assert_eq!(tool_call.function.name, cloned.function.name);
178    }
179
180    #[test]
181    fn test_tool_call_debug() {
182        let tool_call = ToolCall {
183            id: "debug_test".to_string(),
184            call_type: "function".to_string(),
185            function: FunctionCall {
186                name: "debug_function".to_string(),
187                arguments: "{}".to_string(),
188            },
189        };
190
191        let debug_str = format!("{tool_call:?}");
192        assert!(debug_str.contains("ToolCall"));
193        assert!(debug_str.contains("debug_test"));
194        assert!(debug_str.contains("debug_function"));
195    }
196
197    #[test]
198    fn test_function_call_creation() {
199        let function_call = FunctionCall {
200            name: "test_function".to_string(),
201            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
202        };
203
204        assert_eq!(function_call.name, "test_function");
205        assert_eq!(
206            function_call.arguments,
207            "{\"param1\": \"value1\", \"param2\": 42}"
208        );
209    }
210
211    #[test]
212    fn test_function_call_serialization() {
213        let function_call = FunctionCall {
214            name: "serialize_function".to_string(),
215            arguments: "{\"data\": [1, 2, 3]}".to_string(),
216        };
217
218        let serialized = serde_json::to_string(&function_call).unwrap();
219        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
220
221        assert_eq!(deserialized.name, "serialize_function");
222        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
223    }
224
225    #[test]
226    fn test_function_call_equality() {
227        let func1 = FunctionCall {
228            name: "equal_func".to_string(),
229            arguments: "{}".to_string(),
230        };
231
232        let func2 = FunctionCall {
233            name: "equal_func".to_string(),
234            arguments: "{}".to_string(),
235        };
236
237        let func3 = FunctionCall {
238            name: "different_func".to_string(),
239            arguments: "{}".to_string(),
240        };
241
242        assert_eq!(func1, func2);
243        assert_ne!(func1, func3);
244    }
245
246    #[test]
247    fn test_function_call_clone() {
248        let function_call = FunctionCall {
249            name: "clone_func".to_string(),
250            arguments: "{\"clone\": \"test\"}".to_string(),
251        };
252
253        let cloned = function_call.clone();
254        assert_eq!(function_call, cloned);
255        assert_eq!(function_call.name, cloned.name);
256        assert_eq!(function_call.arguments, cloned.arguments);
257    }
258
259    #[test]
260    fn test_function_call_debug() {
261        let function_call = FunctionCall {
262            name: "debug_func".to_string(),
263            arguments: "{}".to_string(),
264        };
265
266        let debug_str = format!("{function_call:?}");
267        assert!(debug_str.contains("FunctionCall"));
268        assert!(debug_str.contains("debug_func"));
269    }
270
271    #[test]
272    fn test_tool_call_with_empty_values() {
273        let tool_call = ToolCall {
274            id: String::new(),
275            call_type: String::new(),
276            function: FunctionCall {
277                name: String::new(),
278                arguments: String::new(),
279            },
280        };
281
282        assert!(tool_call.id.is_empty());
283        assert!(tool_call.call_type.is_empty());
284        assert!(tool_call.function.name.is_empty());
285        assert!(tool_call.function.arguments.is_empty());
286    }
287
288    #[test]
289    fn test_tool_call_with_complex_arguments() {
290        let complex_args = json!({
291            "nested": {
292                "array": [1, 2, 3],
293                "object": {
294                    "key": "value"
295                }
296            },
297            "simple": "string"
298        });
299
300        let tool_call = ToolCall {
301            id: "complex_call".to_string(),
302            call_type: "function".to_string(),
303            function: FunctionCall {
304                name: "complex_function".to_string(),
305                arguments: complex_args.to_string(),
306            },
307        };
308
309        let serialized = serde_json::to_string(&tool_call).unwrap();
310        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
311
312        assert_eq!(deserialized.id, "complex_call");
313        assert_eq!(deserialized.function.name, "complex_function");
314        // Arguments should be preserved as string
315        assert!(deserialized.function.arguments.contains("nested"));
316        assert!(deserialized.function.arguments.contains("array"));
317    }
318
319    #[test]
320    fn test_tool_call_with_unicode() {
321        let tool_call = ToolCall {
322            id: "unicode_call".to_string(),
323            call_type: "function".to_string(),
324            function: FunctionCall {
325                name: "unicode_function".to_string(),
326                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
327            },
328        };
329
330        let serialized = serde_json::to_string(&tool_call).unwrap();
331        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
332
333        assert_eq!(deserialized.id, "unicode_call");
334        assert_eq!(deserialized.function.name, "unicode_function");
335        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
336    }
337
338    #[test]
339    fn test_tool_call_large_arguments() {
340        let large_arg = "x".repeat(10000);
341        let tool_call = ToolCall {
342            id: "large_call".to_string(),
343            call_type: "function".to_string(),
344            function: FunctionCall {
345                name: "large_function".to_string(),
346                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
347            },
348        };
349
350        let serialized = serde_json::to_string(&tool_call).unwrap();
351        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
352
353        assert_eq!(deserialized.id, "large_call");
354        assert_eq!(deserialized.function.name, "large_function");
355        assert!(deserialized.function.arguments.len() > 10000);
356    }
357
358    // Mock LLM provider for testing
359    struct MockLLMProvider;
360
361    #[async_trait]
362    impl chat::ChatProvider for MockLLMProvider {
363        async fn chat_with_tools(
364            &self,
365            _messages: &[chat::ChatMessage],
366            _tools: Option<&[chat::Tool]>,
367            _json_schema: Option<chat::StructuredOutputFormat>,
368        ) -> Result<Box<dyn chat::ChatResponse>, error::LLMError> {
369            Ok(Box::new(MockChatResponse {
370                text: Some("Mock response".to_string()),
371            }))
372        }
373    }
374
375    #[async_trait]
376    impl completion::CompletionProvider for MockLLMProvider {
377        async fn complete(
378            &self,
379            _req: &completion::CompletionRequest,
380            _json_schema: Option<chat::StructuredOutputFormat>,
381        ) -> Result<completion::CompletionResponse, error::LLMError> {
382            Ok(completion::CompletionResponse {
383                text: "Mock completion".to_string(),
384            })
385        }
386    }
387
388    #[async_trait]
389    impl embedding::EmbeddingProvider for MockLLMProvider {
390        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
391            let mut embeddings = Vec::new();
392            for (i, _) in input.iter().enumerate() {
393                embeddings.push(vec![i as f32, (i + 1) as f32]);
394            }
395            Ok(embeddings)
396        }
397    }
398
399    #[async_trait]
400    impl models::ModelsProvider for MockLLMProvider {}
401
402    impl LLMProvider for MockLLMProvider {
403        fn tools(&self) -> Option<&[chat::Tool]> {
404            None
405        }
406    }
407
408    struct MockChatResponse {
409        text: Option<String>,
410    }
411
412    impl chat::ChatResponse for MockChatResponse {
413        fn text(&self) -> Option<String> {
414            self.text.clone()
415        }
416
417        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
418            None
419        }
420    }
421
422    impl std::fmt::Debug for MockChatResponse {
423        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424            write!(f, "MockChatResponse")
425        }
426    }
427
428    impl std::fmt::Display for MockChatResponse {
429        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430            write!(f, "{}", self.text.as_deref().unwrap_or(""))
431        }
432    }
433
434    #[tokio::test]
435    async fn test_llm_provider_trait_chat() {
436        let provider = MockLLMProvider;
437        let messages = vec![chat::ChatMessage::user().content("Test").build()];
438
439        let response = provider.chat(&messages, None).await.unwrap();
440        assert_eq!(response.text(), Some("Mock response".to_string()));
441    }
442
443    #[tokio::test]
444    async fn test_llm_provider_trait_completion() {
445        let provider = MockLLMProvider;
446        let request = completion::CompletionRequest::new("Test prompt");
447
448        let response = provider.complete(&request, None).await.unwrap();
449        assert_eq!(response.text, "Mock completion");
450    }
451
452    #[tokio::test]
453    async fn test_llm_provider_trait_embedding() {
454        let provider = MockLLMProvider;
455        let input = vec!["First".to_string(), "Second".to_string()];
456
457        let embeddings = provider.embed(input).await.unwrap();
458        assert_eq!(embeddings.len(), 2);
459        assert_eq!(embeddings[0], vec![0.0, 1.0]);
460        assert_eq!(embeddings[1], vec![1.0, 2.0]);
461    }
462
463    #[test]
464    fn test_llm_provider_tools() {
465        let provider = MockLLMProvider;
466        assert!(provider.tools().is_none());
467    }
468}