1use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14pub struct ProviderManager {
20 inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24 providers: BTreeMap<CompactString, Provider>,
26 active: CompactString,
28 client: reqwest::Client,
30}
31
32#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35 pub name: CompactString,
37 pub active: bool,
39}
40
41impl ProviderManager {
42 pub fn new(active: impl Into<CompactString>) -> Self {
46 Self {
47 inner: Arc::new(RwLock::new(Inner {
48 providers: BTreeMap::new(),
49 active: active.into(),
50 client: reqwest::Client::new(),
51 })),
52 }
53 }
54
55 pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
61 if configs.is_empty() {
62 bail!("at least one provider config is required");
63 }
64 let manager = Self::new(configs[0].model.clone());
65 for config in configs {
66 manager.add_config(config).await?;
67 }
68 Ok(manager)
69 }
70
71 pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
73 let mut inner = self
74 .inner
75 .write()
76 .map_err(|_| anyhow!("provider lock poisoned"))?;
77 inner.providers.insert(name.into(), provider);
78 Ok(())
79 }
80
81 pub async fn add_config(&self, config: &ProviderConfig) -> Result<()> {
83 config.validate()?;
84 let client = {
85 let inner = self
86 .inner
87 .read()
88 .map_err(|_| anyhow!("provider lock poisoned"))?;
89 inner.client.clone()
90 };
91 let provider = build_provider(config, client).await?;
92 let mut inner = self
93 .inner
94 .write()
95 .map_err(|_| anyhow!("provider lock poisoned"))?;
96 inner.providers.insert(config.model.clone(), provider);
97 Ok(())
98 }
99
100 pub fn active(&self) -> Result<Provider> {
102 let inner = self
103 .inner
104 .read()
105 .map_err(|_| anyhow!("provider lock poisoned"))?;
106 Ok(inner.providers[&inner.active].clone())
107 }
108
109 pub fn active_model_name(&self) -> Result<CompactString> {
111 let inner = self
112 .inner
113 .read()
114 .map_err(|_| anyhow!("provider lock poisoned"))?;
115 Ok(inner.active.clone())
116 }
117
118 pub fn switch(&self, model: &str) -> Result<()> {
121 let mut inner = self
122 .inner
123 .write()
124 .map_err(|_| anyhow!("provider lock poisoned"))?;
125 if !inner.providers.contains_key(model) {
126 bail!("provider '{}' not found", model);
127 }
128 inner.active = CompactString::from(model);
129 Ok(())
130 }
131
132 pub fn remove(&self, model: &str) -> Result<()> {
135 let mut inner = self
136 .inner
137 .write()
138 .map_err(|_| anyhow!("provider lock poisoned"))?;
139 if inner.active == model {
140 bail!("cannot remove the active provider '{}'", model);
141 }
142 if inner.providers.remove(model).is_none() {
143 bail!("provider '{}' not found", model);
144 }
145 Ok(())
146 }
147
148 pub fn list(&self) -> Result<Vec<ProviderEntry>> {
150 let inner = self
151 .inner
152 .read()
153 .map_err(|_| anyhow!("provider lock poisoned"))?;
154 Ok(inner
155 .providers
156 .keys()
157 .map(|name| ProviderEntry {
158 name: name.clone(),
159 active: *name == inner.active,
160 })
161 .collect())
162 }
163
164 fn provider_for(&self, model: &str) -> Result<Provider> {
167 let inner = self
168 .inner
169 .read()
170 .map_err(|_| anyhow!("provider lock poisoned"))?;
171 inner
172 .providers
173 .get(model)
174 .cloned()
175 .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
176 }
177
178 pub async fn wait_until_ready(&self) -> Result<()> {
183 let mut provider = self.active()?;
184 provider.wait_until_ready().await
185 }
186
187 pub fn context_limit(&self, model: &str) -> usize {
192 let Ok(inner) = self.inner.read() else {
193 return default_context_limit(model);
194 };
195 if let Some(provider) = inner.providers.get(model)
196 && let Some(limit) = provider.context_length(model)
197 {
198 return limit;
199 }
200 default_context_limit(model)
201 }
202}
203
204impl Model for ProviderManager {
205 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
206 let provider = self.provider_for(&request.model)?;
207 provider.send(request).await
208 }
209
210 fn stream(
211 &self,
212 request: wcore::model::Request,
213 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
214 let result = self.provider_for(&request.model);
215 try_stream! {
216 let provider = result?;
217 let mut stream = std::pin::pin!(provider.stream(request));
218 while let Some(chunk) = stream.next().await {
219 yield chunk?;
220 }
221 }
222 }
223
224 fn context_limit(&self, model: &str) -> usize {
225 ProviderManager::context_limit(self, model)
226 }
227
228 fn active_model(&self) -> CompactString {
229 self.active_model_name()
230 .unwrap_or_else(|_| CompactString::const_new("unknown"))
231 }
232}
233
234impl std::fmt::Debug for ProviderManager {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self.inner.read() {
237 Ok(inner) => f
238 .debug_struct("ProviderManager")
239 .field("active", &inner.active)
240 .field("count", &inner.providers.len())
241 .finish(),
242 Err(_) => f
243 .debug_struct("ProviderManager")
244 .field("error", &"lock poisoned")
245 .finish(),
246 }
247 }
248}
249
250impl Clone for ProviderManager {
251 fn clone(&self) -> Self {
252 Self {
253 inner: Arc::clone(&self.inner),
254 }
255 }
256}