1mod auth;
24
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27 ProviderClient, Transport,
28};
29use crate::completion::{self, CompletionError, GetTokenUsage};
30use crate::embeddings::{self, EmbeddingError};
31use crate::http_client::{self, HttpClientExt};
32use crate::providers::internal::openai_chat_completions_compatible::{
33 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
34 CompatibleToolCallChunk,
35};
36use crate::providers::openai;
37use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest};
38use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse};
39use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
40use async_stream::stream;
41use futures::StreamExt;
42use http::Request;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45use std::borrow::Cow;
46use std::collections::HashMap;
47use std::fmt::Debug;
48use std::path::{Path, PathBuf};
49use tracing::info_span;
50use tracing_futures::Instrument as _;
51
52const GITHUB_COPILOT_API_BASE_URL: &str = "https://api.githubcopilot.com";
53const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
54const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
55const API_VERSION: &str = "2025-04-01";
56
57pub const GPT_4: &str = "gpt-4";
59pub const GPT_4O: &str = "gpt-4o";
61pub const GPT_4O_MINI: &str = "gpt-4o-mini";
63pub const GPT_4_1: &str = "gpt-4.1";
65pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
67pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
69pub const GPT_5_3_CODEX: &str = "gpt-5.3-codex";
71pub const GPT_5_1_CODEX: &str = "gpt-5.1-codex";
73pub const CLAUDE_SONNET_4: &str = "claude-sonnet-4";
75pub const CLAUDE_3_5_SONNET: &str = "claude-3.5-sonnet";
77pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash-001";
79pub const O3_MINI: &str = "o3-mini";
81pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
83pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
85pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
87
88pub use openai::EncodingFormat;
89
90#[derive(Clone)]
91pub enum CopilotAuth {
92 ApiKey(String),
93 GitHubAccessToken(String),
94 OAuth,
95}
96
97impl ApiKey for CopilotAuth {}
98
99impl<S> From<S> for CopilotAuth
100where
101 S: Into<String>,
102{
103 fn from(value: S) -> Self {
104 Self::ApiKey(value.into())
105 }
106}
107
108impl Debug for CopilotAuth {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 Self::ApiKey(_) => f.write_str("ApiKey(<redacted>)"),
112 Self::GitHubAccessToken(_) => f.write_str("GitHubAccessToken(<redacted>)"),
113 Self::OAuth => f.write_str("OAuth"),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct CopilotBuilder {
120 access_token_file: Option<PathBuf>,
121 api_key_file: Option<PathBuf>,
122 device_code_handler: auth::DeviceCodeHandler,
123}
124
125#[derive(Clone)]
126pub struct CopilotExt {
127 auth: auth::Authenticator,
128}
129
130impl Debug for CopilotExt {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("CopilotExt")
133 .field("auth", &self.auth)
134 .finish()
135 }
136}
137
138pub type Client<H = reqwest::Client> = client::Client<CopilotExt, H>;
139pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<CopilotBuilder, CopilotAuth, H>;
140
141impl Default for CopilotBuilder {
142 fn default() -> Self {
143 let token_dir = default_token_dir();
144 Self {
145 access_token_file: token_dir.as_ref().map(|dir| dir.join("access-token")),
146 api_key_file: token_dir.map(|dir| dir.join("api-key.json")),
147 device_code_handler: auth::DeviceCodeHandler::default(),
148 }
149 }
150}
151
152impl Provider for CopilotExt {
153 type Builder = CopilotBuilder;
154
155 const VERIFY_PATH: &'static str = "";
156}
157
158impl<H> Capabilities<H> for CopilotExt {
159 type Completion = Capable<CompletionModel<H>>;
160 type Embeddings = Capable<EmbeddingModel<H>>;
161 type Transcription = Nothing;
162 type ModelListing = Nothing;
163 #[cfg(feature = "image")]
164 type ImageGeneration = Nothing;
165 #[cfg(feature = "audio")]
166 type AudioGeneration = Nothing;
167}
168
169impl DebugExt for CopilotExt {}
170
171impl ProviderBuilder for CopilotBuilder {
172 type Extension<H>
173 = CopilotExt
174 where
175 H: HttpClientExt;
176 type ApiKey = CopilotAuth;
177
178 const BASE_URL: &'static str = GITHUB_COPILOT_API_BASE_URL;
179
180 fn build<H>(
181 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
182 ) -> http_client::Result<Self::Extension<H>>
183 where
184 H: HttpClientExt,
185 {
186 let auth = match builder.get_api_key() {
187 CopilotAuth::ApiKey(api_key) => auth::AuthSource::ApiKey(api_key.clone()),
188 CopilotAuth::GitHubAccessToken(access_token) => {
189 auth::AuthSource::GitHubAccessToken(access_token.clone())
190 }
191 CopilotAuth::OAuth => auth::AuthSource::OAuth,
192 };
193
194 let ext = builder.ext();
195 Ok(CopilotExt {
196 auth: auth::Authenticator::new(
197 auth,
198 ext.access_token_file.clone(),
199 ext.api_key_file.clone(),
200 ext.device_code_handler.clone(),
201 ),
202 })
203 }
204}
205
206impl ProviderClient for Client {
207 type Input = CopilotAuth;
208 type Error = crate::client::ProviderClientError;
209
210 fn from_env() -> Result<Self, Self::Error> {
211 let mut builder = Self::builder();
212 fn get(name: &str) -> Option<String> {
213 std::env::var(name).ok()
214 }
215
216 if let Some(base_url) = env_base_url(&get) {
217 builder = builder.base_url(base_url);
218 }
219
220 if let Some(api_key) = env_api_key(&get) {
221 builder.api_key(api_key).build().map_err(Into::into)
222 } else if let Some(access_token) = env_github_access_token(&get) {
223 builder
224 .github_access_token(access_token)
225 .build()
226 .map_err(Into::into)
227 } else {
228 builder.oauth().build().map_err(Into::into)
229 }
230 }
231
232 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
233 Self::builder().api_key(input).build().map_err(Into::into)
234 }
235}
236
237impl<H> client::ClientBuilder<CopilotBuilder, crate::markers::Missing, H> {
238 pub fn github_access_token(
239 self,
240 access_token: impl Into<String>,
241 ) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
242 self.api_key(CopilotAuth::GitHubAccessToken(access_token.into()))
243 }
244
245 pub fn oauth(self) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
246 self.api_key(CopilotAuth::OAuth)
247 }
248}
249
250impl<H> ClientBuilder<H> {
251 pub fn on_device_code<F>(self, handler: F) -> Self
252 where
253 F: Fn(auth::DeviceCodePrompt) + Send + Sync + 'static,
254 {
255 self.over_ext(|mut ext| {
256 ext.device_code_handler = auth::DeviceCodeHandler::new(handler);
257 ext
258 })
259 }
260
261 pub fn token_dir(self, path: impl AsRef<Path>) -> Self {
262 let path = path.as_ref();
263 self.over_ext(|mut ext| {
264 ext.access_token_file = Some(path.join("access-token"));
265 ext.api_key_file = Some(path.join("api-key.json"));
266 ext
267 })
268 }
269
270 pub fn access_token_file(self, path: impl AsRef<Path>) -> Self {
271 let path = path.as_ref().to_path_buf();
272 self.over_ext(|mut ext| {
273 ext.access_token_file = Some(path);
274 ext
275 })
276 }
277
278 pub fn api_key_file(self, path: impl AsRef<Path>) -> Self {
279 let path = path.as_ref().to_path_buf();
280 self.over_ext(|mut ext| {
281 ext.api_key_file = Some(path);
282 ext
283 })
284 }
285}
286
287fn env_value<F>(get: &F, name: &str) -> Option<String>
288where
289 F: Fn(&str) -> Option<String>,
290{
291 get(name).filter(|value| !value.trim().is_empty())
292}
293
294fn first_env_value<F>(get: &F, keys: &[&str]) -> Option<String>
295where
296 F: Fn(&str) -> Option<String>,
297{
298 keys.iter().find_map(|key| env_value(get, key))
299}
300
301fn env_api_key<F>(get: &F) -> Option<String>
302where
303 F: Fn(&str) -> Option<String>,
304{
305 first_env_value(get, &["GITHUB_COPILOT_API_KEY", "COPILOT_API_KEY"])
306}
307
308fn env_github_access_token<F>(get: &F) -> Option<String>
309where
310 F: Fn(&str) -> Option<String>,
311{
312 first_env_value(get, &["COPILOT_GITHUB_ACCESS_TOKEN", "GITHUB_TOKEN"])
313}
314
315fn env_base_url<F>(get: &F) -> Option<String>
316where
317 F: Fn(&str) -> Option<String>,
318{
319 first_env_value(get, &["GITHUB_COPILOT_API_BASE", "COPILOT_BASE_URL"])
320}
321
322impl<H> Client<H>
323where
324 H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
325{
326 pub async fn authorize(&self) -> Result<(), auth::AuthError> {
327 self.ext().auth.auth_context().await.map(|_| ())
328 }
329}
330
331fn default_headers(
332 api_key: &str,
333 initiator: &'static str,
334 has_vision: bool,
335) -> Vec<(&'static str, String)> {
336 let mut headers = vec![
337 (
338 http::header::AUTHORIZATION.as_str(),
339 format!("Bearer {api_key}"),
340 ),
341 ("copilot-integration-id", "vscode-chat".to_string()),
342 ("editor-version", "vscode/1.95.0".to_string()),
343 ("editor-plugin-version", EDITOR_PLUGIN_VERSION.to_string()),
344 ("user-agent", USER_AGENT.to_string()),
345 ("openai-intent", "conversation-panel".to_string()),
346 ("x-github-api-version", API_VERSION.to_string()),
347 ("x-request-id", nanoid::nanoid!()),
348 (
349 "x-vscode-user-agent-library-version",
350 "electron-fetch".to_string(),
351 ),
352 ("X-Initiator", initiator.to_string()),
353 ];
354
355 if has_vision {
356 headers.push(("copilot-vision-request", "true".to_string()));
357 }
358
359 headers
360}
361
362fn apply_headers(
363 builder: http_client::Builder,
364 headers: &[(&'static str, String)],
365) -> http_client::Builder {
366 headers
367 .iter()
368 .fold(builder, |builder, (key, value)| builder.header(*key, value))
369}
370
371fn runtime_base_url<'a, H>(client: &'a Client<H>, auth: &'a auth::AuthContext) -> Cow<'a, str> {
372 if client.base_url() == GITHUB_COPILOT_API_BASE_URL {
373 auth.api_base
374 .as_deref()
375 .map(Cow::Borrowed)
376 .unwrap_or_else(|| Cow::Borrowed(client.base_url()))
377 } else {
378 Cow::Borrowed(client.base_url())
379 }
380}
381
382fn post_with_auth_base<H>(
383 client: &Client<H>,
384 auth: &auth::AuthContext,
385 path: &str,
386 transport: Transport,
387) -> http_client::Result<http_client::Builder> {
388 let uri = client
389 .ext()
390 .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
391 let mut req = Request::post(uri);
392
393 if let Some(headers) = req.headers_mut() {
394 headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
395 }
396
397 client.ext().with_custom(req)
398}
399
400fn request_initiator(request: &completion::CompletionRequest) -> &'static str {
401 for message in request.chat_history.iter() {
402 match message {
403 crate::completion::Message::Assistant { .. } => return "agent",
404 crate::completion::Message::User { content } => {
405 if content
406 .iter()
407 .any(|item| matches!(item, crate::message::UserContent::ToolResult(_)))
408 {
409 return "agent";
410 }
411 }
412 crate::completion::Message::System { .. } => {}
413 }
414 }
415
416 "user"
417}
418
419fn request_has_vision(request: &completion::CompletionRequest) -> bool {
420 request.chat_history.iter().any(|message| match message {
421 crate::completion::Message::User { content } => content
422 .iter()
423 .any(|item| matches!(item, crate::message::UserContent::Image(_))),
424 _ => false,
425 })
426}
427
428#[derive(Clone, Copy, Debug, PartialEq, Eq)]
429enum CompletionRoute {
430 ChatCompletions,
431 Responses,
432}
433
434fn route_for_model(model: &str) -> CompletionRoute {
435 if model.to_ascii_lowercase().contains("codex") {
436 CompletionRoute::Responses
437 } else {
438 CompletionRoute::ChatCompletions
439 }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
443#[serde(tag = "api", rename_all = "snake_case")]
444pub enum CopilotCompletionResponse {
445 Chat(ChatCompletionResponse),
446 Responses(Box<responses_api::CompletionResponse>),
447}
448
449#[derive(Clone, Serialize, Deserialize)]
450#[serde(tag = "api", rename_all = "snake_case")]
451pub enum CopilotStreamingResponse {
452 Chat(openai::completion::streaming::StreamingCompletionResponse),
453 Responses(responses_api::streaming::StreamingCompletionResponse),
454}
455
456impl GetTokenUsage for CopilotStreamingResponse {
457 fn token_usage(&self) -> Option<completion::Usage> {
458 match self {
459 Self::Chat(response) => response.token_usage(),
460 Self::Responses(response) => response.token_usage(),
461 }
462 }
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ChatCompletionResponse {
467 pub id: String,
468 #[serde(default)]
469 pub object: Option<String>,
470 #[serde(default)]
471 pub created: Option<u64>,
472 pub model: String,
473 pub system_fingerprint: Option<String>,
474 pub choices: Vec<ChatChoice>,
475 pub usage: Option<openai::completion::Usage>,
476}
477
478#[derive(Clone, Debug, Serialize, Deserialize)]
479pub struct ChatChoice {
480 #[serde(default)]
481 pub index: usize,
482 pub message: openai::completion::Message,
483 pub logprobs: Option<serde_json::Value>,
484 #[serde(default)]
485 pub finish_reason: Option<String>,
486}
487
488impl TryFrom<ChatCompletionResponse> for completion::CompletionResponse<ChatCompletionResponse> {
489 type Error = CompletionError;
490
491 fn try_from(response: ChatCompletionResponse) -> Result<Self, Self::Error> {
492 let choice = response.choices.first().ok_or_else(|| {
493 CompletionError::ResponseError("Response contained no choices".to_owned())
494 })?;
495
496 let content = match &choice.message {
497 openai::completion::Message::Assistant {
498 content,
499 tool_calls,
500 ..
501 } => {
502 let mut content = content
503 .iter()
504 .filter_map(|c| {
505 let s = match c {
506 openai::completion::AssistantContent::Text { text } => text,
507 openai::completion::AssistantContent::Refusal { refusal } => refusal,
508 };
509 if s.is_empty() {
510 None
511 } else {
512 Some(completion::AssistantContent::text(s))
513 }
514 })
515 .collect::<Vec<_>>();
516
517 content.extend(
518 tool_calls
519 .iter()
520 .map(|call| {
521 completion::AssistantContent::tool_call(
522 &call.id,
523 &call.function.name,
524 call.function.arguments.clone(),
525 )
526 })
527 .collect::<Vec<_>>(),
528 );
529 Ok(content)
530 }
531 _ => Err(CompletionError::ResponseError(
532 "Response did not contain a valid message or tool call".into(),
533 )),
534 }?;
535
536 let choice = crate::OneOrMany::many(content).map_err(|_| {
537 CompletionError::ResponseError(
538 "Response contained no message or tool call (empty)".to_owned(),
539 )
540 })?;
541
542 let usage = response
543 .usage
544 .as_ref()
545 .map(|usage| completion::Usage {
546 input_tokens: usage.prompt_tokens as u64,
547 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
548 total_tokens: usage.total_tokens as u64,
549 cached_input_tokens: usage
550 .prompt_tokens_details
551 .as_ref()
552 .map(|d| d.cached_tokens as u64)
553 .unwrap_or(0),
554 cache_creation_input_tokens: 0,
555 })
556 .unwrap_or_default();
557
558 Ok(completion::CompletionResponse {
559 choice,
560 usage,
561 raw_response: response,
562 message_id: None,
563 })
564 }
565}
566
567#[derive(Debug, Deserialize)]
568pub struct ChatApiErrorResponse {
569 #[serde(default)]
570 pub message: Option<String>,
571 #[serde(default)]
572 pub error: Option<String>,
573}
574
575impl ChatApiErrorResponse {
576 pub fn error_message(&self) -> &str {
577 self.message
578 .as_deref()
579 .or(self.error.as_deref())
580 .unwrap_or("unknown error")
581 }
582}
583
584#[derive(Debug, Deserialize)]
585#[serde(untagged)]
586enum ChatApiResponse<T> {
587 Ok(T),
588 Err(ChatApiErrorResponse),
589}
590
591#[derive(Clone)]
592pub struct CompletionModel<H = reqwest::Client> {
593 client: Client<H>,
594 pub model: String,
595 pub strict_tools: bool,
596 pub tool_result_array_content: bool,
597}
598
599impl<H> CompletionModel<H>
600where
601 Client<H>: HttpClientExt + Clone + Debug + 'static,
602 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
603{
604 pub fn new(client: Client<H>, model: impl Into<String>) -> Self {
605 Self {
606 client,
607 model: model.into(),
608 strict_tools: false,
609 tool_result_array_content: false,
610 }
611 }
612
613 pub fn with_strict_tools(mut self) -> Self {
614 self.strict_tools = true;
615 self
616 }
617
618 pub fn with_tool_result_array_content(mut self) -> Self {
619 self.tool_result_array_content = true;
620 self
621 }
622
623 fn route(&self) -> CompletionRoute {
624 route_for_model(&self.model)
625 }
626
627 async fn auth_context(&self) -> Result<auth::AuthContext, CompletionError> {
628 self.client
629 .ext()
630 .auth
631 .auth_context()
632 .await
633 .map_err(|err| CompletionError::ProviderError(err.to_string()))
634 }
635
636 fn chat_request(
637 &self,
638 completion_request: completion::CompletionRequest,
639 ) -> Result<openai::completion::CompletionRequest, CompletionError> {
640 openai::completion::CompletionRequest::try_from(openai::completion::OpenAIRequestParams {
641 model: self.model.clone(),
642 request: completion_request,
643 strict_tools: self.strict_tools,
644 tool_result_array_content: self.tool_result_array_content,
645 })
646 }
647
648 fn responses_request(
649 &self,
650 completion_request: completion::CompletionRequest,
651 ) -> Result<ResponsesRequest, CompletionError> {
652 ResponsesRequest::try_from((self.model.clone(), completion_request))
653 }
654
655 async fn completion_chat(
656 &self,
657 completion_request: completion::CompletionRequest,
658 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
659 let initiator = request_initiator(&completion_request);
660 let has_vision = request_has_vision(&completion_request);
661 let request = self.chat_request(completion_request)?;
662 let body = serde_json::to_vec(&request)?;
663 let auth = self.auth_context().await?;
664
665 let headers = default_headers(&auth.api_key, initiator, has_vision);
666 let req = apply_headers(
667 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Http)?,
668 &headers,
669 )
670 .body(body)
671 .map_err(|err| CompletionError::HttpError(err.into()))?;
672
673 let span = if tracing::Span::current().is_disabled() {
674 info_span!(
675 target: "rig::completions",
676 "chat",
677 gen_ai.operation.name = "chat",
678 gen_ai.provider.name = "copilot",
679 gen_ai.request.model = self.model,
680 gen_ai.response.id = tracing::field::Empty,
681 gen_ai.response.model = tracing::field::Empty,
682 gen_ai.usage.output_tokens = tracing::field::Empty,
683 gen_ai.usage.input_tokens = tracing::field::Empty,
684 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
685 )
686 } else {
687 tracing::Span::current()
688 };
689
690 async move {
691 let response = self.client.send(req).await?;
692
693 if response.status().is_success() {
694 let body = http_client::text(response).await?;
695 match serde_json::from_str::<ChatApiResponse<ChatCompletionResponse>>(&body)? {
696 ChatApiResponse::Ok(response) => {
697 let core = completion::CompletionResponse::try_from(response.clone())?;
698 let span = tracing::Span::current();
699 span.record("gen_ai.response.id", response.id.as_str());
700 span.record("gen_ai.response.model", response.model.as_str());
701 if let Some(usage) = &response.usage {
702 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
703 span.record(
704 "gen_ai.usage.output_tokens",
705 usage.total_tokens - usage.prompt_tokens,
706 );
707 span.record(
708 "gen_ai.usage.cache_read.input_tokens",
709 usage
710 .prompt_tokens_details
711 .as_ref()
712 .map(|details| details.cached_tokens)
713 .unwrap_or(0),
714 );
715 }
716
717 Ok(completion::CompletionResponse {
718 choice: core.choice,
719 usage: core.usage,
720 raw_response: CopilotCompletionResponse::Chat(response),
721 message_id: core.message_id,
722 })
723 }
724 ChatApiResponse::Err(err) => Err(CompletionError::ProviderError(
725 err.error_message().to_string(),
726 )),
727 }
728 } else {
729 let body = http_client::text(response).await?;
730 Err(CompletionError::ProviderError(body))
731 }
732 }
733 .instrument(span)
734 .await
735 }
736
737 async fn completion_responses(
738 &self,
739 completion_request: completion::CompletionRequest,
740 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
741 let initiator = request_initiator(&completion_request);
742 let has_vision = request_has_vision(&completion_request);
743 let request = self.responses_request(completion_request)?;
744 let auth = self.auth_context().await?;
745
746 let headers = default_headers(&auth.api_key, initiator, has_vision);
747 let req = apply_headers(
748 post_with_auth_base(&self.client, &auth, "/responses", Transport::Http)?,
749 &headers,
750 )
751 .body(serde_json::to_vec(&request)?)
752 .map_err(|err| CompletionError::HttpError(err.into()))?;
753
754 let span = if tracing::Span::current().is_disabled() {
755 info_span!(
756 target: "rig::completions",
757 "chat",
758 gen_ai.operation.name = "chat",
759 gen_ai.provider.name = "copilot",
760 gen_ai.request.model = self.model,
761 gen_ai.response.id = tracing::field::Empty,
762 gen_ai.response.model = tracing::field::Empty,
763 gen_ai.usage.output_tokens = tracing::field::Empty,
764 gen_ai.usage.input_tokens = tracing::field::Empty,
765 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
766 )
767 } else {
768 tracing::Span::current()
769 };
770
771 async move {
772 let response = self.client.send(req).await?;
773 if response.status().is_success() {
774 let body = http_client::text(response).await?;
775 let response = serde_json::from_str::<responses_api::CompletionResponse>(&body)?;
776 let core = completion::CompletionResponse::try_from(response.clone())?;
777
778 let span = tracing::Span::current();
779 span.record("gen_ai.response.id", response.id.as_str());
780 span.record("gen_ai.response.model", response.model.as_str());
781 if let Some(usage) = &response.usage {
782 span.record("gen_ai.usage.input_tokens", usage.input_tokens);
783 span.record("gen_ai.usage.output_tokens", usage.output_tokens);
784 span.record(
785 "gen_ai.usage.cache_read.input_tokens",
786 usage
787 .input_tokens_details
788 .as_ref()
789 .map(|details| details.cached_tokens)
790 .unwrap_or(0),
791 );
792 }
793
794 Ok(completion::CompletionResponse {
795 choice: core.choice,
796 usage: core.usage,
797 raw_response: CopilotCompletionResponse::Responses(Box::new(response)),
798 message_id: core.message_id,
799 })
800 } else {
801 let body = http_client::text(response).await?;
802 Err(CompletionError::ProviderError(body))
803 }
804 }
805 .instrument(span)
806 .await
807 }
808
809 async fn stream_chat(
810 &self,
811 completion_request: completion::CompletionRequest,
812 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
813 let initiator = request_initiator(&completion_request);
814 let has_vision = request_has_vision(&completion_request);
815 let request = self.chat_request(completion_request)?;
816 let auth = self.auth_context().await?;
817 let headers = default_headers(&auth.api_key, initiator, has_vision);
818 let mut request_json = serde_json::to_value(&request)?;
819 let request_object = request_json.as_object_mut().ok_or_else(|| {
820 CompletionError::ResponseError("copilot request body must be a JSON object".into())
821 })?;
822 request_object.insert("stream".to_owned(), json!(true));
823 request_object.insert(
824 "stream_options".to_owned(),
825 json!({ "include_usage": true }),
826 );
827
828 let req = apply_headers(
829 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?,
830 &headers,
831 )
832 .body(serde_json::to_vec(&request_json)?)
833 .map_err(|err| CompletionError::HttpError(err.into()))?;
834
835 let span = if tracing::Span::current().is_disabled() {
836 info_span!(
837 target: "rig::completions",
838 "chat_streaming",
839 gen_ai.operation.name = "chat_streaming",
840 gen_ai.provider.name = "copilot",
841 gen_ai.request.model = self.model,
842 gen_ai.response.id = tracing::field::Empty,
843 gen_ai.response.model = tracing::field::Empty,
844 gen_ai.usage.output_tokens = tracing::field::Empty,
845 gen_ai.usage.input_tokens = tracing::field::Empty,
846 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
847 )
848 } else {
849 tracing::Span::current()
850 };
851
852 tracing::Instrument::instrument(
853 send_copilot_chat_streaming_request(self.client.clone(), req),
854 span,
855 )
856 .await
857 }
858
859 async fn stream_responses(
860 &self,
861 completion_request: completion::CompletionRequest,
862 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
863 let initiator = request_initiator(&completion_request);
864 let has_vision = request_has_vision(&completion_request);
865 let mut request = self.responses_request(completion_request)?;
866 request.stream = Some(true);
867 let auth = self.auth_context().await?;
868
869 let headers = default_headers(&auth.api_key, initiator, has_vision);
870 let req = apply_headers(
871 post_with_auth_base(&self.client, &auth, "/responses", Transport::Sse)?,
872 &headers,
873 )
874 .body(serde_json::to_vec(&request)?)
875 .map_err(|err| CompletionError::HttpError(err.into()))?;
876
877 let span = if tracing::Span::current().is_disabled() {
878 info_span!(
879 target: "rig::completions",
880 "chat_streaming",
881 gen_ai.operation.name = "chat_streaming",
882 gen_ai.provider.name = "copilot",
883 gen_ai.request.model = self.model,
884 gen_ai.response.id = tracing::field::Empty,
885 gen_ai.response.model = tracing::field::Empty,
886 gen_ai.usage.output_tokens = tracing::field::Empty,
887 gen_ai.usage.input_tokens = tracing::field::Empty,
888 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
889 )
890 } else {
891 tracing::Span::current()
892 };
893
894 let client = self.client.clone();
895 let mut event_source = crate::http_client::sse::GenericEventSource::new(client, req);
896
897 let stream = tracing_futures::Instrument::instrument(
898 stream! {
899 let mut final_usage = responses_api::ResponsesUsage::new();
900 let mut tool_calls: Vec<streaming::RawStreamingChoice<CopilotStreamingResponse>> = Vec::new();
901 let mut tool_call_internal_ids: HashMap<String, String> = HashMap::new();
902 let span = tracing::Span::current();
903
904 let mut terminated_with_error = false;
905
906 while let Some(event_result) = event_source.next().await {
907 match event_result {
908 Ok(crate::http_client::sse::Event::Open) => continue,
909 Ok(crate::http_client::sse::Event::Message(evt)) => {
910 if evt.data.trim().is_empty() {
911 continue;
912 }
913
914 let Ok(data) = serde_json::from_str::<responses_api::streaming::StreamingCompletionChunk>(&evt.data) else {
915 continue;
916 };
917
918 if let responses_api::streaming::StreamingCompletionChunk::Delta(chunk) = &data {
919 use responses_api::streaming::{ItemChunkKind, StreamingItemDoneOutput};
920
921 match &chunk.data {
922 ItemChunkKind::OutputItemAdded(message) => {
923 if let StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } = message {
924 let internal_call_id = tool_call_internal_ids
925 .entry(func.id.clone())
926 .or_insert_with(|| nanoid::nanoid!())
927 .clone();
928 yield Ok(RawStreamingChoice::ToolCallDelta {
929 id: func.id.clone(),
930 internal_call_id,
931 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
932 });
933 }
934 }
935 ItemChunkKind::OutputItemDone(message) => match message {
936 StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } => {
937 let internal_id = tool_call_internal_ids
938 .entry(func.id.clone())
939 .or_insert_with(|| nanoid::nanoid!())
940 .clone();
941 let raw_tool_call = streaming::RawStreamingToolCall::new(
942 func.id.clone(),
943 func.name.clone(),
944 func.arguments.clone(),
945 )
946 .with_internal_call_id(internal_id)
947 .with_call_id(func.call_id.clone());
948 tool_calls.push(RawStreamingChoice::ToolCall(raw_tool_call));
949 }
950 StreamingItemDoneOutput { item: responses_api::Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
951 for reasoning_choice in responses_api::streaming::reasoning_choices_from_done_item(
952 id,
953 summary,
954 encrypted_content.as_deref(),
955 ) {
956 match reasoning_choice {
957 RawStreamingChoice::Reasoning { id, content } => {
958 yield Ok(RawStreamingChoice::Reasoning { id, content });
959 }
960 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
961 yield Ok(RawStreamingChoice::ReasoningDelta { id, reasoning });
962 }
963 _ => {}
964 }
965 }
966 }
967 StreamingItemDoneOutput { item: responses_api::Output::Message(msg), .. } => {
968 yield Ok(RawStreamingChoice::MessageId(msg.id.clone()));
969 }
970 StreamingItemDoneOutput { item: responses_api::Output::Unknown, .. } => {}
971 },
972 ItemChunkKind::OutputTextDelta(delta) => {
973 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
974 }
975 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
976 yield Ok(RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
977 }
978 ItemChunkKind::RefusalDelta(delta) => {
979 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
980 }
981 ItemChunkKind::FunctionCallArgsDelta(delta) => {
982 let internal_call_id = tool_call_internal_ids
983 .entry(delta.item_id.clone())
984 .or_insert_with(|| nanoid::nanoid!())
985 .clone();
986 yield Ok(RawStreamingChoice::ToolCallDelta {
987 id: delta.item_id.clone(),
988 internal_call_id,
989 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
990 })
991 }
992 _ => continue,
993 }
994 }
995
996 if let responses_api::streaming::StreamingCompletionChunk::Response(chunk) = data {
997 let responses_api::streaming::ResponseChunk { kind, response, .. } = *chunk;
998 match kind {
999 responses_api::streaming::ResponseChunkKind::ResponseCompleted => {
1000 span.record("gen_ai.response.id", response.id.as_str());
1001 span.record("gen_ai.response.model", response.model.as_str());
1002 if let Some(usage) = response.usage {
1003 final_usage = usage;
1004 }
1005 }
1006 responses_api::streaming::ResponseChunkKind::ResponseFailed
1007 | responses_api::streaming::ResponseChunkKind::ResponseIncomplete => {
1008 let error = response
1009 .error
1010 .as_ref()
1011 .map(|err| err.message.clone())
1012 .unwrap_or_else(|| "Copilot response stream failed".into());
1013 terminated_with_error = true;
1014 yield Err(CompletionError::ProviderError(error));
1015 break;
1016 }
1017 _ => continue,
1018 }
1019 }
1020 }
1021 Err(crate::http_client::Error::StreamEnded) => {
1022 break;
1023 }
1024 Err(error) => {
1025 terminated_with_error = true;
1026 yield Err(CompletionError::ProviderError(error.to_string()));
1027 break;
1028 }
1029 }
1030 }
1031
1032 event_source.close();
1033
1034 if terminated_with_error {
1035 return;
1036 }
1037
1038 for tool_call in &tool_calls {
1039 yield Ok(tool_call.to_owned())
1040 }
1041
1042 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
1043 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
1044 span.record(
1045 "gen_ai.usage.cache_read.input_tokens",
1046 final_usage
1047 .input_tokens_details
1048 .as_ref()
1049 .map(|details| details.cached_tokens)
1050 .unwrap_or(0),
1051 );
1052
1053 yield Ok(RawStreamingChoice::FinalResponse(
1054 CopilotStreamingResponse::Responses(
1055 responses_api::streaming::StreamingCompletionResponse { usage: final_usage }
1056 )
1057 ));
1058 },
1059 span,
1060 );
1061
1062 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
1063 }
1064}
1065
1066impl<H> completion::CompletionModel for CompletionModel<H>
1067where
1068 Client<H>: HttpClientExt + Clone + Debug + 'static,
1069 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1070{
1071 type Response = CopilotCompletionResponse;
1072 type StreamingResponse = CopilotStreamingResponse;
1073 type Client = Client<H>;
1074
1075 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1076 Self::new(client.clone(), model)
1077 }
1078
1079 async fn completion(
1080 &self,
1081 completion_request: completion::CompletionRequest,
1082 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
1083 match self.route() {
1084 CompletionRoute::ChatCompletions => self.completion_chat(completion_request).await,
1085 CompletionRoute::Responses => self.completion_responses(completion_request).await,
1086 }
1087 }
1088
1089 async fn stream(
1090 &self,
1091 completion_request: completion::CompletionRequest,
1092 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1093 match self.route() {
1094 CompletionRoute::ChatCompletions => self.stream_chat(completion_request).await,
1095 CompletionRoute::Responses => self.stream_responses(completion_request).await,
1096 }
1097 }
1098}
1099
1100#[derive(Clone)]
1101pub struct EmbeddingModel<H = reqwest::Client> {
1102 client: Client<H>,
1103 pub model: String,
1104 pub encoding_format: Option<openai::EncodingFormat>,
1105 pub user: Option<String>,
1106 ndims: usize,
1107}
1108
1109#[derive(Deserialize)]
1110struct CopilotEmbeddingResponse {
1111 data: Vec<CopilotEmbeddingData>,
1112}
1113
1114#[derive(Deserialize)]
1115struct CopilotEmbeddingData {
1116 embedding: Vec<serde_json::Number>,
1117}
1118
1119impl<H> EmbeddingModel<H>
1120where
1121 Client<H>: HttpClientExt + Clone + Debug + 'static,
1122 H: Clone + Default + Debug + 'static,
1123{
1124 pub fn new(client: Client<H>, model: impl Into<String>, ndims: usize) -> Self {
1125 Self {
1126 client,
1127 model: model.into(),
1128 encoding_format: None,
1129 user: None,
1130 ndims,
1131 }
1132 }
1133}
1134
1135impl<H> embeddings::EmbeddingModel for EmbeddingModel<H>
1136where
1137 Client<H>: HttpClientExt + Clone + Debug + WasmCompatSend + WasmCompatSync + 'static,
1138 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1139{
1140 const MAX_DOCUMENTS: usize = 1024;
1141 type Client = Client<H>;
1142
1143 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
1144 let model = model.into();
1145 let dims = ndims.unwrap_or(match model.as_str() {
1146 TEXT_EMBEDDING_3_LARGE => 3072,
1147 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
1148 _ => 0,
1149 });
1150 Self::new(client.clone(), model, dims)
1151 }
1152
1153 fn ndims(&self) -> usize {
1154 self.ndims
1155 }
1156
1157 async fn embed_texts(
1158 &self,
1159 documents: impl IntoIterator<Item = String>,
1160 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
1161 let documents = documents.into_iter().collect::<Vec<_>>();
1162 let auth = self
1163 .client
1164 .ext()
1165 .auth
1166 .auth_context()
1167 .await
1168 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
1169
1170 let headers = default_headers(&auth.api_key, "user", false);
1171 let mut body = json!({
1172 "model": self.model,
1173 "input": documents,
1174 });
1175
1176 let body_object = body.as_object_mut().ok_or_else(|| {
1177 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
1178 })?;
1179
1180 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
1181 body_object.insert("dimensions".to_owned(), json!(self.ndims));
1182 }
1183 if let Some(encoding_format) = &self.encoding_format {
1184 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
1185 }
1186 if let Some(user) = &self.user {
1187 body_object.insert("user".to_owned(), json!(user));
1188 }
1189
1190 let req = apply_headers(
1191 post_with_auth_base(&self.client, &auth, "/embeddings", Transport::Http)?,
1192 &headers,
1193 )
1194 .body(serde_json::to_vec(&body)?)
1195 .map_err(|err| EmbeddingError::HttpError(err.into()))?;
1196
1197 let response = self.client.send(req).await?;
1198 if response.status().is_success() {
1199 let body: Vec<u8> = response.into_body().await?;
1200 #[derive(Deserialize)]
1201 struct NestedApiError {
1202 error: NestedApiErrorMessage,
1203 }
1204
1205 #[derive(Deserialize)]
1206 struct NestedApiErrorMessage {
1207 message: String,
1208 }
1209
1210 let body: CopilotEmbeddingResponse = match serde_json::from_slice(&body) {
1211 Ok(parsed) => parsed,
1212 Err(parse_error) => {
1213 if let Ok(err) = serde_json::from_slice::<NestedApiError>(&body) {
1214 return Err(EmbeddingError::ProviderError(err.error.message));
1215 }
1216
1217 let preview = String::from_utf8_lossy(&body);
1218 let preview = if preview.len() > 512 {
1219 format!("{}...", &preview[..512])
1220 } else {
1221 preview.into_owned()
1222 };
1223
1224 return Err(EmbeddingError::ProviderError(format!(
1225 "Failed to parse Copilot embeddings response: {parse_error}; body: {preview}"
1226 )));
1227 }
1228 };
1229
1230 Ok(body
1231 .data
1232 .into_iter()
1233 .zip(documents.into_iter())
1234 .map(|(embedding, document)| embeddings::Embedding {
1235 document,
1236 vec: embedding
1237 .embedding
1238 .into_iter()
1239 .filter_map(|n| n.as_f64())
1240 .collect(),
1241 })
1242 .collect())
1243 } else {
1244 let text = http_client::text(response).await?;
1245 Err(EmbeddingError::ProviderError(text))
1246 }
1247 }
1248}
1249
1250#[derive(Deserialize, Debug)]
1251struct ChatStreamingFunction {
1252 name: Option<String>,
1253 arguments: Option<String>,
1254}
1255
1256#[derive(Deserialize, Debug)]
1257struct ChatStreamingToolCall {
1258 index: usize,
1259 id: Option<String>,
1260 function: ChatStreamingFunction,
1261}
1262
1263impl From<&ChatStreamingToolCall> for CompatibleToolCallChunk {
1264 fn from(value: &ChatStreamingToolCall) -> Self {
1265 Self {
1266 index: value.index,
1267 id: value.id.clone(),
1268 name: value.function.name.clone(),
1269 arguments: value.function.arguments.clone(),
1270 }
1271 }
1272}
1273
1274#[derive(Deserialize, Debug, Default)]
1275struct ChatStreamingDelta {
1276 #[serde(default)]
1277 content: Option<String>,
1278 #[serde(default)]
1279 reasoning_content: Option<String>,
1280 #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")]
1281 tool_calls: Vec<ChatStreamingToolCall>,
1282}
1283
1284#[derive(Deserialize, Debug, PartialEq)]
1285#[serde(rename_all = "snake_case")]
1286enum ChatFinishReason {
1287 ToolCalls,
1288 Stop,
1289 ContentFilter,
1290 Length,
1291 #[serde(untagged)]
1292 Other(String),
1293}
1294
1295#[derive(Deserialize, Debug)]
1296struct ChatStreamingChoice {
1297 delta: ChatStreamingDelta,
1298 finish_reason: Option<ChatFinishReason>,
1299}
1300
1301#[derive(Deserialize, Debug)]
1302struct ChatStreamingChunk {
1303 id: Option<String>,
1304 model: Option<String>,
1305 choices: Vec<ChatStreamingChoice>,
1306 usage: Option<openai::completion::Usage>,
1307}
1308
1309#[derive(Clone, Copy)]
1310struct CopilotChatCompatibleProfile;
1311
1312impl CompatibleStreamProfile for CopilotChatCompatibleProfile {
1313 type Usage = openai::completion::Usage;
1314 type Detail = ();
1315 type FinalResponse = CopilotStreamingResponse;
1316
1317 fn normalize_chunk(
1318 &self,
1319 data: &str,
1320 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
1321 let data = match serde_json::from_str::<ChatStreamingChunk>(data) {
1322 Ok(data) => data,
1323 Err(error) => {
1324 tracing::debug!(?error, "Couldn't parse Copilot chat SSE payload");
1325 return Ok(None);
1326 }
1327 };
1328
1329 Ok(Some(
1330 openai_chat_completions_compatible::normalize_first_choice_chunk(
1331 data.id,
1332 data.model,
1333 data.usage,
1334 &data.choices,
1335 |choice| CompatibleChoiceData {
1336 finish_reason: if choice.finish_reason == Some(ChatFinishReason::ToolCalls) {
1337 CompatibleFinishReason::ToolCalls
1338 } else {
1339 CompatibleFinishReason::Other
1340 },
1341 text: choice.delta.content.clone(),
1342 reasoning: choice.delta.reasoning_content.clone(),
1343 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
1344 &choice.delta.tool_calls,
1345 ),
1346 details: Vec::new(),
1347 },
1348 ),
1349 ))
1350 }
1351
1352 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
1353 CopilotStreamingResponse::Chat(openai::completion::streaming::StreamingCompletionResponse {
1354 usage,
1355 })
1356 }
1357
1358 fn uses_distinct_tool_call_eviction(&self) -> bool {
1359 true
1360 }
1361}
1362
1363async fn send_copilot_chat_streaming_request<T>(
1364 http_client: T,
1365 req: Request<Vec<u8>>,
1366) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError>
1367where
1368 T: HttpClientExt + Clone + 'static,
1369{
1370 openai_chat_completions_compatible::send_compatible_streaming_request(
1371 http_client,
1372 req,
1373 CopilotChatCompatibleProfile,
1374 )
1375 .await
1376}
1377
1378fn default_token_dir() -> Option<PathBuf> {
1379 config_dir().map(|dir| dir.join("github_copilot"))
1380}
1381
1382fn config_dir() -> Option<PathBuf> {
1383 #[cfg(target_os = "windows")]
1384 {
1385 std::env::var_os("APPDATA").map(PathBuf::from)
1386 }
1387
1388 #[cfg(not(target_os = "windows"))]
1389 {
1390 std::env::var_os("XDG_CONFIG_HOME")
1391 .map(PathBuf::from)
1392 .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
1393 }
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use super::{
1399 ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute,
1400 TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token,
1401 route_for_model,
1402 };
1403 use crate::client::CompletionClient;
1404 use crate::completion::CompletionModel;
1405 use crate::http_client::mock::MockStreamingClient;
1406 use crate::http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response};
1407 use crate::providers::internal::openai_chat_completions_compatible::test_support::{
1408 sse_bytes_from_data_lines, sse_bytes_from_json_events,
1409 };
1410 use crate::streaming::StreamedAssistantContent;
1411 use bytes::Bytes;
1412 use futures::StreamExt;
1413 use std::collections::HashMap;
1414 use std::future::{self, Future};
1415 use std::sync::{Arc, Mutex};
1416
1417 #[derive(Debug, Clone, PartialEq, Eq)]
1418 struct CapturedRequest {
1419 uri: String,
1420 body: Bytes,
1421 }
1422
1423 #[derive(Debug, Clone, Default)]
1424 struct RecordingHttpClient {
1425 requests: Arc<Mutex<Vec<CapturedRequest>>>,
1426 response_body: Bytes,
1427 }
1428
1429 impl RecordingHttpClient {
1430 fn new(response_body: impl Into<Bytes>) -> Self {
1431 Self {
1432 requests: Arc::new(Mutex::new(Vec::new())),
1433 response_body: response_body.into(),
1434 }
1435 }
1436
1437 fn requests(&self) -> Vec<CapturedRequest> {
1438 self.requests.lock().expect("requests lock").clone()
1439 }
1440 }
1441
1442 impl HttpClientExt for RecordingHttpClient {
1443 fn send<T, U>(
1444 &self,
1445 req: Request<T>,
1446 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1447 + crate::wasm_compat::WasmCompatSend
1448 + 'static
1449 where
1450 T: Into<Bytes> + crate::wasm_compat::WasmCompatSend,
1451 U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1452 {
1453 let requests = Arc::clone(&self.requests);
1454 let response_body = self.response_body.clone();
1455 let (parts, body) = req.into_parts();
1456 let body = body.into();
1457
1458 requests
1459 .lock()
1460 .expect("requests lock")
1461 .push(CapturedRequest {
1462 uri: parts.uri.to_string(),
1463 body,
1464 });
1465
1466 async move {
1467 let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
1468 Response::builder()
1469 .status(http::StatusCode::OK)
1470 .body(body)
1471 .map_err(http_client::Error::Protocol)
1472 }
1473 }
1474
1475 fn send_multipart<U>(
1476 &self,
1477 _req: Request<MultipartForm>,
1478 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1479 + crate::wasm_compat::WasmCompatSend
1480 + 'static
1481 where
1482 U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1483 {
1484 future::ready(Err(http_client::Error::InvalidStatusCode(
1485 http::StatusCode::NOT_IMPLEMENTED,
1486 )))
1487 }
1488
1489 fn send_streaming<T>(
1490 &self,
1491 _req: Request<T>,
1492 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>>
1493 + crate::wasm_compat::WasmCompatSend
1494 where
1495 T: Into<Bytes>,
1496 {
1497 future::ready(Err(http_client::Error::InvalidStatusCode(
1498 http::StatusCode::NOT_IMPLEMENTED,
1499 )))
1500 }
1501 }
1502
1503 #[derive(Debug, Clone, Default)]
1504 struct SequencedStreamingHttpClient {
1505 chunks: Arc<Mutex<Option<Vec<http_client::Result<Bytes>>>>>,
1506 }
1507
1508 impl SequencedStreamingHttpClient {
1509 fn new(chunks: Vec<http_client::Result<Bytes>>) -> Self {
1510 Self {
1511 chunks: Arc::new(Mutex::new(Some(chunks))),
1512 }
1513 }
1514 }
1515
1516 impl HttpClientExt for SequencedStreamingHttpClient {
1517 fn send<T, U>(
1518 &self,
1519 _req: Request<T>,
1520 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1521 + crate::wasm_compat::WasmCompatSend
1522 + 'static
1523 where
1524 T: Into<Bytes> + crate::wasm_compat::WasmCompatSend,
1525 U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1526 {
1527 future::ready(Err(http_client::Error::InvalidStatusCode(
1528 http::StatusCode::NOT_IMPLEMENTED,
1529 )))
1530 }
1531
1532 fn send_multipart<U>(
1533 &self,
1534 _req: Request<MultipartForm>,
1535 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>>
1536 + crate::wasm_compat::WasmCompatSend
1537 + 'static
1538 where
1539 U: From<Bytes> + crate::wasm_compat::WasmCompatSend + 'static,
1540 {
1541 future::ready(Err(http_client::Error::InvalidStatusCode(
1542 http::StatusCode::NOT_IMPLEMENTED,
1543 )))
1544 }
1545
1546 fn send_streaming<T>(
1547 &self,
1548 _req: Request<T>,
1549 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>>
1550 + crate::wasm_compat::WasmCompatSend
1551 where
1552 T: Into<Bytes>,
1553 {
1554 let chunks = self
1555 .chunks
1556 .lock()
1557 .expect("chunks lock")
1558 .take()
1559 .expect("streaming chunks should only be consumed once");
1560
1561 async move {
1562 let byte_stream = futures::stream::iter(chunks);
1563 let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
1564
1565 Response::builder()
1566 .status(http::StatusCode::OK)
1567 .header(http::header::CONTENT_TYPE, "text/event-stream")
1568 .body(boxed_stream)
1569 .map_err(http_client::Error::Protocol)
1570 }
1571 }
1572 }
1573
1574 fn env_map(entries: &[(&str, &str)]) -> HashMap<String, String> {
1575 entries
1576 .iter()
1577 .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1578 .collect()
1579 }
1580
1581 fn minimal_chat_response() -> &'static str {
1582 r#"{
1583 "id": "chatcmpl-123",
1584 "model": "gpt-4o",
1585 "choices": [{
1586 "index": 0,
1587 "message": {
1588 "role": "assistant",
1589 "content": "hello"
1590 },
1591 "finish_reason": "stop"
1592 }],
1593 "usage": {
1594 "prompt_tokens": 4,
1595 "total_tokens": 7
1596 }
1597 }"#
1598 }
1599
1600 fn minimal_responses_response() -> &'static str {
1601 r#"{
1602 "id": "resp_123",
1603 "object": "response",
1604 "created_at": 1700000000,
1605 "status": "completed",
1606 "error": null,
1607 "incomplete_details": null,
1608 "instructions": null,
1609 "max_output_tokens": null,
1610 "model": "gpt-5.3-codex",
1611 "usage": {
1612 "input_tokens": 4,
1613 "input_tokens_details": {
1614 "cached_tokens": 0
1615 },
1616 "output_tokens": 3,
1617 "output_tokens_details": {
1618 "reasoning_tokens": 0
1619 },
1620 "total_tokens": 7
1621 },
1622 "output": [{
1623 "type": "message",
1624 "id": "msg_123",
1625 "role": "assistant",
1626 "status": "completed",
1627 "content": [{
1628 "type": "output_text",
1629 "text": "hello"
1630 }]
1631 }],
1632 "tools": []
1633 }"#
1634 }
1635
1636 fn minimal_embeddings_response() -> &'static str {
1637 r#"{
1638 "data": [
1639 {
1640 "embedding": [0.1, 0.2, 0.3]
1641 },
1642 {
1643 "embedding": [0.4, 0.5, 0.6]
1644 }
1645 ]
1646 }"#
1647 }
1648
1649 #[test]
1650 fn deserialize_standard_openai_response() {
1651 let json = r#"{
1652 "id": "chatcmpl-abc123",
1653 "object": "chat.completion",
1654 "created": 1700000000,
1655 "model": "gpt-4o",
1656 "choices": [{
1657 "index": 0,
1658 "message": {
1659 "role": "assistant",
1660 "content": "Hello!"
1661 },
1662 "finish_reason": "stop"
1663 }],
1664 "usage": {
1665 "prompt_tokens": 10,
1666 "completion_tokens": 5,
1667 "total_tokens": 15
1668 }
1669 }"#;
1670
1671 let response: ChatCompletionResponse =
1672 serde_json::from_str(json).expect("standard OpenAI response should deserialize");
1673 assert_eq!(response.id, "chatcmpl-abc123");
1674 assert_eq!(response.object.as_deref(), Some("chat.completion"));
1675 assert_eq!(response.created, Some(1700000000));
1676 assert_eq!(response.model, "gpt-4o");
1677 assert_eq!(response.choices.len(), 1);
1678 assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop"));
1679 }
1680
1681 #[test]
1682 fn deserialize_copilot_response_without_object_and_created() {
1683 let response: ChatCompletionResponse = serde_json::from_str(minimal_chat_response())
1684 .expect("Copilot response should deserialize");
1685
1686 assert_eq!(response.id, "chatcmpl-123");
1687 assert_eq!(response.object, None);
1688 assert_eq!(response.created, None);
1689 assert_eq!(response.model, "gpt-4o");
1690 assert_eq!(response.choices.len(), 1);
1691 }
1692
1693 #[test]
1694 fn deserialize_copilot_response_without_finish_reason() {
1695 let json = r#"{
1696 "id": "chatcmpl-claude-001",
1697 "model": "claude-3.5-sonnet",
1698 "choices": [{
1699 "message": {
1700 "role": "assistant",
1701 "content": "Here is my analysis."
1702 }
1703 }],
1704 "usage": {
1705 "prompt_tokens": 50,
1706 "total_tokens": 80
1707 }
1708 }"#;
1709
1710 let response: ChatCompletionResponse =
1711 serde_json::from_str(json).expect("Claude-via-Copilot response should deserialize");
1712
1713 assert_eq!(response.model, "claude-3.5-sonnet");
1714 assert_eq!(response.choices[0].finish_reason, None);
1715 assert_eq!(response.choices[0].index, 0);
1716 }
1717
1718 #[test]
1719 fn error_response_with_message_field() {
1720 let json = r#"{"message": "rate limit exceeded"}"#;
1721 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("message-shaped error");
1722
1723 assert_eq!(err.error_message(), "rate limit exceeded");
1724 }
1725
1726 #[test]
1727 fn error_response_with_error_field() {
1728 let json = r#"{"error": "model not found"}"#;
1729 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("error-shaped error");
1730
1731 assert_eq!(err.error_message(), "model not found");
1732 }
1733
1734 #[test]
1735 fn routes_codex_models_to_responses() {
1736 assert_eq!(route_for_model("gpt-5.3-codex"), CompletionRoute::Responses);
1737 assert_eq!(
1738 route_for_model("gpt-5.1-CODEX-mini"),
1739 CompletionRoute::Responses
1740 );
1741 assert_eq!(route_for_model("gpt-5.2"), CompletionRoute::ChatCompletions);
1742 assert_eq!(
1743 route_for_model("claude-sonnet-4.5"),
1744 CompletionRoute::ChatCompletions
1745 );
1746 }
1747
1748 #[tokio::test]
1749 async fn completion_model_routes_chat_requests_to_chat_completions() {
1750 let http_client = RecordingHttpClient::new(minimal_chat_response());
1751 let client = Client::builder()
1752 .api_key("copilot-token")
1753 .http_client(http_client.clone())
1754 .build()
1755 .expect("build client");
1756 let model = client.completion_model("gpt-4o");
1757 let request = model.completion_request("hello").build();
1758
1759 let _response = model.completion(request).await.expect("chat completion");
1760
1761 let requests = http_client.requests();
1762 assert_eq!(requests.len(), 1);
1763 assert!(requests[0].uri.ends_with("/chat/completions"));
1764 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-4o\""));
1765 }
1766
1767 #[tokio::test]
1768 async fn completion_model_routes_codex_requests_to_responses() {
1769 let http_client = RecordingHttpClient::new(minimal_responses_response());
1770 let client = Client::builder()
1771 .api_key("copilot-token")
1772 .http_client(http_client.clone())
1773 .build()
1774 .expect("build client");
1775 let model = client.completion_model("gpt-5.3-codex");
1776 let request = model.completion_request("hello").build();
1777
1778 let _response = model
1779 .completion(request)
1780 .await
1781 .expect("responses completion");
1782
1783 let requests = http_client.requests();
1784 assert_eq!(requests.len(), 1);
1785 assert!(requests[0].uri.ends_with("/responses"));
1786 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-5.3-codex\""));
1787 }
1788
1789 #[tokio::test]
1790 async fn embeddings_accept_minimal_copilot_response_shape() {
1791 use crate::client::EmbeddingsClient;
1792 use crate::embeddings::EmbeddingModel as _;
1793
1794 let http_client = RecordingHttpClient::new(minimal_embeddings_response());
1795 let client = Client::builder()
1796 .api_key("copilot-token")
1797 .http_client(http_client.clone())
1798 .build()
1799 .expect("build client");
1800 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1801
1802 let embeddings = model
1803 .embed_texts(["one".to_string(), "two".to_string()])
1804 .await
1805 .expect("embeddings should deserialize");
1806
1807 assert_eq!(embeddings.len(), 2);
1808 assert_eq!(embeddings[0].vec, vec![0.1, 0.2, 0.3]);
1809 assert_eq!(embeddings[1].vec, vec![0.4, 0.5, 0.6]);
1810
1811 let requests = http_client.requests();
1812 assert_eq!(requests.len(), 1);
1813 assert!(requests[0].uri.ends_with("/embeddings"));
1814 assert!(
1815 String::from_utf8_lossy(&requests[0].body)
1816 .contains("\"model\":\"text-embedding-3-small\"")
1817 );
1818 }
1819
1820 #[tokio::test]
1821 async fn responses_stream_terminates_after_terminal_error() {
1822 let tool_call_done = serde_json::json!({
1823 "type": "response.output_item.done",
1824 "sequence_number": 1,
1825 "item": {
1826 "type": "function_call",
1827 "id": "fc_123",
1828 "arguments": "{}",
1829 "call_id": "call_123",
1830 "name": "example_tool",
1831 "status": "completed"
1832 }
1833 });
1834 let failed = serde_json::json!({
1835 "type": "response.failed",
1836 "sequence_number": 2,
1837 "response": {
1838 "id": "resp_123",
1839 "object": "response",
1840 "created_at": 1700000000,
1841 "status": "failed",
1842 "error": {
1843 "code": "server_error",
1844 "message": "Copilot response stream failed"
1845 },
1846 "incomplete_details": null,
1847 "instructions": null,
1848 "max_output_tokens": null,
1849 "model": "gpt-5.3-codex",
1850 "usage": null,
1851 "output": [],
1852 "tools": []
1853 }
1854 });
1855 let http_client = MockStreamingClient {
1856 sse_bytes: sse_bytes_from_json_events(&[tool_call_done, failed]),
1857 };
1858 let client = Client::builder()
1859 .api_key("copilot-token")
1860 .http_client(http_client)
1861 .build()
1862 .expect("build client");
1863 let model = client.completion_model("gpt-5.3-codex");
1864 let request = model.completion_request("hello").build();
1865 let mut stream = model.stream(request).await.expect("stream should start");
1866
1867 let err = match stream.next().await.expect("stream should yield an item") {
1868 Ok(_) => panic!("stream should surface a provider error"),
1869 Err(err) => err,
1870 };
1871 assert_eq!(
1872 err.to_string(),
1873 "ProviderError: Copilot response stream failed"
1874 );
1875 assert!(
1876 stream.next().await.is_none(),
1877 "responses stream should terminate immediately after a terminal error"
1878 );
1879 }
1880
1881 #[tokio::test]
1882 async fn chat_stream_terminates_after_transport_error() {
1883 let chunks = vec![
1884 Ok(sse_bytes_from_data_lines([
1885 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
1886 ])),
1887 Err(http_client::Error::InvalidStatusCode(
1888 http::StatusCode::BAD_GATEWAY,
1889 )),
1890 ];
1891
1892 let http_client = SequencedStreamingHttpClient::new(chunks);
1893 let client = Client::builder()
1894 .api_key("copilot-token")
1895 .http_client(http_client)
1896 .build()
1897 .expect("build client");
1898 let model = client.completion_model("gpt-4o");
1899 let request = model.completion_request("hello").build();
1900 let mut stream = model.stream(request).await.expect("stream should start");
1901
1902 let mut saw_error = false;
1903 while let Some(item) = stream.next().await {
1904 match item {
1905 Ok(StreamedAssistantContent::ToolCallDelta { .. }) => {}
1906 Err(err) => {
1907 assert_eq!(
1908 err.to_string(),
1909 "ProviderError: Invalid status code: 502 Bad Gateway"
1910 );
1911 saw_error = true;
1912 break;
1913 }
1914 Ok(_) => panic!("unexpected non-error stream item before transport failure"),
1915 }
1916 }
1917
1918 assert!(saw_error, "stream should surface the transport error");
1919 assert!(
1920 stream.next().await.is_none(),
1921 "chat stream should terminate immediately after a transport error"
1922 );
1923 }
1924
1925 #[test]
1926 fn env_api_key_prefers_github_prefixed_vars() {
1927 let env = env_map(&[
1928 ("COPILOT_API_KEY", "copilot-key"),
1929 ("GITHUB_COPILOT_API_KEY", "github-key"),
1930 ("GITHUB_TOKEN", "bootstrap-token"),
1931 ]);
1932 let get = |name: &str| env.get(name).cloned();
1933
1934 assert_eq!(env_api_key(&get).as_deref(), Some("github-key"));
1935 }
1936
1937 #[test]
1938 fn env_github_access_token_prefers_explicit_bootstrap_var() {
1939 let env = env_map(&[
1940 ("COPILOT_GITHUB_ACCESS_TOKEN", "explicit-bootstrap"),
1941 ("GITHUB_TOKEN", "fallback-bootstrap"),
1942 ]);
1943 let get = |name: &str| env.get(name).cloned();
1944
1945 assert_eq!(
1946 env_github_access_token(&get).as_deref(),
1947 Some("explicit-bootstrap")
1948 );
1949 }
1950
1951 #[test]
1952 fn env_base_url_prefers_github_prefixed_vars() {
1953 let env = env_map(&[
1954 ("COPILOT_BASE_URL", "https://copilot.example"),
1955 ("GITHUB_COPILOT_API_BASE", "https://github.example"),
1956 ]);
1957 let get = |name: &str| env.get(name).cloned();
1958
1959 assert_eq!(
1960 env_base_url(&get).as_deref(),
1961 Some("https://github.example")
1962 );
1963 }
1964
1965 #[test]
1966 fn env_without_api_key_falls_back_to_oauth() {
1967 let env = env_map(&[("COPILOT_BASE_URL", "https://copilot.example")]);
1968 let get = |name: &str| env.get(name).cloned();
1969
1970 assert!(env_api_key(&get).is_none());
1971 assert!(env_github_access_token(&get).is_none());
1972 assert_eq!(
1973 env_base_url(&get).as_deref(),
1974 Some("https://copilot.example")
1975 );
1976 }
1977
1978 #[test]
1979 fn env_github_token_is_not_treated_as_copilot_api_key() {
1980 let env = env_map(&[("GITHUB_TOKEN", "bootstrap-token")]);
1981 let get = |name: &str| env.get(name).cloned();
1982
1983 assert!(env_api_key(&get).is_none());
1984 assert_eq!(
1985 env_github_access_token(&get).as_deref(),
1986 Some("bootstrap-token")
1987 );
1988 }
1989}