aster/providers/
provider_registry.rs1use super::base::{ModelInfo, Provider, ProviderMetadata, ProviderType};
2use crate::config::DeclarativeProviderConfig;
3use crate::model::ModelConfig;
4use anyhow::Result;
5use futures::future::BoxFuture;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9type ProviderConstructor =
10 Arc<dyn Fn(ModelConfig) -> BoxFuture<'static, Result<Arc<dyn Provider>>> + Send + Sync>;
11
12#[derive(Clone)]
13pub struct ProviderEntry {
14 metadata: ProviderMetadata,
15 pub(crate) constructor: ProviderConstructor,
16 provider_type: ProviderType,
17}
18
19impl ProviderEntry {
20 pub async fn create_with_default_model(&self) -> Result<Arc<dyn Provider>> {
21 let default_model = &self.metadata.default_model;
22 let model_config = ModelConfig::new(default_model.as_str())?;
23 (self.constructor)(model_config).await
24 }
25}
26
27#[derive(Default)]
28pub struct ProviderRegistry {
29 pub(crate) entries: HashMap<String, ProviderEntry>,
30}
31
32impl ProviderRegistry {
33 pub fn new() -> Self {
34 Self {
35 entries: HashMap::new(),
36 }
37 }
38
39 pub fn register<P, F>(&mut self, constructor: F, preferred: bool)
40 where
41 P: Provider + 'static,
42 F: Fn(ModelConfig) -> BoxFuture<'static, Result<P>> + Send + Sync + 'static,
43 {
44 let metadata = P::metadata();
45 let name = metadata.name.clone();
46
47 self.entries.insert(
48 name,
49 ProviderEntry {
50 metadata,
51 constructor: Arc::new(move |model| {
52 let fut = constructor(model);
53 Box::pin(async move {
54 let provider = fut.await?;
55 Ok(Arc::new(provider) as Arc<dyn Provider>)
56 })
57 }),
58 provider_type: if preferred {
59 ProviderType::Preferred
60 } else {
61 ProviderType::Builtin
62 },
63 },
64 );
65 }
66
67 pub fn register_with_name<P, F>(
68 &mut self,
69 config: &DeclarativeProviderConfig,
70 provider_type: ProviderType,
71 constructor: F,
72 ) where
73 P: Provider + 'static,
74 F: Fn(ModelConfig) -> Result<P> + Send + Sync + 'static,
75 {
76 let base_metadata = P::metadata();
77 let description = config
78 .description
79 .clone()
80 .unwrap_or_else(|| format!("Custom {} provider", config.display_name));
81 let default_model = config
82 .models
83 .first()
84 .map(|m| m.name.clone())
85 .unwrap_or_default();
86 let known_models: Vec<ModelInfo> = config
87 .models
88 .iter()
89 .map(|m| ModelInfo {
90 name: m.name.clone(),
91 context_limit: m.context_limit,
92 input_token_cost: m.input_token_cost,
93 output_token_cost: m.output_token_cost,
94 currency: m.currency.clone(),
95 supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)),
96 })
97 .collect();
98
99 let mut config_keys = base_metadata.config_keys.clone();
100
101 if let Some(api_key_index) = config_keys
102 .iter()
103 .position(|key| key.required && key.secret)
104 {
105 config_keys[api_key_index] =
106 super::base::ConfigKey::new(&config.api_key_env, true, true, None);
107 }
108
109 let custom_metadata = ProviderMetadata {
110 name: config.name.clone(),
111 display_name: config.display_name.clone(),
112 description,
113 default_model,
114 known_models,
115 model_doc_link: base_metadata.model_doc_link,
116 config_keys,
117 };
118
119 self.entries.insert(
120 config.name.clone(),
121 ProviderEntry {
122 metadata: custom_metadata,
123 constructor: Arc::new(move |model| {
124 let result = constructor(model);
125 Box::pin(async move {
126 let provider = result?;
127 Ok(Arc::new(provider) as Arc<dyn Provider>)
128 })
129 }),
130 provider_type,
131 },
132 );
133 }
134
135 pub fn with_providers<F>(mut self, setup: F) -> Self
136 where
137 F: FnOnce(&mut Self),
138 {
139 setup(&mut self);
140 self
141 }
142
143 pub async fn create(&self, name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
144 let entry = self
145 .entries
146 .get(name)
147 .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?;
148
149 (entry.constructor)(model).await
150 }
151
152 pub fn all_metadata_with_types(&self) -> Vec<(ProviderMetadata, ProviderType)> {
153 self.entries
154 .values()
155 .map(|e| (e.metadata.clone(), e.provider_type))
156 .collect()
157 }
158
159 pub fn remove_custom_providers(&mut self) {
160 self.entries.retain(|name, _| !name.starts_with("custom_"));
161 }
162}