Skip to main content

autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
21pub mod backends;
22
23/// Builder pattern for configuring and instantiating LLM providers
24pub mod builder;
25
26/// Chat-based interactions with language models (e.g. ChatGPT style)
27pub mod chat;
28
29/// Text completion capabilities (e.g. GPT-3 style completion)
30pub mod completion;
31
32/// Vector embeddings generation for text
33pub mod embedding;
34
35/// Error types and handling
36pub mod error;
37
38/// Evaluator for LLM providers
39pub mod evaluator;
40
41/// Secret store for storing API keys and other sensitive information
42#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45/// Listing models support
46pub mod models;
47
48mod protocol;
49pub mod providers;
50
51//Re-export for convenience
52pub use async_trait::async_trait;
53
54/// Core trait that all LLM providers must implement, combining chat, completion
55/// and embedding capabilities into a unified interface
56pub trait LLMProvider:
57    chat::ChatProvider
58    + completion::CompletionProvider
59    + embedding::EmbeddingProvider
60    + models::ModelsProvider
61    + Send
62    + Sync
63    + 'static
64{
65}
66
67/// Tool call represents a function call that an LLM wants to make.
68/// This is a standardized structure used across all providers.
69#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
70pub struct ToolCall {
71    /// The ID of the tool call.
72    pub id: String,
73    /// The type of the tool call (usually "function").
74    #[serde(rename = "type")]
75    pub call_type: String,
76    /// The function to call.
77    pub function: FunctionCall,
78}
79
80impl Display for ToolCall {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(
83            f,
84            "ToolCall {{ id: {}, type: {}, function: {:?} }}",
85            self.id, self.call_type, self.function
86        )
87    }
88}
89
90/// FunctionCall contains details about which function to call and with what arguments.
91#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
92pub struct FunctionCall {
93    /// The name of the function to call.
94    pub name: String,
95    /// The arguments to pass to the function, typically serialized as a JSON string.
96    pub arguments: String,
97}
98
99/// Default value for call_type field in ToolCall
100pub fn default_call_type() -> String {
101    "function".to_string()
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
108    use crate::completion::CompletionProvider;
109    use crate::embedding::EmbeddingProvider;
110    use crate::error::LLMError;
111    use async_trait::async_trait;
112    use serde_json::json;
113
114    #[test]
115    fn test_tool_call_creation() {
116        let tool_call = ToolCall {
117            id: "call_123".to_string(),
118            call_type: "function".to_string(),
119            function: FunctionCall {
120                name: "test_function".to_string(),
121                arguments: "{\"param\": \"value\"}".to_string(),
122            },
123        };
124
125        assert_eq!(tool_call.id, "call_123");
126        assert_eq!(tool_call.call_type, "function");
127        assert_eq!(tool_call.function.name, "test_function");
128        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
129    }
130
131    #[test]
132    fn test_tool_call_serialization() {
133        let tool_call = ToolCall {
134            id: "call_456".to_string(),
135            call_type: "function".to_string(),
136            function: FunctionCall {
137                name: "serialize_test".to_string(),
138                arguments: "{\"test\": true}".to_string(),
139            },
140        };
141
142        let serialized = serde_json::to_string(&tool_call).unwrap();
143        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
144
145        assert_eq!(deserialized.id, "call_456");
146        assert_eq!(deserialized.call_type, "function");
147        assert_eq!(deserialized.function.name, "serialize_test");
148        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
149    }
150
151    #[test]
152    fn test_tool_call_equality() {
153        let tool_call1 = ToolCall {
154            id: "call_1".to_string(),
155            call_type: "function".to_string(),
156            function: FunctionCall {
157                name: "equal_test".to_string(),
158                arguments: "{}".to_string(),
159            },
160        };
161
162        let tool_call2 = ToolCall {
163            id: "call_1".to_string(),
164            call_type: "function".to_string(),
165            function: FunctionCall {
166                name: "equal_test".to_string(),
167                arguments: "{}".to_string(),
168            },
169        };
170
171        let tool_call3 = ToolCall {
172            id: "call_2".to_string(),
173            call_type: "function".to_string(),
174            function: FunctionCall {
175                name: "equal_test".to_string(),
176                arguments: "{}".to_string(),
177            },
178        };
179
180        assert_eq!(tool_call1, tool_call2);
181        assert_ne!(tool_call1, tool_call3);
182    }
183
184    #[test]
185    fn test_tool_call_clone() {
186        let tool_call = ToolCall {
187            id: "clone_test".to_string(),
188            call_type: "function".to_string(),
189            function: FunctionCall {
190                name: "test_clone".to_string(),
191                arguments: "{\"clone\": true}".to_string(),
192            },
193        };
194
195        let cloned = tool_call.clone();
196        assert_eq!(tool_call, cloned);
197        assert_eq!(tool_call.id, cloned.id);
198        assert_eq!(tool_call.function.name, cloned.function.name);
199    }
200
201    #[test]
202    fn test_tool_call_debug() {
203        let tool_call = ToolCall {
204            id: "debug_test".to_string(),
205            call_type: "function".to_string(),
206            function: FunctionCall {
207                name: "debug_function".to_string(),
208                arguments: "{}".to_string(),
209            },
210        };
211
212        let debug_str = format!("{tool_call:?}");
213        assert!(debug_str.contains("ToolCall"));
214        assert!(debug_str.contains("debug_test"));
215        assert!(debug_str.contains("debug_function"));
216    }
217
218    #[test]
219    fn test_function_call_creation() {
220        let function_call = FunctionCall {
221            name: "test_function".to_string(),
222            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
223        };
224
225        assert_eq!(function_call.name, "test_function");
226        assert_eq!(
227            function_call.arguments,
228            "{\"param1\": \"value1\", \"param2\": 42}"
229        );
230    }
231
232    #[test]
233    fn test_function_call_serialization() {
234        let function_call = FunctionCall {
235            name: "serialize_function".to_string(),
236            arguments: "{\"data\": [1, 2, 3]}".to_string(),
237        };
238
239        let serialized = serde_json::to_string(&function_call).unwrap();
240        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
241
242        assert_eq!(deserialized.name, "serialize_function");
243        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
244    }
245
246    #[test]
247    fn test_function_call_equality() {
248        let func1 = FunctionCall {
249            name: "equal_func".to_string(),
250            arguments: "{}".to_string(),
251        };
252
253        let func2 = FunctionCall {
254            name: "equal_func".to_string(),
255            arguments: "{}".to_string(),
256        };
257
258        let func3 = FunctionCall {
259            name: "different_func".to_string(),
260            arguments: "{}".to_string(),
261        };
262
263        assert_eq!(func1, func2);
264        assert_ne!(func1, func3);
265    }
266
267    #[test]
268    fn test_function_call_clone() {
269        let function_call = FunctionCall {
270            name: "clone_func".to_string(),
271            arguments: "{\"clone\": \"test\"}".to_string(),
272        };
273
274        let cloned = function_call.clone();
275        assert_eq!(function_call, cloned);
276        assert_eq!(function_call.name, cloned.name);
277        assert_eq!(function_call.arguments, cloned.arguments);
278    }
279
280    #[test]
281    fn test_function_call_debug() {
282        let function_call = FunctionCall {
283            name: "debug_func".to_string(),
284            arguments: "{}".to_string(),
285        };
286
287        let debug_str = format!("{function_call:?}");
288        assert!(debug_str.contains("FunctionCall"));
289        assert!(debug_str.contains("debug_func"));
290    }
291
292    #[test]
293    fn test_tool_call_with_empty_values() {
294        let tool_call = ToolCall {
295            id: String::new(),
296            call_type: String::new(),
297            function: FunctionCall {
298                name: String::new(),
299                arguments: String::new(),
300            },
301        };
302
303        assert!(tool_call.id.is_empty());
304        assert!(tool_call.call_type.is_empty());
305        assert!(tool_call.function.name.is_empty());
306        assert!(tool_call.function.arguments.is_empty());
307    }
308
309    #[test]
310    fn test_tool_call_with_complex_arguments() {
311        let complex_args = json!({
312            "nested": {
313                "array": [1, 2, 3],
314                "object": {
315                    "key": "value"
316                }
317            },
318            "simple": "string"
319        });
320
321        let tool_call = ToolCall {
322            id: "complex_call".to_string(),
323            call_type: "function".to_string(),
324            function: FunctionCall {
325                name: "complex_function".to_string(),
326                arguments: complex_args.to_string(),
327            },
328        };
329
330        let serialized = serde_json::to_string(&tool_call).unwrap();
331        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
332
333        assert_eq!(deserialized.id, "complex_call");
334        assert_eq!(deserialized.function.name, "complex_function");
335        // Arguments should be preserved as string
336        assert!(deserialized.function.arguments.contains("nested"));
337        assert!(deserialized.function.arguments.contains("array"));
338    }
339
340    #[test]
341    fn test_tool_call_with_unicode() {
342        let tool_call = ToolCall {
343            id: "unicode_call".to_string(),
344            call_type: "function".to_string(),
345            function: FunctionCall {
346                name: "unicode_function".to_string(),
347                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
348            },
349        };
350
351        let serialized = serde_json::to_string(&tool_call).unwrap();
352        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
353
354        assert_eq!(deserialized.id, "unicode_call");
355        assert_eq!(deserialized.function.name, "unicode_function");
356        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
357    }
358
359    #[test]
360    fn test_tool_call_large_arguments() {
361        let large_arg = "x".repeat(10000);
362        let tool_call = ToolCall {
363            id: "large_call".to_string(),
364            call_type: "function".to_string(),
365            function: FunctionCall {
366                name: "large_function".to_string(),
367                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
368            },
369        };
370
371        let serialized = serde_json::to_string(&tool_call).unwrap();
372        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
373
374        assert_eq!(deserialized.id, "large_call");
375        assert_eq!(deserialized.function.name, "large_function");
376        assert!(deserialized.function.arguments.len() > 10000);
377    }
378
379    // Mock LLM provider for testing
380    struct MockLLMProvider;
381
382    #[async_trait]
383    impl chat::ChatProvider for MockLLMProvider {
384        async fn chat(
385            &self,
386            _messages: &[ChatMessage],
387            _json_schema: Option<StructuredOutputFormat>,
388        ) -> Result<Box<dyn ChatResponse>, LLMError> {
389            Ok(Box::new(MockChatResponse {
390                text: Some("Mock response".into()),
391            }))
392        }
393
394        async fn chat_with_tools(
395            &self,
396            _messages: &[ChatMessage],
397            _tools: Option<&[Tool]>,
398            _json_schema: Option<StructuredOutputFormat>,
399        ) -> Result<Box<dyn ChatResponse>, LLMError> {
400            Ok(Box::new(MockChatResponse {
401                text: Some("Mock response".into()),
402            }))
403        }
404    }
405
406    #[async_trait]
407    impl completion::CompletionProvider for MockLLMProvider {
408        async fn complete(
409            &self,
410            _req: &completion::CompletionRequest,
411            _json_schema: Option<chat::StructuredOutputFormat>,
412        ) -> Result<completion::CompletionResponse, error::LLMError> {
413            Ok(completion::CompletionResponse {
414                text: "Mock completion".to_string(),
415            })
416        }
417    }
418
419    #[async_trait]
420    impl embedding::EmbeddingProvider for MockLLMProvider {
421        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
422            let mut embeddings = Vec::new();
423            for (i, _) in input.iter().enumerate() {
424                embeddings.push(vec![i as f32, (i + 1) as f32]);
425            }
426            Ok(embeddings)
427        }
428    }
429
430    #[async_trait]
431    impl models::ModelsProvider for MockLLMProvider {}
432
433    impl LLMProvider for MockLLMProvider {}
434
435    struct MockChatResponse {
436        text: Option<String>,
437    }
438
439    impl chat::ChatResponse for MockChatResponse {
440        fn text(&self) -> Option<String> {
441            self.text.clone()
442        }
443
444        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
445            None
446        }
447    }
448
449    impl std::fmt::Debug for MockChatResponse {
450        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451            write!(f, "MockChatResponse")
452        }
453    }
454
455    impl std::fmt::Display for MockChatResponse {
456        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457            write!(f, "{}", self.text.as_deref().unwrap_or(""))
458        }
459    }
460
461    #[tokio::test]
462    async fn test_llm_provider_trait_chat() {
463        let provider = MockLLMProvider;
464        let messages = vec![chat::ChatMessage::user().content("Test").build()];
465
466        let response = provider.chat(&messages, None).await.unwrap();
467        assert_eq!(response.text(), Some("Mock response".to_string()));
468    }
469
470    #[tokio::test]
471    async fn test_llm_provider_trait_completion() {
472        let provider = MockLLMProvider;
473        let request = completion::CompletionRequest::new("Test prompt");
474
475        let response = provider.complete(&request, None).await.unwrap();
476        assert_eq!(response.text, "Mock completion");
477    }
478
479    #[tokio::test]
480    async fn test_llm_provider_trait_embedding() {
481        let provider = MockLLMProvider;
482        let input = vec!["First".to_string(), "Second".to_string()];
483
484        let embeddings = provider.embed(input).await.unwrap();
485        assert_eq!(embeddings.len(), 2);
486        assert_eq!(embeddings[0], vec![0.0, 1.0]);
487        assert_eq!(embeddings[1], vec![1.0, 2.0]);
488    }
489}