use crate::error::{AixError, AixResult};
use crate::types::{ChatRequest, ChatResponse};
use crate::streaming::TokenStream;
use async_trait::async_trait;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelCapabilities {
pub supports_streaming: bool,
pub supports_function_calling: bool,
pub supports_vision: bool,
pub max_tokens: u32,
pub max_context_window: u32,
}
impl ModelCapabilities {
pub fn new(
supports_streaming: bool,
supports_function_calling: bool,
supports_vision: bool,
max_tokens: u32,
max_context_window: u32,
) -> Self {
Self {
supports_streaming,
supports_function_calling,
supports_vision,
max_tokens,
max_context_window,
}
}
pub fn basic_text(max_tokens: u32, max_context_window: u32) -> Self {
Self::new(false, false, false, max_tokens, max_context_window)
}
pub fn full_featured(max_tokens: u32, max_context_window: u32) -> Self {
Self::new(true, true, true, max_tokens, max_context_window)
}
pub fn streaming_text(max_tokens: u32, max_context_window: u32) -> Self {
Self::new(true, false, false, max_tokens, max_context_window)
}
}
#[async_trait]
pub trait AiProvider: Send + Sync {
async fn chat(&self, request: ChatRequest) -> AixResult<ChatResponse>;
async fn chat_stream(&self, request: ChatRequest) -> AixResult<TokenStream>;
fn provider_name(&self) -> &str;
fn capabilities(&self) -> ModelCapabilities;
fn supports_streaming(&self) -> bool {
self.capabilities().supports_streaming
}
fn supports_function_calling(&self) -> bool {
self.capabilities().supports_function_calling
}
fn supports_vision(&self) -> bool {
self.capabilities().supports_vision
}
fn max_tokens(&self) -> u32 {
self.capabilities().max_tokens
}
fn max_context_window(&self) -> u32 {
self.capabilities().max_context_window
}
fn validate_request(&self, request: &ChatRequest) -> AixResult<()> {
if request.model.is_empty() {
return Err(AixError::config("Model name cannot be empty"));
}
if request.messages.is_empty() {
return Err(AixError::config("Messages cannot be empty"));
}
if let Some(max_tokens) = request.config.max_tokens {
if max_tokens > self.max_tokens() {
return Err(AixError::config(format!(
"Requested max_tokens ({}) exceeds provider limit ({})",
max_tokens,
self.max_tokens()
)));
}
}
for (i, message) in request.messages.iter().enumerate() {
if message.content.is_empty() {
return Err(AixError::config(format!(
"Message {} has empty content",
i + 1
)));
}
}
Ok(())
}
fn estimate_tokens(&self, request: &ChatRequest) -> u32 {
let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
(total_chars / 4) as u32
}
fn fits_in_context(&self, request: &ChatRequest) -> bool {
let estimated_tokens = self.estimate_tokens(request);
let max_completion_tokens = request.config.max_tokens.unwrap_or(self.max_tokens());
estimated_tokens + max_completion_tokens <= self.max_context_window()
}
}
pub trait AiProviderExt: AiProvider {
async fn chat_simple<S: Into<String>, M: Into<String>>(
&self,
model: S,
message: M,
) -> AixResult<ChatResponse> {
let request = crate::types::ChatRequest::simple(model, message);
self.chat(request).await
}
async fn chat_stream_simple<S: Into<String>, M: Into<String>>(
&self,
model: S,
message: M,
) -> AixResult<TokenStream> {
let request = crate::types::ChatRequest::new(model)
.message(crate::types::ChatMessage::user(message))
.stream(true)
.build();
self.chat_stream(request).await
}
}
impl<T: AiProvider> AiProviderExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ChatMessage, ModelConfig};
use crate::streaming::TokenStream;
struct MockProvider {
name: String,
capabilities: ModelCapabilities,
}
#[async_trait]
impl AiProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> AixResult<ChatResponse> {
Ok(ChatResponse::new(
"test-id",
"test-model",
"Test response",
crate::types::Role::Assistant,
crate::types::Usage::new(10, 20),
))
}
async fn chat_stream(&self, _request: ChatRequest) -> AixResult<TokenStream> {
Ok(crate::streaming::from_iter(std::iter::empty()))
}
fn provider_name(&self) -> &str {
&self.name
}
fn capabilities(&self) -> ModelCapabilities {
self.capabilities.clone()
}
}
#[tokio::test]
async fn test_provider_capabilities() {
let provider = MockProvider {
name: "test".to_string(),
capabilities: ModelCapabilities::full_featured(4096, 8192),
};
assert!(provider.supports_streaming());
assert!(provider.supports_function_calling());
assert!(provider.supports_vision());
assert_eq!(provider.max_tokens(), 4096);
assert_eq!(provider.max_context_window(), 8192);
}
#[tokio::test]
async fn test_provider_validation() {
let provider = MockProvider {
name: "test".to_string(),
capabilities: ModelCapabilities::basic_text(4096, 8192),
};
let valid_request = ChatRequest::simple("test-model", "Hello, world!");
assert!(provider.validate_request(&valid_request).is_ok());
let invalid_request = ChatRequest {
model: String::new(),
messages: vec![ChatMessage::user("Hello")],
config: ModelConfig::default(),
stream: false,
};
assert!(provider.validate_request(&invalid_request).is_err());
let empty_messages_request = ChatRequest {
model: "test-model".to_string(),
messages: vec![],
config: ModelConfig::default(),
stream: false,
};
assert!(provider.validate_request(&empty_messages_request).is_err());
}
#[tokio::test]
async fn test_provider_extension_methods() {
let provider = MockProvider {
name: "test".to_string(),
capabilities: ModelCapabilities::basic_text(4096, 8192),
};
let response = provider
.chat_simple("test-model", "Hello, world!")
.await
.unwrap();
assert_eq!(response.content, "Test response");
let stream = provider
.chat_stream_simple("test-model", "Hello, world!")
.await
.unwrap();
drop(stream);
}
#[test]
fn test_capabilities_constructors() {
let basic = ModelCapabilities::basic_text(2048, 4096);
assert!(!basic.supports_streaming);
assert!(!basic.supports_function_calling);
assert!(!basic.supports_vision);
let full = ModelCapabilities::full_featured(4096, 8192);
assert!(full.supports_streaming);
assert!(full.supports_function_calling);
assert!(full.supports_vision);
let streaming = ModelCapabilities::streaming_text(2048, 4096);
assert!(streaming.supports_streaming);
assert!(!streaming.supports_function_calling);
assert!(!streaming.supports_vision);
}
}