1use crate::config::paths::Paths;
2use crate::providers::api_client::{ApiClient, AuthMethod};
3use crate::providers::utils::{handle_status_openai_compat, stream_openai_compat};
4use anyhow::{anyhow, Context, Result};
5use async_trait::async_trait;
6use axum::http;
7use chrono::{DateTime, Utc};
8use reqwest::{Client, Response};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::time::Duration;
15
16use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
17use super::errors::ProviderError;
18use super::formats::openai::{create_request, get_usage, response_to_message};
19use super::retry::ProviderRetry;
20use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog};
21
22use crate::config::{Config, ConfigError};
23use crate::conversation::message::Message;
24
25use crate::model::ModelConfig;
26use crate::providers::base::{ConfigKey, MessageStream};
27use rmcp::model::Tool;
28
29pub const GITHUB_COPILOT_DEFAULT_MODEL: &str = "gpt-4.1";
30pub const GITHUB_COPILOT_KNOWN_MODELS: &[&str] = &[
31 "gpt-4.1",
32 "gpt-5-mini",
33 "gpt-5",
34 "gpt-4o",
35 "grok-code-fast-1",
36 "gpt-5-codex",
37 "claude-sonnet-4",
38 "claude-sonnet-4.5",
39 "claude-haiku-4.5",
40 "gemini-2.5-pro",
41];
42
43pub const GITHUB_COPILOT_STREAM_MODELS: &[&str] = &[
44 "gpt-4.1",
45 "gpt-5",
46 "gpt-5-mini",
47 "gpt-5-codex",
48 "gemini-2.5-pro",
49 "grok-code-fast-1",
50];
51
52const GITHUB_COPILOT_DOC_URL: &str =
53 "https://docs.github.com/en/copilot/using-github-copilot/ai-models";
54const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
55const GITHUB_COPILOT_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
56const GITHUB_COPILOT_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
57const GITHUB_COPILOT_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
58
59#[derive(Debug, Deserialize)]
60struct DeviceCodeInfo {
61 device_code: String,
62 user_code: String,
63 verification_uri: String,
64}
65
66#[derive(Debug, Serialize, Deserialize, Clone)]
67struct CopilotTokenEndpoints {
68 api: String,
69 #[serde(flatten)]
70 _extra: HashMap<String, Value>,
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74#[allow(dead_code)] struct CopilotTokenInfo {
76 token: String,
77 expires_at: i64,
78 refresh_in: i64,
79 endpoints: CopilotTokenEndpoints,
80 #[serde(flatten)]
81 _extra: HashMap<String, Value>,
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
85struct CopilotState {
86 expires_at: DateTime<Utc>,
87 info: CopilotTokenInfo,
88}
89
90#[derive(Debug)]
91struct DiskCache {
92 cache_path: PathBuf,
93}
94
95impl DiskCache {
96 fn new() -> Self {
97 let cache_path = Paths::in_config_dir("githubcopilot/info.json");
98 Self { cache_path }
99 }
100
101 async fn load(&self) -> Option<CopilotState> {
102 if let Ok(contents) = tokio::fs::read_to_string(&self.cache_path).await {
103 if let Ok(info) = serde_json::from_str::<CopilotState>(&contents) {
104 return Some(info);
105 }
106 }
107 None
108 }
109
110 async fn save(&self, info: &CopilotState) -> Result<()> {
111 if let Some(parent) = self.cache_path.parent() {
112 tokio::fs::create_dir_all(parent).await?;
113 }
114 let contents = serde_json::to_string(info)?;
115 tokio::fs::write(&self.cache_path, contents).await?;
116 Ok(())
117 }
118}
119
120#[derive(Debug, serde::Serialize)]
121pub struct GithubCopilotProvider {
122 #[serde(skip)]
123 client: Client,
124 #[serde(skip)]
125 cache: DiskCache,
126 #[serde(skip)]
127 mu: tokio::sync::Mutex<RefCell<Option<CopilotState>>>,
128 model: ModelConfig,
129 #[serde(skip)]
130 name: String,
131}
132
133impl GithubCopilotProvider {
134 fn payload_contains_image(payload: &Value) -> bool {
135 payload
136 .get("messages")
137 .and_then(|m| m.as_array())
138 .is_some_and(|messages| {
139 messages.iter().any(|msg| {
140 msg.get("content").is_some_and(|content| {
141 content
142 .as_array()
143 .map(|arr| arr.iter().collect::<Vec<_>>())
144 .unwrap_or_else(|| vec![content])
145 .iter()
146 .any(|item| {
147 matches!(
148 item.get("type").and_then(|v| v.as_str()),
149 Some("image_url") | Some("image")
150 )
151 })
152 })
153 })
154 })
155 }
156
157 pub async fn from_env(model: ModelConfig) -> Result<Self> {
158 let client = Client::builder()
159 .timeout(Duration::from_secs(600))
160 .build()?;
161 let cache = DiskCache::new();
162 let mu = tokio::sync::Mutex::new(RefCell::new(None));
163 Ok(Self {
164 client,
165 cache,
166 mu,
167 model,
168 name: Self::metadata().name,
169 })
170 }
171
172 async fn post(&self, payload: &mut Value) -> Result<Response, ProviderError> {
173 let (endpoint, token) = self.get_api_info().await?;
174 let auth = AuthMethod::BearerToken(token);
175 let mut headers = self.get_github_headers();
176 if Self::payload_contains_image(payload) {
177 headers.insert("Copilot-Vision-Request", "true".parse().unwrap());
178 }
179 let api_client = ApiClient::new(endpoint.clone(), auth)?.with_headers(headers)?;
180
181 api_client
182 .response_post("chat/completions", payload)
183 .await
184 .map_err(|e| e.into())
185 }
186
187 async fn get_api_info(&self) -> Result<(String, String)> {
188 let guard = self.mu.lock().await;
189
190 if let Some(state) = guard.borrow().as_ref() {
191 if state.expires_at > Utc::now() {
192 return Ok((state.info.endpoints.api.clone(), state.info.token.clone()));
193 }
194 }
195
196 if let Some(state) = self.cache.load().await {
197 if guard.borrow().is_none() {
198 guard.replace(Some(state.clone()));
199 }
200 if state.expires_at > Utc::now() {
201 return Ok((state.info.endpoints.api, state.info.token));
202 }
203 }
204
205 const MAX_ATTEMPTS: i32 = 3;
206 for attempt in 0..MAX_ATTEMPTS {
207 tracing::trace!("attempt {} to refresh api info", attempt + 1);
208 let info = match self.refresh_api_info().await {
209 Ok(data) => data,
210 Err(err) => {
211 tracing::warn!("failed to refresh api info: {}", err);
212 continue;
213 }
214 };
215 let expires_at = Utc::now() + chrono::Duration::seconds(info.refresh_in);
216 let new_state = CopilotState { info, expires_at };
217 self.cache.save(&new_state).await?;
218 guard.replace(Some(new_state.clone()));
219 return Ok((new_state.info.endpoints.api, new_state.info.token));
220 }
221 Err(anyhow!("failed to get api info after 3 attempts"))
222 }
223
224 async fn refresh_api_info(&self) -> Result<CopilotTokenInfo> {
225 let config = Config::global();
226 let token = match config.get_secret::<String>("GITHUB_COPILOT_TOKEN") {
227 Ok(token) => token,
228 Err(err) => match err {
229 ConfigError::NotFound(_) => {
230 let token = self
231 .get_access_token()
232 .await
233 .context("unable to login into github")?;
234 config.set_secret("GITHUB_COPILOT_TOKEN", &token)?;
235 token
236 }
237 _ => return Err(err.into()),
238 },
239 };
240 let resp = self
241 .client
242 .get(GITHUB_COPILOT_API_KEY_URL)
243 .headers(self.get_github_headers())
244 .header(http::header::AUTHORIZATION, format!("bearer {}", &token))
245 .send()
246 .await?
247 .error_for_status()?
248 .text()
249 .await?;
250 tracing::trace!("copilot token response: {}", resp);
251 let info: CopilotTokenInfo = serde_json::from_str(&resp)?;
252 Ok(info)
253 }
254
255 async fn get_access_token(&self) -> Result<String> {
256 for attempt in 0..3 {
257 tracing::trace!("attempt {} to get access token", attempt + 1);
258 match self.login().await {
259 Ok(token) => return Ok(token),
260 Err(err) => tracing::warn!("failed to get access token: {}", err),
261 }
262 }
263 Err(anyhow!("failed to get access token after 3 attempts"))
264 }
265
266 async fn login(&self) -> Result<String> {
267 let device_code_info = self.get_device_code().await?;
268
269 println!(
270 "Please visit {} and enter code {}",
271 device_code_info.verification_uri, device_code_info.user_code
272 );
273
274 self.poll_for_access_token(&device_code_info.device_code)
275 .await
276 }
277
278 async fn get_device_code(&self) -> Result<DeviceCodeInfo> {
279 #[derive(Serialize)]
280 struct DeviceCodeRequest {
281 client_id: String,
282 scope: String,
283 }
284 self.client
285 .post(GITHUB_COPILOT_DEVICE_CODE_URL)
286 .headers(self.get_github_headers())
287 .json(&DeviceCodeRequest {
288 client_id: GITHUB_COPILOT_CLIENT_ID.to_string(),
289 scope: "read:user".to_string(),
290 })
291 .send()
292 .await
293 .context("failed to send request to get device code")?
294 .error_for_status()
295 .context("failed to get device code")?
296 .json::<DeviceCodeInfo>()
297 .await
298 .context("failed to parse device code response")
299 }
300
301 async fn poll_for_access_token(&self, device_code: &str) -> Result<String> {
302 #[derive(Serialize)]
303 struct AccessTokenRequest {
304 client_id: String,
305 device_code: String,
306 grant_type: String,
307 }
308 #[derive(Debug, Deserialize)]
309 struct AccessTokenResponse {
310 access_token: Option<String>,
311 error: Option<String>,
312 #[serde(flatten)]
313 _extra: HashMap<String, Value>,
314 }
315
316 const MAX_ATTEMPTS: i32 = 36;
317 for attempt in 0..MAX_ATTEMPTS {
318 let resp = self
319 .client
320 .post(GITHUB_COPILOT_ACCESS_TOKEN_URL)
321 .headers(self.get_github_headers())
322 .json(&AccessTokenRequest {
323 client_id: GITHUB_COPILOT_CLIENT_ID.to_string(),
324 device_code: device_code.to_string(),
325 grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
326 })
327 .send()
328 .await
329 .context("failed to make request while polling for access token")?
330 .error_for_status()
331 .context("error polling for access token")?
332 .json::<AccessTokenResponse>()
333 .await
334 .context("failed to parse response while polling for access token")?;
335 if resp.access_token.is_some() {
336 tracing::trace!("successful authorization: {:#?}", resp,);
337 }
338 if let Some(access_token) = resp.access_token {
339 return Ok(access_token);
340 } else if resp
341 .error
342 .as_ref()
343 .is_some_and(|err| err == "authorization_pending")
344 {
345 tracing::debug!(
346 "authorization pending (attempt {}/{})",
347 attempt + 1,
348 MAX_ATTEMPTS
349 );
350 } else {
351 tracing::debug!("unexpected response: {:#?}", resp);
352 }
353 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
354 }
355 Err(anyhow!("failed to get access token"))
356 }
357
358 fn get_github_headers(&self) -> http::HeaderMap {
359 let mut headers = http::HeaderMap::new();
360 headers.insert(http::header::ACCEPT, "application/json".parse().unwrap());
361 headers.insert(
362 http::header::CONTENT_TYPE,
363 "application/json".parse().unwrap(),
364 );
365 headers.insert(
366 http::header::USER_AGENT,
367 "GithubCopilot/1.155.0".parse().unwrap(),
368 );
369 headers.insert("editor-version", "vscode/1.85.1".parse().unwrap());
370 headers.insert("editor-plugin-version", "copilot/1.155.0".parse().unwrap());
371 headers
372 }
373}
374
375#[async_trait]
376impl Provider for GithubCopilotProvider {
377 fn metadata() -> ProviderMetadata {
378 ProviderMetadata::new(
379 "github_copilot",
380 "GitHub Copilot",
381 "GitHub Copilot. Run `aster configure` and select copilot to set up.",
382 GITHUB_COPILOT_DEFAULT_MODEL,
383 GITHUB_COPILOT_KNOWN_MODELS.to_vec(),
384 GITHUB_COPILOT_DOC_URL,
385 vec![ConfigKey::new_oauth(
386 "GITHUB_COPILOT_TOKEN",
387 true,
388 true,
389 None,
390 )],
391 )
392 }
393
394 fn get_name(&self) -> &str {
395 &self.name
396 }
397
398 fn get_model_config(&self) -> ModelConfig {
399 self.model.clone()
400 }
401
402 fn supports_streaming(&self) -> bool {
403 GITHUB_COPILOT_STREAM_MODELS
404 .iter()
405 .any(|prefix| self.model.model_name.starts_with(prefix))
406 }
407
408 #[tracing::instrument(
409 skip(self, model_config, system, messages, tools),
410 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
411 )]
412 async fn complete_with_model(
413 &self,
414 model_config: &ModelConfig,
415 system: &str,
416 messages: &[Message],
417 tools: &[Tool],
418 ) -> Result<(Message, ProviderUsage), ProviderError> {
419 let payload = create_request(
420 model_config,
421 system,
422 messages,
423 tools,
424 &ImageFormat::OpenAi,
425 false,
426 )?;
427 let mut log = RequestLog::start(model_config, &payload)?;
428
429 let response = self
431 .with_retry(|| async {
432 let mut payload_clone = payload.clone();
433 self.post(&mut payload_clone).await
434 })
435 .await?;
436 let response = handle_response_openai_compat(response).await?;
437
438 let response = promote_tool_choice(response);
439
440 let message = response_to_message(&response)?;
442 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
443 tracing::debug!("Failed to get usage data");
444 Usage::default()
445 });
446 let response_model = get_model(&response);
447 log.write(&response, Some(&usage))?;
448 Ok((message, ProviderUsage::new(response_model, usage)))
449 }
450
451 async fn stream(
452 &self,
453 system: &str,
454 messages: &[Message],
455 tools: &[Tool],
456 ) -> Result<MessageStream, ProviderError> {
457 let payload = create_request(
458 &self.model,
459 system,
460 messages,
461 tools,
462 &ImageFormat::OpenAi,
463 true,
464 )?;
465 let mut log = RequestLog::start(&self.model, &payload)?;
466
467 let response = self
468 .with_retry(|| async {
469 let mut payload_clone = payload.clone();
470 let resp = self.post(&mut payload_clone).await?;
471 handle_status_openai_compat(resp).await
472 })
473 .await
474 .inspect_err(|e| {
475 let _ = log.error(e);
476 })?;
477
478 stream_openai_compat(response, log)
479 }
480
481 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
482 let (endpoint, token) = self.get_api_info().await?;
483 let url = format!("{}/models", endpoint);
484
485 let mut headers = http::HeaderMap::new();
486 headers.insert(http::header::ACCEPT, "application/json".parse().unwrap());
487 headers.insert(
488 http::header::CONTENT_TYPE,
489 "application/json".parse().unwrap(),
490 );
491 headers.insert("Copilot-Integration-Id", "vscode-chat".parse().unwrap());
492 headers.insert(
493 http::header::AUTHORIZATION,
494 format!("Bearer {}", token).parse().unwrap(),
495 );
496
497 let response = self.client.get(url).headers(headers).send().await?;
498
499 let json: serde_json::Value = response.json().await?;
500
501 let arr = match json.get("data").and_then(|v| v.as_array()) {
502 Some(arr) => arr,
503 None => return Ok(None),
504 };
505 let mut models: Vec<String> = arr
506 .iter()
507 .filter_map(|m| {
508 if let Some(s) = m.as_str() {
509 Some(s.to_string())
510 } else if let Some(obj) = m.as_object() {
511 obj.get("id").and_then(|v| v.as_str()).map(str::to_string)
512 } else {
513 None
514 }
515 })
516 .collect();
517 models.sort();
518 Ok(Some(models))
519 }
520
521 async fn configure_oauth(&self) -> Result<(), ProviderError> {
522 let config = Config::global();
523
524 if config.get_secret::<String>("GITHUB_COPILOT_TOKEN").is_ok() {
526 match self.refresh_api_info().await {
528 Ok(_) => return Ok(()), Err(_) => {
530 tracing::debug!("Existing token is invalid, starting OAuth flow");
532 }
533 }
534 }
535
536 let token = self
538 .get_access_token()
539 .await
540 .map_err(|e| ProviderError::Authentication(format!("OAuth flow failed: {}", e)))?;
541
542 config
544 .set_secret("GITHUB_COPILOT_TOKEN", &token)
545 .map_err(|e| ProviderError::ExecutionError(format!("Failed to save token: {}", e)))?;
546
547 Ok(())
548 }
549}
550
551fn promote_tool_choice(response: Value) -> Value {
560 let Some(choices) = response.get("choices").and_then(|c| c.as_array()) else {
561 return response;
562 };
563
564 let tool_choice_idx = choices.iter().position(|choice| {
565 choice
566 .get("message")
567 .and_then(|m| m.get("tool_calls"))
568 .and_then(|tc| tc.as_array())
569 .map(|arr| !arr.is_empty())
570 .unwrap_or(false)
571 });
572
573 if let Some(idx) = tool_choice_idx {
574 if idx != 0 {
575 let mut new_response = response;
576 if let Some(new_choices) = new_response
577 .get_mut("choices")
578 .and_then(|c| c.as_array_mut())
579 {
580 let choice = new_choices.remove(idx);
581 new_choices.insert(0, choice);
582 }
583 return new_response;
584 }
585 }
586
587 response
588}
589
590#[cfg(test)]
591mod tests {
592 use super::promote_tool_choice;
593 use serde_json::json;
594
595 #[test]
596 fn promotes_choice_with_tool_call() {
597 let response = json!({
598 "choices": [
599 {"message": {"content": "plain text"}},
600 {"message": {"tool_calls": [{"function": {"name": "foo", "arguments": "{}"}}]}}
601 ]
602 });
603
604 let promoted = promote_tool_choice(response);
605 assert_eq!(
606 promoted
607 .get("choices")
608 .and_then(|c| c.as_array())
609 .map(|c| c.len()),
610 Some(2)
611 );
612 let first_choice = promoted
613 .get("choices")
614 .and_then(|c| c.as_array())
615 .and_then(|c| c.first())
616 .unwrap();
617
618 assert!(first_choice
619 .get("message")
620 .and_then(|m| m.get("tool_calls"))
621 .is_some());
622 }
623
624 #[test]
625 fn leaves_response_when_tool_choice_first() {
626 let response = json!({
627 "choices": [
628 {"message": {"tool_calls": [{"function": {"name": "foo", "arguments": "{}"}}]}},
629 {"message": {"content": "plain text"}}
630 ]
631 });
632
633 let promoted = promote_tool_choice(response.clone());
634 assert_eq!(promoted, response);
635 }
636}