Skip to main content

autoagents_llm/
lib.rs

1//! AutoAgents LLM is a unified interface for interacting with Large Language Model providers.
2//!
3//! # Overview
4//! This crate provides a consistent API for working with different LLM backends by abstracting away
5//! provider-specific implementation details. It supports:
6//!
7//! - Chat-based interactions
8//! - Text completion
9//! - Embeddings generation
10//! - Multiple providers (OpenAI, Anthropic, etc.)
11//! - Request validation and retry logic
12//!
13//! # Architecture
14//! The crate is organized into modules that handle different aspects of LLM interactions:
15
16use std::fmt::Display;
17
18use serde::{Deserialize, Serialize};
19
20/// Backend implementations for supported LLM providers like OpenAI, Anthropic, etc.
21pub mod backends;
22
23/// Builder pattern for configuring and instantiating LLM providers
24pub mod builder;
25
26/// Chat-based interactions with language models (e.g. ChatGPT style)
27pub mod chat;
28
29/// Text completion capabilities (e.g. GPT-3 style completion)
30pub mod completion;
31
32/// Vector embeddings generation for text
33pub mod embedding;
34
35/// Error types and handling
36pub mod error;
37
38/// Evaluator for LLM providers
39pub mod evaluator;
40
41/// Secret store for storing API keys and other sensitive information
42#[cfg(not(target_arch = "wasm32"))]
43pub mod secret_store;
44
45/// Listing models support
46pub mod models;
47
48mod protocol;
49pub mod providers;
50
51/// Composable optimization pipeline for LLM providers.
52pub mod pipeline;
53
54/// Built-in optimization passes (cache, etc.). Not available on WASM.
55#[cfg(all(not(target_arch = "wasm32"), feature = "optim"))]
56pub mod optim;
57
58//Re-export for convenience
59pub use async_trait::async_trait;
60
61/// Unit config for providers with no provider-specific options.
62#[derive(Debug, Default, Clone)]
63pub struct NoConfig;
64
65/// Provides an associated configuration type for LLM provider builders.
66///
67/// Implement this alongside [`LLMProvider`] to expose provider-specific builder
68/// options. Use [`NoConfig`] for providers with no special options.
69pub trait HasConfig {
70    /// Provider-specific configuration type.
71    type Config: Default + Send + Sync + 'static;
72}
73
74/// Core trait that all LLM providers must implement, combining chat, completion
75/// and embedding capabilities into a unified interface
76pub trait LLMProvider:
77    chat::ChatProvider
78    + completion::CompletionProvider
79    + embedding::EmbeddingProvider
80    + models::ModelsProvider
81    + Send
82    + Sync
83    + 'static
84{
85}
86
87/// Tool call represents a function call that an LLM wants to make.
88/// This is a standardized structure used across all providers.
89#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
90pub struct ToolCall {
91    /// The ID of the tool call.
92    pub id: String,
93    /// The type of the tool call (usually "function").
94    #[serde(rename = "type")]
95    pub call_type: String,
96    /// The function to call.
97    pub function: FunctionCall,
98}
99
100impl Display for ToolCall {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(
103            f,
104            "ToolCall {{ id: {}, type: {}, function: {:?} }}",
105            self.id, self.call_type, self.function
106        )
107    }
108}
109
110/// FunctionCall contains details about which function to call and with what arguments.
111#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
112pub struct FunctionCall {
113    /// The name of the function to call.
114    pub name: String,
115    /// The arguments to pass to the function, typically serialized as a JSON string.
116    pub arguments: String,
117}
118
119/// Default value for call_type field in ToolCall
120pub fn default_call_type() -> String {
121    "function".to_string()
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::HasConfig;
128    use crate::chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat, Tool};
129    use crate::completion::CompletionProvider;
130    use crate::embedding::EmbeddingProvider;
131    use crate::error::LLMError;
132    use async_trait::async_trait;
133    use serde_json::json;
134
135    #[test]
136    fn test_tool_call_creation() {
137        let tool_call = ToolCall {
138            id: "call_123".to_string(),
139            call_type: "function".to_string(),
140            function: FunctionCall {
141                name: "test_function".to_string(),
142                arguments: "{\"param\": \"value\"}".to_string(),
143            },
144        };
145
146        assert_eq!(tool_call.id, "call_123");
147        assert_eq!(tool_call.call_type, "function");
148        assert_eq!(tool_call.function.name, "test_function");
149        assert_eq!(tool_call.function.arguments, "{\"param\": \"value\"}");
150    }
151
152    #[test]
153    fn test_tool_call_serialization() {
154        let tool_call = ToolCall {
155            id: "call_456".to_string(),
156            call_type: "function".to_string(),
157            function: FunctionCall {
158                name: "serialize_test".to_string(),
159                arguments: "{\"test\": true}".to_string(),
160            },
161        };
162
163        let serialized = serde_json::to_string(&tool_call).unwrap();
164        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
165
166        assert_eq!(deserialized.id, "call_456");
167        assert_eq!(deserialized.call_type, "function");
168        assert_eq!(deserialized.function.name, "serialize_test");
169        assert_eq!(deserialized.function.arguments, "{\"test\": true}");
170    }
171
172    #[test]
173    fn test_tool_call_equality() {
174        let tool_call1 = ToolCall {
175            id: "call_1".to_string(),
176            call_type: "function".to_string(),
177            function: FunctionCall {
178                name: "equal_test".to_string(),
179                arguments: "{}".to_string(),
180            },
181        };
182
183        let tool_call2 = ToolCall {
184            id: "call_1".to_string(),
185            call_type: "function".to_string(),
186            function: FunctionCall {
187                name: "equal_test".to_string(),
188                arguments: "{}".to_string(),
189            },
190        };
191
192        let tool_call3 = ToolCall {
193            id: "call_2".to_string(),
194            call_type: "function".to_string(),
195            function: FunctionCall {
196                name: "equal_test".to_string(),
197                arguments: "{}".to_string(),
198            },
199        };
200
201        assert_eq!(tool_call1, tool_call2);
202        assert_ne!(tool_call1, tool_call3);
203    }
204
205    #[test]
206    fn test_tool_call_clone() {
207        let tool_call = ToolCall {
208            id: "clone_test".to_string(),
209            call_type: "function".to_string(),
210            function: FunctionCall {
211                name: "test_clone".to_string(),
212                arguments: "{\"clone\": true}".to_string(),
213            },
214        };
215
216        let cloned = tool_call.clone();
217        assert_eq!(tool_call, cloned);
218        assert_eq!(tool_call.id, cloned.id);
219        assert_eq!(tool_call.function.name, cloned.function.name);
220    }
221
222    #[test]
223    fn test_tool_call_debug() {
224        let tool_call = ToolCall {
225            id: "debug_test".to_string(),
226            call_type: "function".to_string(),
227            function: FunctionCall {
228                name: "debug_function".to_string(),
229                arguments: "{}".to_string(),
230            },
231        };
232
233        let debug_str = format!("{tool_call:?}");
234        assert!(debug_str.contains("ToolCall"));
235        assert!(debug_str.contains("debug_test"));
236        assert!(debug_str.contains("debug_function"));
237    }
238
239    #[test]
240    fn test_function_call_creation() {
241        let function_call = FunctionCall {
242            name: "test_function".to_string(),
243            arguments: "{\"param1\": \"value1\", \"param2\": 42}".to_string(),
244        };
245
246        assert_eq!(function_call.name, "test_function");
247        assert_eq!(
248            function_call.arguments,
249            "{\"param1\": \"value1\", \"param2\": 42}"
250        );
251    }
252
253    #[test]
254    fn test_function_call_serialization() {
255        let function_call = FunctionCall {
256            name: "serialize_function".to_string(),
257            arguments: "{\"data\": [1, 2, 3]}".to_string(),
258        };
259
260        let serialized = serde_json::to_string(&function_call).unwrap();
261        let deserialized: FunctionCall = serde_json::from_str(&serialized).unwrap();
262
263        assert_eq!(deserialized.name, "serialize_function");
264        assert_eq!(deserialized.arguments, "{\"data\": [1, 2, 3]}");
265    }
266
267    #[test]
268    fn test_function_call_equality() {
269        let func1 = FunctionCall {
270            name: "equal_func".to_string(),
271            arguments: "{}".to_string(),
272        };
273
274        let func2 = FunctionCall {
275            name: "equal_func".to_string(),
276            arguments: "{}".to_string(),
277        };
278
279        let func3 = FunctionCall {
280            name: "different_func".to_string(),
281            arguments: "{}".to_string(),
282        };
283
284        assert_eq!(func1, func2);
285        assert_ne!(func1, func3);
286    }
287
288    #[test]
289    fn test_function_call_clone() {
290        let function_call = FunctionCall {
291            name: "clone_func".to_string(),
292            arguments: "{\"clone\": \"test\"}".to_string(),
293        };
294
295        let cloned = function_call.clone();
296        assert_eq!(function_call, cloned);
297        assert_eq!(function_call.name, cloned.name);
298        assert_eq!(function_call.arguments, cloned.arguments);
299    }
300
301    #[test]
302    fn test_function_call_debug() {
303        let function_call = FunctionCall {
304            name: "debug_func".to_string(),
305            arguments: "{}".to_string(),
306        };
307
308        let debug_str = format!("{function_call:?}");
309        assert!(debug_str.contains("FunctionCall"));
310        assert!(debug_str.contains("debug_func"));
311    }
312
313    #[test]
314    fn test_tool_call_with_empty_values() {
315        let tool_call = ToolCall {
316            id: String::default(),
317            call_type: String::default(),
318            function: FunctionCall {
319                name: String::default(),
320                arguments: String::default(),
321            },
322        };
323
324        assert!(tool_call.id.is_empty());
325        assert!(tool_call.call_type.is_empty());
326        assert!(tool_call.function.name.is_empty());
327        assert!(tool_call.function.arguments.is_empty());
328    }
329
330    #[test]
331    fn test_tool_call_with_complex_arguments() {
332        let complex_args = json!({
333            "nested": {
334                "array": [1, 2, 3],
335                "object": {
336                    "key": "value"
337                }
338            },
339            "simple": "string"
340        });
341
342        let tool_call = ToolCall {
343            id: "complex_call".to_string(),
344            call_type: "function".to_string(),
345            function: FunctionCall {
346                name: "complex_function".to_string(),
347                arguments: complex_args.to_string(),
348            },
349        };
350
351        let serialized = serde_json::to_string(&tool_call).unwrap();
352        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
353
354        assert_eq!(deserialized.id, "complex_call");
355        assert_eq!(deserialized.function.name, "complex_function");
356        // Arguments should be preserved as string
357        assert!(deserialized.function.arguments.contains("nested"));
358        assert!(deserialized.function.arguments.contains("array"));
359    }
360
361    #[test]
362    fn test_tool_call_with_unicode() {
363        let tool_call = ToolCall {
364            id: "unicode_call".to_string(),
365            call_type: "function".to_string(),
366            function: FunctionCall {
367                name: "unicode_function".to_string(),
368                arguments: "{\"message\": \"Hello δΈ–η•Œ! 🌍\"}".to_string(),
369            },
370        };
371
372        let serialized = serde_json::to_string(&tool_call).unwrap();
373        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
374
375        assert_eq!(deserialized.id, "unicode_call");
376        assert_eq!(deserialized.function.name, "unicode_function");
377        assert!(deserialized.function.arguments.contains("Hello δΈ–η•Œ! 🌍"));
378    }
379
380    #[test]
381    fn test_tool_call_large_arguments() {
382        let large_arg = "x".repeat(10000);
383        let tool_call = ToolCall {
384            id: "large_call".to_string(),
385            call_type: "function".to_string(),
386            function: FunctionCall {
387                name: "large_function".to_string(),
388                arguments: format!("{{\"large_param\": \"{large_arg}\"}}"),
389            },
390        };
391
392        let serialized = serde_json::to_string(&tool_call).unwrap();
393        let deserialized: ToolCall = serde_json::from_str(&serialized).unwrap();
394
395        assert_eq!(deserialized.id, "large_call");
396        assert_eq!(deserialized.function.name, "large_function");
397        assert!(deserialized.function.arguments.len() > 10000);
398    }
399
400    // Mock LLM provider for testing
401    struct MockLLMProvider;
402
403    #[async_trait]
404    impl chat::ChatProvider for MockLLMProvider {
405        async fn chat(
406            &self,
407            _messages: &[ChatMessage],
408            _json_schema: Option<StructuredOutputFormat>,
409        ) -> Result<Box<dyn ChatResponse>, LLMError> {
410            Ok(Box::new(MockChatResponse {
411                text: Some("Mock response".into()),
412            }))
413        }
414
415        async fn chat_with_tools(
416            &self,
417            _messages: &[ChatMessage],
418            _tools: Option<&[Tool]>,
419            _json_schema: Option<StructuredOutputFormat>,
420        ) -> Result<Box<dyn ChatResponse>, LLMError> {
421            Ok(Box::new(MockChatResponse {
422                text: Some("Mock response".into()),
423            }))
424        }
425    }
426
427    #[async_trait]
428    impl completion::CompletionProvider for MockLLMProvider {
429        async fn complete(
430            &self,
431            _req: &completion::CompletionRequest,
432            _json_schema: Option<chat::StructuredOutputFormat>,
433        ) -> Result<completion::CompletionResponse, error::LLMError> {
434            Ok(completion::CompletionResponse {
435                text: "Mock completion".to_string(),
436            })
437        }
438    }
439
440    #[async_trait]
441    impl embedding::EmbeddingProvider for MockLLMProvider {
442        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, error::LLMError> {
443            let mut embeddings = Vec::new();
444            for (i, _) in input.iter().enumerate() {
445                embeddings.push(vec![i as f32, (i + 1) as f32]);
446            }
447            Ok(embeddings)
448        }
449    }
450
451    #[async_trait]
452    impl models::ModelsProvider for MockLLMProvider {}
453
454    impl LLMProvider for MockLLMProvider {}
455
456    impl HasConfig for MockLLMProvider {
457        type Config = NoConfig;
458    }
459
460    struct MockChatResponse {
461        text: Option<String>,
462    }
463
464    impl chat::ChatResponse for MockChatResponse {
465        fn text(&self) -> Option<String> {
466            self.text.clone()
467        }
468
469        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
470            None
471        }
472    }
473
474    impl std::fmt::Debug for MockChatResponse {
475        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476            write!(f, "MockChatResponse")
477        }
478    }
479
480    impl std::fmt::Display for MockChatResponse {
481        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482            write!(f, "{}", self.text.as_deref().unwrap_or(""))
483        }
484    }
485
486    #[tokio::test]
487    async fn test_llm_provider_trait_chat() {
488        let provider = MockLLMProvider;
489        let messages = vec![chat::ChatMessage::user().content("Test").build()];
490
491        let response = provider.chat(&messages, None).await.unwrap();
492        assert_eq!(response.text(), Some("Mock response".to_string()));
493    }
494
495    #[tokio::test]
496    async fn test_llm_provider_trait_completion() {
497        let provider = MockLLMProvider;
498        let request = completion::CompletionRequest::new("Test prompt");
499
500        let response = provider.complete(&request, None).await.unwrap();
501        assert_eq!(response.text, "Mock completion");
502    }
503
504    #[tokio::test]
505    async fn test_llm_provider_trait_embedding() {
506        let provider = MockLLMProvider;
507        let input = vec!["First".to_string(), "Second".to_string()];
508
509        let embeddings = provider.embed(input).await.unwrap();
510        assert_eq!(embeddings.len(), 2);
511        assert_eq!(embeddings[0], vec![0.0, 1.0]);
512        assert_eq!(embeddings[1], vec![1.0, 2.0]);
513    }
514}