aix-core 0.1.0

Core abstractions and types for the AIX library
Documentation
//! Core traits and abstractions for AI providers.
//!
//! This module defines the `AiProvider` trait that all providers must implement,
//! along with related types and capabilities.

use crate::error::{AixError, AixResult};
use crate::types::{ChatRequest, ChatResponse};
use crate::streaming::TokenStream;
use async_trait::async_trait;

/// Capabilities of a model/provider.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelCapabilities {
    /// Whether the provider supports streaming responses
    pub supports_streaming: bool,
    /// Whether the provider supports function/tool calling
    pub supports_function_calling: bool,
    /// Whether the provider supports vision/image inputs
    pub supports_vision: bool,
    /// Maximum number of tokens that can be generated
    pub max_tokens: u32,
    /// Maximum context window size (prompt + completion)
    pub max_context_window: u32,
}

impl ModelCapabilities {
    /// Create a new capabilities specification.
    pub fn new(
        supports_streaming: bool,
        supports_function_calling: bool,
        supports_vision: bool,
        max_tokens: u32,
        max_context_window: u32,
    ) -> Self {
        Self {
            supports_streaming,
            supports_function_calling,
            supports_vision,
            max_tokens,
            max_context_window,
        }
    }

    /// Create capabilities for a basic text-only model without streaming.
    pub fn basic_text(max_tokens: u32, max_context_window: u32) -> Self {
        Self::new(false, false, false, max_tokens, max_context_window)
    }

    /// Create capabilities for a full-featured model.
    pub fn full_featured(max_tokens: u32, max_context_window: u32) -> Self {
        Self::new(true, true, true, max_tokens, max_context_window)
    }

    /// Create capabilities for a streaming text model.
    pub fn streaming_text(max_tokens: u32, max_context_window: u32) -> Self {
        Self::new(true, false, false, max_tokens, max_context_window)
    }
}

/// Core trait that all AI providers must implement.
///
/// This trait provides a unified interface for interacting with different
/// AI providers (OpenAI, Anthropic, etc.) while allowing each provider
/// to handle its own specifics internally.
#[async_trait]
pub trait AiProvider: Send + Sync {
    /// Execute a chat completion request.
    ///
    /// # Arguments
    /// * `request` - The chat completion request
    ///
    /// # Returns
    /// A `ChatResponse` containing the generated completion
    ///
    /// # Errors
    /// Returns an `AixError` if the request fails
    async fn chat(&self, request: ChatRequest) -> AixResult<ChatResponse>;

    /// Execute a streaming chat completion request.
    ///
    /// # Arguments
    /// * `request` - The chat completion request with streaming enabled
    ///
    /// # Returns
    /// A `TokenStream` that yields `StreamChunk` items as they are generated
    ///
    /// # Errors
    /// Returns an `AixError` if the stream cannot be established
    async fn chat_stream(&self, request: ChatRequest) -> AixResult<TokenStream>;

    /// Get the name of this provider.
    ///
    /// # Returns
    /// A string slice containing the provider name (e.g., "openai", "anthropic")
    fn provider_name(&self) -> &str;

    /// Get the capabilities of this provider.
    ///
    /// # Returns
    /// A `ModelCapabilities` struct describing what this provider supports
    fn capabilities(&self) -> ModelCapabilities;

    /// Check if this provider supports streaming.
    ///
    /// # Returns
    /// `true` if streaming is supported, `false` otherwise
    fn supports_streaming(&self) -> bool {
        self.capabilities().supports_streaming
    }

    /// Check if this provider supports function calling.
    ///
    /// # Returns
    /// `true` if function calling is supported, `false` otherwise
    fn supports_function_calling(&self) -> bool {
        self.capabilities().supports_function_calling
    }

    /// Check if this provider supports vision/image inputs.
    ///
    /// # Returns
    /// `true` if vision is supported, `false` otherwise
    fn supports_vision(&self) -> bool {
        self.capabilities().supports_vision
    }

