Skip to main content

gunmetal_providers/
lib.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, bail};
4use async_trait::async_trait;
5use gunmetal_core::{ChatCompletionRequest, ProviderContext, ProviderKind, ProviderProfile};
6use gunmetal_storage::AppPaths;
7use serde_json::Value;
8use tokio::sync::Mutex;
9
10mod codex;
11mod copilot;
12mod openai;
13mod openrouter;
14mod zen;
15
16pub use codex::{CodexClient, CodexClientOptions};
17pub use copilot::{CopilotClient, CopilotClientOptions, CopilotSession};
18pub use gunmetal_sdk::{
19    ModelsDevCatalog, ProviderAdapter, ProviderAuthMethod, ProviderAuthResult, ProviderByteStream,
20    ProviderCapabilities, ProviderChatResult, ProviderClass, ProviderDefinition,
21    ProviderEventStream, ProviderHub, ProviderLoginResult, ProviderModelSyncResult,
22    ProviderRawSseResult, ProviderRegistry, ProviderStreamEvent, ProviderStreamResult,
23    ProviderUxHints, openai_compatible_event_stream, synthetic_chat_sse_stream,
24};
25pub use openai::{OpenAiClient, OpenAiClientOptions};
26pub use openrouter::{OpenRouterClient, OpenRouterClientOptions};
27pub use zen::{ZenClient, ZenClientOptions};
28
29#[derive(Clone, Default)]
30struct CodexAdapter {
31    clients: Arc<Mutex<HashMap<uuid::Uuid, Arc<Mutex<CodexClient>>>>>,
32}
33
34struct CopilotAdapter;
35struct OpenRouterAdapter;
36struct ZenAdapter;
37struct OpenAiAdapter;
38
39impl CodexAdapter {
40    async fn cached_client(
41        &self,
42        profile: &ProviderProfile,
43        context: &dyn ProviderContext,
44    ) -> Result<Arc<Mutex<CodexClient>>> {
45        {
46            let clients = self.clients.lock().await;
47            if let Some(client) = clients.get(&profile.id) {
48                return Ok(client.clone());
49            }
50        }
51
52        let client = Arc::new(Mutex::new(
53            CodexClient::spawn(CodexClientOptions::from_profile(profile, context)).await?,
54        ));
55        let mut clients = self.clients.lock().await;
56        Ok(clients
57            .entry(profile.id)
58            .or_insert_with(|| client.clone())
59            .clone())
60    }
61
62    async fn evict_client(&self, profile_id: uuid::Uuid) {
63        self.clients.lock().await.remove(&profile_id);
64    }
65}
66
67pub fn builtin_registry() -> ProviderRegistry {
68    let mut registry = ProviderRegistry::default();
69    registry.register(CodexAdapter::default());
70    registry.register(CopilotAdapter);
71    registry.register(OpenRouterAdapter);
72    registry.register(ZenAdapter);
73    registry.register(OpenAiAdapter);
74    registry
75}
76
77pub fn builtin_provider_hub(paths: AppPaths) -> ProviderHub {
78    ProviderHub::new(paths, builtin_registry())
79}
80
81pub fn builtin_providers() -> Vec<ProviderDefinition> {
82    builtin_registry().definitions()
83}
84
85#[async_trait]
86impl ProviderAdapter for CodexAdapter {
87    fn definition(&self) -> ProviderDefinition {
88        ProviderDefinition {
89            kind: ProviderKind::Codex,
90            label: "codex",
91            class: ProviderClass::Subscription,
92            priority: 1,
93            capabilities: ProviderCapabilities {
94                auth_method: ProviderAuthMethod::BrowserSession,
95                supports_base_url: false,
96                supports_model_sync: true,
97                supports_chat_completions: true,
98                supports_responses_api: true,
99                supports_streaming: true,
100            },
101            ux: ProviderUxHints {
102                helper_title: "Browser sign-in provider",
103                helper_body: "Save the provider, then auth it through the browser flow. Base URL and API key are not needed here.",
104                suggested_name: "codex",
105                base_url_placeholder: "not used for this provider",
106            },
107        }
108    }
109
110    async fn auth_status(
111        &self,
112        profile: &ProviderProfile,
113        context: &dyn ProviderContext,
114    ) -> Result<ProviderAuthResult> {
115        let client = self.cached_client(profile, context).await?;
116        let client = client.lock().await;
117        Ok(ProviderAuthResult {
118            credentials: None,
119            status: client.auth_status().await?,
120        })
121    }
122
123    async fn login(
124        &self,
125        profile: &ProviderProfile,
126        context: &dyn ProviderContext,
127        open_browser: bool,
128    ) -> Result<ProviderLoginResult> {
129        let client = self.cached_client(profile, context).await?;
130        let client = client.lock().await;
131        let session = client.login().await?;
132        if open_browser {
133            let _ = webbrowser::open(&session.auth_url);
134        }
135        Ok(ProviderLoginResult {
136            credentials: None,
137            session,
138        })
139    }
140
141    async fn logout(
142        &self,
143        profile: &ProviderProfile,
144        context: &dyn ProviderContext,
145    ) -> Result<Option<Value>> {
146        let client = self.cached_client(profile, context).await?;
147        let client = client.lock().await;
148        client.logout().await?;
149        drop(client);
150        self.evict_client(profile.id).await;
151        Ok(None)
152    }
153
154    async fn sync_models(
155        &self,
156        profile: &ProviderProfile,
157        context: &dyn ProviderContext,
158    ) -> Result<ProviderModelSyncResult> {
159        let client = self.cached_client(profile, context).await?;
160        let client = client.lock().await;
161        Ok(ProviderModelSyncResult {
162            credentials: None,
163            models: client.list_models(profile.id).await?,
164        })
165    }
166
167    async fn chat_completion(
168        &self,
169        profile: &ProviderProfile,
170        context: &dyn ProviderContext,
171        request: &ChatCompletionRequest,
172    ) -> Result<ProviderChatResult> {
173        let client = self.cached_client(profile, context).await?;
174        let client = client.lock().await;
175        Ok(ProviderChatResult {
176            credentials: None,
177            completion: client.chat_completion(profile.id, request).await?,
178        })
179    }
180}
181
182#[async_trait]
183impl ProviderAdapter for CopilotAdapter {
184    fn definition(&self) -> ProviderDefinition {
185        ProviderDefinition {
186            kind: ProviderKind::Copilot,
187            label: "copilot",
188            class: ProviderClass::Subscription,
189            priority: 2,
190            capabilities: ProviderCapabilities {
191                auth_method: ProviderAuthMethod::BrowserSession,
192                supports_base_url: false,
193                supports_model_sync: true,
194                supports_chat_completions: true,
195                supports_responses_api: true,
196                supports_streaming: true,
197            },
198            ux: ProviderUxHints {
199                helper_title: "Browser sign-in provider",
200                helper_body: "Save the provider, then auth it through the browser flow. Base URL and API key are not needed here.",
201                suggested_name: "copilot",
202                base_url_placeholder: "not used for this provider",
203            },
204        }
205    }
206
207    async fn auth_status(
208        &self,
209        profile: &ProviderProfile,
210        _context: &dyn ProviderContext,
211    ) -> Result<ProviderAuthResult> {
212        let result = CopilotClient::with_options(CopilotClientOptions::from_profile(profile))
213            .auth_status(profile)
214            .await?;
215        Ok(ProviderAuthResult {
216            credentials: result.credentials,
217            status: result.status,
218        })
219    }
220
221    async fn login(
222        &self,
223        profile: &ProviderProfile,
224        _context: &dyn ProviderContext,
225        open_browser: bool,
226    ) -> Result<ProviderLoginResult> {
227        let result = CopilotClient::with_options(CopilotClientOptions::from_profile(profile))
228            .login(profile, open_browser)
229            .await?;
230        Ok(ProviderLoginResult {
231            credentials: result.credentials,
232            session: result.session,
233        })
234    }
235
236    async fn logout(
237        &self,
238        _profile: &ProviderProfile,
239        _paths: &dyn ProviderContext,
240    ) -> Result<Option<Value>> {
241        Ok(None)
242    }
243
244    async fn sync_models(
245        &self,
246        profile: &ProviderProfile,
247        _context: &dyn ProviderContext,
248    ) -> Result<ProviderModelSyncResult> {
249        let result = CopilotClient::with_options(CopilotClientOptions::from_profile(profile))
250            .list_models(profile)
251            .await?;
252        Ok(ProviderModelSyncResult {
253            credentials: result.credentials,
254            models: result.models,
255        })
256    }
257
258    async fn chat_completion(
259        &self,
260        profile: &ProviderProfile,
261        _context: &dyn ProviderContext,
262        request: &ChatCompletionRequest,
263    ) -> Result<ProviderChatResult> {
264        let result = CopilotClient::with_options(CopilotClientOptions::from_profile(profile))
265            .chat_completion(profile, request)
266            .await?;
267        Ok(ProviderChatResult {
268            credentials: result.credentials,
269            completion: result.completion,
270        })
271    }
272}
273
274#[async_trait]
275impl ProviderAdapter for OpenRouterAdapter {
276    fn definition(&self) -> ProviderDefinition {
277        ProviderDefinition {
278            kind: ProviderKind::OpenRouter,
279            label: "openrouter",
280            class: ProviderClass::Gateway,
281            priority: 3,
282            capabilities: ProviderCapabilities {
283                auth_method: ProviderAuthMethod::ApiKey,
284                supports_base_url: true,
285                supports_model_sync: true,
286                supports_chat_completions: true,
287                supports_responses_api: true,
288                supports_streaming: true,
289            },
290            ux: ProviderUxHints {
291                helper_title: "Gateway provider",
292                helper_body: "Save your upstream API key here. Base URL usually stays on the default OpenRouter endpoint.",
293                suggested_name: "openrouter",
294                base_url_placeholder: "https://openrouter.ai/api/v1",
295            },
296        }
297    }
298
299    async fn auth_status(
300        &self,
301        profile: &ProviderProfile,
302        _context: &dyn ProviderContext,
303    ) -> Result<ProviderAuthResult> {
304        let result = OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile))
305            .auth_status(profile)
306            .await?;
307        Ok(ProviderAuthResult {
308            credentials: result.credentials,
309            status: result.status,
310        })
311    }
312
313    async fn login(
314        &self,
315        profile: &ProviderProfile,
316        _context: &dyn ProviderContext,
317        _open_browser: bool,
318    ) -> Result<ProviderLoginResult> {
319        bail!(
320            "provider '{}' does not support browser login",
321            profile.provider
322        )
323    }
324
325    async fn logout(
326        &self,
327        profile: &ProviderProfile,
328        _paths: &dyn ProviderContext,
329    ) -> Result<Option<Value>> {
330        Ok(
331            OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile))
332                .clear_credentials(),
333        )
334    }
335
336    async fn sync_models(
337        &self,
338        profile: &ProviderProfile,
339        _context: &dyn ProviderContext,
340    ) -> Result<ProviderModelSyncResult> {
341        let result = OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile))
342            .list_models(profile)
343            .await?;
344        Ok(ProviderModelSyncResult {
345            credentials: result.credentials,
346            models: result.models,
347        })
348    }
349
350    async fn chat_completion(
351        &self,
352        profile: &ProviderProfile,
353        _context: &dyn ProviderContext,
354        request: &ChatCompletionRequest,
355    ) -> Result<ProviderChatResult> {
356        let result = OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile))
357            .chat_completion(profile, request)
358            .await?;
359        Ok(ProviderChatResult {
360            credentials: result.credentials,
361            completion: result.completion,
362        })
363    }
364
365    async fn stream_chat_completion(
366        &self,
367        profile: &ProviderProfile,
368        _context: &dyn ProviderContext,
369        request: &ChatCompletionRequest,
370    ) -> Result<ProviderStreamResult> {
371        let result = OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile))
372            .stream_chat_completion(profile, request)
373            .await?;
374        Ok(ProviderStreamResult {
375            credentials: result.credentials,
376            stream: result.stream,
377        })
378    }
379
380    async fn raw_stream_chat_completion(
381        &self,
382        profile: &ProviderProfile,
383        _context: &dyn ProviderContext,
384        request: &ChatCompletionRequest,
385    ) -> Result<ProviderRawSseResult> {
386        let client = OpenRouterClient::with_options(OpenRouterClientOptions::from_profile(profile));
387        Ok(ProviderRawSseResult {
388            credentials: None,
389            stream: client.raw_stream_chat_completion(profile, request).await?,
390        })
391    }
392}
393
394#[async_trait]
395impl ProviderAdapter for ZenAdapter {
396    fn definition(&self) -> ProviderDefinition {
397        ProviderDefinition {
398            kind: ProviderKind::Zen,
399            label: "zen",
400            class: ProviderClass::Gateway,
401            priority: 4,
402            capabilities: ProviderCapabilities {
403                auth_method: ProviderAuthMethod::ApiKey,
404                supports_base_url: true,
405                supports_model_sync: true,
406                supports_chat_completions: true,
407                supports_responses_api: true,
408                supports_streaming: true,
409            },
410            ux: ProviderUxHints {
411                helper_title: "Gateway provider",
412                helper_body: "Save your upstream API key here. Base URL usually stays on the default Zen endpoint.",
413                suggested_name: "zen",
414                base_url_placeholder: "https://opencode.ai/zen/v1",
415            },
416        }
417    }
418
419    async fn auth_status(
420        &self,
421        profile: &ProviderProfile,
422        _context: &dyn ProviderContext,
423    ) -> Result<ProviderAuthResult> {
424        let result = ZenClient::with_options(ZenClientOptions::from_profile(profile))
425            .auth_status(profile)
426            .await?;
427        Ok(ProviderAuthResult {
428            credentials: result.credentials,
429            status: result.status,
430        })
431    }
432
433    async fn login(
434        &self,
435        profile: &ProviderProfile,
436        _context: &dyn ProviderContext,
437        _open_browser: bool,
438    ) -> Result<ProviderLoginResult> {
439        bail!(
440            "provider '{}' does not support browser login",
441            profile.provider
442        )
443    }
444
445    async fn logout(
446        &self,
447        profile: &ProviderProfile,
448        _paths: &dyn ProviderContext,
449    ) -> Result<Option<Value>> {
450        Ok(ZenClient::with_options(ZenClientOptions::from_profile(profile)).clear_credentials())
451    }
452
453    async fn sync_models(
454        &self,
455        profile: &ProviderProfile,
456        _context: &dyn ProviderContext,
457    ) -> Result<ProviderModelSyncResult> {
458        let result = ZenClient::with_options(ZenClientOptions::from_profile(profile))
459            .list_models(profile)
460            .await?;
461        Ok(ProviderModelSyncResult {
462            credentials: result.credentials,
463            models: result.models,
464        })
465    }
466
467    async fn chat_completion(
468        &self,
469        profile: &ProviderProfile,
470        _context: &dyn ProviderContext,
471        request: &ChatCompletionRequest,
472    ) -> Result<ProviderChatResult> {
473        let result = ZenClient::with_options(ZenClientOptions::from_profile(profile))
474            .chat_completion(profile, request)
475            .await?;
476        Ok(ProviderChatResult {
477            credentials: result.credentials,
478            completion: result.completion,
479        })
480    }
481}
482
483#[async_trait]
484impl ProviderAdapter for OpenAiAdapter {
485    fn definition(&self) -> ProviderDefinition {
486        ProviderDefinition {
487            kind: ProviderKind::OpenAi,
488            label: "openai",
489            class: ProviderClass::Direct,
490            priority: 5,
491            capabilities: ProviderCapabilities {
492                auth_method: ProviderAuthMethod::ApiKey,
493                supports_base_url: true,
494                supports_model_sync: true,
495                supports_chat_completions: true,
496                supports_responses_api: true,
497                supports_streaming: true,
498            },
499            ux: ProviderUxHints {
500                helper_title: "Direct provider",
501                helper_body: "Save your upstream API key here. Base URL is optional unless you need a custom endpoint.",
502                suggested_name: "openai",
503                base_url_placeholder: "https://api.openai.com/v1",
504            },
505        }
506    }
507
508    async fn auth_status(
509        &self,
510        profile: &ProviderProfile,
511        _context: &dyn ProviderContext,
512    ) -> Result<ProviderAuthResult> {
513        Ok(ProviderAuthResult {
514            credentials: profile.credentials.clone(),
515            status: OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile))
516                .auth_status(profile)
517                .await?,
518        })
519    }
520
521    async fn login(
522        &self,
523        profile: &ProviderProfile,
524        _context: &dyn ProviderContext,
525        _open_browser: bool,
526    ) -> Result<ProviderLoginResult> {
527        bail!(
528            "provider '{}' does not support browser login",
529            profile.provider
530        )
531    }
532
533    async fn logout(
534        &self,
535        profile: &ProviderProfile,
536        _paths: &dyn ProviderContext,
537    ) -> Result<Option<Value>> {
538        Ok(
539            OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile))
540                .clear_credentials(),
541        )
542    }
543
544    async fn sync_models(
545        &self,
546        profile: &ProviderProfile,
547        _context: &dyn ProviderContext,
548    ) -> Result<ProviderModelSyncResult> {
549        Ok(ProviderModelSyncResult {
550            credentials: profile.credentials.clone(),
551            models: OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile))
552                .list_models(profile)
553                .await?,
554        })
555    }
556
557    async fn chat_completion(
558        &self,
559        profile: &ProviderProfile,
560        _context: &dyn ProviderContext,
561        request: &ChatCompletionRequest,
562    ) -> Result<ProviderChatResult> {
563        Ok(ProviderChatResult {
564            credentials: profile.credentials.clone(),
565            completion: OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile))
566                .chat_completion(profile, request)
567                .await?,
568        })
569    }
570
571    async fn stream_chat_completion(
572        &self,
573        profile: &ProviderProfile,
574        _context: &dyn ProviderContext,
575        request: &ChatCompletionRequest,
576    ) -> Result<ProviderStreamResult> {
577        Ok(ProviderStreamResult {
578            credentials: profile.credentials.clone(),
579            stream: OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile))
580                .stream_chat_completion(profile, request)
581                .await?,
582        })
583    }
584
585    async fn raw_stream_chat_completion(
586        &self,
587        profile: &ProviderProfile,
588        _context: &dyn ProviderContext,
589        request: &ChatCompletionRequest,
590    ) -> Result<ProviderRawSseResult> {
591        let client = OpenAiClient::with_options(OpenAiClientOptions::from_profile(profile));
592        Ok(ProviderRawSseResult {
593            credentials: profile.credentials.clone(),
594            stream: client.raw_stream_chat_completion(profile, request).await?,
595        })
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[test]
604    fn builtin_provider_order_matches_product_priority() {
605        let providers = builtin_providers();
606        assert_eq!(providers[0].kind, ProviderKind::Codex);
607        assert_eq!(providers[1].kind, ProviderKind::Copilot);
608        assert_eq!(providers[2].kind, ProviderKind::OpenRouter);
609        assert_eq!(providers[3].kind, ProviderKind::Zen);
610        assert_eq!(providers[4].kind, ProviderKind::OpenAi);
611        assert!(providers[0].supports_browser_login());
612        assert!(!providers[0].capabilities.supports_base_url);
613        assert!(providers[2].requires_api_key());
614        assert_eq!(
615            providers[2].ux.base_url_placeholder,
616            "https://openrouter.ai/api/v1"
617        );
618    }
619
620    #[tokio::test]
621    async fn codex_cached_client_reuses_existing_client() {
622        let temp = tempfile::tempdir().unwrap();
623        let paths = AppPaths::from_root(temp.path().to_path_buf()).unwrap();
624        let profile = ProviderProfile {
625            id: uuid::Uuid::new_v4(),
626            provider: ProviderKind::Codex,
627            name: "codex".to_owned(),
628            base_url: None,
629            enabled: true,
630            credentials: None,
631            created_at: chrono::Utc::now(),
632            updated_at: chrono::Utc::now(),
633        };
634        let adapter = CodexAdapter::default();
635        let mock_client = Arc::new(Mutex::new(CodexClient::mock("test")));
636        {
637            let mut clients = adapter.clients.lock().await;
638            clients.insert(profile.id, mock_client.clone());
639        }
640
641        let cached = adapter.cached_client(&profile, &paths).await.unwrap();
642        assert!(Arc::ptr_eq(&cached, &mock_client));
643    }
644
645    #[tokio::test]
646    async fn codex_evict_client_removes_cached_entry() {
647        let adapter = CodexAdapter::default();
648        let profile_id = uuid::Uuid::new_v4();
649        {
650            let mut clients = adapter.clients.lock().await;
651            clients.insert(profile_id, Arc::new(Mutex::new(CodexClient::mock("test"))));
652        }
653        assert!(adapter.clients.lock().await.contains_key(&profile_id));
654        adapter.evict_client(profile_id).await;
655        assert!(!adapter.clients.lock().await.contains_key(&profile_id));
656    }
657
658    #[tokio::test]
659    async fn copilot_logout_returns_none() {
660        let adapter = CopilotAdapter;
661        let profile = ProviderProfile {
662            id: uuid::Uuid::new_v4(),
663            provider: ProviderKind::Copilot,
664            name: "copilot".to_owned(),
665            base_url: None,
666            enabled: true,
667            credentials: None,
668            created_at: chrono::Utc::now(),
669            updated_at: chrono::Utc::now(),
670        };
671        let result = adapter
672            .logout(&profile, &AppPaths::resolve().unwrap())
673            .await
674            .unwrap();
675        assert!(result.is_none());
676    }
677}