autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
21pub mod backends;
22
23/// Builder pattern for configuring and instantiating LLM providers
24pub mod builder;
25
26/// Chat-based interactions with language models (e.g. ChatGPT style)
27pub mod chat;
28
29/// Text completion capabilities (e.g. GPT-3 style completion)
30pub mod completion;
31
32/// Vector embeddings generation for text
33pub mod embedding;
34
35/// Error types and handling
36pub mod error;
37
38/// Evaluator for LLM providers
39pub mod evaluator;
40
41/// Secret store for storing API keys and other sensitive information
42#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45/// Listing models support
46pub mod models;
47
48pub mod providers;
49
50//Re-export for convenience
51pub use async_trait::async_trait;
52
53/// Core trait that all LLM providers must implement, combining chat, completion
54/// and embedding capabilities into a unified interface
55pub trait LLMProvider:
56    chat::ChatProvider
57    + completion::CompletionProvider
58    + embedding::EmbeddingProvider
59    + models::ModelsProvider
60    + Send
61    + Sync
62    + 'static
63{
64}
65
66/// Tool call represents a function call that an LLM wants to make.
67/// This is a standardized structure used across all providers.
68#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
69pub struct ToolCall {
70    /// The ID of the tool call.
71    pub id: String,
72    /// The type of the tool call (usually "function").
73    #[serde(rename = "type")]
74    pub call_type: String,
75    /// The function to call.
76    pub function: FunctionCall,
77}
78
79impl Display for ToolCall {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(
82            f,
83            "ToolCall {{ id: {}, type: {}, function: {:?} }}",
84            self.id, self.call_type, self.function
85        )
86    }
87}
88
89/// FunctionCall contains details about which function to call and with what arguments.
90#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
91pub struct FunctionCall {
92    /// The name of the function to call.
93    pub name: String,
94    /// The arguments to pass to the function, typically serialized as a JSON string.
95    pub arguments: String,
96}
97
98/// Default value for call_type field in ToolCall
99pub fn default_call_type() -> String {
100    "function".to_string()
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
107    use crate::completion::CompletionProvider;
108    use crate::embedding::EmbeddingProvider;
109    use crate::error::LLMError;
110    use async_trait::async_trait;
111    use serde_json::json;
112
113    #[test]
114    fn test_tool_call_creation() {
115        let tool_call = ToolCall {
116            id: "call_123".to_string(),
117            call_type: "function".to_string(),
118            function: FunctionCall {
119                name: "test_function".to_string(),
120                arguments: "{\"param\": \"value\"}".to_string(),
121            },
122        };
123
124        assert_eq!(tool_call.id, "call_123");
125        assert_eq!(tool_call.call_type, "function");
126        assert_eq!(tool_call.function.name, "test_function");
127        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
128    }
129
130    #[test]
131    fn test_tool_call_serialization() {
132        let tool_call = ToolCall {
133            id: "call_456".to_string(),
134            call_type: "function".to_string(),
135            function: FunctionCall {
136                name: "serialize_test".to_string(),
137                arguments: "{\"test\": true}".to_string(),
138            },
139        };
140
141        let serialized = serde_json::to_string(&tool_call).unwrap();
142        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
143
144        assert_eq!(deserialized.id, "call_456");
145        assert_eq!(deserialized.call_type, "function");
146        assert_eq!(deserialized.function.name, "serialize_test");
147        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
148    }
149
150    #[test]
151    fn test_tool_call_equality() {
152        let tool_call1 = ToolCall {
153            id: "call_1".to_string(),
154            call_type: "function".to_string(),
155            function: FunctionCall {
156                name: "equal_test".to_string(),
157                arguments: "{}".to_string(),
158            },
159        };
160
161        let tool_call2 = ToolCall {
162            id: "call_1".to_string(),
163            call_type: "function".to_string(),
164            function: FunctionCall {
165                name: "equal_test".to_string(),
166                arguments: "{}".to_string(),
167            },
168        };
169
170        let tool_call3 = ToolCall {
171            id: "call_2".to_string(),
172            call_type: "function".to_string(),
173            function: FunctionCall {
174                name: "equal_test".to_string(),
175                arguments: "{}".to_string(),
176            },
177        };
178
179        assert_eq!(tool_call1, tool_call2);
180        assert_ne!(tool_call1, tool_call3);
181    }
182
183    #[test]
184    fn test_tool_call_clone() {
185        let tool_call = ToolCall {
186            id: "clone_test".to_string(),
187            call_type: "function".to_string(),
188            function: FunctionCall {
189                name: "test_clone".to_string(),
190                arguments: "{\"clone\": true}".to_string(),
191            },
192        };
193
194        let cloned = tool_call.clone();
195        assert_eq!(tool_call, cloned);
196        assert_eq!(tool_call.id, cloned.id);
197        assert_eq!(tool_call.function.name, cloned.function.name);
198    }
199
200    #[test]
201    fn test_tool_call_debug() {
202        let tool_call = ToolCall {
203            id: "debug_test".to_string(),
204            call_type: "function".to_string(),
205            function: FunctionCall {
206                name: "debug_function".to_string(),
207                arguments: "{}".to_string(),
208            },
209        };
210
211        let debug_str = format!("{tool_call:?}");
212        assert!(debug_str.contains("ToolCall"));
213        assert!(debug_str.contains("debug_test"));
214        assert!(debug_str.contains("debug_function"));
215    }
216
217    #[test]
218    fn test_function_call_creation() {
219        let function_call = FunctionCall {
220            name: "test_function".to_string(),
221            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
222        };
223
224        assert_eq!(function_call.name, "test_function");
225        assert_eq!(
226            function_call.arguments,
227            "{\"param1\": \"value1\", \"param2\": 42}"
228        );
229    }
230
231    #[test]
232    fn test_function_call_serialization() {
233        let function_call = FunctionCall {
234            name: "serialize_function".to_string(),
235            arguments: "{\"data\": [1, 2, 3]}".to_string(),
236        };
237
238        let serialized = serde_json::to_string(&function_call).unwrap();
239        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
240
241        assert_eq!(deserialized.name, "serialize_function");
242        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
243    }
244
245    #[test]
246    fn test_function_call_equality() {
247        let func1 = FunctionCall {
248            name: "equal_func".to_string(),
249            arguments: "{}".to_string(),
250        };
251
252        let func2 = FunctionCall {
253            name: "equal_func".to_string(),
254            arguments: "{}".to_string(),
255        };
256
257        let func3 = FunctionCall {
258            name: "different_func".to_string(),
259            arguments: "{}".to_string(),
260        };
261
262        assert_eq!(func1, func2);
263        assert_ne!(func1, func3);
264    }
265
266    #[test]
267    fn test_function_call_clone() {
268        let function_call = FunctionCall {
269            name: "clone_func".to_string(),
270            arguments: "{\"clone\": \"test\"}".to_string(),
271        };
272
273        let cloned = function_call.clone();
274        assert_eq!(function_call, cloned);
275        assert_eq!(function_call.name, cloned.name);
276        assert_eq!(function_call.arguments, cloned.arguments);
277    }
278
279    #[test]
280    fn test_function_call_debug() {
281        let function_call = FunctionCall {
282            name: "debug_func".to_string(),
283            arguments: "{}".to_string(),
284        };
285
286        let debug_str = format!("{function_call:?}");
287        assert!(debug_str.contains("FunctionCall"));
288        assert!(debug_str.contains("debug_func"));
289    }
290
291    #[test]
292    fn test_tool_call_with_empty_values() {
293        let tool_call = ToolCall {
294            id: String::new(),
295            call_type: String::new(),
296            function: FunctionCall {
297                name: String::new(),
298                arguments: String::new(),
299            },
300        };
301
302        assert!(tool_call.id.is_empty());
303        assert!(tool_call.call_type.is_empty());
304        assert!(tool_call.function.name.is_empty());
305        assert!(tool_call.function.arguments.is_empty());
306    }
307
308    #[test]
309    fn test_tool_call_with_complex_arguments() {
310        let complex_args = json!({
311            "nested": {
312                "array": [1, 2, 3],
313                "object": {
314                    "key": "value"
315                }
316            },
317            "simple": "string"
318        });
319
320        let tool_call = ToolCall {
321            id: "complex_call".to_string(),
322            call_type: "function".to_string(),
323            function: FunctionCall {
324                name: "complex_function".to_string(),
325                arguments: complex_args.to_string(),
326            },
327        };
328
329        let serialized = serde_json::to_string(&tool_call).unwrap();
330        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
331
332        assert_eq!(deserialized.id, "complex_call");
333        assert_eq!(deserialized.function.name, "complex_function");
334        // Arguments should be preserved as string
335        assert!(deserialized.function.arguments.contains("nested"));
336        assert!(deserialized.function.arguments.contains("array"));
337    }
338
339    #[test]
340    fn test_tool_call_with_unicode() {
341        let tool_call = ToolCall {
342            id: "unicode_call".to_string(),
343            call_type: "function".to_string(),
344            function: FunctionCall {
345                name: "unicode_function".to_string(),
346                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
347            },
348        };
349
350        let serialized = serde_json::to_string(&tool_call).unwrap();
351        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
352
353        assert_eq!(deserialized.id, "unicode_call");
354        assert_eq!(deserialized.function.name, "unicode_function");
355        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
356    }
357
358    #[test]
359    fn test_tool_call_large_arguments() {
360        let large_arg = "x".repeat(10000);
361        let tool_call = ToolCall {
362            id: "large_call".to_string(),
363            call_type: "function".to_string(),
364            function: FunctionCall {
365                name: "large_function".to_string(),
366                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
367            },
368        };
369
370        let serialized = serde_json::to_string(&tool_call).unwrap();
371        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
372
373        assert_eq!(deserialized.id, "large_call");
374        assert_eq!(deserialized.function.name, "large_function");
375        assert!(deserialized.function.arguments.len() > 10000);
376    }
377
378    // Mock LLM provider for testing
379    struct MockLLMProvider;
380
381    #[async_trait]
382    impl chat::ChatProvider for MockLLMProvider {
383        async fn chat(
384            &self,
385            _messages: &[ChatMessage],
386            _json_schema: Option<StructuredOutputFormat>,
387        ) -> Result<Box<dyn ChatResponse>, LLMError> {
388            Ok(Box::new(MockChatResponse {
389                text: Some("Mock response".into()),
390            }))
391        }
392
393        async fn chat_with_tools(
394            &self,
395            _messages: &[ChatMessage],
396            _tools: Option<&[Tool]>,
397            _json_schema: Option<StructuredOutputFormat>,
398        ) -> Result<Box<dyn ChatResponse>, LLMError> {
399            Ok(Box::new(MockChatResponse {
400                text: Some("Mock response".into()),
401            }))
402        }
403    }
404
405    #[async_trait]
406    impl completion::CompletionProvider for MockLLMProvider {
407        async fn complete(
408            &self,
409            _req: &completion::CompletionRequest,
410            _json_schema: Option<chat::StructuredOutputFormat>,
411        ) -> Result<completion::CompletionResponse, error::LLMError> {
412            Ok(completion::CompletionResponse {
413                text: "Mock completion".to_string(),
414            })
415        }
416    }
417
418    #[async_trait]
419    impl embedding::EmbeddingProvider for MockLLMProvider {
420        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
421            let mut embeddings = Vec::new();
422            for (i, _) in input.iter().enumerate() {
423                embeddings.push(vec![i as f32, (i + 1) as f32]);
424            }
425            Ok(embeddings)
426        }
427    }
428
429    #[async_trait]
430    impl models::ModelsProvider for MockLLMProvider {}
431
432    impl LLMProvider for MockLLMProvider {}
433
434    struct MockChatResponse {
435        text: Option<String>,
436    }
437
438    impl chat::ChatResponse for MockChatResponse {
439        fn text(&self) -> Option<String> {
440            self.text.clone()
441        }
442
443        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
444            None
445        }
446    }
447
448    impl std::fmt::Debug for MockChatResponse {
449        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450            write!(f, "MockChatResponse")
451        }
452    }
453
454    impl std::fmt::Display for MockChatResponse {
455        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456            write!(f, "{}", self.text.as_deref().unwrap_or(""))
457        }
458    }
459
460    #[tokio::test]
461    async fn test_llm_provider_trait_chat() {
462        let provider = MockLLMProvider;
463        let messages = vec![chat::ChatMessage::user().content("Test").build()];
464
465        let response = provider.chat(&messages, None).await.unwrap();
466        assert_eq!(response.text(), Some("Mock response".to_string()));
467    }
468
469    #[tokio::test]
470    async fn test_llm_provider_trait_completion() {
471        let provider = MockLLMProvider;
472        let request = completion::CompletionRequest::new("Test prompt");
473
474        let response = provider.complete(&request, None).await.unwrap();
475        assert_eq!(response.text, "Mock completion");
476    }
477
478    #[tokio::test]
479    async fn test_llm_provider_trait_embedding() {
480        let provider = MockLLMProvider;
481        let input = vec!["First".to_string(), "Second".to_string()];
482
483        let embeddings = provider.embed(input).await.unwrap();
484        assert_eq!(embeddings.len(), 2);
485        assert_eq!(embeddings[0], vec![0.0, 1.0]);
486        assert_eq!(embeddings[1], vec![1.0, 2.0]);
487    }
488}