autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use serde::{Deserialize, Serialize};
17
18/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
19pub mod backends;
20
21/// Builder pattern for configuring and instantiating LLM providers
22pub mod builder;
23
24/// Chat-based interactions with language models (e.g. ChatGPT style)
25pub mod chat;
26
27/// Text completion capabilities (e.g. GPT-3 style completion)
28pub mod completion;
29
30/// Vector embeddings generation for text
31pub mod embedding;
32
33/// Error types and handling
34pub mod error;
35
36/// Evaluator for LLM providers
37pub mod evaluator;
38
39/// Secret store for storing API keys and other sensitive information
40#[cfg(not(target_arch = "wasm32"))]
41pub mod secret_store;
42
43/// Listing models support
44pub mod models;
45
46/// Core trait that all LLM providers must implement, combining chat, completion
47/// and embedding capabilities into a unified interface
48pub trait LLMProvider:
49    chat::ChatProvider
50    + completion::CompletionProvider
51    + embedding::EmbeddingProvider
52    + models::ModelsProvider
53    + Send
54    + Sync
55    + 'static
56{
57}
58
59/// Tool call represents a function call that an LLM wants to make.
60/// This is a standardized structure used across all providers.
61#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
62pub struct ToolCall {
63    /// The ID of the tool call.
64    pub id: String,
65    /// The type of the tool call (usually "function").
66    #[serde(rename = "type")]
67    pub call_type: String,
68    /// The function to call.
69    pub function: FunctionCall,
70}
71
72/// FunctionCall contains details about which function to call and with what arguments.
73#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
74pub struct FunctionCall {
75    /// The name of the function to call.
76    pub name: String,
77    /// The arguments to pass to the function, typically serialized as a JSON string.
78    pub arguments: String,
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
85    use crate::completion::CompletionProvider;
86    use crate::embedding::EmbeddingProvider;
87    use crate::error::LLMError;
88    use async_trait::async_trait;
89    use serde_json::json;
90
91    #[test]
92    fn test_tool_call_creation() {
93        let tool_call = ToolCall {
94            id: "call_123".to_string(),
95            call_type: "function".to_string(),
96            function: FunctionCall {
97                name: "test_function".to_string(),
98                arguments: "{\"param\": \"value\"}".to_string(),
99            },
100        };
101
102        assert_eq!(tool_call.id, "call_123");
103        assert_eq!(tool_call.call_type, "function");
104        assert_eq!(tool_call.function.name, "test_function");
105        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
106    }
107
108    #[test]
109    fn test_tool_call_serialization() {
110        let tool_call = ToolCall {
111            id: "call_456".to_string(),
112            call_type: "function".to_string(),
113            function: FunctionCall {
114                name: "serialize_test".to_string(),
115                arguments: "{\"test\": true}".to_string(),
116            },
117        };
118
119        let serialized = serde_json::to_string(&tool_call).unwrap();
120        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
121
122        assert_eq!(deserialized.id, "call_456");
123        assert_eq!(deserialized.call_type, "function");
124        assert_eq!(deserialized.function.name, "serialize_test");
125        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
126    }
127
128    #[test]
129    fn test_tool_call_equality() {
130        let tool_call1 = ToolCall {
131            id: "call_1".to_string(),
132            call_type: "function".to_string(),
133            function: FunctionCall {
134                name: "equal_test".to_string(),
135                arguments: "{}".to_string(),
136            },
137        };
138
139        let tool_call2 = ToolCall {
140            id: "call_1".to_string(),
141            call_type: "function".to_string(),
142            function: FunctionCall {
143                name: "equal_test".to_string(),
144                arguments: "{}".to_string(),
145            },
146        };
147
148        let tool_call3 = ToolCall {
149            id: "call_2".to_string(),
150            call_type: "function".to_string(),
151            function: FunctionCall {
152                name: "equal_test".to_string(),
153                arguments: "{}".to_string(),
154            },
155        };
156
157        assert_eq!(tool_call1, tool_call2);
158        assert_ne!(tool_call1, tool_call3);
159    }
160
161    #[test]
162    fn test_tool_call_clone() {
163        let tool_call = ToolCall {
164            id: "clone_test".to_string(),
165            call_type: "function".to_string(),
166            function: FunctionCall {
167                name: "test_clone".to_string(),
168                arguments: "{\"clone\": true}".to_string(),
169            },
170        };
171
172        let cloned = tool_call.clone();
173        assert_eq!(tool_call, cloned);
174        assert_eq!(tool_call.id, cloned.id);
175        assert_eq!(tool_call.function.name, cloned.function.name);
176    }
177
178    #[test]
179    fn test_tool_call_debug() {
180        let tool_call = ToolCall {
181            id: "debug_test".to_string(),
182            call_type: "function".to_string(),
183            function: FunctionCall {
184                name: "debug_function".to_string(),
185                arguments: "{}".to_string(),
186            },
187        };
188
189        let debug_str = format!("{tool_call:?}");
190        assert!(debug_str.contains("ToolCall"));
191        assert!(debug_str.contains("debug_test"));
192        assert!(debug_str.contains("debug_function"));
193    }
194
195    #[test]
196    fn test_function_call_creation() {
197        let function_call = FunctionCall {
198            name: "test_function".to_string(),
199            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
200        };
201
202        assert_eq!(function_call.name, "test_function");
203        assert_eq!(
204            function_call.arguments,
205            "{\"param1\": \"value1\", \"param2\": 42}"
206        );
207    }
208
209    #[test]
210    fn test_function_call_serialization() {
211        let function_call = FunctionCall {
212            name: "serialize_function".to_string(),
213            arguments: "{\"data\": [1, 2, 3]}".to_string(),
214        };
215
216        let serialized = serde_json::to_string(&function_call).unwrap();
217        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
218
219        assert_eq!(deserialized.name, "serialize_function");
220        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
221    }
222
223    #[test]
224    fn test_function_call_equality() {
225        let func1 = FunctionCall {
226            name: "equal_func".to_string(),
227            arguments: "{}".to_string(),
228        };
229
230        let func2 = FunctionCall {
231            name: "equal_func".to_string(),
232            arguments: "{}".to_string(),
233        };
234
235        let func3 = FunctionCall {
236            name: "different_func".to_string(),
237            arguments: "{}".to_string(),
238        };
239
240        assert_eq!(func1, func2);
241        assert_ne!(func1, func3);
242    }
243
244    #[test]
245    fn test_function_call_clone() {
246        let function_call = FunctionCall {
247            name: "clone_func".to_string(),
248            arguments: "{\"clone\": \"test\"}".to_string(),
249        };
250
251        let cloned = function_call.clone();
252        assert_eq!(function_call, cloned);
253        assert_eq!(function_call.name, cloned.name);
254        assert_eq!(function_call.arguments, cloned.arguments);
255    }
256
257    #[test]
258    fn test_function_call_debug() {
259        let function_call = FunctionCall {
260            name: "debug_func".to_string(),
261            arguments: "{}".to_string(),
262        };
263
264        let debug_str = format!("{function_call:?}");
265        assert!(debug_str.contains("FunctionCall"));
266        assert!(debug_str.contains("debug_func"));
267    }
268
269    #[test]
270    fn test_tool_call_with_empty_values() {
271        let tool_call = ToolCall {
272            id: String::new(),
273            call_type: String::new(),
274            function: FunctionCall {
275                name: String::new(),
276                arguments: String::new(),
277            },
278        };
279
280        assert!(tool_call.id.is_empty());
281        assert!(tool_call.call_type.is_empty());
282        assert!(tool_call.function.name.is_empty());
283        assert!(tool_call.function.arguments.is_empty());
284    }
285
286    #[test]
287    fn test_tool_call_with_complex_arguments() {
288        let complex_args = json!({
289            "nested": {
290                "array": [1, 2, 3],
291                "object": {
292                    "key": "value"
293                }
294            },
295            "simple": "string"
296        });
297
298        let tool_call = ToolCall {
299            id: "complex_call".to_string(),
300            call_type: "function".to_string(),
301            function: FunctionCall {
302                name: "complex_function".to_string(),
303                arguments: complex_args.to_string(),
304            },
305        };
306
307        let serialized = serde_json::to_string(&tool_call).unwrap();
308        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
309
310        assert_eq!(deserialized.id, "complex_call");
311        assert_eq!(deserialized.function.name, "complex_function");
312        // Arguments should be preserved as string
313        assert!(deserialized.function.arguments.contains("nested"));
314        assert!(deserialized.function.arguments.contains("array"));
315    }
316
317    #[test]
318    fn test_tool_call_with_unicode() {
319        let tool_call = ToolCall {
320            id: "unicode_call".to_string(),
321            call_type: "function".to_string(),
322            function: FunctionCall {
323                name: "unicode_function".to_string(),
324                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
325            },
326        };
327
328        let serialized = serde_json::to_string(&tool_call).unwrap();
329        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
330
331        assert_eq!(deserialized.id, "unicode_call");
332        assert_eq!(deserialized.function.name, "unicode_function");
333        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
334    }
335
336    #[test]
337    fn test_tool_call_large_arguments() {
338        let large_arg = "x".repeat(10000);
339        let tool_call = ToolCall {
340            id: "large_call".to_string(),
341            call_type: "function".to_string(),
342            function: FunctionCall {
343                name: "large_function".to_string(),
344                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
345            },
346        };
347
348        let serialized = serde_json::to_string(&tool_call).unwrap();
349        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
350
351        assert_eq!(deserialized.id, "large_call");
352        assert_eq!(deserialized.function.name, "large_function");
353        assert!(deserialized.function.arguments.len() > 10000);
354    }
355
356    // Mock LLM provider for testing
357    struct MockLLMProvider;
358
359    #[async_trait]
360    impl chat::ChatProvider for MockLLMProvider {
361        async fn chat(
362            &self,
363            _messages: &[ChatMessage],
364            _tools: Option<&[Tool]>,
365            _json_schema: Option<StructuredOutputFormat>,
366        ) -> Result<Box<dyn ChatResponse>, LLMError> {
367            Ok(Box::new(MockChatResponse {
368                text: Some("Mock response".into()),
369            }))
370        }
371    }
372
373    #[async_trait]
374    impl completion::CompletionProvider for MockLLMProvider {
375        async fn complete(
376            &self,
377            _req: &completion::CompletionRequest,
378            _json_schema: Option<chat::StructuredOutputFormat>,
379        ) -> Result<completion::CompletionResponse, error::LLMError> {
380            Ok(completion::CompletionResponse {
381                text: "Mock completion".to_string(),
382            })
383        }
384    }
385
386    #[async_trait]
387    impl embedding::EmbeddingProvider for MockLLMProvider {
388        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
389            let mut embeddings = Vec::new();
390            for (i, _) in input.iter().enumerate() {
391                embeddings.push(vec![i as f32, (i + 1) as f32]);
392            }
393            Ok(embeddings)
394        }
395    }
396
397    #[async_trait]
398    impl models::ModelsProvider for MockLLMProvider {}
399
400    impl LLMProvider for MockLLMProvider {}
401
402    struct MockChatResponse {
403        text: Option<String>,
404    }
405
406    impl chat::ChatResponse for MockChatResponse {
407        fn text(&self) -> Option<String> {
408            self.text.clone()
409        }
410
411        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
412            None
413        }
414    }
415
416    impl std::fmt::Debug for MockChatResponse {
417        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418            write!(f, "MockChatResponse")
419        }
420    }
421
422    impl std::fmt::Display for MockChatResponse {
423        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424            write!(f, "{}", self.text.as_deref().unwrap_or(""))
425        }
426    }
427
428    #[tokio::test]
429    async fn test_llm_provider_trait_chat() {
430        let provider = MockLLMProvider;
431        let messages = vec![chat::ChatMessage::user().content("Test").build()];
432
433        let response = provider.chat(&messages, None, None).await.unwrap();
434        assert_eq!(response.text(), Some("Mock response".to_string()));
435    }
436
437    #[tokio::test]
438    async fn test_llm_provider_trait_completion() {
439        let provider = MockLLMProvider;
440        let request = completion::CompletionRequest::new("Test prompt");
441
442        let response = provider.complete(&request, None).await.unwrap();
443        assert_eq!(response.text, "Mock completion");
444    }
445
446    #[tokio::test]
447    async fn test_llm_provider_trait_embedding() {
448        let provider = MockLLMProvider;
449        let input = vec!["First".to_string(), "Second".to_string()];
450
451        let embeddings = provider.embed(input).await.unwrap();
452        assert_eq!(embeddings.len(), 2);
453        assert_eq!(embeddings[0], vec![0.0, 1.0]);
454        assert_eq!(embeddings[1], vec![1.0, 2.0]);
455    }
456}