Skip to main content

rig/providers/copilot/
mod.rs

1//! GitHub Copilot provider.
2//!
3//! Supports Chat Completions, Responses, and Embeddings against
4//! `https://api.githubcopilot.com`.
5//!
6//! `Client::completion_model(...)` automatically routes Codex-class models
7//! through `/responses` and conversational models through
8//! `/chat/completions`.
9//!
10//! # Example
11//! ```no_run
12//! use rig::client::{CompletionClient, ProviderClient};
13//! use rig::providers::copilot;
14//!
15//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
16//! let client = copilot::Client::from_env()?;
17//! let model = client.completion_model(copilot::GPT_4O);
18//! # let _ = model;
19//! # Ok(())
20//! # }
21//! ```
22
23mod auth;
24
25use crate::client::{
26    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27    ProviderClient, Transport,
28};
29use crate::completion::{self, CompletionError, GetTokenUsage};
30use crate::embeddings::{self, EmbeddingError};
31use crate::http_client::{self, HttpClientExt};
32use crate::providers::internal::openai_chat_completions_compatible::{
33    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
34    CompatibleToolCallChunk,
35};
36use crate::providers::openai;
37use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest};
38use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse};
39use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
40use async_stream::stream;
41use futures::StreamExt;
42use http::Request;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45use std::borrow::Cow;
46use std::collections::HashMap;
47use std::fmt::Debug;
48use std::path::{Path, PathBuf};
49use tracing::info_span;
50use tracing_futures::Instrument as _;
51
52const GITHUB_COPILOT_API_BASE_URL: &str = "https://api.githubcopilot.com";
53const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
54const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
55const API_VERSION: &str = "2025-04-01";
56
57/// `gpt-4`
58pub const GPT_4: &str = "gpt-4";
59/// `gpt-4o`
60pub const GPT_4O: &str = "gpt-4o";
61/// `gpt-4o-mini`
62pub const GPT_4O_MINI: &str = "gpt-4o-mini";
63/// `gpt-4.1`
64pub const GPT_4_1: &str = "gpt-4.1";
65/// `gpt-4.1-mini`
66pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
67/// `gpt-4.1-nano`
68pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
69/// `gpt-5.3-codex`
70pub const GPT_5_3_CODEX: &str = "gpt-5.3-codex";
71/// `gpt-5.1-codex`
72pub const GPT_5_1_CODEX: &str = "gpt-5.1-codex";
73/// `claude-sonnet-4` completion model (Anthropic, via Copilot)
74pub const CLAUDE_SONNET_4: &str = "claude-sonnet-4";
75/// `claude-3.5-sonnet` completion model (Anthropic, via Copilot)
76pub const CLAUDE_3_5_SONNET: &str = "claude-3.5-sonnet";
77/// `gemini-2.0-flash-001` completion model (Google, via Copilot)
78pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash-001";
79/// `o3-mini` reasoning model (OpenAI, via Copilot)
80pub const O3_MINI: &str = "o3-mini";
81/// `text-embedding-3-small`
82pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
83/// `text-embedding-3-large`
84pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
85/// `text-embedding-ada-002`
86pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
87
88pub use openai::EncodingFormat;
89
90#[derive(Clone)]
91pub enum CopilotAuth {
92    ApiKey(String),
93    GitHubAccessToken(String),
94    OAuth,
95}
96
97impl ApiKey for CopilotAuth {}
98
99impl<S> From<S> for CopilotAuth
100where
101    S: Into<String>,
102{
103    fn from(value: S) -> Self {
104        Self::ApiKey(value.into())
105    }
106}
107
108impl Debug for CopilotAuth {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self {
111            Self::ApiKey(_) => f.write_str("ApiKey(<redacted>)"),
112            Self::GitHubAccessToken(_) => f.write_str("GitHubAccessToken(<redacted>)"),
113            Self::OAuth => f.write_str("OAuth"),
114        }
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct CopilotBuilder {
120    access_token_file: Option<PathBuf>,
121    api_key_file: Option<PathBuf>,
122    device_code_handler: auth::DeviceCodeHandler,
123}
124
125#[derive(Clone)]
126pub struct CopilotExt {
127    auth: auth::Authenticator,
128}
129
130impl Debug for CopilotExt {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("CopilotExt")
133            .field("auth", &self.auth)
134            .finish()
135    }
136}
137
138pub type Client<H = reqwest::Client> = client::Client<CopilotExt, H>;
139pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CopilotBuilder, CopilotAuth, H>;
140
141impl Default for CopilotBuilder {
142    fn default() -> Self {
143        let token_dir = default_token_dir();
144        Self {
145            access_token_file: token_dir.as_ref().map(|dir| dir.join("access-token")),
146            api_key_file: token_dir.map(|dir| dir.join("api-key.json")),
147            device_code_handler: auth::DeviceCodeHandler::default(),
148        }
149    }
150}
151
152impl Provider for CopilotExt {
153    type Builder = CopilotBuilder;
154
155    const VERIFY_PATH: &'static str = "";
156}
157
158impl<H> Capabilities<H> for CopilotExt {
159    type Completion = Capable<CompletionModel<H>>;
160    type Embeddings = Capable<EmbeddingModel<H>>;
161    type Transcription = Nothing;
162    type ModelListing = Nothing;
163    #[cfg(feature = "image")]
164    type ImageGeneration = Nothing;
165    #[cfg(feature = "audio")]
166    type AudioGeneration = Nothing;
167}
168
169impl DebugExt for CopilotExt {}
170
171impl ProviderBuilder for CopilotBuilder {
172    type Extension<H>
173        = CopilotExt
174    where
175        H: HttpClientExt;
176    type ApiKey = CopilotAuth;
177
178    const BASE_URL: &'static str = GITHUB_COPILOT_API_BASE_URL;
179
180    fn build<H>(
181        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
182    ) -> http_client::Result<Self::Extension<H>>
183    where
184        H: HttpClientExt,
185    {
186        let auth = match builder.get_api_key() {
187            CopilotAuth::ApiKey(api_key) => auth::AuthSource::ApiKey(api_key.clone()),
188            CopilotAuth::GitHubAccessToken(access_token) => {
189                auth::AuthSource::GitHubAccessToken(access_token.clone())
190            }
191            CopilotAuth::OAuth => auth::AuthSource::OAuth,
192        };
193
194        let ext = builder.ext();
195        Ok(CopilotExt {
196            auth: auth::Authenticator::new(
197                auth,
198                ext.access_token_file.clone(),
199                ext.api_key_file.clone(),
200                ext.device_code_handler.clone(),
201            ),
202        })
203    }
204}
205
206impl ProviderClient for Client {
207    type Input = CopilotAuth;
208    type Error = crate::client::ProviderClientError;
209
210    fn from_env() -> Result<Self, Self::Error> {
211        let mut builder = Self::builder();
212        fn get(name: &str) -> Option<String> {
213            std::env::var(name).ok()
214        }
215
216        if let Some(base_url) = env_base_url(&get) {
217            builder = builder.base_url(base_url);
218        }
219
220        if let Some(api_key) = env_api_key(&get) {
221            builder.api_key(api_key).build().map_err(Into::into)
222        } else if let Some(access_token) = env_github_access_token(&get) {
223            builder
224                .github_access_token(access_token)
225                .build()
226                .map_err(Into::into)
227        } else {
228            builder.oauth().build().map_err(Into::into)
229        }
230    }
231
232    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
233        Self::builder().api_key(input).build().map_err(Into::into)
234    }
235}
236
237impl<H> client::ClientBuilder<CopilotBuilder, crate::markers::Missing, H> {
238    pub fn github_access_token(
239        self,
240        access_token: impl Into<String>,
241    ) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
242        self.api_key(CopilotAuth::GitHubAccessToken(access_token.into()))
243    }
244
245    pub fn oauth(self) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
246        self.api_key(CopilotAuth::OAuth)
247    }
248}
249
250impl<H> ClientBuilder<H> {
251    pub fn on_device_code<F>(self, handler: F) -> Self
252    where
253        F: Fn(auth::DeviceCodePrompt) + Send + Sync + 'static,
254    {
255        self.over_ext(|mut ext| {
256            ext.device_code_handler = auth::DeviceCodeHandler::new(handler);
257            ext
258        })
259    }
260
261    pub fn token_dir(self, path: impl AsRef<Path>) -> Self {
262        let path = path.as_ref();
263        self.over_ext(|mut ext| {
264            ext.access_token_file = Some(path.join("access-token"));
265            ext.api_key_file = Some(path.join("api-key.json"));
266            ext
267        })
268    }
269
270    pub fn access_token_file(self, path: impl AsRef<Path>) -> Self {
271        let path = path.as_ref().to_path_buf();
272        self.over_ext(|mut ext| {
273            ext.access_token_file = Some(path);
274            ext
275        })
276    }
277
278    pub fn api_key_file(self, path: impl AsRef<Path>) -> Self {
279        let path = path.as_ref().to_path_buf();
280        self.over_ext(|mut ext| {
281            ext.api_key_file = Some(path);
282            ext
283        })
284    }
285}
286
287fn env_value<F>(get: &F, name: &str) -> Option<String>
288where
289    F: Fn(&str) -> Option<String>,
290{
291    get(name).filter(|value| !value.trim().is_empty())
292}
293
294fn first_env_value<F>(get: &F, keys: &[&str]) -> Option<String>
295where
296    F: Fn(&str) -> Option<String>,
297{
298    keys.iter().find_map(|key| env_value(get, key))
299}
300
301fn env_api_key<F>(get: &F) -> Option<String>
302where
303    F: Fn(&str) -> Option<String>,
304{
305    first_env_value(get, &["GITHUB_COPILOT_API_KEY", "COPILOT_API_KEY"])
306}
307
308fn env_github_access_token<F>(get: &F) -> Option<String>
309where
310    F: Fn(&str) -> Option<String>,
311{
312    first_env_value(get, &["COPILOT_GITHUB_ACCESS_TOKEN", "GITHUB_TOKEN"])
313}
314
315fn env_base_url<F>(get: &F) -> Option<String>
316where
317    F: Fn(&str) -> Option<String>,
318{
319    first_env_value(get, &["GITHUB_COPILOT_API_BASE", "COPILOT_BASE_URL"])
320}
321
322impl<H> Client<H>
323where
324    H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
325{
326    pub async fn authorize(&self) -> Result<(), auth::AuthError> {
327        self.ext().auth.auth_context().await.map(|_| ())
328    }
329}
330
331fn default_headers(
332    api_key: &str,
333    initiator: &'static str,
334    has_vision: bool,
335) -> Vec<(&'static str, String)> {
336    let mut headers = vec![
337        (
338            http::header::AUTHORIZATION.as_str(),
339            format!("Bearer {api_key}"),
340        ),
341        ("copilot-integration-id", "vscode-chat".to_string()),
342        ("editor-version", "vscode/1.95.0".to_string()),
343        ("editor-plugin-version", EDITOR_PLUGIN_VERSION.to_string()),
344        ("user-agent", USER_AGENT.to_string()),
345        ("openai-intent", "conversation-panel".to_string()),
346        ("x-github-api-version", API_VERSION.to_string()),
347        ("x-request-id", nanoid::nanoid!()),
348        (
349            "x-vscode-user-agent-library-version",
350            "electron-fetch".to_string(),
351        ),
352        ("X-Initiator", initiator.to_string()),
353    ];
354
355    if has_vision {
356        headers.push(("copilot-vision-request", "true".to_string()));
357    }
358
359    headers
360}
361
362fn apply_headers(
363    builder: http_client::Builder,
364    headers: &[(&'static str, String)],
365) -> http_client::Builder {
366    headers
367        .iter()
368        .fold(builder, |builder, (key, value)| builder.header(*key, value))
369}
370
371fn runtime_base_url<'a, H>(client: &'a Client<H>, auth: &'a auth::AuthContext) -> Cow<'a, str> {
372    if client.base_url() == GITHUB_COPILOT_API_BASE_URL {
373        auth.api_base
374            .as_deref()
375            .map(Cow::Borrowed)
376            .unwrap_or_else(|| Cow::Borrowed(client.base_url()))
377    } else {
378        Cow::Borrowed(client.base_url())
379    }
380}
381
382fn post_with_auth_base<H>(
383    client: &Client<H>,
384    auth: &auth::AuthContext,
385    path: &str,
386    transport: Transport,
387) -> http_client::Result<http_client::Builder> {
388    let uri = client
389        .ext()
390        .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
391    let mut req = Request::post(uri);
392
393    if let Some(headers) = req.headers_mut() {
394        headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
395    }
396
397    client.ext().with_custom(req)
398}
399
400fn request_initiator(request: &completion::CompletionRequest) -> &'static str {
401    for message in request.chat_history.iter() {
402        match message {
403            crate::completion::Message::Assistant { .. } => return "agent",
404            crate::completion::Message::User { content } => {
405                if content
406                    .iter()
407                    .any(|item| matches!(item, crate::message::UserContent::ToolResult(_)))
408                {
409                    return "agent";
410                }
411            }
412            crate::completion::Message::System { .. } => {}
413        }
414    }
415
416    "user"
417}
418
419fn request_has_vision(request: &completion::CompletionRequest) -> bool {
420    request.chat_history.iter().any(|message| match message {
421        crate::completion::Message::User { content } => content
422            .iter()
423            .any(|item| matches!(item, crate::message::UserContent::Image(_))),
424        _ => false,
425    })
426}
427
428#[derive(Clone, Copy, Debug, PartialEq, Eq)]
429enum CompletionRoute {
430    ChatCompletions,
431    Responses,
432}
433
434fn route_for_model(model: &str) -> CompletionRoute {
435    if model.to_ascii_lowercase().contains("codex") {
436        CompletionRoute::Responses
437    } else {
438        CompletionRoute::ChatCompletions
439    }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
443#[serde(tag = "api", rename_all = "snake_case")]
444pub enum CopilotCompletionResponse {
445    Chat(ChatCompletionResponse),
446    Responses(Box<responses_api::CompletionResponse>),
447}
448
449#[derive(Clone, Serialize, Deserialize)]
450#[serde(tag = "api", rename_all = "snake_case")]
451pub enum CopilotStreamingResponse {
452    Chat(openai::completion::streaming::StreamingCompletionResponse),
453    Responses(responses_api::streaming::StreamingCompletionResponse),
454}
455
456impl GetTokenUsage for CopilotStreamingResponse {
457    fn token_usage(&self) -> Option<completion::Usage> {
458        match self {
459            Self::Chat(response) => response.token_usage(),
460            Self::Responses(response) => response.token_usage(),
461        }
462    }
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ChatCompletionResponse {
467    pub id: String,
468    #[serde(default)]
469    pub object: Option<String>,
470    #[serde(default)]
471    pub created: Option<u64>,
472    pub model: String,
473    pub system_fingerprint: Option<String>,
474    pub choices: Vec<ChatChoice>,
475    pub usage: Option<openai::completion::Usage>,
476}
477
478#[derive(Clone, Debug, Serialize, Deserialize)]
479pub struct ChatChoice {
480    #[serde(default)]
481    pub index: usize,
482    pub message: openai::completion::Message,
483    pub logprobs: Option<serde_json::Value>,
484    #[serde(default)]
485    pub finish_reason: Option<String>,
486}
487
488impl TryFrom<ChatCompletionResponse> for completion::CompletionResponse<ChatCompletionResponse> {
489    type Error = CompletionError;
490
491    fn try_from(response: ChatCompletionResponse) -> Result<Self, Self::Error> {
492        let choice = response.choices.first().ok_or_else(|| {
493            CompletionError::ResponseError("Response contained no choices".to_owned())
494        })?;
495
496        let content = match &choice.message {
497            openai::completion::Message::Assistant {
498                content,
499                tool_calls,
500                ..
501            } => {
502                let mut content = content
503                    .iter()
504                    .filter_map(|c| {
505                        let s = match c {
506                            openai::completion::AssistantContent::Text { text } => text,
507                            openai::completion::AssistantContent::Refusal { refusal } => refusal,
508                        };
509                        if s.is_empty() {
510                            None
511                        } else {
512                            Some(completion::AssistantContent::text(s))
513                        }
514                    })
515                    .collect::<Vec<_>>();
516
517                content.extend(
518                    tool_calls
519                        .iter()
520                        .map(|call| {
521                            completion::AssistantContent::tool_call(
522                                &call.id,
523                                &call.function.name,
524                                call.function.arguments.clone(),
525                            )
526                        })
527                        .collect::<Vec<_>>(),
528                );
529                Ok(content)
530            }
531            _ => Err(CompletionError::ResponseError(
532                "Response did not contain a valid message or tool call".into(),
533            )),
534        }?;
535
536        let choice = crate::OneOrMany::many(content).map_err(|_| {
537            CompletionError::ResponseError(
538                "Response contained no message or tool call (empty)".to_owned(),
539            )
540        })?;
541
542        let usage = response
543            .usage
544            .as_ref()
545            .map(|usage| completion::Usage {
546                input_tokens: usage.prompt_tokens as u64,
547                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
548                total_tokens: usage.total_tokens as u64,
549                cached_input_tokens: usage
550                    .prompt_tokens_details
551                    .as_ref()
552                    .map(|d| d.cached_tokens as u64)
553                    .unwrap_or(0),
554                cache_creation_input_tokens: 0,
555            })
556            .unwrap_or_default();
557
558        Ok(completion::CompletionResponse {
559            choice,
560            usage,
561            raw_response: response,
562            message_id: None,
563        })
564    }
565}
566
567#[derive(Debug, Deserialize)]
568pub struct ChatApiErrorResponse {
569    #[serde(default)]
570    pub message: Option<String>,
571    #[serde(default)]
572    pub error: Option<String>,
573}
574
575impl ChatApiErrorResponse {
576    pub fn error_message(&self) -> &str {
577        self.message
578            .as_deref()
579            .or(self.error.as_deref())
580            .unwrap_or("unknown error")
581    }
582}
583
584#[derive(Debug, Deserialize)]
585#[serde(untagged)]
586enum ChatApiResponse<T> {
587    Ok(T),
588    Err(ChatApiErrorResponse),
589}
590
591#[derive(Clone)]
592pub struct CompletionModel<H = reqwest::Client> {
593    client: Client<H>,
594    pub model: String,
595    pub strict_tools: bool,
596    pub tool_result_array_content: bool,
597}
598
599impl<H> CompletionModel<H>
600where
601    Client<H>: HttpClientExt + Clone + Debug + 'static,
602    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
603{
604    pub fn new(client: Client<H>, model: impl Into<String>) -> Self {
605        Self {
606            client,
607            model: model.into(),
608            strict_tools: false,
609            tool_result_array_content: false,
610        }
611    }
612
613    pub fn with_strict_tools(mut self) -> Self {
614        self.strict_tools = true;
615        self
616    }
617
618    pub fn with_tool_result_array_content(mut self) -> Self {
619        self.tool_result_array_content = true;
620        self
621    }
622
623    fn route(&self) -> CompletionRoute {
624        route_for_model(&self.model)
625    }
626
627    async fn auth_context(&self) -> Result<auth::AuthContext, CompletionError> {
628        self.client
629            .ext()
630            .auth
631            .auth_context()
632            .await
633            .map_err(|err| CompletionError::ProviderError(err.to_string()))
634    }
635
636    fn chat_request(
637        &self,
638        completion_request: completion::CompletionRequest,
639    ) -> Result<openai::completion::CompletionRequest, CompletionError> {
640        openai::completion::CompletionRequest::try_from(openai::completion::OpenAIRequestParams {
641            model: self.model.clone(),
642            request: completion_request,
643            strict_tools: self.strict_tools,
644            tool_result_array_content: self.tool_result_array_content,
645        })
646    }
647
648    fn responses_request(
649        &self,
650        completion_request: completion::CompletionRequest,
651    ) -> Result<ResponsesRequest, CompletionError> {
652        ResponsesRequest::try_from((self.model.clone(), completion_request))
653    }
654
655    async fn completion_chat(
656        &self,
657        completion_request: completion::CompletionRequest,
658    ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
659        let initiator = request_initiator(&completion_request);
660        let has_vision = request_has_vision(&completion_request);
661        let request = self.chat_request(completion_request)?;
662        let body = serde_json::to_vec(&request)?;
663        let auth = self.auth_context().await?;
664
665        let headers = default_headers(&auth.api_key, initiator, has_vision);
666        let req = apply_headers(
667            post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Http)?,
668            &headers,
669        )
670        .body(body)
671        .map_err(|err| CompletionError::HttpError(err.into()))?;
672
673        let span = if tracing::Span::current().is_disabled() {
674            info_span!(
675                target: "rig::completions",
676                "chat",
677                gen_ai.operation.name = "chat",
678                gen_ai.provider.name = "copilot",
679                gen_ai.request.model = self.model,
680                gen_ai.response.id = tracing::field::Empty,
681                gen_ai.response.model = tracing::field::Empty,
682                gen_ai.usage.output_tokens = tracing::field::Empty,
683                gen_ai.usage.input_tokens = tracing::field::Empty,
684                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
685            )
686        } else {
687            tracing::Span::current()
688        };
689
690        async move {
691            let response = self.client.send(req).await?;
692
693            if response.status().is_success() {
694                let body = http_client::text(response).await?;
695                match serde_json::from_str::<ChatApiResponse<ChatCompletionResponse>>(&body)? {
696                    ChatApiResponse::Ok(response) => {
697                        let core = completion::CompletionResponse::try_from(response.clone())?;
698                        let span = tracing::Span::current();
699                        span.record("gen_ai.response.id", response.id.as_str());
700                        span.record("gen_ai.response.model", response.model.as_str());
701                        if let Some(usage) = &response.usage {
702                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
703                            span.record(
704                                "gen_ai.usage.output_tokens",
705                                usage.total_tokens - usage.prompt_tokens,
706                            );
707                            span.record(
708                                "gen_ai.usage.cache_read.input_tokens",
709                                usage
710                                    .prompt_tokens_details
711                                    .as_ref()
712                                    .map(|details| details.cached_tokens)
713                                    .unwrap_or(0),
714                            );
715                        }
716
717                        Ok(completion::CompletionResponse {
718                            choice: core.choice,
719                            usage: core.usage,
720                            raw_response: CopilotCompletionResponse::Chat(response),
721                            message_id: core.message_id,
722                        })
723                    }
724                    ChatApiResponse::Err(err) => Err(CompletionError::ProviderError(
725                        err.error_message().to_string(),
726                    )),
727                }
728            } else {
729                let body = http_client::text(response).await?;
730                Err(CompletionError::ProviderError(body))
731            }
732        }
733        .instrument(span)
734        .await
735    }
736
737    async fn completion_responses(
738        &self,
739        completion_request: completion::CompletionRequest,
740    ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
741        let initiator = request_initiator(&completion_request);
742        let has_vision = request_has_vision(&completion_request);
743        let request = self.responses_request(completion_request)?;
744        let auth = self.auth_context().await?;
745
746        let headers = default_headers(&auth.api_key, initiator, has_vision);
747        let req = apply_headers(
748            post_with_auth_base(&self.client, &auth, "/responses", Transport::Http)?,
749            &headers,
750        )
751        .body(serde_json::to_vec(&request)?)
752        .map_err(|err| CompletionError::HttpError(err.into()))?;
753
754        let span = if tracing::Span::current().is_disabled() {
755            info_span!(
756                target: "rig::completions",
757                "chat",
758                gen_ai.operation.name = "chat",
759                gen_ai.provider.name = "copilot",
760                gen_ai.request.model = self.model,
761                gen_ai.response.id = tracing::field::Empty,
762                gen_ai.response.model = tracing::field::Empty,
763                gen_ai.usage.output_tokens = tracing::field::Empty,
764                gen_ai.usage.input_tokens = tracing::field::Empty,
765                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
766            )
767        } else {
768            tracing::Span::current()
769        };
770
771        async move {
772            let response = self.client.send(req).await?;
773            if response.status().is_success() {
774                let body = http_client::text(response).await?;
775                let response = serde_json::from_str::<responses_api::CompletionResponse>(&body)?;
776                let core = completion::CompletionResponse::try_from(response.clone())?;
777
778                let span = tracing::Span::current();
779                span.record("gen_ai.response.id", response.id.as_str());
780                span.record("gen_ai.response.model", response.model.as_str());
781                if let Some(usage) = &response.usage {
782                    span.record("gen_ai.usage.input_tokens", usage.input_tokens);
783                    span.record("gen_ai.usage.output_tokens", usage.output_tokens);
784                    span.record(
785                        "gen_ai.usage.cache_read.input_tokens",
786                        usage
787                            .input_tokens_details
788                            .as_ref()
789                            .map(|details| details.cached_tokens)
790                            .unwrap_or(0),
791                    );
792                }
793
794                Ok(completion::CompletionResponse {
795                    choice: core.choice,
796                    usage: core.usage,
797                    raw_response: CopilotCompletionResponse::Responses(Box::new(response)),
798                    message_id: core.message_id,
799                })
800            } else {
801                let body = http_client::text(response).await?;
802                Err(CompletionError::ProviderError(body))
803            }
804        }
805        .instrument(span)
806        .await
807    }
808
809    async fn stream_chat(
810        &self,
811        completion_request: completion::CompletionRequest,
812    ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
813        let initiator = request_initiator(&completion_request);
814        let has_vision = request_has_vision(&completion_request);
815        let request = self.chat_request(completion_request)?;
816        let auth = self.auth_context().await?;
817        let headers = default_headers(&auth.api_key, initiator, has_vision);
818        let mut request_json = serde_json::to_value(&request)?;
819        let request_object = request_json.as_object_mut().ok_or_else(|| {
820            CompletionError::ResponseError("copilot request body must be a JSON object".into())
821        })?;
822        request_object.insert("stream".to_owned(), json!(true));
823        request_object.insert(
824            "stream_options".to_owned(),
825            json!({ "include_usage": true }),
826        );
827
828        let req = apply_headers(
829            post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?,
830            &headers,
831        )
832        .body(serde_json::to_vec(&request_json)?)
833        .map_err(|err| CompletionError::HttpError(err.into()))?;
834
835        let span = if tracing::Span::current().is_disabled() {
836            info_span!(
837                target: "rig::completions",
838                "chat_streaming",
839                gen_ai.operation.name = "chat_streaming",
840                gen_ai.provider.name = "copilot",
841                gen_ai.request.model = self.model,
842                gen_ai.response.id = tracing::field::Empty,
843                gen_ai.response.model = tracing::field::Empty,
844                gen_ai.usage.output_tokens = tracing::field::Empty,
845                gen_ai.usage.input_tokens = tracing::field::Empty,
846                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
847            )
848        } else {
849            tracing::Span::current()
850        };
851
852        tracing::Instrument::instrument(
853            send_copilot_chat_streaming_request(self.client.clone(), req),
854            span,
855        )
856        .await
857    }
858
859    async fn stream_responses(
860        &self,
861        completion_request: completion::CompletionRequest,
862    ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
863        let initiator = request_initiator(&completion_request);
864        let has_vision = request_has_vision(&completion_request);
865        let mut request = self.responses_request(completion_request)?;
866        request.stream = Some(true);
867        let auth = self.auth_context().await?;
868
869        let headers = default_headers(&auth.api_key, initiator, has_vision);
870        let req = apply_headers(
871            post_with_auth_base(&self.client, &auth, "/responses", Transport::Sse)?,
872            &headers,
873        )
874        .body(serde_json::to_vec(&request)?)
875        .map_err(|err| CompletionError::HttpError(err.into()))?;
876
877        let span = if tracing::Span::current().is_disabled() {
878            info_span!(
879                target: "rig::completions",
880                "chat_streaming",
881                gen_ai.operation.name = "chat_streaming",
882                gen_ai.provider.name = "copilot",
883                gen_ai.request.model = self.model,
884                gen_ai.response.id = tracing::field::Empty,
885                gen_ai.response.model = tracing::field::Empty,
886                gen_ai.usage.output_tokens = tracing::field::Empty,
887                gen_ai.usage.input_tokens = tracing::field::Empty,
888                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
889            )
890        } else {
891            tracing::Span::current()
892        };
893
894        let client = self.client.clone();
895        let mut event_source = crate::http_client::sse::GenericEventSource::new(client, req);
896
897        let stream = tracing_futures::Instrument::instrument(
898            stream! {
899                let mut final_usage = responses_api::ResponsesUsage::new();
900                let mut tool_calls: Vec<streaming::RawStreamingChoice<CopilotStreamingResponse>> = Vec::new();
901                let mut tool_call_internal_ids: HashMap<String, String> = HashMap::new();
902                let span = tracing::Span::current();
903
904                let mut terminated_with_error = false;
905
906                while let Some(event_result) = event_source.next().await {
907                    match event_result {
908                        Ok(crate::http_client::sse::Event::Open) => continue,
909                        Ok(crate::http_client::sse::Event::Message(evt)) => {
910                            if evt.data.trim().is_empty() {
911                                continue;
912                            }
913
914                            let Ok(data) = serde_json::from_str::<responses_api::streaming::StreamingCompletionChunk>(&evt.data) else {
915                                continue;
916                            };
917
918                            if let responses_api::streaming::StreamingCompletionChunk::Delta(chunk) = &data {
919                                use responses_api::streaming::{ItemChunkKind, StreamingItemDoneOutput};
920
921                                match &chunk.data {
922                                    ItemChunkKind::OutputItemAdded(message) => {
923                                        if let StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } = message {
924                                            let internal_call_id = tool_call_internal_ids
925                                                .entry(func.id.clone())
926                                                .or_insert_with(|| nanoid::nanoid!())
927                                                .clone();
928                                            yield Ok(RawStreamingChoice::ToolCallDelta {
929                                                id: func.id.clone(),
930                                                internal_call_id,
931                                                content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
932                                            });
933                                        }
934                                    }
935                                    ItemChunkKind::OutputItemDone(message) => match message {
936                                        StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } => {
937                                            let internal_id = tool_call_internal_ids
938                                                .entry(func.id.clone())
939                                                .or_insert_with(|| nanoid::nanoid!())
940                                                .clone();
941                                            let raw_tool_call = streaming::RawStreamingToolCall::new(
942                                                func.id.clone(),
943                                                func.name.clone(),
944                                                func.arguments.clone(),
945                                            )
946                                            .with_internal_call_id(internal_id)
947                                            .with_call_id(func.call_id.clone());
948                                            tool_calls.push(RawStreamingChoice::ToolCall(raw_tool_call));
949                                        }
950                                        StreamingItemDoneOutput { item: responses_api::Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
951                                            for reasoning_choice in responses_api::streaming::reasoning_choices_from_done_item(
952                                                id,
953                                                summary,
954                                                encrypted_content.as_deref(),
955                                            ) {
956                                                match reasoning_choice {
957                                                    RawStreamingChoice::Reasoning { id, content } => {
958                                                        yield Ok(RawStreamingChoice::Reasoning { id, content });
959                                                    }
960                                                    RawStreamingChoice::ReasoningDelta { id, reasoning } => {
961                                                        yield Ok(RawStreamingChoice::ReasoningDelta { id, reasoning });
962                                                    }
963                                                    _ => {}
964                                                }
965                                            }
966                                        }
967                                        StreamingItemDoneOutput { item: responses_api::Output::Message(msg), .. } => {
968                                            yield Ok(RawStreamingChoice::MessageId(msg.id.clone()));
969                                        }
970                                        StreamingItemDoneOutput { item: responses_api::Output::Unknown, .. } => {}
971                                    },
972                                    ItemChunkKind::OutputTextDelta(delta) => {
973                                        yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
974                                    }
975                                    ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
976                                        yield Ok(RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
977                                    }
978                                    ItemChunkKind::RefusalDelta(delta) => {
979                                        yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
980                                    }
981                                    ItemChunkKind::FunctionCallArgsDelta(delta) => {
982                                        let internal_call_id = tool_call_internal_ids
983                                            .entry(delta.item_id.clone())
984                                            .or_insert_with(|| nanoid::nanoid!())
985                                            .clone();
986                                        yield Ok(RawStreamingChoice::ToolCallDelta {
987                                            id: delta.item_id.clone(),
988                                            internal_call_id,
989                                            content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
990                                        })
991                                    }
992                                    _ => continue,
993                                }
994                            }
995
996                            if let responses_api::streaming::StreamingCompletionChunk::Response(chunk) = data {
997                                let responses_api::streaming::ResponseChunk { kind, response, .. } = *chunk;
998                                match kind {
999                                    responses_api::streaming::ResponseChunkKind::ResponseCompleted => {
1000                                        span.record("gen_ai.response.id", response.id.as_str());
1001                                        span.record("gen_ai.response.model", response.model.as_str());
1002                                        if let Some(usage) = response.usage {
1003                                            final_usage = usage;
1004                                        }
1005                                    }
1006                                    responses_api::streaming::ResponseChunkKind::ResponseFailed
1007                                    | responses_api::streaming::ResponseChunkKind::ResponseIncomplete => {
1008                                        let error = response
1009                                            .error
1010                                            .as_ref()
1011                                            .map(|err| err.message.clone())
1012                                            .unwrap_or_else(|| "Copilot response stream failed".into());
1013                                        terminated_with_error = true;
1014                                        yield Err(CompletionError::ProviderError(error));
1015                                        break;
1016                                    }
1017                                    _ => continue,
1018                                }
1019                            }
1020                        }
1021                        Err(crate::http_client::Error::StreamEnded) => {
1022                            break;
1023                        }
1024                        Err(error) => {
1025                            terminated_with_error = true;
1026                            yield Err(CompletionError::ProviderError(error.to_string()));
1027                            break;
1028                        }
1029                    }
1030                }
1031
1032                event_source.close();
1033
1034                if terminated_with_error {
1035                    return;
1036                }
1037
1038                for tool_call in &tool_calls {
1039                    yield Ok(tool_call.to_owned())
1040                }
1041
1042                span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
1043                span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
1044                span.record(
1045                    "gen_ai.usage.cache_read.input_tokens",
1046                    final_usage
1047                        .input_tokens_details
1048                        .as_ref()
1049                        .map(|details| details.cached_tokens)
1050                        .unwrap_or(0),
1051                );
1052
1053                yield Ok(RawStreamingChoice::FinalResponse(
1054                    CopilotStreamingResponse::Responses(
1055                        responses_api::streaming::StreamingCompletionResponse { usage: final_usage }
1056                    )
1057                ));
1058            },
1059            span,
1060        );
1061
1062        Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
1063    }
1064}
1065
1066impl<H> completion::CompletionModel for CompletionModel<H>
1067where
1068    Client<H>: HttpClientExt + Clone + Debug + 'static,
1069    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1070{
1071    type Response = CopilotCompletionResponse;
1072    type StreamingResponse = CopilotStreamingResponse;
1073    type Client = Client<H>;
1074
1075    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1076        Self::new(client.clone(), model)
1077    }
1078
1079    async fn completion(
1080        &self,
1081        completion_request: completion::CompletionRequest,
1082    ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
1083        match self.route() {
1084            CompletionRoute::ChatCompletions => self.completion_chat(completion_request).await,
1085            CompletionRoute::Responses => self.completion_responses(completion_request).await,
1086        }
1087    }
1088
1089    async fn stream(
1090        &self,
1091        completion_request: completion::CompletionRequest,
1092    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1093        match self.route() {
1094            CompletionRoute::ChatCompletions => self.stream_chat(completion_request).await,
1095            CompletionRoute::Responses => self.stream_responses(completion_request).await,
1096        }
1097    }
1098}
1099
1100#[derive(Clone)]
1101pub struct EmbeddingModel<H = reqwest::Client> {
1102    client: Client<H>,
1103    pub model: String,
1104    pub encoding_format: Option<openai::EncodingFormat>,
1105    pub user: Option<String>,
1106    ndims: usize,
1107}
1108
1109#[derive(Deserialize)]
1110struct CopilotEmbeddingResponse {
1111    data: Vec<CopilotEmbeddingData>,
1112}
1113
1114#[derive(Deserialize)]
1115struct CopilotEmbeddingData {
1116    embedding: Vec<serde_json::Number>,
1117}
1118
1119impl<H> EmbeddingModel<H>
1120where
1121    Client<H>: HttpClientExt + Clone + Debug + 'static,
1122    H: Clone + Default + Debug + 'static,
1123{
1124    pub fn new(client: Client<H>, model: impl Into<String>, ndims: usize) -> Self {
1125        Self {
1126            client,
1127            model: model.into(),
1128            encoding_format: None,
1129            user: None,
1130            ndims,
1131        }
1132    }
1133}
1134
1135impl<H> embeddings::EmbeddingModel for EmbeddingModel<H>
1136where
1137    Client<H>: HttpClientExt + Clone + Debug + WasmCompatSend + WasmCompatSync + 'static,
1138    H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1139{
1140    const MAX_DOCUMENTS: usize = 1024;
1141    type Client = Client<H>;
1142
1143    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
1144        let model = model.into();
1145        let dims = ndims.unwrap_or(match model.as_str() {
1146            TEXT_EMBEDDING_3_LARGE => 3072,
1147            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
1148            _ => 0,
1149        });
1150        Self::new(client.clone(), model, dims)
1151    }
1152
1153    fn ndims(&self) -> usize {
1154        self.ndims
1155    }
1156
1157    async fn embed_texts(
1158        &self,
1159        documents: impl IntoIterator<Item = String>,
1160    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
1161        let documents = documents.into_iter().collect::<Vec<_>>();
1162        let auth = self
1163            .client
1164            .ext()
1165            .auth
1166            .auth_context()
1167            .await
1168            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
1169
1170        let headers = default_headers(&auth.api_key, "user", false);
1171        let mut body = json!({
1172            "model": self.model,
1173            "input": documents,
1174        });
1175
1176        let body_object = body.as_object_mut().ok_or_else(|| {
1177            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
1178        })?;
1179
1180        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
1181            body_object.insert("dimensions".to_owned(), json!(self.ndims));
1182        }
1183        if let Some(encoding_format) = &self.encoding_format {
1184            body_object.insert("encoding_format".to_owned(), json!(encoding_format));
1185        }
1186        if let Some(user) = &self.user {
1187            body_object.insert("user".to_owned(), json!(user));
1188        }
1189
1190        let req = apply_headers(
1191            post_with_auth_base(&self.client, &auth, "/embeddings", Transport::Http)?,
1192            &headers,
1193        )
1194        .body(serde_json::to_vec(&body)?)
1195        .map_err(|err| EmbeddingError::HttpError(err.into()))?;
1196
1197        let response = self.client.send(req).await?;
1198        if response.status().is_success() {
1199            let body: Vec<u8> = response.into_body().await?;
1200            #[derive(Deserialize)]
1201            struct NestedApiError {
1202                error: NestedApiErrorMessage,
1203            }
1204
1205            #[derive(Deserialize)]
1206            struct NestedApiErrorMessage {
1207                message: String,
1208            }
1209
1210            let body: CopilotEmbeddingResponse = match serde_json::from_slice(&body) {
1211                Ok(parsed) => parsed,
1212                Err(parse_error) => {
1213                    if let Ok(err) = serde_json::from_slice::<NestedApiError>(&body) {
1214                        return Err(EmbeddingError::ProviderError(err.error.message));
1215                    }
1216
1217                    let preview = String::from_utf8_lossy(&body);
1218                    let preview = if preview.len() > 512 {
1219                        format!("{}...", &preview[..512])
1220                    } else {
1221                        preview.into_owned()
1222                    };
1223
1224                    return Err(EmbeddingError::ProviderError(format!(
1225                        "Failed to parse Copilot embeddings response: {parse_error}; body: {preview}"
1226                    )));
1227                }
1228            };
1229
1230            Ok(body
1231                .data
1232                .into_iter()
1233                .zip(documents.into_iter())
1234                .map(|(embedding, document)| embeddings::Embedding {
1235                    document,
1236                    vec: embedding
1237                        .embedding
1238                        .into_iter()
1239                        .filter_map(|n| n.as_f64())
1240                        .collect(),
1241                })
1242                .collect())
1243        } else {
1244            let text = http_client::text(response).await?;
1245            Err(EmbeddingError::ProviderError(text))
1246        }
1247    }
1248}
1249
1250#[derive(Deserialize, Debug)]
1251struct ChatStreamingFunction {
1252    name: Option<String>,
1253    arguments: Option<String>,
1254}
1255
1256#[derive(Deserialize, Debug)]
1257struct ChatStreamingToolCall {
1258    index: usize,
1259    id: Option<String>,
1260    function: ChatStreamingFunction,
1261}
1262
1263impl From<&ChatStreamingToolCall> for CompatibleToolCallChunk {
1264    fn from(value: &ChatStreamingToolCall) -> Self {
1265        Self {
1266            index: value.index,
1267            id: value.id.clone(),
1268            name: value.function.name.clone(),
1269            arguments: value.function.arguments.clone(),
1270        }
1271    }
1272}
1273
1274#[derive(Deserialize, Debug, Default)]
1275struct ChatStreamingDelta {
1276    #[serde(default)]
1277    content: Option<String>,
1278    #[serde(default)]
1279    reasoning_content: Option<String>,
1280    #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")]
1281    tool_calls: Vec<ChatStreamingToolCall>,
1282}
1283
1284#[derive(Deserialize, Debug, PartialEq)]
1285#[serde(rename_all = "snake_case")]
1286enum ChatFinishReason {
1287    ToolCalls,
1288    Stop,
1289    ContentFilter,
1290    Length,
1291    #[serde(untagged)]
1292    Other(String),
1293}
1294
1295#[derive(Deserialize, Debug)]
1296struct ChatStreamingChoice {
1297    delta: ChatStreamingDelta,
1298    finish_reason: Option<ChatFinishReason>,
1299}
1300
1301#[derive(Deserialize, Debug)]
1302struct ChatStreamingChunk {
1303    id: Option<String>,
1304    model: Option<String>,
1305    choices: Vec<ChatStreamingChoice>,
1306    usage: Option<openai::completion::Usage>,
1307}
1308
1309#[derive(Clone, Copy)]
1310struct CopilotChatCompatibleProfile;
1311
1312impl CompatibleStreamProfile for CopilotChatCompatibleProfile {
1313    type Usage = openai::completion::Usage;
1314    type Detail = ();
1315    type FinalResponse = CopilotStreamingResponse;
1316
1317    fn normalize_chunk(
1318        &self,
1319        data: &str,
1320    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
1321        let data = match serde_json::from_str::<ChatStreamingChunk>(data) {
1322            Ok(data) => data,
1323            Err(error) => {
1324                tracing::debug!(?error, "Couldn't parse Copilot chat SSE payload");
1325                return Ok(None);
1326            }
1327        };
1328
1329        Ok(Some(
1330            openai_chat_completions_compatible::normalize_first_choice_chunk(
1331                data.id,
1332                data.model,
1333                data.usage,
1334                &data.choices,
1335                |choice| CompatibleChoiceData {
1336                    finish_reason: if choice.finish_reason == Some(ChatFinishReason::ToolCalls) {
1337                        CompatibleFinishReason::ToolCalls
1338                    } else {
1339                        CompatibleFinishReason::Other
1340                    },
1341                    text: choice.delta.content.clone(),
1342                    reasoning: choice.delta.reasoning_content.clone(),
1343                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
1344                        &choice.delta.tool_calls,
1345                    ),
1346                    details: Vec::new(),
1347                },
1348            ),
1349        ))
1350    }
1351
1352    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
1353        CopilotStreamingResponse::Chat(openai::completion::streaming::StreamingCompletionResponse {
1354            usage,
1355        })
1356    }
1357
1358    fn uses_distinct_tool_call_eviction(&self) -> bool {
1359        true
1360    }
1361}
1362
1363async fn send_copilot_chat_streaming_request<T>(
1364    http_client: T,
1365    req: Request<Vec<u8>>,
1366) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError>
1367where
1368    T: HttpClientExt + Clone + 'static,
1369{
1370    openai_chat_completions_compatible::send_compatible_streaming_request(
1371        http_client,
1372        req,
1373        CopilotChatCompatibleProfile,
1374    )
1375    .await
1376}
1377
1378fn default_token_dir() -> Option<PathBuf> {
1379    config_dir().map(|dir| dir.join("github_copilot"))
1380}
1381
1382fn config_dir() -> Option<PathBuf> {
1383    #[cfg(target_os = "windows")]
1384    {
1385        std::env::var_os("APPDATA").map(PathBuf::from)
1386    }
1387
1388    #[cfg(not(target_os = "windows"))]
1389    {
1390        std::env::var_os("XDG_CONFIG_HOME")
1391            .map(PathBuf::from)
1392            .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
1393    }
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398    use super::{
1399        ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute,
1400        TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token,
1401        route_for_model,
1402    };
1403    use crate::client::CompletionClient;
1404    use crate::completion::CompletionModel;
1405    use crate::http_client::mock::MockStreamingClient;
1406    use crate::http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response};
1407    use crate::providers::internal::openai_chat_completions_compatible::test_support::{
1408        sse_bytes_from_data_lines, sse_bytes_from_json_events,
1409    };
1410    use crate::streaming::StreamedAssistantContent;
1411    use bytes::Bytes;
1412    use futures::StreamExt;
1413    use std::collections::HashMap;
1414    use std::future::{self, Future};
1415    use std::sync::{Arc, Mutex};
1416
1417    #[derive(Debug, Clone, PartialEq, Eq)]
1418    struct CapturedRequest {
1419        uri: String,
1420        body: Bytes,
1421    }
1422
1423    #[derive(Debug, Clone, Default)]
1424    struct RecordingHttpClient {
1425        requests: Arc<Mutex<Vec<CapturedRequest>>>,
1426        response_body: Bytes,
1427    }
1428
1429    impl RecordingHttpClient {
1430        fn new(response_body: impl Into<Bytes>) -> Self {
1431            Self {
1432                requests: Arc::new(Mutex::new(Vec::new())),
1433                response_body: response_body.into(),
1434            }
1435        }
1436
1437        fn requests(&self) -> Vec<CapturedRequest> {
1438            self.requests.lock().expect("requests lock").clone()
1439        }
1440    }
1441
1442    impl HttpClientExt for RecordingHttpClient {
1443        fn send<T, U>(
1444            &self,
1445            req: Request<T>,
1446        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1447        + crate::wasm_compat::WasmCompatSend
1448        + 'static
1449        where
1450            T: Into<Bytes> + crate::wasm_compat::WasmCompatSend,
1451            U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1452        {
1453            let requests = Arc::clone(&self.requests);
1454            let response_body = self.response_body.clone();
1455            let (parts, body) = req.into_parts();
1456            let body = body.into();
1457
1458            requests
1459                .lock()
1460                .expect("requests lock")
1461                .push(CapturedRequest {
1462                    uri: parts.uri.to_string(),
1463                    body,
1464                });
1465
1466            async move {
1467                let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
1468                Response::builder()
1469                    .status(http::StatusCode::OK)
1470                    .body(body)
1471                    .map_err(http_client::Error::Protocol)
1472            }
1473        }
1474
1475        fn send_multipart<U>(
1476            &self,
1477            _req: Request<MultipartForm>,
1478        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1479        + crate::wasm_compat::WasmCompatSend
1480        + 'static
1481        where
1482            U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1483        {
1484            future::ready(Err(http_client::Error::InvalidStatusCode(
1485                http::StatusCode::NOT_IMPLEMENTED,
1486            )))
1487        }
1488
1489        fn send_streaming<T>(
1490            &self,
1491            _req: Request<T>,
1492        ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>>
1493        + crate::wasm_compat::WasmCompatSend
1494        where
1495            T: Into<Bytes>,
1496        {
1497            future::ready(Err(http_client::Error::InvalidStatusCode(
1498                http::StatusCode::NOT_IMPLEMENTED,
1499            )))
1500        }
1501    }
1502
1503    #[derive(Debug, Clone, Default)]
1504    struct SequencedStreamingHttpClient {
1505        chunks: Arc<Mutex<Option<Vec<http_client::Result<Bytes>>>>>,
1506    }
1507
1508    impl SequencedStreamingHttpClient {
1509        fn new(chunks: Vec<http_client::Result<Bytes>>) -> Self {
1510            Self {
1511                chunks: Arc::new(Mutex::new(Some(chunks))),
1512            }
1513        }
1514    }
1515
1516    impl HttpClientExt for SequencedStreamingHttpClient {
1517        fn send<T, U>(
1518            &self,
1519            _req: Request<T>,
1520        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1521        + crate::wasm_compat::WasmCompatSend
1522        + 'static
1523        where
1524            T: Into<Bytes> + crate::wasm_compat::WasmCompatSend,
1525            U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1526        {
1527            future::ready(Err(http_client::Error::InvalidStatusCode(
1528                http::StatusCode::NOT_IMPLEMENTED,
1529            )))
1530        }
1531
1532        fn send_multipart<U>(
1533            &self,
1534            _req: Request<MultipartForm>,
1535        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1536        + crate::wasm_compat::WasmCompatSend
1537        + 'static
1538        where
1539            U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1540        {
1541            future::ready(Err(http_client::Error::InvalidStatusCode(
1542                http::StatusCode::NOT_IMPLEMENTED,
1543            )))
1544        }
1545
1546        fn send_streaming<T>(
1547            &self,
1548            _req: Request<T>,
1549        ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>>
1550        + crate::wasm_compat::WasmCompatSend
1551        where
1552            T: Into<Bytes>,
1553        {
1554            let chunks = self
1555                .chunks
1556                .lock()
1557                .expect("chunks lock")
1558                .take()
1559                .expect("streaming chunks should only be consumed once");
1560
1561            async move {
1562                let byte_stream = futures::stream::iter(chunks);
1563                let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
1564
1565                Response::builder()
1566                    .status(http::StatusCode::OK)
1567                    .header(http::header::CONTENT_TYPE, "text/event-stream")
1568                    .body(boxed_stream)
1569                    .map_err(http_client::Error::Protocol)
1570            }
1571        }
1572    }
1573
1574    fn env_map(entries: &[(&str, &str)]) -> HashMap<String, String> {
1575        entries
1576            .iter()
1577            .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1578            .collect()
1579    }
1580
1581    fn minimal_chat_response() -> &'static str {
1582        r#"{
1583            "id": "chatcmpl-123",
1584            "model": "gpt-4o",
1585            "choices": [{
1586                "index": 0,
1587                "message": {
1588                    "role": "assistant",
1589                    "content": "hello"
1590                },
1591                "finish_reason": "stop"
1592            }],
1593            "usage": {
1594                "prompt_tokens": 4,
1595                "total_tokens": 7
1596            }
1597        }"#
1598    }
1599
1600    fn minimal_responses_response() -> &'static str {
1601        r#"{
1602            "id": "resp_123",
1603            "object": "response",
1604            "created_at": 1700000000,
1605            "status": "completed",
1606            "error": null,
1607            "incomplete_details": null,
1608            "instructions": null,
1609            "max_output_tokens": null,
1610            "model": "gpt-5.3-codex",
1611            "usage": {
1612                "input_tokens": 4,
1613                "input_tokens_details": {
1614                    "cached_tokens": 0
1615                },
1616                "output_tokens": 3,
1617                "output_tokens_details": {
1618                    "reasoning_tokens": 0
1619                },
1620                "total_tokens": 7
1621            },
1622            "output": [{
1623                "type": "message",
1624                "id": "msg_123",
1625                "role": "assistant",
1626                "status": "completed",
1627                "content": [{
1628                    "type": "output_text",
1629                    "text": "hello"
1630                }]
1631            }],
1632            "tools": []
1633        }"#
1634    }
1635
1636    fn minimal_embeddings_response() -> &'static str {
1637        r#"{
1638            "data": [
1639                {
1640                    "embedding": [0.1, 0.2, 0.3]
1641                },
1642                {
1643                    "embedding": [0.4, 0.5, 0.6]
1644                }
1645            ]
1646        }"#
1647    }
1648
1649    #[test]
1650    fn deserialize_standard_openai_response() {
1651        let json = r#"{
1652            "id": "chatcmpl-abc123",
1653            "object": "chat.completion",
1654            "created": 1700000000,
1655            "model": "gpt-4o",
1656            "choices": [{
1657                "index": 0,
1658                "message": {
1659                    "role": "assistant",
1660                    "content": "Hello!"
1661                },
1662                "finish_reason": "stop"
1663            }],
1664            "usage": {
1665                "prompt_tokens": 10,
1666                "completion_tokens": 5,
1667                "total_tokens": 15
1668            }
1669        }"#;
1670
1671        let response: ChatCompletionResponse =
1672            serde_json::from_str(json).expect("standard OpenAI response should deserialize");
1673        assert_eq!(response.id, "chatcmpl-abc123");
1674        assert_eq!(response.object.as_deref(), Some("chat.completion"));
1675        assert_eq!(response.created, Some(1700000000));
1676        assert_eq!(response.model, "gpt-4o");
1677        assert_eq!(response.choices.len(), 1);
1678        assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop"));
1679    }
1680
1681    #[test]
1682    fn deserialize_copilot_response_without_object_and_created() {
1683        let response: ChatCompletionResponse = serde_json::from_str(minimal_chat_response())
1684            .expect("Copilot response should deserialize");
1685
1686        assert_eq!(response.id, "chatcmpl-123");
1687        assert_eq!(response.object, None);
1688        assert_eq!(response.created, None);
1689        assert_eq!(response.model, "gpt-4o");
1690        assert_eq!(response.choices.len(), 1);
1691    }
1692
1693    #[test]
1694    fn deserialize_copilot_response_without_finish_reason() {
1695        let json = r#"{
1696            "id": "chatcmpl-claude-001",
1697            "model": "claude-3.5-sonnet",
1698            "choices": [{
1699                "message": {
1700                    "role": "assistant",
1701                    "content": "Here is my analysis."
1702                }
1703            }],
1704            "usage": {
1705                "prompt_tokens": 50,
1706                "total_tokens": 80
1707            }
1708        }"#;
1709
1710        let response: ChatCompletionResponse =
1711            serde_json::from_str(json).expect("Claude-via-Copilot response should deserialize");
1712
1713        assert_eq!(response.model, "claude-3.5-sonnet");
1714        assert_eq!(response.choices[0].finish_reason, None);
1715        assert_eq!(response.choices[0].index, 0);
1716    }
1717
1718    #[test]
1719    fn error_response_with_message_field() {
1720        let json = r#"{"message": "rate limit exceeded"}"#;
1721        let err: ChatApiErrorResponse = serde_json::from_str(json).expect("message-shaped error");
1722
1723        assert_eq!(err.error_message(), "rate limit exceeded");
1724    }
1725
1726    #[test]
1727    fn error_response_with_error_field() {
1728        let json = r#"{"error": "model not found"}"#;
1729        let err: ChatApiErrorResponse = serde_json::from_str(json).expect("error-shaped error");
1730
1731        assert_eq!(err.error_message(), "model not found");
1732    }
1733
1734    #[test]
1735    fn routes_codex_models_to_responses() {
1736        assert_eq!(route_for_model("gpt-5.3-codex"), CompletionRoute::Responses);
1737        assert_eq!(
1738            route_for_model("gpt-5.1-CODEX-mini"),
1739            CompletionRoute::Responses
1740        );
1741        assert_eq!(route_for_model("gpt-5.2"), CompletionRoute::ChatCompletions);
1742        assert_eq!(
1743            route_for_model("claude-sonnet-4.5"),
1744            CompletionRoute::ChatCompletions
1745        );
1746    }
1747
1748    #[tokio::test]
1749    async fn completion_model_routes_chat_requests_to_chat_completions() {
1750        let http_client = RecordingHttpClient::new(minimal_chat_response());
1751        let client = Client::builder()
1752            .api_key("copilot-token")
1753            .http_client(http_client.clone())
1754            .build()
1755            .expect("build client");
1756        let model = client.completion_model("gpt-4o");
1757        let request = model.completion_request("hello").build();
1758
1759        let _response = model.completion(request).await.expect("chat completion");
1760
1761        let requests = http_client.requests();
1762        assert_eq!(requests.len(), 1);
1763        assert!(requests[0].uri.ends_with("/chat/completions"));
1764        assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-4o\""));
1765    }
1766
1767    #[tokio::test]
1768    async fn completion_model_routes_codex_requests_to_responses() {
1769        let http_client = RecordingHttpClient::new(minimal_responses_response());
1770        let client = Client::builder()
1771            .api_key("copilot-token")
1772            .http_client(http_client.clone())
1773            .build()
1774            .expect("build client");
1775        let model = client.completion_model("gpt-5.3-codex");
1776        let request = model.completion_request("hello").build();
1777
1778        let _response = model
1779            .completion(request)
1780            .await
1781            .expect("responses completion");
1782
1783        let requests = http_client.requests();
1784        assert_eq!(requests.len(), 1);
1785        assert!(requests[0].uri.ends_with("/responses"));
1786        assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-5.3-codex\""));
1787    }
1788
1789    #[tokio::test]
1790    async fn embeddings_accept_minimal_copilot_response_shape() {
1791        use crate::client::EmbeddingsClient;
1792        use crate::embeddings::EmbeddingModel as _;
1793
1794        let http_client = RecordingHttpClient::new(minimal_embeddings_response());
1795        let client = Client::builder()
1796            .api_key("copilot-token")
1797            .http_client(http_client.clone())
1798            .build()
1799            .expect("build client");
1800        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1801
1802        let embeddings = model
1803            .embed_texts(["one".to_string(), "two".to_string()])
1804            .await
1805            .expect("embeddings should deserialize");
1806
1807        assert_eq!(embeddings.len(), 2);
1808        assert_eq!(embeddings[0].vec, vec![0.1, 0.2, 0.3]);
1809        assert_eq!(embeddings[1].vec, vec![0.4, 0.5, 0.6]);
1810
1811        let requests = http_client.requests();
1812        assert_eq!(requests.len(), 1);
1813        assert!(requests[0].uri.ends_with("/embeddings"));
1814        assert!(
1815            String::from_utf8_lossy(&requests[0].body)
1816                .contains("\"model\":\"text-embedding-3-small\"")
1817        );
1818    }
1819
1820    #[tokio::test]
1821    async fn responses_stream_terminates_after_terminal_error() {
1822        let tool_call_done = serde_json::json!({
1823            "type": "response.output_item.done",
1824            "sequence_number": 1,
1825            "item": {
1826                "type": "function_call",
1827                "id": "fc_123",
1828                "arguments": "{}",
1829                "call_id": "call_123",
1830                "name": "example_tool",
1831                "status": "completed"
1832            }
1833        });
1834        let failed = serde_json::json!({
1835            "type": "response.failed",
1836            "sequence_number": 2,
1837            "response": {
1838                "id": "resp_123",
1839                "object": "response",
1840                "created_at": 1700000000,
1841                "status": "failed",
1842                "error": {
1843                    "code": "server_error",
1844                    "message": "Copilot response stream failed"
1845                },
1846                "incomplete_details": null,
1847                "instructions": null,
1848                "max_output_tokens": null,
1849                "model": "gpt-5.3-codex",
1850                "usage": null,
1851                "output": [],
1852                "tools": []
1853            }
1854        });
1855        let http_client = MockStreamingClient {
1856            sse_bytes: sse_bytes_from_json_events(&[tool_call_done, failed]),
1857        };
1858        let client = Client::builder()
1859            .api_key("copilot-token")
1860            .http_client(http_client)
1861            .build()
1862            .expect("build client");
1863        let model = client.completion_model("gpt-5.3-codex");
1864        let request = model.completion_request("hello").build();
1865        let mut stream = model.stream(request).await.expect("stream should start");
1866
1867        let err = match stream.next().await.expect("stream should yield an item") {
1868            Ok(_) => panic!("stream should surface a provider error"),
1869            Err(err) => err,
1870        };
1871        assert_eq!(
1872            err.to_string(),
1873            "ProviderError: Copilot response stream failed"
1874        );
1875        assert!(
1876            stream.next().await.is_none(),
1877            "responses stream should terminate immediately after a terminal error"
1878        );
1879    }
1880
1881    #[tokio::test]
1882    async fn chat_stream_terminates_after_transport_error() {
1883        let chunks = vec![
1884            Ok(sse_bytes_from_data_lines([
1885                "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
1886            ])),
1887            Err(http_client::Error::InvalidStatusCode(
1888                http::StatusCode::BAD_GATEWAY,
1889            )),
1890        ];
1891
1892        let http_client = SequencedStreamingHttpClient::new(chunks);
1893        let client = Client::builder()
1894            .api_key("copilot-token")
1895            .http_client(http_client)
1896            .build()
1897            .expect("build client");
1898        let model = client.completion_model("gpt-4o");
1899        let request = model.completion_request("hello").build();
1900        let mut stream = model.stream(request).await.expect("stream should start");
1901
1902        let mut saw_error = false;
1903        while let Some(item) = stream.next().await {
1904            match item {
1905                Ok(StreamedAssistantContent::ToolCallDelta { .. }) => {}
1906                Err(err) => {
1907                    assert_eq!(
1908                        err.to_string(),
1909                        "ProviderError: Invalid status code: 502 Bad Gateway"
1910                    );
1911                    saw_error = true;
1912                    break;
1913                }
1914                Ok(_) => panic!("unexpected non-error stream item before transport failure"),
1915            }
1916        }
1917
1918        assert!(saw_error, "stream should surface the transport error");
1919        assert!(
1920            stream.next().await.is_none(),
1921            "chat stream should terminate immediately after a transport error"
1922        );
1923    }
1924
1925    #[test]
1926    fn env_api_key_prefers_github_prefixed_vars() {
1927        let env = env_map(&[
1928            ("COPILOT_API_KEY", "copilot-key"),
1929            ("GITHUB_COPILOT_API_KEY", "github-key"),
1930            ("GITHUB_TOKEN", "bootstrap-token"),
1931        ]);
1932        let get = |name: &str| env.get(name).cloned();
1933
1934        assert_eq!(env_api_key(&get).as_deref(), Some("github-key"));
1935    }
1936
1937    #[test]
1938    fn env_github_access_token_prefers_explicit_bootstrap_var() {
1939        let env = env_map(&[
1940            ("COPILOT_GITHUB_ACCESS_TOKEN", "explicit-bootstrap"),
1941            ("GITHUB_TOKEN", "fallback-bootstrap"),
1942        ]);
1943        let get = |name: &str| env.get(name).cloned();
1944
1945        assert_eq!(
1946            env_github_access_token(&get).as_deref(),
1947            Some("explicit-bootstrap")
1948        );
1949    }
1950
1951    #[test]
1952    fn env_base_url_prefers_github_prefixed_vars() {
1953        let env = env_map(&[
1954            ("COPILOT_BASE_URL", "https://copilot.example"),
1955            ("GITHUB_COPILOT_API_BASE", "https://github.example"),
1956        ]);
1957        let get = |name: &str| env.get(name).cloned();
1958
1959        assert_eq!(
1960            env_base_url(&get).as_deref(),
1961            Some("https://github.example")
1962        );
1963    }
1964
1965    #[test]
1966    fn env_without_api_key_falls_back_to_oauth() {
1967        let env = env_map(&[("COPILOT_BASE_URL", "https://copilot.example")]);
1968        let get = |name: &str| env.get(name).cloned();
1969
1970        assert!(env_api_key(&get).is_none());
1971        assert!(env_github_access_token(&get).is_none());
1972        assert_eq!(
1973            env_base_url(&get).as_deref(),
1974            Some("https://copilot.example")
1975        );
1976    }
1977
1978    #[test]
1979    fn env_github_token_is_not_treated_as_copilot_api_key() {
1980        let env = env_map(&[("GITHUB_TOKEN", "bootstrap-token")]);
1981        let get = |name: &str| env.get(name).cloned();
1982
1983        assert!(env_api_key(&get).is_none());
1984        assert_eq!(
1985            env_github_access_token(&get).as_deref(),
1986            Some("bootstrap-token")
1987        );
1988    }
1989}