    /// Get the maximum number of tokens this provider can generate.
    ///
    /// # Returns
    /// The maximum number of tokens for a single completion
    fn max_tokens(&self) -> u32 {
        self.capabilities().max_tokens
    }

    /// Get the maximum context window size.
    ///
    /// # Returns
    /// The maximum number of tokens (prompt + completion) that can be processed
    fn max_context_window(&self) -> u32 {
        self.capabilities().max_context_window
    }

    /// Validate a chat request before sending it.
    ///
    /// This method allows providers to perform provider-specific validation.
    /// The default implementation performs basic validation that applies to all providers.
    ///
    /// # Arguments
    /// * `request` - The chat request to validate
    ///
    /// # Returns
    /// `Ok(())` if the request is valid, or an `AixError` if validation fails
    fn validate_request(&self, request: &ChatRequest) -> AixResult<()> {
        // Basic validation that applies to all providers
        if request.model.is_empty() {
            return Err(AixError::config("Model name cannot be empty"));
        }

        if request.messages.is_empty() {
            return Err(AixError::config("Messages cannot be empty"));
        }

        // Check that we don't exceed the max tokens if specified
        if let Some(max_tokens) = request.config.max_tokens {
            if max_tokens > self.max_tokens() {
                return Err(AixError::config(format!(
                    "Requested max_tokens ({}) exceeds provider limit ({})",
                    max_tokens,
                    self.max_tokens()
                )));
            }
        }

        // Validate message content
        for (i, message) in request.messages.iter().enumerate() {
            if message.content.is_empty() {
                return Err(AixError::config(format!(
                    "Message {} has empty content",
                    i + 1
                )));
            }
        }

        Ok(())
    }

    /// Estimate the number of tokens in a request.
    ///
    /// This is a rough estimate and should not be relied upon for exact token counting.
    /// Different providers may use different tokenization methods.
    ///
    /// # Arguments
    /// * `request` - The chat request to estimate tokens for
    ///
    /// # Returns
    /// An estimated token count
    fn estimate_tokens(&self, request: &ChatRequest) -> u32 {
        // Simple estimation: roughly 4 characters per token
        // This is a very rough estimate and providers should override this
        // with their own tokenization if available
        let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
        (total_chars / 4) as u32
    }

    /// Check if a request is likely to fit within the context window.
    ///
    /// # Arguments
    /// * `request` - The chat request to check
    ///
    /// # Returns
    /// `true` if the request is likely to fit, `false` otherwise
    fn fits_in_context(&self, request: &ChatRequest) -> bool {
        let estimated_tokens = self.estimate_tokens(request);
        let max_completion_tokens = request.config.max_tokens.unwrap_or(self.max_tokens());
        estimated_tokens + max_completion_tokens <= self.max_context_window()
    }
}

/// Extension trait for `AiProvider` that provides convenience methods.
pub trait AiProviderExt: AiProvider {
    /// Execute a simple chat request with a single user message.
    ///
    /// # Arguments
    /// * `model` - The model to use
    /// * `message` - The user message
    ///
    /// # Returns
    /// A `ChatResponse` containing the generated completion
    ///
    /// # Errors
    /// Returns an `AixError` if the request fails
    async fn chat_simple<S: Into<String>, M: Into<String>>(
        &self,
        model: S,
        message: M,
    ) -> AixResult<ChatResponse> {
        let request = crate::types::ChatRequest::simple(model, message);
        self.chat(request).await
    }

    /// Execute a streaming chat request with a single user message.
    ///
    /// # Arguments
    /// * `model` - The model to use
    /// * `message` - The user message
    ///
    /// # Returns
    /// A `TokenStream` that yields `StreamChunk` items as they are generated
    ///
    /// # Errors
    /// Returns an `AixError` if the stream cannot be established
    async fn chat_stream_simple<S: Into<String>, M: Into<String>>(
        &self,
        model: S,
        message: M,
    ) -> AixResult<TokenStream> {
        let request = crate::types::ChatRequest::new(model)
            .message(crate::types::ChatMessage::user(message))
            .stream(true)
            .build();
        self.chat_stream(request).await
    }
}

