1use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use futures_core::Stream;
8use futures_util::StreamExt;
9use std::collections::BTreeMap;
10use std::sync::{Arc, RwLock};
11use wcore::model::{Model, Response, StreamChunk, default_context_limit};
12
13pub struct ProviderRegistry {
19 inner: Arc<RwLock<Inner>>,
20}
21
22struct Inner {
23 providers: BTreeMap<String, Provider>,
25 active: String,
27 client: reqwest::Client,
29}
30
31#[derive(Debug, Clone)]
33pub struct ProviderEntry {
34 pub name: String,
36 pub active: bool,
38}
39
40impl ProviderRegistry {
41 pub fn new(active: impl Into<String>) -> Self {
45 Self {
46 inner: Arc::new(RwLock::new(Inner {
47 providers: BTreeMap::new(),
48 active: active.into(),
49 client: reqwest::Client::new(),
50 })),
51 }
52 }
53
54 pub fn from_providers(
59 active: String,
60 providers: &BTreeMap<String, ProviderDef>,
61 ) -> Result<Self> {
62 let registry = Self::new(active);
63 for def in providers.values() {
64 registry.add_def(def)?;
65 }
66 Ok(registry)
67 }
68
69 pub fn add_provider(&self, name: impl Into<String>, provider: Provider) -> Result<()> {
71 let mut inner = self
72 .inner
73 .write()
74 .map_err(|_| anyhow!("provider lock poisoned"))?;
75 inner.providers.insert(name.into(), provider);
76 Ok(())
77 }
78
79 pub fn add_def(&self, def: &ProviderDef) -> Result<()> {
81 let client = {
82 let inner = self
83 .inner
84 .read()
85 .map_err(|_| anyhow!("provider lock poisoned"))?;
86 inner.client.clone()
87 };
88 for model_name in &def.models {
89 let provider = build_provider(def, model_name, client.clone())?;
90 let mut inner = self
91 .inner
92 .write()
93 .map_err(|_| anyhow!("provider lock poisoned"))?;
94 inner.providers.insert(model_name.to_string(), provider);
95 }
96 Ok(())
97 }
98
99 pub fn active(&self) -> Result<Provider> {
101 let inner = self
102 .inner
103 .read()
104 .map_err(|_| anyhow!("provider lock poisoned"))?;
105 Ok(inner.providers[&inner.active].clone())
106 }
107
108 pub fn active_model_name(&self) -> Result<String> {
110 let inner = self
111 .inner
112 .read()
113 .map_err(|_| anyhow!("provider lock poisoned"))?;
114 Ok(inner.active.clone())
115 }
116
117 pub fn switch(&self, model: &str) -> Result<()> {
120 let mut inner = self
121 .inner
122 .write()
123 .map_err(|_| anyhow!("provider lock poisoned"))?;
124 if !inner.providers.contains_key(model) {
125 bail!("provider '{}' not found", model);
126 }
127 inner.active = model.to_owned();
128 Ok(())
129 }
130
131 pub fn remove(&self, model: &str) -> Result<()> {
134 let mut inner = self
135 .inner
136 .write()
137 .map_err(|_| anyhow!("provider lock poisoned"))?;
138 if inner.active == model {
139 bail!("cannot remove the active provider '{}'", model);
140 }
141 if inner.providers.remove(model).is_none() {
142 bail!("provider '{}' not found", model);
143 }
144 Ok(())
145 }
146
147 pub fn list(&self) -> Result<Vec<ProviderEntry>> {
149 let inner = self
150 .inner
151 .read()
152 .map_err(|_| anyhow!("provider lock poisoned"))?;
153 Ok(inner
154 .providers
155 .keys()
156 .map(|name| ProviderEntry {
157 name: name.clone(),
158 active: *name == inner.active,
159 })
160 .collect())
161 }
162
163 fn provider_for(&self, model: &str) -> Result<Provider> {
166 let inner = self
167 .inner
168 .read()
169 .map_err(|_| anyhow!("provider lock poisoned"))?;
170 inner
171 .providers
172 .get(model)
173 .cloned()
174 .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
175 }
176
177 pub fn context_limit(&self, model: &str) -> usize {
181 default_context_limit(model)
182 }
183}
184
185impl Model for ProviderRegistry {
186 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
187 let provider = self.provider_for(&request.model)?;
188 provider.send(request).await
189 }
190
191 fn stream(
192 &self,
193 request: wcore::model::Request,
194 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
195 let result = self.provider_for(&request.model);
196 try_stream! {
197 let provider = result?;
198 let mut stream = std::pin::pin!(provider.stream(request));
199 while let Some(chunk) = stream.next().await {
200 yield chunk?;
201 }
202 }
203 }
204
205 fn context_limit(&self, model: &str) -> usize {
206 ProviderRegistry::context_limit(self, model)
207 }
208
209 fn active_model(&self) -> String {
210 self.active_model_name()
211 .unwrap_or_else(|_| "unknown".to_owned())
212 }
213}
214
215impl std::fmt::Debug for ProviderRegistry {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self.inner.read() {
218 Ok(inner) => f
219 .debug_struct("ProviderRegistry")
220 .field("active", &inner.active)
221 .field("count", &inner.providers.len())
222 .finish(),
223 Err(_) => f
224 .debug_struct("ProviderRegistry")
225 .field("error", &"lock poisoned")
226 .finish(),
227 }
228 }
229}
230
231impl Clone for ProviderRegistry {
232 fn clone(&self) -> Self {
233 Self {
234 inner: Arc::clone(&self.inner),
235 }
236 }
237}