1use crate::error::{AixError, AixResult};
7use crate::types::{ChatRequest, ChatResponse};
8use crate::streaming::TokenStream;
9use async_trait::async_trait;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelCapabilities {
14 pub supports_streaming: bool,
16 pub supports_function_calling: bool,
18 pub supports_vision: bool,
20 pub max_tokens: u32,
22 pub max_context_window: u32,
24}
25
26impl ModelCapabilities {
27 pub fn new(
29 supports_streaming: bool,
30 supports_function_calling: bool,
31 supports_vision: bool,
32 max_tokens: u32,
33 max_context_window: u32,
34 ) -> Self {
35 Self {
36 supports_streaming,
37 supports_function_calling,
38 supports_vision,
39 max_tokens,
40 max_context_window,
41 }
42 }
43
44 pub fn basic_text(max_tokens: u32, max_context_window: u32) -> Self {
46 Self::new(false, false, false, max_tokens, max_context_window)
47 }
48
49 pub fn full_featured(max_tokens: u32, max_context_window: u32) -> Self {
51 Self::new(true, true, true, max_tokens, max_context_window)
52 }
53
54 pub fn streaming_text(max_tokens: u32, max_context_window: u32) -> Self {
56 Self::new(true, false, false, max_tokens, max_context_window)
57 }
58}
59
60#[async_trait]
66pub trait AiProvider: Send + Sync {
67 async fn chat(&self, request: ChatRequest) -> AixResult<ChatResponse>;
78
79 async fn chat_stream(&self, request: ChatRequest) -> AixResult<TokenStream>;
90
91 fn provider_name(&self) -> &str;
96
97 fn capabilities(&self) -> ModelCapabilities;
102
103 fn supports_streaming(&self) -> bool {
108 self.capabilities().supports_streaming
109 }
110
111 fn supports_function_calling(&self) -> bool {
116 self.capabilities().supports_function_calling
117 }
118
119 fn supports_vision(&self) -> bool {
124 self.capabilities().supports_vision
125 }
126
127 fn max_tokens(&self) -> u32 {
132 self.capabilities().max_tokens
133 }
134
135 fn max_context_window(&self) -> u32 {
140 self.capabilities().max_context_window
141 }
142
143 fn validate_request(&self, request: &ChatRequest) -> AixResult<()> {
154 if request.model.is_empty() {
156 return Err(AixError::config("Model name cannot be empty"));
157 }
158
159 if request.messages.is_empty() {
160 return Err(AixError::config("Messages cannot be empty"));
161 }
162
163 if let Some(max_tokens) = request.config.max_tokens {
165 if max_tokens > self.max_tokens() {
166 return Err(AixError::config(format!(
167 "Requested max_tokens ({}) exceeds provider limit ({})",
168 max_tokens,
169 self.max_tokens()
170 )));
171 }
172 }
173
174 for (i, message) in request.messages.iter().enumerate() {
176 if message.content.is_empty() {
177 return Err(AixError::config(format!(
178 "Message {} has empty content",
179 i + 1
180 )));
181 }
182 }
183
184 Ok(())
185 }
186
187 fn estimate_tokens(&self, request: &ChatRequest) -> u32 {
198 let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
202 (total_chars / 4) as u32
203 }
204
205 fn fits_in_context(&self, request: &ChatRequest) -> bool {
213 let estimated_tokens = self.estimate_tokens(request);
214 let max_completion_tokens = request.config.max_tokens.unwrap_or(self.max_tokens());
215 estimated_tokens + max_completion_tokens <= self.max_context_window()
216 }
217}
218
219pub trait AiProviderExt: AiProvider {
221 async fn chat_simple<S: Into<String>, M: Into<String>>(
233 &self,
234 model: S,
235 message: M,
236 ) -> AixResult<ChatResponse> {
237 let request = crate::types::ChatRequest::simple(model, message);
238 self.chat(request).await
239 }
240
241 async fn chat_stream_simple<S: Into<String>, M: Into<String>>(
253 &self,
254 model: S,
255 message: M,
256 ) -> AixResult<TokenStream> {
257 let request = crate::types::ChatRequest::new(model)
258 .message(crate::types::ChatMessage::user(message))
259 .stream(true)
260 .build();
261 self.chat_stream(request).await
262 }
263}
264
265impl<T: AiProvider> AiProviderExt for T {}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::types::{ChatMessage, ModelConfig};
272 use crate::streaming::TokenStream;
273
274 struct MockProvider {
276 name: String,
277 capabilities: ModelCapabilities,
278 }
279
280 #[async_trait]
281 impl AiProvider for MockProvider {
282 async fn chat(&self, _request: ChatRequest) -> AixResult<ChatResponse> {
283 Ok(ChatResponse::new(
284 "test-id",
285 "test-model",
286 "Test response",
287 crate::types::Role::Assistant,
288 crate::types::Usage::new(10, 20),
289 ))
290 }
291
292 async fn chat_stream(&self, _request: ChatRequest) -> AixResult<TokenStream> {
293 Ok(crate::streaming::from_iter(std::iter::empty()))
295 }
296
297 fn provider_name(&self) -> &str {
298 &self.name
299 }
300
301 fn capabilities(&self) -> ModelCapabilities {
302 self.capabilities.clone()
303 }
304 }
305
306 #[tokio::test]
307 async fn test_provider_capabilities() {
308 let provider = MockProvider {
309 name: "test".to_string(),
310 capabilities: ModelCapabilities::full_featured(4096, 8192),
311 };
312
313 assert!(provider.supports_streaming());
314 assert!(provider.supports_function_calling());
315 assert!(provider.supports_vision());
316 assert_eq!(provider.max_tokens(), 4096);
317 assert_eq!(provider.max_context_window(), 8192);
318 }
319
320 #[tokio::test]
321 async fn test_provider_validation() {
322 let provider = MockProvider {
323 name: "test".to_string(),
324 capabilities: ModelCapabilities::basic_text(4096, 8192),
325 };
326
327 let valid_request = ChatRequest::simple("test-model", "Hello, world!");
329 assert!(provider.validate_request(&valid_request).is_ok());
330
331 let invalid_request = ChatRequest {
333 model: String::new(),
334 messages: vec![ChatMessage::user("Hello")],
335 config: ModelConfig::default(),
336 stream: false,
337 };
338 assert!(provider.validate_request(&invalid_request).is_err());
339
340 let empty_messages_request = ChatRequest {
342 model: "test-model".to_string(),
343 messages: vec![],
344 config: ModelConfig::default(),
345 stream: false,
346 };
347 assert!(provider.validate_request(&empty_messages_request).is_err());
348 }
349
350 #[tokio::test]
351 async fn test_provider_extension_methods() {
352 let provider = MockProvider {
353 name: "test".to_string(),
354 capabilities: ModelCapabilities::basic_text(4096, 8192),
355 };
356
357 let response = provider
358 .chat_simple("test-model", "Hello, world!")
359 .await
360 .unwrap();
361 assert_eq!(response.content, "Test response");
362
363 let stream = provider
364 .chat_stream_simple("test-model", "Hello, world!")
365 .await
366 .unwrap();
367 drop(stream);
369 }
370
371 #[test]
372 fn test_capabilities_constructors() {
373 let basic = ModelCapabilities::basic_text(2048, 4096);
374 assert!(!basic.supports_streaming);
375 assert!(!basic.supports_function_calling);
376 assert!(!basic.supports_vision);
377
378 let full = ModelCapabilities::full_featured(4096, 8192);
379 assert!(full.supports_streaming);
380 assert!(full.supports_function_calling);
381 assert!(full.supports_vision);
382
383 let streaming = ModelCapabilities::streaming_text(2048, 4096);
384 assert!(streaming.supports_streaming);
385 assert!(!streaming.supports_function_calling);
386 assert!(!streaming.supports_vision);
387 }
388}