ai_lib/provider/
models.rs

1use crate::types::AiLibError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6/// Model information structure for custom model management
7///
8/// This struct provides detailed information about AI models,
9/// allowing developers to build custom model managers and arrays.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12    /// Model name/identifier
13    pub name: String,
14    /// Display name for user interface
15    pub display_name: String,
16    /// Model description
17    pub description: String,
18    /// Model capabilities
19    pub capabilities: ModelCapabilities,
20    /// Pricing information
21    pub pricing: PricingInfo,
22    /// Performance metrics
23    pub performance: PerformanceMetrics,
24    /// Provider-specific metadata
25    pub metadata: HashMap<String, String>,
26}
27
28/// Model capabilities enumeration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelCapabilities {
31    /// Chat capabilities
32    pub chat: bool,
33    /// Code generation capabilities
34    pub code_generation: bool,
35    /// Multimodal capabilities (text + image/audio)
36    pub multimodal: bool,
37    /// Function calling capabilities
38    pub function_calling: bool,
39    /// Tool use capabilities
40    pub tool_use: bool,
41    /// Multilingual support
42    pub multilingual: bool,
43    /// Context window size in tokens
44    pub context_window: Option<u32>,
45}
46
47impl ModelCapabilities {
48    /// Create new capabilities with default values
49    pub fn new() -> Self {
50        Self {
51            chat: true,
52            code_generation: false,
53            multimodal: false,
54            function_calling: false,
55            tool_use: false,
56            multilingual: false,
57            context_window: None,
58        }
59    }
60
61    /// Enable chat capabilities
62    pub fn with_chat(mut self) -> Self {
63        self.chat = true;
64        self
65    }
66
67    /// Enable code generation capabilities
68    pub fn with_code_generation(mut self) -> Self {
69        self.code_generation = true;
70        self
71    }
72
73    /// Enable multimodal capabilities
74    pub fn with_multimodal(mut self) -> Self {
75        self.multimodal = true;
76        self
77    }
78
79    /// Enable function calling capabilities
80    pub fn with_function_calling(mut self) -> Self {
81        self.function_calling = true;
82        self
83    }
84
85    /// Enable tool use capabilities
86    pub fn with_tool_use(mut self) -> Self {
87        self.tool_use = true;
88        self
89    }
90
91    /// Enable multilingual support
92    pub fn with_multilingual(mut self) -> Self {
93        self.multilingual = true;
94        self
95    }
96
97    /// Set context window size
98    pub fn with_context_window(mut self, size: u32) -> Self {
99        self.context_window = Some(size);
100        self
101    }
102
103    /// Check if model supports a specific capability
104    pub fn supports(&self, capability: &str) -> bool {
105        match capability {
106            "chat" => self.chat,
107            "code_generation" => self.code_generation,
108            "multimodal" => self.multimodal,
109            "function_calling" => self.function_calling,
110            "tool_use" => self.tool_use,
111            "multilingual" => self.multilingual,
112            _ => false,
113        }
114    }
115}
116
117/// Pricing information for models
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PricingInfo {
120    /// Cost per input token (in USD)
121    pub input_cost_per_1k: f64,
122    /// Cost per output token (in USD)
123    pub output_cost_per_1k: f64,
124    /// Currency (default: USD)
125    pub currency: String,
126}
127
128impl PricingInfo {
129    /// Create new pricing information
130    pub fn new(input_cost_per_1k: f64, output_cost_per_1k: f64) -> Self {
131        Self {
132            input_cost_per_1k,
133            output_cost_per_1k,
134            currency: "USD".to_string(),
135        }
136    }
137
138    /// Set custom currency
139    pub fn with_currency(mut self, currency: &str) -> Self {
140        self.currency = currency.to_string();
141        self
142    }
143
144    /// Calculate cost for a given number of tokens
145    pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
146        let input_cost = (input_tokens as f64 / 1000.0) * self.input_cost_per_1k;
147        let output_cost = (output_tokens as f64 / 1000.0) * self.output_cost_per_1k;
148        input_cost + output_cost
149    }
150}
151
152/// Performance metrics for models
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155    /// Speed tier classification
156    pub speed: SpeedTier,
157    /// Quality tier classification
158    pub quality: QualityTier,
159    /// Average response time
160    pub avg_response_time: Option<Duration>,
161    /// Throughput (requests per second)
162    pub throughput: Option<f64>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum SpeedTier {
167    Fast,
168    Balanced,
169    Slow,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub enum QualityTier {
174    Basic,
175    Good,
176    Excellent,
177}
178
179impl PerformanceMetrics {
180    /// Create new performance metrics
181    pub fn new() -> Self {
182        Self {
183            speed: SpeedTier::Balanced,
184            quality: QualityTier::Good,
185            avg_response_time: None,
186            throughput: None,
187        }
188    }
189
190    /// Set speed tier
191    pub fn with_speed(mut self, speed: SpeedTier) -> Self {
192        self.speed = speed;
193        self
194    }
195
196    /// Set quality tier
197    pub fn with_quality(mut self, quality: QualityTier) -> Self {
198        self.quality = quality;
199        self
200    }
201
202    /// Set average response time
203    pub fn with_avg_response_time(mut self, time: Duration) -> Self {
204        self.avg_response_time = Some(time);
205        self
206    }
207
208    /// Set throughput
209    pub fn with_throughput(mut self, tps: f64) -> Self {
210        self.throughput = Some(tps);
211        self
212    }
213}
214
215/// Custom model manager for developers
216///
217/// This struct allows developers to build their own model management systems,
218/// including model discovery, selection, and load balancing.
219#[derive(Clone)]
220pub struct CustomModelManager {
221    /// Provider identifier
222    pub provider: String,
223    /// Available models
224    pub models: HashMap<String, ModelInfo>,
225    /// Model selection strategy
226    pub selection_strategy: ModelSelectionStrategy,
227}
228
229#[derive(Debug, Clone)]
230pub enum ModelSelectionStrategy {
231    /// Round-robin selection
232    RoundRobin,
233    /// Weighted selection based on performance
234    Weighted,
235    /// Least connections selection
236    LeastConnections,
237    /// Performance-based selection
238    PerformanceBased,
239    /// Cost-based selection
240    CostBased,
241}
242
243impl CustomModelManager {
244    /// Create new model manager
245    pub fn new(provider: &str) -> Self {
246        Self {
247            provider: provider.to_string(),
248            models: HashMap::new(),
249            selection_strategy: ModelSelectionStrategy::RoundRobin,
250        }
251    }
252
253    /// Add a model to the manager
254    pub fn add_model(&mut self, model: ModelInfo) {
255        self.models.insert(model.name.clone(), model);
256    }
257
258    /// Remove a model from the manager
259    pub fn remove_model(&mut self, model_name: &str) -> Option<ModelInfo> {
260        self.models.remove(model_name)
261    }
262
263    /// Get model information
264    pub fn get_model(&self, model_name: &str) -> Option<&ModelInfo> {
265        self.models.get(model_name)
266    }
267
268    /// List all available models
269    pub fn list_models(&self) -> Vec<&ModelInfo> {
270        self.models.values().collect()
271    }
272
273    /// Set selection strategy
274    pub fn with_strategy(mut self, strategy: ModelSelectionStrategy) -> Self {
275        self.selection_strategy = strategy;
276        self
277    }
278
279    /// Select model based on current strategy
280    pub fn select_model(&self) -> Option<&ModelInfo> {
281        if self.models.is_empty() {
282            return None;
283        }
284
285        match self.selection_strategy {
286            ModelSelectionStrategy::RoundRobin => {
287                // Simple round-robin implementation
288                let models: Vec<&ModelInfo> = self.models.values().collect();
289                let index = (std::time::SystemTime::now()
290                    .duration_since(std::time::UNIX_EPOCH)
291                    .unwrap()
292                    .as_secs() as usize)
293                    % models.len();
294                Some(models[index])
295            }
296            ModelSelectionStrategy::Weighted => {
297                // Weighted selection based on performance
298                self.models.values().max_by_key(|model| {
299                    let speed_score = match model.performance.speed {
300                        SpeedTier::Fast => 3,
301                        SpeedTier::Balanced => 2,
302                        SpeedTier::Slow => 1,
303                    };
304                    let quality_score = match model.performance.quality {
305                        QualityTier::Excellent => 3,
306                        QualityTier::Good => 2,
307                        QualityTier::Basic => 1,
308                    };
309                    speed_score + quality_score
310                })
311            }
312            ModelSelectionStrategy::LeastConnections => {
313                // For now, return first available model
314                // In a real implementation, this would track connection counts
315                self.models.values().next()
316            }
317            ModelSelectionStrategy::PerformanceBased => {
318                // Select model with best performance metrics
319                self.models.values().max_by_key(|model| match model.performance.speed {
320                    SpeedTier::Fast => 3,
321                    SpeedTier::Balanced => 2,
322                    SpeedTier::Slow => 1,
323                })
324            }
325            ModelSelectionStrategy::CostBased => {
326                // Select model with lowest cost
327                self.models.values().min_by(|a, b| {
328                    let a_cost = a.pricing.input_cost_per_1k + a.pricing.output_cost_per_1k;
329                    let b_cost = b.pricing.input_cost_per_1k + b.pricing.output_cost_per_1k;
330                    a_cost
331                        .partial_cmp(&b_cost)
332                        .unwrap_or(std::cmp::Ordering::Equal)
333                })
334            }
335        }
336    }
337
338    /// Recommend model for specific use case
339    pub fn recommend_for(&self, use_case: &str) -> Option<&ModelInfo> {
340        let supported_models: Vec<&ModelInfo> = self
341            .models
342            .values()
343            .filter(|model| model.capabilities.supports(use_case))
344            .collect();
345
346        if supported_models.is_empty() {
347            return None;
348        }
349
350        // For now, return the first supported model
351        // In a real implementation, this could use more sophisticated logic
352        supported_models.first().copied()
353    }
354
355    /// Load models from configuration file
356    pub fn load_from_config(&mut self, config_path: &str) -> Result<(), AiLibError> {
357        let config_content = std::fs::read_to_string(config_path)
358            .map_err(|e| AiLibError::ConfigurationError(format!("Failed to read config: {}", e)))?;
359
360        let models: Vec<ModelInfo> = serde_json::from_str(&config_content).map_err(|e| {
361            AiLibError::ConfigurationError(format!("Failed to parse config: {}", e))
362        })?;
363
364        for model in models {
365            self.add_model(model);
366        }
367
368        Ok(())
369    }
370
371    /// Save current model configuration to file
372    pub fn save_to_config(&self, config_path: &str) -> Result<(), AiLibError> {
373        let models: Vec<&ModelInfo> = self.models.values().collect();
374        let config_content = serde_json::to_string_pretty(&models).map_err(|e| {
375            AiLibError::ConfigurationError(format!("Failed to serialize config: {}", e))
376        })?;
377
378        std::fs::write(config_path, config_content).map_err(|e| {
379            AiLibError::ConfigurationError(format!("Failed to write config: {}", e))
380        })?;
381
382        Ok(())
383    }
384}
385
386/// Model array for load balancing and A/B testing
387///
388/// This struct allows developers to build model arrays with multiple endpoints,
389/// supporting various load balancing strategies.
390#[derive(Clone)]
391pub struct ModelArray {
392    /// Array name/identifier
393    pub name: String,
394    /// Model endpoints in the array
395    pub endpoints: Vec<ModelEndpoint>,
396    /// Load balancing strategy
397    pub strategy: LoadBalancingStrategy,
398    /// Health check configuration
399    pub health_check: HealthCheckConfig,
400}
401
402/// Model endpoint in an array
403#[derive(Debug, Clone)]
404pub struct ModelEndpoint {
405    /// Endpoint name
406    pub name: String,
407    /// Model name
408    pub model_name: String,
409    /// Endpoint URL
410    pub url: String,
411    /// Weight for weighted load balancing
412    pub weight: f32,
413    /// Health status
414    pub healthy: bool,
415    /// Connection count
416    pub connection_count: u32,
417}
418
419/// Load balancing strategies
420#[derive(Debug, Clone)]
421pub enum LoadBalancingStrategy {
422    /// Round-robin load balancing
423    RoundRobin,
424    /// Weighted load balancing
425    Weighted,
426    /// Least connections load balancing
427    LeastConnections,
428    /// Health-based load balancing
429    HealthBased,
430}
431
432/// Health check configuration
433#[derive(Debug, Clone)]
434pub struct HealthCheckConfig {
435    /// Health check endpoint
436    pub endpoint: String,
437    /// Health check interval
438    pub interval: Duration,
439    /// Health check timeout
440    pub timeout: Duration,
441    /// Maximum consecutive failures
442    pub max_failures: u32,
443}
444
445impl ModelArray {
446    /// Create new model array
447    pub fn new(name: &str) -> Self {
448        Self {
449            name: name.to_string(),
450            endpoints: Vec::new(),
451            strategy: LoadBalancingStrategy::RoundRobin,
452            health_check: HealthCheckConfig {
453                endpoint: "/health".to_string(),
454                interval: Duration::from_secs(30),
455                timeout: Duration::from_secs(5),
456                max_failures: 3,
457            },
458        }
459    }
460
461    /// Add endpoint to the array
462    pub fn add_endpoint(&mut self, endpoint: ModelEndpoint) {
463        self.endpoints.push(endpoint);
464    }
465
466    /// Set load balancing strategy
467    pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
468        self.strategy = strategy;
469        self
470    }
471
472    /// Configure health check
473    pub fn with_health_check(mut self, config: HealthCheckConfig) -> Self {
474        self.health_check = config;
475        self
476    }
477
478    /// Select next endpoint based on strategy
479    pub fn select_endpoint(&mut self) -> Option<&mut ModelEndpoint> {
480        if self.endpoints.is_empty() {
481            return None;
482        }
483
484        // Get indices of healthy endpoints
485        let healthy_indices: Vec<usize> = self
486            .endpoints
487            .iter()
488            .enumerate()
489            .filter(|(_, endpoint)| endpoint.healthy)
490            .map(|(index, _)| index)
491            .collect();
492
493        if healthy_indices.is_empty() {
494            return None;
495        }
496
497        match self.strategy {
498            LoadBalancingStrategy::RoundRobin => {
499                // Simple round-robin implementation
500                let index = (std::time::SystemTime::now()
501                    .duration_since(std::time::UNIX_EPOCH)
502                    .unwrap()
503                    .as_secs() as usize)
504                    % healthy_indices.len();
505                let endpoint_index = healthy_indices[index];
506                Some(&mut self.endpoints[endpoint_index])
507            }
508            LoadBalancingStrategy::Weighted => {
509                // Weighted selection - simplified implementation
510                let total_weight: f32 = healthy_indices
511                    .iter()
512                    .map(|&idx| self.endpoints[idx].weight)
513                    .sum();
514                let mut current_weight = 0.0;
515
516                for &idx in &healthy_indices {
517                    current_weight += self.endpoints[idx].weight;
518                    if current_weight >= total_weight / 2.0 {
519                        return Some(&mut self.endpoints[idx]);
520                    }
521                }
522
523                // Fallback to first healthy endpoint
524                let endpoint_index = healthy_indices[0];
525                Some(&mut self.endpoints[endpoint_index])
526            }
527            LoadBalancingStrategy::LeastConnections => {
528                // Select endpoint with least connections
529                healthy_indices
530                    .iter()
531                    .min_by_key(|&&idx| self.endpoints[idx].connection_count)
532                    .map(|&idx| &mut self.endpoints[idx])
533            }
534            LoadBalancingStrategy::HealthBased => {
535                // Select first healthy endpoint
536                let endpoint_index = healthy_indices[0];
537                Some(&mut self.endpoints[endpoint_index])
538            }
539        }
540    }
541
542    /// Mark endpoint as unhealthy
543    pub fn mark_unhealthy(&mut self, endpoint_name: &str) {
544        if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
545            endpoint.healthy = false;
546        }
547    }
548
549    /// Mark endpoint as healthy
550    pub fn mark_healthy(&mut self, endpoint_name: &str) {
551        if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
552            endpoint.healthy = true;
553        }
554    }
555
556    /// Get array health status
557    pub fn is_healthy(&self) -> bool {
558        self.endpoints.iter().any(|endpoint| endpoint.healthy)
559    }
560}
561
562impl Default for ModelCapabilities {
563    fn default() -> Self {
564        Self::new()
565    }
566}
567
568impl Default for PerformanceMetrics {
569    fn default() -> Self {
570        Self::new()
571    }
572}