1use std::borrow::Cow;
8use std::sync::RwLock;
9
10use serde::{Deserialize, Serialize};
11
12use super::Provider;
13use crate::error::{LiterLlmError, Result};
14
15static CUSTOM_PROVIDERS: RwLock<Vec<CustomProviderConfig>> = RwLock::new(Vec::new());
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CustomProviderConfig {
27 pub name: String,
29 pub base_url: String,
31 pub auth_header: AuthHeaderFormat,
33 pub model_prefixes: Vec<String>,
35}
36
37#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub enum AuthHeaderFormat {
40 #[default]
42 Bearer,
43 ApiKey(String),
45 None,
47}
48
49pub fn register_custom_provider(config: CustomProviderConfig) -> Result<()> {
61 validate_config(&config)?;
62
63 let mut providers = CUSTOM_PROVIDERS.write().map_err(|e| LiterLlmError::ServerError {
64 message: format!("custom provider registry lock poisoned: {e}"),
65 })?;
66
67 if let Some(existing) = providers.iter_mut().find(|p| p.name == config.name) {
69 *existing = config;
70 } else {
71 providers.push(config);
72 }
73
74 Ok(())
75}
76
77pub fn unregister_custom_provider(name: &str) -> Result<bool> {
86 let mut providers = CUSTOM_PROVIDERS.write().map_err(|e| LiterLlmError::ServerError {
87 message: format!("custom provider registry lock poisoned: {e}"),
88 })?;
89
90 let before = providers.len();
91 providers.retain(|p| p.name != name);
92 Ok(providers.len() < before)
93}
94
95pub(crate) fn detect_custom_provider(model: &str) -> Option<Box<dyn Provider>> {
101 let providers = CUSTOM_PROVIDERS.read().ok()?;
102
103 for cfg in providers.iter() {
104 let matches = cfg
105 .model_prefixes
106 .iter()
107 .any(|prefix| model.starts_with(prefix.as_str()));
108
109 if matches {
110 return Some(Box::new(CustomProvider { config: cfg.clone() }));
111 }
112 }
113
114 None
115}
116
117#[cfg(test)]
119pub(crate) fn clear_custom_providers() {
120 if let Ok(mut providers) = CUSTOM_PROVIDERS.write() {
121 providers.clear();
122 }
123}
124
125fn validate_config(config: &CustomProviderConfig) -> Result<()> {
128 if config.name.trim().is_empty() {
129 return Err(LiterLlmError::BadRequest {
130 message: "custom provider name must not be empty or whitespace-only".into(),
131 });
132 }
133 if config.base_url.trim().is_empty() {
134 return Err(LiterLlmError::BadRequest {
135 message: "custom provider base_url must not be empty or whitespace-only".into(),
136 });
137 }
138 if config.model_prefixes.is_empty() {
139 return Err(LiterLlmError::BadRequest {
140 message: "custom provider must have at least one model prefix".into(),
141 });
142 }
143 for prefix in &config.model_prefixes {
144 if prefix.is_empty() {
145 return Err(LiterLlmError::BadRequest {
146 message: "custom provider model prefix must not be empty (would match all models)".into(),
147 });
148 }
149 }
150 Ok(())
151}
152
153struct CustomProvider {
160 config: CustomProviderConfig,
161}
162
163impl Provider for CustomProvider {
164 fn name(&self) -> &str {
165 &self.config.name
166 }
167
168 fn base_url(&self) -> &str {
169 &self.config.base_url
170 }
171
172 fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
173 match &self.config.auth_header {
174 AuthHeaderFormat::Bearer => Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}")))),
175 AuthHeaderFormat::ApiKey(header_name) => Some((Cow::Owned(header_name.clone()), Cow::Borrowed(api_key))),
176 AuthHeaderFormat::None => None,
177 }
178 }
179
180 fn matches_model(&self, model: &str) -> bool {
181 self.config
182 .model_prefixes
183 .iter()
184 .any(|prefix| model.starts_with(prefix.as_str()))
185 }
186}
187
188#[cfg(test)]
191mod tests {
192 use super::*;
193
194 static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
196
197 fn setup() -> std::sync::MutexGuard<'static, ()> {
199 let guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
200 clear_custom_providers();
201 guard
202 }
203
204 #[test]
205 fn register_and_detect_by_model_prefix() {
206 let _guard = setup();
207
208 let config = CustomProviderConfig {
209 name: "my-provider".into(),
210 base_url: "https://api.my-provider.com/v1".into(),
211 auth_header: AuthHeaderFormat::Bearer,
212 model_prefixes: vec!["my-".into(), "my-provider/".into()],
213 };
214
215 register_custom_provider(config).expect("registration should succeed");
216
217 let provider = detect_custom_provider("my-model-7b");
218 assert!(provider.is_some(), "should detect custom provider by prefix 'my-'");
219 let provider = provider.unwrap();
220 assert_eq!(provider.name(), "my-provider");
221 assert_eq!(provider.base_url(), "https://api.my-provider.com/v1");
222
223 let provider2 = detect_custom_provider("my-provider/llama-70b");
225 assert!(provider2.is_some(), "should detect custom provider by slash prefix");
226
227 let none = detect_custom_provider("gpt-4");
229 assert!(none.is_none(), "should not match unrelated model");
230 }
231
232 #[test]
233 fn unregister_removes_provider() {
234 let _guard = setup();
235
236 let config = CustomProviderConfig {
237 name: "ephemeral".into(),
238 base_url: "https://api.ephemeral.com/v1".into(),
239 auth_header: AuthHeaderFormat::Bearer,
240 model_prefixes: vec!["eph-".into()],
241 };
242
243 register_custom_provider(config).expect("registration should succeed");
244 assert!(detect_custom_provider("eph-model").is_some());
245
246 let removed = unregister_custom_provider("ephemeral").expect("unregister should succeed");
247 assert!(removed, "should return true when provider was found");
248
249 assert!(
250 detect_custom_provider("eph-model").is_none(),
251 "should no longer detect after unregister"
252 );
253
254 let removed_again = unregister_custom_provider("ephemeral").expect("unregister should succeed");
256 assert!(!removed_again, "should return false when provider not found");
257 }
258
259 #[test]
260 fn custom_provider_with_api_key_auth() {
261 let _guard = setup();
262
263 let config = CustomProviderConfig {
264 name: "secure-provider".into(),
265 base_url: "https://api.secure.com/v1".into(),
266 auth_header: AuthHeaderFormat::ApiKey("X-Custom-Auth".into()),
267 model_prefixes: vec!["secure/".into()],
268 };
269
270 register_custom_provider(config).expect("registration should succeed");
271
272 let provider = detect_custom_provider("secure/model-1").expect("should detect provider");
273 let (header_name, header_value) = provider
274 .auth_header("my-secret-key")
275 .expect("should return auth header");
276 assert_eq!(header_name.as_ref(), "X-Custom-Auth");
277 assert_eq!(header_value.as_ref(), "my-secret-key");
278 }
279
280 #[test]
281 fn custom_provider_with_no_auth() {
282 let _guard = setup();
283
284 let config = CustomProviderConfig {
285 name: "local-provider".into(),
286 base_url: "http://localhost:8080/v1".into(),
287 auth_header: AuthHeaderFormat::None,
288 model_prefixes: vec!["local/".into()],
289 };
290
291 register_custom_provider(config).expect("registration should succeed");
292
293 let provider = detect_custom_provider("local/model").expect("should detect provider");
294 assert!(
295 provider.auth_header("unused").is_none(),
296 "no-auth provider should return None"
297 );
298 }
299
300 #[test]
301 fn custom_provider_bearer_auth() {
302 let _guard = setup();
303
304 let config = CustomProviderConfig {
305 name: "bearer-provider".into(),
306 base_url: "https://api.bearer.com/v1".into(),
307 auth_header: AuthHeaderFormat::Bearer,
308 model_prefixes: vec!["bearer/".into()],
309 };
310
311 register_custom_provider(config).expect("registration should succeed");
312
313 let provider = detect_custom_provider("bearer/model").expect("should detect provider");
314 let (header_name, header_value) = provider.auth_header("my-token").expect("should return auth header");
315 assert_eq!(header_name.as_ref(), "Authorization");
316 assert_eq!(header_value.as_ref(), "Bearer my-token");
317 }
318
319 #[test]
320 fn register_replaces_existing_provider() {
321 let _guard = setup();
322
323 let config1 = CustomProviderConfig {
324 name: "updatable".into(),
325 base_url: "https://old.example.com/v1".into(),
326 auth_header: AuthHeaderFormat::Bearer,
327 model_prefixes: vec!["upd/".into()],
328 };
329 register_custom_provider(config1).expect("first registration should succeed");
330
331 let config2 = CustomProviderConfig {
332 name: "updatable".into(),
333 base_url: "https://new.example.com/v1".into(),
334 auth_header: AuthHeaderFormat::Bearer,
335 model_prefixes: vec!["upd/".into()],
336 };
337 register_custom_provider(config2).expect("second registration should succeed");
338
339 let provider = detect_custom_provider("upd/model").expect("should detect provider");
340 assert_eq!(
341 provider.base_url(),
342 "https://new.example.com/v1",
343 "should use the updated config"
344 );
345 }
346
347 #[test]
348 fn validation_rejects_empty_name() {
349 let _guard = setup();
350
351 let config = CustomProviderConfig {
352 name: String::new(),
353 base_url: "https://example.com".into(),
354 auth_header: AuthHeaderFormat::Bearer,
355 model_prefixes: vec!["x/".into()],
356 };
357 let result = register_custom_provider(config);
358 assert!(result.is_err(), "should reject empty name");
359 }
360
361 #[test]
362 fn validation_rejects_empty_base_url() {
363 let _guard = setup();
364
365 let config = CustomProviderConfig {
366 name: "valid-name".into(),
367 base_url: String::new(),
368 auth_header: AuthHeaderFormat::Bearer,
369 model_prefixes: vec!["x/".into()],
370 };
371 let result = register_custom_provider(config);
372 assert!(result.is_err(), "should reject empty base_url");
373 }
374
375 #[test]
376 fn validation_rejects_no_prefixes() {
377 let _guard = setup();
378
379 let config = CustomProviderConfig {
380 name: "valid-name".into(),
381 base_url: "https://example.com".into(),
382 auth_header: AuthHeaderFormat::Bearer,
383 model_prefixes: vec![],
384 };
385 let result = register_custom_provider(config);
386 assert!(result.is_err(), "should reject empty model_prefixes");
387 }
388
389 #[test]
390 fn config_serde_round_trip() {
391 let config = CustomProviderConfig {
392 name: "serde-test".into(),
393 base_url: "https://example.com/v1".into(),
394 auth_header: AuthHeaderFormat::ApiKey("X-Api-Key".into()),
395 model_prefixes: vec!["serde/".into()],
396 };
397
398 let json = serde_json::to_string(&config).expect("should serialize");
399 let parsed: CustomProviderConfig = serde_json::from_str(&json).expect("should deserialize");
400
401 assert_eq!(parsed.name, "serde-test");
402 assert_eq!(parsed.base_url, "https://example.com/v1");
403 assert_eq!(parsed.model_prefixes, vec!["serde/"]);
404 }
405}