1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures_util::{
6 StreamExt,
7 stream::{self, BoxStream},
8};
9use gunmetal_core::{
10 ChatCompletionRequest, ChatCompletionResult, ModelDescriptor, ModelMetadata,
11 ProviderAuthStatus, ProviderContext, ProviderKind, ProviderLoginSession, ProviderProfile,
12 TokenUsage,
13};
14use gunmetal_storage::AppPaths;
15use reqwest::{Client, Response};
16use serde::Deserialize;
17use serde_json::Value;
18use tokio::sync::Mutex;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ProviderClass {
22 Subscription,
23 Gateway,
24 Direct,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ProviderAuthMethod {
29 BrowserSession,
30 ApiKey,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ProviderCapabilities {
35 pub auth_method: ProviderAuthMethod,
36 pub supports_base_url: bool,
37 pub supports_model_sync: bool,
38 pub supports_chat_completions: bool,
39 pub supports_responses_api: bool,
40 pub supports_streaming: bool,
41}
42
43impl ProviderCapabilities {
44 pub fn supports_browser_login(&self) -> bool {
45 matches!(self.auth_method, ProviderAuthMethod::BrowserSession)
46 }
47
48 pub fn requires_api_key(&self) -> bool {
49 matches!(self.auth_method, ProviderAuthMethod::ApiKey)
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ProviderUxHints {
55 pub helper_title: &'static str,
56 pub helper_body: &'static str,
57 pub suggested_name: &'static str,
58 pub base_url_placeholder: &'static str,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct ProviderDefinition {
63 pub kind: ProviderKind,
64 pub label: &'static str,
65 pub class: ProviderClass,
66 pub priority: usize,
67 pub capabilities: ProviderCapabilities,
68 pub ux: ProviderUxHints,
69}
70
71impl ProviderDefinition {
72 pub fn supports_browser_login(&self) -> bool {
73 self.capabilities.supports_browser_login()
74 }
75
76 pub fn requires_api_key(&self) -> bool {
77 self.capabilities.requires_api_key()
78 }
79}
80
81#[derive(Debug, Clone)]
82pub struct ProviderAuthResult {
83 pub credentials: Option<Value>,
84 pub status: ProviderAuthStatus,
85}
86
87#[derive(Debug, Clone)]
88pub struct ProviderLoginResult {
89 pub credentials: Option<Value>,
90 pub session: ProviderLoginSession,
91}
92
93#[derive(Debug, Clone)]
94pub struct ProviderModelSyncResult {
95 pub credentials: Option<Value>,
96 pub models: Vec<ModelDescriptor>,
97}
98
99#[derive(Debug, Clone)]
100pub struct ProviderChatResult {
101 pub completion: ChatCompletionResult,
102 pub credentials: Option<Value>,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub enum ProviderStreamEvent {
107 TextDelta(String),
108 Complete {
109 model: String,
110 finish_reason: String,
111 usage: TokenUsage,
112 },
113}
114
115pub type ProviderEventStream = BoxStream<'static, Result<ProviderStreamEvent>>;
116pub type ProviderByteStream = BoxStream<'static, Result<Vec<u8>>>;
117
118pub struct ProviderStreamResult {
119 pub stream: ProviderEventStream,
120 pub credentials: Option<Value>,
121}
122
123pub struct ProviderRawSseResult {
124 pub stream: ProviderByteStream,
125 pub credentials: Option<Value>,
126}
127
128#[async_trait]
129pub trait ProviderAdapter: Send + Sync {
130 fn definition(&self) -> ProviderDefinition;
131
132 async fn auth_status(
133 &self,
134 profile: &ProviderProfile,
135 context: &dyn ProviderContext,
136 ) -> Result<ProviderAuthResult>;
137
138 async fn login(
139 &self,
140 profile: &ProviderProfile,
141 context: &dyn ProviderContext,
142 open_browser: bool,
143 ) -> Result<ProviderLoginResult>;
144
145 async fn logout(
146 &self,
147 profile: &ProviderProfile,
148 context: &dyn ProviderContext,
149 ) -> Result<Option<Value>>;
150
151 async fn sync_models(
152 &self,
153 profile: &ProviderProfile,
154 context: &dyn ProviderContext,
155 ) -> Result<ProviderModelSyncResult>;
156
157 async fn chat_completion(
158 &self,
159 profile: &ProviderProfile,
160 context: &dyn ProviderContext,
161 request: &ChatCompletionRequest,
162 ) -> Result<ProviderChatResult>;
163
164 async fn stream_chat_completion(
165 &self,
166 profile: &ProviderProfile,
167 context: &dyn ProviderContext,
168 request: &ChatCompletionRequest,
169 ) -> Result<ProviderStreamResult> {
170 let result = self.chat_completion(profile, context, request).await?;
171 Ok(ProviderStreamResult {
172 credentials: result.credentials,
173 stream: synthetic_completion_stream(result.completion),
174 })
175 }
176
177 async fn raw_stream_chat_completion(
178 &self,
179 profile: &ProviderProfile,
180 context: &dyn ProviderContext,
181 request: &ChatCompletionRequest,
182 ) -> Result<ProviderRawSseResult> {
183 let result = self
184 .stream_chat_completion(profile, context, request)
185 .await?;
186 Ok(ProviderRawSseResult {
187 credentials: result.credentials,
188 stream: synthetic_chat_sse_stream(request.model.clone(), result.stream),
189 })
190 }
191}
192
193#[derive(Clone, Default)]
194pub struct ProviderRegistry {
195 adapters: HashMap<ProviderKind, Arc<dyn ProviderAdapter>>,
196}
197
198impl ProviderRegistry {
199 pub fn register<A>(&mut self, adapter: A)
200 where
201 A: ProviderAdapter + 'static,
202 {
203 let adapter = Arc::new(adapter);
204 self.adapters
205 .insert(adapter.definition().kind.clone(), adapter);
206 }
207
208 pub fn get(&self, kind: &ProviderKind) -> Option<Arc<dyn ProviderAdapter>> {
209 self.adapters.get(kind).cloned()
210 }
211
212 pub fn definition(&self, kind: &ProviderKind) -> Option<ProviderDefinition> {
213 self.adapters.get(kind).map(|adapter| adapter.definition())
214 }
215
216 pub fn definitions(&self) -> Vec<ProviderDefinition> {
217 let mut definitions = self
218 .adapters
219 .values()
220 .map(|adapter| adapter.definition())
221 .collect::<Vec<_>>();
222 definitions.sort_by_key(|item| item.priority);
223 definitions
224 }
225}
226
227#[derive(Clone)]
228pub struct ProviderHub {
229 paths: AppPaths,
230 registry: ProviderRegistry,
231 models_dev: ModelsDevCatalog,
232}
233
234impl ProviderHub {
235 pub fn new(paths: AppPaths, registry: ProviderRegistry) -> Self {
236 Self {
237 paths,
238 registry,
239 models_dev: ModelsDevCatalog::default(),
240 }
241 }
242
243 pub fn with_registry(paths: AppPaths, registry: ProviderRegistry) -> Self {
244 Self::new(paths, registry)
245 }
246
247 pub fn with_registry_and_models_dev(
248 paths: AppPaths,
249 registry: ProviderRegistry,
250 models_dev: ModelsDevCatalog,
251 ) -> Self {
252 Self {
253 paths,
254 registry,
255 models_dev,
256 }
257 }
258
259 pub async fn auth_status(&self, profile: &ProviderProfile) -> Result<ProviderAuthStatus> {
260 let adapter = self.adapter(&profile.provider)?;
261 let result = adapter.auth_status(profile, &self.paths).await?;
262 self.persist_credentials(profile.id, result.credentials)?;
263 Ok(result.status)
264 }
265
266 pub async fn login(
267 &self,
268 profile: &ProviderProfile,
269 open_browser: bool,
270 ) -> Result<ProviderLoginSession> {
271 let adapter = self.adapter(&profile.provider)?;
272 let result = adapter.login(profile, &self.paths, open_browser).await?;
273 self.persist_credentials(profile.id, result.credentials)?;
274 Ok(result.session)
275 }
276
277 pub async fn logout(&self, profile: &ProviderProfile) -> Result<()> {
278 let adapter = self.adapter(&profile.provider)?;
279 let credentials = adapter.logout(profile, &self.paths).await?;
280 self.persist_credentials(profile.id, credentials)
281 }
282
283 pub async fn sync_models(&self, profile: &ProviderProfile) -> Result<Vec<ModelDescriptor>> {
284 let adapter = self.adapter(&profile.provider)?;
285 let mut result = adapter.sync_models(profile, &self.paths).await?;
286 self.persist_credentials(profile.id, result.credentials)?;
287 if let Err(error) = self.models_dev.enrich(&mut result.models).await {
288 let _ = error;
289 }
290 Ok(result.models)
291 }
292
293 pub async fn chat_completion(
294 &self,
295 profile: &ProviderProfile,
296 request: &ChatCompletionRequest,
297 ) -> Result<ChatCompletionResult> {
298 let adapter = self.adapter(&profile.provider)?;
299 let result = adapter
300 .chat_completion(profile, &self.paths, request)
301 .await?;
302 self.persist_credentials(profile.id, result.credentials)?;
303 Ok(result.completion)
304 }
305
306 pub async fn stream_chat_completion(
307 &self,
308 profile: &ProviderProfile,
309 request: &ChatCompletionRequest,
310 ) -> Result<ProviderEventStream> {
311 let adapter = self.adapter(&profile.provider)?;
312 let result = adapter
313 .stream_chat_completion(profile, &self.paths, request)
314 .await?;
315 self.persist_credentials(profile.id, result.credentials)?;
316 Ok(result.stream)
317 }
318
319 pub async fn raw_stream_chat_completion(
320 &self,
321 profile: &ProviderProfile,
322 request: &ChatCompletionRequest,
323 ) -> Result<ProviderByteStream> {
324 let adapter = self.adapter(&profile.provider)?;
325 let result = adapter
326 .raw_stream_chat_completion(profile, &self.paths, request)
327 .await?;
328 self.persist_credentials(profile.id, result.credentials)?;
329 Ok(result.stream)
330 }
331
332 pub fn definitions(&self) -> Vec<ProviderDefinition> {
333 self.registry.definitions()
334 }
335
336 pub fn definition(&self, kind: &ProviderKind) -> Option<ProviderDefinition> {
337 self.registry.definition(kind)
338 }
339
340 fn adapter(&self, kind: &ProviderKind) -> Result<Arc<dyn ProviderAdapter>> {
341 self.registry
342 .get(kind)
343 .ok_or_else(|| anyhow!("provider '{}' not implemented yet", kind))
344 }
345
346 fn persist_credentials(
347 &self,
348 profile_id: uuid::Uuid,
349 credentials: Option<serde_json::Value>,
350 ) -> Result<()> {
351 let Some(credentials) = credentials else {
352 return Ok(());
353 };
354 self.paths
355 .storage_handle()?
356 .update_profile_credentials(profile_id, Some(credentials))
357 }
358}
359
360fn synthetic_completion_stream(completion: ChatCompletionResult) -> ProviderEventStream {
361 let mut events = text_chunks(&completion.message.content)
362 .into_iter()
363 .map(ProviderStreamEvent::TextDelta)
364 .collect::<Vec<_>>();
365 events.push(ProviderStreamEvent::Complete {
366 model: completion.model,
367 finish_reason: completion.finish_reason,
368 usage: completion.usage,
369 });
370 stream::iter(events.into_iter().map(Ok)).boxed()
371}
372
373pub fn synthetic_chat_sse_stream(model: String, stream: ProviderEventStream) -> ProviderByteStream {
374 let id = format!("chatcmpl-{}", uuid::Uuid::new_v4().simple());
375 let created = chrono::Utc::now().timestamp();
376 let first = stream::once(async move {
377 Ok::<Vec<u8>, anyhow::Error>(
378 format!(
379 "data: {}\n\n",
380 serde_json::json!({
381 "id": id,
382 "object": "chat.completion.chunk",
383 "created": created,
384 "model": model,
385 "choices": [{
386 "index": 0,
387 "delta": { "role": "assistant" },
388 "finish_reason": Value::Null
389 }]
390 })
391 )
392 .into_bytes(),
393 )
394 });
395
396 let content = stream.map(move |item| match item {
397 Ok(ProviderStreamEvent::TextDelta(chunk)) => Ok(format!(
398 "data: {}\n\n",
399 serde_json::json!({
400 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
401 "object": "chat.completion.chunk",
402 "created": chrono::Utc::now().timestamp(),
403 "choices": [{
404 "index": 0,
405 "delta": { "content": chunk },
406 "finish_reason": Value::Null
407 }]
408 })
409 )
410 .into_bytes()),
411 Ok(ProviderStreamEvent::Complete {
412 model,
413 finish_reason,
414 usage,
415 }) => Ok(format!(
416 "data: {}\n\n",
417 serde_json::json!({
418 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
419 "object": "chat.completion.chunk",
420 "created": chrono::Utc::now().timestamp(),
421 "model": model,
422 "choices": [{
423 "index": 0,
424 "delta": {},
425 "finish_reason": finish_reason
426 }],
427 "usage": usage
428 })
429 )
430 .into_bytes()),
431 Err(error) => Ok(format!(
432 "event: error\ndata: {}\n\n",
433 serde_json::json!({ "error": { "message": error.to_string() } })
434 )
435 .into_bytes()),
436 });
437
438 let done = stream::once(async { Ok::<Vec<u8>, anyhow::Error>(b"data: [DONE]\n\n".to_vec()) });
439 first.chain(content).chain(done).boxed()
440}
441
442pub fn openai_compatible_event_stream<F>(
443 response: Response,
444 fallback_model: String,
445 normalize_model: F,
446) -> ProviderEventStream
447where
448 F: Fn(&str) -> String + Send + Sync + 'static,
449{
450 let normalize_model = Arc::new(normalize_model);
451 async_stream::try_stream! {
452 let mut upstream = response.bytes_stream();
453 let mut decoder = SseDecoder::default();
454 let mut current_model = fallback_model;
455
456 while let Some(chunk) = upstream.next().await {
457 let chunk = chunk?;
458 decoder.push(&chunk);
459
460 while let Some(event) = decoder.next_event() {
461 if event == "[DONE]" {
462 continue;
463 }
464
465 for parsed in parse_openai_compatible_event(
466 &event,
467 &mut current_model,
468 normalize_model.as_ref(),
469 )? {
470 yield parsed;
471 }
472 }
473 }
474 }
475 .boxed()
476}
477
478fn parse_openai_compatible_event(
479 event: &str,
480 current_model: &mut String,
481 normalize_model: &dyn Fn(&str) -> String,
482) -> Result<Vec<ProviderStreamEvent>> {
483 let payload = serde_json::from_str::<OpenAiCompatibleStreamChunk>(event)?;
484 if let Some(model) = payload.model.as_deref() {
485 *current_model = normalize_model(model);
486 }
487
488 let mut events = Vec::new();
489 let usage = payload.usage.map(to_token_usage);
490 for choice in payload.choices {
491 if let Some(delta) = choice.delta.and_then(|delta| delta.content)
492 && !delta.is_empty()
493 {
494 events.push(ProviderStreamEvent::TextDelta(delta));
495 }
496
497 if let Some(finish_reason) = choice.finish_reason {
498 events.push(ProviderStreamEvent::Complete {
499 model: current_model.clone(),
500 finish_reason,
501 usage: usage.clone().unwrap_or(TokenUsage {
502 input_tokens: None,
503 output_tokens: None,
504 total_tokens: None,
505 }),
506 });
507 }
508 }
509
510 Ok(events)
511}
512
513fn to_token_usage(usage: OpenAiCompatibleUsage) -> TokenUsage {
514 let input_tokens = usage.prompt_tokens.map(to_u32);
515 let output_tokens = usage.completion_tokens.map(to_u32);
516 let total_tokens =
517 usage
518 .total_tokens
519 .map(to_u32)
520 .or_else(|| match (input_tokens, output_tokens) {
521 (Some(input), Some(output)) => Some(input.saturating_add(output)),
522 _ => None,
523 });
524
525 TokenUsage {
526 input_tokens,
527 output_tokens,
528 total_tokens,
529 }
530}
531
532#[derive(Default)]
533struct SseDecoder {
534 buffer: String,
535}
536
537impl SseDecoder {
538 fn push(&mut self, chunk: &[u8]) {
539 let chunk = String::from_utf8_lossy(chunk);
540 let chunk = chunk.replace("\r\n", "\n");
541 self.buffer.push_str(&chunk);
542 }
543
544 fn next_event(&mut self) -> Option<String> {
545 let separator = self.buffer.find("\n\n")?;
546 let frame = self.buffer[..separator].to_owned();
547 self.buffer.drain(..separator + 2);
548
549 let data = frame
550 .lines()
551 .filter_map(|line| line.strip_prefix("data:"))
552 .map(str::trim_start)
553 .collect::<Vec<_>>()
554 .join("\n");
555 (!data.is_empty()).then_some(data)
556 }
557}
558
559#[derive(Debug, Deserialize)]
560struct OpenAiCompatibleStreamChunk {
561 #[serde(default)]
562 model: Option<String>,
563 #[serde(default)]
564 choices: Vec<OpenAiCompatibleStreamChoice>,
565 #[serde(default)]
566 usage: Option<OpenAiCompatibleUsage>,
567}
568
569#[derive(Debug, Deserialize)]
570struct OpenAiCompatibleStreamChoice {
571 #[serde(default)]
572 delta: Option<OpenAiCompatibleStreamDelta>,
573 #[serde(default)]
574 finish_reason: Option<String>,
575}
576
577#[derive(Debug, Deserialize)]
578struct OpenAiCompatibleStreamDelta {
579 #[serde(default)]
580 content: Option<String>,
581}
582
583#[derive(Debug, Clone, Deserialize)]
584struct OpenAiCompatibleUsage {
585 #[serde(default)]
586 prompt_tokens: Option<u64>,
587 #[serde(default)]
588 completion_tokens: Option<u64>,
589 #[serde(default)]
590 total_tokens: Option<u64>,
591}
592
593fn text_chunks(value: &str) -> Vec<String> {
594 if value.is_empty() {
595 return vec![String::new()];
596 }
597
598 let mut chunks = Vec::new();
599 let mut current = String::new();
600 let mut count = 0usize;
601 for ch in value.chars() {
602 current.push(ch);
603 count += 1;
604 if count >= 24 {
605 chunks.push(std::mem::take(&mut current));
606 count = 0;
607 }
608 }
609
610 if !current.is_empty() {
611 chunks.push(current);
612 }
613
614 chunks
615}
616
617#[derive(Clone)]
618pub struct ModelsDevCatalog {
619 catalog_url: String,
620 http: Client,
621 cache: Arc<Mutex<Option<ModelsDevIndex>>>,
622}
623
624impl Default for ModelsDevCatalog {
625 fn default() -> Self {
626 Self::new("https://models.dev/api.json")
627 }
628}
629
630impl ModelsDevCatalog {
631 pub fn new(catalog_url: impl Into<String>) -> Self {
632 Self {
633 catalog_url: catalog_url.into(),
634 http: Client::builder()
635 .connect_timeout(std::time::Duration::from_secs(2))
636 .timeout(std::time::Duration::from_secs(4))
637 .build()
638 .expect("reqwest client"),
639 cache: Arc::new(Mutex::new(None)),
640 }
641 }
642
643 async fn enrich(&self, models: &mut [ModelDescriptor]) -> Result<()> {
644 let index = self.index().await?;
645 for model in models {
646 if model.metadata.is_some() {
647 continue;
648 }
649
650 let aliases = provider_aliases(&model.provider);
651 let metadata = aliases
652 .iter()
653 .find_map(|alias| index.by_provider.get(*alias))
654 .and_then(|models| models.get(&model.upstream_name).cloned())
655 .or_else(|| index.by_model_id.get(&model.upstream_name).cloned());
656 model.metadata = metadata;
657 }
658 Ok(())
659 }
660
661 async fn index(&self) -> Result<ModelsDevIndex> {
662 {
663 let cache = self.cache.lock().await;
664 if let Some(index) = cache.as_ref() {
665 return Ok(index.clone());
666 }
667 }
668
669 let payload = self
670 .http
671 .get(&self.catalog_url)
672 .send()
673 .await?
674 .error_for_status()?
675 .json::<HashMap<String, ModelsDevProvider>>()
676 .await?;
677 let index = ModelsDevIndex::from_payload(payload);
678 let mut cache = self.cache.lock().await;
679 *cache = Some(index.clone());
680 Ok(index)
681 }
682}
683
684#[derive(Debug, Clone, Default)]
685struct ModelsDevIndex {
686 by_model_id: HashMap<String, ModelMetadata>,
687 by_provider: HashMap<String, HashMap<String, ModelMetadata>>,
688}
689
690impl ModelsDevIndex {
691 fn from_payload(payload: HashMap<String, ModelsDevProvider>) -> Self {
692 let mut index = Self::default();
693 for (provider, envelope) in payload {
694 let mut provider_models = HashMap::new();
695 for (model_id, model) in envelope.models {
696 let metadata = ModelMetadata {
697 family: model.family,
698 release_date: model.release_date,
699 last_updated: model.last_updated,
700 input_modalities: model.modalities.input,
701 output_modalities: model.modalities.output,
702 context_window: model.limit.context.map(to_u32),
703 max_output_tokens: model.limit.output.map(to_u32),
704 supports_attachments: model.attachment,
705 supports_reasoning: model.reasoning,
706 supports_tools: model.tool_call,
707 open_weights: model.open_weights,
708 };
709 provider_models.insert(model_id.clone(), metadata.clone());
710 index.by_model_id.entry(model_id).or_insert(metadata);
711 }
712 index.by_provider.insert(provider, provider_models);
713 }
714 index
715 }
716}
717
718#[derive(Debug, Clone, Deserialize, Default)]
719struct ModelsDevProvider {
720 #[serde(default)]
721 models: HashMap<String, ModelsDevModel>,
722}
723
724#[derive(Debug, Clone, Deserialize, Default)]
725struct ModelsDevModel {
726 family: Option<String>,
727 attachment: Option<bool>,
728 reasoning: Option<bool>,
729 tool_call: Option<bool>,
730 open_weights: Option<bool>,
731 release_date: Option<String>,
732 last_updated: Option<String>,
733 #[serde(default)]
734 modalities: ModelsDevModalities,
735 #[serde(default)]
736 limit: ModelsDevLimits,
737}
738
739#[derive(Debug, Clone, Deserialize, Default)]
740struct ModelsDevModalities {
741 #[serde(default)]
742 input: Vec<String>,
743 #[serde(default)]
744 output: Vec<String>,
745}
746
747#[derive(Debug, Clone, Deserialize, Default)]
748struct ModelsDevLimits {
749 context: Option<u64>,
750 output: Option<u64>,
751}
752
753fn provider_aliases(provider: &ProviderKind) -> &'static [&'static str] {
754 match provider {
755 ProviderKind::Codex => &["codex", "openai"],
756 ProviderKind::Copilot => &["copilot", "github"],
757 ProviderKind::OpenRouter => &["openrouter"],
758 ProviderKind::Zen => &["zen", "opencode", "zenmux"],
759 ProviderKind::OpenAi => &["openai"],
760 ProviderKind::Azure => &["azure", "azure-cognitive-services"],
761 ProviderKind::Nvidia => &["nvidia"],
762 ProviderKind::Custom(_) => &[],
763 }
764}
765
766fn to_u32(value: u64) -> u32 {
767 u32::try_from(value).unwrap_or(u32::MAX)
768}
769
770#[cfg(test)]
771mod tests {
772 use anyhow::{Result, bail};
773 use gunmetal_core::{
774 ChatMessage, ChatRole, NewProviderProfile, ProviderAuthState, RequestOptions,
775 };
776 use serde_json::json;
777 use tempfile::TempDir;
778 use wiremock::{
779 Mock, MockServer, ResponseTemplate,
780 matchers::{method, path},
781 };
782
783 use super::*;
784
785 #[tokio::test]
786 async fn provider_hub_uses_registered_adapter_and_persists_credentials() {
787 let temp = TempDir::new().unwrap();
788 let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
789 let storage = paths.storage_handle().unwrap();
790 let profile = storage
791 .create_profile(NewProviderProfile {
792 provider: ProviderKind::Custom("mock".to_owned()),
793 name: "mock".to_owned(),
794 base_url: None,
795 enabled: true,
796 credentials: None,
797 })
798 .unwrap();
799
800 let mut registry = ProviderRegistry::default();
801 registry.register(MockAdapter);
802 let hub = ProviderHub::new(paths.clone(), registry);
803
804 let status = hub.auth_status(&profile).await.unwrap();
805 assert_eq!(status.state, ProviderAuthState::Connected);
806
807 let synced = hub.sync_models(&profile).await.unwrap();
808 assert_eq!(synced[0].id, "mock/model-1");
809
810 let completion = hub
811 .chat_completion(
812 &profile,
813 &ChatCompletionRequest {
814 model: "mock/model-1".to_owned(),
815 messages: vec![ChatMessage {
816 role: ChatRole::User,
817 content: "ping".to_owned(),
818 }],
819 stream: false,
820 options: RequestOptions::default(),
821 },
822 )
823 .await
824 .unwrap();
825 assert_eq!(completion.message.content, "hello from mock");
826
827 let updated = storage.get_profile(profile.id).unwrap().unwrap();
828 assert_eq!(updated.credentials, Some(json!({ "token": "updated" })));
829 }
830
831 #[tokio::test]
832 async fn models_dev_enriches_synced_models() {
833 let server = MockServer::start().await;
834 Mock::given(method("GET"))
835 .and(path("/api.json"))
836 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
837 "openai": {
838 "models": {
839 "gpt-5.1": {
840 "family": "gpt",
841 "attachment": true,
842 "reasoning": true,
843 "tool_call": true,
844 "open_weights": false,
845 "release_date": "2025-01-01",
846 "last_updated": "2025-02-01",
847 "modalities": { "input": ["text"], "output": ["text"] },
848 "limit": { "context": 272000, "output": 16384 }
849 }
850 }
851 }
852 })))
853 .mount(&server)
854 .await;
855
856 let temp = TempDir::new().unwrap();
857 let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
858 let storage = paths.storage_handle().unwrap();
859 let profile = storage
860 .create_profile(NewProviderProfile {
861 provider: ProviderKind::Codex,
862 name: "codex".to_owned(),
863 base_url: None,
864 enabled: true,
865 credentials: None,
866 })
867 .unwrap();
868
869 let mut registry = ProviderRegistry::default();
870 registry.register(MockCodexAdapter);
871 let hub = ProviderHub::with_registry_and_models_dev(
872 paths,
873 registry,
874 ModelsDevCatalog::new(format!("{}/api.json", server.uri())),
875 );
876
877 let models = hub.sync_models(&profile).await.unwrap();
878 assert_eq!(
879 models[0]
880 .metadata
881 .as_ref()
882 .and_then(|value| value.family.as_deref()),
883 Some("gpt")
884 );
885 assert_eq!(
886 models[0]
887 .metadata
888 .as_ref()
889 .and_then(|value| value.context_window),
890 Some(272_000)
891 );
892 }
893
894 #[test]
895 fn provider_hub_exposes_definition_metadata() {
896 let temp = TempDir::new().unwrap();
897 let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
898 let mut registry = ProviderRegistry::default();
899 registry.register(MockAdapter);
900 let hub = ProviderHub::new(paths, registry);
901
902 let definition = hub
903 .definition(&ProviderKind::Custom("mock".to_owned()))
904 .unwrap();
905 assert_eq!(definition.label, "mock");
906 assert!(definition.requires_api_key());
907 assert!(definition.capabilities.supports_responses_api);
908 }
909
910 #[tokio::test]
911 async fn synthetic_chat_sse_stream_emits_expected_events() {
912 let events: ProviderEventStream = stream::iter(vec![
913 Ok(ProviderStreamEvent::TextDelta("Hello".to_owned())),
914 Ok(ProviderStreamEvent::TextDelta(" world".to_owned())),
915 Ok(ProviderStreamEvent::Complete {
916 model: "gpt-4".to_owned(),
917 finish_reason: "stop".to_owned(),
918 usage: TokenUsage {
919 input_tokens: Some(1),
920 output_tokens: Some(2),
921 total_tokens: Some(3),
922 },
923 }),
924 ])
925 .boxed();
926
927 let byte_stream = synthetic_chat_sse_stream("gpt-4".to_owned(), events);
928 let chunks: Vec<Vec<u8>> = byte_stream
929 .collect::<Vec<_>>()
930 .await
931 .into_iter()
932 .collect::<Result<Vec<_>>>()
933 .unwrap();
934 let output = String::from_utf8(chunks.concat()).unwrap();
935
936 assert!(output.contains("chat.completion.chunk"));
937 assert!(output.contains("\"role\":\"assistant\""));
938 assert!(output.contains("Hello"));
939 assert!(output.contains(" world"));
940 assert!(output.contains("[DONE]"));
941 assert!(output.contains("\"finish_reason\":\"stop\""));
942 }
943
944 #[tokio::test]
945 async fn synthetic_completion_stream_emits_text_then_complete() {
946 let completion = ChatCompletionResult {
947 model: "test-model".to_owned(),
948 message: ChatMessage {
949 role: ChatRole::Assistant,
950 content: "Hello world".to_owned(),
951 },
952 finish_reason: "stop".to_owned(),
953 usage: TokenUsage {
954 input_tokens: Some(1),
955 output_tokens: Some(1),
956 total_tokens: Some(2),
957 },
958 };
959
960 let stream = synthetic_completion_stream(completion);
961 let events: Vec<ProviderStreamEvent> = stream
962 .collect::<Vec<_>>()
963 .await
964 .into_iter()
965 .collect::<Result<Vec<_>>>()
966 .unwrap();
967
968 assert_eq!(events.len(), 2);
969 assert_eq!(
970 events[0],
971 ProviderStreamEvent::TextDelta("Hello world".to_owned())
972 );
973 match &events[1] {
974 ProviderStreamEvent::Complete {
975 model,
976 finish_reason,
977 usage,
978 } => {
979 assert_eq!(model, "test-model");
980 assert_eq!(finish_reason, "stop");
981 assert_eq!(usage.total_tokens, Some(2));
982 }
983 _ => panic!("expected Complete event"),
984 }
985 }
986
987 #[tokio::test]
988 async fn synthetic_completion_stream_empty_content() {
989 let completion = ChatCompletionResult {
990 model: "m".to_owned(),
991 message: ChatMessage {
992 role: ChatRole::Assistant,
993 content: "".to_owned(),
994 },
995 finish_reason: "stop".to_owned(),
996 usage: TokenUsage {
997 input_tokens: None,
998 output_tokens: None,
999 total_tokens: None,
1000 },
1001 };
1002
1003 let stream = synthetic_completion_stream(completion);
1004 let events: Vec<ProviderStreamEvent> = stream
1005 .collect::<Vec<_>>()
1006 .await
1007 .into_iter()
1008 .collect::<Result<Vec<_>>>()
1009 .unwrap();
1010
1011 assert_eq!(events.len(), 2);
1012 assert_eq!(events[0], ProviderStreamEvent::TextDelta("".to_owned()));
1013 }
1014
1015 #[test]
1016 fn sse_decoder_complete_event() {
1017 let mut decoder = SseDecoder::default();
1018 decoder.push(b"data: hello\n\n");
1019 assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1020 assert_eq!(decoder.next_event(), None);
1021 }
1022
1023 #[test]
1024 fn sse_decoder_multiple_events() {
1025 let mut decoder = SseDecoder::default();
1026 decoder.push(b"data: first\n\ndata: second\n\n");
1027 assert_eq!(decoder.next_event(), Some("first".to_owned()));
1028 assert_eq!(decoder.next_event(), Some("second".to_owned()));
1029 assert_eq!(decoder.next_event(), None);
1030 }
1031
1032 #[test]
1033 fn sse_decoder_partial_chunks() {
1034 let mut decoder = SseDecoder::default();
1035 decoder.push(b"data: hel");
1036 assert_eq!(decoder.next_event(), None);
1037 decoder.push(b"lo\n\n");
1038 assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1039 }
1040
1041 #[test]
1042 fn sse_decoder_malformed_no_data_prefix() {
1043 let mut decoder = SseDecoder::default();
1044 decoder.push(b"event: message\n\n");
1045 assert_eq!(decoder.next_event(), None);
1046 }
1047
1048 #[test]
1049 fn sse_decoder_empty_chunk() {
1050 let mut decoder = SseDecoder::default();
1051 decoder.push(b"");
1052 assert_eq!(decoder.next_event(), None);
1053 }
1054
1055 #[test]
1056 fn sse_decoder_multiline_data() {
1057 let mut decoder = SseDecoder::default();
1058 decoder.push(b"data: line1\ndata: line2\n\n");
1059 assert_eq!(decoder.next_event(), Some("line1\nline2".to_owned()));
1060 }
1061
1062 #[test]
1063 fn sse_decoder_carriage_return() {
1064 let mut decoder = SseDecoder::default();
1065 decoder.push(b"data: hello\r\n\r\n");
1066 assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1067 }
1068
1069 #[tokio::test]
1070 async fn openai_compatible_event_stream_parses_text_and_complete() {
1071 let server = MockServer::start().await;
1072 Mock::given(method("GET"))
1073 .and(path("/stream"))
1074 .respond_with(ResponseTemplate::new(200).set_body_string(
1075 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n\
1076 data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n\
1077 data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n\
1078 data: [DONE]\n\n",
1079 ))
1080 .mount(&server)
1081 .await;
1082
1083 let client = reqwest::Client::new();
1084 let response = client
1085 .get(format!("{}/stream", server.uri()))
1086 .send()
1087 .await
1088 .unwrap();
1089
1090 let stream =
1091 openai_compatible_event_stream(response, "fallback-model".to_owned(), |s| s.to_owned());
1092 let events: Vec<ProviderStreamEvent> = stream
1093 .collect::<Vec<_>>()
1094 .await
1095 .into_iter()
1096 .collect::<Result<Vec<_>>>()
1097 .unwrap();
1098
1099 assert_eq!(events.len(), 3);
1100 assert_eq!(
1101 events[0],
1102 ProviderStreamEvent::TextDelta("Hello".to_owned())
1103 );
1104 assert_eq!(
1105 events[1],
1106 ProviderStreamEvent::TextDelta(" world".to_owned())
1107 );
1108 match &events[2] {
1109 ProviderStreamEvent::Complete {
1110 model,
1111 finish_reason,
1112 usage,
1113 } => {
1114 assert_eq!(model, "fallback-model");
1115 assert_eq!(finish_reason, "stop");
1116 assert_eq!(usage.total_tokens, Some(3));
1117 }
1118 _ => panic!("expected Complete"),
1119 }
1120 }
1121
1122 #[tokio::test]
1123 async fn models_dev_http_failure_returns_error() {
1124 let server = MockServer::start().await;
1125 Mock::given(method("GET"))
1126 .and(path("/api.json"))
1127 .respond_with(ResponseTemplate::new(500))
1128 .mount(&server)
1129 .await;
1130
1131 let catalog = ModelsDevCatalog::new(format!("{}/api.json", server.uri()));
1132 let mut models = vec![ModelDescriptor {
1133 id: "test".to_owned(),
1134 provider: ProviderKind::OpenAi,
1135 profile_id: None,
1136 upstream_name: "gpt-4".to_owned(),
1137 display_name: "GPT-4".to_owned(),
1138 metadata: None,
1139 }];
1140
1141 let result = catalog.enrich(&mut models).await;
1142 assert!(result.is_err());
1143 }
1144
1145 #[tokio::test]
1146 async fn models_dev_cache_reuses_index() {
1147 let server = MockServer::start().await;
1148 Mock::given(method("GET"))
1149 .and(path("/api.json"))
1150 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1151 "openai": {
1152 "models": {
1153 "gpt-4": {
1154 "family": "gpt",
1155 "modalities": { "input": ["text"], "output": ["text"] },
1156 "limit": { "context": 8192, "output": 4096 }
1157 }
1158 }
1159 }
1160 })))
1161 .expect(1)
1162 .mount(&server)
1163 .await;
1164
1165 let catalog = ModelsDevCatalog::new(format!("{}/api.json", server.uri()));
1166
1167 let mut models = vec![ModelDescriptor {
1168 id: "openai/gpt-4".to_owned(),
1169 provider: ProviderKind::OpenAi,
1170 profile_id: None,
1171 upstream_name: "gpt-4".to_owned(),
1172 display_name: "GPT-4".to_owned(),
1173 metadata: None,
1174 }];
1175
1176 catalog.enrich(&mut models).await.unwrap();
1177 assert_eq!(
1178 models[0].metadata.as_ref().unwrap().family,
1179 Some("gpt".to_owned())
1180 );
1181
1182 let mut models2 = vec![ModelDescriptor {
1183 id: "openai/gpt-4".to_owned(),
1184 provider: ProviderKind::OpenAi,
1185 profile_id: None,
1186 upstream_name: "gpt-4".to_owned(),
1187 display_name: "GPT-4".to_owned(),
1188 metadata: None,
1189 }];
1190 catalog.enrich(&mut models2).await.unwrap();
1191 assert_eq!(
1192 models2[0].metadata.as_ref().unwrap().family,
1193 Some("gpt".to_owned())
1194 );
1195 }
1196
1197 #[derive(Default)]
1198 struct MockAdapter;
1199
1200 #[async_trait]
1201 impl ProviderAdapter for MockAdapter {
1202 fn definition(&self) -> ProviderDefinition {
1203 ProviderDefinition {
1204 kind: ProviderKind::Custom("mock".to_owned()),
1205 label: "mock",
1206 class: ProviderClass::Direct,
1207 priority: 99,
1208 capabilities: ProviderCapabilities {
1209 auth_method: ProviderAuthMethod::ApiKey,
1210 supports_base_url: true,
1211 supports_model_sync: true,
1212 supports_chat_completions: true,
1213 supports_responses_api: true,
1214 supports_streaming: true,
1215 },
1216 ux: ProviderUxHints {
1217 helper_title: "Direct provider",
1218 helper_body: "Save the upstream API key here.",
1219 suggested_name: "mock",
1220 base_url_placeholder: "optional override",
1221 },
1222 }
1223 }
1224
1225 async fn auth_status(
1226 &self,
1227 _profile: &ProviderProfile,
1228 _context: &dyn ProviderContext,
1229 ) -> Result<ProviderAuthResult> {
1230 Ok(ProviderAuthResult {
1231 credentials: Some(json!({ "token": "updated" })),
1232 status: ProviderAuthStatus {
1233 state: ProviderAuthState::Connected,
1234 label: "mock".to_owned(),
1235 },
1236 })
1237 }
1238
1239 async fn login(
1240 &self,
1241 _profile: &ProviderProfile,
1242 _context: &dyn ProviderContext,
1243 _open_browser: bool,
1244 ) -> Result<ProviderLoginResult> {
1245 bail!("not implemented")
1246 }
1247
1248 async fn logout(
1249 &self,
1250 _profile: &ProviderProfile,
1251 _context: &dyn ProviderContext,
1252 ) -> Result<Option<Value>> {
1253 Ok(None)
1254 }
1255
1256 async fn sync_models(
1257 &self,
1258 profile: &ProviderProfile,
1259 _context: &dyn ProviderContext,
1260 ) -> Result<ProviderModelSyncResult> {
1261 Ok(ProviderModelSyncResult {
1262 credentials: Some(json!({ "token": "updated" })),
1263 models: vec![ModelDescriptor {
1264 id: "mock/model-1".to_owned(),
1265 provider: profile.provider.clone(),
1266 profile_id: Some(profile.id),
1267 upstream_name: "model-1".to_owned(),
1268 display_name: "Model 1".to_owned(),
1269 metadata: None,
1270 }],
1271 })
1272 }
1273
1274 async fn chat_completion(
1275 &self,
1276 _profile: &ProviderProfile,
1277 _context: &dyn ProviderContext,
1278 request: &ChatCompletionRequest,
1279 ) -> Result<ProviderChatResult> {
1280 Ok(ProviderChatResult {
1281 credentials: Some(json!({ "token": "updated" })),
1282 completion: ChatCompletionResult {
1283 model: request.model.clone(),
1284 message: ChatMessage {
1285 role: ChatRole::Assistant,
1286 content: "hello from mock".to_owned(),
1287 },
1288 finish_reason: "stop".to_owned(),
1289 usage: gunmetal_core::TokenUsage {
1290 input_tokens: Some(1),
1291 output_tokens: Some(1),
1292 total_tokens: Some(2),
1293 },
1294 },
1295 })
1296 }
1297 }
1298
1299 struct MockCodexAdapter;
1300
1301 #[async_trait]
1302 impl ProviderAdapter for MockCodexAdapter {
1303 fn definition(&self) -> ProviderDefinition {
1304 ProviderDefinition {
1305 kind: ProviderKind::Codex,
1306 label: "codex",
1307 class: ProviderClass::Subscription,
1308 priority: 1,
1309 capabilities: ProviderCapabilities {
1310 auth_method: ProviderAuthMethod::BrowserSession,
1311 supports_base_url: false,
1312 supports_model_sync: true,
1313 supports_chat_completions: true,
1314 supports_responses_api: true,
1315 supports_streaming: true,
1316 },
1317 ux: ProviderUxHints {
1318 helper_title: "Browser sign-in provider",
1319 helper_body: "Save the provider, then finish auth in the browser.",
1320 suggested_name: "codex",
1321 base_url_placeholder: "not used for this provider",
1322 },
1323 }
1324 }
1325
1326 async fn auth_status(
1327 &self,
1328 _profile: &ProviderProfile,
1329 _context: &dyn ProviderContext,
1330 ) -> Result<ProviderAuthResult> {
1331 Ok(ProviderAuthResult {
1332 credentials: None,
1333 status: ProviderAuthStatus {
1334 state: ProviderAuthState::Connected,
1335 label: "codex".to_owned(),
1336 },
1337 })
1338 }
1339
1340 async fn login(
1341 &self,
1342 _profile: &ProviderProfile,
1343 _context: &dyn ProviderContext,
1344 _open_browser: bool,
1345 ) -> Result<ProviderLoginResult> {
1346 bail!("not implemented")
1347 }
1348
1349 async fn logout(
1350 &self,
1351 _profile: &ProviderProfile,
1352 _context: &dyn ProviderContext,
1353 ) -> Result<Option<Value>> {
1354 Ok(None)
1355 }
1356
1357 async fn sync_models(
1358 &self,
1359 profile: &ProviderProfile,
1360 _context: &dyn ProviderContext,
1361 ) -> Result<ProviderModelSyncResult> {
1362 Ok(ProviderModelSyncResult {
1363 credentials: None,
1364 models: vec![ModelDescriptor {
1365 id: "codex/gpt-5.1".to_owned(),
1366 provider: ProviderKind::Codex,
1367 profile_id: Some(profile.id),
1368 upstream_name: "gpt-5.1".to_owned(),
1369 display_name: "GPT-5.1".to_owned(),
1370 metadata: None,
1371 }],
1372 })
1373 }
1374
1375 async fn chat_completion(
1376 &self,
1377 _profile: &ProviderProfile,
1378 _context: &dyn ProviderContext,
1379 request: &ChatCompletionRequest,
1380 ) -> Result<ProviderChatResult> {
1381 Ok(ProviderChatResult {
1382 credentials: None,
1383 completion: ChatCompletionResult {
1384 model: request.model.clone(),
1385 message: ChatMessage {
1386 role: ChatRole::Assistant,
1387 content: "hello".to_owned(),
1388 },
1389 finish_reason: "stop".to_owned(),
1390 usage: gunmetal_core::TokenUsage {
1391 input_tokens: Some(1),
1392 output_tokens: Some(1),
1393 total_tokens: Some(2),
1394 },
1395 },
1396 })
1397 }
1398 }
1399}