ai_lib/provider/
models.rs

1use crate::types::AiLibError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6/// Model information structure for custom model management
7///
8/// This struct provides detailed information about AI models,
9/// allowing developers to build custom model managers and arrays.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12    /// Model name/identifier
13    pub name: String,
14    /// Display name for user interface
15    pub display_name: String,
16    /// Model description
17    pub description: String,
18    /// Model capabilities
19    pub capabilities: ModelCapabilities,
20    /// Pricing information
21    pub pricing: PricingInfo,
22    /// Performance metrics
23    pub performance: PerformanceMetrics,
24    /// Provider-specific metadata
25    pub metadata: HashMap<String, String>,
26}
27
28/// Model capabilities enumeration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelCapabilities {
31    /// Chat capabilities
32    pub chat: bool,
33    /// Code generation capabilities
34    pub code_generation: bool,
35    /// Multimodal capabilities (text + image/audio)
36    pub multimodal: bool,
37    /// Function calling capabilities
38    pub function_calling: bool,
39    /// Tool use capabilities
40    pub tool_use: bool,
41    /// Multilingual support
42    pub multilingual: bool,
43    /// Context window size in tokens
44    pub context_window: Option<u32>,
45}
46
47impl ModelCapabilities {
48    /// Create new capabilities with default values
49    pub fn new() -> Self {
50        Self {
51            chat: true,
52            code_generation: false,
53            multimodal: false,
54            function_calling: false,
55            tool_use: false,
56            multilingual: false,
57            context_window: None,
58        }
59    }
60
61    /// Enable chat capabilities
62    pub fn with_chat(mut self) -> Self {
63        self.chat = true;
64        self
65    }
66
67    /// Enable code generation capabilities
68    pub fn with_code_generation(mut self) -> Self {
69        self.code_generation = true;
70        self
71    }
72
73    /// Enable multimodal capabilities
74    pub fn with_multimodal(mut self) -> Self {
75        self.multimodal = true;
76        self
77    }
78
79    /// Enable function calling capabilities
80    pub fn with_function_calling(mut self) -> Self {
81        self.function_calling = true;
82        self
83    }
84
85    /// Enable tool use capabilities
86    pub fn with_tool_use(mut self) -> Self {
87        self.tool_use = true;
88        self
89    }
90
91    /// Enable multilingual support
92    pub fn with_multilingual(mut self) -> Self {
93        self.multilingual = true;
94        self
95    }
96
97    /// Set context window size
98    pub fn with_context_window(mut self, size: u32) -> Self {
99        self.context_window = Some(size);
100        self
101    }
102
103    /// Check if model supports a specific capability
104    pub fn supports(&self, capability: &str) -> bool {
105        match capability {
106            "chat" => self.chat,
107            "code_generation" => self.code_generation,
108            "multimodal" => self.multimodal,
109            "function_calling" => self.function_calling,
110            "tool_use" => self.tool_use,
111            "multilingual" => self.multilingual,
112            _ => false,
113        }
114    }
115}
116
117/// Pricing information for models
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PricingInfo {
120    /// Cost per input token (in USD)
121    pub input_cost_per_1k: f64,
122    /// Cost per output token (in USD)
123    pub output_cost_per_1k: f64,
124    /// Currency (default: USD)
125    pub currency: String,
126}
127
128impl PricingInfo {
129    /// Create new pricing information
130    pub fn new(input_cost_per_1k: f64, output_cost_per_1k: f64) -> Self {
131        Self {
132            input_cost_per_1k,
133            output_cost_per_1k,
134            currency: "USD".to_string(),
135        }
136    }
137
138    /// Set custom currency
139    pub fn with_currency(mut self, currency: &str) -> Self {
140        self.currency = currency.to_string();
141        self
142    }
143
144    /// Calculate cost for a given number of tokens
145    pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
146        let input_cost = (input_tokens as f64 / 1000.0) * self.input_cost_per_1k;
147        let output_cost = (output_tokens as f64 / 1000.0) * self.output_cost_per_1k;
148        input_cost + output_cost
149    }
150}
151
152/// Performance metrics for models
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155    /// Speed tier classification
156    pub speed: SpeedTier,
157    /// Quality tier classification
158    pub quality: QualityTier,
159    /// Average response time
160    pub avg_response_time: Option<Duration>,
161    /// Throughput (requests per second)
162    pub throughput: Option<f64>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum SpeedTier {
167    Fast,
168    Balanced,
169    Slow,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub enum QualityTier {
174    Basic,
175    Good,
176    Excellent,
177}
178
179impl PerformanceMetrics {
180    /// Create new performance metrics
181    pub fn new() -> Self {
182        Self {
183            speed: SpeedTier::Balanced,
184            quality: QualityTier::Good,
185            avg_response_time: None,
186            throughput: None,
187        }
188    }
189
190    /// Set speed tier
191    pub fn with_speed(mut self, speed: SpeedTier) -> Self {
192        self.speed = speed;
193        self
194    }
195
196    /// Set quality tier
197    pub fn with_quality(mut self, quality: QualityTier) -> Self {
198        self.quality = quality;
199        self
200    }
201
202    /// Set average response time
203    pub fn with_avg_response_time(mut self, time: Duration) -> Self {
204        self.avg_response_time = Some(time);
205        self
206    }
207
208    /// Set throughput
209    pub fn with_throughput(mut self, tps: f64) -> Self {
210        self.throughput = Some(tps);
211        self
212    }
213}
214
215/// Custom model manager for developers
216///
217/// This struct allows developers to build their own model management systems,
218/// including model discovery, selection, and load balancing.
219#[derive(Clone)]
220pub struct CustomModelManager {
221    /// Provider identifier
222    pub provider: String,
223    /// Available models
224    pub models: HashMap<String, ModelInfo>,
225    /// Model selection strategy
226    pub selection_strategy: ModelSelectionStrategy,
227}
228
229#[derive(Debug, Clone)]
230pub enum ModelSelectionStrategy {
231    /// Round-robin selection
232    RoundRobin,
233    /// Weighted selection based on performance
234    Weighted,
235    /// Least connections selection
236    LeastConnections,
237    /// Performance-based selection
238    PerformanceBased,
239    /// Cost-based selection
240    CostBased,
241}
242
243impl CustomModelManager {
244    /// Create new model manager
245    pub fn new(provider: &str) -> Self {
246        Self {
247            provider: provider.to_string(),
248            models: HashMap::new(),
249            selection_strategy: ModelSelectionStrategy::RoundRobin,
250        }
251    }
252
253    /// Add a model to the manager
254    pub fn add_model(&mut self, model: ModelInfo) {
255        self.models.insert(model.name.clone(), model);
256    }
257
258    /// Remove a model from the manager
259    pub fn remove_model(&mut self, model_name: &str) -> Option<ModelInfo> {
260        self.models.remove(model_name)
261    }
262
263    /// Get model information
264    pub fn get_model(&self, model_name: &str) -> Option<&ModelInfo> {
265        self.models.get(model_name)
266    }
267
268    /// List all available models
269    pub fn list_models(&self) -> Vec<&ModelInfo> {
270        self.models.values().collect()
271    }
272
273    /// Set selection strategy
274    pub fn with_strategy(mut self, strategy: ModelSelectionStrategy) -> Self {
275        self.selection_strategy = strategy;
276        self
277    }
278
279    /// Select model based on current strategy
280    pub fn select_model(&self) -> Option<&ModelInfo> {
281        if self.models.is_empty() {
282            return None;
283        }
284
285        match self.selection_strategy {
286            ModelSelectionStrategy::RoundRobin => {
287                // Simple round-robin implementation
288                let models: Vec<&ModelInfo> = self.models.values().collect();
289                let index = (std::time::SystemTime::now()
290                    .duration_since(std::time::UNIX_EPOCH)
291                    .unwrap()
292                    .as_secs() as usize)
293                    % models.len();
294                Some(models[index])
295            }
296            ModelSelectionStrategy::Weighted => {
297                // Weighted selection based on performance
298                self.models.values().max_by_key(|model| {
299                    let speed_score = match model.performance.speed {
300                        SpeedTier::Fast => 3,
301                        SpeedTier::Balanced => 2,
302                        SpeedTier::Slow => 1,
303                    };
304                    let quality_score = match model.performance.quality {
305                        QualityTier::Excellent => 3,
306                        QualityTier::Good => 2,
307                        QualityTier::Basic => 1,
308                    };
309                    speed_score + quality_score
310                })
311            }
312            ModelSelectionStrategy::LeastConnections => {
313                // For now, return first available model
314                // In a real implementation, this would track connection counts
315                self.models.values().next()
316            }
317            ModelSelectionStrategy::PerformanceBased => {
318                // Select model with best performance metrics
319                self.models.values().max_by_key(|model| {
320                    let speed_score = match model.performance.speed {
321                        SpeedTier::Fast => 3,
322                        SpeedTier::Balanced => 2,
323                        SpeedTier::Slow => 1,
324                    };
325                    speed_score
326                })
327            }
328            ModelSelectionStrategy::CostBased => {
329                // Select model with lowest cost
330                self.models.values().min_by(|a, b| {
331                    let a_cost = a.pricing.input_cost_per_1k + a.pricing.output_cost_per_1k;
332                    let b_cost = b.pricing.input_cost_per_1k + b.pricing.output_cost_per_1k;
333                    a_cost
334                        .partial_cmp(&b_cost)
335                        .unwrap_or(std::cmp::Ordering::Equal)
336                })
337            }
338        }
339    }
340
341    /// Recommend model for specific use case
342    pub fn recommend_for(&self, use_case: &str) -> Option<&ModelInfo> {
343        let supported_models: Vec<&ModelInfo> = self
344            .models
345            .values()
346            .filter(|model| model.capabilities.supports(use_case))
347            .collect();
348
349        if supported_models.is_empty() {
350            return None;
351        }
352
353        // For now, return the first supported model
354        // In a real implementation, this could use more sophisticated logic
355        supported_models.first().copied()
356    }
357
358    /// Load models from configuration file
359    pub fn load_from_config(&mut self, config_path: &str) -> Result<(), AiLibError> {
360        let config_content = std::fs::read_to_string(config_path)
361            .map_err(|e| AiLibError::ConfigurationError(format!("Failed to read config: {}", e)))?;
362
363        let models: Vec<ModelInfo> = serde_json::from_str(&config_content).map_err(|e| {
364            AiLibError::ConfigurationError(format!("Failed to parse config: {}", e))
365        })?;
366
367        for model in models {
368            self.add_model(model);
369        }
370
371        Ok(())
372    }
373
374    /// Save current model configuration to file
375    pub fn save_to_config(&self, config_path: &str) -> Result<(), AiLibError> {
376        let models: Vec<&ModelInfo> = self.models.values().collect();
377        let config_content = serde_json::to_string_pretty(&models).map_err(|e| {
378            AiLibError::ConfigurationError(format!("Failed to serialize config: {}", e))
379        })?;
380
381        std::fs::write(config_path, config_content).map_err(|e| {
382            AiLibError::ConfigurationError(format!("Failed to write config: {}", e))
383        })?;
384
385        Ok(())
386    }
387}
388
389/// Model array for load balancing and A/B testing
390///
391/// This struct allows developers to build model arrays with multiple endpoints,
392/// supporting various load balancing strategies.
393pub struct ModelArray {
394    /// Array name/identifier
395    pub name: String,
396    /// Model endpoints in the array
397    pub endpoints: Vec<ModelEndpoint>,
398    /// Load balancing strategy
399    pub strategy: LoadBalancingStrategy,
400    /// Health check configuration
401    pub health_check: HealthCheckConfig,
402}
403
404/// Model endpoint in an array
405#[derive(Debug, Clone)]
406pub struct ModelEndpoint {
407    /// Endpoint name
408    pub name: String,
409    /// Model name
410    pub model_name: String,
411    /// Endpoint URL
412    pub url: String,
413    /// Weight for weighted load balancing
414    pub weight: f32,
415    /// Health status
416    pub healthy: bool,
417    /// Connection count
418    pub connection_count: u32,
419}
420
421/// Load balancing strategies
422#[derive(Debug, Clone)]
423pub enum LoadBalancingStrategy {
424    /// Round-robin load balancing
425    RoundRobin,
426    /// Weighted load balancing
427    Weighted,
428    /// Least connections load balancing
429    LeastConnections,
430    /// Health-based load balancing
431    HealthBased,
432}
433
434/// Health check configuration
435#[derive(Debug, Clone)]
436pub struct HealthCheckConfig {
437    /// Health check endpoint
438    pub endpoint: String,
439    /// Health check interval
440    pub interval: Duration,
441    /// Health check timeout
442    pub timeout: Duration,
443    /// Maximum consecutive failures
444    pub max_failures: u32,
445}
446
447impl ModelArray {
448    /// Create new model array
449    pub fn new(name: &str) -> Self {
450        Self {
451            name: name.to_string(),
452            endpoints: Vec::new(),
453            strategy: LoadBalancingStrategy::RoundRobin,
454            health_check: HealthCheckConfig {
455                endpoint: "/health".to_string(),
456                interval: Duration::from_secs(30),
457                timeout: Duration::from_secs(5),
458                max_failures: 3,
459            },
460        }
461    }
462
463    /// Add endpoint to the array
464    pub fn add_endpoint(&mut self, endpoint: ModelEndpoint) {
465        self.endpoints.push(endpoint);
466    }
467
468    /// Set load balancing strategy
469    pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
470        self.strategy = strategy;
471        self
472    }
473
474    /// Configure health check
475    pub fn with_health_check(mut self, config: HealthCheckConfig) -> Self {
476        self.health_check = config;
477        self
478    }
479
480    /// Select next endpoint based on strategy
481    pub fn select_endpoint(&mut self) -> Option<&mut ModelEndpoint> {
482        if self.endpoints.is_empty() {
483            return None;
484        }
485
486        // Get indices of healthy endpoints
487        let healthy_indices: Vec<usize> = self
488            .endpoints
489            .iter()
490            .enumerate()
491            .filter(|(_, endpoint)| endpoint.healthy)
492            .map(|(index, _)| index)
493            .collect();
494
495        if healthy_indices.is_empty() {
496            return None;
497        }
498
499        match self.strategy {
500            LoadBalancingStrategy::RoundRobin => {
501                // Simple round-robin implementation
502                let index = (std::time::SystemTime::now()
503                    .duration_since(std::time::UNIX_EPOCH)
504                    .unwrap()
505                    .as_secs() as usize)
506                    % healthy_indices.len();
507                let endpoint_index = healthy_indices[index];
508                Some(&mut self.endpoints[endpoint_index])
509            }
510            LoadBalancingStrategy::Weighted => {
511                // Weighted selection - simplified implementation
512                let total_weight: f32 = healthy_indices
513                    .iter()
514                    .map(|&idx| self.endpoints[idx].weight)
515                    .sum();
516                let mut current_weight = 0.0;
517
518                for &idx in &healthy_indices {
519                    current_weight += self.endpoints[idx].weight;
520                    if current_weight >= total_weight / 2.0 {
521                        return Some(&mut self.endpoints[idx]);
522                    }
523                }
524
525                // Fallback to first healthy endpoint
526                let endpoint_index = healthy_indices[0];
527                Some(&mut self.endpoints[endpoint_index])
528            }
529            LoadBalancingStrategy::LeastConnections => {
530                // Select endpoint with least connections
531                healthy_indices
532                    .iter()
533                    .min_by_key(|&&idx| self.endpoints[idx].connection_count)
534                    .map(|&idx| &mut self.endpoints[idx])
535            }
536            LoadBalancingStrategy::HealthBased => {
537                // Select first healthy endpoint
538                let endpoint_index = healthy_indices[0];
539                Some(&mut self.endpoints[endpoint_index])
540            }
541        }
542    }
543
544    /// Mark endpoint as unhealthy
545    pub fn mark_unhealthy(&mut self, endpoint_name: &str) {
546        if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
547            endpoint.healthy = false;
548        }
549    }
550
551    /// Mark endpoint as healthy
552    pub fn mark_healthy(&mut self, endpoint_name: &str) {
553        if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
554            endpoint.healthy = true;
555        }
556    }
557
558    /// Get array health status
559    pub fn is_healthy(&self) -> bool {
560        self.endpoints.iter().any(|endpoint| endpoint.healthy)
561    }
562}
563
564impl Default for ModelCapabilities {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570impl Default for PerformanceMetrics {
571    fn default() -> Self {
572        Self::new()
573    }
574}