Skip to main content

gunmetal_sdk/
lib.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures_util::{
6    StreamExt,
7    stream::{self, BoxStream},
8};
9use gunmetal_core::{
10    ChatCompletionRequest, ChatCompletionResult, ModelDescriptor, ModelMetadata,
11    ProviderAuthStatus, ProviderContext, ProviderKind, ProviderLoginSession, ProviderProfile,
12    TokenUsage,
13};
14use gunmetal_storage::AppPaths;
15use reqwest::{Client, Response};
16use serde::Deserialize;
17use serde_json::Value;
18use tokio::sync::Mutex;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ProviderClass {
22    Subscription,
23    Gateway,
24    Direct,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ProviderAuthMethod {
29    BrowserSession,
30    ApiKey,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ProviderCapabilities {
35    pub auth_method: ProviderAuthMethod,
36    pub supports_base_url: bool,
37    pub supports_model_sync: bool,
38    pub supports_chat_completions: bool,
39    pub supports_responses_api: bool,
40    pub supports_streaming: bool,
41}
42
43impl ProviderCapabilities {
44    pub fn supports_browser_login(&self) -> bool {
45        matches!(self.auth_method, ProviderAuthMethod::BrowserSession)
46    }
47
48    pub fn requires_api_key(&self) -> bool {
49        matches!(self.auth_method, ProviderAuthMethod::ApiKey)
50    }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ProviderUxHints {
55    pub helper_title: &'static str,
56    pub helper_body: &'static str,
57    pub suggested_name: &'static str,
58    pub base_url_placeholder: &'static str,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct ProviderDefinition {
63    pub kind: ProviderKind,
64    pub label: &'static str,
65    pub class: ProviderClass,
66    pub priority: usize,
67    pub capabilities: ProviderCapabilities,
68    pub ux: ProviderUxHints,
69}
70
71impl ProviderDefinition {
72    pub fn supports_browser_login(&self) -> bool {
73        self.capabilities.supports_browser_login()
74    }
75
76    pub fn requires_api_key(&self) -> bool {
77        self.capabilities.requires_api_key()
78    }
79}
80
81#[derive(Debug, Clone)]
82pub struct ProviderAuthResult {
83    pub credentials: Option<Value>,
84    pub status: ProviderAuthStatus,
85}
86
87#[derive(Debug, Clone)]
88pub struct ProviderLoginResult {
89    pub credentials: Option<Value>,
90    pub session: ProviderLoginSession,
91}
92
93#[derive(Debug, Clone)]
94pub struct ProviderModelSyncResult {
95    pub credentials: Option<Value>,
96    pub models: Vec<ModelDescriptor>,
97}
98
99#[derive(Debug, Clone)]
100pub struct ProviderChatResult {
101    pub completion: ChatCompletionResult,
102    pub credentials: Option<Value>,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub enum ProviderStreamEvent {
107    TextDelta(String),
108    Complete {
109        model: String,
110        finish_reason: String,
111        usage: TokenUsage,
112    },
113}
114
115pub type ProviderEventStream = BoxStream<'static, Result<ProviderStreamEvent>>;
116pub type ProviderByteStream = BoxStream<'static, Result<Vec<u8>>>;
117
118pub struct ProviderStreamResult {
119    pub stream: ProviderEventStream,
120    pub credentials: Option<Value>,
121}
122
123pub struct ProviderRawSseResult {
124    pub stream: ProviderByteStream,
125    pub credentials: Option<Value>,
126}
127
128#[async_trait]
129pub trait ProviderAdapter: Send + Sync {
130    fn definition(&self) -> ProviderDefinition;
131
132    async fn auth_status(
133        &self,
134        profile: &ProviderProfile,
135        context: &dyn ProviderContext,
136    ) -> Result<ProviderAuthResult>;
137
138    async fn login(
139        &self,
140        profile: &ProviderProfile,
141        context: &dyn ProviderContext,
142        open_browser: bool,
143    ) -> Result<ProviderLoginResult>;
144
145    async fn logout(
146        &self,
147        profile: &ProviderProfile,
148        context: &dyn ProviderContext,
149    ) -> Result<Option<Value>>;
150
151    async fn sync_models(
152        &self,
153        profile: &ProviderProfile,
154        context: &dyn ProviderContext,
155    ) -> Result<ProviderModelSyncResult>;
156
157    async fn chat_completion(
158        &self,
159        profile: &ProviderProfile,
160        context: &dyn ProviderContext,
161        request: &ChatCompletionRequest,
162    ) -> Result<ProviderChatResult>;
163
164    async fn stream_chat_completion(
165        &self,
166        profile: &ProviderProfile,
167        context: &dyn ProviderContext,
168        request: &ChatCompletionRequest,
169    ) -> Result<ProviderStreamResult> {
170        let result = self.chat_completion(profile, context, request).await?;
171        Ok(ProviderStreamResult {
172            credentials: result.credentials,
173            stream: synthetic_completion_stream(result.completion),
174        })
175    }
176
177    async fn raw_stream_chat_completion(
178        &self,
179        profile: &ProviderProfile,
180        context: &dyn ProviderContext,
181        request: &ChatCompletionRequest,
182    ) -> Result<ProviderRawSseResult> {
183        let result = self
184            .stream_chat_completion(profile, context, request)
185            .await?;
186        Ok(ProviderRawSseResult {
187            credentials: result.credentials,
188            stream: synthetic_chat_sse_stream(request.model.clone(), result.stream),
189        })
190    }
191}
192
193#[derive(Clone, Default)]
194pub struct ProviderRegistry {
195    adapters: HashMap<ProviderKind, Arc<dyn ProviderAdapter>>,
196}
197
198impl ProviderRegistry {
199    pub fn register<A>(&mut self, adapter: A)
200    where
201        A: ProviderAdapter + 'static,
202    {
203        let adapter = Arc::new(adapter);
204        self.adapters
205            .insert(adapter.definition().kind.clone(), adapter);
206    }
207
208    pub fn get(&self, kind: &ProviderKind) -> Option<Arc<dyn ProviderAdapter>> {
209        self.adapters.get(kind).cloned()
210    }
211
212    pub fn definition(&self, kind: &ProviderKind) -> Option<ProviderDefinition> {
213        self.adapters.get(kind).map(|adapter| adapter.definition())
214    }
215
216    pub fn definitions(&self) -> Vec<ProviderDefinition> {
217        let mut definitions = self
218            .adapters
219            .values()
220            .map(|adapter| adapter.definition())
221            .collect::<Vec<_>>();
222        definitions.sort_by_key(|item| item.priority);
223        definitions
224    }
225}
226
227#[derive(Clone)]
228pub struct ProviderHub {
229    paths: AppPaths,
230    registry: ProviderRegistry,
231    models_dev: ModelsDevCatalog,
232}
233
234impl ProviderHub {
235    pub fn new(paths: AppPaths, registry: ProviderRegistry) -> Self {
236        Self {
237            paths,
238            registry,
239            models_dev: ModelsDevCatalog::default(),
240        }
241    }
242
243    pub fn with_registry(paths: AppPaths, registry: ProviderRegistry) -> Self {
244        Self::new(paths, registry)
245    }
246
247    pub fn with_registry_and_models_dev(
248        paths: AppPaths,
249        registry: ProviderRegistry,
250        models_dev: ModelsDevCatalog,
251    ) -> Self {
252        Self {
253            paths,
254            registry,
255            models_dev,
256        }
257    }
258
259    pub async fn auth_status(&self, profile: &ProviderProfile) -> Result<ProviderAuthStatus> {
260        let adapter = self.adapter(&profile.provider)?;
261        let result = adapter.auth_status(profile, &self.paths).await?;
262        self.persist_credentials(profile.id, result.credentials)?;
263        Ok(result.status)
264    }
265
266    pub async fn login(
267        &self,
268        profile: &ProviderProfile,
269        open_browser: bool,
270    ) -> Result<ProviderLoginSession> {
271        let adapter = self.adapter(&profile.provider)?;
272        let result = adapter.login(profile, &self.paths, open_browser).await?;
273        self.persist_credentials(profile.id, result.credentials)?;
274        Ok(result.session)
275    }
276
277    pub async fn logout(&self, profile: &ProviderProfile) -> Result<()> {
278        let adapter = self.adapter(&profile.provider)?;
279        let credentials = adapter.logout(profile, &self.paths).await?;
280        self.persist_credentials(profile.id, credentials)
281    }
282
283    pub async fn sync_models(&self, profile: &ProviderProfile) -> Result<Vec<ModelDescriptor>> {
284        let adapter = self.adapter(&profile.provider)?;
285        let mut result = adapter.sync_models(profile, &self.paths).await?;
286        self.persist_credentials(profile.id, result.credentials)?;
287        if let Err(error) = self.models_dev.enrich(&mut result.models).await {
288            let _ = error;
289        }
290        Ok(result.models)
291    }
292
293    pub async fn chat_completion(
294        &self,
295        profile: &ProviderProfile,
296        request: &ChatCompletionRequest,
297    ) -> Result<ChatCompletionResult> {
298        let adapter = self.adapter(&profile.provider)?;
299        let result = adapter
300            .chat_completion(profile, &self.paths, request)
301            .await?;
302        self.persist_credentials(profile.id, result.credentials)?;
303        Ok(result.completion)
304    }
305
306    pub async fn stream_chat_completion(
307        &self,
308        profile: &ProviderProfile,
309        request: &ChatCompletionRequest,
310    ) -> Result<ProviderEventStream> {
311        let adapter = self.adapter(&profile.provider)?;
312        let result = adapter
313            .stream_chat_completion(profile, &self.paths, request)
314            .await?;
315        self.persist_credentials(profile.id, result.credentials)?;
316        Ok(result.stream)
317    }
318
319    pub async fn raw_stream_chat_completion(
320        &self,
321        profile: &ProviderProfile,
322        request: &ChatCompletionRequest,
323    ) -> Result<ProviderByteStream> {
324        let adapter = self.adapter(&profile.provider)?;
325        let result = adapter
326            .raw_stream_chat_completion(profile, &self.paths, request)
327            .await?;
328        self.persist_credentials(profile.id, result.credentials)?;
329        Ok(result.stream)
330    }
331
332    pub fn definitions(&self) -> Vec<ProviderDefinition> {
333        self.registry.definitions()
334    }
335
336    pub fn definition(&self, kind: &ProviderKind) -> Option<ProviderDefinition> {
337        self.registry.definition(kind)
338    }
339
340    fn adapter(&self, kind: &ProviderKind) -> Result<Arc<dyn ProviderAdapter>> {
341        self.registry
342            .get(kind)
343            .ok_or_else(|| anyhow!("provider '{}' not implemented yet", kind))
344    }
345
346    fn persist_credentials(
347        &self,
348        profile_id: uuid::Uuid,
349        credentials: Option<serde_json::Value>,
350    ) -> Result<()> {
351        let Some(credentials) = credentials else {
352            return Ok(());
353        };
354        self.paths
355            .storage_handle()?
356            .update_profile_credentials(profile_id, Some(credentials))
357    }
358}
359
360fn synthetic_completion_stream(completion: ChatCompletionResult) -> ProviderEventStream {
361    let mut events = text_chunks(&completion.message.content)
362        .into_iter()
363        .map(ProviderStreamEvent::TextDelta)
364        .collect::<Vec<_>>();
365    events.push(ProviderStreamEvent::Complete {
366        model: completion.model,
367        finish_reason: completion.finish_reason,
368        usage: completion.usage,
369    });
370    stream::iter(events.into_iter().map(Ok)).boxed()
371}
372
373pub fn synthetic_chat_sse_stream(model: String, stream: ProviderEventStream) -> ProviderByteStream {
374    let id = format!("chatcmpl-{}", uuid::Uuid::new_v4().simple());
375    let created = chrono::Utc::now().timestamp();
376    let first = stream::once(async move {
377        Ok::<Vec<u8>, anyhow::Error>(
378            format!(
379                "data: {}\n\n",
380                serde_json::json!({
381                    "id": id,
382                    "object": "chat.completion.chunk",
383                    "created": created,
384                    "model": model,
385                    "choices": [{
386                        "index": 0,
387                        "delta": { "role": "assistant" },
388                        "finish_reason": Value::Null
389                    }]
390                })
391            )
392            .into_bytes(),
393        )
394    });
395
396    let content = stream.map(move |item| match item {
397        Ok(ProviderStreamEvent::TextDelta(chunk)) => Ok(format!(
398            "data: {}\n\n",
399            serde_json::json!({
400                "id": format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
401                "object": "chat.completion.chunk",
402                "created": chrono::Utc::now().timestamp(),
403                "choices": [{
404                    "index": 0,
405                    "delta": { "content": chunk },
406                    "finish_reason": Value::Null
407                }]
408            })
409        )
410        .into_bytes()),
411        Ok(ProviderStreamEvent::Complete {
412            model,
413            finish_reason,
414            usage,
415        }) => Ok(format!(
416            "data: {}\n\n",
417            serde_json::json!({
418                "id": format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
419                "object": "chat.completion.chunk",
420                "created": chrono::Utc::now().timestamp(),
421                "model": model,
422                "choices": [{
423                    "index": 0,
424                    "delta": {},
425                    "finish_reason": finish_reason
426                }],
427                "usage": usage
428            })
429        )
430        .into_bytes()),
431        Err(error) => Ok(format!(
432            "event: error\ndata: {}\n\n",
433            serde_json::json!({ "error": { "message": error.to_string() } })
434        )
435        .into_bytes()),
436    });
437
438    let done = stream::once(async { Ok::<Vec<u8>, anyhow::Error>(b"data: [DONE]\n\n".to_vec()) });
439    first.chain(content).chain(done).boxed()
440}
441
442pub fn openai_compatible_event_stream<F>(
443    response: Response,
444    fallback_model: String,
445    normalize_model: F,
446) -> ProviderEventStream
447where
448    F: Fn(&str) -> String + Send + Sync + 'static,
449{
450    let normalize_model = Arc::new(normalize_model);
451    async_stream::try_stream! {
452        let mut upstream = response.bytes_stream();
453        let mut decoder = SseDecoder::default();
454        let mut current_model = fallback_model;
455
456        while let Some(chunk) = upstream.next().await {
457            let chunk = chunk?;
458            decoder.push(&chunk);
459
460            while let Some(event) = decoder.next_event() {
461                if event == "[DONE]" {
462                    continue;
463                }
464
465                for parsed in parse_openai_compatible_event(
466                    &event,
467                    &mut current_model,
468                    normalize_model.as_ref(),
469                )? {
470                    yield parsed;
471                }
472            }
473        }
474    }
475    .boxed()
476}
477
478fn parse_openai_compatible_event(
479    event: &str,
480    current_model: &mut String,
481    normalize_model: &dyn Fn(&str) -> String,
482) -> Result<Vec<ProviderStreamEvent>> {
483    let payload = serde_json::from_str::<OpenAiCompatibleStreamChunk>(event)?;
484    if let Some(model) = payload.model.as_deref() {
485        *current_model = normalize_model(model);
486    }
487
488    let mut events = Vec::new();
489    let usage = payload.usage.map(to_token_usage);
490    for choice in payload.choices {
491        if let Some(delta) = choice.delta.and_then(|delta| delta.content)
492            && !delta.is_empty()
493        {
494            events.push(ProviderStreamEvent::TextDelta(delta));
495        }
496
497        if let Some(finish_reason) = choice.finish_reason {
498            events.push(ProviderStreamEvent::Complete {
499                model: current_model.clone(),
500                finish_reason,
501                usage: usage.clone().unwrap_or(TokenUsage {
502                    input_tokens: None,
503                    output_tokens: None,
504                    total_tokens: None,
505                }),
506            });
507        }
508    }
509
510    Ok(events)
511}
512
513fn to_token_usage(usage: OpenAiCompatibleUsage) -> TokenUsage {
514    let input_tokens = usage.prompt_tokens.map(to_u32);
515    let output_tokens = usage.completion_tokens.map(to_u32);
516    let total_tokens =
517        usage
518            .total_tokens
519            .map(to_u32)
520            .or_else(|| match (input_tokens, output_tokens) {
521                (Some(input), Some(output)) => Some(input.saturating_add(output)),
522                _ => None,
523            });
524
525    TokenUsage {
526        input_tokens,
527        output_tokens,
528        total_tokens,
529    }
530}
531
532#[derive(Default)]
533struct SseDecoder {
534    buffer: String,
535}
536
537impl SseDecoder {
538    fn push(&mut self, chunk: &[u8]) {
539        let chunk = String::from_utf8_lossy(chunk);
540        let chunk = chunk.replace("\r\n", "\n");
541        self.buffer.push_str(&chunk);
542    }
543
544    fn next_event(&mut self) -> Option<String> {
545        let separator = self.buffer.find("\n\n")?;
546        let frame = self.buffer[..separator].to_owned();
547        self.buffer.drain(..separator + 2);
548
549        let data = frame
550            .lines()
551            .filter_map(|line| line.strip_prefix("data:"))
552            .map(str::trim_start)
553            .collect::<Vec<_>>()
554            .join("\n");
555        (!data.is_empty()).then_some(data)
556    }
557}
558
559#[derive(Debug, Deserialize)]
560struct OpenAiCompatibleStreamChunk {
561    #[serde(default)]
562    model: Option<String>,
563    #[serde(default)]
564    choices: Vec<OpenAiCompatibleStreamChoice>,
565    #[serde(default)]
566    usage: Option<OpenAiCompatibleUsage>,
567}
568
569#[derive(Debug, Deserialize)]
570struct OpenAiCompatibleStreamChoice {
571    #[serde(default)]
572    delta: Option<OpenAiCompatibleStreamDelta>,
573    #[serde(default)]
574    finish_reason: Option<String>,
575}
576
577#[derive(Debug, Deserialize)]
578struct OpenAiCompatibleStreamDelta {
579    #[serde(default)]
580    content: Option<String>,
581}
582
583#[derive(Debug, Clone, Deserialize)]
584struct OpenAiCompatibleUsage {
585    #[serde(default)]
586    prompt_tokens: Option<u64>,
587    #[serde(default)]
588    completion_tokens: Option<u64>,
589    #[serde(default)]
590    total_tokens: Option<u64>,
591}
592
593fn text_chunks(value: &str) -> Vec<String> {
594    if value.is_empty() {
595        return vec![String::new()];
596    }
597
598    let mut chunks = Vec::new();
599    let mut current = String::new();
600    let mut count = 0usize;
601    for ch in value.chars() {
602        current.push(ch);
603        count += 1;
604        if count >= 24 {
605            chunks.push(std::mem::take(&mut current));
606            count = 0;
607        }
608    }
609
610    if !current.is_empty() {
611        chunks.push(current);
612    }
613
614    chunks
615}
616
617#[derive(Clone)]
618pub struct ModelsDevCatalog {
619    catalog_url: String,
620    http: Client,
621    cache: Arc<Mutex<Option<ModelsDevIndex>>>,
622}
623
624impl Default for ModelsDevCatalog {
625    fn default() -> Self {
626        Self::new("https://models.dev/api.json")
627    }
628}
629
630impl ModelsDevCatalog {
631    pub fn new(catalog_url: impl Into<String>) -> Self {
632        Self {
633            catalog_url: catalog_url.into(),
634            http: Client::builder()
635                .connect_timeout(std::time::Duration::from_secs(2))
636                .timeout(std::time::Duration::from_secs(4))
637                .build()
638                .expect("reqwest client"),
639            cache: Arc::new(Mutex::new(None)),
640        }
641    }
642
643    async fn enrich(&self, models: &mut [ModelDescriptor]) -> Result<()> {
644        let index = self.index().await?;
645        for model in models {
646            if model.metadata.is_some() {
647                continue;
648            }
649
650            let aliases = provider_aliases(&model.provider);
651            let metadata = aliases
652                .iter()
653                .find_map(|alias| index.by_provider.get(*alias))
654                .and_then(|models| models.get(&model.upstream_name).cloned())
655                .or_else(|| index.by_model_id.get(&model.upstream_name).cloned());
656            model.metadata = metadata;
657        }
658        Ok(())
659    }
660
661    async fn index(&self) -> Result<ModelsDevIndex> {
662        {
663            let cache = self.cache.lock().await;
664            if let Some(index) = cache.as_ref() {
665                return Ok(index.clone());
666            }
667        }
668
669        let payload = self
670            .http
671            .get(&self.catalog_url)
672            .send()
673            .await?
674            .error_for_status()?
675            .json::<HashMap<String, ModelsDevProvider>>()
676            .await?;
677        let index = ModelsDevIndex::from_payload(payload);
678        let mut cache = self.cache.lock().await;
679        *cache = Some(index.clone());
680        Ok(index)
681    }
682}
683
684#[derive(Debug, Clone, Default)]
685struct ModelsDevIndex {
686    by_model_id: HashMap<String, ModelMetadata>,
687    by_provider: HashMap<String, HashMap<String, ModelMetadata>>,
688}
689
690impl ModelsDevIndex {
691    fn from_payload(payload: HashMap<String, ModelsDevProvider>) -> Self {
692        let mut index = Self::default();
693        for (provider, envelope) in payload {
694            let mut provider_models = HashMap::new();
695            for (model_id, model) in envelope.models {
696                let metadata = ModelMetadata {
697                    family: model.family,
698                    release_date: model.release_date,
699                    last_updated: model.last_updated,
700                    input_modalities: model.modalities.input,
701                    output_modalities: model.modalities.output,
702                    context_window: model.limit.context.map(to_u32),
703                    max_output_tokens: model.limit.output.map(to_u32),
704                    supports_attachments: model.attachment,
705                    supports_reasoning: model.reasoning,
706                    supports_tools: model.tool_call,
707                    open_weights: model.open_weights,
708                };
709                provider_models.insert(model_id.clone(), metadata.clone());
710                index.by_model_id.entry(model_id).or_insert(metadata);
711            }
712            index.by_provider.insert(provider, provider_models);
713        }
714        index
715    }
716}
717
718#[derive(Debug, Clone, Deserialize, Default)]
719struct ModelsDevProvider {
720    #[serde(default)]
721    models: HashMap<String, ModelsDevModel>,
722}
723
724#[derive(Debug, Clone, Deserialize, Default)]
725struct ModelsDevModel {
726    family: Option<String>,
727    attachment: Option<bool>,
728    reasoning: Option<bool>,
729    tool_call: Option<bool>,
730    open_weights: Option<bool>,
731    release_date: Option<String>,
732    last_updated: Option<String>,
733    #[serde(default)]
734    modalities: ModelsDevModalities,
735    #[serde(default)]
736    limit: ModelsDevLimits,
737}
738
739#[derive(Debug, Clone, Deserialize, Default)]
740struct ModelsDevModalities {
741    #[serde(default)]
742    input: Vec<String>,
743    #[serde(default)]
744    output: Vec<String>,
745}
746
747#[derive(Debug, Clone, Deserialize, Default)]
748struct ModelsDevLimits {
749    context: Option<u64>,
750    output: Option<u64>,
751}
752
753fn provider_aliases(provider: &ProviderKind) -> &'static [&'static str] {
754    match provider {
755        ProviderKind::Codex => &["codex", "openai"],
756        ProviderKind::Copilot => &["copilot", "github"],
757        ProviderKind::OpenRouter => &["openrouter"],
758        ProviderKind::Zen => &["zen", "opencode", "zenmux"],
759        ProviderKind::OpenAi => &["openai"],
760        ProviderKind::Azure => &["azure", "azure-cognitive-services"],
761        ProviderKind::Nvidia => &["nvidia"],
762        ProviderKind::Custom(_) => &[],
763    }
764}
765
766fn to_u32(value: u64) -> u32 {
767    u32::try_from(value).unwrap_or(u32::MAX)
768}
769
770#[cfg(test)]
771mod tests {
772    use anyhow::{Result, bail};
773    use gunmetal_core::{
774        ChatMessage, ChatRole, NewProviderProfile, ProviderAuthState, RequestOptions,
775    };
776    use serde_json::json;
777    use tempfile::TempDir;
778    use wiremock::{
779        Mock, MockServer, ResponseTemplate,
780        matchers::{method, path},
781    };
782
783    use super::*;
784
785    #[tokio::test]
786    async fn provider_hub_uses_registered_adapter_and_persists_credentials() {
787        let temp = TempDir::new().unwrap();
788        let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
789        let storage = paths.storage_handle().unwrap();
790        let profile = storage
791            .create_profile(NewProviderProfile {
792                provider: ProviderKind::Custom("mock".to_owned()),
793                name: "mock".to_owned(),
794                base_url: None,
795                enabled: true,
796                credentials: None,
797            })
798            .unwrap();
799
800        let mut registry = ProviderRegistry::default();
801        registry.register(MockAdapter);
802        let hub = ProviderHub::new(paths.clone(), registry);
803
804        let status = hub.auth_status(&profile).await.unwrap();
805        assert_eq!(status.state, ProviderAuthState::Connected);
806
807        let synced = hub.sync_models(&profile).await.unwrap();
808        assert_eq!(synced[0].id, "mock/model-1");
809
810        let completion = hub
811            .chat_completion(
812                &profile,
813                &ChatCompletionRequest {
814                    model: "mock/model-1".to_owned(),
815                    messages: vec![ChatMessage {
816                        role: ChatRole::User,
817                        content: "ping".to_owned(),
818                    }],
819                    stream: false,
820                    options: RequestOptions::default(),
821                },
822            )
823            .await
824            .unwrap();
825        assert_eq!(completion.message.content, "hello from mock");
826
827        let updated = storage.get_profile(profile.id).unwrap().unwrap();
828        assert_eq!(updated.credentials, Some(json!({ "token": "updated" })));
829    }
830
831    #[tokio::test]
832    async fn models_dev_enriches_synced_models() {
833        let server = MockServer::start().await;
834        Mock::given(method("GET"))
835            .and(path("/api.json"))
836            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
837                "openai": {
838                    "models": {
839                        "gpt-5.1": {
840                            "family": "gpt",
841                            "attachment": true,
842                            "reasoning": true,
843                            "tool_call": true,
844                            "open_weights": false,
845                            "release_date": "2025-01-01",
846                            "last_updated": "2025-02-01",
847                            "modalities": { "input": ["text"], "output": ["text"] },
848                            "limit": { "context": 272000, "output": 16384 }
849                        }
850                    }
851                }
852            })))
853            .mount(&server)
854            .await;
855
856        let temp = TempDir::new().unwrap();
857        let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
858        let storage = paths.storage_handle().unwrap();
859        let profile = storage
860            .create_profile(NewProviderProfile {
861                provider: ProviderKind::Codex,
862                name: "codex".to_owned(),
863                base_url: None,
864                enabled: true,
865                credentials: None,
866            })
867            .unwrap();
868
869        let mut registry = ProviderRegistry::default();
870        registry.register(MockCodexAdapter);
871        let hub = ProviderHub::with_registry_and_models_dev(
872            paths,
873            registry,
874            ModelsDevCatalog::new(format!("{}/api.json", server.uri())),
875        );
876
877        let models = hub.sync_models(&profile).await.unwrap();
878        assert_eq!(
879            models[0]
880                .metadata
881                .as_ref()
882                .and_then(|value| value.family.as_deref()),
883            Some("gpt")
884        );
885        assert_eq!(
886            models[0]
887                .metadata
888                .as_ref()
889                .and_then(|value| value.context_window),
890            Some(272_000)
891        );
892    }
893
894    #[test]
895    fn provider_hub_exposes_definition_metadata() {
896        let temp = TempDir::new().unwrap();
897        let paths = AppPaths::from_root(temp.path().join("gunmetal-home")).unwrap();
898        let mut registry = ProviderRegistry::default();
899        registry.register(MockAdapter);
900        let hub = ProviderHub::new(paths, registry);
901
902        let definition = hub
903            .definition(&ProviderKind::Custom("mock".to_owned()))
904            .unwrap();
905        assert_eq!(definition.label, "mock");
906        assert!(definition.requires_api_key());
907        assert!(definition.capabilities.supports_responses_api);
908    }
909
910    #[tokio::test]
911    async fn synthetic_chat_sse_stream_emits_expected_events() {
912        let events: ProviderEventStream = stream::iter(vec![
913            Ok(ProviderStreamEvent::TextDelta("Hello".to_owned())),
914            Ok(ProviderStreamEvent::TextDelta(" world".to_owned())),
915            Ok(ProviderStreamEvent::Complete {
916                model: "gpt-4".to_owned(),
917                finish_reason: "stop".to_owned(),
918                usage: TokenUsage {
919                    input_tokens: Some(1),
920                    output_tokens: Some(2),
921                    total_tokens: Some(3),
922                },
923            }),
924        ])
925        .boxed();
926
927        let byte_stream = synthetic_chat_sse_stream("gpt-4".to_owned(), events);
928        let chunks: Vec<Vec<u8>> = byte_stream
929            .collect::<Vec<_>>()
930            .await
931            .into_iter()
932            .collect::<Result<Vec<_>>>()
933            .unwrap();
934        let output = String::from_utf8(chunks.concat()).unwrap();
935
936        assert!(output.contains("chat.completion.chunk"));
937        assert!(output.contains("\"role\":\"assistant\""));
938        assert!(output.contains("Hello"));
939        assert!(output.contains(" world"));
940        assert!(output.contains("[DONE]"));
941        assert!(output.contains("\"finish_reason\":\"stop\""));
942    }
943
944    #[tokio::test]
945    async fn synthetic_completion_stream_emits_text_then_complete() {
946        let completion = ChatCompletionResult {
947            model: "test-model".to_owned(),
948            message: ChatMessage {
949                role: ChatRole::Assistant,
950                content: "Hello world".to_owned(),
951            },
952            finish_reason: "stop".to_owned(),
953            usage: TokenUsage {
954                input_tokens: Some(1),
955                output_tokens: Some(1),
956                total_tokens: Some(2),
957            },
958        };
959
960        let stream = synthetic_completion_stream(completion);
961        let events: Vec<ProviderStreamEvent> = stream
962            .collect::<Vec<_>>()
963            .await
964            .into_iter()
965            .collect::<Result<Vec<_>>>()
966            .unwrap();
967
968        assert_eq!(events.len(), 2);
969        assert_eq!(
970            events[0],
971            ProviderStreamEvent::TextDelta("Hello world".to_owned())
972        );
973        match &events[1] {
974            ProviderStreamEvent::Complete {
975                model,
976                finish_reason,
977                usage,
978            } => {
979                assert_eq!(model, "test-model");
980                assert_eq!(finish_reason, "stop");
981                assert_eq!(usage.total_tokens, Some(2));
982            }
983            _ => panic!("expected Complete event"),
984        }
985    }
986
987    #[tokio::test]
988    async fn synthetic_completion_stream_empty_content() {
989        let completion = ChatCompletionResult {
990            model: "m".to_owned(),
991            message: ChatMessage {
992                role: ChatRole::Assistant,
993                content: "".to_owned(),
994            },
995            finish_reason: "stop".to_owned(),
996            usage: TokenUsage {
997                input_tokens: None,
998                output_tokens: None,
999                total_tokens: None,
1000            },
1001        };
1002
1003        let stream = synthetic_completion_stream(completion);
1004        let events: Vec<ProviderStreamEvent> = stream
1005            .collect::<Vec<_>>()
1006            .await
1007            .into_iter()
1008            .collect::<Result<Vec<_>>>()
1009            .unwrap();
1010
1011        assert_eq!(events.len(), 2);
1012        assert_eq!(events[0], ProviderStreamEvent::TextDelta("".to_owned()));
1013    }
1014
1015    #[test]
1016    fn sse_decoder_complete_event() {
1017        let mut decoder = SseDecoder::default();
1018        decoder.push(b"data: hello\n\n");
1019        assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1020        assert_eq!(decoder.next_event(), None);
1021    }
1022
1023    #[test]
1024    fn sse_decoder_multiple_events() {
1025        let mut decoder = SseDecoder::default();
1026        decoder.push(b"data: first\n\ndata: second\n\n");
1027        assert_eq!(decoder.next_event(), Some("first".to_owned()));
1028        assert_eq!(decoder.next_event(), Some("second".to_owned()));
1029        assert_eq!(decoder.next_event(), None);
1030    }
1031
1032    #[test]
1033    fn sse_decoder_partial_chunks() {
1034        let mut decoder = SseDecoder::default();
1035        decoder.push(b"data: hel");
1036        assert_eq!(decoder.next_event(), None);
1037        decoder.push(b"lo\n\n");
1038        assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1039    }
1040
1041    #[test]
1042    fn sse_decoder_malformed_no_data_prefix() {
1043        let mut decoder = SseDecoder::default();
1044        decoder.push(b"event: message\n\n");
1045        assert_eq!(decoder.next_event(), None);
1046    }
1047
1048    #[test]
1049    fn sse_decoder_empty_chunk() {
1050        let mut decoder = SseDecoder::default();
1051        decoder.push(b"");
1052        assert_eq!(decoder.next_event(), None);
1053    }
1054
1055    #[test]
1056    fn sse_decoder_multiline_data() {
1057        let mut decoder = SseDecoder::default();
1058        decoder.push(b"data: line1\ndata: line2\n\n");
1059        assert_eq!(decoder.next_event(), Some("line1\nline2".to_owned()));
1060    }
1061
1062    #[test]
1063    fn sse_decoder_carriage_return() {
1064        let mut decoder = SseDecoder::default();
1065        decoder.push(b"data: hello\r\n\r\n");
1066        assert_eq!(decoder.next_event(), Some("hello".to_owned()));
1067    }
1068
1069    #[tokio::test]
1070    async fn openai_compatible_event_stream_parses_text_and_complete() {
1071        let server = MockServer::start().await;
1072        Mock::given(method("GET"))
1073            .and(path("/stream"))
1074            .respond_with(ResponseTemplate::new(200).set_body_string(
1075                "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n\
1076                 data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n\
1077                 data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n\
1078                 data: [DONE]\n\n",
1079            ))
1080            .mount(&server)
1081            .await;
1082
1083        let client = reqwest::Client::new();
1084        let response = client
1085            .get(format!("{}/stream", server.uri()))
1086            .send()
1087            .await
1088            .unwrap();
1089
1090        let stream =
1091            openai_compatible_event_stream(response, "fallback-model".to_owned(), |s| s.to_owned());
1092        let events: Vec<ProviderStreamEvent> = stream
1093            .collect::<Vec<_>>()
1094            .await
1095            .into_iter()
1096            .collect::<Result<Vec<_>>>()
1097            .unwrap();
1098
1099        assert_eq!(events.len(), 3);
1100        assert_eq!(
1101            events[0],
1102            ProviderStreamEvent::TextDelta("Hello".to_owned())
1103        );
1104        assert_eq!(
1105            events[1],
1106            ProviderStreamEvent::TextDelta(" world".to_owned())
1107        );
1108        match &events[2] {
1109            ProviderStreamEvent::Complete {
1110                model,
1111                finish_reason,
1112                usage,
1113            } => {
1114                assert_eq!(model, "fallback-model");
1115                assert_eq!(finish_reason, "stop");
1116                assert_eq!(usage.total_tokens, Some(3));
1117            }
1118            _ => panic!("expected Complete"),
1119        }
1120    }
1121
1122    #[tokio::test]
1123    async fn models_dev_http_failure_returns_error() {
1124        let server = MockServer::start().await;
1125        Mock::given(method("GET"))
1126            .and(path("/api.json"))
1127            .respond_with(ResponseTemplate::new(500))
1128            .mount(&server)
1129            .await;
1130
1131        let catalog = ModelsDevCatalog::new(format!("{}/api.json", server.uri()));
1132        let mut models = vec![ModelDescriptor {
1133            id: "test".to_owned(),
1134            provider: ProviderKind::OpenAi,
1135            profile_id: None,
1136            upstream_name: "gpt-4".to_owned(),
1137            display_name: "GPT-4".to_owned(),
1138            metadata: None,
1139        }];
1140
1141        let result = catalog.enrich(&mut models).await;
1142        assert!(result.is_err());
1143    }
1144
1145    #[tokio::test]
1146    async fn models_dev_cache_reuses_index() {
1147        let server = MockServer::start().await;
1148        Mock::given(method("GET"))
1149            .and(path("/api.json"))
1150            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1151                "openai": {
1152                    "models": {
1153                        "gpt-4": {
1154                            "family": "gpt",
1155                            "modalities": { "input": ["text"], "output": ["text"] },
1156                            "limit": { "context": 8192, "output": 4096 }
1157                        }
1158                    }
1159                }
1160            })))
1161            .expect(1)
1162            .mount(&server)
1163            .await;
1164
1165        let catalog = ModelsDevCatalog::new(format!("{}/api.json", server.uri()));
1166
1167        let mut models = vec![ModelDescriptor {
1168            id: "openai/gpt-4".to_owned(),
1169            provider: ProviderKind::OpenAi,
1170            profile_id: None,
1171            upstream_name: "gpt-4".to_owned(),
1172            display_name: "GPT-4".to_owned(),
1173            metadata: None,
1174        }];
1175
1176        catalog.enrich(&mut models).await.unwrap();
1177        assert_eq!(
1178            models[0].metadata.as_ref().unwrap().family,
1179            Some("gpt".to_owned())
1180        );
1181
1182        let mut models2 = vec![ModelDescriptor {
1183            id: "openai/gpt-4".to_owned(),
1184            provider: ProviderKind::OpenAi,
1185            profile_id: None,
1186            upstream_name: "gpt-4".to_owned(),
1187            display_name: "GPT-4".to_owned(),
1188            metadata: None,
1189        }];
1190        catalog.enrich(&mut models2).await.unwrap();
1191        assert_eq!(
1192            models2[0].metadata.as_ref().unwrap().family,
1193            Some("gpt".to_owned())
1194        );
1195    }
1196
1197    #[derive(Default)]
1198    struct MockAdapter;
1199
1200    #[async_trait]
1201    impl ProviderAdapter for MockAdapter {
1202        fn definition(&self) -> ProviderDefinition {
1203            ProviderDefinition {
1204                kind: ProviderKind::Custom("mock".to_owned()),
1205                label: "mock",
1206                class: ProviderClass::Direct,
1207                priority: 99,
1208                capabilities: ProviderCapabilities {
1209                    auth_method: ProviderAuthMethod::ApiKey,
1210                    supports_base_url: true,
1211                    supports_model_sync: true,
1212                    supports_chat_completions: true,
1213                    supports_responses_api: true,
1214                    supports_streaming: true,
1215                },
1216                ux: ProviderUxHints {
1217                    helper_title: "Direct provider",
1218                    helper_body: "Save the upstream API key here.",
1219                    suggested_name: "mock",
1220                    base_url_placeholder: "optional override",
1221                },
1222            }
1223        }
1224
1225        async fn auth_status(
1226            &self,
1227            _profile: &ProviderProfile,
1228            _context: &dyn ProviderContext,
1229        ) -> Result<ProviderAuthResult> {
1230            Ok(ProviderAuthResult {
1231                credentials: Some(json!({ "token": "updated" })),
1232                status: ProviderAuthStatus {
1233                    state: ProviderAuthState::Connected,
1234                    label: "mock".to_owned(),
1235                },
1236            })
1237        }
1238
1239        async fn login(
1240            &self,
1241            _profile: &ProviderProfile,
1242            _context: &dyn ProviderContext,
1243            _open_browser: bool,
1244        ) -> Result<ProviderLoginResult> {
1245            bail!("not implemented")
1246        }
1247
1248        async fn logout(
1249            &self,
1250            _profile: &ProviderProfile,
1251            _context: &dyn ProviderContext,
1252        ) -> Result<Option<Value>> {
1253            Ok(None)
1254        }
1255
1256        async fn sync_models(
1257            &self,
1258            profile: &ProviderProfile,
1259            _context: &dyn ProviderContext,
1260        ) -> Result<ProviderModelSyncResult> {
1261            Ok(ProviderModelSyncResult {
1262                credentials: Some(json!({ "token": "updated" })),
1263                models: vec![ModelDescriptor {
1264                    id: "mock/model-1".to_owned(),
1265                    provider: profile.provider.clone(),
1266                    profile_id: Some(profile.id),
1267                    upstream_name: "model-1".to_owned(),
1268                    display_name: "Model 1".to_owned(),
1269                    metadata: None,
1270                }],
1271            })
1272        }
1273
1274        async fn chat_completion(
1275            &self,
1276            _profile: &ProviderProfile,
1277            _context: &dyn ProviderContext,
1278            request: &ChatCompletionRequest,
1279        ) -> Result<ProviderChatResult> {
1280            Ok(ProviderChatResult {
1281                credentials: Some(json!({ "token": "updated" })),
1282                completion: ChatCompletionResult {
1283                    model: request.model.clone(),
1284                    message: ChatMessage {
1285                        role: ChatRole::Assistant,
1286                        content: "hello from mock".to_owned(),
1287                    },
1288                    finish_reason: "stop".to_owned(),
1289                    usage: gunmetal_core::TokenUsage {
1290                        input_tokens: Some(1),
1291                        output_tokens: Some(1),
1292                        total_tokens: Some(2),
1293                    },
1294                },
1295            })
1296        }
1297    }
1298
1299    struct MockCodexAdapter;
1300
1301    #[async_trait]
1302    impl ProviderAdapter for MockCodexAdapter {
1303        fn definition(&self) -> ProviderDefinition {
1304            ProviderDefinition {
1305                kind: ProviderKind::Codex,
1306                label: "codex",
1307                class: ProviderClass::Subscription,
1308                priority: 1,
1309                capabilities: ProviderCapabilities {
1310                    auth_method: ProviderAuthMethod::BrowserSession,
1311                    supports_base_url: false,
1312                    supports_model_sync: true,
1313                    supports_chat_completions: true,
1314                    supports_responses_api: true,
1315                    supports_streaming: true,
1316                },
1317                ux: ProviderUxHints {
1318                    helper_title: "Browser sign-in provider",
1319                    helper_body: "Save the provider, then finish auth in the browser.",
1320                    suggested_name: "codex",
1321                    base_url_placeholder: "not used for this provider",
1322                },
1323            }
1324        }
1325
1326        async fn auth_status(
1327            &self,
1328            _profile: &ProviderProfile,
1329            _context: &dyn ProviderContext,
1330        ) -> Result<ProviderAuthResult> {
1331            Ok(ProviderAuthResult {
1332                credentials: None,
1333                status: ProviderAuthStatus {
1334                    state: ProviderAuthState::Connected,
1335                    label: "codex".to_owned(),
1336                },
1337            })
1338        }
1339
1340        async fn login(
1341            &self,
1342            _profile: &ProviderProfile,
1343            _context: &dyn ProviderContext,
1344            _open_browser: bool,
1345        ) -> Result<ProviderLoginResult> {
1346            bail!("not implemented")
1347        }
1348
1349        async fn logout(
1350            &self,
1351            _profile: &ProviderProfile,
1352            _context: &dyn ProviderContext,
1353        ) -> Result<Option<Value>> {
1354            Ok(None)
1355        }
1356
1357        async fn sync_models(
1358            &self,
1359            profile: &ProviderProfile,
1360            _context: &dyn ProviderContext,
1361        ) -> Result<ProviderModelSyncResult> {
1362            Ok(ProviderModelSyncResult {
1363                credentials: None,
1364                models: vec![ModelDescriptor {
1365                    id: "codex/gpt-5.1".to_owned(),
1366                    provider: ProviderKind::Codex,
1367                    profile_id: Some(profile.id),
1368                    upstream_name: "gpt-5.1".to_owned(),
1369                    display_name: "GPT-5.1".to_owned(),
1370                    metadata: None,
1371                }],
1372            })
1373        }
1374
1375        async fn chat_completion(
1376            &self,
1377            _profile: &ProviderProfile,
1378            _context: &dyn ProviderContext,
1379            request: &ChatCompletionRequest,
1380        ) -> Result<ProviderChatResult> {
1381            Ok(ProviderChatResult {
1382                credentials: None,
1383                completion: ChatCompletionResult {
1384                    model: request.model.clone(),
1385                    message: ChatMessage {
1386                        role: ChatRole::Assistant,
1387                        content: "hello".to_owned(),
1388                    },
1389                    finish_reason: "stop".to_owned(),
1390                    usage: gunmetal_core::TokenUsage {
1391                        input_tokens: Some(1),
1392                        output_tokens: Some(1),
1393                        total_tokens: Some(2),
1394                    },
1395                },
1396            })
1397        }
1398    }
1399}