1mod auth;
24
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider, ProviderBuilder,
27 ProviderClient, Transport,
28};
29use crate::completion::{self, CompletionError, GetTokenUsage};
30use crate::embeddings::{self, EmbeddingError};
31use crate::http_client::{self, HttpClientExt};
32use crate::model::{Model, ModelList, ModelListingError};
33use crate::providers::internal::openai_chat_completions_compatible::{
34 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
35 CompatibleToolCallChunk,
36};
37use crate::providers::openai;
38use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest};
39use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse};
40use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
41use async_stream::stream;
42use futures::StreamExt;
43use http::Request;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use std::borrow::Cow;
47use std::collections::HashMap;
48use std::fmt::Debug;
49use std::path::{Path, PathBuf};
50use tracing::info_span;
51use tracing_futures::Instrument as _;
52
53const GITHUB_COPILOT_API_BASE_URL: &str = "https://api.githubcopilot.com";
54const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
55const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
56const API_VERSION: &str = "2025-04-01";
57
58pub const GPT_4: &str = "gpt-4";
60pub const GPT_4O: &str = "gpt-4o";
62pub const GPT_4O_MINI: &str = "gpt-4o-mini";
64pub const GPT_4_1: &str = "gpt-4.1";
66pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
68pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
70pub const GPT_5_3_CODEX: &str = "gpt-5.3-codex";
72pub const GPT_5_1_CODEX: &str = "gpt-5.1-codex";
74pub const GPT_5_5: &str = "gpt-5.5";
76pub const GPT_5_4: &str = "gpt-5.4";
78pub const CLAUDE_SONNET_4: &str = "claude-sonnet-4";
80pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4.6";
82pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4.6";
84pub const CLAUDE_OPUS_4_7: &str = "claude-opus-4.7";
86pub const CLAUDE_3_5_SONNET: &str = "claude-3.5-sonnet";
88pub const GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
90pub const GEMINI_3_1_PRO_FLASH: &str = "gemini-3.1-pro-preview";
92pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash-001";
94pub const O3_MINI: &str = "o3-mini";
96pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
98pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
100pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
102
103pub use openai::EncodingFormat;
104
105#[derive(Clone)]
106pub enum CopilotAuth {
107 ApiKey(String),
108 GitHubAccessToken(String),
109 OAuth,
110}
111
112impl ApiKey for CopilotAuth {}
113
114impl<S> From<S> for CopilotAuth
115where
116 S: Into<String>,
117{
118 fn from(value: S) -> Self {
119 Self::ApiKey(value.into())
120 }
121}
122
123impl Debug for CopilotAuth {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 Self::ApiKey(_) => f.write_str("ApiKey(<redacted>)"),
127 Self::GitHubAccessToken(_) => f.write_str("GitHubAccessToken(<redacted>)"),
128 Self::OAuth => f.write_str("OAuth"),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct CopilotBuilder {
135 access_token_file: Option<PathBuf>,
136 api_key_file: Option<PathBuf>,
137 device_code_handler: auth::DeviceCodeHandler,
138}
139
140#[derive(Clone)]
141pub struct CopilotExt {
142 auth: auth::Authenticator,
143}
144
145impl Debug for CopilotExt {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("CopilotExt")
148 .field("auth", &self.auth)
149 .finish()
150 }
151}
152
153pub type Client<H = reqwest::Client> = client::Client<CopilotExt, H>;
154pub type ClientBuilder<H = crate::markers::Missing> =
155 client::ClientBuilder<CopilotBuilder, CopilotAuth, H>;
156
157impl Default for CopilotBuilder {
158 fn default() -> Self {
159 let token_dir = default_token_dir();
160 Self {
161 access_token_file: token_dir.as_ref().map(|dir| dir.join("access-token")),
162 api_key_file: token_dir.map(|dir| dir.join("api-key.json")),
163 device_code_handler: auth::DeviceCodeHandler::default(),
164 }
165 }
166}
167
168impl Provider for CopilotExt {
169 type Builder = CopilotBuilder;
170
171 const VERIFY_PATH: &'static str = "";
172}
173
174impl<H> Capabilities<H> for CopilotExt {
175 type Completion = Capable<CompletionModel<H>>;
176 type Embeddings = Capable<EmbeddingModel<H>>;
177 type Transcription = Nothing;
178 type ModelListing = Capable<CopilotModelLister<H>>;
179 #[cfg(feature = "image")]
180 type ImageGeneration = Nothing;
181 #[cfg(feature = "audio")]
182 type AudioGeneration = Nothing;
183}
184
185impl DebugExt for CopilotExt {}
186
187impl ProviderBuilder for CopilotBuilder {
188 type Extension<H>
189 = CopilotExt
190 where
191 H: HttpClientExt;
192 type ApiKey = CopilotAuth;
193
194 const BASE_URL: &'static str = GITHUB_COPILOT_API_BASE_URL;
195
196 fn build<H>(
197 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
198 ) -> http_client::Result<Self::Extension<H>>
199 where
200 H: HttpClientExt,
201 {
202 let auth = match builder.get_api_key() {
203 CopilotAuth::ApiKey(api_key) => auth::AuthSource::ApiKey(api_key.clone()),
204 CopilotAuth::GitHubAccessToken(access_token) => {
205 auth::AuthSource::GitHubAccessToken(access_token.clone())
206 }
207 CopilotAuth::OAuth => auth::AuthSource::OAuth,
208 };
209
210 let ext = builder.ext();
211 Ok(CopilotExt {
212 auth: auth::Authenticator::new(
213 auth,
214 ext.access_token_file.clone(),
215 ext.api_key_file.clone(),
216 ext.device_code_handler.clone(),
217 ),
218 })
219 }
220}
221
222impl ProviderClient for Client {
223 type Input = CopilotAuth;
224 type Error = crate::client::ProviderClientError;
225
226 fn from_env() -> Result<Self, Self::Error> {
227 let mut builder = Self::builder();
228 fn get(name: &str) -> Option<String> {
229 std::env::var(name).ok()
230 }
231
232 if let Some(base_url) = env_base_url(&get) {
233 builder = builder.base_url(base_url);
234 }
235
236 if let Some(api_key) = env_api_key(&get) {
237 builder.api_key(api_key).build().map_err(Into::into)
238 } else if let Some(access_token) = env_github_access_token(&get) {
239 builder
240 .github_access_token(access_token)
241 .build()
242 .map_err(Into::into)
243 } else {
244 builder.oauth().build().map_err(Into::into)
245 }
246 }
247
248 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
249 Self::builder().api_key(input).build().map_err(Into::into)
250 }
251}
252
253impl<H> client::ClientBuilder<CopilotBuilder, crate::markers::Missing, H> {
254 pub fn github_access_token(
255 self,
256 access_token: impl Into<String>,
257 ) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
258 self.api_key(CopilotAuth::GitHubAccessToken(access_token.into()))
259 }
260
261 pub fn oauth(self) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
262 self.api_key(CopilotAuth::OAuth)
263 }
264}
265
266impl<H> ClientBuilder<H> {
267 pub fn on_device_code<F>(self, handler: F) -> Self
268 where
269 F: Fn(auth::DeviceCodePrompt) + Send + Sync + 'static,
270 {
271 self.over_ext(|mut ext| {
272 ext.device_code_handler = auth::DeviceCodeHandler::new(handler);
273 ext
274 })
275 }
276
277 pub fn token_dir(self, path: impl AsRef<Path>) -> Self {
278 let path = path.as_ref();
279 self.over_ext(|mut ext| {
280 ext.access_token_file = Some(path.join("access-token"));
281 ext.api_key_file = Some(path.join("api-key.json"));
282 ext
283 })
284 }
285
286 pub fn access_token_file(self, path: impl AsRef<Path>) -> Self {
287 let path = path.as_ref().to_path_buf();
288 self.over_ext(|mut ext| {
289 ext.access_token_file = Some(path);
290 ext
291 })
292 }
293
294 pub fn api_key_file(self, path: impl AsRef<Path>) -> Self {
295 let path = path.as_ref().to_path_buf();
296 self.over_ext(|mut ext| {
297 ext.api_key_file = Some(path);
298 ext
299 })
300 }
301}
302
303fn env_value<F>(get: &F, name: &str) -> Option<String>
304where
305 F: Fn(&str) -> Option<String>,
306{
307 get(name).filter(|value| !value.trim().is_empty())
308}
309
310fn first_env_value<F>(get: &F, keys: &[&str]) -> Option<String>
311where
312 F: Fn(&str) -> Option<String>,
313{
314 keys.iter().find_map(|key| env_value(get, key))
315}
316
317fn env_api_key<F>(get: &F) -> Option<String>
318where
319 F: Fn(&str) -> Option<String>,
320{
321 first_env_value(get, &["GITHUB_COPILOT_API_KEY", "COPILOT_API_KEY"])
322}
323
324fn env_github_access_token<F>(get: &F) -> Option<String>
325where
326 F: Fn(&str) -> Option<String>,
327{
328 first_env_value(get, &["COPILOT_GITHUB_ACCESS_TOKEN", "GITHUB_TOKEN"])
329}
330
331fn env_base_url<F>(get: &F) -> Option<String>
332where
333 F: Fn(&str) -> Option<String>,
334{
335 first_env_value(get, &["GITHUB_COPILOT_API_BASE", "COPILOT_BASE_URL"])
336}
337
338impl<H> Client<H>
339where
340 H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
341{
342 pub async fn authorize(&self) -> Result<(), auth::AuthError> {
343 self.ext().auth.auth_context().await.map(|_| ())
344 }
345}
346
347fn default_headers(
348 api_key: &str,
349 initiator: &'static str,
350 has_vision: bool,
351) -> Vec<(&'static str, String)> {
352 let mut headers = vec![
353 (
354 http::header::AUTHORIZATION.as_str(),
355 format!("Bearer {api_key}"),
356 ),
357 ("copilot-integration-id", "vscode-chat".to_string()),
358 ("editor-version", "vscode/1.95.0".to_string()),
359 ("editor-plugin-version", EDITOR_PLUGIN_VERSION.to_string()),
360 ("user-agent", USER_AGENT.to_string()),
361 ("openai-intent", "conversation-panel".to_string()),
362 ("x-github-api-version", API_VERSION.to_string()),
363 ("x-request-id", nanoid::nanoid!()),
364 (
365 "x-vscode-user-agent-library-version",
366 "electron-fetch".to_string(),
367 ),
368 ("X-Initiator", initiator.to_string()),
369 ];
370
371 if has_vision {
372 headers.push(("copilot-vision-request", "true".to_string()));
373 }
374
375 headers
376}
377
378fn apply_headers(
379 builder: http_client::Builder,
380 headers: &[(&'static str, String)],
381) -> http_client::Builder {
382 headers
383 .iter()
384 .fold(builder, |builder, (key, value)| builder.header(*key, value))
385}
386
387fn runtime_base_url<'a, H>(client: &'a Client<H>, auth: &'a auth::AuthContext) -> Cow<'a, str> {
388 if client.base_url() == GITHUB_COPILOT_API_BASE_URL {
389 auth.api_base
390 .as_deref()
391 .map(Cow::Borrowed)
392 .unwrap_or_else(|| Cow::Borrowed(client.base_url()))
393 } else {
394 Cow::Borrowed(client.base_url())
395 }
396}
397
398fn post_with_auth_base<H>(
399 client: &Client<H>,
400 auth: &auth::AuthContext,
401 path: &str,
402 transport: Transport,
403) -> http_client::Result<http_client::Builder> {
404 let uri = client
405 .ext()
406 .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
407 let mut req = Request::post(uri);
408
409 if let Some(headers) = req.headers_mut() {
410 headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
411 }
412
413 client.ext().with_custom(req)
414}
415
416fn get_with_auth_base<H>(
417 client: &Client<H>,
418 auth: &auth::AuthContext,
419 path: &str,
420 transport: Transport,
421) -> http_client::Result<http_client::Builder> {
422 let uri = client
423 .ext()
424 .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
425 let mut req = Request::get(uri);
426
427 if let Some(headers) = req.headers_mut() {
428 headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
429 }
430
431 client.ext().with_custom(req)
432}
433
434fn request_initiator(request: &completion::CompletionRequest) -> &'static str {
435 for message in request.chat_history.iter() {
436 match message {
437 crate::completion::Message::Assistant { .. } => return "agent",
438 crate::completion::Message::User { content } => {
439 if content
440 .iter()
441 .any(|item| matches!(item, crate::message::UserContent::ToolResult(_)))
442 {
443 return "agent";
444 }
445 }
446 crate::completion::Message::System { .. } => {}
447 }
448 }
449
450 "user"
451}
452
453fn request_has_vision(request: &completion::CompletionRequest) -> bool {
454 request.chat_history.iter().any(|message| match message {
455 crate::completion::Message::User { content } => content
456 .iter()
457 .any(|item| matches!(item, crate::message::UserContent::Image(_))),
458 _ => false,
459 })
460}
461
462#[derive(Clone, Copy, Debug, PartialEq, Eq)]
463enum CompletionRoute {
464 ChatCompletions,
465 Responses,
466}
467
468fn route_for_model(model: &str) -> CompletionRoute {
469 if model.to_ascii_lowercase().contains("codex") {
470 CompletionRoute::Responses
471 } else {
472 CompletionRoute::ChatCompletions
473 }
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
477#[serde(tag = "api", rename_all = "snake_case")]
478pub enum CopilotCompletionResponse {
479 Chat(ChatCompletionResponse),
480 Responses(Box<responses_api::CompletionResponse>),
481}
482
483#[derive(Clone, Serialize, Deserialize)]
484#[serde(tag = "api", rename_all = "snake_case")]
485pub enum CopilotStreamingResponse {
486 Chat(openai::completion::streaming::StreamingCompletionResponse),
487 Responses(responses_api::streaming::StreamingCompletionResponse),
488}
489
490impl GetTokenUsage for CopilotStreamingResponse {
491 fn token_usage(&self) -> Option<completion::Usage> {
492 match self {
493 Self::Chat(response) => response.token_usage(),
494 Self::Responses(response) => response.token_usage(),
495 }
496 }
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct ChatCompletionResponse {
501 pub id: String,
502 #[serde(default)]
503 pub object: Option<String>,
504 #[serde(default)]
505 pub created: Option<u64>,
506 pub model: String,
507 pub system_fingerprint: Option<String>,
508 pub choices: Vec<ChatChoice>,
509 pub usage: Option<openai::completion::Usage>,
510}
511
512#[derive(Clone, Debug, Serialize, Deserialize)]
513pub struct ChatChoice {
514 #[serde(default)]
515 pub index: usize,
516 pub message: openai::completion::Message,
517 pub logprobs: Option<serde_json::Value>,
518 #[serde(default)]
519 pub finish_reason: Option<String>,
520}
521
522impl TryFrom<ChatCompletionResponse> for completion::CompletionResponse<ChatCompletionResponse> {
523 type Error = CompletionError;
524
525 fn try_from(response: ChatCompletionResponse) -> Result<Self, Self::Error> {
526 let choice = response.choices.first().ok_or_else(|| {
527 CompletionError::ResponseError("Response contained no choices".to_owned())
528 })?;
529
530 let content = match &choice.message {
531 openai::completion::Message::Assistant {
532 content,
533 tool_calls,
534 ..
535 } => {
536 let mut content = content
537 .iter()
538 .filter_map(|c| {
539 let s = match c {
540 openai::completion::AssistantContent::Text { text } => text,
541 openai::completion::AssistantContent::Refusal { refusal } => refusal,
542 };
543 if s.is_empty() {
544 None
545 } else {
546 Some(completion::AssistantContent::text(s))
547 }
548 })
549 .collect::<Vec<_>>();
550
551 content.extend(
552 tool_calls
553 .iter()
554 .map(|call| {
555 completion::AssistantContent::tool_call(
556 &call.id,
557 &call.function.name,
558 call.function.arguments.clone(),
559 )
560 })
561 .collect::<Vec<_>>(),
562 );
563 Ok(content)
564 }
565 _ => Err(CompletionError::ResponseError(
566 "Response did not contain a valid message or tool call".into(),
567 )),
568 }?;
569
570 let choice = crate::OneOrMany::many(content).map_err(|_| {
571 CompletionError::ResponseError(
572 "Response contained no message or tool call (empty)".to_owned(),
573 )
574 })?;
575
576 let usage = response
577 .usage
578 .as_ref()
579 .map(|usage| completion::Usage {
580 input_tokens: usage.prompt_tokens as u64,
581 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
582 total_tokens: usage.total_tokens as u64,
583 cached_input_tokens: usage
584 .prompt_tokens_details
585 .as_ref()
586 .map(|d| d.cached_tokens as u64)
587 .unwrap_or(0),
588 cache_creation_input_tokens: 0,
589 tool_use_prompt_tokens: 0,
590 reasoning_tokens: 0,
591 })
592 .unwrap_or_default();
593
594 Ok(completion::CompletionResponse {
595 choice,
596 usage,
597 raw_response: response,
598 message_id: None,
599 })
600 }
601}
602
603#[derive(Debug, Deserialize)]
604pub struct ChatApiErrorResponse {
605 #[serde(default)]
606 pub message: Option<String>,
607 #[serde(default)]
608 pub error: Option<String>,
609}
610
611impl ChatApiErrorResponse {
612 pub fn error_message(&self) -> &str {
613 self.message
614 .as_deref()
615 .or(self.error.as_deref())
616 .unwrap_or("unknown error")
617 }
618}
619
620#[derive(Debug, Deserialize)]
621#[serde(untagged)]
622enum ChatApiResponse<T> {
623 Ok(T),
624 Err(ChatApiErrorResponse),
625}
626
627#[derive(Clone)]
628pub struct CompletionModel<H = reqwest::Client> {
629 client: Client<H>,
630 pub model: String,
631 pub strict_tools: bool,
632 pub tool_result_array_content: bool,
633}
634
635impl<H> CompletionModel<H>
636where
637 Client<H>: HttpClientExt + Clone + Debug + 'static,
638 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
639{
640 pub fn new(client: Client<H>, model: impl Into<String>) -> Self {
641 Self {
642 client,
643 model: model.into(),
644 strict_tools: false,
645 tool_result_array_content: false,
646 }
647 }
648
649 pub fn with_strict_tools(mut self) -> Self {
650 self.strict_tools = true;
651 self
652 }
653
654 pub fn with_tool_result_array_content(mut self) -> Self {
655 self.tool_result_array_content = true;
656 self
657 }
658
659 fn route(&self) -> CompletionRoute {
660 route_for_model(&self.model)
661 }
662
663 async fn auth_context(&self) -> Result<auth::AuthContext, CompletionError> {
664 self.client
665 .ext()
666 .auth
667 .auth_context()
668 .await
669 .map_err(|err| CompletionError::ProviderError(err.to_string()))
670 }
671
672 fn chat_request(
673 &self,
674 completion_request: completion::CompletionRequest,
675 ) -> Result<openai::completion::CompletionRequest, CompletionError> {
676 openai::completion::CompletionRequest::try_from(openai::completion::OpenAIRequestParams {
677 model: self.model.clone(),
678 request: completion_request,
679 strict_tools: self.strict_tools,
680 tool_result_array_content: self.tool_result_array_content,
681 })
682 }
683
684 fn responses_request(
685 &self,
686 completion_request: completion::CompletionRequest,
687 ) -> Result<ResponsesRequest, CompletionError> {
688 ResponsesRequest::try_from((self.model.clone(), completion_request))
689 }
690
691 async fn completion_chat(
692 &self,
693 completion_request: completion::CompletionRequest,
694 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
695 let initiator = request_initiator(&completion_request);
696 let has_vision = request_has_vision(&completion_request);
697 let request = self.chat_request(completion_request)?;
698 let body = serde_json::to_vec(&request)?;
699 let auth = self.auth_context().await?;
700
701 let headers = default_headers(&auth.api_key, initiator, has_vision);
702 let req = apply_headers(
703 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Http)?,
704 &headers,
705 )
706 .body(body)
707 .map_err(|err| CompletionError::HttpError(err.into()))?;
708
709 let span = if tracing::Span::current().is_disabled() {
710 info_span!(
711 target: "rig::completions",
712 "chat",
713 gen_ai.operation.name = "chat",
714 gen_ai.provider.name = "copilot",
715 gen_ai.request.model = self.model,
716 gen_ai.response.id = tracing::field::Empty,
717 gen_ai.response.model = tracing::field::Empty,
718 gen_ai.usage.output_tokens = tracing::field::Empty,
719 gen_ai.usage.input_tokens = tracing::field::Empty,
720 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
721 )
722 } else {
723 tracing::Span::current()
724 };
725
726 async move {
727 let response = self.client.send(req).await?;
728
729 if response.status().is_success() {
730 let body = http_client::text(response).await?;
731 match serde_json::from_str::<ChatApiResponse<ChatCompletionResponse>>(&body)? {
732 ChatApiResponse::Ok(response) => {
733 let core = completion::CompletionResponse::try_from(response.clone())?;
734 let span = tracing::Span::current();
735 span.record("gen_ai.response.id", response.id.as_str());
736 span.record("gen_ai.response.model", response.model.as_str());
737 if let Some(usage) = &response.usage {
738 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
739 span.record(
740 "gen_ai.usage.output_tokens",
741 usage.total_tokens - usage.prompt_tokens,
742 );
743 span.record(
744 "gen_ai.usage.cache_read.input_tokens",
745 usage
746 .prompt_tokens_details
747 .as_ref()
748 .map(|details| details.cached_tokens)
749 .unwrap_or(0),
750 );
751 }
752
753 Ok(completion::CompletionResponse {
754 choice: core.choice,
755 usage: core.usage,
756 raw_response: CopilotCompletionResponse::Chat(response),
757 message_id: core.message_id,
758 })
759 }
760 ChatApiResponse::Err(err) => Err(CompletionError::ProviderError(
761 err.error_message().to_string(),
762 )),
763 }
764 } else {
765 let body = http_client::text(response).await?;
766 Err(CompletionError::ProviderError(body))
767 }
768 }
769 .instrument(span)
770 .await
771 }
772
773 async fn completion_responses(
774 &self,
775 completion_request: completion::CompletionRequest,
776 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
777 let initiator = request_initiator(&completion_request);
778 let has_vision = request_has_vision(&completion_request);
779 let request = self.responses_request(completion_request)?;
780 let auth = self.auth_context().await?;
781
782 let headers = default_headers(&auth.api_key, initiator, has_vision);
783 let req = apply_headers(
784 post_with_auth_base(&self.client, &auth, "/responses", Transport::Http)?,
785 &headers,
786 )
787 .body(serde_json::to_vec(&request)?)
788 .map_err(|err| CompletionError::HttpError(err.into()))?;
789
790 let span = if tracing::Span::current().is_disabled() {
791 info_span!(
792 target: "rig::completions",
793 "chat",
794 gen_ai.operation.name = "chat",
795 gen_ai.provider.name = "copilot",
796 gen_ai.request.model = self.model,
797 gen_ai.response.id = tracing::field::Empty,
798 gen_ai.response.model = tracing::field::Empty,
799 gen_ai.usage.output_tokens = tracing::field::Empty,
800 gen_ai.usage.input_tokens = tracing::field::Empty,
801 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
802 )
803 } else {
804 tracing::Span::current()
805 };
806
807 async move {
808 let response = self.client.send(req).await?;
809 if response.status().is_success() {
810 let body = http_client::text(response).await?;
811 let response = serde_json::from_str::<responses_api::CompletionResponse>(&body)?;
812 let core = completion::CompletionResponse::try_from(response.clone())?;
813
814 let span = tracing::Span::current();
815 span.record("gen_ai.response.id", response.id.as_str());
816 span.record("gen_ai.response.model", response.model.as_str());
817 if let Some(usage) = &response.usage {
818 span.record("gen_ai.usage.input_tokens", usage.input_tokens);
819 span.record("gen_ai.usage.output_tokens", usage.output_tokens);
820 span.record(
821 "gen_ai.usage.cache_read.input_tokens",
822 usage
823 .input_tokens_details
824 .as_ref()
825 .map(|details| details.cached_tokens)
826 .unwrap_or(0),
827 );
828 }
829
830 Ok(completion::CompletionResponse {
831 choice: core.choice,
832 usage: core.usage,
833 raw_response: CopilotCompletionResponse::Responses(Box::new(response)),
834 message_id: core.message_id,
835 })
836 } else {
837 let body = http_client::text(response).await?;
838 Err(CompletionError::ProviderError(body))
839 }
840 }
841 .instrument(span)
842 .await
843 }
844
845 async fn stream_chat(
846 &self,
847 completion_request: completion::CompletionRequest,
848 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
849 let initiator = request_initiator(&completion_request);
850 let has_vision = request_has_vision(&completion_request);
851 let request = self.chat_request(completion_request)?;
852 let auth = self.auth_context().await?;
853 let headers = default_headers(&auth.api_key, initiator, has_vision);
854 let mut request_json = serde_json::to_value(&request)?;
855 let request_object = request_json.as_object_mut().ok_or_else(|| {
856 CompletionError::ResponseError("copilot request body must be a JSON object".into())
857 })?;
858 request_object.insert("stream".to_owned(), json!(true));
859 request_object.insert(
860 "stream_options".to_owned(),
861 json!({ "include_usage": true }),
862 );
863
864 let req = apply_headers(
865 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?,
866 &headers,
867 )
868 .body(serde_json::to_vec(&request_json)?)
869 .map_err(|err| CompletionError::HttpError(err.into()))?;
870
871 let span = if tracing::Span::current().is_disabled() {
872 info_span!(
873 target: "rig::completions",
874 "chat_streaming",
875 gen_ai.operation.name = "chat_streaming",
876 gen_ai.provider.name = "copilot",
877 gen_ai.request.model = self.model,
878 gen_ai.response.id = tracing::field::Empty,
879 gen_ai.response.model = tracing::field::Empty,
880 gen_ai.usage.output_tokens = tracing::field::Empty,
881 gen_ai.usage.input_tokens = tracing::field::Empty,
882 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
883 )
884 } else {
885 tracing::Span::current()
886 };
887
888 tracing::Instrument::instrument(
889 send_copilot_chat_streaming_request(self.client.clone(), req),
890 span,
891 )
892 .await
893 }
894
895 async fn stream_responses(
896 &self,
897 completion_request: completion::CompletionRequest,
898 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
899 let initiator = request_initiator(&completion_request);
900 let has_vision = request_has_vision(&completion_request);
901 let mut request = self.responses_request(completion_request)?;
902 request.stream = Some(true);
903 let auth = self.auth_context().await?;
904
905 let headers = default_headers(&auth.api_key, initiator, has_vision);
906 let req = apply_headers(
907 post_with_auth_base(&self.client, &auth, "/responses", Transport::Sse)?,
908 &headers,
909 )
910 .body(serde_json::to_vec(&request)?)
911 .map_err(|err| CompletionError::HttpError(err.into()))?;
912
913 let span = if tracing::Span::current().is_disabled() {
914 info_span!(
915 target: "rig::completions",
916 "chat_streaming",
917 gen_ai.operation.name = "chat_streaming",
918 gen_ai.provider.name = "copilot",
919 gen_ai.request.model = self.model,
920 gen_ai.response.id = tracing::field::Empty,
921 gen_ai.response.model = tracing::field::Empty,
922 gen_ai.usage.output_tokens = tracing::field::Empty,
923 gen_ai.usage.input_tokens = tracing::field::Empty,
924 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
925 )
926 } else {
927 tracing::Span::current()
928 };
929
930 let client = self.client.clone();
931 let mut event_source = crate::http_client::sse::GenericEventSource::new(client, req);
932
933 let stream = tracing_futures::Instrument::instrument(
934 stream! {
935 let mut final_usage = responses_api::ResponsesUsage::new();
936 let mut tool_calls: Vec<streaming::RawStreamingChoice<CopilotStreamingResponse>> = Vec::new();
937 let mut tool_call_internal_ids: HashMap<String, String> = HashMap::new();
938 let span = tracing::Span::current();
939
940 let mut terminated_with_error = false;
941
942 while let Some(event_result) = event_source.next().await {
943 match event_result {
944 Ok(crate::http_client::sse::Event::Open) => continue,
945 Ok(crate::http_client::sse::Event::Message(evt)) => {
946 if evt.data.trim().is_empty() {
947 continue;
948 }
949
950 let Ok(data) = serde_json::from_str::<responses_api::streaming::StreamingCompletionChunk>(&evt.data) else {
951 continue;
952 };
953
954 if let responses_api::streaming::StreamingCompletionChunk::Delta(chunk) = &data {
955 use responses_api::streaming::{ItemChunkKind, StreamingItemDoneOutput};
956
957 match &chunk.data {
958 ItemChunkKind::OutputItemAdded(message) => {
959 if let StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } = message {
960 let internal_call_id = tool_call_internal_ids
961 .entry(func.id.clone())
962 .or_insert_with(|| nanoid::nanoid!())
963 .clone();
964 yield Ok(RawStreamingChoice::ToolCallDelta {
965 id: func.id.clone(),
966 internal_call_id,
967 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
968 });
969 }
970 }
971 ItemChunkKind::OutputItemDone(message) => match message {
972 StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } => {
973 let internal_id = tool_call_internal_ids
974 .entry(func.id.clone())
975 .or_insert_with(|| nanoid::nanoid!())
976 .clone();
977 let raw_tool_call = streaming::RawStreamingToolCall::new(
978 func.id.clone(),
979 func.name.clone(),
980 func.arguments.clone(),
981 )
982 .with_internal_call_id(internal_id)
983 .with_call_id(func.call_id.clone());
984 tool_calls.push(RawStreamingChoice::ToolCall(raw_tool_call));
985 }
986 StreamingItemDoneOutput { item: responses_api::Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
987 for reasoning_choice in responses_api::streaming::reasoning_choices_from_done_item(
988 id,
989 summary,
990 encrypted_content.as_deref(),
991 ) {
992 match reasoning_choice {
993 RawStreamingChoice::Reasoning { id, content } => {
994 yield Ok(RawStreamingChoice::Reasoning { id, content });
995 }
996 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
997 yield Ok(RawStreamingChoice::ReasoningDelta { id, reasoning });
998 }
999 _ => {}
1000 }
1001 }
1002 }
1003 StreamingItemDoneOutput { item: responses_api::Output::Message(msg), .. } => {
1004 yield Ok(RawStreamingChoice::MessageId(msg.id.clone()));
1005 }
1006 StreamingItemDoneOutput { item: responses_api::Output::Unknown, .. } => {}
1007 },
1008 ItemChunkKind::OutputTextDelta(delta) => {
1009 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1010 }
1011 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
1012 yield Ok(RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
1013 }
1014 ItemChunkKind::RefusalDelta(delta) => {
1015 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1016 }
1017 ItemChunkKind::FunctionCallArgsDelta(delta) => {
1018 if let Some(item_id) = chunk.item_id.as_ref() {
1019 let internal_call_id = tool_call_internal_ids
1020 .entry(item_id.clone())
1021 .or_insert_with(|| nanoid::nanoid!())
1022 .clone();
1023 yield Ok(RawStreamingChoice::ToolCallDelta {
1024 id: item_id.clone(),
1025 internal_call_id,
1026 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
1027 })
1028 }
1029 }
1030 _ => continue,
1031 }
1032 }
1033
1034 if let responses_api::streaming::StreamingCompletionChunk::Response(chunk) = data {
1035 let responses_api::streaming::ResponseChunk { kind, response, .. } = *chunk;
1036 match kind {
1037 responses_api::streaming::ResponseChunkKind::ResponseCompleted => {
1038 span.record("gen_ai.response.id", response.id.as_str());
1039 span.record("gen_ai.response.model", response.model.as_str());
1040 if let Some(usage) = response.usage {
1041 final_usage = usage;
1042 }
1043 }
1044 responses_api::streaming::ResponseChunkKind::ResponseFailed
1045 | responses_api::streaming::ResponseChunkKind::ResponseIncomplete => {
1046 let error = response
1047 .error
1048 .as_ref()
1049 .map(|err| err.message.clone())
1050 .unwrap_or_else(|| "Copilot response stream failed".into());
1051 terminated_with_error = true;
1052 yield Err(CompletionError::ProviderError(error));
1053 break;
1054 }
1055 _ => continue,
1056 }
1057 }
1058 }
1059 Err(crate::http_client::Error::StreamEnded) => {
1060 break;
1061 }
1062 Err(error) => {
1063 terminated_with_error = true;
1064 yield Err(CompletionError::ProviderError(error.to_string()));
1065 break;
1066 }
1067 }
1068 }
1069
1070 event_source.close();
1071
1072 if terminated_with_error {
1073 return;
1074 }
1075
1076 for tool_call in &tool_calls {
1077 yield Ok(tool_call.to_owned())
1078 }
1079
1080 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
1081 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
1082 span.record(
1083 "gen_ai.usage.cache_read.input_tokens",
1084 final_usage
1085 .input_tokens_details
1086 .as_ref()
1087 .map(|details| details.cached_tokens)
1088 .unwrap_or(0),
1089 );
1090
1091 yield Ok(RawStreamingChoice::FinalResponse(
1092 CopilotStreamingResponse::Responses(
1093 responses_api::streaming::StreamingCompletionResponse { usage: final_usage }
1094 )
1095 ));
1096 },
1097 span,
1098 );
1099
1100 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
1101 }
1102}
1103
1104impl<H> completion::CompletionModel for CompletionModel<H>
1105where
1106 Client<H>: HttpClientExt + Clone + Debug + 'static,
1107 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1108{
1109 type Response = CopilotCompletionResponse;
1110 type StreamingResponse = CopilotStreamingResponse;
1111 type Client = Client<H>;
1112
1113 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1114 Self::new(client.clone(), model)
1115 }
1116
1117 async fn completion(
1118 &self,
1119 completion_request: completion::CompletionRequest,
1120 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
1121 match self.route() {
1122 CompletionRoute::ChatCompletions => self.completion_chat(completion_request).await,
1123 CompletionRoute::Responses => self.completion_responses(completion_request).await,
1124 }
1125 }
1126
1127 async fn stream(
1128 &self,
1129 completion_request: completion::CompletionRequest,
1130 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1131 match self.route() {
1132 CompletionRoute::ChatCompletions => self.stream_chat(completion_request).await,
1133 CompletionRoute::Responses => self.stream_responses(completion_request).await,
1134 }
1135 }
1136}
1137
1138#[derive(Clone)]
1139pub struct EmbeddingModel<H = reqwest::Client> {
1140 client: Client<H>,
1141 pub model: String,
1142 pub encoding_format: Option<openai::EncodingFormat>,
1143 pub user: Option<String>,
1144 ndims: usize,
1145}
1146
1147#[derive(Deserialize)]
1148struct CopilotEmbeddingResponse {
1149 data: Vec<CopilotEmbeddingData>,
1150}
1151
1152#[derive(Deserialize)]
1153struct CopilotEmbeddingData {
1154 embedding: Vec<serde_json::Number>,
1155}
1156
1157impl<H> EmbeddingModel<H>
1158where
1159 Client<H>: HttpClientExt + Clone + Debug + 'static,
1160 H: Clone + Default + Debug + 'static,
1161{
1162 pub fn new(client: Client<H>, model: impl Into<String>, ndims: usize) -> Self {
1163 Self {
1164 client,
1165 model: model.into(),
1166 encoding_format: None,
1167 user: None,
1168 ndims,
1169 }
1170 }
1171}
1172
1173impl<H> embeddings::EmbeddingModel for EmbeddingModel<H>
1174where
1175 Client<H>: HttpClientExt + Clone + Debug + WasmCompatSend + WasmCompatSync + 'static,
1176 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1177{
1178 const MAX_DOCUMENTS: usize = 1024;
1179 type Client = Client<H>;
1180
1181 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
1182 let model = model.into();
1183 let dims = ndims.unwrap_or(match model.as_str() {
1184 TEXT_EMBEDDING_3_LARGE => 3072,
1185 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
1186 _ => 0,
1187 });
1188 Self::new(client.clone(), model, dims)
1189 }
1190
1191 fn ndims(&self) -> usize {
1192 self.ndims
1193 }
1194
1195 async fn embed_texts(
1196 &self,
1197 documents: impl IntoIterator<Item = String>,
1198 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
1199 let documents = documents.into_iter().collect::<Vec<_>>();
1200 let auth = self
1201 .client
1202 .ext()
1203 .auth
1204 .auth_context()
1205 .await
1206 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
1207
1208 let headers = default_headers(&auth.api_key, "user", false);
1209 let mut body = json!({
1210 "model": self.model,
1211 "input": documents,
1212 });
1213
1214 let body_object = body.as_object_mut().ok_or_else(|| {
1215 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
1216 })?;
1217
1218 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
1219 body_object.insert("dimensions".to_owned(), json!(self.ndims));
1220 }
1221 if let Some(encoding_format) = &self.encoding_format {
1222 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
1223 }
1224 if let Some(user) = &self.user {
1225 body_object.insert("user".to_owned(), json!(user));
1226 }
1227
1228 let req = apply_headers(
1229 post_with_auth_base(&self.client, &auth, "/embeddings", Transport::Http)?,
1230 &headers,
1231 )
1232 .body(serde_json::to_vec(&body)?)
1233 .map_err(|err| EmbeddingError::HttpError(err.into()))?;
1234
1235 let response = self.client.send(req).await?;
1236 if response.status().is_success() {
1237 let body: Vec<u8> = response.into_body().await?;
1238 #[derive(Deserialize)]
1239 struct NestedApiError {
1240 error: NestedApiErrorMessage,
1241 }
1242
1243 #[derive(Deserialize)]
1244 struct NestedApiErrorMessage {
1245 message: String,
1246 }
1247
1248 let body: CopilotEmbeddingResponse = match serde_json::from_slice(&body) {
1249 Ok(parsed) => parsed,
1250 Err(parse_error) => {
1251 if let Ok(err) = serde_json::from_slice::<NestedApiError>(&body) {
1252 return Err(EmbeddingError::ProviderError(err.error.message));
1253 }
1254
1255 let preview = String::from_utf8_lossy(&body);
1256 let preview = if preview.len() > 512 {
1257 format!("{}...", &preview[..512])
1258 } else {
1259 preview.into_owned()
1260 };
1261
1262 return Err(EmbeddingError::ProviderError(format!(
1263 "Failed to parse Copilot embeddings response: {parse_error}; body: {preview}"
1264 )));
1265 }
1266 };
1267
1268 Ok(body
1269 .data
1270 .into_iter()
1271 .zip(documents.into_iter())
1272 .map(|(embedding, document)| embeddings::Embedding {
1273 document,
1274 vec: embedding
1275 .embedding
1276 .into_iter()
1277 .filter_map(|n| n.as_f64())
1278 .collect(),
1279 })
1280 .collect())
1281 } else {
1282 let text = http_client::text(response).await?;
1283 Err(EmbeddingError::ProviderError(text))
1284 }
1285 }
1286}
1287
1288const MODEL_LISTING_PATH: &str = "/models";
1289const MODEL_LISTING_PROVIDER: &str = "Copilot";
1290
1291#[derive(Debug, Deserialize)]
1292struct ListModelsResponse {
1293 data: Vec<ListModelEntry>,
1294}
1295
1296#[derive(Debug, Deserialize)]
1297struct ListModelEntry {
1298 id: String,
1299 #[serde(default)]
1300 name: Option<String>,
1301 #[serde(default)]
1302 vendor: Option<String>,
1303 #[serde(default)]
1304 capabilities: Option<ListModelEntryCapabilities>,
1305}
1306
1307#[derive(Debug, Deserialize)]
1308struct ListModelEntryCapabilities {
1309 #[serde(default, rename = "type")]
1310 r#type: Option<String>,
1311}
1312
1313impl From<ListModelEntry> for Model {
1314 fn from(value: ListModelEntry) -> Self {
1315 let mut model = Model::from_id(value.id);
1316 model.name = value.name;
1317 model.owned_by = value.vendor;
1318 if let Some(caps) = value.capabilities {
1319 model.r#type = caps.r#type;
1320 }
1321 model
1322 }
1323}
1324
1325#[derive(Clone)]
1327pub struct CopilotModelLister<H = reqwest::Client> {
1328 client: Client<H>,
1329}
1330
1331impl<H> ModelLister<H> for CopilotModelLister<H>
1332where
1333 H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
1334{
1335 type Client = Client<H>;
1336
1337 fn new(client: Self::Client) -> Self {
1338 Self { client }
1339 }
1340
1341 async fn list_all(&self) -> Result<ModelList, ModelListingError> {
1342 let auth = self.client.ext().auth.auth_context().await.map_err(|err| {
1343 ModelListingError::AuthError {
1344 message: err.to_string(),
1345 }
1346 })?;
1347
1348 let headers = default_headers(&auth.api_key, "user", false);
1349 let req = apply_headers(
1350 get_with_auth_base(&self.client, &auth, MODEL_LISTING_PATH, Transport::Http)?,
1351 &headers,
1352 )
1353 .body(http_client::NoBody)?;
1354
1355 let response = self.client.send::<_, Vec<u8>>(req).await?;
1356
1357 if !response.status().is_success() {
1358 let status_code = response.status().as_u16();
1359 let body = response.into_body().await?;
1360 return Err(ModelListingError::api_error_with_context(
1361 MODEL_LISTING_PROVIDER,
1362 MODEL_LISTING_PATH,
1363 status_code,
1364 &body,
1365 ));
1366 }
1367
1368 let body = response.into_body().await?;
1369 let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
1370 ModelListingError::parse_error_with_context(
1371 MODEL_LISTING_PROVIDER,
1372 MODEL_LISTING_PATH,
1373 &error,
1374 &body,
1375 )
1376 })?;
1377 let models = api_resp.data.into_iter().map(Model::from).collect();
1378
1379 Ok(ModelList::new(models))
1380 }
1381}
1382
1383#[derive(Deserialize, Debug)]
1384struct ChatStreamingFunction {
1385 name: Option<String>,
1386 arguments: Option<String>,
1387}
1388
1389#[derive(Deserialize, Debug)]
1390struct ChatStreamingToolCall {
1391 index: usize,
1392 id: Option<String>,
1393 function: ChatStreamingFunction,
1394}
1395
1396impl From<&ChatStreamingToolCall> for CompatibleToolCallChunk {
1397 fn from(value: &ChatStreamingToolCall) -> Self {
1398 Self {
1399 index: value.index,
1400 id: value.id.clone(),
1401 name: value.function.name.clone(),
1402 arguments: value.function.arguments.clone(),
1403 }
1404 }
1405}
1406
1407#[derive(Deserialize, Debug, Default)]
1408struct ChatStreamingDelta {
1409 #[serde(default)]
1410 content: Option<String>,
1411 #[serde(default)]
1412 reasoning_content: Option<String>,
1413 #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")]
1414 tool_calls: Vec<ChatStreamingToolCall>,
1415}
1416
1417#[derive(Deserialize, Debug, PartialEq)]
1418#[serde(rename_all = "snake_case")]
1419enum ChatFinishReason {
1420 ToolCalls,
1421 Stop,
1422 ContentFilter,
1423 Length,
1424 #[serde(untagged)]
1425 Other(String),
1426}
1427
1428#[derive(Deserialize, Debug)]
1429struct ChatStreamingChoice {
1430 delta: ChatStreamingDelta,
1431 finish_reason: Option<ChatFinishReason>,
1432}
1433
1434#[derive(Deserialize, Debug)]
1435struct ChatStreamingChunk {
1436 id: Option<String>,
1437 model: Option<String>,
1438 choices: Vec<ChatStreamingChoice>,
1439 usage: Option<openai::completion::Usage>,
1440}
1441
1442#[derive(Clone, Copy)]
1443struct CopilotChatCompatibleProfile;
1444
1445impl CompatibleStreamProfile for CopilotChatCompatibleProfile {
1446 type Usage = openai::completion::Usage;
1447 type Detail = ();
1448 type FinalResponse = CopilotStreamingResponse;
1449
1450 fn normalize_chunk(
1451 &self,
1452 data: &str,
1453 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
1454 let data = match serde_json::from_str::<ChatStreamingChunk>(data) {
1455 Ok(data) => data,
1456 Err(error) => {
1457 tracing::debug!(?error, "Couldn't parse Copilot chat SSE payload");
1458 return Ok(None);
1459 }
1460 };
1461
1462 Ok(Some(
1463 openai_chat_completions_compatible::normalize_first_choice_chunk(
1464 data.id,
1465 data.model,
1466 data.usage,
1467 &data.choices,
1468 |choice| CompatibleChoiceData {
1469 finish_reason: if choice.finish_reason == Some(ChatFinishReason::ToolCalls) {
1470 CompatibleFinishReason::ToolCalls
1471 } else {
1472 CompatibleFinishReason::Other
1473 },
1474 text: choice.delta.content.clone(),
1475 reasoning: choice.delta.reasoning_content.clone(),
1476 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
1477 &choice.delta.tool_calls,
1478 ),
1479 details: Vec::new(),
1480 },
1481 ),
1482 ))
1483 }
1484
1485 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
1486 CopilotStreamingResponse::Chat(openai::completion::streaming::StreamingCompletionResponse {
1487 usage,
1488 })
1489 }
1490
1491 fn uses_distinct_tool_call_eviction(&self) -> bool {
1492 true
1493 }
1494}
1495
1496async fn send_copilot_chat_streaming_request<T>(
1497 http_client: T,
1498 req: Request<Vec<u8>>,
1499) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError>
1500where
1501 T: HttpClientExt + Clone + 'static,
1502{
1503 openai_chat_completions_compatible::send_compatible_streaming_request(
1504 http_client,
1505 req,
1506 CopilotChatCompatibleProfile,
1507 )
1508 .await
1509}
1510
1511fn default_token_dir() -> Option<PathBuf> {
1512 config_dir().map(|dir| dir.join("github_copilot"))
1513}
1514
1515fn config_dir() -> Option<PathBuf> {
1516 #[cfg(target_os = "windows")]
1517 {
1518 std::env::var_os("APPDATA").map(PathBuf::from)
1519 }
1520
1521 #[cfg(not(target_os = "windows"))]
1522 {
1523 std::env::var_os("XDG_CONFIG_HOME")
1524 .map(PathBuf::from)
1525 .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
1526 }
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531 use super::{
1532 ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute,
1533 TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token,
1534 route_for_model,
1535 };
1536 use crate::client::CompletionClient;
1537 use crate::completion::CompletionModel;
1538 use crate::http_client;
1539 use crate::providers::internal::openai_chat_completions_compatible::test_support::{
1540 sse_bytes_from_data_lines, sse_bytes_from_json_events,
1541 };
1542 use crate::streaming::StreamedAssistantContent;
1543 use crate::test_utils::MockStreamingClient;
1544 use crate::test_utils::{RecordingHttpClient, SequencedStreamingHttpClient};
1545 use futures::StreamExt;
1546 use std::collections::HashMap;
1547
1548 fn env_map(entries: &[(&str, &str)]) -> HashMap<String, String> {
1549 entries
1550 .iter()
1551 .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1552 .collect()
1553 }
1554
1555 fn minimal_chat_response() -> &'static str {
1556 r#"{
1557 "id": "chatcmpl-123",
1558 "model": "gpt-4o",
1559 "choices": [{
1560 "index": 0,
1561 "message": {
1562 "role": "assistant",
1563 "content": "hello"
1564 },
1565 "finish_reason": "stop"
1566 }],
1567 "usage": {
1568 "prompt_tokens": 4,
1569 "total_tokens": 7
1570 }
1571 }"#
1572 }
1573
1574 fn minimal_responses_response() -> &'static str {
1575 r#"{
1576 "id": "resp_123",
1577 "object": "response",
1578 "created_at": 1700000000,
1579 "status": "completed",
1580 "error": null,
1581 "incomplete_details": null,
1582 "instructions": null,
1583 "max_output_tokens": null,
1584 "model": "gpt-5.3-codex",
1585 "usage": {
1586 "input_tokens": 4,
1587 "input_tokens_details": {
1588 "cached_tokens": 0
1589 },
1590 "output_tokens": 3,
1591 "output_tokens_details": {
1592 "reasoning_tokens": 0
1593 },
1594 "total_tokens": 7
1595 },
1596 "output": [{
1597 "type": "message",
1598 "id": "msg_123",
1599 "role": "assistant",
1600 "status": "completed",
1601 "content": [{
1602 "type": "output_text",
1603 "text": "hello"
1604 }]
1605 }],
1606 "tools": []
1607 }"#
1608 }
1609
1610 fn minimal_embeddings_response() -> &'static str {
1611 r#"{
1612 "data": [
1613 {
1614 "embedding": [0.1, 0.2, 0.3]
1615 },
1616 {
1617 "embedding": [0.4, 0.5, 0.6]
1618 }
1619 ]
1620 }"#
1621 }
1622
1623 #[test]
1624 fn deserialize_standard_openai_response() {
1625 let json = r#"{
1626 "id": "chatcmpl-abc123",
1627 "object": "chat.completion",
1628 "created": 1700000000,
1629 "model": "gpt-4o",
1630 "choices": [{
1631 "index": 0,
1632 "message": {
1633 "role": "assistant",
1634 "content": "Hello!"
1635 },
1636 "finish_reason": "stop"
1637 }],
1638 "usage": {
1639 "prompt_tokens": 10,
1640 "completion_tokens": 5,
1641 "total_tokens": 15
1642 }
1643 }"#;
1644
1645 let response: ChatCompletionResponse =
1646 serde_json::from_str(json).expect("standard OpenAI response should deserialize");
1647 assert_eq!(response.id, "chatcmpl-abc123");
1648 assert_eq!(response.object.as_deref(), Some("chat.completion"));
1649 assert_eq!(response.created, Some(1700000000));
1650 assert_eq!(response.model, "gpt-4o");
1651 assert_eq!(response.choices.len(), 1);
1652 assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop"));
1653 }
1654
1655 #[test]
1656 fn deserialize_copilot_response_without_object_and_created() {
1657 let response: ChatCompletionResponse = serde_json::from_str(minimal_chat_response())
1658 .expect("Copilot response should deserialize");
1659
1660 assert_eq!(response.id, "chatcmpl-123");
1661 assert_eq!(response.object, None);
1662 assert_eq!(response.created, None);
1663 assert_eq!(response.model, "gpt-4o");
1664 assert_eq!(response.choices.len(), 1);
1665 }
1666
1667 #[test]
1668 fn deserialize_copilot_response_without_finish_reason() {
1669 let json = r#"{
1670 "id": "chatcmpl-claude-001",
1671 "model": "claude-3.5-sonnet",
1672 "choices": [{
1673 "message": {
1674 "role": "assistant",
1675 "content": "Here is my analysis."
1676 }
1677 }],
1678 "usage": {
1679 "prompt_tokens": 50,
1680 "total_tokens": 80
1681 }
1682 }"#;
1683
1684 let response: ChatCompletionResponse =
1685 serde_json::from_str(json).expect("Claude-via-Copilot response should deserialize");
1686
1687 assert_eq!(response.model, "claude-3.5-sonnet");
1688 assert_eq!(response.choices[0].finish_reason, None);
1689 assert_eq!(response.choices[0].index, 0);
1690 }
1691
1692 #[test]
1693 fn error_response_with_message_field() {
1694 let json = r#"{"message": "rate limit exceeded"}"#;
1695 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("message-shaped error");
1696
1697 assert_eq!(err.error_message(), "rate limit exceeded");
1698 }
1699
1700 #[test]
1701 fn error_response_with_error_field() {
1702 let json = r#"{"error": "model not found"}"#;
1703 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("error-shaped error");
1704
1705 assert_eq!(err.error_message(), "model not found");
1706 }
1707
1708 #[test]
1709 fn routes_codex_models_to_responses() {
1710 assert_eq!(route_for_model("gpt-5.3-codex"), CompletionRoute::Responses);
1711 assert_eq!(
1712 route_for_model("gpt-5.1-CODEX-mini"),
1713 CompletionRoute::Responses
1714 );
1715 assert_eq!(route_for_model("gpt-5.2"), CompletionRoute::ChatCompletions);
1716 assert_eq!(
1717 route_for_model("claude-sonnet-4.5"),
1718 CompletionRoute::ChatCompletions
1719 );
1720 }
1721
1722 #[tokio::test]
1723 async fn completion_model_routes_chat_requests_to_chat_completions() {
1724 let http_client = RecordingHttpClient::new(minimal_chat_response());
1725 let client = Client::builder()
1726 .api_key("copilot-token")
1727 .http_client(http_client.clone())
1728 .build()
1729 .expect("build client");
1730 let model = client.completion_model("gpt-4o");
1731 let request = model.completion_request("hello").build();
1732
1733 let _response = model.completion(request).await.expect("chat completion");
1734
1735 let requests = http_client.requests();
1736 assert_eq!(requests.len(), 1);
1737 assert!(requests[0].uri.ends_with("/chat/completions"));
1738 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-4o\""));
1739 }
1740
1741 #[tokio::test]
1742 async fn completion_model_routes_codex_requests_to_responses() {
1743 let http_client = RecordingHttpClient::new(minimal_responses_response());
1744 let client = Client::builder()
1745 .api_key("copilot-token")
1746 .http_client(http_client.clone())
1747 .build()
1748 .expect("build client");
1749 let model = client.completion_model("gpt-5.3-codex");
1750 let request = model.completion_request("hello").build();
1751
1752 let _response = model
1753 .completion(request)
1754 .await
1755 .expect("responses completion");
1756
1757 let requests = http_client.requests();
1758 assert_eq!(requests.len(), 1);
1759 assert!(requests[0].uri.ends_with("/responses"));
1760 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-5.3-codex\""));
1761 }
1762
1763 #[tokio::test]
1764 async fn embeddings_accept_minimal_copilot_response_shape() {
1765 use crate::client::EmbeddingsClient;
1766 use crate::embeddings::EmbeddingModel as _;
1767
1768 let http_client = RecordingHttpClient::new(minimal_embeddings_response());
1769 let client = Client::builder()
1770 .api_key("copilot-token")
1771 .http_client(http_client.clone())
1772 .build()
1773 .expect("build client");
1774 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1775
1776 let embeddings = model
1777 .embed_texts(["one".to_string(), "two".to_string()])
1778 .await
1779 .expect("embeddings should deserialize");
1780
1781 assert_eq!(embeddings.len(), 2);
1782 assert_eq!(embeddings[0].vec, vec![0.1, 0.2, 0.3]);
1783 assert_eq!(embeddings[1].vec, vec![0.4, 0.5, 0.6]);
1784
1785 let requests = http_client.requests();
1786 assert_eq!(requests.len(), 1);
1787 assert!(requests[0].uri.ends_with("/embeddings"));
1788 assert!(
1789 String::from_utf8_lossy(&requests[0].body)
1790 .contains("\"model\":\"text-embedding-3-small\"")
1791 );
1792 }
1793
1794 #[tokio::test]
1795 async fn responses_stream_terminates_after_terminal_error() {
1796 let tool_call_done = serde_json::json!({
1797 "type": "response.output_item.done",
1798 "sequence_number": 1,
1799 "item": {
1800 "type": "function_call",
1801 "id": "fc_123",
1802 "arguments": "{}",
1803 "call_id": "call_123",
1804 "name": "example_tool",
1805 "status": "completed"
1806 }
1807 });
1808 let failed = serde_json::json!({
1809 "type": "response.failed",
1810 "sequence_number": 2,
1811 "response": {
1812 "id": "resp_123",
1813 "object": "response",
1814 "created_at": 1700000000,
1815 "status": "failed",
1816 "error": {
1817 "code": "server_error",
1818 "message": "Copilot response stream failed"
1819 },
1820 "incomplete_details": null,
1821 "instructions": null,
1822 "max_output_tokens": null,
1823 "model": "gpt-5.3-codex",
1824 "usage": null,
1825 "output": [],
1826 "tools": []
1827 }
1828 });
1829 let http_client = MockStreamingClient {
1830 sse_bytes: sse_bytes_from_json_events(&[tool_call_done, failed]),
1831 };
1832 let client = Client::builder()
1833 .api_key("copilot-token")
1834 .http_client(http_client)
1835 .build()
1836 .expect("build client");
1837 let model = client.completion_model("gpt-5.3-codex");
1838 let request = model.completion_request("hello").build();
1839 let mut stream = model.stream(request).await.expect("stream should start");
1840
1841 let err = match stream.next().await.expect("stream should yield an item") {
1842 Ok(_) => panic!("stream should surface a provider error"),
1843 Err(err) => err,
1844 };
1845 assert_eq!(
1846 err.to_string(),
1847 "ProviderError: Copilot response stream failed"
1848 );
1849 assert!(
1850 stream.next().await.is_none(),
1851 "responses stream should terminate immediately after a terminal error"
1852 );
1853 }
1854
1855 #[tokio::test]
1856 async fn chat_stream_terminates_after_transport_error() {
1857 let chunks = vec![
1858 Ok(sse_bytes_from_data_lines([
1859 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
1860 ])),
1861 Err(http_client::Error::InvalidStatusCode(
1862 http::StatusCode::BAD_GATEWAY,
1863 )),
1864 ];
1865
1866 let http_client = SequencedStreamingHttpClient::new(chunks);
1867 let client = Client::builder()
1868 .api_key("copilot-token")
1869 .http_client(http_client)
1870 .build()
1871 .expect("build client");
1872 let model = client.completion_model("gpt-4o");
1873 let request = model.completion_request("hello").build();
1874 let mut stream = model.stream(request).await.expect("stream should start");
1875
1876 let mut saw_error = false;
1877 while let Some(item) = stream.next().await {
1878 match item {
1879 Ok(StreamedAssistantContent::ToolCallDelta { .. }) => {}
1880 Err(err) => {
1881 assert_eq!(
1882 err.to_string(),
1883 "ProviderError: Invalid status code: 502 Bad Gateway"
1884 );
1885 saw_error = true;
1886 break;
1887 }
1888 Ok(_) => panic!("unexpected non-error stream item before transport failure"),
1889 }
1890 }
1891
1892 assert!(saw_error, "stream should surface the transport error");
1893 assert!(
1894 stream.next().await.is_none(),
1895 "chat stream should terminate immediately after a transport error"
1896 );
1897 }
1898
1899 #[test]
1900 fn env_api_key_prefers_github_prefixed_vars() {
1901 let env = env_map(&[
1902 ("COPILOT_API_KEY", "copilot-key"),
1903 ("GITHUB_COPILOT_API_KEY", "github-key"),
1904 ("GITHUB_TOKEN", "bootstrap-token"),
1905 ]);
1906 let get = |name: &str| env.get(name).cloned();
1907
1908 assert_eq!(env_api_key(&get).as_deref(), Some("github-key"));
1909 }
1910
1911 #[test]
1912 fn env_github_access_token_prefers_explicit_bootstrap_var() {
1913 let env = env_map(&[
1914 ("COPILOT_GITHUB_ACCESS_TOKEN", "explicit-bootstrap"),
1915 ("GITHUB_TOKEN", "fallback-bootstrap"),
1916 ]);
1917 let get = |name: &str| env.get(name).cloned();
1918
1919 assert_eq!(
1920 env_github_access_token(&get).as_deref(),
1921 Some("explicit-bootstrap")
1922 );
1923 }
1924
1925 #[test]
1926 fn env_base_url_prefers_github_prefixed_vars() {
1927 let env = env_map(&[
1928 ("COPILOT_BASE_URL", "https://copilot.example"),
1929 ("GITHUB_COPILOT_API_BASE", "https://github.example"),
1930 ]);
1931 let get = |name: &str| env.get(name).cloned();
1932
1933 assert_eq!(
1934 env_base_url(&get).as_deref(),
1935 Some("https://github.example")
1936 );
1937 }
1938
1939 #[test]
1940 fn env_without_api_key_falls_back_to_oauth() {
1941 let env = env_map(&[("COPILOT_BASE_URL", "https://copilot.example")]);
1942 let get = |name: &str| env.get(name).cloned();
1943
1944 assert!(env_api_key(&get).is_none());
1945 assert!(env_github_access_token(&get).is_none());
1946 assert_eq!(
1947 env_base_url(&get).as_deref(),
1948 Some("https://copilot.example")
1949 );
1950 }
1951
1952 #[test]
1953 fn env_github_token_is_not_treated_as_copilot_api_key() {
1954 let env = env_map(&[("GITHUB_TOKEN", "bootstrap-token")]);
1955 let get = |name: &str| env.get(name).cloned();
1956
1957 assert!(env_api_key(&get).is_none());
1958 assert_eq!(
1959 env_github_access_token(&get).as_deref(),
1960 Some("bootstrap-token")
1961 );
1962 }
1963}