1use crate::{Error, Result};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::env;
12
13use super::{
14 Message, GenerateOptions, GenerateResult, StreamChunk, FinishReason,
15 Usage, ToolDefinition, MessageRole, MessageContent
16};
17
18#[async_trait]
20pub trait Provider: Send + Sync {
21 fn id(&self) -> &str;
23
24 fn name(&self) -> &str;
26
27 fn base_url(&self) -> &str;
29
30 fn api_version(&self) -> &str;
32
33 async fn list_models(&self) -> Result<Vec<ModelInfo>>;
35
36 async fn get_model(&self, model_id: &str) -> Result<Arc<dyn Model>>;
38
39 async fn health_check(&self) -> Result<ProviderHealth>;
41
42 fn get_config(&self) -> &ProviderConfig;
44
45 async fn update_config(&mut self, config: ProviderConfig) -> Result<()>;
47
48 async fn get_rate_limits(&self) -> Result<RateLimitInfo>;
50
51 async fn get_usage(&self) -> Result<UsageStats>;
53}
54
55#[async_trait]
57pub trait Model: Send + Sync {
58 fn id(&self) -> &str;
60
61 fn name(&self) -> &str;
63
64 fn provider_id(&self) -> &str;
66
67 fn capabilities(&self) -> &ModelCapabilities;
69
70 fn config(&self) -> &ModelConfig;
72
73 async fn generate(
75 &self,
76 messages: Vec<Message>,
77 options: GenerateOptions,
78 ) -> Result<GenerateResult>;
79
80 async fn stream(
82 &self,
83 messages: Vec<Message>,
84 options: GenerateOptions,
85 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>>;
86
87 async fn count_tokens(&self, messages: &[Message]) -> Result<u32>;
89
90 async fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> Result<f64>;
92
93 fn metadata(&self) -> &ModelMetadata;
95}
96
97use futures::Stream;
98use std::pin::Pin;
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ModelInfo {
103 pub id: String,
104 pub name: String,
105 pub description: Option<String>,
106 pub capabilities: ModelCapabilities,
107 pub limits: ModelLimits,
108 pub pricing: ModelPricing,
109 pub release_date: Option<chrono::DateTime<chrono::Utc>>,
110 pub status: ModelStatus,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ModelCapabilities {
116 pub text_generation: bool,
118
119 pub tool_calling: bool,
121
122 pub vision: bool,
124
125 pub streaming: bool,
127
128 pub caching: bool,
130
131 pub json_mode: bool,
133
134 pub reasoning: bool,
136
137 pub code_generation: bool,
139
140 pub multilingual: bool,
142
143 pub custom: HashMap<String, serde_json::Value>,
145}
146
147impl Default for ModelCapabilities {
148 fn default() -> Self {
149 Self {
150 text_generation: true,
151 tool_calling: false,
152 vision: false,
153 streaming: true,
154 caching: false,
155 json_mode: false,
156 reasoning: false,
157 code_generation: false,
158 multilingual: false,
159 custom: HashMap::new(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelLimits {
167 pub max_context_tokens: u32,
169
170 pub max_output_tokens: u32,
172
173 pub max_image_size_bytes: Option<u64>,
175
176 pub max_images_per_request: Option<u32>,
178
179 pub max_tool_calls: Option<u32>,
181
182 pub rate_limits: RateLimitInfo,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct ModelPricing {
189 pub input_cost_per_1k: f64,
191
192 pub output_cost_per_1k: f64,
194
195 pub cache_read_cost_per_1k: Option<f64>,
197
198 pub cache_write_cost_per_1k: Option<f64>,
200
201 pub currency: String,
203}
204
205impl Default for ModelPricing {
206 fn default() -> Self {
207 Self {
208 input_cost_per_1k: 0.0,
209 output_cost_per_1k: 0.0,
210 cache_read_cost_per_1k: None,
211 cache_write_cost_per_1k: None,
212 currency: "USD".to_string(),
213 }
214 }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
219#[serde(rename_all = "lowercase")]
220pub enum ModelStatus {
221 Active,
223
224 Deprecated,
226
227 Beta,
229
230 Unavailable,
232
233 Discontinued,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct RateLimitInfo {
240 pub requests_per_minute: Option<u32>,
242
243 pub tokens_per_minute: Option<u32>,
245
246 pub tokens_per_day: Option<u32>,
248
249 pub concurrent_requests: Option<u32>,
251
252 pub current_usage: Option<CurrentUsage>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct CurrentUsage {
259 pub requests_this_minute: u32,
261
262 pub tokens_this_minute: u32,
264
265 pub tokens_today: u32,
267
268 pub active_requests: u32,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct ProviderHealth {
275 pub available: bool,
277
278 pub latency_ms: Option<u64>,
280
281 pub error: Option<String>,
283
284 pub last_check: chrono::DateTime<chrono::Utc>,
286
287 pub details: HashMap<String, serde_json::Value>,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ProviderConfig {
294 pub provider_id: String,
296
297 pub api_key: Option<String>,
299
300 pub base_url_override: Option<String>,
302
303 pub api_version_override: Option<String>,
305
306 pub timeout_seconds: u64,
308
309 pub max_retries: u32,
311
312 pub retry_delay_ms: u64,
314
315 pub custom_headers: HashMap<String, String>,
317
318 pub organization_id: Option<String>,
320
321 pub project_id: Option<String>,
323
324 pub extra: HashMap<String, serde_json::Value>,
326}
327
328impl Default for ProviderConfig {
329 fn default() -> Self {
330 Self {
331 provider_id: String::new(),
332 api_key: None,
333 base_url_override: None,
334 api_version_override: None,
335 timeout_seconds: 60,
336 max_retries: 3,
337 retry_delay_ms: 1000,
338 custom_headers: HashMap::new(),
339 organization_id: None,
340 project_id: None,
341 extra: HashMap::new(),
342 }
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ModelConfig {
349 pub model_id: String,
351
352 pub default_temperature: Option<f32>,
354
355 pub default_max_tokens: Option<u32>,
357
358 pub default_top_p: Option<f32>,
360
361 pub default_stop_sequences: Vec<String>,
363
364 pub use_caching: bool,
366
367 pub options: HashMap<String, serde_json::Value>,
369}
370
371impl Default for ModelConfig {
372 fn default() -> Self {
373 Self {
374 model_id: String::new(),
375 default_temperature: None,
376 default_max_tokens: None,
377 default_top_p: None,
378 default_stop_sequences: Vec::new(),
379 use_caching: false,
380 options: HashMap::new(),
381 }
382 }
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct ModelMetadata {
388 pub family: String,
390
391 pub parameters: Option<String>,
393
394 pub training_cutoff: Option<chrono::DateTime<chrono::Utc>>,
396
397 pub version: Option<String>,
399
400 pub extra: HashMap<String, serde_json::Value>,
402}
403
404impl Default for ModelMetadata {
405 fn default() -> Self {
406 Self {
407 family: String::new(),
408 parameters: None,
409 training_cutoff: None,
410 version: None,
411 extra: HashMap::new(),
412 }
413 }
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct UsageStats {
419 pub total_requests: u64,
421
422 pub total_tokens: u64,
424
425 pub total_cost: f64,
427
428 pub currency: String,
430
431 pub by_model: HashMap<String, ModelUsage>,
433
434 pub by_period: HashMap<String, PeriodUsage>,
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct ModelUsage {
441 pub requests: u64,
443
444 pub input_tokens: u64,
446
447 pub output_tokens: u64,
449
450 pub cache_hits: u64,
452
453 pub cost: f64,
455
456 pub avg_latency_ms: f64,
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct PeriodUsage {
463 pub start: chrono::DateTime<chrono::Utc>,
465
466 pub end: chrono::DateTime<chrono::Utc>,
468
469 pub requests: u64,
471
472 pub tokens: u64,
474
475 pub cost: f64,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct Cost {
482 pub input_per_1k: f64,
484
485 pub output_per_1k: f64,
487
488 pub currency: String,
490}
491
492#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct Limits {
495 pub max_context_tokens: u32,
497
498 pub max_output_tokens: u32,
500}
501
502#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
504#[serde(rename_all = "lowercase")]
505pub enum ProviderSource {
506 Official,
508
509 Community,
511
512 Custom,
514}
515
516#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
518#[serde(rename_all = "lowercase")]
519pub enum ProviderStatus {
520 Active,
522
523 Beta,
525
526 Deprecated,
528
529 Unavailable,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct RetryConfig {
536 pub max_retries: u32,
538
539 pub initial_delay_ms: u64,
541
542 pub max_delay_ms: u64,
544
545 pub multiplier: f32,
547}
548
549impl Default for RetryConfig {
550 fn default() -> Self {
551 Self {
552 max_retries: 3,
553 initial_delay_ms: 1000,
554 max_delay_ms: 10000,
555 multiplier: 2.0,
556 }
557 }
558}
559
560pub async fn retry_with_backoff<F, T, E>(
562 config: &RetryConfig,
563 operation: F,
564) -> Result<T>
565where
566 F: Fn() -> futures::future::BoxFuture<'static, Result<T>>,
567{
568 use tokio::time::{sleep, Duration};
569
570 let mut attempts = 0;
571 let mut delay = config.initial_delay_ms;
572
573 loop {
574 match operation().await {
575 Ok(result) => return Ok(result),
576 Err(e) if attempts < config.max_retries => {
577 attempts += 1;
578 sleep(Duration::from_millis(delay)).await;
579 delay = (delay as f32 * config.multiplier) as u64;
580 delay = delay.min(config.max_delay_ms);
581 }
582 Err(e) => return Err(e),
583 }
584 }
585}
586
587pub struct ProviderRegistry {
589 providers: HashMap<String, Arc<dyn Provider>>,
590 models: HashMap<String, Arc<dyn Model>>,
591 default_provider: Option<String>,
592 storage: Arc<dyn crate::auth::AuthStorage>,
593}
594
595impl ProviderRegistry {
596 pub fn new(storage: Arc<dyn crate::auth::AuthStorage>) -> Self {
598 Self {
599 providers: HashMap::new(),
600 models: HashMap::new(),
601 default_provider: None,
602 storage,
603 }
604 }
605
606 pub fn register_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
608 let provider_id = provider.id().to_string();
609
610 if self.providers.contains_key(&provider_id) {
611 return Err(Error::Other(anyhow::anyhow!(
612 "Provider {} is already registered",
613 provider_id
614 )));
615 }
616
617 self.providers.insert(provider_id, provider);
618 Ok(())
619 }
620
621 pub fn get_provider(&self, provider_id: &str) -> Result<Arc<dyn Provider>> {
623 self.providers
624 .get(provider_id)
625 .cloned()
626 .ok_or_else(|| Error::Other(anyhow::anyhow!("Provider {} not found", provider_id)))
627 }
628
629 pub fn list_providers(&self) -> Vec<String> {
631 self.providers.keys().cloned().collect()
632 }
633
634 pub async fn get_model(&mut self, provider_id: &str, model_id: &str) -> Result<Arc<dyn Model>> {
636 let key = format!("{}/{}", provider_id, model_id);
637
638 if let Some(model) = self.models.get(&key) {
640 return Ok(model.clone());
641 }
642
643 let provider = self.get_provider(provider_id)?;
645 let model = provider.get_model(model_id).await?;
646
647 self.models.insert(key, model.clone());
649
650 Ok(model)
651 }
652
653 pub fn parse_model_string(&self, model_string: &str) -> Result<(String, String)> {
655 if let Some((provider, model)) = model_string.split_once('/') {
656 Ok((provider.to_string(), model.to_string()))
657 } else if let Some((provider, model)) = model_string.split_once(':') {
658 Ok((provider.to_string(), model.to_string()))
659 } else {
660 if let Some(default_provider) = &self.default_provider {
662 Ok((default_provider.clone(), model_string.to_string()))
663 } else {
664 Err(Error::Other(anyhow::anyhow!(
665 "Invalid model string format: {}. Expected 'provider/model' or 'provider:model'",
666 model_string
667 )))
668 }
669 }
670 }
671
672 pub fn set_default_provider(&mut self, provider_id: &str) -> Result<()> {
674 if !self.providers.contains_key(provider_id) {
675 return Err(Error::Other(anyhow::anyhow!(
676 "Provider {} is not registered",
677 provider_id
678 )));
679 }
680
681 self.default_provider = Some(provider_id.to_string());
682 Ok(())
683 }
684
685 pub fn get_default_provider(&self) -> Option<&str> {
687 self.default_provider.as_deref()
688 }
689
690 pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
692 let mut all_models = Vec::new();
693
694 for provider in self.providers.values() {
695 match provider.list_models().await {
696 Ok(models) => all_models.extend(models),
697 Err(e) => {
698 tracing::warn!("Failed to list models for provider {}: {}", provider.id(), e);
699 }
700 }
701 }
702
703 Ok(all_models)
704 }
705
706 pub async fn get_all_provider_health(&self) -> HashMap<String, ProviderHealth> {
708 let mut health_status = HashMap::new();
709
710 for (id, provider) in &self.providers {
711 match provider.health_check().await {
712 Ok(health) => {
713 health_status.insert(id.clone(), health);
714 }
715 Err(e) => {
716 health_status.insert(
717 id.clone(),
718 ProviderHealth {
719 available: false,
720 latency_ms: None,
721 error: Some(e.to_string()),
722 last_check: chrono::Utc::now(),
723 details: HashMap::new(),
724 },
725 );
726 }
727 }
728 }
729
730 health_status
731 }
732
733 pub fn clear_model_cache(&mut self) {
735 self.models.clear();
736 }
737
738 pub fn remove_provider(&mut self, provider_id: &str) -> Result<()> {
740 if !self.providers.contains_key(provider_id) {
741 return Err(Error::Other(anyhow::anyhow!(
742 "Provider {} is not registered",
743 provider_id
744 )));
745 }
746
747 self.providers.remove(provider_id);
749
750 self.models.retain(|key, _| !key.starts_with(&format!("{}/", provider_id)));
752
753 if self.default_provider.as_deref() == Some(provider_id) {
755 self.default_provider = None;
756 }
757
758 Ok(())
759 }
760
761 pub async fn discover_from_env(&mut self) -> Result<()> {
763 if env::var("ANTHROPIC_API_KEY").is_ok() {
765 if let Ok(provider) = self.create_anthropic_provider().await {
766 self.register_provider(provider)?;
767 }
768 }
769
770 if env::var("OPENAI_API_KEY").is_ok() {
772 if let Ok(provider) = self.create_openai_provider().await {
773 self.register_provider(provider)?;
774 }
775 }
776
777 if env::var("GITHUB_TOKEN").is_ok() || env::var("GITHUB_COPILOT_TOKEN").is_ok() {
779 if let Ok(provider) = self.create_github_copilot_provider().await {
780 self.register_provider(provider)?;
781 }
782 }
783
784 Ok(())
785 }
786
787 pub async fn discover_from_storage(&mut self) -> Result<()> {
789 if let Ok(Some(_)) = self.storage.get("anthropic").await {
791 if let Ok(provider) = self.create_anthropic_provider().await {
792 self.register_provider(provider)?;
793 }
794 }
795
796 if let Ok(Some(_)) = self.storage.get("openai").await {
798 if let Ok(provider) = self.create_openai_provider().await {
799 self.register_provider(provider)?;
800 }
801 }
802
803 if let Ok(Some(_)) = self.storage.get("github-copilot").await {
805 if let Ok(provider) = self.create_github_copilot_provider().await {
806 self.register_provider(provider)?;
807 }
808 }
809
810 Ok(())
811 }
812
813 pub async fn initialize_all(&mut self) -> Result<()> {
815 let provider_ids: Vec<String> = self.providers.keys().cloned().collect();
816
817 for provider_id in provider_ids {
818 match self.providers.get(&provider_id) {
819 Some(provider) => {
820 if let Err(e) = provider.health_check().await {
822 tracing::warn!("Failed to initialize provider {}: {}", provider_id, e);
823 }
824 }
825 None => continue,
826 }
827 }
828
829 Ok(())
830 }
831
832 pub async fn load_models_dev(&mut self) -> Result<()> {
834 tracing::info!("Loading models from models.dev (using built-in configs for now)");
837 Ok(())
838 }
839
840 pub async fn load_configs(&mut self, path: &str) -> Result<()> {
842 use std::path::Path;
843 use tokio::fs;
844
845 let path = Path::new(path);
846 if !path.exists() {
847 return Err(Error::Other(anyhow::anyhow!(
848 "Configuration file not found: {}",
849 path.display()
850 )));
851 }
852
853 let contents = fs::read_to_string(path).await?;
854 let configs: HashMap<String, ProviderConfig> = serde_json::from_str(&contents)?;
855
856 for (provider_id, config) in configs {
857 if self.providers.contains_key(&provider_id) {
860 tracing::warn!("Cannot update config for provider {} - providers are immutable through Arc", provider_id);
861 }
862 }
863
864 Ok(())
865 }
866
867 pub async fn get(&self, provider_id: &str) -> Option<Arc<dyn Provider>> {
869 self.providers.get(provider_id).cloned()
870 }
871
872 pub fn parse_model(model_str: &str) -> (String, String) {
874 if let Some((provider, model)) = model_str.split_once('/') {
875 (provider.to_string(), model.to_string())
876 } else if let Some((provider, model)) = model_str.split_once(':') {
877 (provider.to_string(), model.to_string())
878 } else {
879 ("anthropic".to_string(), model_str.to_string())
881 }
882 }
883
884 pub async fn get_default_model(&self, provider_id: &str) -> Result<Arc<dyn Model>> {
886 let provider = self.get_provider(provider_id)?;
887
888 let models = provider.list_models().await?;
890 if let Some(default_model) = models.iter().find(|m| m.status == ModelStatus::Active) {
891 provider.get_model(&default_model.id).await
892 } else if let Some(first_model) = models.first() {
893 provider.get_model(&first_model.id).await
894 } else {
895 Err(Error::Other(anyhow::anyhow!(
896 "Provider {} has no available models",
897 provider_id
898 )))
899 }
900 }
901
902 pub async fn available(&self) -> Vec<String> {
904 let mut available = Vec::new();
905
906 for (id, provider) in &self.providers {
907 if let Ok(health) = provider.health_check().await {
908 if health.available {
909 available.push(id.clone());
910 }
911 }
912 }
913
914 available
915 }
916
917 pub async fn list(&self) -> Vec<String> {
919 self.providers.keys().cloned().collect()
920 }
921
922 pub async fn register(&mut self, provider: Arc<dyn Provider>) {
924 let provider_id = provider.id().to_string();
925 self.providers.insert(provider_id, provider);
926 }
927
928 async fn create_anthropic_provider(&self) -> Result<Arc<dyn Provider>> {
930 Err(Error::Other(anyhow::anyhow!("Anthropic provider creation not implemented in this context")))
933 }
934
935 async fn create_openai_provider(&self) -> Result<Arc<dyn Provider>> {
936 Err(Error::Other(anyhow::anyhow!("OpenAI provider creation not implemented in this context")))
939 }
940
941 async fn create_github_copilot_provider(&self) -> Result<Arc<dyn Provider>> {
942 Err(Error::Other(anyhow::anyhow!("GitHub Copilot provider creation not implemented in this context")))
945 }
946}
947
948#[cfg(test)]
951mod tests {
952 use super::*;
953
954 #[test]
955 fn test_parse_model_string() {
956 let (provider, model) = ProviderRegistry::parse_model("anthropic/claude-3-opus");
959 assert_eq!(provider, "anthropic");
960 assert_eq!(model, "claude-3-opus");
961
962 let (provider, model) = ProviderRegistry::parse_model("openai:gpt-4");
964 assert_eq!(provider, "openai");
965 assert_eq!(model, "gpt-4");
966
967 let (provider, model) = ProviderRegistry::parse_model("claude-3-opus");
969 assert_eq!(provider, "anthropic");
970 assert_eq!(model, "claude-3-opus");
971 }
972
973 #[test]
974 fn test_model_capabilities_default() {
975 let caps = ModelCapabilities::default();
976 assert!(caps.text_generation);
977 assert!(!caps.tool_calling);
978 assert!(!caps.vision);
979 assert!(caps.streaming);
980 }
981
982 #[test]
983 fn test_provider_config_default() {
984 let config = ProviderConfig::default();
985 assert_eq!(config.timeout_seconds, 60);
986 assert_eq!(config.max_retries, 3);
987 assert_eq!(config.retry_delay_ms, 1000);
988 }
989}