Skip to main content

aster/providers/
githubcopilot.rs

1use crate::config::paths::Paths;
2use crate::providers::api_client::{ApiClient, AuthMethod};
3use crate::providers::utils::{handle_status_openai_compat, stream_openai_compat};
4use anyhow::{anyhow, Context, Result};
5use async_trait::async_trait;
6use axum::http;
7use chrono::{DateTime, Utc};
8use reqwest::{Client, Response};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::time::Duration;
15
16use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
17use super::errors::ProviderError;
18use super::formats::openai::{create_request, get_usage, response_to_message};
19use super::retry::ProviderRetry;
20use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog};
21
22use crate::config::{Config, ConfigError};
23use crate::conversation::message::Message;
24
25use crate::model::ModelConfig;
26use crate::providers::base::{ConfigKey, MessageStream};
27use rmcp::model::Tool;
28
29pub const GITHUB_COPILOT_DEFAULT_MODEL: &str = "gpt-4.1";
30pub const GITHUB_COPILOT_KNOWN_MODELS: &[&str] = &[
31    "gpt-4.1",
32    "gpt-5-mini",
33    "gpt-5",
34    "gpt-4o",
35    "grok-code-fast-1",
36    "gpt-5-codex",
37    "claude-sonnet-4",
38    "claude-sonnet-4.5",
39    "claude-haiku-4.5",
40    "gemini-2.5-pro",
41];
42
43pub const GITHUB_COPILOT_STREAM_MODELS: &[&str] = &[
44    "gpt-4.1",
45    "gpt-5",
46    "gpt-5-mini",
47    "gpt-5-codex",
48    "gemini-2.5-pro",
49    "grok-code-fast-1",
50];
51
52const GITHUB_COPILOT_DOC_URL: &str =
53    "https://docs.github.com/en/copilot/using-github-copilot/ai-models";
54const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
55const GITHUB_COPILOT_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
56const GITHUB_COPILOT_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
57const GITHUB_COPILOT_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
58
59#[derive(Debug, Deserialize)]
60struct DeviceCodeInfo {
61    device_code: String,
62    user_code: String,
63    verification_uri: String,
64}
65
66#[derive(Debug, Serialize, Deserialize, Clone)]
67struct CopilotTokenEndpoints {
68    api: String,
69    #[serde(flatten)]
70    _extra: HashMap<String, Value>,
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74#[allow(dead_code)] // useful for debugging
75struct CopilotTokenInfo {
76    token: String,
77    expires_at: i64,
78    refresh_in: i64,
79    endpoints: CopilotTokenEndpoints,
80    #[serde(flatten)]
81    _extra: HashMap<String, Value>,
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
85struct CopilotState {
86    expires_at: DateTime<Utc>,
87    info: CopilotTokenInfo,
88}
89
90#[derive(Debug)]
91struct DiskCache {
92    cache_path: PathBuf,
93}
94
95impl DiskCache {
96    fn new() -> Self {
97        let cache_path = Paths::in_config_dir("githubcopilot/info.json");
98        Self { cache_path }
99    }
100
101    async fn load(&self) -> Option<CopilotState> {
102        if let Ok(contents) = tokio::fs::read_to_string(&self.cache_path).await {
103            if let Ok(info) = serde_json::from_str::<CopilotState>(&contents) {
104                return Some(info);
105            }
106        }
107        None
108    }
109
110    async fn save(&self, info: &CopilotState) -> Result<()> {
111        if let Some(parent) = self.cache_path.parent() {
112            tokio::fs::create_dir_all(parent).await?;
113        }
114        let contents = serde_json::to_string(info)?;
115        tokio::fs::write(&self.cache_path, contents).await?;
116        Ok(())
117    }
118}
119
120#[derive(Debug, serde::Serialize)]
121pub struct GithubCopilotProvider {
122    #[serde(skip)]
123    client: Client,
124    #[serde(skip)]
125    cache: DiskCache,
126    #[serde(skip)]
127    mu: tokio::sync::Mutex<RefCell<Option<CopilotState>>>,
128    model: ModelConfig,
129    #[serde(skip)]
130    name: String,
131}
132
133impl GithubCopilotProvider {
134    fn payload_contains_image(payload: &Value) -> bool {
135        payload
136            .get("messages")
137            .and_then(|m| m.as_array())
138            .is_some_and(|messages| {
139                messages.iter().any(|msg| {
140                    msg.get("content").is_some_and(|content| {
141                        content
142                            .as_array()
143                            .map(|arr| arr.iter().collect::<Vec<_>>())
144                            .unwrap_or_else(|| vec![content])
145                            .iter()
146                            .any(|item| {
147                                matches!(
148                                    item.get("type").and_then(|v| v.as_str()),
149                                    Some("image_url") | Some("image")
150                                )
151                            })
152                    })
153                })
154            })
155    }
156
157    pub async fn from_env(model: ModelConfig) -> Result<Self> {
158        let client = Client::builder()
159            .timeout(Duration::from_secs(600))
160            .build()?;
161        let cache = DiskCache::new();
162        let mu = tokio::sync::Mutex::new(RefCell::new(None));
163        Ok(Self {
164            client,
165            cache,
166            mu,
167            model,
168            name: Self::metadata().name,
169        })
170    }
171
172    async fn post(&self, payload: &mut Value) -> Result<Response, ProviderError> {
173        let (endpoint, token) = self.get_api_info().await?;
174        let auth = AuthMethod::BearerToken(token);
175        let mut headers = self.get_github_headers();
176        if Self::payload_contains_image(payload) {
177            headers.insert("Copilot-Vision-Request", "true".parse().unwrap());
178        }
179        let api_client = ApiClient::new(endpoint.clone(), auth)?.with_headers(headers)?;
180
181        api_client
182            .response_post("chat/completions", payload)
183            .await
184            .map_err(|e| e.into())
185    }
186
187    async fn get_api_info(&self) -> Result<(String, String)> {
188        let guard = self.mu.lock().await;
189
190        if let Some(state) = guard.borrow().as_ref() {
191            if state.expires_at > Utc::now() {
192                return Ok((state.info.endpoints.api.clone(), state.info.token.clone()));
193            }
194        }
195
196        if let Some(state) = self.cache.load().await {
197            if guard.borrow().is_none() {
198                guard.replace(Some(state.clone()));
199            }
200            if state.expires_at > Utc::now() {
201                return Ok((state.info.endpoints.api, state.info.token));
202            }
203        }
204
205        const MAX_ATTEMPTS: i32 = 3;
206        for attempt in 0..MAX_ATTEMPTS {
207            tracing::trace!("attempt {} to refresh api info", attempt + 1);
208            let info = match self.refresh_api_info().await {
209                Ok(data) => data,
210                Err(err) => {
211                    tracing::warn!("failed to refresh api info: {}", err);
212                    continue;
213                }
214            };
215            let expires_at = Utc::now() + chrono::Duration::seconds(info.refresh_in);
216            let new_state = CopilotState { info, expires_at };
217            self.cache.save(&new_state).await?;
218            guard.replace(Some(new_state.clone()));
219            return Ok((new_state.info.endpoints.api, new_state.info.token));
220        }
221        Err(anyhow!("failed to get api info after 3 attempts"))
222    }
223
224    async fn refresh_api_info(&self) -> Result<CopilotTokenInfo> {
225        let config = Config::global();
226        let token = match config.get_secret::<String>("GITHUB_COPILOT_TOKEN") {
227            Ok(token) => token,
228            Err(err) => match err {
229                ConfigError::NotFound(_) => {
230                    let token = self
231                        .get_access_token()
232                        .await
233                        .context("unable to login into github")?;
234                    config.set_secret("GITHUB_COPILOT_TOKEN", &token)?;
235                    token
236                }
237                _ => return Err(err.into()),
238            },
239        };
240        let resp = self
241            .client
242            .get(GITHUB_COPILOT_API_KEY_URL)
243            .headers(self.get_github_headers())
244            .header(http::header::AUTHORIZATION, format!("bearer {}", &token))
245            .send()
246            .await?
247            .error_for_status()?
248            .text()
249            .await?;
250        tracing::trace!("copilot token response: {}", resp);
251        let info: CopilotTokenInfo = serde_json::from_str(&resp)?;
252        Ok(info)
253    }
254
255    async fn get_access_token(&self) -> Result<String> {
256        for attempt in 0..3 {
257            tracing::trace!("attempt {} to get access token", attempt + 1);
258            match self.login().await {
259                Ok(token) => return Ok(token),
260                Err(err) => tracing::warn!("failed to get access token: {}", err),
261            }
262        }
263        Err(anyhow!("failed to get access token after 3 attempts"))
264    }
265
266    async fn login(&self) -> Result<String> {
267        let device_code_info = self.get_device_code().await?;
268
269        println!(
270            "Please visit {} and enter code {}",
271            device_code_info.verification_uri, device_code_info.user_code
272        );
273
274        self.poll_for_access_token(&device_code_info.device_code)
275            .await
276    }
277
278    async fn get_device_code(&self) -> Result<DeviceCodeInfo> {
279        #[derive(Serialize)]
280        struct DeviceCodeRequest {
281            client_id: String,
282            scope: String,
283        }
284        self.client
285            .post(GITHUB_COPILOT_DEVICE_CODE_URL)
286            .headers(self.get_github_headers())
287            .json(&DeviceCodeRequest {
288                client_id: GITHUB_COPILOT_CLIENT_ID.to_string(),
289                scope: "read:user".to_string(),
290            })
291            .send()
292            .await
293            .context("failed to send request to get device code")?
294            .error_for_status()
295            .context("failed to get device code")?
296            .json::<DeviceCodeInfo>()
297            .await
298            .context("failed to parse device code response")
299    }
300
301    async fn poll_for_access_token(&self, device_code: &str) -> Result<String> {
302        #[derive(Serialize)]
303        struct AccessTokenRequest {
304            client_id: String,
305            device_code: String,
306            grant_type: String,
307        }
308        #[derive(Debug, Deserialize)]
309        struct AccessTokenResponse {
310            access_token: Option<String>,
311            error: Option<String>,
312            #[serde(flatten)]
313            _extra: HashMap<String, Value>,
314        }
315
316        const MAX_ATTEMPTS: i32 = 36;
317        for attempt in 0..MAX_ATTEMPTS {
318            let resp = self
319                .client
320                .post(GITHUB_COPILOT_ACCESS_TOKEN_URL)
321                .headers(self.get_github_headers())
322                .json(&AccessTokenRequest {
323                    client_id: GITHUB_COPILOT_CLIENT_ID.to_string(),
324                    device_code: device_code.to_string(),
325                    grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
326                })
327                .send()
328                .await
329                .context("failed to make request while polling for access token")?
330                .error_for_status()
331                .context("error polling for access token")?
332                .json::<AccessTokenResponse>()
333                .await
334                .context("failed to parse response while polling for access token")?;
335            if resp.access_token.is_some() {
336                tracing::trace!("successful authorization: {:#?}", resp,);
337            }
338            if let Some(access_token) = resp.access_token {
339                return Ok(access_token);
340            } else if resp
341                .error
342                .as_ref()
343                .is_some_and(|err| err == "authorization_pending")
344            {
345                tracing::debug!(
346                    "authorization pending (attempt {}/{})",
347                    attempt + 1,
348                    MAX_ATTEMPTS
349                );
350            } else {
351                tracing::debug!("unexpected response: {:#?}", resp);
352            }
353            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
354        }
355        Err(anyhow!("failed to get access token"))
356    }
357
358    fn get_github_headers(&self) -> http::HeaderMap {
359        let mut headers = http::HeaderMap::new();
360        headers.insert(http::header::ACCEPT, "application/json".parse().unwrap());
361        headers.insert(
362            http::header::CONTENT_TYPE,
363            "application/json".parse().unwrap(),
364        );
365        headers.insert(
366            http::header::USER_AGENT,
367            "GithubCopilot/1.155.0".parse().unwrap(),
368        );
369        headers.insert("editor-version", "vscode/1.85.1".parse().unwrap());
370        headers.insert("editor-plugin-version", "copilot/1.155.0".parse().unwrap());
371        headers
372    }
373}
374
375#[async_trait]
376impl Provider for GithubCopilotProvider {
377    fn metadata() -> ProviderMetadata {
378        ProviderMetadata::new(
379            "github_copilot",
380            "GitHub Copilot",
381            "GitHub Copilot. Run `aster configure` and select copilot to set up.",
382            GITHUB_COPILOT_DEFAULT_MODEL,
383            GITHUB_COPILOT_KNOWN_MODELS.to_vec(),
384            GITHUB_COPILOT_DOC_URL,
385            vec![ConfigKey::new_oauth(
386                "GITHUB_COPILOT_TOKEN",
387                true,
388                true,
389                None,
390            )],
391        )
392    }
393
394    fn get_name(&self) -> &str {
395        &self.name
396    }
397
398    fn get_model_config(&self) -> ModelConfig {
399        self.model.clone()
400    }
401
402    fn supports_streaming(&self) -> bool {
403        GITHUB_COPILOT_STREAM_MODELS
404            .iter()
405            .any(|prefix| self.model.model_name.starts_with(prefix))
406    }
407
408    #[tracing::instrument(
409        skip(self, model_config, system, messages, tools),
410        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
411    )]
412    async fn complete_with_model(
413        &self,
414        model_config: &ModelConfig,
415        system: &str,
416        messages: &[Message],
417        tools: &[Tool],
418    ) -> Result<(Message, ProviderUsage), ProviderError> {
419        let payload = create_request(
420            model_config,
421            system,
422            messages,
423            tools,
424            &ImageFormat::OpenAi,
425            false,
426        )?;
427        let mut log = RequestLog::start(model_config, &payload)?;
428
429        // Make request with retry
430        let response = self
431            .with_retry(|| async {
432                let mut payload_clone = payload.clone();
433                self.post(&mut payload_clone).await
434            })
435            .await?;
436        let response = handle_response_openai_compat(response).await?;
437
438        let response = promote_tool_choice(response);
439
440        // Parse response
441        let message = response_to_message(&response)?;
442        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
443            tracing::debug!("Failed to get usage data");
444            Usage::default()
445        });
446        let response_model = get_model(&response);
447        log.write(&response, Some(&usage))?;
448        Ok((message, ProviderUsage::new(response_model, usage)))
449    }
450
451    async fn stream(
452        &self,
453        system: &str,
454        messages: &[Message],
455        tools: &[Tool],
456    ) -> Result<MessageStream, ProviderError> {
457        let payload = create_request(
458            &self.model,
459            system,
460            messages,
461            tools,
462            &ImageFormat::OpenAi,
463            true,
464        )?;
465        let mut log = RequestLog::start(&self.model, &payload)?;
466
467        let response = self
468            .with_retry(|| async {
469                let mut payload_clone = payload.clone();
470                let resp = self.post(&mut payload_clone).await?;
471                handle_status_openai_compat(resp).await
472            })
473            .await
474            .inspect_err(|e| {
475                let _ = log.error(e);
476            })?;
477
478        stream_openai_compat(response, log)
479    }
480
481    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
482        let (endpoint, token) = self.get_api_info().await?;
483        let url = format!("{}/models", endpoint);
484
485        let mut headers = http::HeaderMap::new();
486        headers.insert(http::header::ACCEPT, "application/json".parse().unwrap());
487        headers.insert(
488            http::header::CONTENT_TYPE,
489            "application/json".parse().unwrap(),
490        );
491        headers.insert("Copilot-Integration-Id", "vscode-chat".parse().unwrap());
492        headers.insert(
493            http::header::AUTHORIZATION,
494            format!("Bearer {}", token).parse().unwrap(),
495        );
496
497        let response = self.client.get(url).headers(headers).send().await?;
498
499        let json: serde_json::Value = response.json().await?;
500
501        let arr = match json.get("data").and_then(|v| v.as_array()) {
502            Some(arr) => arr,
503            None => return Ok(None),
504        };
505        let mut models: Vec<String> = arr
506            .iter()
507            .filter_map(|m| {
508                if let Some(s) = m.as_str() {
509                    Some(s.to_string())
510                } else if let Some(obj) = m.as_object() {
511                    obj.get("id").and_then(|v| v.as_str()).map(str::to_string)
512                } else {
513                    None
514                }
515            })
516            .collect();
517        models.sort();
518        Ok(Some(models))
519    }
520
521    async fn configure_oauth(&self) -> Result<(), ProviderError> {
522        let config = Config::global();
523
524        // Check if token already exists and is valid
525        if config.get_secret::<String>("GITHUB_COPILOT_TOKEN").is_ok() {
526            // Try to refresh API info to validate the token
527            match self.refresh_api_info().await {
528                Ok(_) => return Ok(()), // Token is valid
529                Err(_) => {
530                    // Token is invalid, continue with OAuth flow
531                    tracing::debug!("Existing token is invalid, starting OAuth flow");
532                }
533            }
534        }
535
536        // Start OAuth device code flow
537        let token = self
538            .get_access_token()
539            .await
540            .map_err(|e| ProviderError::Authentication(format!("OAuth flow failed: {}", e)))?;
541
542        // Save the token
543        config
544            .set_secret("GITHUB_COPILOT_TOKEN", &token)
545            .map_err(|e| ProviderError::ExecutionError(format!("Failed to save token: {}", e)))?;
546
547        Ok(())
548    }
549}
550
551// Copilot sometimes returns multiple choices in a completion response for
552// Claude models and places the `tool_calls` payload in a non-zero index choice.
553// Example:
554// - Choice 0: {"finish_reason":"stop","message":{"content":"I'll check the Desktop directory…"}}
555// - Choice 1: {"finish_reason":"tool_calls","message":{"tool_calls":[{"function":{"arguments":"{\"command\":
556//   \"ls -1 ~/Desktop | wc -l\"}","name":"developer__shell"},…}]}}
557// This function ensures the first choice contains tool metadata so the shared formatter emits a
558// `ToolRequest` instead of returning only the plain-text choice.
559fn promote_tool_choice(response: Value) -> Value {
560    let Some(choices) = response.get("choices").and_then(|c| c.as_array()) else {
561        return response;
562    };
563
564    let tool_choice_idx = choices.iter().position(|choice| {
565        choice
566            .get("message")
567            .and_then(|m| m.get("tool_calls"))
568            .and_then(|tc| tc.as_array())
569            .map(|arr| !arr.is_empty())
570            .unwrap_or(false)
571    });
572
573    if let Some(idx) = tool_choice_idx {
574        if idx != 0 {
575            let mut new_response = response;
576            if let Some(new_choices) = new_response
577                .get_mut("choices")
578                .and_then(|c| c.as_array_mut())
579            {
580                let choice = new_choices.remove(idx);
581                new_choices.insert(0, choice);
582            }
583            return new_response;
584        }
585    }
586
587    response
588}
589
590#[cfg(test)]
591mod tests {
592    use super::promote_tool_choice;
593    use serde_json::json;
594
595    #[test]
596    fn promotes_choice_with_tool_call() {
597        let response = json!({
598            "choices": [
599                {"message": {"content": "plain text"}},
600                {"message": {"tool_calls": [{"function": {"name": "foo", "arguments": "{}"}}]}}
601            ]
602        });
603
604        let promoted = promote_tool_choice(response);
605        assert_eq!(
606            promoted
607                .get("choices")
608                .and_then(|c| c.as_array())
609                .map(|c| c.len()),
610            Some(2)
611        );
612        let first_choice = promoted
613            .get("choices")
614            .and_then(|c| c.as_array())
615            .and_then(|c| c.first())
616            .unwrap();
617
618        assert!(first_choice
619            .get("message")
620            .and_then(|m| m.get("tool_calls"))
621            .is_some());
622    }
623
624    #[test]
625    fn leaves_response_when_tool_choice_first() {
626        let response = json!({
627            "choices": [
628                {"message": {"tool_calls": [{"function": {"name": "foo", "arguments": "{}"}}]}},
629                {"message": {"content": "plain text"}}
630            ]
631        });
632
633        let promoted = promote_tool_choice(response.clone());
634        assert_eq!(promoted, response);
635    }
636}