Skip to main content

crabtalk_model/
manager.rs

1//! `ProviderRegistry` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use futures_core::Stream;
8use futures_util::StreamExt;
9use std::collections::BTreeMap;
10use std::sync::{Arc, RwLock};
11use wcore::model::{Model, Response, StreamChunk, default_context_limit};
12
13/// Manages a set of named providers with an active selection.
14///
15/// All methods that read or mutate the inner state acquire the `RwLock`.
16/// `active()` returns a clone of the current `Provider` — callers do not
17/// hold the lock while performing LLM calls.
18pub struct ProviderRegistry {
19    inner: Arc<RwLock<Inner>>,
20}
21
22struct Inner {
23    /// Provider instances keyed by model name.
24    providers: BTreeMap<String, Provider>,
25    /// Model name of the currently active provider.
26    active: String,
27    /// Shared HTTP client for constructing new providers.
28    client: reqwest::Client,
29}
30
31/// Info about a single provider entry returned by `list()`.
32#[derive(Debug, Clone)]
33pub struct ProviderEntry {
34    /// Provider model name (key).
35    pub name: String,
36    /// Whether this is the active provider.
37    pub active: bool,
38}
39
40impl ProviderRegistry {
41    /// Create an empty manager with the given active model name.
42    ///
43    /// Use `add_provider()` or `add_model()` to populate.
44    pub fn new(active: impl Into<String>) -> Self {
45        Self {
46            inner: Arc::new(RwLock::new(Inner {
47                providers: BTreeMap::new(),
48                active: active.into(),
49                client: reqwest::Client::new(),
50            })),
51        }
52    }
53
54    /// Build a registry from a map of provider definitions and an active model.
55    ///
56    /// Iterates each provider def, building a `Provider` instance per model
57    /// in its `models` list.
58    pub fn from_providers(
59        active: String,
60        providers: &BTreeMap<String, ProviderDef>,
61    ) -> Result<Self> {
62        let registry = Self::new(active);
63        for def in providers.values() {
64            registry.add_def(def)?;
65        }
66        Ok(registry)
67    }
68
69    /// Add a pre-built provider directly (e.g. local models from registry).
70    pub fn add_provider(&self, name: impl Into<String>, provider: Provider) -> Result<()> {
71        let mut inner = self
72            .inner
73            .write()
74            .map_err(|_| anyhow!("provider lock poisoned"))?;
75        inner.providers.insert(name.into(), provider);
76        Ok(())
77    }
78
79    /// Add all models from a provider definition. Builds a `Provider` per model.
80    pub fn add_def(&self, def: &ProviderDef) -> Result<()> {
81        let client = {
82            let inner = self
83                .inner
84                .read()
85                .map_err(|_| anyhow!("provider lock poisoned"))?;
86            inner.client.clone()
87        };
88        for model_name in &def.models {
89            let provider = build_provider(def, model_name, client.clone())?;
90            let mut inner = self
91                .inner
92                .write()
93                .map_err(|_| anyhow!("provider lock poisoned"))?;
94            inner.providers.insert(model_name.to_string(), provider);
95        }
96        Ok(())
97    }
98
99    /// Get a clone of the active provider.
100    pub fn active(&self) -> Result<Provider> {
101        let inner = self
102            .inner
103            .read()
104            .map_err(|_| anyhow!("provider lock poisoned"))?;
105        Ok(inner.providers[&inner.active].clone())
106    }
107
108    /// Get the model name of the active provider (also its key).
109    pub fn active_model_name(&self) -> Result<String> {
110        let inner = self
111            .inner
112            .read()
113            .map_err(|_| anyhow!("provider lock poisoned"))?;
114        Ok(inner.active.clone())
115    }
116
117    /// Switch to a different provider by model name. Returns an error if the
118    /// name is not found.
119    pub fn switch(&self, model: &str) -> Result<()> {
120        let mut inner = self
121            .inner
122            .write()
123            .map_err(|_| anyhow!("provider lock poisoned"))?;
124        if !inner.providers.contains_key(model) {
125            bail!("provider '{}' not found", model);
126        }
127        inner.active = model.to_owned();
128        Ok(())
129    }
130
131    /// Remove a provider by model name. Fails if the provider is currently
132    /// active.
133    pub fn remove(&self, model: &str) -> Result<()> {
134        let mut inner = self
135            .inner
136            .write()
137            .map_err(|_| anyhow!("provider lock poisoned"))?;
138        if inner.active == model {
139            bail!("cannot remove the active provider '{}'", model);
140        }
141        if inner.providers.remove(model).is_none() {
142            bail!("provider '{}' not found", model);
143        }
144        Ok(())
145    }
146
147    /// List all providers with their active status.
148    pub fn list(&self) -> Result<Vec<ProviderEntry>> {
149        let inner = self
150            .inner
151            .read()
152            .map_err(|_| anyhow!("provider lock poisoned"))?;
153        Ok(inner
154            .providers
155            .keys()
156            .map(|name| ProviderEntry {
157                name: name.clone(),
158                active: *name == inner.active,
159            })
160            .collect())
161    }
162
163    /// Look up a provider by model name. Returns a clone so callers don't
164    /// hold the lock during LLM calls.
165    fn provider_for(&self, model: &str) -> Result<Provider> {
166        let inner = self
167            .inner
168            .read()
169            .map_err(|_| anyhow!("provider lock poisoned"))?;
170        inner
171            .providers
172            .get(model)
173            .cloned()
174            .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
175    }
176
177    /// Resolve the context limit for a model.
178    ///
179    /// Uses the static map in `wcore::model::default_context_limit`.
180    pub fn context_limit(&self, model: &str) -> usize {
181        default_context_limit(model)
182    }
183}
184
185impl Model for ProviderRegistry {
186    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
187        let provider = self.provider_for(&request.model)?;
188        provider.send(request).await
189    }
190
191    fn stream(
192        &self,
193        request: wcore::model::Request,
194    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
195        let result = self.provider_for(&request.model);
196        try_stream! {
197            let provider = result?;
198            let mut stream = std::pin::pin!(provider.stream(request));
199            while let Some(chunk) = stream.next().await {
200                yield chunk?;
201            }
202        }
203    }
204
205    fn context_limit(&self, model: &str) -> usize {
206        ProviderRegistry::context_limit(self, model)
207    }
208
209    fn active_model(&self) -> String {
210        self.active_model_name()
211            .unwrap_or_else(|_| "unknown".to_owned())
212    }
213}
214
215impl std::fmt::Debug for ProviderRegistry {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self.inner.read() {
218            Ok(inner) => f
219                .debug_struct("ProviderRegistry")
220                .field("active", &inner.active)
221                .field("count", &inner.providers.len())
222                .finish(),
223            Err(_) => f
224                .debug_struct("ProviderRegistry")
225                .field("error", &"lock poisoned")
226                .finish(),
227        }
228    }
229}
230
231impl Clone for ProviderRegistry {
232    fn clone(&self) -> Self {
233        Self {
234            inner: Arc::clone(&self.inner),
235        }
236    }
237}