Skip to main content

aster/providers/
provider_registry.rs

1use super::base::{ModelInfo, Provider, ProviderMetadata, ProviderType};
2use crate::config::DeclarativeProviderConfig;
3use crate::model::ModelConfig;
4use anyhow::Result;
5use futures::future::BoxFuture;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9type ProviderConstructor =
10    Arc<dyn Fn(ModelConfig) -> BoxFuture<'static, Result<Arc<dyn Provider>>> + Send + Sync>;
11
12#[derive(Clone)]
13pub struct ProviderEntry {
14    metadata: ProviderMetadata,
15    pub(crate) constructor: ProviderConstructor,
16    provider_type: ProviderType,
17}
18
19impl ProviderEntry {
20    pub async fn create_with_default_model(&self) -> Result<Arc<dyn Provider>> {
21        let default_model = &self.metadata.default_model;
22        let model_config = ModelConfig::new(default_model.as_str())?;
23        (self.constructor)(model_config).await
24    }
25}
26
27#[derive(Default)]
28pub struct ProviderRegistry {
29    pub(crate) entries: HashMap<String, ProviderEntry>,
30}
31
32impl ProviderRegistry {
33    pub fn new() -> Self {
34        Self {
35            entries: HashMap::new(),
36        }
37    }
38
39    pub fn register<P, F>(&mut self, constructor: F, preferred: bool)
40    where
41        P: Provider + 'static,
42        F: Fn(ModelConfig) -> BoxFuture<'static, Result<P>> + Send + Sync + 'static,
43    {
44        let metadata = P::metadata();
45        let name = metadata.name.clone();
46
47        self.entries.insert(
48            name,
49            ProviderEntry {
50                metadata,
51                constructor: Arc::new(move |model| {
52                    let fut = constructor(model);
53                    Box::pin(async move {
54                        let provider = fut.await?;
55                        Ok(Arc::new(provider) as Arc<dyn Provider>)
56                    })
57                }),
58                provider_type: if preferred {
59                    ProviderType::Preferred
60                } else {
61                    ProviderType::Builtin
62                },
63            },
64        );
65    }
66
67    pub fn register_with_name<P, F>(
68        &mut self,
69        config: &DeclarativeProviderConfig,
70        provider_type: ProviderType,
71        constructor: F,
72    ) where
73        P: Provider + 'static,
74        F: Fn(ModelConfig) -> Result<P> + Send + Sync + 'static,
75    {
76        let base_metadata = P::metadata();
77        let description = config
78            .description
79            .clone()
80            .unwrap_or_else(|| format!("Custom {} provider", config.display_name));
81        let default_model = config
82            .models
83            .first()
84            .map(|m| m.name.clone())
85            .unwrap_or_default();
86        let known_models: Vec<ModelInfo> = config
87            .models
88            .iter()
89            .map(|m| ModelInfo {
90                name: m.name.clone(),
91                context_limit: m.context_limit,
92                input_token_cost: m.input_token_cost,
93                output_token_cost: m.output_token_cost,
94                currency: m.currency.clone(),
95                supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)),
96            })
97            .collect();
98
99        let mut config_keys = base_metadata.config_keys.clone();
100
101        if let Some(api_key_index) = config_keys
102            .iter()
103            .position(|key| key.required && key.secret)
104        {
105            config_keys[api_key_index] =
106                super::base::ConfigKey::new(&config.api_key_env, true, true, None);
107        }
108
109        let custom_metadata = ProviderMetadata {
110            name: config.name.clone(),
111            display_name: config.display_name.clone(),
112            description,
113            default_model,
114            known_models,
115            model_doc_link: base_metadata.model_doc_link,
116            config_keys,
117        };
118
119        self.entries.insert(
120            config.name.clone(),
121            ProviderEntry {
122                metadata: custom_metadata,
123                constructor: Arc::new(move |model| {
124                    let result = constructor(model);
125                    Box::pin(async move {
126                        let provider = result?;
127                        Ok(Arc::new(provider) as Arc<dyn Provider>)
128                    })
129                }),
130                provider_type,
131            },
132        );
133    }
134
135    pub fn with_providers<F>(mut self, setup: F) -> Self
136    where
137        F: FnOnce(&mut Self),
138    {
139        setup(&mut self);
140        self
141    }
142
143    pub async fn create(&self, name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
144        let entry = self
145            .entries
146            .get(name)
147            .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?;
148
149        (entry.constructor)(model).await
150    }
151
152    pub fn all_metadata_with_types(&self) -> Vec<(ProviderMetadata, ProviderType)> {
153        self.entries
154            .values()
155            .map(|e| (e.metadata.clone(), e.provider_type))
156            .collect()
157    }
158
159    pub fn remove_custom_providers(&mut self) {
160        self.entries.retain(|name, _| !name.starts_with("custom_"));
161    }
162}