1use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14pub struct ProviderManager {
20 inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24 providers: BTreeMap<CompactString, Provider>,
26 active: CompactString,
28 client: reqwest::Client,
30}
31
32#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35 pub name: CompactString,
37 pub active: bool,
39}
40
41impl ProviderManager {
42 pub fn new(active: impl Into<CompactString>) -> Self {
46 Self {
47 inner: Arc::new(RwLock::new(Inner {
48 providers: BTreeMap::new(),
49 active: active.into(),
50 client: reqwest::Client::new(),
51 })),
52 }
53 }
54
55 pub async fn from_providers(
60 active: CompactString,
61 providers: &BTreeMap<CompactString, ProviderDef>,
62 ) -> Result<Self> {
63 let manager = Self::new(active);
64 for def in providers.values() {
65 manager.add_def(def).await?;
66 }
67 Ok(manager)
68 }
69
70 pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
72 let mut inner = self
73 .inner
74 .write()
75 .map_err(|_| anyhow!("provider lock poisoned"))?;
76 inner.providers.insert(name.into(), provider);
77 Ok(())
78 }
79
80 pub async fn add_def(&self, def: &ProviderDef) -> Result<()> {
82 let client = {
83 let inner = self
84 .inner
85 .read()
86 .map_err(|_| anyhow!("provider lock poisoned"))?;
87 inner.client.clone()
88 };
89 for model_name in &def.models {
90 let provider = build_provider(def, model_name, client.clone()).await?;
91 let mut inner = self
92 .inner
93 .write()
94 .map_err(|_| anyhow!("provider lock poisoned"))?;
95 inner.providers.insert(model_name.clone(), provider);
96 }
97 Ok(())
98 }
99
100 pub fn active(&self) -> Result<Provider> {
102 let inner = self
103 .inner
104 .read()
105 .map_err(|_| anyhow!("provider lock poisoned"))?;
106 Ok(inner.providers[&inner.active].clone())
107 }
108
109 pub fn active_model_name(&self) -> Result<CompactString> {
111 let inner = self
112 .inner
113 .read()
114 .map_err(|_| anyhow!("provider lock poisoned"))?;
115 Ok(inner.active.clone())
116 }
117
118 pub fn switch(&self, model: &str) -> Result<()> {
121 let mut inner = self
122 .inner
123 .write()
124 .map_err(|_| anyhow!("provider lock poisoned"))?;
125 if !inner.providers.contains_key(model) {
126 bail!("provider '{}' not found", model);
127 }
128 inner.active = CompactString::from(model);
129 Ok(())
130 }
131
132 pub fn remove(&self, model: &str) -> Result<()> {
135 let mut inner = self
136 .inner
137 .write()
138 .map_err(|_| anyhow!("provider lock poisoned"))?;
139 if inner.active == model {
140 bail!("cannot remove the active provider '{}'", model);
141 }
142 if inner.providers.remove(model).is_none() {
143 bail!("provider '{}' not found", model);
144 }
145 Ok(())
146 }
147
148 pub fn list(&self) -> Result<Vec<ProviderEntry>> {
150 let inner = self
151 .inner
152 .read()
153 .map_err(|_| anyhow!("provider lock poisoned"))?;
154 Ok(inner
155 .providers
156 .keys()
157 .map(|name| ProviderEntry {
158 name: name.clone(),
159 active: *name == inner.active,
160 })
161 .collect())
162 }
163
164 fn provider_for(&self, model: &str) -> Result<Provider> {
167 let inner = self
168 .inner
169 .read()
170 .map_err(|_| anyhow!("provider lock poisoned"))?;
171 inner
172 .providers
173 .get(model)
174 .cloned()
175 .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
176 }
177
178 pub fn context_limit(&self, model: &str) -> usize {
182 default_context_limit(model)
183 }
184}
185
186impl Model for ProviderManager {
187 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
188 let provider = self.provider_for(&request.model)?;
189 provider.send(request).await
190 }
191
192 fn stream(
193 &self,
194 request: wcore::model::Request,
195 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
196 let result = self.provider_for(&request.model);
197 try_stream! {
198 let provider = result?;
199 let mut stream = std::pin::pin!(provider.stream(request));
200 while let Some(chunk) = stream.next().await {
201 yield chunk?;
202 }
203 }
204 }
205
206 fn context_limit(&self, model: &str) -> usize {
207 ProviderManager::context_limit(self, model)
208 }
209
210 fn active_model(&self) -> CompactString {
211 self.active_model_name()
212 .unwrap_or_else(|_| CompactString::const_new("unknown"))
213 }
214}
215
216impl std::fmt::Debug for ProviderManager {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self.inner.read() {
219 Ok(inner) => f
220 .debug_struct("ProviderManager")
221 .field("active", &inner.active)
222 .field("count", &inner.providers.len())
223 .finish(),
224 Err(_) => f
225 .debug_struct("ProviderManager")
226 .field("error", &"lock poisoned")
227 .finish(),
228 }
229 }
230}
231
232impl Clone for ProviderManager {
233 fn clone(&self) -> Self {
234 Self {
235 inner: Arc::clone(&self.inner),
236 }
237 }
238}