// Blanket implementation for all types that implement AiProvider
impl<T: AiProvider> AiProviderExt for T {}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::{ChatMessage, ModelConfig};
    use crate::streaming::TokenStream;

    // Mock provider for testing
    struct MockProvider {
        name: String,
        capabilities: ModelCapabilities,
    }

    #[async_trait]
    impl AiProvider for MockProvider {
        async fn chat(&self, _request: ChatRequest) -> AixResult<ChatResponse> {
            Ok(ChatResponse::new(
                "test-id",
                "test-model",
                "Test response",
                crate::types::Role::Assistant,
                crate::types::Usage::new(10, 20),
            ))
        }

        async fn chat_stream(&self, _request: ChatRequest) -> AixResult<TokenStream> {
            // Return an empty stream for testing
            Ok(crate::streaming::from_iter(std::iter::empty()))
        }

        fn provider_name(&self) -> &str {
            &self.name
        }

        fn capabilities(&self) -> ModelCapabilities {
            self.capabilities.clone()
        }
    }

    #[tokio::test]
    async fn test_provider_capabilities() {
        let provider = MockProvider {
            name: "test".to_string(),
            capabilities: ModelCapabilities::full_featured(4096, 8192),
        };

        assert!(provider.supports_streaming());
        assert!(provider.supports_function_calling());
        assert!(provider.supports_vision());
        assert_eq!(provider.max_tokens(), 4096);
        assert_eq!(provider.max_context_window(), 8192);
    }

    #[tokio::test]
    async fn test_provider_validation() {
        let provider = MockProvider {
            name: "test".to_string(),
            capabilities: ModelCapabilities::basic_text(4096, 8192),
        };

        // Valid request should pass
        let valid_request = ChatRequest::simple("test-model", "Hello, world!");
        assert!(provider.validate_request(&valid_request).is_ok());

        // Empty model should fail
        let invalid_request = ChatRequest {
            model: String::new(),
            messages: vec![ChatMessage::user("Hello")],
            config: ModelConfig::default(),
            stream: false,
        };
        assert!(provider.validate_request(&invalid_request).is_err());

        // Empty messages should fail
        let empty_messages_request = ChatRequest {
            model: "test-model".to_string(),
            messages: vec![],
            config: ModelConfig::default(),
            stream: false,
        };
        assert!(provider.validate_request(&empty_messages_request).is_err());
    }

    #[tokio::test]
    async fn test_provider_extension_methods() {
        let provider = MockProvider {
            name: "test".to_string(),
            capabilities: ModelCapabilities::basic_text(4096, 8192),
        };

        let response = provider
            .chat_simple("test-model", "Hello, world!")
            .await
            .unwrap();
        assert_eq!(response.content, "Test response");

        let stream = provider
            .chat_stream_simple("test-model", "Hello, world!")
            .await
            .unwrap();
        // Stream should be valid (empty in this mock case)
        drop(stream);
    }

    #[test]
    fn test_capabilities_constructors() {
        let basic = ModelCapabilities::basic_text(2048, 4096);
        assert!(!basic.supports_streaming);
        assert!(!basic.supports_function_calling);
        assert!(!basic.supports_vision);

        let full = ModelCapabilities::full_featured(4096, 8192);
        assert!(full.supports_streaming);
        assert!(full.supports_function_calling);
        assert!(full.supports_vision);

        let streaming = ModelCapabilities::streaming_text(2048, 4096);
        assert!(streaming.supports_streaming);
        assert!(!streaming.supports_function_calling);
        assert!(!streaming.supports_vision);
    }
}