Skip to main content

rig_core/providers/copilot/
mod.rs

1//! GitHub Copilot provider.
2//!
3//! Supports Chat Completions, Responses, and Embeddings against
4//! `https://api.githubcopilot.com`.
5//!
6//! `Client::completion_model(...)` automatically routes Codex-class models
7//! through `/responses` and conversational models through
8//! `/chat/completions`.
9//!
10//! # Example
11//! ```no_run
12//! use rig_core::client::{CompletionClient, ProviderClient};
13//! use rig_core::providers::copilot;
14//!
15//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
16//! let client = copilot::Client::from_env()?;
17//! let model = client.completion_model(copilot::GPT_4O);
18//! # let _ = model;
19//! # Ok(())
20//! # }
21//! ```
22
23mod auth;
24
25use crate::client::{
26    self, ApiKey, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider, ProviderBuilder,
27    ProviderClient, Transport,
28};
29use crate::completion::{self, CompletionError, GetTokenUsage};
30use crate::embeddings::{self, EmbeddingError};
31use crate::http_client::{self, HttpClientExt};
32use crate::model::{Model, ModelList, ModelListingError};
33use crate::providers::internal::openai_chat_completions_compatible::{
34    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
35    CompatibleToolCallChunk,
36};
37use crate::providers::openai;
38use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest};
39use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse};
40use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
41use async_stream::stream;
42use futures::StreamExt;
43use http::Request;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use std::borrow::Cow;
47use std::collections::HashMap;
48use std::fmt::Debug;
49use std::path::{Path, PathBuf};
50use tracing::info_span;
51use tracing_futures::Instrument as _;
52
53const GITHUB_COPILOT_API_BASE_URL: &str = "https://api.githubcopilot.com";
54const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
55const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
56const API_VERSION: &str = "2025-04-01";
57
58/// `gpt-4`
59pub const GPT_4: &str = "gpt-4";
60/// `gpt-4o`
61pub const GPT_4O: &str = "gpt-4o";
62/// `gpt-4o-mini`
63pub const GPT_4O_MINI: &str = "gpt-4o-mini";
64/// `gpt-4.1`
65pub const GPT_4_1: &str = "gpt-4.1";
66/// `gpt-4.1-mini`
67pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
68/// `gpt-4.1-nano`
69pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
70/// `gpt-5.3-codex`
71pub const GPT_5_3_CODEX: &str = "gpt-5.3-codex";
72/// `gpt-5.1-codex`
73pub const GPT_5_1_CODEX: &str = "gpt-5.1-codex";
74/// `gpt-5.5`
75pub const GPT_5_5: &str = "gpt-5.5";
76/// `gpt-5.4`
77pub const GPT_5_4: &str = "gpt-5.4";
78/// `claude-sonnet-4` completion model (Anthropic, via Copilot)
79pub const CLAUDE_SONNET_4: &str = "claude-sonnet-4";
80/// `claude-sonnet-4.6`
81pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4.6";
82/// `claude-opus-4.6`
83pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4.6";
84/// `claude-opus-4.7`
85pub const CLAUDE_OPUS_4_7: &str = "claude-opus-4.7";
86/// `claude-3.5-sonnet` completion model (Anthropic, via Copilot)
87pub const CLAUDE_3_5_SONNET: &str = "claude-3.5-sonnet";
88/// `gemini-3-flash-preview` completion model (Google, via Copilot)
89pub const GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
90/// `gemini-3.1-pro-preview` completion model (Google, via Copilot)
91pub const GEMINI_3_1_PRO_FLASH: &str = "gemini-3.1-pro-preview";
92/// `gemini-2.0-flash-001` completion model (Google, via Copilot)
93pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash-001";
94/// `o3-mini` reasoning model (OpenAI, via Copilot)
95pub const O3_MINI: &str = "o3-mini";
96/// `text-embedding-3-small`
97pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
98/// `text-embedding-3-large`
99pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
100/// `text-embedding-ada-002`
101pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
102
103pub use openai::EncodingFormat;
104
105#[derive(Clone)]
106pub enum CopilotAuth {
107    ApiKey(String),
108    GitHubAccessToken(String),
109    OAuth,
110}
111
112impl ApiKey for CopilotAuth {}
113
114impl<S> From<S> for CopilotAuth
115where
116    S: Into<String>,
117{
118    fn from(value: S) -> Self {
119        Self::ApiKey(value.into())
120    }
121}
122
123impl Debug for CopilotAuth {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            Self::ApiKey(_) => f.write_str("ApiKey(<redacted>)"),
127            Self::GitHubAccessToken(_) => f.write_str("GitHubAccessToken(<redacted>)"),
128            Self::OAuth => f.write_str("OAuth"),
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct CopilotBuilder {
135    access_token_file: Option<PathBuf>,
136    api_key_file: Option<PathBuf>,
137    device_code_handler: auth::DeviceCodeHandler,
138}
139
140#[derive(Clone)]
141pub struct CopilotExt {
142    auth: auth::Authenticator,
143}
144
145impl Debug for CopilotExt {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("CopilotExt")
148            .field("auth", &self.auth)
149            .finish()
150    }
151}
152
153pub type Client<H = reqwest::Client> = client::Client<CopilotExt, H>;
154pub type ClientBuilder<H = crate::markers::Missing> =
155    client::ClientBuilder<CopilotBuilder, CopilotAuth, H>;
156
157impl Default for CopilotBuilder {
158    fn default() -> Self {
159        let token_dir = default_token_dir();
160        Self {
161            access_token_file: token_dir.as_ref().map(|dir| dir.join("access-token")),
162            api_key_file: token_dir.map(|dir| dir.join("api-key.json")),
163            device_code_handler: auth::DeviceCodeHandler::default(),
164        }
165    }
166}
167
168impl Provider for CopilotExt {
169    type Builder = CopilotBuilder;
170
171    const VERIFY_PATH: &'static str = "";
172}
173
174impl<H> Capabilities<H> for CopilotExt {
175    type Completion = Capable<CompletionModel<H>>;
176    type Embeddings = Capable<EmbeddingModel<H>>;
177    type Transcription = Nothing;
178    type ModelListing = Capable<CopilotModelLister<H>>;
179    #[cfg(feature = "image")]
180    type ImageGeneration = Nothing;
181    #[cfg(feature = "audio")]
182    type AudioGeneration = Nothing;
183}
184
185impl DebugExt for CopilotExt {}
186
187impl ProviderBuilder for CopilotBuilder {
188    type Extension<H>
189        = CopilotExt
190    where
191        H: HttpClientExt;
192    type ApiKey = CopilotAuth;
193
194    const BASE_URL: &'static str = GITHUB_COPILOT_API_BASE_URL;
195
196    fn build<H>(
197        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
198    ) -> http_client::Result<Self::Extension<H>>
199    where
200        H: HttpClientExt,
201    {
202        let auth = match builder.get_api_key() {
203            CopilotAuth::ApiKey(api_key) => auth::AuthSource::ApiKey(api_key.clone()),
204            CopilotAuth::GitHubAccessToken(access_token) => {
205                auth::AuthSource::GitHubAccessToken(access_token.clone())
206            }
207            CopilotAuth::OAuth => auth::AuthSource::OAuth,
208        };
209
210        let ext = builder.ext();
211        Ok(CopilotExt {
212            auth: auth::Authenticator::new(
213                auth,
214                ext.access_token_file.clone(),
215                ext.api_key_file.clone(),
216                ext.device_code_handler.clone(),
217            ),
218        })
219    }
220}
221
222impl ProviderClient for Client {
223    type Input = CopilotAuth;
224    type Error = crate::client::ProviderClientError;
225
226    fn from_env() -> Result<Self, Self::Error> {
227        let mut builder = Self::builder();
228        fn get(name: &str) -> Option<String> {
229            std::env::var(name).ok()
230        }
231
232        if let Some(base_url) = env_base_url(&get) {
233            builder = builder.base_url(base_url);
234        }
235
236        if let Some(api_key) = env_api_key(&get) {
237            builder.api_key(api_key).build().map_err(Into::into)
238        } else if let Some(access_token) = env_github_access_token(&get) {
239            builder
240                .github_access_token(access_token)
241                .build()
242                .map_err(Into::into)
243        } else {
244            builder.oauth().build().map_err(Into::into)
245        }
246    }
247
248    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
249        Self::builder().api_key(input).build().map_err(Into::into)
250    }
251}
252
253impl<H> client::ClientBuilder<CopilotBuilder, crate::markers::Missing, H> {
254    pub fn github_access_token(
255        self,
256        access_token: impl Into<String>,
257    ) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
258        self.api_key(CopilotAuth::GitHubAccessToken(access_token.into()))
259    }
260
261    pub fn oauth(self) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
262        self.api_key(CopilotAuth::OAuth)
263    }
264}
265
266impl<H> ClientBuilder<H> {
267    pub fn on_device_code<F>(self, handler: F) -> Self
268    where
269        F: Fn(auth::DeviceCodePrompt) + Send + Sync + 'static,
270    {
271        self.over_ext(|mut ext| {
272            ext.device_code_handler = auth::DeviceCodeHandler::new(handler);
273            ext
274        })
275    }
276
277    pub fn token_dir(self, path: impl AsRef<Path>) -> Self {
278        let path = path.as_ref();
279        self.over_ext(|mut ext| {
280            ext.access_token_file = Some(path.join("access-token"));
281            ext.api_key_file = Some(path.join("api-key.json"));
282            ext
283        })
284    }
285
286    pub fn access_token_file(self, path: impl AsRef<Path>) -> Self {
287        let path = path.as_ref().to_path_buf();
288        self.over_ext(|mut ext| {
289            ext.access_token_file = Some(path);
290            ext
291        })
292    }
293
294    pub fn api_key_file(self, path: impl AsRef<Path>) -> Self {
295        let path = path.as_ref().to_path_buf();
296        self.over_ext(|mut ext| {
297            ext.api_key_file = Some(path);
298            ext
299        })
300    }
301}
302
303fn env_value<F>(get: &F, name: &str) -> Option<String>
304where
305    F: Fn(&str) -> Option<String>,
306{
307    get(name).filter(|value| !value.trim().is_empty())
308}
309
310fn first_env_value<F>(get: &F, keys: &[&str]) -> Option<String>
311where
312    F: Fn(&str) -> Option<String>,
313{
314    keys.iter().find_map(|key| env_value(get, key))
315}
316
317fn env_api_key<F>(get: &F) -> Option<String>
318where
319    F: Fn(&str) -> Option<String>,
320{
321    first_env_value(get, &["GITHUB_COPILOT_API_KEY", "COPILOT_API_KEY"])
322}
323
324fn env_github_access_token<F>(get: &F) -> Option<String>
325where
326    F: Fn(&str) -> Option<String>,
327{
328    first_env_value(get, &["COPILOT_GITHUB_ACCESS_TOKEN", "GITHUB_TOKEN"])
329}
330
331fn env_base_url<F>(get: &F) -> Option<String>
332where
333    F: Fn(&str) -> Option<String>,
334{
335    first_env_value(get, &["GITHUB_COPILOT_API_BASE", "COPILOT_BASE_URL"])
336}
337
338impl<H> Client<H>
339where
340    H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
341{
342    pub async fn authorize(&self) -> Result<(), auth::AuthError> {
343        self.ext().auth.auth_context().await.map(|_| ())
344    }
345}
346
347fn default_headers(
348    api_key: &str,
349    initiator: &'static str,
350    has_vision: bool,
351) -> Vec<(&'static str, String)> {
352    let mut headers = vec![
353        (
354            http::header::AUTHORIZATION.as_str(),
355            format!("Bearer {api_key}"),
356        ),
357        ("copilot-integration-id", "vscode-chat".to_string()),
358        ("editor-version", "vscode/1.95.0".to_string()),
359        ("editor-plugin-version", EDITOR_PLUGIN_VERSION.to_string()),
360        ("user-agent", USER_AGENT.to_string()),
361        ("openai-intent", "conversation-panel".to_string()),
362        ("x-github-api-version", API_VERSION.to_string()),
363        ("x-request-id", nanoid::nanoid!()),
364        (
365            "x-vscode-user-agent-library-version",
366            "electron-fetch".to_string(),
367        ),
368        ("X-Initiator", initiator.to_string()),
369    ];
370
371    if has_vision {
372        headers.push(("copilot-vision-request", "true".to_string()));
373    }
374
375    headers
376}
377
378fn apply_headers(
379    builder: http_client::Builder,
380    headers: &[(&'static str, String)],
381) -> http_client::Builder {
382    headers
383        .iter()
384        .fold(builder, |builder, (key, value)| builder.header(*key, value))
385}
386
387fn runtime_base_url<'a, H>(client: &'a Client<H>, auth: &'a auth::AuthContext) -> Cow<'a, str> {
388    if client.base_url() == GITHUB_COPILOT_API_BASE_URL {
389        auth.api_base
390            .as_deref()
391            .map(Cow::Borrowed)
392            .unwrap_or_else(|| Cow::Borrowed(client.base_url()))
393    } else {
394        Cow::Borrowed(client.base_url())
395    }
396}
397
398fn post_with_auth_base<H>(
399    client: &Client<H>,
400    auth: &auth::AuthContext,
401    path: &str,
402    transport: Transport,
403) -> http_client::Result<http_client::Builder> {
404    let uri = client
405        .ext()
406        .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
407    let mut req = Request::post(uri);
408
409    if let Some(headers) = req.headers_mut() {
410        headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
411    }
412
413    client.ext().with_custom(req)
414}
415
416fn get_with_auth_base<H>(
417    client: &Client<H>,
418    auth: &auth::AuthContext,
419    path: &str,
420    transport: Transport,
421) -> http_client::Result<http_client::Builder> {
422    let uri = client
423        .ext()
424        .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
425    let mut req = Request::get(uri);
426
427    if let Some(headers) = req.headers_mut() {
428        headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
429    }
430
431    client.ext().with_custom(req)
432}
433
434fn request_initiator(request: &completion::CompletionRequest) -> &'static str {
435    for message in request.chat_history.iter() {
436        match message {
437            crate::completion::Message::Assistant { .. } => return "agent",
438            crate::completion::Message::User { content } => {
439                if content
440                    .iter()
441                    .any(|item| matches!(item, crate::message::UserContent::ToolResult(_)))
442                {
443                    return "agent";
444                }
445            }
446            crate::completion::Message::System { .. } => {}
447        }
448    }
449
450    "user"
451}
452
453fn request_has_vision(request: &completion::CompletionRequest) -> bool {
454    request.chat_history.iter().any(|message| match message {
455        crate::completion::Message::User { content } => content
456            .iter()
457            .any(|item| matches!(item, crate::message::UserContent::Image(_))),
458        _ => false,
459    })
460}
461
462#[derive(Clone, Copy, Debug, PartialEq, Eq)]
463enum CompletionRoute {
464    ChatCompletions,
465    Responses,
466}
467
468fn route_for_model(model: &str) -> CompletionRoute {
469    if model.to_ascii_lowercase().contains("codex") {
470        CompletionRoute::Responses
471    } else {
472        CompletionRoute::ChatCompletions
473    }
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
477#[serde(tag = "api", rename_all = "snake_case")]
478pub enum CopilotCompletionResponse {
479    Chat(ChatCompletionResponse),
480    Responses(Box<responses_api::CompletionResponse>),
481}
482
483#[derive(Clone, Serialize, Deserialize)]
484#[serde(tag = "api", rename_all = "snake_case")]
485pub enum CopilotStreamingResponse {
486    Chat(openai::completion::streaming::StreamingCompletionResponse),
487    Responses(responses_api::streaming::StreamingCompletionResponse),
488}
489
490impl GetTokenUsage for CopilotStreamingResponse {
491    fn token_usage(&self) -> Option<completion::Usage> {
492        match self {
493            Self::Chat(response) => response.token_usage(),
494            Self::Responses(response) => response.token_usage(),
495        }
496    }
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct ChatCompletionResponse {
501    pub id: String,
502    #[serde(default)]
503    pub object: Option<String>,
504    #[serde(default)]
505    pub created: Option<u64>,
506    pub model: String,
507    pub system_fingerprint: Option<String>,
508    pub choices: Vec<ChatChoice>,
509    pub usage: Option<openai::completion::Usage>,
510}
511
512#[derive(Clone, Debug, Serialize, Deserialize)]
513pub struct ChatChoice {
514    #[serde(default)]
515    pub index: usize,
516    pub message: openai::completion::Message,
517    pub logprobs: Option<serde_json::Value>,
518    #[serde(default)]
519    pub finish_reason: Option<String>,
520}
521
522impl TryFrom<ChatCompletionResponse> for completion::CompletionResponse<ChatCompletionResponse> {
523    type Error = CompletionError;
524
525    fn try_from(response: ChatCompletionResponse) -> Result<Self, Self::Error> {
526        let choice = response.choices.first().ok_or_else(|| {
527            CompletionError::ResponseError("Response contained no choices".to_owned())
528        })?;
529
530        let content = match &choice.message {
531            openai::completion::Message::Assistant {
532                content,
533                tool_calls,
534                ..
535            } => {
536                let mut content = content
537                    .iter()
538                    .filter_map(|c| {
539                        let s = match c {
540                            openai::completion::AssistantContent::Text { text } => text,
541                            openai::completion::AssistantContent::Refusal { refusal } => refusal,
542                        };
543                        if s.is_empty() {
544                            None
545                        } else {
546                            Some(completion::AssistantContent::text(s))
547                        }
548                    })
549                    .collect::<Vec<_>>();
550
551                content.extend(
552                    tool_calls
553                        .iter()
554                        .map(|call| {
555                            completion::AssistantContent::tool_call(
556                                &call.id,
557                                &call.function.name,
558                                call.function.arguments.clone(),
559                            )
560                        })
561                        .collect::<Vec<_>>(),
562                );
563                Ok(content)
564            }
565            _ => Err(CompletionError::ResponseError(
566                "Response did not contain a valid message or tool call".into(),
567            )),
568        }?;
569
570        let choice = crate::OneOrMany::many(content).map_err(|_| {
571            CompletionError::ResponseError(
572                "Response contained no message or tool call (empty)".to_owned(),
573            )
574        })?;
575
576        let usage = response
577            .usage
578            .as_ref()
579            .map(|usage| completion::Usage {
580                input_tokens: usage.prompt_tokens as u64,
581                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
582                total_tokens: usage.total_tokens as u64,
583                cached_input_tokens: usage
584                    .prompt_tokens_details
585                    .as_ref()
586                    .map(|d| d.cached_tokens as u64)
587                    .unwrap_or(0),
588                cache_creation_input_tokens: 0,
589                reasoning_tokens: 0,
590            })
591            .unwrap_or_default();
592
593        Ok(completion::CompletionResponse {
594            choice,
595            usage,
596            raw_response: response,
597            message_id: None,
598        })
599    }
600}
601
602#[derive(Debug, Deserialize)]
603pub struct ChatApiErrorResponse {
604    #[serde(default)]
605    pub message: Option<String>,
606    #[serde(default)]
607    pub error: Option<String>,
608}
609
610impl ChatApiErrorResponse {
611    pub fn error_message(&self) -> &str {
612        self.message
613            .as_deref()
614            .or(self.error.as_deref())
615            .unwrap_or("unknown error")
616    }
617}
618
619#[derive(Debug, Deserialize)]
620#[serde(untagged)]
621enum ChatApiResponse<T> {
622    Ok(T),
623    Err(ChatApiErrorResponse),
624}
625
626#[derive(Clone)]
627pub struct CompletionModel<H = reqwest::Client> {
628    client: Client<H>,
629    pub model: String,
630    pub strict_tools: bool,
631    pub tool_result_array_content: bool,
632}
633
634impl<H> CompletionModel<H>
635where
636    Client<H>: HttpClientExt + Clone + Debug + 'static,
637    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
638{
639    pub fn new(client: Client<H>, model: impl Into<String>) -> Self {
640        Self {
641            client,
642            model: model.into(),
643            strict_tools: false,
644            tool_result_array_content: false,
645        }
646    }
647
648    pub fn with_strict_tools(mut self) -> Self {
649        self.strict_tools = true;
650        self
651    }
652
653    pub fn with_tool_result_array_content(mut self) -> Self {
654        self.tool_result_array_content = true;
655        self
656    }
657
658    fn route(&self) -> CompletionRoute {
659        route_for_model(&self.model)
660    }
661
662    async fn auth_context(&self) -> Result<auth::AuthContext, CompletionError> {
663        self.client
664            .ext()
665            .auth
666            .auth_context()
667            .await
668            .map_err(|err| CompletionError::ProviderError(err.to_string()))
669    }
670
671    fn chat_request(
672        &self,
673        completion_request: completion::CompletionRequest,
674    ) -> Result<openai::completion::CompletionRequest, CompletionError> {
675        openai::completion::CompletionRequest::try_from(openai::completion::OpenAIRequestParams {
676            model: self.model.clone(),
677            request: completion_request,
678            strict_tools: self.strict_tools,
679            tool_result_array_content: self.tool_result_array_content,
680        })
681    }
682
683    fn responses_request(
684        &self,
685        completion_request: completion::CompletionRequest,
686    ) -> Result<ResponsesRequest, CompletionError> {
687        ResponsesRequest::try_from((self.model.clone(), completion_request))
688    }
689
690    async fn completion_chat(
691        &self,
692        completion_request: completion::CompletionRequest,
693    ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
694        let initiator = request_initiator(&completion_request);
695        let has_vision = request_has_vision(&completion_request);
696        let request = self.chat_request(completion_request)?;
697        let body = serde_json::to_vec(&request)?;
698        let auth = self.auth_context().await?;
699
700        let headers = default_headers(&auth.api_key, initiator, has_vision);
701        let req = apply_headers(
702            post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Http)?,
703            &headers,
704        )
705        .body(body)
706        .map_err(|err| CompletionError::HttpError(err.into()))?;
707
708        let span = if tracing::Span::current().is_disabled() {
709            info_span!(
710                target: "rig::completions",
711                "chat",
712                gen_ai.operation.name = "chat",
713                gen_ai.provider.name = "copilot",
714                gen_ai.request.model = self.model,
715                gen_ai.response.id = tracing::field::Empty,
716                gen_ai.response.model = tracing::field::Empty,
717                gen_ai.usage.output_tokens = tracing::field::Empty,
718                gen_ai.usage.input_tokens = tracing::field::Empty,
719                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
720            )
721        } else {
722            tracing::Span::current()
723        };
724
725        async move {
726            let response = self.client.send(req).await?;
727
728            if response.status().is_success() {
729                let body = http_client::text(response).await?;
730                match serde_json::from_str::<ChatApiResponse<ChatCompletionResponse>>(&body)? {
731                    ChatApiResponse::Ok(response) => {
732                        let core = completion::CompletionResponse::try_from(response.clone())?;
733                        let span = tracing::Span::current();
734                        span.record("gen_ai.response.id", response.id.as_str());
735                        span.record("gen_ai.response.model", response.model.as_str());
736                        if let Some(usage) = &response.usage {
737                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
738                            span.record(
739                                "gen_ai.usage.output_tokens",
740                                usage.total_tokens - usage.prompt_tokens,
741                            );
742                            span.record(
743                                "gen_ai.usage.cache_read.input_tokens",
744                                usage
745                                    .prompt_tokens_details
746                                    .as_ref()
747                                    .map(|details| details.cached_tokens)
748                                    .unwrap_or(0),
749                            );
750                        }
751
752                        Ok(completion::CompletionResponse {
753                            choice: core.choice,
754                            usage: core.usage,
755                            raw_response: CopilotCompletionResponse::Chat(response),
756                            message_id: core.message_id,
757                        })
758                    }
759                    ChatApiResponse::Err(err) => Err(CompletionError::ProviderError(
760                        err.error_message().to_string(),
761                    )),
762                }
763            } else {
764                let body = http_client::text(response).await?;
765                Err(CompletionError::ProviderError(body))
766            }
767        }
768        .instrument(span)
769        .await
770    }
771
772    async fn completion_responses(
773        &self,
774        completion_request: completion::CompletionRequest,
775    ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
776        let initiator = request_initiator(&completion_request);
777        let has_vision = request_has_vision(&completion_request);
778        let request = self.responses_request(completion_request)?;
779        let auth = self.auth_context().await?;
780
781        let headers = default_headers(&auth.api_key, initiator, has_vision);
782        let req = apply_headers(
783            post_with_auth_base(&self.client, &auth, "/responses", Transport::Http)?,
784            &headers,
785        )
786        .body(serde_json::to_vec(&request)?)
787        .map_err(|err| CompletionError::HttpError(err.into()))?;
788
789        let span = if tracing::Span::current().is_disabled() {
790            info_span!(
791                target: "rig::completions",
792                "chat",
793                gen_ai.operation.name = "chat",
794                gen_ai.provider.name = "copilot",
795                gen_ai.request.model = self.model,
796                gen_ai.response.id = tracing::field::Empty,
797                gen_ai.response.model = tracing::field::Empty,
798                gen_ai.usage.output_tokens = tracing::field::Empty,
799                gen_ai.usage.input_tokens = tracing::field::Empty,
800                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
801            )
802        } else {
803            tracing::Span::current()
804        };
805
806        async move {
807            let response = self.client.send(req).await?;
808            if response.status().is_success() {
809                let body = http_client::text(response).await?;
810                let response = serde_json::from_str::<responses_api::CompletionResponse>(&body)?;
811                let core = completion::CompletionResponse::try_from(response.clone())?;
812
813                let span = tracing::Span::current();
814                span.record("gen_ai.response.id", response.id.as_str());
815                span.record("gen_ai.response.model", response.model.as_str());
816                if let Some(usage) = &response.usage {
817                    span.record("gen_ai.usage.input_tokens", usage.input_tokens);
818                    span.record("gen_ai.usage.output_tokens", usage.output_tokens);
819                    span.record(
820                        "gen_ai.usage.cache_read.input_tokens",
821                        usage
822                            .input_tokens_details
823                            .as_ref()
824                            .map(|details| details.cached_tokens)
825                            .unwrap_or(0),
826                    );
827                }
828
829                Ok(completion::CompletionResponse {
830                    choice: core.choice,
831                    usage: core.usage,
832                    raw_response: CopilotCompletionResponse::Responses(Box::new(response)),
833                    message_id: core.message_id,
834                })
835            } else {
836                let body = http_client::text(response).await?;
837                Err(CompletionError::ProviderError(body))
838            }
839        }
840        .instrument(span)
841        .await
842    }
843
844    async fn stream_chat(
845        &self,
846        completion_request: completion::CompletionRequest,
847    ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
848        let initiator = request_initiator(&completion_request);
849        let has_vision = request_has_vision(&completion_request);
850        let request = self.chat_request(completion_request)?;
851        let auth = self.auth_context().await?;
852        let headers = default_headers(&auth.api_key, initiator, has_vision);
853        let mut request_json = serde_json::to_value(&request)?;
854        let request_object = request_json.as_object_mut().ok_or_else(|| {
855            CompletionError::ResponseError("copilot request body must be a JSON object".into())
856        })?;
857        request_object.insert("stream".to_owned(), json!(true));
858        request_object.insert(
859            "stream_options".to_owned(),
860            json!({ "include_usage": true }),
861        );
862
863        let req = apply_headers(
864            post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?,
865            &headers,
866        )
867        .body(serde_json::to_vec(&request_json)?)
868        .map_err(|err| CompletionError::HttpError(err.into()))?;
869
870        let span = if tracing::Span::current().is_disabled() {
871            info_span!(
872                target: "rig::completions",
873                "chat_streaming",
874                gen_ai.operation.name = "chat_streaming",
875                gen_ai.provider.name = "copilot",
876                gen_ai.request.model = self.model,
877                gen_ai.response.id = tracing::field::Empty,
878                gen_ai.response.model = tracing::field::Empty,
879                gen_ai.usage.output_tokens = tracing::field::Empty,
880                gen_ai.usage.input_tokens = tracing::field::Empty,
881                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
882            )
883        } else {
884            tracing::Span::current()
885        };
886
887        tracing::Instrument::instrument(
888            send_copilot_chat_streaming_request(self.client.clone(), req),
889            span,
890        )
891        .await
892    }
893
894    async fn stream_responses(
895        &self,
896        completion_request: completion::CompletionRequest,
897    ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
898        let initiator = request_initiator(&completion_request);
899        let has_vision = request_has_vision(&completion_request);
900        let mut request = self.responses_request(completion_request)?;
901        request.stream = Some(true);
902        let auth = self.auth_context().await?;
903
904        let headers = default_headers(&auth.api_key, initiator, has_vision);
905        let req = apply_headers(
906            post_with_auth_base(&self.client, &auth, "/responses", Transport::Sse)?,
907            &headers,
908        )
909        .body(serde_json::to_vec(&request)?)
910        .map_err(|err| CompletionError::HttpError(err.into()))?;
911
912        let span = if tracing::Span::current().is_disabled() {
913            info_span!(
914                target: "rig::completions",
915                "chat_streaming",
916                gen_ai.operation.name = "chat_streaming",
917                gen_ai.provider.name = "copilot",
918                gen_ai.request.model = self.model,
919                gen_ai.response.id = tracing::field::Empty,
920                gen_ai.response.model = tracing::field::Empty,
921                gen_ai.usage.output_tokens = tracing::field::Empty,
922                gen_ai.usage.input_tokens = tracing::field::Empty,
923                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
924            )
925        } else {
926            tracing::Span::current()
927        };
928
929        let client = self.client.clone();
930        let mut event_source = crate::http_client::sse::GenericEventSource::new(client, req);
931
932        let stream = tracing_futures::Instrument::instrument(
933            stream! {
934                let mut final_usage = responses_api::ResponsesUsage::new();
935                let mut tool_calls: Vec<streaming::RawStreamingChoice<CopilotStreamingResponse>> = Vec::new();
936                let mut tool_call_internal_ids: HashMap<String, String> = HashMap::new();
937                let span = tracing::Span::current();
938
939                let mut terminated_with_error = false;
940
941                while let Some(event_result) = event_source.next().await {
942                    match event_result {
943                        Ok(crate::http_client::sse::Event::Open) => continue,
944                        Ok(crate::http_client::sse::Event::Message(evt)) => {
945                            if evt.data.trim().is_empty() {
946                                continue;
947                            }
948
949                            let Ok(data) = serde_json::from_str::<responses_api::streaming::StreamingCompletionChunk>(&evt.data) else {
950                                continue;
951                            };
952
953                            if let responses_api::streaming::StreamingCompletionChunk::Delta(chunk) = &data {
954                                use responses_api::streaming::{ItemChunkKind, StreamingItemDoneOutput};
955
956                                match &chunk.data {
957                                    ItemChunkKind::OutputItemAdded(message) => {
958                                        if let StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } = message {
959                                            let internal_call_id = tool_call_internal_ids
960                                                .entry(func.id.clone())
961                                                .or_insert_with(|| nanoid::nanoid!())
962                                                .clone();
963                                            yield Ok(RawStreamingChoice::ToolCallDelta {
964                                                id: func.id.clone(),
965                                                internal_call_id,
966                                                content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
967                                            });
968                                        }
969                                    }
970                                    ItemChunkKind::OutputItemDone(message) => match message {
971                                        StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } => {
972                                            let internal_id = tool_call_internal_ids
973                                                .entry(func.id.clone())
974                                                .or_insert_with(|| nanoid::nanoid!())
975                                                .clone();
976                                            let raw_tool_call = streaming::RawStreamingToolCall::new(
977                                                func.id.clone(),
978                                                func.name.clone(),
979                                                func.arguments.clone(),
980                                            )
981                                            .with_internal_call_id(internal_id)
982                                            .with_call_id(func.call_id.clone());
983                                            tool_calls.push(RawStreamingChoice::ToolCall(raw_tool_call));
984                                        }
985                                        StreamingItemDoneOutput { item: responses_api::Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
986                                            for reasoning_choice in responses_api::streaming::reasoning_choices_from_done_item(
987                                                id,
988                                                summary,
989                                                encrypted_content.as_deref(),
990                                            ) {
991                                                match reasoning_choice {
992                                                    RawStreamingChoice::Reasoning { id, content } => {
993                                                        yield Ok(RawStreamingChoice::Reasoning { id, content });
994                                                    }
995                                                    RawStreamingChoice::ReasoningDelta { id, reasoning } => {
996                                                        yield Ok(RawStreamingChoice::ReasoningDelta { id, reasoning });
997                                                    }
998                                                    _ => {}
999                                                }
1000                                            }
1001                                        }
1002                                        StreamingItemDoneOutput { item: responses_api::Output::Message(msg), .. } => {
1003                                            yield Ok(RawStreamingChoice::MessageId(msg.id.clone()));
1004                                        }
1005                                        StreamingItemDoneOutput { item: responses_api::Output::Unknown, .. } => {}
1006                                    },
1007                                    ItemChunkKind::OutputTextDelta(delta) => {
1008                                        yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1009                                    }
1010                                    ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
1011                                        yield Ok(RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
1012                                    }
1013                                    ItemChunkKind::RefusalDelta(delta) => {
1014                                        yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1015                                    }
1016                                    ItemChunkKind::FunctionCallArgsDelta(delta) => {
1017                                        let internal_call_id = tool_call_internal_ids
1018                                            .entry(delta.item_id.clone())
1019                                            .or_insert_with(|| nanoid::nanoid!())
1020                                            .clone();
1021                                        yield Ok(RawStreamingChoice::ToolCallDelta {
1022                                            id: delta.item_id.clone(),
1023                                            internal_call_id,
1024                                            content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
1025                                        })
1026                                    }
1027                                    _ => continue,
1028                                }
1029                            }
1030
1031                            if let responses_api::streaming::StreamingCompletionChunk::Response(chunk) = data {
1032                                let responses_api::streaming::ResponseChunk { kind, response, .. } = *chunk;
1033                                match kind {
1034                                    responses_api::streaming::ResponseChunkKind::ResponseCompleted => {
1035                                        span.record("gen_ai.response.id", response.id.as_str());
1036                                        span.record("gen_ai.response.model", response.model.as_str());
1037                                        if let Some(usage) = response.usage {
1038                                            final_usage = usage;
1039                                        }
1040                                    }
1041                                    responses_api::streaming::ResponseChunkKind::ResponseFailed
1042                                    | responses_api::streaming::ResponseChunkKind::ResponseIncomplete => {
1043                                        let error = response
1044                                            .error
1045                                            .as_ref()
1046                                            .map(|err| err.message.clone())
1047                                            .unwrap_or_else(|| "Copilot response stream failed".into());
1048                                        terminated_with_error = true;
1049                                        yield Err(CompletionError::ProviderError(error));
1050                                        break;
1051                                    }
1052                                    _ => continue,
1053                                }
1054                            }
1055                        }
1056                        Err(crate::http_client::Error::StreamEnded) => {
1057                            break;
1058                        }
1059                        Err(error) => {
1060                            terminated_with_error = true;
1061                            yield Err(CompletionError::ProviderError(error.to_string()));
1062                            break;
1063                        }
1064                    }
1065                }
1066
1067                event_source.close();
1068
1069                if terminated_with_error {
1070                    return;
1071                }
1072
1073                for tool_call in &tool_calls {
1074                    yield Ok(tool_call.to_owned())
1075                }
1076
1077                span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
1078                span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
1079                span.record(
1080                    "gen_ai.usage.cache_read.input_tokens",
1081                    final_usage
1082                        .input_tokens_details
1083                        .as_ref()
1084                        .map(|details| details.cached_tokens)
1085                        .unwrap_or(0),
1086                );
1087
1088                yield Ok(RawStreamingChoice::FinalResponse(
1089                    CopilotStreamingResponse::Responses(
1090                        responses_api::streaming::StreamingCompletionResponse { usage: final_usage }
1091                    )
1092                ));
1093            },
1094            span,
1095        );
1096
1097        Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
1098    }
1099}
1100
1101impl<H> completion::CompletionModel for CompletionModel<H>
1102where
1103    Client<H>: HttpClientExt + Clone + Debug + 'static,
1104    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1105{
1106    type Response = CopilotCompletionResponse;
1107    type StreamingResponse = CopilotStreamingResponse;
1108    type Client = Client<H>;
1109
1110    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1111        Self::new(client.clone(), model)
1112    }
1113
1114    async fn completion(
1115        &self,
1116        completion_request: completion::CompletionRequest,
1117    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
1118        match self.route() {
1119            CompletionRoute::ChatCompletions => self.completion_chat(completion_request).await,
1120            CompletionRoute::Responses => self.completion_responses(completion_request).await,
1121        }
1122    }
1123
1124    async fn stream(
1125        &self,
1126        completion_request: completion::CompletionRequest,
1127    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1128        match self.route() {
1129            CompletionRoute::ChatCompletions => self.stream_chat(completion_request).await,
1130            CompletionRoute::Responses => self.stream_responses(completion_request).await,
1131        }
1132    }
1133}
1134
1135#[derive(Clone)]
1136pub struct EmbeddingModel<H = reqwest::Client> {
1137    client: Client<H>,
1138    pub model: String,
1139    pub encoding_format: Option<openai::EncodingFormat>,
1140    pub user: Option<String>,
1141    ndims: usize,
1142}
1143
1144#[derive(Deserialize)]
1145struct CopilotEmbeddingResponse {
1146    data: Vec<CopilotEmbeddingData>,
1147}
1148
1149#[derive(Deserialize)]
1150struct CopilotEmbeddingData {
1151    embedding: Vec<serde_json::Number>,
1152}
1153
1154impl<H> EmbeddingModel<H>
1155where
1156    Client<H>: HttpClientExt + Clone + Debug + 'static,
1157    H: Clone + Default + Debug + 'static,
1158{
1159    pub fn new(client: Client<H>, model: impl Into<String>, ndims: usize) -> Self {
1160        Self {
1161            client,
1162            model: model.into(),
1163            encoding_format: None,
1164            user: None,
1165            ndims,
1166        }
1167    }
1168}
1169
1170impl<H> embeddings::EmbeddingModel for EmbeddingModel<H>
1171where
1172    Client<H>: HttpClientExt + Clone + Debug + WasmCompatSend + WasmCompatSync + 'static,
1173    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1174{
1175    const MAX_DOCUMENTS: usize = 1024;
1176    type Client = Client<H>;
1177
1178    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
1179        let model = model.into();
1180        let dims = ndims.unwrap_or(match model.as_str() {
1181            TEXT_EMBEDDING_3_LARGE => 3072,
1182            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
1183            _ => 0,
1184        });
1185        Self::new(client.clone(), model, dims)
1186    }
1187
1188    fn ndims(&self) -> usize {
1189        self.ndims
1190    }
1191
1192    async fn embed_texts(
1193        &self,
1194        documents: impl IntoIterator<Item = String>,
1195    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
1196        let documents = documents.into_iter().collect::<Vec<_>>();
1197        let auth = self
1198            .client
1199            .ext()
1200            .auth
1201            .auth_context()
1202            .await
1203            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
1204
1205        let headers = default_headers(&auth.api_key, "user", false);
1206        let mut body = json!({
1207            "model": self.model,
1208            "input": documents,
1209        });
1210
1211        let body_object = body.as_object_mut().ok_or_else(|| {
1212            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
1213        })?;
1214
1215        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
1216            body_object.insert("dimensions".to_owned(), json!(self.ndims));
1217        }
1218        if let Some(encoding_format) = &self.encoding_format {
1219            body_object.insert("encoding_format".to_owned(), json!(encoding_format));
1220        }
1221        if let Some(user) = &self.user {
1222            body_object.insert("user".to_owned(), json!(user));
1223        }
1224
1225        let req = apply_headers(
1226            post_with_auth_base(&self.client, &auth, "/embeddings", Transport::Http)?,
1227            &headers,
1228        )
1229        .body(serde_json::to_vec(&body)?)
1230        .map_err(|err| EmbeddingError::HttpError(err.into()))?;
1231
1232        let response = self.client.send(req).await?;
1233        if response.status().is_success() {
1234            let body: Vec<u8> = response.into_body().await?;
1235            #[derive(Deserialize)]
1236            struct NestedApiError {
1237                error: NestedApiErrorMessage,
1238            }
1239
1240            #[derive(Deserialize)]
1241            struct NestedApiErrorMessage {
1242                message: String,
1243            }
1244
1245            let body: CopilotEmbeddingResponse = match serde_json::from_slice(&body) {
1246                Ok(parsed) => parsed,
1247                Err(parse_error) => {
1248                    if let Ok(err) = serde_json::from_slice::<NestedApiError>(&body) {
1249                        return Err(EmbeddingError::ProviderError(err.error.message));
1250                    }
1251
1252                    let preview = String::from_utf8_lossy(&body);
1253                    let preview = if preview.len() > 512 {
1254                        format!("{}...", &preview[..512])
1255                    } else {
1256                        preview.into_owned()
1257                    };
1258
1259                    return Err(EmbeddingError::ProviderError(format!(
1260                        "Failed to parse Copilot embeddings response: {parse_error}; body: {preview}"
1261                    )));
1262                }
1263            };
1264
1265            Ok(body
1266                .data
1267                .into_iter()
1268                .zip(documents.into_iter())
1269                .map(|(embedding, document)| embeddings::Embedding {
1270                    document,
1271                    vec: embedding
1272                        .embedding
1273                        .into_iter()
1274                        .filter_map(|n| n.as_f64())
1275                        .collect(),
1276                })
1277                .collect())
1278        } else {
1279            let text = http_client::text(response).await?;
1280            Err(EmbeddingError::ProviderError(text))
1281        }
1282    }
1283}
1284
1285const MODEL_LISTING_PATH: &str = "/models";
1286const MODEL_LISTING_PROVIDER: &str = "Copilot";
1287
1288#[derive(Debug, Deserialize)]
1289struct ListModelsResponse {
1290    data: Vec<ListModelEntry>,
1291}
1292
1293#[derive(Debug, Deserialize)]
1294struct ListModelEntry {
1295    id: String,
1296    #[serde(default)]
1297    name: Option<String>,
1298    #[serde(default)]
1299    vendor: Option<String>,
1300    #[serde(default)]
1301    capabilities: Option<ListModelEntryCapabilities>,
1302}
1303
1304#[derive(Debug, Deserialize)]
1305struct ListModelEntryCapabilities {
1306    #[serde(default, rename = "type")]
1307    r#type: Option<String>,
1308}
1309
1310impl From<ListModelEntry> for Model {
1311    fn from(value: ListModelEntry) -> Self {
1312        let mut model = Model::from_id(value.id);
1313        model.name = value.name;
1314        model.owned_by = value.vendor;
1315        if let Some(caps) = value.capabilities {
1316            model.r#type = caps.r#type;
1317        }
1318        model
1319    }
1320}
1321
1322/// [`ModelLister`] implementation for the GitHub Copilot API (`GET /models`).
1323#[derive(Clone)]
1324pub struct CopilotModelLister<H = reqwest::Client> {
1325    client: Client<H>,
1326}
1327
1328impl<H> ModelLister<H> for CopilotModelLister<H>
1329where
1330    H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
1331{
1332    type Client = Client<H>;
1333
1334    fn new(client: Self::Client) -> Self {
1335        Self { client }
1336    }
1337
1338    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
1339        let auth = self.client.ext().auth.auth_context().await.map_err(|err| {
1340            ModelListingError::AuthError {
1341                message: err.to_string(),
1342            }
1343        })?;
1344
1345        let headers = default_headers(&auth.api_key, "user", false);
1346        let req = apply_headers(
1347            get_with_auth_base(&self.client, &auth, MODEL_LISTING_PATH, Transport::Http)?,
1348            &headers,
1349        )
1350        .body(http_client::NoBody)?;
1351
1352        let response = self.client.send::<_, Vec<u8>>(req).await?;
1353
1354        if !response.status().is_success() {
1355            let status_code = response.status().as_u16();
1356            let body = response.into_body().await?;
1357            return Err(ModelListingError::api_error_with_context(
1358                MODEL_LISTING_PROVIDER,
1359                MODEL_LISTING_PATH,
1360                status_code,
1361                &body,
1362            ));
1363        }
1364
1365        let body = response.into_body().await?;
1366        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
1367            ModelListingError::parse_error_with_context(
1368                MODEL_LISTING_PROVIDER,
1369                MODEL_LISTING_PATH,
1370                &error,
1371                &body,
1372            )
1373        })?;
1374        let models = api_resp.data.into_iter().map(Model::from).collect();
1375
1376        Ok(ModelList::new(models))
1377    }
1378}
1379
1380#[derive(Deserialize, Debug)]
1381struct ChatStreamingFunction {
1382    name: Option<String>,
1383    arguments: Option<String>,
1384}
1385
1386#[derive(Deserialize, Debug)]
1387struct ChatStreamingToolCall {
1388    index: usize,
1389    id: Option<String>,
1390    function: ChatStreamingFunction,
1391}
1392
1393impl From<&ChatStreamingToolCall> for CompatibleToolCallChunk {
1394    fn from(value: &ChatStreamingToolCall) -> Self {
1395        Self {
1396            index: value.index,
1397            id: value.id.clone(),
1398            name: value.function.name.clone(),
1399            arguments: value.function.arguments.clone(),
1400        }
1401    }
1402}
1403
1404#[derive(Deserialize, Debug, Default)]
1405struct ChatStreamingDelta {
1406    #[serde(default)]
1407    content: Option<String>,
1408    #[serde(default)]
1409    reasoning_content: Option<String>,
1410    #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")]
1411    tool_calls: Vec<ChatStreamingToolCall>,
1412}
1413
1414#[derive(Deserialize, Debug, PartialEq)]
1415#[serde(rename_all = "snake_case")]
1416enum ChatFinishReason {
1417    ToolCalls,
1418    Stop,
1419    ContentFilter,
1420    Length,
1421    #[serde(untagged)]
1422    Other(String),
1423}
1424
1425#[derive(Deserialize, Debug)]
1426struct ChatStreamingChoice {
1427    delta: ChatStreamingDelta,
1428    finish_reason: Option<ChatFinishReason>,
1429}
1430
1431#[derive(Deserialize, Debug)]
1432struct ChatStreamingChunk {
1433    id: Option<String>,
1434    model: Option<String>,
1435    choices: Vec<ChatStreamingChoice>,
1436    usage: Option<openai::completion::Usage>,
1437}
1438
1439#[derive(Clone, Copy)]
1440struct CopilotChatCompatibleProfile;
1441
1442impl CompatibleStreamProfile for CopilotChatCompatibleProfile {
1443    type Usage = openai::completion::Usage;
1444    type Detail = ();
1445    type FinalResponse = CopilotStreamingResponse;
1446
1447    fn normalize_chunk(
1448        &self,
1449        data: &str,
1450    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
1451        let data = match serde_json::from_str::<ChatStreamingChunk>(data) {
1452            Ok(data) => data,
1453            Err(error) => {
1454                tracing::debug!(?error, "Couldn't parse Copilot chat SSE payload");
1455                return Ok(None);
1456            }
1457        };
1458
1459        Ok(Some(
1460            openai_chat_completions_compatible::normalize_first_choice_chunk(
1461                data.id,
1462                data.model,
1463                data.usage,
1464                &data.choices,
1465                |choice| CompatibleChoiceData {
1466                    finish_reason: if choice.finish_reason == Some(ChatFinishReason::ToolCalls) {
1467                        CompatibleFinishReason::ToolCalls
1468                    } else {
1469                        CompatibleFinishReason::Other
1470                    },
1471                    text: choice.delta.content.clone(),
1472                    reasoning: choice.delta.reasoning_content.clone(),
1473                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
1474                        &choice.delta.tool_calls,
1475                    ),
1476                    details: Vec::new(),
1477                },
1478            ),
1479        ))
1480    }
1481
1482    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
1483        CopilotStreamingResponse::Chat(openai::completion::streaming::StreamingCompletionResponse {
1484            usage,
1485        })
1486    }
1487
1488    fn uses_distinct_tool_call_eviction(&self) -> bool {
1489        true
1490    }
1491}
1492
1493async fn send_copilot_chat_streaming_request<T>(
1494    http_client: T,
1495    req: Request<Vec<u8>>,
1496) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError>
1497where
1498    T: HttpClientExt + Clone + 'static,
1499{
1500    openai_chat_completions_compatible::send_compatible_streaming_request(
1501        http_client,
1502        req,
1503        CopilotChatCompatibleProfile,
1504    )
1505    .await
1506}
1507
1508fn default_token_dir() -> Option<PathBuf> {
1509    config_dir().map(|dir| dir.join("github_copilot"))
1510}
1511
1512fn config_dir() -> Option<PathBuf> {
1513    #[cfg(target_os = "windows")]
1514    {
1515        std::env::var_os("APPDATA").map(PathBuf::from)
1516    }
1517
1518    #[cfg(not(target_os = "windows"))]
1519    {
1520        std::env::var_os("XDG_CONFIG_HOME")
1521            .map(PathBuf::from)
1522            .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
1523    }
1524}
1525
1526#[cfg(test)]
1527mod tests {
1528    use super::{
1529        ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute,
1530        TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token,
1531        route_for_model,
1532    };
1533    use crate::client::CompletionClient;
1534    use crate::completion::CompletionModel;
1535    use crate::http_client;
1536    use crate::providers::internal::openai_chat_completions_compatible::test_support::{
1537        sse_bytes_from_data_lines, sse_bytes_from_json_events,
1538    };
1539    use crate::streaming::StreamedAssistantContent;
1540    use crate::test_utils::MockStreamingClient;
1541    use crate::test_utils::{RecordingHttpClient, SequencedStreamingHttpClient};
1542    use futures::StreamExt;
1543    use std::collections::HashMap;
1544
1545    fn env_map(entries: &[(&str, &str)]) -> HashMap<String, String> {
1546        entries
1547            .iter()
1548            .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1549            .collect()
1550    }
1551
1552    fn minimal_chat_response() -> &'static str {
1553        r#"{
1554            "id": "chatcmpl-123",
1555            "model": "gpt-4o",
1556            "choices": [{
1557                "index": 0,
1558                "message": {
1559                    "role": "assistant",
1560                    "content": "hello"
1561                },
1562                "finish_reason": "stop"
1563            }],
1564            "usage": {
1565                "prompt_tokens": 4,
1566                "total_tokens": 7
1567            }
1568        }"#
1569    }
1570
1571    fn minimal_responses_response() -> &'static str {
1572        r#"{
1573            "id": "resp_123",
1574            "object": "response",
1575            "created_at": 1700000000,
1576            "status": "completed",
1577            "error": null,
1578            "incomplete_details": null,
1579            "instructions": null,
1580            "max_output_tokens": null,
1581            "model": "gpt-5.3-codex",
1582            "usage": {
1583                "input_tokens": 4,
1584                "input_tokens_details": {
1585                    "cached_tokens": 0
1586                },
1587                "output_tokens": 3,
1588                "output_tokens_details": {
1589                    "reasoning_tokens": 0
1590                },
1591                "total_tokens": 7
1592            },
1593            "output": [{
1594                "type": "message",
1595                "id": "msg_123",
1596                "role": "assistant",
1597                "status": "completed",
1598                "content": [{
1599                    "type": "output_text",
1600                    "text": "hello"
1601                }]
1602            }],
1603            "tools": []
1604        }"#
1605    }
1606
1607    fn minimal_embeddings_response() -> &'static str {
1608        r#"{
1609            "data": [
1610                {
1611                    "embedding": [0.1, 0.2, 0.3]
1612                },
1613                {
1614                    "embedding": [0.4, 0.5, 0.6]
1615                }
1616            ]
1617        }"#
1618    }
1619
1620    #[test]
1621    fn deserialize_standard_openai_response() {
1622        let json = r#"{
1623            "id": "chatcmpl-abc123",
1624            "object": "chat.completion",
1625            "created": 1700000000,
1626            "model": "gpt-4o",
1627            "choices": [{
1628                "index": 0,
1629                "message": {
1630                    "role": "assistant",
1631                    "content": "Hello!"
1632                },
1633                "finish_reason": "stop"
1634            }],
1635            "usage": {
1636                "prompt_tokens": 10,
1637                "completion_tokens": 5,
1638                "total_tokens": 15
1639            }
1640        }"#;
1641
1642        let response: ChatCompletionResponse =
1643            serde_json::from_str(json).expect("standard OpenAI response should deserialize");
1644        assert_eq!(response.id, "chatcmpl-abc123");
1645        assert_eq!(response.object.as_deref(), Some("chat.completion"));
1646        assert_eq!(response.created, Some(1700000000));
1647        assert_eq!(response.model, "gpt-4o");
1648        assert_eq!(response.choices.len(), 1);
1649        assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop"));
1650    }
1651
1652    #[test]
1653    fn deserialize_copilot_response_without_object_and_created() {
1654        let response: ChatCompletionResponse = serde_json::from_str(minimal_chat_response())
1655            .expect("Copilot response should deserialize");
1656
1657        assert_eq!(response.id, "chatcmpl-123");
1658        assert_eq!(response.object, None);
1659        assert_eq!(response.created, None);
1660        assert_eq!(response.model, "gpt-4o");
1661        assert_eq!(response.choices.len(), 1);
1662    }
1663
1664    #[test]
1665    fn deserialize_copilot_response_without_finish_reason() {
1666        let json = r#"{
1667            "id": "chatcmpl-claude-001",
1668            "model": "claude-3.5-sonnet",
1669            "choices": [{
1670                "message": {
1671                    "role": "assistant",
1672                    "content": "Here is my analysis."
1673                }
1674            }],
1675            "usage": {
1676                "prompt_tokens": 50,
1677                "total_tokens": 80
1678            }
1679        }"#;
1680
1681        let response: ChatCompletionResponse =
1682            serde_json::from_str(json).expect("Claude-via-Copilot response should deserialize");
1683
1684        assert_eq!(response.model, "claude-3.5-sonnet");
1685        assert_eq!(response.choices[0].finish_reason, None);
1686        assert_eq!(response.choices[0].index, 0);
1687    }
1688
1689    #[test]
1690    fn error_response_with_message_field() {
1691        let json = r#"{"message": "rate limit exceeded"}"#;
1692        let err: ChatApiErrorResponse = serde_json::from_str(json).expect("message-shaped error");
1693
1694        assert_eq!(err.error_message(), "rate limit exceeded");
1695    }
1696
1697    #[test]
1698    fn error_response_with_error_field() {
1699        let json = r#"{"error": "model not found"}"#;
1700        let err: ChatApiErrorResponse = serde_json::from_str(json).expect("error-shaped error");
1701
1702        assert_eq!(err.error_message(), "model not found");
1703    }
1704
1705    #[test]
1706    fn routes_codex_models_to_responses() {
1707        assert_eq!(route_for_model("gpt-5.3-codex"), CompletionRoute::Responses);
1708        assert_eq!(
1709            route_for_model("gpt-5.1-CODEX-mini"),
1710            CompletionRoute::Responses
1711        );
1712        assert_eq!(route_for_model("gpt-5.2"), CompletionRoute::ChatCompletions);
1713        assert_eq!(
1714            route_for_model("claude-sonnet-4.5"),
1715            CompletionRoute::ChatCompletions
1716        );
1717    }
1718
1719    #[tokio::test]
1720    async fn completion_model_routes_chat_requests_to_chat_completions() {
1721        let http_client = RecordingHttpClient::new(minimal_chat_response());
1722        let client = Client::builder()
1723            .api_key("copilot-token")
1724            .http_client(http_client.clone())
1725            .build()
1726            .expect("build client");
1727        let model = client.completion_model("gpt-4o");
1728        let request = model.completion_request("hello").build();
1729
1730        let _response = model.completion(request).await.expect("chat completion");
1731
1732        let requests = http_client.requests();
1733        assert_eq!(requests.len(), 1);
1734        assert!(requests[0].uri.ends_with("/chat/completions"));
1735        assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-4o\""));
1736    }
1737
1738    #[tokio::test]
1739    async fn completion_model_routes_codex_requests_to_responses() {
1740        let http_client = RecordingHttpClient::new(minimal_responses_response());
1741        let client = Client::builder()
1742            .api_key("copilot-token")
1743            .http_client(http_client.clone())
1744            .build()
1745            .expect("build client");
1746        let model = client.completion_model("gpt-5.3-codex");
1747        let request = model.completion_request("hello").build();
1748
1749        let _response = model
1750            .completion(request)
1751            .await
1752            .expect("responses completion");
1753
1754        let requests = http_client.requests();
1755        assert_eq!(requests.len(), 1);
1756        assert!(requests[0].uri.ends_with("/responses"));
1757        assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-5.3-codex\""));
1758    }
1759
1760    #[tokio::test]
1761    async fn embeddings_accept_minimal_copilot_response_shape() {
1762        use crate::client::EmbeddingsClient;
1763        use crate::embeddings::EmbeddingModel as _;
1764
1765        let http_client = RecordingHttpClient::new(minimal_embeddings_response());
1766        let client = Client::builder()
1767            .api_key("copilot-token")
1768            .http_client(http_client.clone())
1769            .build()
1770            .expect("build client");
1771        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1772
1773        let embeddings = model
1774            .embed_texts(["one".to_string(), "two".to_string()])
1775            .await
1776            .expect("embeddings should deserialize");
1777
1778        assert_eq!(embeddings.len(), 2);
1779        assert_eq!(embeddings[0].vec, vec![0.1, 0.2, 0.3]);
1780        assert_eq!(embeddings[1].vec, vec![0.4, 0.5, 0.6]);
1781
1782        let requests = http_client.requests();
1783        assert_eq!(requests.len(), 1);
1784        assert!(requests[0].uri.ends_with("/embeddings"));
1785        assert!(
1786            String::from_utf8_lossy(&requests[0].body)
1787                .contains("\"model\":\"text-embedding-3-small\"")
1788        );
1789    }
1790
1791    #[tokio::test]
1792    async fn responses_stream_terminates_after_terminal_error() {
1793        let tool_call_done = serde_json::json!({
1794            "type": "response.output_item.done",
1795            "sequence_number": 1,
1796            "item": {
1797                "type": "function_call",
1798                "id": "fc_123",
1799                "arguments": "{}",
1800                "call_id": "call_123",
1801                "name": "example_tool",
1802                "status": "completed"
1803            }
1804        });
1805        let failed = serde_json::json!({
1806            "type": "response.failed",
1807            "sequence_number": 2,
1808            "response": {
1809                "id": "resp_123",
1810                "object": "response",
1811                "created_at": 1700000000,
1812                "status": "failed",
1813                "error": {
1814                    "code": "server_error",
1815                    "message": "Copilot response stream failed"
1816                },
1817                "incomplete_details": null,
1818                "instructions": null,
1819                "max_output_tokens": null,
1820                "model": "gpt-5.3-codex",
1821                "usage": null,
1822                "output": [],
1823                "tools": []
1824            }
1825        });
1826        let http_client = MockStreamingClient {
1827            sse_bytes: sse_bytes_from_json_events(&[tool_call_done, failed]),
1828        };
1829        let client = Client::builder()
1830            .api_key("copilot-token")
1831            .http_client(http_client)
1832            .build()
1833            .expect("build client");
1834        let model = client.completion_model("gpt-5.3-codex");
1835        let request = model.completion_request("hello").build();
1836        let mut stream = model.stream(request).await.expect("stream should start");
1837
1838        let err = match stream.next().await.expect("stream should yield an item") {
1839            Ok(_) => panic!("stream should surface a provider error"),
1840            Err(err) => err,
1841        };
1842        assert_eq!(
1843            err.to_string(),
1844            "ProviderError: Copilot response stream failed"
1845        );
1846        assert!(
1847            stream.next().await.is_none(),
1848            "responses stream should terminate immediately after a terminal error"
1849        );
1850    }
1851
1852    #[tokio::test]
1853    async fn chat_stream_terminates_after_transport_error() {
1854        let chunks = vec![
1855            Ok(sse_bytes_from_data_lines([
1856                "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
1857            ])),
1858            Err(http_client::Error::InvalidStatusCode(
1859                http::StatusCode::BAD_GATEWAY,
1860            )),
1861        ];
1862
1863        let http_client = SequencedStreamingHttpClient::new(chunks);
1864        let client = Client::builder()
1865            .api_key("copilot-token")
1866            .http_client(http_client)
1867            .build()
1868            .expect("build client");
1869        let model = client.completion_model("gpt-4o");
1870        let request = model.completion_request("hello").build();
1871        let mut stream = model.stream(request).await.expect("stream should start");
1872
1873        let mut saw_error = false;
1874        while let Some(item) = stream.next().await {
1875            match item {
1876                Ok(StreamedAssistantContent::ToolCallDelta { .. }) => {}
1877                Err(err) => {
1878                    assert_eq!(
1879                        err.to_string(),
1880                        "ProviderError: Invalid status code: 502 Bad Gateway"
1881                    );
1882                    saw_error = true;
1883                    break;
1884                }
1885                Ok(_) => panic!("unexpected non-error stream item before transport failure"),
1886            }
1887        }
1888
1889        assert!(saw_error, "stream should surface the transport error");
1890        assert!(
1891            stream.next().await.is_none(),
1892            "chat stream should terminate immediately after a transport error"
1893        );
1894    }
1895
1896    #[test]
1897    fn env_api_key_prefers_github_prefixed_vars() {
1898        let env = env_map(&[
1899            ("COPILOT_API_KEY", "copilot-key"),
1900            ("GITHUB_COPILOT_API_KEY", "github-key"),
1901            ("GITHUB_TOKEN", "bootstrap-token"),
1902        ]);
1903        let get = |name: &str| env.get(name).cloned();
1904
1905        assert_eq!(env_api_key(&get).as_deref(), Some("github-key"));
1906    }
1907
1908    #[test]
1909    fn env_github_access_token_prefers_explicit_bootstrap_var() {
1910        let env = env_map(&[
1911            ("COPILOT_GITHUB_ACCESS_TOKEN", "explicit-bootstrap"),
1912            ("GITHUB_TOKEN", "fallback-bootstrap"),
1913        ]);
1914        let get = |name: &str| env.get(name).cloned();
1915
1916        assert_eq!(
1917            env_github_access_token(&get).as_deref(),
1918            Some("explicit-bootstrap")
1919        );
1920    }
1921
1922    #[test]
1923    fn env_base_url_prefers_github_prefixed_vars() {
1924        let env = env_map(&[
1925            ("COPILOT_BASE_URL", "https://copilot.example"),
1926            ("GITHUB_COPILOT_API_BASE", "https://github.example"),
1927        ]);
1928        let get = |name: &str| env.get(name).cloned();
1929
1930        assert_eq!(
1931            env_base_url(&get).as_deref(),
1932            Some("https://github.example")
1933        );
1934    }
1935
1936    #[test]
1937    fn env_without_api_key_falls_back_to_oauth() {
1938        let env = env_map(&[("COPILOT_BASE_URL", "https://copilot.example")]);
1939        let get = |name: &str| env.get(name).cloned();
1940
1941        assert!(env_api_key(&get).is_none());
1942        assert!(env_github_access_token(&get).is_none());
1943        assert_eq!(
1944            env_base_url(&get).as_deref(),
1945            Some("https://copilot.example")
1946        );
1947    }
1948
1949    #[test]
1950    fn env_github_token_is_not_treated_as_copilot_api_key() {
1951        let env = env_map(&[("GITHUB_TOKEN", "bootstrap-token")]);
1952        let get = |name: &str| env.get(name).cloned();
1953
1954        assert!(env_api_key(&get).is_none());
1955        assert_eq!(
1956            env_github_access_token(&get).as_deref(),
1957            Some("bootstrap-token")
1958        );
1959    }
1960}