Skip to main content

walrus_model/
manager.rs

1//! `ProviderManager` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14/// Manages a set of named providers with an active selection.
15///
16/// All methods that read or mutate the inner state acquire the `RwLock`.
17/// `active()` returns a clone of the current `Provider` — callers do not
18/// hold the lock while performing LLM calls.
19pub struct ProviderManager {
20    inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24    /// Provider instances keyed by model name.
25    providers: BTreeMap<CompactString, (ProviderConfig, Provider)>,
26    /// Model name of the currently active provider.
27    active: CompactString,
28    /// Shared HTTP client for constructing new providers.
29    client: reqwest::Client,
30}
31
32/// Info about a single provider entry returned by `list()`.
33#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35    /// Provider model name (key).
36    pub name: CompactString,
37    /// Whether this is the active provider.
38    pub active: bool,
39}
40
41impl ProviderManager {
42    /// Create a new manager from a list of provider configs.
43    ///
44    /// The first element becomes the active provider.
45    /// Returns an error if the slice is empty, any config fails validation, or
46    /// any provider fails to build.
47    pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
48        if configs.is_empty() {
49            bail!("at least one provider config is required");
50        }
51
52        let client = reqwest::Client::new();
53        let mut providers = BTreeMap::new();
54
55        for config in configs {
56            config.validate()?;
57            let provider = build_provider(config, client.clone()).await?;
58            providers.insert(config.model.clone(), (config.clone(), provider));
59        }
60
61        let active = configs[0].model.clone();
62
63        Ok(Self {
64            inner: Arc::new(RwLock::new(Inner {
65                providers,
66                active,
67                client,
68            })),
69        })
70    }
71
72    /// Create a manager with a single provider.
73    pub fn single(config: ProviderConfig, provider: Provider) -> Self {
74        let model = config.model.clone();
75        let mut providers = BTreeMap::new();
76        providers.insert(model.clone(), (config, provider));
77        Self {
78            inner: Arc::new(RwLock::new(Inner {
79                providers,
80                active: model,
81                client: reqwest::Client::new(),
82            })),
83        }
84    }
85
86    /// Get a clone of the active provider.
87    pub fn active(&self) -> Provider {
88        let inner = self.inner.read().expect("provider lock poisoned");
89        inner.providers[&inner.active].1.clone()
90    }
91
92    /// Get the model name of the active provider (also its key).
93    pub fn active_model(&self) -> CompactString {
94        let inner = self.inner.read().expect("provider lock poisoned");
95        inner.active.clone()
96    }
97
98    /// Get a clone of the active provider's config.
99    pub fn active_config(&self) -> ProviderConfig {
100        let inner = self.inner.read().expect("provider lock poisoned");
101        inner.providers[&inner.active].0.clone()
102    }
103
104    /// Switch to a different provider by model name. Returns an error if the
105    /// name is not found.
106    pub fn switch(&self, model: &str) -> Result<()> {
107        let mut inner = self.inner.write().expect("provider lock poisoned");
108        if !inner.providers.contains_key(model) {
109            bail!("provider '{}' not found", model);
110        }
111        inner.active = CompactString::from(model);
112        Ok(())
113    }
114
115    /// Add a new provider. Validates config first. Replaces any existing
116    /// provider with the same model name.
117    pub async fn add(&self, config: &ProviderConfig) -> Result<()> {
118        config.validate()?;
119        let client = {
120            let inner = self.inner.read().expect("provider lock poisoned");
121            inner.client.clone()
122        };
123        let provider = build_provider(config, client).await?;
124        let mut inner = self.inner.write().expect("provider lock poisoned");
125        inner
126            .providers
127            .insert(config.model.clone(), (config.clone(), provider));
128        Ok(())
129    }
130
131    /// Remove a provider by model name. Fails if the provider is currently
132    /// active.
133    pub fn remove(&self, model: &str) -> Result<()> {
134        let mut inner = self.inner.write().expect("provider lock poisoned");
135        if inner.active == model {
136            bail!("cannot remove the active provider '{}'", model);
137        }
138        if inner.providers.remove(model).is_none() {
139            bail!("provider '{}' not found", model);
140        }
141        Ok(())
142    }
143
144    /// List all providers with their active status.
145    pub fn list(&self) -> Vec<ProviderEntry> {
146        let inner = self.inner.read().expect("provider lock poisoned");
147        inner
148            .providers
149            .keys()
150            .map(|name| ProviderEntry {
151                name: name.clone(),
152                active: *name == inner.active,
153            })
154            .collect()
155    }
156
157    /// Look up a provider by model name. Returns a clone so callers don't
158    /// hold the lock during LLM calls.
159    fn provider_for(&self, model: &str) -> Result<Provider> {
160        let inner = self.inner.read().expect("provider lock poisoned");
161        inner
162            .providers
163            .get(model)
164            .map(|(_, p)| p.clone())
165            .ok_or_else(|| anyhow::anyhow!("model '{}' not found in registry", model))
166    }
167
168    /// Resolve the context limit for a model.
169    ///
170    /// Resolution chain: provider reports limit → static map → 8192 default.
171    pub fn context_limit(&self, model: &str) -> usize {
172        let inner = self.inner.read().expect("provider lock poisoned");
173        if let Some((_, provider)) = inner.providers.get(model)
174            && let Some(limit) = provider.context_length(model)
175        {
176            return limit;
177        }
178        default_context_limit(model)
179    }
180}
181
182impl Model for ProviderManager {
183    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
184        let provider = self.provider_for(&request.model)?;
185        provider.send(request).await
186    }
187
188    fn stream(
189        &self,
190        request: wcore::model::Request,
191    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
192        let result = self.provider_for(&request.model);
193        try_stream! {
194            let provider = result?;
195            let mut stream = std::pin::pin!(provider.stream(request));
196            while let Some(chunk) = stream.next().await {
197                yield chunk?;
198            }
199        }
200    }
201
202    fn context_limit(&self, model: &str) -> usize {
203        ProviderManager::context_limit(self, model)
204    }
205
206    fn active_model(&self) -> CompactString {
207        ProviderManager::active_model(self)
208    }
209}
210
211impl std::fmt::Debug for ProviderManager {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        let inner = self.inner.read().expect("provider lock poisoned");
214        f.debug_struct("ProviderManager")
215            .field("active", &inner.active)
216            .field("count", &inner.providers.len())
217            .finish()
218    }
219}
220
221impl Clone for ProviderManager {
222    fn clone(&self) -> Self {
223        Self {
224            inner: Arc::clone(&self.inner),
225        }
226    }
227}