1mod auth;
24
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider, ProviderBuilder,
27 ProviderClient, Transport,
28};
29use crate::completion::{self, CompletionError, GetTokenUsage};
30use crate::embeddings::{self, EmbeddingError};
31use crate::http_client::{self, HttpClientExt};
32use crate::model::{Model, ModelList, ModelListingError};
33use crate::providers::internal::openai_chat_completions_compatible::{
34 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
35 CompatibleToolCallChunk,
36};
37use crate::providers::openai;
38use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest};
39use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse};
40use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
41use async_stream::stream;
42use futures::StreamExt;
43use http::Request;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use std::borrow::Cow;
47use std::collections::HashMap;
48use std::fmt::Debug;
49use std::path::{Path, PathBuf};
50use tracing::info_span;
51use tracing_futures::Instrument as _;
52
53const GITHUB_COPILOT_API_BASE_URL: &str = "https://api.githubcopilot.com";
54const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
55const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
56const API_VERSION: &str = "2025-04-01";
57
58pub const GPT_4: &str = "gpt-4";
60pub const GPT_4O: &str = "gpt-4o";
62pub const GPT_4O_MINI: &str = "gpt-4o-mini";
64pub const GPT_4_1: &str = "gpt-4.1";
66pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
68pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
70pub const GPT_5_3_CODEX: &str = "gpt-5.3-codex";
72pub const GPT_5_1_CODEX: &str = "gpt-5.1-codex";
74pub const GPT_5_5: &str = "gpt-5.5";
76pub const GPT_5_4: &str = "gpt-5.4";
78pub const CLAUDE_SONNET_4: &str = "claude-sonnet-4";
80pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4.6";
82pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4.6";
84pub const CLAUDE_OPUS_4_7: &str = "claude-opus-4.7";
86pub const CLAUDE_3_5_SONNET: &str = "claude-3.5-sonnet";
88pub const GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
90pub const GEMINI_3_1_PRO_FLASH: &str = "gemini-3.1-pro-preview";
92pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash-001";
94pub const O3_MINI: &str = "o3-mini";
96pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
98pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
100pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
102
103pub use openai::EncodingFormat;
104
105#[derive(Clone)]
106pub enum CopilotAuth {
107 ApiKey(String),
108 GitHubAccessToken(String),
109 OAuth,
110}
111
112impl ApiKey for CopilotAuth {}
113
114impl<S> From<S> for CopilotAuth
115where
116 S: Into<String>,
117{
118 fn from(value: S) -> Self {
119 Self::ApiKey(value.into())
120 }
121}
122
123impl Debug for CopilotAuth {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 Self::ApiKey(_) => f.write_str("ApiKey(<redacted>)"),
127 Self::GitHubAccessToken(_) => f.write_str("GitHubAccessToken(<redacted>)"),
128 Self::OAuth => f.write_str("OAuth"),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct CopilotBuilder {
135 access_token_file: Option<PathBuf>,
136 api_key_file: Option<PathBuf>,
137 device_code_handler: auth::DeviceCodeHandler,
138}
139
140#[derive(Clone)]
141pub struct CopilotExt {
142 auth: auth::Authenticator,
143}
144
145impl Debug for CopilotExt {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("CopilotExt")
148 .field("auth", &self.auth)
149 .finish()
150 }
151}
152
153pub type Client<H = reqwest::Client> = client::Client<CopilotExt, H>;
154pub type ClientBuilder<H = crate::markers::Missing> =
155 client::ClientBuilder<CopilotBuilder, CopilotAuth, H>;
156
157impl Default for CopilotBuilder {
158 fn default() -> Self {
159 let token_dir = default_token_dir();
160 Self {
161 access_token_file: token_dir.as_ref().map(|dir| dir.join("access-token")),
162 api_key_file: token_dir.map(|dir| dir.join("api-key.json")),
163 device_code_handler: auth::DeviceCodeHandler::default(),
164 }
165 }
166}
167
168impl Provider for CopilotExt {
169 type Builder = CopilotBuilder;
170
171 const VERIFY_PATH: &'static str = "";
172}
173
174impl<H> Capabilities<H> for CopilotExt {
175 type Completion = Capable<CompletionModel<H>>;
176 type Embeddings = Capable<EmbeddingModel<H>>;
177 type Transcription = Nothing;
178 type ModelListing = Capable<CopilotModelLister<H>>;
179 #[cfg(feature = "image")]
180 type ImageGeneration = Nothing;
181 #[cfg(feature = "audio")]
182 type AudioGeneration = Nothing;
183}
184
185impl DebugExt for CopilotExt {}
186
187impl ProviderBuilder for CopilotBuilder {
188 type Extension<H>
189 = CopilotExt
190 where
191 H: HttpClientExt;
192 type ApiKey = CopilotAuth;
193
194 const BASE_URL: &'static str = GITHUB_COPILOT_API_BASE_URL;
195
196 fn build<H>(
197 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
198 ) -> http_client::Result<Self::Extension<H>>
199 where
200 H: HttpClientExt,
201 {
202 let auth = match builder.get_api_key() {
203 CopilotAuth::ApiKey(api_key) => auth::AuthSource::ApiKey(api_key.clone()),
204 CopilotAuth::GitHubAccessToken(access_token) => {
205 auth::AuthSource::GitHubAccessToken(access_token.clone())
206 }
207 CopilotAuth::OAuth => auth::AuthSource::OAuth,
208 };
209
210 let ext = builder.ext();
211 Ok(CopilotExt {
212 auth: auth::Authenticator::new(
213 auth,
214 ext.access_token_file.clone(),
215 ext.api_key_file.clone(),
216 ext.device_code_handler.clone(),
217 ),
218 })
219 }
220}
221
222impl ProviderClient for Client {
223 type Input = CopilotAuth;
224 type Error = crate::client::ProviderClientError;
225
226 fn from_env() -> Result<Self, Self::Error> {
227 let mut builder = Self::builder();
228 fn get(name: &str) -> Option<String> {
229 std::env::var(name).ok()
230 }
231
232 if let Some(base_url) = env_base_url(&get) {
233 builder = builder.base_url(base_url);
234 }
235
236 if let Some(api_key) = env_api_key(&get) {
237 builder.api_key(api_key).build().map_err(Into::into)
238 } else if let Some(access_token) = env_github_access_token(&get) {
239 builder
240 .github_access_token(access_token)
241 .build()
242 .map_err(Into::into)
243 } else {
244 builder.oauth().build().map_err(Into::into)
245 }
246 }
247
248 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
249 Self::builder().api_key(input).build().map_err(Into::into)
250 }
251}
252
253impl<H> client::ClientBuilder<CopilotBuilder, crate::markers::Missing, H> {
254 pub fn github_access_token(
255 self,
256 access_token: impl Into<String>,
257 ) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
258 self.api_key(CopilotAuth::GitHubAccessToken(access_token.into()))
259 }
260
261 pub fn oauth(self) -> client::ClientBuilder<CopilotBuilder, CopilotAuth, H> {
262 self.api_key(CopilotAuth::OAuth)
263 }
264}
265
266impl<H> ClientBuilder<H> {
267 pub fn on_device_code<F>(self, handler: F) -> Self
268 where
269 F: Fn(auth::DeviceCodePrompt) + Send + Sync + 'static,
270 {
271 self.over_ext(|mut ext| {
272 ext.device_code_handler = auth::DeviceCodeHandler::new(handler);
273 ext
274 })
275 }
276
277 pub fn token_dir(self, path: impl AsRef<Path>) -> Self {
278 let path = path.as_ref();
279 self.over_ext(|mut ext| {
280 ext.access_token_file = Some(path.join("access-token"));
281 ext.api_key_file = Some(path.join("api-key.json"));
282 ext
283 })
284 }
285
286 pub fn access_token_file(self, path: impl AsRef<Path>) -> Self {
287 let path = path.as_ref().to_path_buf();
288 self.over_ext(|mut ext| {
289 ext.access_token_file = Some(path);
290 ext
291 })
292 }
293
294 pub fn api_key_file(self, path: impl AsRef<Path>) -> Self {
295 let path = path.as_ref().to_path_buf();
296 self.over_ext(|mut ext| {
297 ext.api_key_file = Some(path);
298 ext
299 })
300 }
301}
302
303fn env_value<F>(get: &F, name: &str) -> Option<String>
304where
305 F: Fn(&str) -> Option<String>,
306{
307 get(name).filter(|value| !value.trim().is_empty())
308}
309
310fn first_env_value<F>(get: &F, keys: &[&str]) -> Option<String>
311where
312 F: Fn(&str) -> Option<String>,
313{
314 keys.iter().find_map(|key| env_value(get, key))
315}
316
317fn env_api_key<F>(get: &F) -> Option<String>
318where
319 F: Fn(&str) -> Option<String>,
320{
321 first_env_value(get, &["GITHUB_COPILOT_API_KEY", "COPILOT_API_KEY"])
322}
323
324fn env_github_access_token<F>(get: &F) -> Option<String>
325where
326 F: Fn(&str) -> Option<String>,
327{
328 first_env_value(get, &["COPILOT_GITHUB_ACCESS_TOKEN", "GITHUB_TOKEN"])
329}
330
331fn env_base_url<F>(get: &F) -> Option<String>
332where
333 F: Fn(&str) -> Option<String>,
334{
335 first_env_value(get, &["GITHUB_COPILOT_API_BASE", "COPILOT_BASE_URL"])
336}
337
338impl<H> Client<H>
339where
340 H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
341{
342 pub async fn authorize(&self) -> Result<(), auth::AuthError> {
343 self.ext().auth.auth_context().await.map(|_| ())
344 }
345}
346
347fn default_headers(
348 api_key: &str,
349 initiator: &'static str,
350 has_vision: bool,
351) -> Vec<(&'static str, String)> {
352 let mut headers = vec![
353 (
354 http::header::AUTHORIZATION.as_str(),
355 format!("Bearer {api_key}"),
356 ),
357 ("copilot-integration-id", "vscode-chat".to_string()),
358 ("editor-version", "vscode/1.95.0".to_string()),
359 ("editor-plugin-version", EDITOR_PLUGIN_VERSION.to_string()),
360 ("user-agent", USER_AGENT.to_string()),
361 ("openai-intent", "conversation-panel".to_string()),
362 ("x-github-api-version", API_VERSION.to_string()),
363 ("x-request-id", nanoid::nanoid!()),
364 (
365 "x-vscode-user-agent-library-version",
366 "electron-fetch".to_string(),
367 ),
368 ("X-Initiator", initiator.to_string()),
369 ];
370
371 if has_vision {
372 headers.push(("copilot-vision-request", "true".to_string()));
373 }
374
375 headers
376}
377
378fn apply_headers(
379 builder: http_client::Builder,
380 headers: &[(&'static str, String)],
381) -> http_client::Builder {
382 headers
383 .iter()
384 .fold(builder, |builder, (key, value)| builder.header(*key, value))
385}
386
387fn runtime_base_url<'a, H>(client: &'a Client<H>, auth: &'a auth::AuthContext) -> Cow<'a, str> {
388 if client.base_url() == GITHUB_COPILOT_API_BASE_URL {
389 auth.api_base
390 .as_deref()
391 .map(Cow::Borrowed)
392 .unwrap_or_else(|| Cow::Borrowed(client.base_url()))
393 } else {
394 Cow::Borrowed(client.base_url())
395 }
396}
397
398fn post_with_auth_base<H>(
399 client: &Client<H>,
400 auth: &auth::AuthContext,
401 path: &str,
402 transport: Transport,
403) -> http_client::Result<http_client::Builder> {
404 let uri = client
405 .ext()
406 .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
407 let mut req = Request::post(uri);
408
409 if let Some(headers) = req.headers_mut() {
410 headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
411 }
412
413 client.ext().with_custom(req)
414}
415
416fn get_with_auth_base<H>(
417 client: &Client<H>,
418 auth: &auth::AuthContext,
419 path: &str,
420 transport: Transport,
421) -> http_client::Result<http_client::Builder> {
422 let uri = client
423 .ext()
424 .build_uri(runtime_base_url(client, auth).as_ref(), path, transport);
425 let mut req = Request::get(uri);
426
427 if let Some(headers) = req.headers_mut() {
428 headers.extend(client.headers().iter().map(|(k, v)| (k.clone(), v.clone())));
429 }
430
431 client.ext().with_custom(req)
432}
433
434fn request_initiator(request: &completion::CompletionRequest) -> &'static str {
435 for message in request.chat_history.iter() {
436 match message {
437 crate::completion::Message::Assistant { .. } => return "agent",
438 crate::completion::Message::User { content } => {
439 if content
440 .iter()
441 .any(|item| matches!(item, crate::message::UserContent::ToolResult(_)))
442 {
443 return "agent";
444 }
445 }
446 crate::completion::Message::System { .. } => {}
447 }
448 }
449
450 "user"
451}
452
453fn request_has_vision(request: &completion::CompletionRequest) -> bool {
454 request.chat_history.iter().any(|message| match message {
455 crate::completion::Message::User { content } => content
456 .iter()
457 .any(|item| matches!(item, crate::message::UserContent::Image(_))),
458 _ => false,
459 })
460}
461
462#[derive(Clone, Copy, Debug, PartialEq, Eq)]
463enum CompletionRoute {
464 ChatCompletions,
465 Responses,
466}
467
468fn route_for_model(model: &str) -> CompletionRoute {
469 if model.to_ascii_lowercase().contains("codex") {
470 CompletionRoute::Responses
471 } else {
472 CompletionRoute::ChatCompletions
473 }
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
477#[serde(tag = "api", rename_all = "snake_case")]
478pub enum CopilotCompletionResponse {
479 Chat(ChatCompletionResponse),
480 Responses(Box<responses_api::CompletionResponse>),
481}
482
483#[derive(Clone, Serialize, Deserialize)]
484#[serde(tag = "api", rename_all = "snake_case")]
485pub enum CopilotStreamingResponse {
486 Chat(openai::completion::streaming::StreamingCompletionResponse),
487 Responses(responses_api::streaming::StreamingCompletionResponse),
488}
489
490impl GetTokenUsage for CopilotStreamingResponse {
491 fn token_usage(&self) -> Option<completion::Usage> {
492 match self {
493 Self::Chat(response) => response.token_usage(),
494 Self::Responses(response) => response.token_usage(),
495 }
496 }
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct ChatCompletionResponse {
501 pub id: String,
502 #[serde(default)]
503 pub object: Option<String>,
504 #[serde(default)]
505 pub created: Option<u64>,
506 pub model: String,
507 pub system_fingerprint: Option<String>,
508 pub choices: Vec<ChatChoice>,
509 pub usage: Option<openai::completion::Usage>,
510}
511
512#[derive(Clone, Debug, Serialize, Deserialize)]
513pub struct ChatChoice {
514 #[serde(default)]
515 pub index: usize,
516 pub message: openai::completion::Message,
517 pub logprobs: Option<serde_json::Value>,
518 #[serde(default)]
519 pub finish_reason: Option<String>,
520}
521
522impl TryFrom<ChatCompletionResponse> for completion::CompletionResponse<ChatCompletionResponse> {
523 type Error = CompletionError;
524
525 fn try_from(response: ChatCompletionResponse) -> Result<Self, Self::Error> {
526 let choice = response.choices.first().ok_or_else(|| {
527 CompletionError::ResponseError("Response contained no choices".to_owned())
528 })?;
529
530 let content = match &choice.message {
531 openai::completion::Message::Assistant {
532 content,
533 tool_calls,
534 ..
535 } => {
536 let mut content = content
537 .iter()
538 .filter_map(|c| {
539 let s = match c {
540 openai::completion::AssistantContent::Text { text } => text,
541 openai::completion::AssistantContent::Refusal { refusal } => refusal,
542 };
543 if s.is_empty() {
544 None
545 } else {
546 Some(completion::AssistantContent::text(s))
547 }
548 })
549 .collect::<Vec<_>>();
550
551 content.extend(
552 tool_calls
553 .iter()
554 .map(|call| {
555 completion::AssistantContent::tool_call(
556 &call.id,
557 &call.function.name,
558 call.function.arguments.clone(),
559 )
560 })
561 .collect::<Vec<_>>(),
562 );
563 Ok(content)
564 }
565 _ => Err(CompletionError::ResponseError(
566 "Response did not contain a valid message or tool call".into(),
567 )),
568 }?;
569
570 let choice = crate::OneOrMany::many(content).map_err(|_| {
571 CompletionError::ResponseError(
572 "Response contained no message or tool call (empty)".to_owned(),
573 )
574 })?;
575
576 let usage = response
577 .usage
578 .as_ref()
579 .map(|usage| completion::Usage {
580 input_tokens: usage.prompt_tokens as u64,
581 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
582 total_tokens: usage.total_tokens as u64,
583 cached_input_tokens: usage
584 .prompt_tokens_details
585 .as_ref()
586 .map(|d| d.cached_tokens as u64)
587 .unwrap_or(0),
588 cache_creation_input_tokens: 0,
589 reasoning_tokens: 0,
590 })
591 .unwrap_or_default();
592
593 Ok(completion::CompletionResponse {
594 choice,
595 usage,
596 raw_response: response,
597 message_id: None,
598 })
599 }
600}
601
602#[derive(Debug, Deserialize)]
603pub struct ChatApiErrorResponse {
604 #[serde(default)]
605 pub message: Option<String>,
606 #[serde(default)]
607 pub error: Option<String>,
608}
609
610impl ChatApiErrorResponse {
611 pub fn error_message(&self) -> &str {
612 self.message
613 .as_deref()
614 .or(self.error.as_deref())
615 .unwrap_or("unknown error")
616 }
617}
618
619#[derive(Debug, Deserialize)]
620#[serde(untagged)]
621enum ChatApiResponse<T> {
622 Ok(T),
623 Err(ChatApiErrorResponse),
624}
625
626#[derive(Clone)]
627pub struct CompletionModel<H = reqwest::Client> {
628 client: Client<H>,
629 pub model: String,
630 pub strict_tools: bool,
631 pub tool_result_array_content: bool,
632}
633
634impl<H> CompletionModel<H>
635where
636 Client<H>: HttpClientExt + Clone + Debug + 'static,
637 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
638{
639 pub fn new(client: Client<H>, model: impl Into<String>) -> Self {
640 Self {
641 client,
642 model: model.into(),
643 strict_tools: false,
644 tool_result_array_content: false,
645 }
646 }
647
648 pub fn with_strict_tools(mut self) -> Self {
649 self.strict_tools = true;
650 self
651 }
652
653 pub fn with_tool_result_array_content(mut self) -> Self {
654 self.tool_result_array_content = true;
655 self
656 }
657
658 fn route(&self) -> CompletionRoute {
659 route_for_model(&self.model)
660 }
661
662 async fn auth_context(&self) -> Result<auth::AuthContext, CompletionError> {
663 self.client
664 .ext()
665 .auth
666 .auth_context()
667 .await
668 .map_err(|err| CompletionError::ProviderError(err.to_string()))
669 }
670
671 fn chat_request(
672 &self,
673 completion_request: completion::CompletionRequest,
674 ) -> Result<openai::completion::CompletionRequest, CompletionError> {
675 openai::completion::CompletionRequest::try_from(openai::completion::OpenAIRequestParams {
676 model: self.model.clone(),
677 request: completion_request,
678 strict_tools: self.strict_tools,
679 tool_result_array_content: self.tool_result_array_content,
680 })
681 }
682
683 fn responses_request(
684 &self,
685 completion_request: completion::CompletionRequest,
686 ) -> Result<ResponsesRequest, CompletionError> {
687 ResponsesRequest::try_from((self.model.clone(), completion_request))
688 }
689
690 async fn completion_chat(
691 &self,
692 completion_request: completion::CompletionRequest,
693 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
694 let initiator = request_initiator(&completion_request);
695 let has_vision = request_has_vision(&completion_request);
696 let request = self.chat_request(completion_request)?;
697 let body = serde_json::to_vec(&request)?;
698 let auth = self.auth_context().await?;
699
700 let headers = default_headers(&auth.api_key, initiator, has_vision);
701 let req = apply_headers(
702 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Http)?,
703 &headers,
704 )
705 .body(body)
706 .map_err(|err| CompletionError::HttpError(err.into()))?;
707
708 let span = if tracing::Span::current().is_disabled() {
709 info_span!(
710 target: "rig::completions",
711 "chat",
712 gen_ai.operation.name = "chat",
713 gen_ai.provider.name = "copilot",
714 gen_ai.request.model = self.model,
715 gen_ai.response.id = tracing::field::Empty,
716 gen_ai.response.model = tracing::field::Empty,
717 gen_ai.usage.output_tokens = tracing::field::Empty,
718 gen_ai.usage.input_tokens = tracing::field::Empty,
719 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
720 )
721 } else {
722 tracing::Span::current()
723 };
724
725 async move {
726 let response = self.client.send(req).await?;
727
728 if response.status().is_success() {
729 let body = http_client::text(response).await?;
730 match serde_json::from_str::<ChatApiResponse<ChatCompletionResponse>>(&body)? {
731 ChatApiResponse::Ok(response) => {
732 let core = completion::CompletionResponse::try_from(response.clone())?;
733 let span = tracing::Span::current();
734 span.record("gen_ai.response.id", response.id.as_str());
735 span.record("gen_ai.response.model", response.model.as_str());
736 if let Some(usage) = &response.usage {
737 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
738 span.record(
739 "gen_ai.usage.output_tokens",
740 usage.total_tokens - usage.prompt_tokens,
741 );
742 span.record(
743 "gen_ai.usage.cache_read.input_tokens",
744 usage
745 .prompt_tokens_details
746 .as_ref()
747 .map(|details| details.cached_tokens)
748 .unwrap_or(0),
749 );
750 }
751
752 Ok(completion::CompletionResponse {
753 choice: core.choice,
754 usage: core.usage,
755 raw_response: CopilotCompletionResponse::Chat(response),
756 message_id: core.message_id,
757 })
758 }
759 ChatApiResponse::Err(err) => Err(CompletionError::ProviderError(
760 err.error_message().to_string(),
761 )),
762 }
763 } else {
764 let body = http_client::text(response).await?;
765 Err(CompletionError::ProviderError(body))
766 }
767 }
768 .instrument(span)
769 .await
770 }
771
772 async fn completion_responses(
773 &self,
774 completion_request: completion::CompletionRequest,
775 ) -> Result<completion::CompletionResponse<CopilotCompletionResponse>, CompletionError> {
776 let initiator = request_initiator(&completion_request);
777 let has_vision = request_has_vision(&completion_request);
778 let request = self.responses_request(completion_request)?;
779 let auth = self.auth_context().await?;
780
781 let headers = default_headers(&auth.api_key, initiator, has_vision);
782 let req = apply_headers(
783 post_with_auth_base(&self.client, &auth, "/responses", Transport::Http)?,
784 &headers,
785 )
786 .body(serde_json::to_vec(&request)?)
787 .map_err(|err| CompletionError::HttpError(err.into()))?;
788
789 let span = if tracing::Span::current().is_disabled() {
790 info_span!(
791 target: "rig::completions",
792 "chat",
793 gen_ai.operation.name = "chat",
794 gen_ai.provider.name = "copilot",
795 gen_ai.request.model = self.model,
796 gen_ai.response.id = tracing::field::Empty,
797 gen_ai.response.model = tracing::field::Empty,
798 gen_ai.usage.output_tokens = tracing::field::Empty,
799 gen_ai.usage.input_tokens = tracing::field::Empty,
800 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
801 )
802 } else {
803 tracing::Span::current()
804 };
805
806 async move {
807 let response = self.client.send(req).await?;
808 if response.status().is_success() {
809 let body = http_client::text(response).await?;
810 let response = serde_json::from_str::<responses_api::CompletionResponse>(&body)?;
811 let core = completion::CompletionResponse::try_from(response.clone())?;
812
813 let span = tracing::Span::current();
814 span.record("gen_ai.response.id", response.id.as_str());
815 span.record("gen_ai.response.model", response.model.as_str());
816 if let Some(usage) = &response.usage {
817 span.record("gen_ai.usage.input_tokens", usage.input_tokens);
818 span.record("gen_ai.usage.output_tokens", usage.output_tokens);
819 span.record(
820 "gen_ai.usage.cache_read.input_tokens",
821 usage
822 .input_tokens_details
823 .as_ref()
824 .map(|details| details.cached_tokens)
825 .unwrap_or(0),
826 );
827 }
828
829 Ok(completion::CompletionResponse {
830 choice: core.choice,
831 usage: core.usage,
832 raw_response: CopilotCompletionResponse::Responses(Box::new(response)),
833 message_id: core.message_id,
834 })
835 } else {
836 let body = http_client::text(response).await?;
837 Err(CompletionError::ProviderError(body))
838 }
839 }
840 .instrument(span)
841 .await
842 }
843
844 async fn stream_chat(
845 &self,
846 completion_request: completion::CompletionRequest,
847 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
848 let initiator = request_initiator(&completion_request);
849 let has_vision = request_has_vision(&completion_request);
850 let request = self.chat_request(completion_request)?;
851 let auth = self.auth_context().await?;
852 let headers = default_headers(&auth.api_key, initiator, has_vision);
853 let mut request_json = serde_json::to_value(&request)?;
854 let request_object = request_json.as_object_mut().ok_or_else(|| {
855 CompletionError::ResponseError("copilot request body must be a JSON object".into())
856 })?;
857 request_object.insert("stream".to_owned(), json!(true));
858 request_object.insert(
859 "stream_options".to_owned(),
860 json!({ "include_usage": true }),
861 );
862
863 let req = apply_headers(
864 post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?,
865 &headers,
866 )
867 .body(serde_json::to_vec(&request_json)?)
868 .map_err(|err| CompletionError::HttpError(err.into()))?;
869
870 let span = if tracing::Span::current().is_disabled() {
871 info_span!(
872 target: "rig::completions",
873 "chat_streaming",
874 gen_ai.operation.name = "chat_streaming",
875 gen_ai.provider.name = "copilot",
876 gen_ai.request.model = self.model,
877 gen_ai.response.id = tracing::field::Empty,
878 gen_ai.response.model = tracing::field::Empty,
879 gen_ai.usage.output_tokens = tracing::field::Empty,
880 gen_ai.usage.input_tokens = tracing::field::Empty,
881 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
882 )
883 } else {
884 tracing::Span::current()
885 };
886
887 tracing::Instrument::instrument(
888 send_copilot_chat_streaming_request(self.client.clone(), req),
889 span,
890 )
891 .await
892 }
893
894 async fn stream_responses(
895 &self,
896 completion_request: completion::CompletionRequest,
897 ) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError> {
898 let initiator = request_initiator(&completion_request);
899 let has_vision = request_has_vision(&completion_request);
900 let mut request = self.responses_request(completion_request)?;
901 request.stream = Some(true);
902 let auth = self.auth_context().await?;
903
904 let headers = default_headers(&auth.api_key, initiator, has_vision);
905 let req = apply_headers(
906 post_with_auth_base(&self.client, &auth, "/responses", Transport::Sse)?,
907 &headers,
908 )
909 .body(serde_json::to_vec(&request)?)
910 .map_err(|err| CompletionError::HttpError(err.into()))?;
911
912 let span = if tracing::Span::current().is_disabled() {
913 info_span!(
914 target: "rig::completions",
915 "chat_streaming",
916 gen_ai.operation.name = "chat_streaming",
917 gen_ai.provider.name = "copilot",
918 gen_ai.request.model = self.model,
919 gen_ai.response.id = tracing::field::Empty,
920 gen_ai.response.model = tracing::field::Empty,
921 gen_ai.usage.output_tokens = tracing::field::Empty,
922 gen_ai.usage.input_tokens = tracing::field::Empty,
923 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
924 )
925 } else {
926 tracing::Span::current()
927 };
928
929 let client = self.client.clone();
930 let mut event_source = crate::http_client::sse::GenericEventSource::new(client, req);
931
932 let stream = tracing_futures::Instrument::instrument(
933 stream! {
934 let mut final_usage = responses_api::ResponsesUsage::new();
935 let mut tool_calls: Vec<streaming::RawStreamingChoice<CopilotStreamingResponse>> = Vec::new();
936 let mut tool_call_internal_ids: HashMap<String, String> = HashMap::new();
937 let span = tracing::Span::current();
938
939 let mut terminated_with_error = false;
940
941 while let Some(event_result) = event_source.next().await {
942 match event_result {
943 Ok(crate::http_client::sse::Event::Open) => continue,
944 Ok(crate::http_client::sse::Event::Message(evt)) => {
945 if evt.data.trim().is_empty() {
946 continue;
947 }
948
949 let Ok(data) = serde_json::from_str::<responses_api::streaming::StreamingCompletionChunk>(&evt.data) else {
950 continue;
951 };
952
953 if let responses_api::streaming::StreamingCompletionChunk::Delta(chunk) = &data {
954 use responses_api::streaming::{ItemChunkKind, StreamingItemDoneOutput};
955
956 match &chunk.data {
957 ItemChunkKind::OutputItemAdded(message) => {
958 if let StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } = message {
959 let internal_call_id = tool_call_internal_ids
960 .entry(func.id.clone())
961 .or_insert_with(|| nanoid::nanoid!())
962 .clone();
963 yield Ok(RawStreamingChoice::ToolCallDelta {
964 id: func.id.clone(),
965 internal_call_id,
966 content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
967 });
968 }
969 }
970 ItemChunkKind::OutputItemDone(message) => match message {
971 StreamingItemDoneOutput { item: responses_api::Output::FunctionCall(func), .. } => {
972 let internal_id = tool_call_internal_ids
973 .entry(func.id.clone())
974 .or_insert_with(|| nanoid::nanoid!())
975 .clone();
976 let raw_tool_call = streaming::RawStreamingToolCall::new(
977 func.id.clone(),
978 func.name.clone(),
979 func.arguments.clone(),
980 )
981 .with_internal_call_id(internal_id)
982 .with_call_id(func.call_id.clone());
983 tool_calls.push(RawStreamingChoice::ToolCall(raw_tool_call));
984 }
985 StreamingItemDoneOutput { item: responses_api::Output::Reasoning { summary, id, encrypted_content, .. }, .. } => {
986 for reasoning_choice in responses_api::streaming::reasoning_choices_from_done_item(
987 id,
988 summary,
989 encrypted_content.as_deref(),
990 ) {
991 match reasoning_choice {
992 RawStreamingChoice::Reasoning { id, content } => {
993 yield Ok(RawStreamingChoice::Reasoning { id, content });
994 }
995 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
996 yield Ok(RawStreamingChoice::ReasoningDelta { id, reasoning });
997 }
998 _ => {}
999 }
1000 }
1001 }
1002 StreamingItemDoneOutput { item: responses_api::Output::Message(msg), .. } => {
1003 yield Ok(RawStreamingChoice::MessageId(msg.id.clone()));
1004 }
1005 StreamingItemDoneOutput { item: responses_api::Output::Unknown, .. } => {}
1006 },
1007 ItemChunkKind::OutputTextDelta(delta) => {
1008 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1009 }
1010 ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
1011 yield Ok(RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
1012 }
1013 ItemChunkKind::RefusalDelta(delta) => {
1014 yield Ok(RawStreamingChoice::Message(delta.delta.clone()))
1015 }
1016 ItemChunkKind::FunctionCallArgsDelta(delta) => {
1017 let internal_call_id = tool_call_internal_ids
1018 .entry(delta.item_id.clone())
1019 .or_insert_with(|| nanoid::nanoid!())
1020 .clone();
1021 yield Ok(RawStreamingChoice::ToolCallDelta {
1022 id: delta.item_id.clone(),
1023 internal_call_id,
1024 content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone())
1025 })
1026 }
1027 _ => continue,
1028 }
1029 }
1030
1031 if let responses_api::streaming::StreamingCompletionChunk::Response(chunk) = data {
1032 let responses_api::streaming::ResponseChunk { kind, response, .. } = *chunk;
1033 match kind {
1034 responses_api::streaming::ResponseChunkKind::ResponseCompleted => {
1035 span.record("gen_ai.response.id", response.id.as_str());
1036 span.record("gen_ai.response.model", response.model.as_str());
1037 if let Some(usage) = response.usage {
1038 final_usage = usage;
1039 }
1040 }
1041 responses_api::streaming::ResponseChunkKind::ResponseFailed
1042 | responses_api::streaming::ResponseChunkKind::ResponseIncomplete => {
1043 let error = response
1044 .error
1045 .as_ref()
1046 .map(|err| err.message.clone())
1047 .unwrap_or_else(|| "Copilot response stream failed".into());
1048 terminated_with_error = true;
1049 yield Err(CompletionError::ProviderError(error));
1050 break;
1051 }
1052 _ => continue,
1053 }
1054 }
1055 }
1056 Err(crate::http_client::Error::StreamEnded) => {
1057 break;
1058 }
1059 Err(error) => {
1060 terminated_with_error = true;
1061 yield Err(CompletionError::ProviderError(error.to_string()));
1062 break;
1063 }
1064 }
1065 }
1066
1067 event_source.close();
1068
1069 if terminated_with_error {
1070 return;
1071 }
1072
1073 for tool_call in &tool_calls {
1074 yield Ok(tool_call.to_owned())
1075 }
1076
1077 span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
1078 span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
1079 span.record(
1080 "gen_ai.usage.cache_read.input_tokens",
1081 final_usage
1082 .input_tokens_details
1083 .as_ref()
1084 .map(|details| details.cached_tokens)
1085 .unwrap_or(0),
1086 );
1087
1088 yield Ok(RawStreamingChoice::FinalResponse(
1089 CopilotStreamingResponse::Responses(
1090 responses_api::streaming::StreamingCompletionResponse { usage: final_usage }
1091 )
1092 ));
1093 },
1094 span,
1095 );
1096
1097 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
1098 }
1099}
1100
1101impl<H> completion::CompletionModel for CompletionModel<H>
1102where
1103 Client<H>: HttpClientExt + Clone + Debug + 'static,
1104 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1105{
1106 type Response = CopilotCompletionResponse;
1107 type StreamingResponse = CopilotStreamingResponse;
1108 type Client = Client<H>;
1109
1110 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1111 Self::new(client.clone(), model)
1112 }
1113
1114 async fn completion(
1115 &self,
1116 completion_request: completion::CompletionRequest,
1117 ) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
1118 match self.route() {
1119 CompletionRoute::ChatCompletions => self.completion_chat(completion_request).await,
1120 CompletionRoute::Responses => self.completion_responses(completion_request).await,
1121 }
1122 }
1123
1124 async fn stream(
1125 &self,
1126 completion_request: completion::CompletionRequest,
1127 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1128 match self.route() {
1129 CompletionRoute::ChatCompletions => self.stream_chat(completion_request).await,
1130 CompletionRoute::Responses => self.stream_responses(completion_request).await,
1131 }
1132 }
1133}
1134
1135#[derive(Clone)]
1136pub struct EmbeddingModel<H = reqwest::Client> {
1137 client: Client<H>,
1138 pub model: String,
1139 pub encoding_format: Option<openai::EncodingFormat>,
1140 pub user: Option<String>,
1141 ndims: usize,
1142}
1143
1144#[derive(Deserialize)]
1145struct CopilotEmbeddingResponse {
1146 data: Vec<CopilotEmbeddingData>,
1147}
1148
1149#[derive(Deserialize)]
1150struct CopilotEmbeddingData {
1151 embedding: Vec<serde_json::Number>,
1152}
1153
1154impl<H> EmbeddingModel<H>
1155where
1156 Client<H>: HttpClientExt + Clone + Debug + 'static,
1157 H: Clone + Default + Debug + 'static,
1158{
1159 pub fn new(client: Client<H>, model: impl Into<String>, ndims: usize) -> Self {
1160 Self {
1161 client,
1162 model: model.into(),
1163 encoding_format: None,
1164 user: None,
1165 ndims,
1166 }
1167 }
1168}
1169
1170impl<H> embeddings::EmbeddingModel for EmbeddingModel<H>
1171where
1172 Client<H>: HttpClientExt + Clone + Debug + WasmCompatSend + WasmCompatSync + 'static,
1173 H: Clone + Default + Debug + WasmCompatSend + WasmCompatSync + 'static,
1174{
1175 const MAX_DOCUMENTS: usize = 1024;
1176 type Client = Client<H>;
1177
1178 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
1179 let model = model.into();
1180 let dims = ndims.unwrap_or(match model.as_str() {
1181 TEXT_EMBEDDING_3_LARGE => 3072,
1182 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
1183 _ => 0,
1184 });
1185 Self::new(client.clone(), model, dims)
1186 }
1187
1188 fn ndims(&self) -> usize {
1189 self.ndims
1190 }
1191
1192 async fn embed_texts(
1193 &self,
1194 documents: impl IntoIterator<Item = String>,
1195 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
1196 let documents = documents.into_iter().collect::<Vec<_>>();
1197 let auth = self
1198 .client
1199 .ext()
1200 .auth
1201 .auth_context()
1202 .await
1203 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
1204
1205 let headers = default_headers(&auth.api_key, "user", false);
1206 let mut body = json!({
1207 "model": self.model,
1208 "input": documents,
1209 });
1210
1211 let body_object = body.as_object_mut().ok_or_else(|| {
1212 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
1213 })?;
1214
1215 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
1216 body_object.insert("dimensions".to_owned(), json!(self.ndims));
1217 }
1218 if let Some(encoding_format) = &self.encoding_format {
1219 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
1220 }
1221 if let Some(user) = &self.user {
1222 body_object.insert("user".to_owned(), json!(user));
1223 }
1224
1225 let req = apply_headers(
1226 post_with_auth_base(&self.client, &auth, "/embeddings", Transport::Http)?,
1227 &headers,
1228 )
1229 .body(serde_json::to_vec(&body)?)
1230 .map_err(|err| EmbeddingError::HttpError(err.into()))?;
1231
1232 let response = self.client.send(req).await?;
1233 if response.status().is_success() {
1234 let body: Vec<u8> = response.into_body().await?;
1235 #[derive(Deserialize)]
1236 struct NestedApiError {
1237 error: NestedApiErrorMessage,
1238 }
1239
1240 #[derive(Deserialize)]
1241 struct NestedApiErrorMessage {
1242 message: String,
1243 }
1244
1245 let body: CopilotEmbeddingResponse = match serde_json::from_slice(&body) {
1246 Ok(parsed) => parsed,
1247 Err(parse_error) => {
1248 if let Ok(err) = serde_json::from_slice::<NestedApiError>(&body) {
1249 return Err(EmbeddingError::ProviderError(err.error.message));
1250 }
1251
1252 let preview = String::from_utf8_lossy(&body);
1253 let preview = if preview.len() > 512 {
1254 format!("{}...", &preview[..512])
1255 } else {
1256 preview.into_owned()
1257 };
1258
1259 return Err(EmbeddingError::ProviderError(format!(
1260 "Failed to parse Copilot embeddings response: {parse_error}; body: {preview}"
1261 )));
1262 }
1263 };
1264
1265 Ok(body
1266 .data
1267 .into_iter()
1268 .zip(documents.into_iter())
1269 .map(|(embedding, document)| embeddings::Embedding {
1270 document,
1271 vec: embedding
1272 .embedding
1273 .into_iter()
1274 .filter_map(|n| n.as_f64())
1275 .collect(),
1276 })
1277 .collect())
1278 } else {
1279 let text = http_client::text(response).await?;
1280 Err(EmbeddingError::ProviderError(text))
1281 }
1282 }
1283}
1284
1285const MODEL_LISTING_PATH: &str = "/models";
1286const MODEL_LISTING_PROVIDER: &str = "Copilot";
1287
1288#[derive(Debug, Deserialize)]
1289struct ListModelsResponse {
1290 data: Vec<ListModelEntry>,
1291}
1292
1293#[derive(Debug, Deserialize)]
1294struct ListModelEntry {
1295 id: String,
1296 #[serde(default)]
1297 name: Option<String>,
1298 #[serde(default)]
1299 vendor: Option<String>,
1300 #[serde(default)]
1301 capabilities: Option<ListModelEntryCapabilities>,
1302}
1303
1304#[derive(Debug, Deserialize)]
1305struct ListModelEntryCapabilities {
1306 #[serde(default, rename = "type")]
1307 r#type: Option<String>,
1308}
1309
1310impl From<ListModelEntry> for Model {
1311 fn from(value: ListModelEntry) -> Self {
1312 let mut model = Model::from_id(value.id);
1313 model.name = value.name;
1314 model.owned_by = value.vendor;
1315 if let Some(caps) = value.capabilities {
1316 model.r#type = caps.r#type;
1317 }
1318 model
1319 }
1320}
1321
1322#[derive(Clone)]
1324pub struct CopilotModelLister<H = reqwest::Client> {
1325 client: Client<H>,
1326}
1327
1328impl<H> ModelLister<H> for CopilotModelLister<H>
1329where
1330 H: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
1331{
1332 type Client = Client<H>;
1333
1334 fn new(client: Self::Client) -> Self {
1335 Self { client }
1336 }
1337
1338 async fn list_all(&self) -> Result<ModelList, ModelListingError> {
1339 let auth = self.client.ext().auth.auth_context().await.map_err(|err| {
1340 ModelListingError::AuthError {
1341 message: err.to_string(),
1342 }
1343 })?;
1344
1345 let headers = default_headers(&auth.api_key, "user", false);
1346 let req = apply_headers(
1347 get_with_auth_base(&self.client, &auth, MODEL_LISTING_PATH, Transport::Http)?,
1348 &headers,
1349 )
1350 .body(http_client::NoBody)?;
1351
1352 let response = self.client.send::<_, Vec<u8>>(req).await?;
1353
1354 if !response.status().is_success() {
1355 let status_code = response.status().as_u16();
1356 let body = response.into_body().await?;
1357 return Err(ModelListingError::api_error_with_context(
1358 MODEL_LISTING_PROVIDER,
1359 MODEL_LISTING_PATH,
1360 status_code,
1361 &body,
1362 ));
1363 }
1364
1365 let body = response.into_body().await?;
1366 let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
1367 ModelListingError::parse_error_with_context(
1368 MODEL_LISTING_PROVIDER,
1369 MODEL_LISTING_PATH,
1370 &error,
1371 &body,
1372 )
1373 })?;
1374 let models = api_resp.data.into_iter().map(Model::from).collect();
1375
1376 Ok(ModelList::new(models))
1377 }
1378}
1379
1380#[derive(Deserialize, Debug)]
1381struct ChatStreamingFunction {
1382 name: Option<String>,
1383 arguments: Option<String>,
1384}
1385
1386#[derive(Deserialize, Debug)]
1387struct ChatStreamingToolCall {
1388 index: usize,
1389 id: Option<String>,
1390 function: ChatStreamingFunction,
1391}
1392
1393impl From<&ChatStreamingToolCall> for CompatibleToolCallChunk {
1394 fn from(value: &ChatStreamingToolCall) -> Self {
1395 Self {
1396 index: value.index,
1397 id: value.id.clone(),
1398 name: value.function.name.clone(),
1399 arguments: value.function.arguments.clone(),
1400 }
1401 }
1402}
1403
1404#[derive(Deserialize, Debug, Default)]
1405struct ChatStreamingDelta {
1406 #[serde(default)]
1407 content: Option<String>,
1408 #[serde(default)]
1409 reasoning_content: Option<String>,
1410 #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")]
1411 tool_calls: Vec<ChatStreamingToolCall>,
1412}
1413
1414#[derive(Deserialize, Debug, PartialEq)]
1415#[serde(rename_all = "snake_case")]
1416enum ChatFinishReason {
1417 ToolCalls,
1418 Stop,
1419 ContentFilter,
1420 Length,
1421 #[serde(untagged)]
1422 Other(String),
1423}
1424
1425#[derive(Deserialize, Debug)]
1426struct ChatStreamingChoice {
1427 delta: ChatStreamingDelta,
1428 finish_reason: Option<ChatFinishReason>,
1429}
1430
1431#[derive(Deserialize, Debug)]
1432struct ChatStreamingChunk {
1433 id: Option<String>,
1434 model: Option<String>,
1435 choices: Vec<ChatStreamingChoice>,
1436 usage: Option<openai::completion::Usage>,
1437}
1438
1439#[derive(Clone, Copy)]
1440struct CopilotChatCompatibleProfile;
1441
1442impl CompatibleStreamProfile for CopilotChatCompatibleProfile {
1443 type Usage = openai::completion::Usage;
1444 type Detail = ();
1445 type FinalResponse = CopilotStreamingResponse;
1446
1447 fn normalize_chunk(
1448 &self,
1449 data: &str,
1450 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
1451 let data = match serde_json::from_str::<ChatStreamingChunk>(data) {
1452 Ok(data) => data,
1453 Err(error) => {
1454 tracing::debug!(?error, "Couldn't parse Copilot chat SSE payload");
1455 return Ok(None);
1456 }
1457 };
1458
1459 Ok(Some(
1460 openai_chat_completions_compatible::normalize_first_choice_chunk(
1461 data.id,
1462 data.model,
1463 data.usage,
1464 &data.choices,
1465 |choice| CompatibleChoiceData {
1466 finish_reason: if choice.finish_reason == Some(ChatFinishReason::ToolCalls) {
1467 CompatibleFinishReason::ToolCalls
1468 } else {
1469 CompatibleFinishReason::Other
1470 },
1471 text: choice.delta.content.clone(),
1472 reasoning: choice.delta.reasoning_content.clone(),
1473 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
1474 &choice.delta.tool_calls,
1475 ),
1476 details: Vec::new(),
1477 },
1478 ),
1479 ))
1480 }
1481
1482 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
1483 CopilotStreamingResponse::Chat(openai::completion::streaming::StreamingCompletionResponse {
1484 usage,
1485 })
1486 }
1487
1488 fn uses_distinct_tool_call_eviction(&self) -> bool {
1489 true
1490 }
1491}
1492
1493async fn send_copilot_chat_streaming_request<T>(
1494 http_client: T,
1495 req: Request<Vec<u8>>,
1496) -> Result<StreamingCompletionResponse<CopilotStreamingResponse>, CompletionError>
1497where
1498 T: HttpClientExt + Clone + 'static,
1499{
1500 openai_chat_completions_compatible::send_compatible_streaming_request(
1501 http_client,
1502 req,
1503 CopilotChatCompatibleProfile,
1504 )
1505 .await
1506}
1507
1508fn default_token_dir() -> Option<PathBuf> {
1509 config_dir().map(|dir| dir.join("github_copilot"))
1510}
1511
1512fn config_dir() -> Option<PathBuf> {
1513 #[cfg(target_os = "windows")]
1514 {
1515 std::env::var_os("APPDATA").map(PathBuf::from)
1516 }
1517
1518 #[cfg(not(target_os = "windows"))]
1519 {
1520 std::env::var_os("XDG_CONFIG_HOME")
1521 .map(PathBuf::from)
1522 .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
1523 }
1524}
1525
1526#[cfg(test)]
1527mod tests {
1528 use super::{
1529 ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute,
1530 TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token,
1531 route_for_model,
1532 };
1533 use crate::client::CompletionClient;
1534 use crate::completion::CompletionModel;
1535 use crate::http_client;
1536 use crate::providers::internal::openai_chat_completions_compatible::test_support::{
1537 sse_bytes_from_data_lines, sse_bytes_from_json_events,
1538 };
1539 use crate::streaming::StreamedAssistantContent;
1540 use crate::test_utils::MockStreamingClient;
1541 use crate::test_utils::{RecordingHttpClient, SequencedStreamingHttpClient};
1542 use futures::StreamExt;
1543 use std::collections::HashMap;
1544
1545 fn env_map(entries: &[(&str, &str)]) -> HashMap<String, String> {
1546 entries
1547 .iter()
1548 .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1549 .collect()
1550 }
1551
1552 fn minimal_chat_response() -> &'static str {
1553 r#"{
1554 "id": "chatcmpl-123",
1555 "model": "gpt-4o",
1556 "choices": [{
1557 "index": 0,
1558 "message": {
1559 "role": "assistant",
1560 "content": "hello"
1561 },
1562 "finish_reason": "stop"
1563 }],
1564 "usage": {
1565 "prompt_tokens": 4,
1566 "total_tokens": 7
1567 }
1568 }"#
1569 }
1570
1571 fn minimal_responses_response() -> &'static str {
1572 r#"{
1573 "id": "resp_123",
1574 "object": "response",
1575 "created_at": 1700000000,
1576 "status": "completed",
1577 "error": null,
1578 "incomplete_details": null,
1579 "instructions": null,
1580 "max_output_tokens": null,
1581 "model": "gpt-5.3-codex",
1582 "usage": {
1583 "input_tokens": 4,
1584 "input_tokens_details": {
1585 "cached_tokens": 0
1586 },
1587 "output_tokens": 3,
1588 "output_tokens_details": {
1589 "reasoning_tokens": 0
1590 },
1591 "total_tokens": 7
1592 },
1593 "output": [{
1594 "type": "message",
1595 "id": "msg_123",
1596 "role": "assistant",
1597 "status": "completed",
1598 "content": [{
1599 "type": "output_text",
1600 "text": "hello"
1601 }]
1602 }],
1603 "tools": []
1604 }"#
1605 }
1606
1607 fn minimal_embeddings_response() -> &'static str {
1608 r#"{
1609 "data": [
1610 {
1611 "embedding": [0.1, 0.2, 0.3]
1612 },
1613 {
1614 "embedding": [0.4, 0.5, 0.6]
1615 }
1616 ]
1617 }"#
1618 }
1619
1620 #[test]
1621 fn deserialize_standard_openai_response() {
1622 let json = r#"{
1623 "id": "chatcmpl-abc123",
1624 "object": "chat.completion",
1625 "created": 1700000000,
1626 "model": "gpt-4o",
1627 "choices": [{
1628 "index": 0,
1629 "message": {
1630 "role": "assistant",
1631 "content": "Hello!"
1632 },
1633 "finish_reason": "stop"
1634 }],
1635 "usage": {
1636 "prompt_tokens": 10,
1637 "completion_tokens": 5,
1638 "total_tokens": 15
1639 }
1640 }"#;
1641
1642 let response: ChatCompletionResponse =
1643 serde_json::from_str(json).expect("standard OpenAI response should deserialize");
1644 assert_eq!(response.id, "chatcmpl-abc123");
1645 assert_eq!(response.object.as_deref(), Some("chat.completion"));
1646 assert_eq!(response.created, Some(1700000000));
1647 assert_eq!(response.model, "gpt-4o");
1648 assert_eq!(response.choices.len(), 1);
1649 assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop"));
1650 }
1651
1652 #[test]
1653 fn deserialize_copilot_response_without_object_and_created() {
1654 let response: ChatCompletionResponse = serde_json::from_str(minimal_chat_response())
1655 .expect("Copilot response should deserialize");
1656
1657 assert_eq!(response.id, "chatcmpl-123");
1658 assert_eq!(response.object, None);
1659 assert_eq!(response.created, None);
1660 assert_eq!(response.model, "gpt-4o");
1661 assert_eq!(response.choices.len(), 1);
1662 }
1663
1664 #[test]
1665 fn deserialize_copilot_response_without_finish_reason() {
1666 let json = r#"{
1667 "id": "chatcmpl-claude-001",
1668 "model": "claude-3.5-sonnet",
1669 "choices": [{
1670 "message": {
1671 "role": "assistant",
1672 "content": "Here is my analysis."
1673 }
1674 }],
1675 "usage": {
1676 "prompt_tokens": 50,
1677 "total_tokens": 80
1678 }
1679 }"#;
1680
1681 let response: ChatCompletionResponse =
1682 serde_json::from_str(json).expect("Claude-via-Copilot response should deserialize");
1683
1684 assert_eq!(response.model, "claude-3.5-sonnet");
1685 assert_eq!(response.choices[0].finish_reason, None);
1686 assert_eq!(response.choices[0].index, 0);
1687 }
1688
1689 #[test]
1690 fn error_response_with_message_field() {
1691 let json = r#"{"message": "rate limit exceeded"}"#;
1692 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("message-shaped error");
1693
1694 assert_eq!(err.error_message(), "rate limit exceeded");
1695 }
1696
1697 #[test]
1698 fn error_response_with_error_field() {
1699 let json = r#"{"error": "model not found"}"#;
1700 let err: ChatApiErrorResponse = serde_json::from_str(json).expect("error-shaped error");
1701
1702 assert_eq!(err.error_message(), "model not found");
1703 }
1704
1705 #[test]
1706 fn routes_codex_models_to_responses() {
1707 assert_eq!(route_for_model("gpt-5.3-codex"), CompletionRoute::Responses);
1708 assert_eq!(
1709 route_for_model("gpt-5.1-CODEX-mini"),
1710 CompletionRoute::Responses
1711 );
1712 assert_eq!(route_for_model("gpt-5.2"), CompletionRoute::ChatCompletions);
1713 assert_eq!(
1714 route_for_model("claude-sonnet-4.5"),
1715 CompletionRoute::ChatCompletions
1716 );
1717 }
1718
1719 #[tokio::test]
1720 async fn completion_model_routes_chat_requests_to_chat_completions() {
1721 let http_client = RecordingHttpClient::new(minimal_chat_response());
1722 let client = Client::builder()
1723 .api_key("copilot-token")
1724 .http_client(http_client.clone())
1725 .build()
1726 .expect("build client");
1727 let model = client.completion_model("gpt-4o");
1728 let request = model.completion_request("hello").build();
1729
1730 let _response = model.completion(request).await.expect("chat completion");
1731
1732 let requests = http_client.requests();
1733 assert_eq!(requests.len(), 1);
1734 assert!(requests[0].uri.ends_with("/chat/completions"));
1735 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-4o\""));
1736 }
1737
1738 #[tokio::test]
1739 async fn completion_model_routes_codex_requests_to_responses() {
1740 let http_client = RecordingHttpClient::new(minimal_responses_response());
1741 let client = Client::builder()
1742 .api_key("copilot-token")
1743 .http_client(http_client.clone())
1744 .build()
1745 .expect("build client");
1746 let model = client.completion_model("gpt-5.3-codex");
1747 let request = model.completion_request("hello").build();
1748
1749 let _response = model
1750 .completion(request)
1751 .await
1752 .expect("responses completion");
1753
1754 let requests = http_client.requests();
1755 assert_eq!(requests.len(), 1);
1756 assert!(requests[0].uri.ends_with("/responses"));
1757 assert!(String::from_utf8_lossy(&requests[0].body).contains("\"model\":\"gpt-5.3-codex\""));
1758 }
1759
1760 #[tokio::test]
1761 async fn embeddings_accept_minimal_copilot_response_shape() {
1762 use crate::client::EmbeddingsClient;
1763 use crate::embeddings::EmbeddingModel as _;
1764
1765 let http_client = RecordingHttpClient::new(minimal_embeddings_response());
1766 let client = Client::builder()
1767 .api_key("copilot-token")
1768 .http_client(http_client.clone())
1769 .build()
1770 .expect("build client");
1771 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1772
1773 let embeddings = model
1774 .embed_texts(["one".to_string(), "two".to_string()])
1775 .await
1776 .expect("embeddings should deserialize");
1777
1778 assert_eq!(embeddings.len(), 2);
1779 assert_eq!(embeddings[0].vec, vec![0.1, 0.2, 0.3]);
1780 assert_eq!(embeddings[1].vec, vec![0.4, 0.5, 0.6]);
1781
1782 let requests = http_client.requests();
1783 assert_eq!(requests.len(), 1);
1784 assert!(requests[0].uri.ends_with("/embeddings"));
1785 assert!(
1786 String::from_utf8_lossy(&requests[0].body)
1787 .contains("\"model\":\"text-embedding-3-small\"")
1788 );
1789 }
1790
1791 #[tokio::test]
1792 async fn responses_stream_terminates_after_terminal_error() {
1793 let tool_call_done = serde_json::json!({
1794 "type": "response.output_item.done",
1795 "sequence_number": 1,
1796 "item": {
1797 "type": "function_call",
1798 "id": "fc_123",
1799 "arguments": "{}",
1800 "call_id": "call_123",
1801 "name": "example_tool",
1802 "status": "completed"
1803 }
1804 });
1805 let failed = serde_json::json!({
1806 "type": "response.failed",
1807 "sequence_number": 2,
1808 "response": {
1809 "id": "resp_123",
1810 "object": "response",
1811 "created_at": 1700000000,
1812 "status": "failed",
1813 "error": {
1814 "code": "server_error",
1815 "message": "Copilot response stream failed"
1816 },
1817 "incomplete_details": null,
1818 "instructions": null,
1819 "max_output_tokens": null,
1820 "model": "gpt-5.3-codex",
1821 "usage": null,
1822 "output": [],
1823 "tools": []
1824 }
1825 });
1826 let http_client = MockStreamingClient {
1827 sse_bytes: sse_bytes_from_json_events(&[tool_call_done, failed]),
1828 };
1829 let client = Client::builder()
1830 .api_key("copilot-token")
1831 .http_client(http_client)
1832 .build()
1833 .expect("build client");
1834 let model = client.completion_model("gpt-5.3-codex");
1835 let request = model.completion_request("hello").build();
1836 let mut stream = model.stream(request).await.expect("stream should start");
1837
1838 let err = match stream.next().await.expect("stream should yield an item") {
1839 Ok(_) => panic!("stream should surface a provider error"),
1840 Err(err) => err,
1841 };
1842 assert_eq!(
1843 err.to_string(),
1844 "ProviderError: Copilot response stream failed"
1845 );
1846 assert!(
1847 stream.next().await.is_none(),
1848 "responses stream should terminate immediately after a terminal error"
1849 );
1850 }
1851
1852 #[tokio::test]
1853 async fn chat_stream_terminates_after_transport_error() {
1854 let chunks = vec![
1855 Ok(sse_bytes_from_data_lines([
1856 "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"ping\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}",
1857 ])),
1858 Err(http_client::Error::InvalidStatusCode(
1859 http::StatusCode::BAD_GATEWAY,
1860 )),
1861 ];
1862
1863 let http_client = SequencedStreamingHttpClient::new(chunks);
1864 let client = Client::builder()
1865 .api_key("copilot-token")
1866 .http_client(http_client)
1867 .build()
1868 .expect("build client");
1869 let model = client.completion_model("gpt-4o");
1870 let request = model.completion_request("hello").build();
1871 let mut stream = model.stream(request).await.expect("stream should start");
1872
1873 let mut saw_error = false;
1874 while let Some(item) = stream.next().await {
1875 match item {
1876 Ok(StreamedAssistantContent::ToolCallDelta { .. }) => {}
1877 Err(err) => {
1878 assert_eq!(
1879 err.to_string(),
1880 "ProviderError: Invalid status code: 502 Bad Gateway"
1881 );
1882 saw_error = true;
1883 break;
1884 }
1885 Ok(_) => panic!("unexpected non-error stream item before transport failure"),
1886 }
1887 }
1888
1889 assert!(saw_error, "stream should surface the transport error");
1890 assert!(
1891 stream.next().await.is_none(),
1892 "chat stream should terminate immediately after a transport error"
1893 );
1894 }
1895
1896 #[test]
1897 fn env_api_key_prefers_github_prefixed_vars() {
1898 let env = env_map(&[
1899 ("COPILOT_API_KEY", "copilot-key"),
1900 ("GITHUB_COPILOT_API_KEY", "github-key"),
1901 ("GITHUB_TOKEN", "bootstrap-token"),
1902 ]);
1903 let get = |name: &str| env.get(name).cloned();
1904
1905 assert_eq!(env_api_key(&get).as_deref(), Some("github-key"));
1906 }
1907
1908 #[test]
1909 fn env_github_access_token_prefers_explicit_bootstrap_var() {
1910 let env = env_map(&[
1911 ("COPILOT_GITHUB_ACCESS_TOKEN", "explicit-bootstrap"),
1912 ("GITHUB_TOKEN", "fallback-bootstrap"),
1913 ]);
1914 let get = |name: &str| env.get(name).cloned();
1915
1916 assert_eq!(
1917 env_github_access_token(&get).as_deref(),
1918 Some("explicit-bootstrap")
1919 );
1920 }
1921
1922 #[test]
1923 fn env_base_url_prefers_github_prefixed_vars() {
1924 let env = env_map(&[
1925 ("COPILOT_BASE_URL", "https://copilot.example"),
1926 ("GITHUB_COPILOT_API_BASE", "https://github.example"),
1927 ]);
1928 let get = |name: &str| env.get(name).cloned();
1929
1930 assert_eq!(
1931 env_base_url(&get).as_deref(),
1932 Some("https://github.example")
1933 );
1934 }
1935
1936 #[test]
1937 fn env_without_api_key_falls_back_to_oauth() {
1938 let env = env_map(&[("COPILOT_BASE_URL", "https://copilot.example")]);
1939 let get = |name: &str| env.get(name).cloned();
1940
1941 assert!(env_api_key(&get).is_none());
1942 assert!(env_github_access_token(&get).is_none());
1943 assert_eq!(
1944 env_base_url(&get).as_deref(),
1945 Some("https://copilot.example")
1946 );
1947 }
1948
1949 #[test]
1950 fn env_github_token_is_not_treated_as_copilot_api_key() {
1951 let env = env_map(&[("GITHUB_TOKEN", "bootstrap-token")]);
1952 let get = |name: &str| env.get(name).cloned();
1953
1954 assert!(env_api_key(&get).is_none());
1955 assert_eq!(
1956 env_github_access_token(&get).as_deref(),
1957 Some("bootstrap-token")
1958 );
1959 }
1960}