Skip to main content

aix_core/
traits.rs

1//! Core traits and abstractions for AI providers.
2//!
3//! This module defines the `AiProvider` trait that all providers must implement,
4//! along with related types and capabilities.
5
6use crate::error::{AixError, AixResult};
7use crate::types::{ChatRequest, ChatResponse};
8use crate::streaming::TokenStream;
9use async_trait::async_trait;
10
11/// Capabilities of a model/provider.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelCapabilities {
14    /// Whether the provider supports streaming responses
15    pub supports_streaming: bool,
16    /// Whether the provider supports function/tool calling
17    pub supports_function_calling: bool,
18    /// Whether the provider supports vision/image inputs
19    pub supports_vision: bool,
20    /// Maximum number of tokens that can be generated
21    pub max_tokens: u32,
22    /// Maximum context window size (prompt + completion)
23    pub max_context_window: u32,
24}
25
26impl ModelCapabilities {
27    /// Create a new capabilities specification.
28    pub fn new(
29        supports_streaming: bool,
30        supports_function_calling: bool,
31        supports_vision: bool,
32        max_tokens: u32,
33        max_context_window: u32,
34    ) -> Self {
35        Self {
36            supports_streaming,
37            supports_function_calling,
38            supports_vision,
39            max_tokens,
40            max_context_window,
41        }
42    }
43
44    /// Create capabilities for a basic text-only model without streaming.
45    pub fn basic_text(max_tokens: u32, max_context_window: u32) -> Self {
46        Self::new(false, false, false, max_tokens, max_context_window)
47    }
48
49    /// Create capabilities for a full-featured model.
50    pub fn full_featured(max_tokens: u32, max_context_window: u32) -> Self {
51        Self::new(true, true, true, max_tokens, max_context_window)
52    }
53
54    /// Create capabilities for a streaming text model.
55    pub fn streaming_text(max_tokens: u32, max_context_window: u32) -> Self {
56        Self::new(true, false, false, max_tokens, max_context_window)
57    }
58}
59
60/// Core trait that all AI providers must implement.
61///
62/// This trait provides a unified interface for interacting with different
63/// AI providers (OpenAI, Anthropic, etc.) while allowing each provider
64/// to handle its own specifics internally.
65#[async_trait]
66pub trait AiProvider: Send + Sync {
67    /// Execute a chat completion request.
68    ///
69    /// # Arguments
70    /// * `request` - The chat completion request
71    ///
72    /// # Returns
73    /// A `ChatResponse` containing the generated completion
74    ///
75    /// # Errors
76    /// Returns an `AixError` if the request fails
77    async fn chat(&self, request: ChatRequest) -> AixResult<ChatResponse>;
78
79    /// Execute a streaming chat completion request.
80    ///
81    /// # Arguments
82    /// * `request` - The chat completion request with streaming enabled
83    ///
84    /// # Returns
85    /// A `TokenStream` that yields `StreamChunk` items as they are generated
86    ///
87    /// # Errors
88    /// Returns an `AixError` if the stream cannot be established
89    async fn chat_stream(&self, request: ChatRequest) -> AixResult<TokenStream>;
90
91    /// Get the name of this provider.
92    ///
93    /// # Returns
94    /// A string slice containing the provider name (e.g., "openai", "anthropic")
95    fn provider_name(&self) -> &str;
96
97    /// Get the capabilities of this provider.
98    ///
99    /// # Returns
100    /// A `ModelCapabilities` struct describing what this provider supports
101    fn capabilities(&self) -> ModelCapabilities;
102
103    /// Check if this provider supports streaming.
104    ///
105    /// # Returns
106    /// `true` if streaming is supported, `false` otherwise
107    fn supports_streaming(&self) -> bool {
108        self.capabilities().supports_streaming
109    }
110
111    /// Check if this provider supports function calling.
112    ///
113    /// # Returns
114    /// `true` if function calling is supported, `false` otherwise
115    fn supports_function_calling(&self) -> bool {
116        self.capabilities().supports_function_calling
117    }
118
119    /// Check if this provider supports vision/image inputs.
120    ///
121    /// # Returns
122    /// `true` if vision is supported, `false` otherwise
123    fn supports_vision(&self) -> bool {
124        self.capabilities().supports_vision
125    }
126
127    /// Get the maximum number of tokens this provider can generate.
128    ///
129    /// # Returns
130    /// The maximum number of tokens for a single completion
131    fn max_tokens(&self) -> u32 {
132        self.capabilities().max_tokens
133    }
134
135    /// Get the maximum context window size.
136    ///
137    /// # Returns
138    /// The maximum number of tokens (prompt + completion) that can be processed
139    fn max_context_window(&self) -> u32 {
140        self.capabilities().max_context_window
141    }
142
143    /// Validate a chat request before sending it.
144    ///
145    /// This method allows providers to perform provider-specific validation.
146    /// The default implementation performs basic validation that applies to all providers.
147    ///
148    /// # Arguments
149    /// * `request` - The chat request to validate
150    ///
151    /// # Returns
152    /// `Ok(())` if the request is valid, or an `AixError` if validation fails
153    fn validate_request(&self, request: &ChatRequest) -> AixResult<()> {
154        // Basic validation that applies to all providers
155        if request.model.is_empty() {
156            return Err(AixError::config("Model name cannot be empty"));
157        }
158
159        if request.messages.is_empty() {
160            return Err(AixError::config("Messages cannot be empty"));
161        }
162
163        // Check that we don't exceed the max tokens if specified
164        if let Some(max_tokens) = request.config.max_tokens {
165            if max_tokens > self.max_tokens() {
166                return Err(AixError::config(format!(
167                    "Requested max_tokens ({}) exceeds provider limit ({})",
168                    max_tokens,
169                    self.max_tokens()
170                )));
171            }
172        }
173
174        // Validate message content
175        for (i, message) in request.messages.iter().enumerate() {
176            if message.content.is_empty() {
177                return Err(AixError::config(format!(
178                    "Message {} has empty content",
179                    i + 1
180                )));
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Estimate the number of tokens in a request.
188    ///
189    /// This is a rough estimate and should not be relied upon for exact token counting.
190    /// Different providers may use different tokenization methods.
191    ///
192    /// # Arguments
193    /// * `request` - The chat request to estimate tokens for
194    ///
195    /// # Returns
196    /// An estimated token count
197    fn estimate_tokens(&self, request: &ChatRequest) -> u32 {
198        // Simple estimation: roughly 4 characters per token
199        // This is a very rough estimate and providers should override this
200        // with their own tokenization if available
201        let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
202        (total_chars / 4) as u32
203    }
204
205    /// Check if a request is likely to fit within the context window.
206    ///
207    /// # Arguments
208    /// * `request` - The chat request to check
209    ///
210    /// # Returns
211    /// `true` if the request is likely to fit, `false` otherwise
212    fn fits_in_context(&self, request: &ChatRequest) -> bool {
213        let estimated_tokens = self.estimate_tokens(request);
214        let max_completion_tokens = request.config.max_tokens.unwrap_or(self.max_tokens());
215        estimated_tokens + max_completion_tokens <= self.max_context_window()
216    }
217}
218
219/// Extension trait for `AiProvider` that provides convenience methods.
220pub trait AiProviderExt: AiProvider {
221    /// Execute a simple chat request with a single user message.
222    ///
223    /// # Arguments
224    /// * `model` - The model to use
225    /// * `message` - The user message
226    ///
227    /// # Returns
228    /// A `ChatResponse` containing the generated completion
229    ///
230    /// # Errors
231    /// Returns an `AixError` if the request fails
232    async fn chat_simple<S: Into<String>, M: Into<String>>(
233        &self,
234        model: S,
235        message: M,
236    ) -> AixResult<ChatResponse> {
237        let request = crate::types::ChatRequest::simple(model, message);
238        self.chat(request).await
239    }
240
241    /// Execute a streaming chat request with a single user message.
242    ///
243    /// # Arguments
244    /// * `model` - The model to use
245    /// * `message` - The user message
246    ///
247    /// # Returns
248    /// A `TokenStream` that yields `StreamChunk` items as they are generated
249    ///
250    /// # Errors
251    /// Returns an `AixError` if the stream cannot be established
252    async fn chat_stream_simple<S: Into<String>, M: Into<String>>(
253        &self,
254        model: S,
255        message: M,
256    ) -> AixResult<TokenStream> {
257        let request = crate::types::ChatRequest::new(model)
258            .message(crate::types::ChatMessage::user(message))
259            .stream(true)
260            .build();
261        self.chat_stream(request).await
262    }
263}
264
265// Blanket implementation for all types that implement AiProvider
266impl<T: AiProvider> AiProviderExt for T {}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::types::{ChatMessage, ModelConfig};
272    use crate::streaming::TokenStream;
273
274    // Mock provider for testing
275    struct MockProvider {
276        name: String,
277        capabilities: ModelCapabilities,
278    }
279
280    #[async_trait]
281    impl AiProvider for MockProvider {
282        async fn chat(&self, _request: ChatRequest) -> AixResult<ChatResponse> {
283            Ok(ChatResponse::new(
284                "test-id",
285                "test-model",
286                "Test response",
287                crate::types::Role::Assistant,
288                crate::types::Usage::new(10, 20),
289            ))
290        }
291
292        async fn chat_stream(&self, _request: ChatRequest) -> AixResult<TokenStream> {
293            // Return an empty stream for testing
294            Ok(crate::streaming::from_iter(std::iter::empty()))
295        }
296
297        fn provider_name(&self) -> &str {
298            &self.name
299        }
300
301        fn capabilities(&self) -> ModelCapabilities {
302            self.capabilities.clone()
303        }
304    }
305
306    #[tokio::test]
307    async fn test_provider_capabilities() {
308        let provider = MockProvider {
309            name: "test".to_string(),
310            capabilities: ModelCapabilities::full_featured(4096, 8192),
311        };
312
313        assert!(provider.supports_streaming());
314        assert!(provider.supports_function_calling());
315        assert!(provider.supports_vision());
316        assert_eq!(provider.max_tokens(), 4096);
317        assert_eq!(provider.max_context_window(), 8192);
318    }
319
320    #[tokio::test]
321    async fn test_provider_validation() {
322        let provider = MockProvider {
323            name: "test".to_string(),
324            capabilities: ModelCapabilities::basic_text(4096, 8192),
325        };
326
327        // Valid request should pass
328        let valid_request = ChatRequest::simple("test-model", "Hello, world!");
329        assert!(provider.validate_request(&valid_request).is_ok());
330
331        // Empty model should fail
332        let invalid_request = ChatRequest {
333            model: String::new(),
334            messages: vec![ChatMessage::user("Hello")],
335            config: ModelConfig::default(),
336            stream: false,
337        };
338        assert!(provider.validate_request(&invalid_request).is_err());
339
340        // Empty messages should fail
341        let empty_messages_request = ChatRequest {
342            model: "test-model".to_string(),
343            messages: vec![],
344            config: ModelConfig::default(),
345            stream: false,
346        };
347        assert!(provider.validate_request(&empty_messages_request).is_err());
348    }
349
350    #[tokio::test]
351    async fn test_provider_extension_methods() {
352        let provider = MockProvider {
353            name: "test".to_string(),
354            capabilities: ModelCapabilities::basic_text(4096, 8192),
355        };
356
357        let response = provider
358            .chat_simple("test-model", "Hello, world!")
359            .await
360            .unwrap();
361        assert_eq!(response.content, "Test response");
362
363        let stream = provider
364            .chat_stream_simple("test-model", "Hello, world!")
365            .await
366            .unwrap();
367        // Stream should be valid (empty in this mock case)
368        drop(stream);
369    }
370
371    #[test]
372    fn test_capabilities_constructors() {
373        let basic = ModelCapabilities::basic_text(2048, 4096);
374        assert!(!basic.supports_streaming);
375        assert!(!basic.supports_function_calling);
376        assert!(!basic.supports_vision);
377
378        let full = ModelCapabilities::full_featured(4096, 8192);
379        assert!(full.supports_streaming);
380        assert!(full.supports_function_calling);
381        assert!(full.supports_vision);
382
383        let streaming = ModelCapabilities::streaming_text(2048, 4096);
384        assert!(streaming.supports_streaming);
385        assert!(!streaming.supports_function_calling);
386        assert!(!streaming.supports_vision);
387    }
388}