Skip to main content

liter_llm/provider/
custom.rs

1//! Runtime registration of custom LLM providers.
2//!
3//! Allows users to register providers that are not part of the built-in
4//! `providers.json` registry.  Custom providers are checked **first** during
5//! model detection, so they can override built-in routing.
6
7use std::borrow::Cow;
8use std::sync::RwLock;
9
10use serde::{Deserialize, Serialize};
11
12use super::Provider;
13use crate::error::{LiterLlmError, Result};
14
15// ── Global custom-provider registry ──────────────────────────────────────────
16
17/// Thread-safe registry of runtime-registered custom providers.
18///
19/// Uses `RwLock` so that reads (the hot path inside `detect_provider`) only
20/// take a shared lock, while mutations (`register` / `unregister`) take an
21/// exclusive lock.
22static CUSTOM_PROVIDERS: RwLock<Vec<CustomProviderConfig>> = RwLock::new(Vec::new());
23
24/// Configuration for registering a custom LLM provider at runtime.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CustomProviderConfig {
27    /// Unique name for this provider (e.g., "my-provider").
28    pub name: String,
29    /// Base URL for the provider's API (e.g., "https://api.my-provider.com/v1").
30    pub base_url: String,
31    /// Authentication header format.
32    pub auth_header: AuthHeaderFormat,
33    /// Model name prefixes that route to this provider (e.g., ["my-"]).
34    pub model_prefixes: Vec<String>,
35}
36
37/// How the API key is sent in the HTTP request.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub enum AuthHeaderFormat {
40    /// Bearer token: `Authorization: Bearer <key>`
41    #[default]
42    Bearer,
43    /// Custom header: e.g., `X-Api-Key: <key>`
44    ApiKey(String),
45    /// No authentication required.
46    None,
47}
48
49// ── Public API ───────────────────────────────────────────────────────────────
50
51/// Register a custom provider in the global runtime registry.
52///
53/// The provider will be checked **before** all built-in providers during model
54/// detection. If a provider with the same `name` already exists it is replaced.
55///
56/// # Errors
57///
58/// Returns an error if the config is invalid (empty name, empty base_url, or
59/// no model prefixes).
60pub fn register_custom_provider(config: CustomProviderConfig) -> Result<()> {
61    validate_config(&config)?;
62
63    let mut providers = CUSTOM_PROVIDERS.write().map_err(|e| LiterLlmError::ServerError {
64        message: format!("custom provider registry lock poisoned: {e}"),
65    })?;
66
67    // Replace existing entry with the same name, or append.
68    if let Some(existing) = providers.iter_mut().find(|p| p.name == config.name) {
69        *existing = config;
70    } else {
71        providers.push(config);
72    }
73
74    Ok(())
75}
76
77/// Remove a previously registered custom provider by name.
78///
79/// Returns `true` if a provider with the given name was found and removed,
80/// `false` if no such provider existed.
81///
82/// # Errors
83///
84/// Returns an error only if the internal lock is poisoned.
85pub fn unregister_custom_provider(name: &str) -> Result<bool> {
86    let mut providers = CUSTOM_PROVIDERS.write().map_err(|e| LiterLlmError::ServerError {
87        message: format!("custom provider registry lock poisoned: {e}"),
88    })?;
89
90    let before = providers.len();
91    providers.retain(|p| p.name != name);
92    Ok(providers.len() < before)
93}
94
95/// Try to match a model name against the custom-provider registry.
96///
97/// Returns a boxed [`Provider`] if a custom provider claims the model,
98/// `None` otherwise.  This is called at the **top** of `detect_provider`
99/// so custom providers always take priority over built-in ones.
100pub(crate) fn detect_custom_provider(model: &str) -> Option<Box<dyn Provider>> {
101    let providers = CUSTOM_PROVIDERS.read().ok()?;
102
103    for cfg in providers.iter() {
104        let matches = cfg
105            .model_prefixes
106            .iter()
107            .any(|prefix| model.starts_with(prefix.as_str()));
108
109        if matches {
110            return Some(Box::new(CustomProvider { config: cfg.clone() }));
111        }
112    }
113
114    None
115}
116
117/// Clear all custom providers.  Intended for test isolation only.
118#[cfg(test)]
119pub(crate) fn clear_custom_providers() {
120    if let Ok(mut providers) = CUSTOM_PROVIDERS.write() {
121        providers.clear();
122    }
123}
124
125// ── Validation ───────────────────────────────────────────────────────────────
126
127fn validate_config(config: &CustomProviderConfig) -> Result<()> {
128    if config.name.trim().is_empty() {
129        return Err(LiterLlmError::BadRequest {
130            message: "custom provider name must not be empty or whitespace-only".into(),
131        });
132    }
133    if config.base_url.trim().is_empty() {
134        return Err(LiterLlmError::BadRequest {
135            message: "custom provider base_url must not be empty or whitespace-only".into(),
136        });
137    }
138    if config.model_prefixes.is_empty() {
139        return Err(LiterLlmError::BadRequest {
140            message: "custom provider must have at least one model prefix".into(),
141        });
142    }
143    for prefix in &config.model_prefixes {
144        if prefix.is_empty() {
145            return Err(LiterLlmError::BadRequest {
146                message: "custom provider model prefix must not be empty (would match all models)".into(),
147            });
148        }
149    }
150    Ok(())
151}
152
153// ── Provider implementation ──────────────────────────────────────────────────
154
155/// A runtime-registered custom provider.
156///
157/// Wraps a [`CustomProviderConfig`] and implements the [`Provider`] trait so
158/// the client can use it exactly like a built-in provider.
159struct CustomProvider {
160    config: CustomProviderConfig,
161}
162
163impl Provider for CustomProvider {
164    fn name(&self) -> &str {
165        &self.config.name
166    }
167
168    fn base_url(&self) -> &str {
169        &self.config.base_url
170    }
171
172    fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
173        match &self.config.auth_header {
174            AuthHeaderFormat::Bearer => Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}")))),
175            AuthHeaderFormat::ApiKey(header_name) => Some((Cow::Owned(header_name.clone()), Cow::Borrowed(api_key))),
176            AuthHeaderFormat::None => None,
177        }
178    }
179
180    fn matches_model(&self, model: &str) -> bool {
181        self.config
182            .model_prefixes
183            .iter()
184            .any(|prefix| model.starts_with(prefix.as_str()))
185    }
186}
187
188// ── Tests ────────────────────────────────────────────────────────────────────
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    /// Mutex to serialize tests that share the global custom-provider registry.
195    static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
196
197    /// Acquire the test lock and clear the registry.
198    fn setup() -> std::sync::MutexGuard<'static, ()> {
199        let guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
200        clear_custom_providers();
201        guard
202    }
203
204    #[test]
205    fn register_and_detect_by_model_prefix() {
206        let _guard = setup();
207
208        let config = CustomProviderConfig {
209            name: "my-provider".into(),
210            base_url: "https://api.my-provider.com/v1".into(),
211            auth_header: AuthHeaderFormat::Bearer,
212            model_prefixes: vec!["my-".into(), "my-provider/".into()],
213        };
214
215        register_custom_provider(config).expect("registration should succeed");
216
217        let provider = detect_custom_provider("my-model-7b");
218        assert!(provider.is_some(), "should detect custom provider by prefix 'my-'");
219        let provider = provider.unwrap();
220        assert_eq!(provider.name(), "my-provider");
221        assert_eq!(provider.base_url(), "https://api.my-provider.com/v1");
222
223        // Also detect via slash-prefix routing.
224        let provider2 = detect_custom_provider("my-provider/llama-70b");
225        assert!(provider2.is_some(), "should detect custom provider by slash prefix");
226
227        // Non-matching model should not detect.
228        let none = detect_custom_provider("gpt-4");
229        assert!(none.is_none(), "should not match unrelated model");
230    }
231
232    #[test]
233    fn unregister_removes_provider() {
234        let _guard = setup();
235
236        let config = CustomProviderConfig {
237            name: "ephemeral".into(),
238            base_url: "https://api.ephemeral.com/v1".into(),
239            auth_header: AuthHeaderFormat::Bearer,
240            model_prefixes: vec!["eph-".into()],
241        };
242
243        register_custom_provider(config).expect("registration should succeed");
244        assert!(detect_custom_provider("eph-model").is_some());
245
246        let removed = unregister_custom_provider("ephemeral").expect("unregister should succeed");
247        assert!(removed, "should return true when provider was found");
248
249        assert!(
250            detect_custom_provider("eph-model").is_none(),
251            "should no longer detect after unregister"
252        );
253
254        // Unregistering again returns false.
255        let removed_again = unregister_custom_provider("ephemeral").expect("unregister should succeed");
256        assert!(!removed_again, "should return false when provider not found");
257    }
258
259    #[test]
260    fn custom_provider_with_api_key_auth() {
261        let _guard = setup();
262
263        let config = CustomProviderConfig {
264            name: "secure-provider".into(),
265            base_url: "https://api.secure.com/v1".into(),
266            auth_header: AuthHeaderFormat::ApiKey("X-Custom-Auth".into()),
267            model_prefixes: vec!["secure/".into()],
268        };
269
270        register_custom_provider(config).expect("registration should succeed");
271
272        let provider = detect_custom_provider("secure/model-1").expect("should detect provider");
273        let (header_name, header_value) = provider
274            .auth_header("my-secret-key")
275            .expect("should return auth header");
276        assert_eq!(header_name.as_ref(), "X-Custom-Auth");
277        assert_eq!(header_value.as_ref(), "my-secret-key");
278    }
279
280    #[test]
281    fn custom_provider_with_no_auth() {
282        let _guard = setup();
283
284        let config = CustomProviderConfig {
285            name: "local-provider".into(),
286            base_url: "http://localhost:8080/v1".into(),
287            auth_header: AuthHeaderFormat::None,
288            model_prefixes: vec!["local/".into()],
289        };
290
291        register_custom_provider(config).expect("registration should succeed");
292
293        let provider = detect_custom_provider("local/model").expect("should detect provider");
294        assert!(
295            provider.auth_header("unused").is_none(),
296            "no-auth provider should return None"
297        );
298    }
299
300    #[test]
301    fn custom_provider_bearer_auth() {
302        let _guard = setup();
303
304        let config = CustomProviderConfig {
305            name: "bearer-provider".into(),
306            base_url: "https://api.bearer.com/v1".into(),
307            auth_header: AuthHeaderFormat::Bearer,
308            model_prefixes: vec!["bearer/".into()],
309        };
310
311        register_custom_provider(config).expect("registration should succeed");
312
313        let provider = detect_custom_provider("bearer/model").expect("should detect provider");
314        let (header_name, header_value) = provider.auth_header("my-token").expect("should return auth header");
315        assert_eq!(header_name.as_ref(), "Authorization");
316        assert_eq!(header_value.as_ref(), "Bearer my-token");
317    }
318
319    #[test]
320    fn register_replaces_existing_provider() {
321        let _guard = setup();
322
323        let config1 = CustomProviderConfig {
324            name: "updatable".into(),
325            base_url: "https://old.example.com/v1".into(),
326            auth_header: AuthHeaderFormat::Bearer,
327            model_prefixes: vec!["upd/".into()],
328        };
329        register_custom_provider(config1).expect("first registration should succeed");
330
331        let config2 = CustomProviderConfig {
332            name: "updatable".into(),
333            base_url: "https://new.example.com/v1".into(),
334            auth_header: AuthHeaderFormat::Bearer,
335            model_prefixes: vec!["upd/".into()],
336        };
337        register_custom_provider(config2).expect("second registration should succeed");
338
339        let provider = detect_custom_provider("upd/model").expect("should detect provider");
340        assert_eq!(
341            provider.base_url(),
342            "https://new.example.com/v1",
343            "should use the updated config"
344        );
345    }
346
347    #[test]
348    fn validation_rejects_empty_name() {
349        let _guard = setup();
350
351        let config = CustomProviderConfig {
352            name: String::new(),
353            base_url: "https://example.com".into(),
354            auth_header: AuthHeaderFormat::Bearer,
355            model_prefixes: vec!["x/".into()],
356        };
357        let result = register_custom_provider(config);
358        assert!(result.is_err(), "should reject empty name");
359    }
360
361    #[test]
362    fn validation_rejects_empty_base_url() {
363        let _guard = setup();
364
365        let config = CustomProviderConfig {
366            name: "valid-name".into(),
367            base_url: String::new(),
368            auth_header: AuthHeaderFormat::Bearer,
369            model_prefixes: vec!["x/".into()],
370        };
371        let result = register_custom_provider(config);
372        assert!(result.is_err(), "should reject empty base_url");
373    }
374
375    #[test]
376    fn validation_rejects_no_prefixes() {
377        let _guard = setup();
378
379        let config = CustomProviderConfig {
380            name: "valid-name".into(),
381            base_url: "https://example.com".into(),
382            auth_header: AuthHeaderFormat::Bearer,
383            model_prefixes: vec![],
384        };
385        let result = register_custom_provider(config);
386        assert!(result.is_err(), "should reject empty model_prefixes");
387    }
388
389    #[test]
390    fn config_serde_round_trip() {
391        let config = CustomProviderConfig {
392            name: "serde-test".into(),
393            base_url: "https://example.com/v1".into(),
394            auth_header: AuthHeaderFormat::ApiKey("X-Api-Key".into()),
395            model_prefixes: vec!["serde/".into()],
396        };
397
398        let json = serde_json::to_string(&config).expect("should serialize");
399        let parsed: CustomProviderConfig = serde_json::from_str(&json).expect("should deserialize");
400
401        assert_eq!(parsed.name, "serde-test");
402        assert_eq!(parsed.base_url, "https://example.com/v1");
403        assert_eq!(parsed.model_prefixes, vec!["serde/"]);
404    }
405}