1use crate::types::AiLibError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12 pub name: String,
14 pub display_name: String,
16 pub description: String,
18 pub capabilities: ModelCapabilities,
20 pub pricing: PricingInfo,
22 pub performance: PerformanceMetrics,
24 pub metadata: HashMap<String, String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelCapabilities {
31 pub chat: bool,
33 pub code_generation: bool,
35 pub multimodal: bool,
37 pub function_calling: bool,
39 pub tool_use: bool,
41 pub multilingual: bool,
43 pub context_window: Option<u32>,
45}
46
47impl ModelCapabilities {
48 pub fn new() -> Self {
50 Self {
51 chat: true,
52 code_generation: false,
53 multimodal: false,
54 function_calling: false,
55 tool_use: false,
56 multilingual: false,
57 context_window: None,
58 }
59 }
60
61 pub fn with_chat(mut self) -> Self {
63 self.chat = true;
64 self
65 }
66
67 pub fn with_code_generation(mut self) -> Self {
69 self.code_generation = true;
70 self
71 }
72
73 pub fn with_multimodal(mut self) -> Self {
75 self.multimodal = true;
76 self
77 }
78
79 pub fn with_function_calling(mut self) -> Self {
81 self.function_calling = true;
82 self
83 }
84
85 pub fn with_tool_use(mut self) -> Self {
87 self.tool_use = true;
88 self
89 }
90
91 pub fn with_multilingual(mut self) -> Self {
93 self.multilingual = true;
94 self
95 }
96
97 pub fn with_context_window(mut self, size: u32) -> Self {
99 self.context_window = Some(size);
100 self
101 }
102
103 pub fn supports(&self, capability: &str) -> bool {
105 match capability {
106 "chat" => self.chat,
107 "code_generation" => self.code_generation,
108 "multimodal" => self.multimodal,
109 "function_calling" => self.function_calling,
110 "tool_use" => self.tool_use,
111 "multilingual" => self.multilingual,
112 _ => false,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PricingInfo {
120 pub input_cost_per_1k: f64,
122 pub output_cost_per_1k: f64,
124 pub currency: String,
126}
127
128impl PricingInfo {
129 pub fn new(input_cost_per_1k: f64, output_cost_per_1k: f64) -> Self {
131 Self {
132 input_cost_per_1k,
133 output_cost_per_1k,
134 currency: "USD".to_string(),
135 }
136 }
137
138 pub fn with_currency(mut self, currency: &str) -> Self {
140 self.currency = currency.to_string();
141 self
142 }
143
144 pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
146 let input_cost = (input_tokens as f64 / 1000.0) * self.input_cost_per_1k;
147 let output_cost = (output_tokens as f64 / 1000.0) * self.output_cost_per_1k;
148 input_cost + output_cost
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155 pub speed: SpeedTier,
157 pub quality: QualityTier,
159 pub avg_response_time: Option<Duration>,
161 pub throughput: Option<f64>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum SpeedTier {
167 Fast,
168 Balanced,
169 Slow,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub enum QualityTier {
174 Basic,
175 Good,
176 Excellent,
177}
178
179impl PerformanceMetrics {
180 pub fn new() -> Self {
182 Self {
183 speed: SpeedTier::Balanced,
184 quality: QualityTier::Good,
185 avg_response_time: None,
186 throughput: None,
187 }
188 }
189
190 pub fn with_speed(mut self, speed: SpeedTier) -> Self {
192 self.speed = speed;
193 self
194 }
195
196 pub fn with_quality(mut self, quality: QualityTier) -> Self {
198 self.quality = quality;
199 self
200 }
201
202 pub fn with_avg_response_time(mut self, time: Duration) -> Self {
204 self.avg_response_time = Some(time);
205 self
206 }
207
208 pub fn with_throughput(mut self, tps: f64) -> Self {
210 self.throughput = Some(tps);
211 self
212 }
213}
214
215#[derive(Clone)]
220pub struct CustomModelManager {
221 pub provider: String,
223 pub models: HashMap<String, ModelInfo>,
225 pub selection_strategy: ModelSelectionStrategy,
227}
228
229#[derive(Debug, Clone)]
230pub enum ModelSelectionStrategy {
231 RoundRobin,
233 Weighted,
235 LeastConnections,
237 PerformanceBased,
239 CostBased,
241}
242
243impl CustomModelManager {
244 pub fn new(provider: &str) -> Self {
246 Self {
247 provider: provider.to_string(),
248 models: HashMap::new(),
249 selection_strategy: ModelSelectionStrategy::RoundRobin,
250 }
251 }
252
253 pub fn add_model(&mut self, model: ModelInfo) {
255 self.models.insert(model.name.clone(), model);
256 }
257
258 pub fn remove_model(&mut self, model_name: &str) -> Option<ModelInfo> {
260 self.models.remove(model_name)
261 }
262
263 pub fn get_model(&self, model_name: &str) -> Option<&ModelInfo> {
265 self.models.get(model_name)
266 }
267
268 pub fn list_models(&self) -> Vec<&ModelInfo> {
270 self.models.values().collect()
271 }
272
273 pub fn with_strategy(mut self, strategy: ModelSelectionStrategy) -> Self {
275 self.selection_strategy = strategy;
276 self
277 }
278
279 pub fn select_model(&self) -> Option<&ModelInfo> {
281 if self.models.is_empty() {
282 return None;
283 }
284
285 match self.selection_strategy {
286 ModelSelectionStrategy::RoundRobin => {
287 let models: Vec<&ModelInfo> = self.models.values().collect();
289 let index = (std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap()
292 .as_secs() as usize)
293 % models.len();
294 Some(models[index])
295 }
296 ModelSelectionStrategy::Weighted => {
297 self.models.values().max_by_key(|model| {
299 let speed_score = match model.performance.speed {
300 SpeedTier::Fast => 3,
301 SpeedTier::Balanced => 2,
302 SpeedTier::Slow => 1,
303 };
304 let quality_score = match model.performance.quality {
305 QualityTier::Excellent => 3,
306 QualityTier::Good => 2,
307 QualityTier::Basic => 1,
308 };
309 speed_score + quality_score
310 })
311 }
312 ModelSelectionStrategy::LeastConnections => {
313 self.models.values().next()
316 }
317 ModelSelectionStrategy::PerformanceBased => {
318 self.models.values().max_by_key(|model| match model.performance.speed {
320 SpeedTier::Fast => 3,
321 SpeedTier::Balanced => 2,
322 SpeedTier::Slow => 1,
323 })
324 }
325 ModelSelectionStrategy::CostBased => {
326 self.models.values().min_by(|a, b| {
328 let a_cost = a.pricing.input_cost_per_1k + a.pricing.output_cost_per_1k;
329 let b_cost = b.pricing.input_cost_per_1k + b.pricing.output_cost_per_1k;
330 a_cost
331 .partial_cmp(&b_cost)
332 .unwrap_or(std::cmp::Ordering::Equal)
333 })
334 }
335 }
336 }
337
338 pub fn recommend_for(&self, use_case: &str) -> Option<&ModelInfo> {
340 let supported_models: Vec<&ModelInfo> = self
341 .models
342 .values()
343 .filter(|model| model.capabilities.supports(use_case))
344 .collect();
345
346 if supported_models.is_empty() {
347 return None;
348 }
349
350 supported_models.first().copied()
353 }
354
355 pub fn load_from_config(&mut self, config_path: &str) -> Result<(), AiLibError> {
357 let config_content = std::fs::read_to_string(config_path)
358 .map_err(|e| AiLibError::ConfigurationError(format!("Failed to read config: {}", e)))?;
359
360 let models: Vec<ModelInfo> = serde_json::from_str(&config_content).map_err(|e| {
361 AiLibError::ConfigurationError(format!("Failed to parse config: {}", e))
362 })?;
363
364 for model in models {
365 self.add_model(model);
366 }
367
368 Ok(())
369 }
370
371 pub fn save_to_config(&self, config_path: &str) -> Result<(), AiLibError> {
373 let models: Vec<&ModelInfo> = self.models.values().collect();
374 let config_content = serde_json::to_string_pretty(&models).map_err(|e| {
375 AiLibError::ConfigurationError(format!("Failed to serialize config: {}", e))
376 })?;
377
378 std::fs::write(config_path, config_content).map_err(|e| {
379 AiLibError::ConfigurationError(format!("Failed to write config: {}", e))
380 })?;
381
382 Ok(())
383 }
384}
385
386#[derive(Clone)]
391pub struct ModelArray {
392 pub name: String,
394 pub endpoints: Vec<ModelEndpoint>,
396 pub strategy: LoadBalancingStrategy,
398 pub health_check: HealthCheckConfig,
400}
401
402#[derive(Debug, Clone)]
404pub struct ModelEndpoint {
405 pub name: String,
407 pub model_name: String,
409 pub url: String,
411 pub weight: f32,
413 pub healthy: bool,
415 pub connection_count: u32,
417}
418
419#[derive(Debug, Clone)]
421pub enum LoadBalancingStrategy {
422 RoundRobin,
424 Weighted,
426 LeastConnections,
428 HealthBased,
430}
431
432#[derive(Debug, Clone)]
434pub struct HealthCheckConfig {
435 pub endpoint: String,
437 pub interval: Duration,
439 pub timeout: Duration,
441 pub max_failures: u32,
443}
444
445impl ModelArray {
446 pub fn new(name: &str) -> Self {
448 Self {
449 name: name.to_string(),
450 endpoints: Vec::new(),
451 strategy: LoadBalancingStrategy::RoundRobin,
452 health_check: HealthCheckConfig {
453 endpoint: "/health".to_string(),
454 interval: Duration::from_secs(30),
455 timeout: Duration::from_secs(5),
456 max_failures: 3,
457 },
458 }
459 }
460
461 pub fn add_endpoint(&mut self, endpoint: ModelEndpoint) {
463 self.endpoints.push(endpoint);
464 }
465
466 pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
468 self.strategy = strategy;
469 self
470 }
471
472 pub fn with_health_check(mut self, config: HealthCheckConfig) -> Self {
474 self.health_check = config;
475 self
476 }
477
478 pub fn select_endpoint(&mut self) -> Option<&mut ModelEndpoint> {
480 if self.endpoints.is_empty() {
481 return None;
482 }
483
484 let healthy_indices: Vec<usize> = self
486 .endpoints
487 .iter()
488 .enumerate()
489 .filter(|(_, endpoint)| endpoint.healthy)
490 .map(|(index, _)| index)
491 .collect();
492
493 if healthy_indices.is_empty() {
494 return None;
495 }
496
497 match self.strategy {
498 LoadBalancingStrategy::RoundRobin => {
499 let index = (std::time::SystemTime::now()
501 .duration_since(std::time::UNIX_EPOCH)
502 .unwrap()
503 .as_secs() as usize)
504 % healthy_indices.len();
505 let endpoint_index = healthy_indices[index];
506 Some(&mut self.endpoints[endpoint_index])
507 }
508 LoadBalancingStrategy::Weighted => {
509 let total_weight: f32 = healthy_indices
511 .iter()
512 .map(|&idx| self.endpoints[idx].weight)
513 .sum();
514 let mut current_weight = 0.0;
515
516 for &idx in &healthy_indices {
517 current_weight += self.endpoints[idx].weight;
518 if current_weight >= total_weight / 2.0 {
519 return Some(&mut self.endpoints[idx]);
520 }
521 }
522
523 let endpoint_index = healthy_indices[0];
525 Some(&mut self.endpoints[endpoint_index])
526 }
527 LoadBalancingStrategy::LeastConnections => {
528 healthy_indices
530 .iter()
531 .min_by_key(|&&idx| self.endpoints[idx].connection_count)
532 .map(|&idx| &mut self.endpoints[idx])
533 }
534 LoadBalancingStrategy::HealthBased => {
535 let endpoint_index = healthy_indices[0];
537 Some(&mut self.endpoints[endpoint_index])
538 }
539 }
540 }
541
542 pub fn mark_unhealthy(&mut self, endpoint_name: &str) {
544 if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
545 endpoint.healthy = false;
546 }
547 }
548
549 pub fn mark_healthy(&mut self, endpoint_name: &str) {
551 if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
552 endpoint.healthy = true;
553 }
554 }
555
556 pub fn is_healthy(&self) -> bool {
558 self.endpoints.iter().any(|endpoint| endpoint.healthy)
559 }
560}
561
562impl Default for ModelCapabilities {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568impl Default for PerformanceMetrics {
569 fn default() -> Self {
570 Self::new()
571 }
572}