1use crate::types::AiLibError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12 pub name: String,
14 pub display_name: String,
16 pub description: String,
18 pub capabilities: ModelCapabilities,
20 pub pricing: PricingInfo,
22 pub performance: PerformanceMetrics,
24 pub metadata: HashMap<String, String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelCapabilities {
31 pub chat: bool,
33 pub code_generation: bool,
35 pub multimodal: bool,
37 pub function_calling: bool,
39 pub tool_use: bool,
41 pub multilingual: bool,
43 pub context_window: Option<u32>,
45}
46
47impl ModelCapabilities {
48 pub fn new() -> Self {
50 Self {
51 chat: true,
52 code_generation: false,
53 multimodal: false,
54 function_calling: false,
55 tool_use: false,
56 multilingual: false,
57 context_window: None,
58 }
59 }
60
61 pub fn with_chat(mut self) -> Self {
63 self.chat = true;
64 self
65 }
66
67 pub fn with_code_generation(mut self) -> Self {
69 self.code_generation = true;
70 self
71 }
72
73 pub fn with_multimodal(mut self) -> Self {
75 self.multimodal = true;
76 self
77 }
78
79 pub fn with_function_calling(mut self) -> Self {
81 self.function_calling = true;
82 self
83 }
84
85 pub fn with_tool_use(mut self) -> Self {
87 self.tool_use = true;
88 self
89 }
90
91 pub fn with_multilingual(mut self) -> Self {
93 self.multilingual = true;
94 self
95 }
96
97 pub fn with_context_window(mut self, size: u32) -> Self {
99 self.context_window = Some(size);
100 self
101 }
102
103 pub fn supports(&self, capability: &str) -> bool {
105 match capability {
106 "chat" => self.chat,
107 "code_generation" => self.code_generation,
108 "multimodal" => self.multimodal,
109 "function_calling" => self.function_calling,
110 "tool_use" => self.tool_use,
111 "multilingual" => self.multilingual,
112 _ => false,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct PricingInfo {
120 pub input_cost_per_1k: f64,
122 pub output_cost_per_1k: f64,
124 pub currency: String,
126}
127
128impl PricingInfo {
129 pub fn new(input_cost_per_1k: f64, output_cost_per_1k: f64) -> Self {
131 Self {
132 input_cost_per_1k,
133 output_cost_per_1k,
134 currency: "USD".to_string(),
135 }
136 }
137
138 pub fn with_currency(mut self, currency: &str) -> Self {
140 self.currency = currency.to_string();
141 self
142 }
143
144 pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
146 let input_cost = (input_tokens as f64 / 1000.0) * self.input_cost_per_1k;
147 let output_cost = (output_tokens as f64 / 1000.0) * self.output_cost_per_1k;
148 input_cost + output_cost
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155 pub speed: SpeedTier,
157 pub quality: QualityTier,
159 pub avg_response_time: Option<Duration>,
161 pub throughput: Option<f64>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum SpeedTier {
167 Fast,
168 Balanced,
169 Slow,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub enum QualityTier {
174 Basic,
175 Good,
176 Excellent,
177}
178
179impl PerformanceMetrics {
180 pub fn new() -> Self {
182 Self {
183 speed: SpeedTier::Balanced,
184 quality: QualityTier::Good,
185 avg_response_time: None,
186 throughput: None,
187 }
188 }
189
190 pub fn with_speed(mut self, speed: SpeedTier) -> Self {
192 self.speed = speed;
193 self
194 }
195
196 pub fn with_quality(mut self, quality: QualityTier) -> Self {
198 self.quality = quality;
199 self
200 }
201
202 pub fn with_avg_response_time(mut self, time: Duration) -> Self {
204 self.avg_response_time = Some(time);
205 self
206 }
207
208 pub fn with_throughput(mut self, tps: f64) -> Self {
210 self.throughput = Some(tps);
211 self
212 }
213}
214
215#[derive(Clone)]
220pub struct CustomModelManager {
221 pub provider: String,
223 pub models: HashMap<String, ModelInfo>,
225 pub selection_strategy: ModelSelectionStrategy,
227}
228
229#[derive(Debug, Clone)]
230pub enum ModelSelectionStrategy {
231 RoundRobin,
233 Weighted,
235 LeastConnections,
237 PerformanceBased,
239 CostBased,
241}
242
243impl CustomModelManager {
244 pub fn new(provider: &str) -> Self {
246 Self {
247 provider: provider.to_string(),
248 models: HashMap::new(),
249 selection_strategy: ModelSelectionStrategy::RoundRobin,
250 }
251 }
252
253 pub fn add_model(&mut self, model: ModelInfo) {
255 self.models.insert(model.name.clone(), model);
256 }
257
258 pub fn remove_model(&mut self, model_name: &str) -> Option<ModelInfo> {
260 self.models.remove(model_name)
261 }
262
263 pub fn get_model(&self, model_name: &str) -> Option<&ModelInfo> {
265 self.models.get(model_name)
266 }
267
268 pub fn list_models(&self) -> Vec<&ModelInfo> {
270 self.models.values().collect()
271 }
272
273 pub fn with_strategy(mut self, strategy: ModelSelectionStrategy) -> Self {
275 self.selection_strategy = strategy;
276 self
277 }
278
279 pub fn select_model(&self) -> Option<&ModelInfo> {
281 if self.models.is_empty() {
282 return None;
283 }
284
285 match self.selection_strategy {
286 ModelSelectionStrategy::RoundRobin => {
287 let models: Vec<&ModelInfo> = self.models.values().collect();
289 let index = (std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap()
292 .as_secs() as usize)
293 % models.len();
294 Some(models[index])
295 }
296 ModelSelectionStrategy::Weighted => {
297 self.models.values().max_by_key(|model| {
299 let speed_score = match model.performance.speed {
300 SpeedTier::Fast => 3,
301 SpeedTier::Balanced => 2,
302 SpeedTier::Slow => 1,
303 };
304 let quality_score = match model.performance.quality {
305 QualityTier::Excellent => 3,
306 QualityTier::Good => 2,
307 QualityTier::Basic => 1,
308 };
309 speed_score + quality_score
310 })
311 }
312 ModelSelectionStrategy::LeastConnections => {
313 self.models.values().next()
316 }
317 ModelSelectionStrategy::PerformanceBased => {
318 self.models.values().max_by_key(|model| {
320 let speed_score = match model.performance.speed {
321 SpeedTier::Fast => 3,
322 SpeedTier::Balanced => 2,
323 SpeedTier::Slow => 1,
324 };
325 speed_score
326 })
327 }
328 ModelSelectionStrategy::CostBased => {
329 self.models.values().min_by(|a, b| {
331 let a_cost = a.pricing.input_cost_per_1k + a.pricing.output_cost_per_1k;
332 let b_cost = b.pricing.input_cost_per_1k + b.pricing.output_cost_per_1k;
333 a_cost
334 .partial_cmp(&b_cost)
335 .unwrap_or(std::cmp::Ordering::Equal)
336 })
337 }
338 }
339 }
340
341 pub fn recommend_for(&self, use_case: &str) -> Option<&ModelInfo> {
343 let supported_models: Vec<&ModelInfo> = self
344 .models
345 .values()
346 .filter(|model| model.capabilities.supports(use_case))
347 .collect();
348
349 if supported_models.is_empty() {
350 return None;
351 }
352
353 supported_models.first().copied()
356 }
357
358 pub fn load_from_config(&mut self, config_path: &str) -> Result<(), AiLibError> {
360 let config_content = std::fs::read_to_string(config_path)
361 .map_err(|e| AiLibError::ConfigurationError(format!("Failed to read config: {}", e)))?;
362
363 let models: Vec<ModelInfo> = serde_json::from_str(&config_content).map_err(|e| {
364 AiLibError::ConfigurationError(format!("Failed to parse config: {}", e))
365 })?;
366
367 for model in models {
368 self.add_model(model);
369 }
370
371 Ok(())
372 }
373
374 pub fn save_to_config(&self, config_path: &str) -> Result<(), AiLibError> {
376 let models: Vec<&ModelInfo> = self.models.values().collect();
377 let config_content = serde_json::to_string_pretty(&models).map_err(|e| {
378 AiLibError::ConfigurationError(format!("Failed to serialize config: {}", e))
379 })?;
380
381 std::fs::write(config_path, config_content).map_err(|e| {
382 AiLibError::ConfigurationError(format!("Failed to write config: {}", e))
383 })?;
384
385 Ok(())
386 }
387}
388
389pub struct ModelArray {
394 pub name: String,
396 pub endpoints: Vec<ModelEndpoint>,
398 pub strategy: LoadBalancingStrategy,
400 pub health_check: HealthCheckConfig,
402}
403
404#[derive(Debug, Clone)]
406pub struct ModelEndpoint {
407 pub name: String,
409 pub model_name: String,
411 pub url: String,
413 pub weight: f32,
415 pub healthy: bool,
417 pub connection_count: u32,
419}
420
421#[derive(Debug, Clone)]
423pub enum LoadBalancingStrategy {
424 RoundRobin,
426 Weighted,
428 LeastConnections,
430 HealthBased,
432}
433
434#[derive(Debug, Clone)]
436pub struct HealthCheckConfig {
437 pub endpoint: String,
439 pub interval: Duration,
441 pub timeout: Duration,
443 pub max_failures: u32,
445}
446
447impl ModelArray {
448 pub fn new(name: &str) -> Self {
450 Self {
451 name: name.to_string(),
452 endpoints: Vec::new(),
453 strategy: LoadBalancingStrategy::RoundRobin,
454 health_check: HealthCheckConfig {
455 endpoint: "/health".to_string(),
456 interval: Duration::from_secs(30),
457 timeout: Duration::from_secs(5),
458 max_failures: 3,
459 },
460 }
461 }
462
463 pub fn add_endpoint(&mut self, endpoint: ModelEndpoint) {
465 self.endpoints.push(endpoint);
466 }
467
468 pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
470 self.strategy = strategy;
471 self
472 }
473
474 pub fn with_health_check(mut self, config: HealthCheckConfig) -> Self {
476 self.health_check = config;
477 self
478 }
479
480 pub fn select_endpoint(&mut self) -> Option<&mut ModelEndpoint> {
482 if self.endpoints.is_empty() {
483 return None;
484 }
485
486 let healthy_indices: Vec<usize> = self
488 .endpoints
489 .iter()
490 .enumerate()
491 .filter(|(_, endpoint)| endpoint.healthy)
492 .map(|(index, _)| index)
493 .collect();
494
495 if healthy_indices.is_empty() {
496 return None;
497 }
498
499 match self.strategy {
500 LoadBalancingStrategy::RoundRobin => {
501 let index = (std::time::SystemTime::now()
503 .duration_since(std::time::UNIX_EPOCH)
504 .unwrap()
505 .as_secs() as usize)
506 % healthy_indices.len();
507 let endpoint_index = healthy_indices[index];
508 Some(&mut self.endpoints[endpoint_index])
509 }
510 LoadBalancingStrategy::Weighted => {
511 let total_weight: f32 = healthy_indices
513 .iter()
514 .map(|&idx| self.endpoints[idx].weight)
515 .sum();
516 let mut current_weight = 0.0;
517
518 for &idx in &healthy_indices {
519 current_weight += self.endpoints[idx].weight;
520 if current_weight >= total_weight / 2.0 {
521 return Some(&mut self.endpoints[idx]);
522 }
523 }
524
525 let endpoint_index = healthy_indices[0];
527 Some(&mut self.endpoints[endpoint_index])
528 }
529 LoadBalancingStrategy::LeastConnections => {
530 healthy_indices
532 .iter()
533 .min_by_key(|&&idx| self.endpoints[idx].connection_count)
534 .map(|&idx| &mut self.endpoints[idx])
535 }
536 LoadBalancingStrategy::HealthBased => {
537 let endpoint_index = healthy_indices[0];
539 Some(&mut self.endpoints[endpoint_index])
540 }
541 }
542 }
543
544 pub fn mark_unhealthy(&mut self, endpoint_name: &str) {
546 if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
547 endpoint.healthy = false;
548 }
549 }
550
551 pub fn mark_healthy(&mut self, endpoint_name: &str) {
553 if let Some(endpoint) = self.endpoints.iter_mut().find(|e| e.name == endpoint_name) {
554 endpoint.healthy = true;
555 }
556 }
557
558 pub fn is_healthy(&self) -> bool {
560 self.endpoints.iter().any(|endpoint| endpoint.healthy)
561 }
562}
563
564impl Default for ModelCapabilities {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570impl Default for PerformanceMetrics {
571 fn default() -> Self {
572 Self::new()
573 }
574}