use crate::{Provider, ProviderDef, build_provider};
use anyhow::{Result, anyhow, bail};
use async_stream::try_stream;
use futures_core::Stream;
use futures_util::StreamExt;
use std::collections::BTreeMap;
use std::sync::{Arc, RwLock};
use wcore::model::{Model, Response, StreamChunk, default_context_limit};
pub struct ProviderRegistry {
inner: Arc<RwLock<Inner>>,
}
struct Inner {
providers: BTreeMap<String, Provider>,
active: String,
client: reqwest::Client,
}
#[derive(Debug, Clone)]
pub struct ProviderEntry {
pub name: String,
pub active: bool,
}
impl ProviderRegistry {
pub fn new(active: impl Into<String>) -> Self {
Self {
inner: Arc::new(RwLock::new(Inner {
providers: BTreeMap::new(),
active: active.into(),
client: reqwest::Client::new(),
})),
}
}
pub fn from_providers(
active: String,
providers: &BTreeMap<String, ProviderDef>,
) -> Result<Self> {
let registry = Self::new(active);
for def in providers.values() {
registry.add_def(def)?;
}
Ok(registry)
}
pub fn add_provider(&self, name: impl Into<String>, provider: Provider) -> Result<()> {
let mut inner = self
.inner
.write()
.map_err(|_| anyhow!("provider lock poisoned"))?;
inner.providers.insert(name.into(), provider);
Ok(())
}
pub fn add_def(&self, def: &ProviderDef) -> Result<()> {
let client = {
let inner = self
.inner
.read()
.map_err(|_| anyhow!("provider lock poisoned"))?;
inner.client.clone()
};
for model_name in &def.models {
let provider = build_provider(def, model_name, client.clone())?;
let mut inner = self
.inner
.write()
.map_err(|_| anyhow!("provider lock poisoned"))?;
inner.providers.insert(model_name.to_string(), provider);
}
Ok(())
}
pub fn active(&self) -> Result<Provider> {
let inner = self
.inner
.read()
.map_err(|_| anyhow!("provider lock poisoned"))?;
Ok(inner.providers[&inner.active].clone())
}
pub fn active_model_name(&self) -> Result<String> {
let inner = self
.inner
.read()
.map_err(|_| anyhow!("provider lock poisoned"))?;
Ok(inner.active.clone())
}
pub fn switch(&self, model: &str) -> Result<()> {
let mut inner = self
.inner
.write()
.map_err(|_| anyhow!("provider lock poisoned"))?;
if !inner.providers.contains_key(model) {
bail!("provider '{}' not found", model);
}
inner.active = model.to_owned();
Ok(())
}
pub fn remove(&self, model: &str) -> Result<()> {
let mut inner = self
.inner
.write()
.map_err(|_| anyhow!("provider lock poisoned"))?;
if inner.active == model {
bail!("cannot remove the active provider '{}'", model);
}
if inner.providers.remove(model).is_none() {
bail!("provider '{}' not found", model);
}
Ok(())
}
pub fn list(&self) -> Result<Vec<ProviderEntry>> {
let inner = self
.inner
.read()
.map_err(|_| anyhow!("provider lock poisoned"))?;
Ok(inner
.providers
.keys()
.map(|name| ProviderEntry {
name: name.clone(),
active: *name == inner.active,
})
.collect())
}
fn provider_for(&self, model: &str) -> Result<Provider> {
let inner = self
.inner
.read()
.map_err(|_| anyhow!("provider lock poisoned"))?;
inner
.providers
.get(model)
.cloned()
.ok_or_else(|| anyhow!("model '{}' not found in registry", model))
}
pub fn context_limit(&self, model: &str) -> usize {
default_context_limit(model)
}
}
impl Model for ProviderRegistry {
async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
let provider = self.provider_for(&request.model)?;
provider.send(request).await
}
fn stream(
&self,
request: wcore::model::Request,
) -> impl Stream<Item = Result<StreamChunk>> + Send {
let result = self.provider_for(&request.model);
try_stream! {
let provider = result?;
let mut stream = std::pin::pin!(provider.stream(request));
while let Some(chunk) = stream.next().await {
yield chunk?;
}
}
}
fn context_limit(&self, model: &str) -> usize {
ProviderRegistry::context_limit(self, model)
}
fn active_model(&self) -> String {
self.active_model_name()
.unwrap_or_else(|_| "unknown".to_owned())
}
}
impl std::fmt::Debug for ProviderRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.inner.read() {
Ok(inner) => f
.debug_struct("ProviderRegistry")
.field("active", &inner.active)
.field("count", &inner.providers.len())
.finish(),
Err(_) => f
.debug_struct("ProviderRegistry")
.field("error", &"lock poisoned")
.finish(),
}
}
}
impl Clone for ProviderRegistry {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}