Skip to main content

autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
21pub mod backends;
22
23/// Builder pattern for configuring and instantiating LLM providers
24pub mod builder;
25
26/// Chat-based interactions with language models (e.g. ChatGPT style)
27pub mod chat;
28
29/// Text completion capabilities (e.g. GPT-3 style completion)
30pub mod completion;
31
32/// Vector embeddings generation for text
33pub mod embedding;
34
35/// Error types and handling
36pub mod error;
37
38/// Evaluator for LLM providers
39pub mod evaluator;
40
41/// Secret store for storing API keys and other sensitive information
42#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45/// Listing models support
46pub mod models;
47
48mod protocol;
49pub mod providers;
50
51//Re-export for convenience
52pub use async_trait::async_trait;
53
54/// Unit config for providers with no provider-specific options.
55#[derive(Debug, Default, Clone)]
56pub struct NoConfig;
57
58/// Provides an associated configuration type for LLM provider builders.
59///
60/// Implement this alongside [`LLMProvider`] to expose provider-specific builder
61/// options. Use [`NoConfig`] for providers with no special options.
62pub trait HasConfig {
63    /// Provider-specific configuration type.
64    type Config: Default + Send + Sync + 'static;
65}
66
67/// Core trait that all LLM providers must implement, combining chat, completion
68/// and embedding capabilities into a unified interface
69pub trait LLMProvider:
70    chat::ChatProvider
71    + completion::CompletionProvider
72    + embedding::EmbeddingProvider
73    + models::ModelsProvider
74    + Send
75    + Sync
76    + 'static
77{
78}
79
80/// Tool call represents a function call that an LLM wants to make.
81/// This is a standardized structure used across all providers.
82#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
83pub struct ToolCall {
84    /// The ID of the tool call.
85    pub id: String,
86    /// The type of the tool call (usually "function").
87    #[serde(rename = "type")]
88    pub call_type: String,
89    /// The function to call.
90    pub function: FunctionCall,
91}
92
93impl Display for ToolCall {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(
96            f,
97            "ToolCall {{ id: {}, type: {}, function: {:?} }}",
98            self.id, self.call_type, self.function
99        )
100    }
101}
102
103/// FunctionCall contains details about which function to call and with what arguments.
104#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
105pub struct FunctionCall {
106    /// The name of the function to call.
107    pub name: String,
108    /// The arguments to pass to the function, typically serialized as a JSON string.
109    pub arguments: String,
110}
111
112/// Default value for call_type field in ToolCall
113pub fn default_call_type() -> String {
114    "function".to_string()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::HasConfig;
121    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
122    use crate::completion::CompletionProvider;
123    use crate::embedding::EmbeddingProvider;
124    use crate::error::LLMError;
125    use async_trait::async_trait;
126    use serde_json::json;
127
128    #[test]
129    fn test_tool_call_creation() {
130        let tool_call = ToolCall {
131            id: "call_123".to_string(),
132            call_type: "function".to_string(),
133            function: FunctionCall {
134                name: "test_function".to_string(),
135                arguments: "{\"param\": \"value\"}".to_string(),
136            },
137        };
138
139        assert_eq!(tool_call.id, "call_123");
140        assert_eq!(tool_call.call_type, "function");
141        assert_eq!(tool_call.function.name, "test_function");
142        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
143    }
144
145    #[test]
146    fn test_tool_call_serialization() {
147        let tool_call = ToolCall {
148            id: "call_456".to_string(),
149            call_type: "function".to_string(),
150            function: FunctionCall {
151                name: "serialize_test".to_string(),
152                arguments: "{\"test\": true}".to_string(),
153            },
154        };
155
156        let serialized = serde_json::to_string(&tool_call).unwrap();
157        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
158
159        assert_eq!(deserialized.id, "call_456");
160        assert_eq!(deserialized.call_type, "function");
161        assert_eq!(deserialized.function.name, "serialize_test");
162        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
163    }
164
165    #[test]
166    fn test_tool_call_equality() {
167        let tool_call1 = ToolCall {
168            id: "call_1".to_string(),
169            call_type: "function".to_string(),
170            function: FunctionCall {
171                name: "equal_test".to_string(),
172                arguments: "{}".to_string(),
173            },
174        };
175
176        let tool_call2 = ToolCall {
177            id: "call_1".to_string(),
178            call_type: "function".to_string(),
179            function: FunctionCall {
180                name: "equal_test".to_string(),
181                arguments: "{}".to_string(),
182            },
183        };
184
185        let tool_call3 = ToolCall {
186            id: "call_2".to_string(),
187            call_type: "function".to_string(),
188            function: FunctionCall {
189                name: "equal_test".to_string(),
190                arguments: "{}".to_string(),
191            },
192        };
193
194        assert_eq!(tool_call1, tool_call2);
195        assert_ne!(tool_call1, tool_call3);
196    }
197
198    #[test]
199    fn test_tool_call_clone() {
200        let tool_call = ToolCall {
201            id: "clone_test".to_string(),
202            call_type: "function".to_string(),
203            function: FunctionCall {
204                name: "test_clone".to_string(),
205                arguments: "{\"clone\": true}".to_string(),
206            },
207        };
208
209        let cloned = tool_call.clone();
210        assert_eq!(tool_call, cloned);
211        assert_eq!(tool_call.id, cloned.id);
212        assert_eq!(tool_call.function.name, cloned.function.name);
213    }
214
215    #[test]
216    fn test_tool_call_debug() {
217        let tool_call = ToolCall {
218            id: "debug_test".to_string(),
219            call_type: "function".to_string(),
220            function: FunctionCall {
221                name: "debug_function".to_string(),
222                arguments: "{}".to_string(),
223            },
224        };
225
226        let debug_str = format!("{tool_call:?}");
227        assert!(debug_str.contains("ToolCall"));
228        assert!(debug_str.contains("debug_test"));
229        assert!(debug_str.contains("debug_function"));
230    }
231
232    #[test]
233    fn test_function_call_creation() {
234        let function_call = FunctionCall {
235            name: "test_function".to_string(),
236            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
237        };
238
239        assert_eq!(function_call.name, "test_function");
240        assert_eq!(
241            function_call.arguments,
242            "{\"param1\": \"value1\", \"param2\": 42}"
243        );
244    }
245
246    #[test]
247    fn test_function_call_serialization() {
248        let function_call = FunctionCall {
249            name: "serialize_function".to_string(),
250            arguments: "{\"data\": [1, 2, 3]}".to_string(),
251        };
252
253        let serialized = serde_json::to_string(&function_call).unwrap();
254        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
255
256        assert_eq!(deserialized.name, "serialize_function");
257        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
258    }
259
260    #[test]
261    fn test_function_call_equality() {
262        let func1 = FunctionCall {
263            name: "equal_func".to_string(),
264            arguments: "{}".to_string(),
265        };
266
267        let func2 = FunctionCall {
268            name: "equal_func".to_string(),
269            arguments: "{}".to_string(),
270        };
271
272        let func3 = FunctionCall {
273            name: "different_func".to_string(),
274            arguments: "{}".to_string(),
275        };
276
277        assert_eq!(func1, func2);
278        assert_ne!(func1, func3);
279    }
280
281    #[test]
282    fn test_function_call_clone() {
283        let function_call = FunctionCall {
284            name: "clone_func".to_string(),
285            arguments: "{\"clone\": \"test\"}".to_string(),
286        };
287
288        let cloned = function_call.clone();
289        assert_eq!(function_call, cloned);
290        assert_eq!(function_call.name, cloned.name);
291        assert_eq!(function_call.arguments, cloned.arguments);
292    }
293
294    #[test]
295    fn test_function_call_debug() {
296        let function_call = FunctionCall {
297            name: "debug_func".to_string(),
298            arguments: "{}".to_string(),
299        };
300
301        let debug_str = format!("{function_call:?}");
302        assert!(debug_str.contains("FunctionCall"));
303        assert!(debug_str.contains("debug_func"));
304    }
305
306    #[test]
307    fn test_tool_call_with_empty_values() {
308        let tool_call = ToolCall {
309            id: String::default(),
310            call_type: String::default(),
311            function: FunctionCall {
312                name: String::default(),
313                arguments: String::default(),
314            },
315        };
316
317        assert!(tool_call.id.is_empty());
318        assert!(tool_call.call_type.is_empty());
319        assert!(tool_call.function.name.is_empty());
320        assert!(tool_call.function.arguments.is_empty());
321    }
322
323    #[test]
324    fn test_tool_call_with_complex_arguments() {
325        let complex_args = json!({
326            "nested": {
327                "array": [1, 2, 3],
328                "object": {
329                    "key": "value"
330                }
331            },
332            "simple": "string"
333        });
334
335        let tool_call = ToolCall {
336            id: "complex_call".to_string(),
337            call_type: "function".to_string(),
338            function: FunctionCall {
339                name: "complex_function".to_string(),
340                arguments: complex_args.to_string(),
341            },
342        };
343
344        let serialized = serde_json::to_string(&tool_call).unwrap();
345        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
346
347        assert_eq!(deserialized.id, "complex_call");
348        assert_eq!(deserialized.function.name, "complex_function");
349        // Arguments should be preserved as string
350        assert!(deserialized.function.arguments.contains("nested"));
351        assert!(deserialized.function.arguments.contains("array"));
352    }
353
354    #[test]
355    fn test_tool_call_with_unicode() {
356        let tool_call = ToolCall {
357            id: "unicode_call".to_string(),
358            call_type: "function".to_string(),
359            function: FunctionCall {
360                name: "unicode_function".to_string(),
361                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
362            },
363        };
364
365        let serialized = serde_json::to_string(&tool_call).unwrap();
366        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
367
368        assert_eq!(deserialized.id, "unicode_call");
369        assert_eq!(deserialized.function.name, "unicode_function");
370        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
371    }
372
373    #[test]
374    fn test_tool_call_large_arguments() {
375        let large_arg = "x".repeat(10000);
376        let tool_call = ToolCall {
377            id: "large_call".to_string(),
378            call_type: "function".to_string(),
379            function: FunctionCall {
380                name: "large_function".to_string(),
381                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
382            },
383        };
384
385        let serialized = serde_json::to_string(&tool_call).unwrap();
386        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
387
388        assert_eq!(deserialized.id, "large_call");
389        assert_eq!(deserialized.function.name, "large_function");
390        assert!(deserialized.function.arguments.len() > 10000);
391    }
392
393    // Mock LLM provider for testing
394    struct MockLLMProvider;
395
396    #[async_trait]
397    impl chat::ChatProvider for MockLLMProvider {
398        async fn chat(
399            &self,
400            _messages: &[ChatMessage],
401            _json_schema: Option<StructuredOutputFormat>,
402        ) -> Result<Box<dyn ChatResponse>, LLMError> {
403            Ok(Box::new(MockChatResponse {
404                text: Some("Mock response".into()),
405            }))
406        }
407
408        async fn chat_with_tools(
409            &self,
410            _messages: &[ChatMessage],
411            _tools: Option<&[Tool]>,
412            _json_schema: Option<StructuredOutputFormat>,
413        ) -> Result<Box<dyn ChatResponse>, LLMError> {
414            Ok(Box::new(MockChatResponse {
415                text: Some("Mock response".into()),
416            }))
417        }
418    }
419
420    #[async_trait]
421    impl completion::CompletionProvider for MockLLMProvider {
422        async fn complete(
423            &self,
424            _req: &completion::CompletionRequest,
425            _json_schema: Option<chat::StructuredOutputFormat>,
426        ) -> Result<completion::CompletionResponse, error::LLMError> {
427            Ok(completion::CompletionResponse {
428                text: "Mock completion".to_string(),
429            })
430        }
431    }
432
433    #[async_trait]
434    impl embedding::EmbeddingProvider for MockLLMProvider {
435        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
436            let mut embeddings = Vec::new();
437            for (i, _) in input.iter().enumerate() {
438                embeddings.push(vec![i as f32, (i + 1) as f32]);
439            }
440            Ok(embeddings)
441        }
442    }
443
444    #[async_trait]
445    impl models::ModelsProvider for MockLLMProvider {}
446
447    impl LLMProvider for MockLLMProvider {}
448
449    impl HasConfig for MockLLMProvider {
450        type Config = NoConfig;
451    }
452
453    struct MockChatResponse {
454        text: Option<String>,
455    }
456
457    impl chat::ChatResponse for MockChatResponse {
458        fn text(&self) -> Option<String> {
459            self.text.clone()
460        }
461
462        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
463            None
464        }
465    }
466
467    impl std::fmt::Debug for MockChatResponse {
468        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469            write!(f, "MockChatResponse")
470        }
471    }
472
473    impl std::fmt::Display for MockChatResponse {
474        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475            write!(f, "{}", self.text.as_deref().unwrap_or(""))
476        }
477    }
478
479    #[tokio::test]
480    async fn test_llm_provider_trait_chat() {
481        let provider = MockLLMProvider;
482        let messages = vec![chat::ChatMessage::user().content("Test").build()];
483
484        let response = provider.chat(&messages, None).await.unwrap();
485        assert_eq!(response.text(), Some("Mock response".to_string()));
486    }
487
488    #[tokio::test]
489    async fn test_llm_provider_trait_completion() {
490        let provider = MockLLMProvider;
491        let request = completion::CompletionRequest::new("Test prompt");
492
493        let response = provider.complete(&request, None).await.unwrap();
494        assert_eq!(response.text, "Mock completion");
495    }
496
497    #[tokio::test]
498    async fn test_llm_provider_trait_embedding() {
499        let provider = MockLLMProvider;
500        let input = vec!["First".to_string(), "Second".to_string()];
501
502        let embeddings = provider.embed(input).await.unwrap();
503        assert_eq!(embeddings.len(), 2);
504        assert_eq!(embeddings[0], vec![0.0, 1.0]);
505        assert_eq!(embeddings[1], vec![1.0, 2.0]);
506    }
507}