Skip to main content

rab/provider/oauth/
github_copilot.rs

1//! GitHub Copilot OAuth provider — matching pi's github-copilot.ts exactly.
2//!
3//! Uses the device code flow (RFC 8628) to authenticate with GitHub.
4//! After login, fetches available models and enables them.
5
6use std::collections::HashMap;
7
8use async_trait::async_trait;
9use base64::Engine;
10
11use super::device_code::{PollOptions, PollStatus, poll_device_code_flow};
12use super::{DeviceCodeInfo, OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProvider};
13
14const CLIENT_ID_ENCODED: &str = "SXYxLmI1MDdhMDhjODdlY2ZlOTg=";
15
16const COPILOT_HEADERS: &[(&str, &str)] = &[
17    ("User-Agent", "GitHubCopilotChat/0.35.0"),
18    ("Editor-Version", "vscode/1.107.0"),
19    ("Editor-Plugin-Version", "copilot-chat/0.35.0"),
20    ("Copilot-Integration-Id", "vscode-chat"),
21];
22const COPILOT_API_VERSION: &str = "2026-06-01";
23
24fn client_id() -> String {
25    String::from_utf8(
26        base64::engine::general_purpose::STANDARD
27            .decode(CLIENT_ID_ENCODED)
28            .expect("valid base64"),
29    )
30    .expect("valid utf8")
31}
32
33#[allow(dead_code)]
34fn decode(s: &str) -> String {
35    String::from_utf8(
36        base64::engine::general_purpose::STANDARD
37            .decode(s)
38            .unwrap_or_default(),
39    )
40    .unwrap_or_default()
41}
42
43pub fn normalize_domain(input: &str) -> Option<String> {
44    let trimmed = input.trim();
45    if trimmed.is_empty() {
46        return None;
47    }
48    let url_str = if trimmed.contains("://") {
49        trimmed.to_string()
50    } else {
51        format!("https://{}", trimmed)
52    };
53    url::Url::parse(&url_str)
54        .ok()
55        .map(|u| u.host_str().unwrap_or("").to_string())
56}
57
58fn get_urls(domain: &str) -> (String, String, String) {
59    (
60        format!("https://{}/login/device/code", domain),
61        format!("https://{}/login/oauth/access_token", domain),
62        format!("https://api.{}/copilot_internal/v2/token", domain),
63    )
64}
65
66/// Parse the proxy-ep from a Copilot token and convert to API base URL.
67fn get_base_url_from_token(token: &str) -> Option<String> {
68    for part in token.split(';') {
69        if let Some(host) = part.strip_prefix("proxy-ep=") {
70            let api_host = host.replacen("proxy.", "api.", 1);
71            return Some(format!("https://{}", api_host));
72        }
73    }
74    None
75}
76
77/// Get the GitHub Copilot API base URL.
78pub fn get_copilot_base_url(token: Option<&str>, enterprise_domain: Option<&str>) -> String {
79    if let Some(t) = token
80        && let Some(url) = get_base_url_from_token(t)
81    {
82        return url;
83    }
84    if let Some(domain) = enterprise_domain {
85        return format!("https://copilot-api.{}", domain);
86    }
87    "https://api.individual.githubcopilot.com".to_string()
88}
89
90/// Fetch JSON from a URL with headers.
91async fn fetch_json(url: &str, headers: &[(&str, &str)]) -> Result<serde_json::Value, String> {
92    let client = reqwest::Client::new();
93    let mut req = client.get(url);
94    for (k, v) in headers {
95        req = req.header(*k, *v);
96    }
97    let resp = req.send().await.map_err(|e| format!("HTTP error: {}", e))?;
98    let status = resp.status();
99    if !status.is_success() {
100        let text = resp.text().await.unwrap_or_default();
101        return Err(format!("HTTP {}: {}", status, text));
102    }
103    resp.json().await.map_err(|e| format!("JSON error: {}", e))
104}
105
106/// Post JSON-encoded body to a URL.
107#[allow(dead_code)]
108async fn post_json(
109    url: &str,
110    headers: &[(&str, &str)],
111    body: &serde_json::Value,
112) -> Result<serde_json::Value, String> {
113    let client = reqwest::Client::new();
114    let mut req = client.post(url).json(body);
115    for (k, v) in headers {
116        req = req.header(*k, *v);
117    }
118    let resp = req.send().await.map_err(|e| format!("HTTP error: {}", e))?;
119    let status = resp.status();
120    if !status.is_success() {
121        let text = resp.text().await.unwrap_or_default();
122        return Err(format!("HTTP {}: {}", status, text));
123    }
124    resp.json().await.map_err(|e| format!("JSON error: {}", e))
125}
126
127/// Post form-encoded body to a URL.
128async fn post_form(
129    url: &str,
130    headers: &[(&str, &str)],
131    form: &[(&str, &str)],
132) -> Result<serde_json::Value, String> {
133    let client = reqwest::Client::new();
134    let mut req = client.post(url);
135    for (k, v) in headers {
136        req = req.header(*k, *v);
137    }
138    let params: Vec<(&str, &str)> = form.to_vec();
139    let resp = req
140        .form(&params)
141        .send()
142        .await
143        .map_err(|e| format!("HTTP error: {}", e))?;
144    let status = resp.status();
145    if !status.is_success() {
146        let text = resp.text().await.unwrap_or_default();
147        return Err(format!("HTTP {}: {}", status, text));
148    }
149    resp.json().await.map_err(|e| format!("JSON error: {}", e))
150}
151
152/// Start the device code flow.
153async fn start_device_flow(domain: &str) -> Result<serde_json::Value, String> {
154    let (device_code_url, _, _) = get_urls(domain);
155    post_form(
156        &device_code_url,
157        &[
158            ("Accept", "application/json"),
159            ("User-Agent", "GitHubCopilotChat/0.35.0"),
160        ],
161        &[("client_id", &client_id()), ("scope", "read:user")],
162    )
163    .await
164}
165
166/// Poll for the GitHub access token.
167async fn poll_for_github_access_token(
168    domain: &str,
169    device_code: &str,
170    interval: Option<u32>,
171    expires_in: Option<u32>,
172    cancel: Option<tokio_util::sync::CancellationToken>,
173) -> Result<String, String> {
174    let (_, access_token_url, _) = get_urls(domain);
175    let client_id = client_id();
176    let device_code = device_code.to_string();
177
178    poll_device_code_flow(PollOptions {
179        interval_seconds: interval,
180        expires_in_seconds: expires_in,
181        cancel,
182        poll: Box::new(move || {
183            let access_token_url = access_token_url.clone();
184            let client_id = client_id.clone();
185            let device_code = device_code.clone();
186            Box::pin(async move {
187                let raw = post_form(
188                    &access_token_url,
189                    &[
190                        ("Accept", "application/json"),
191                        ("User-Agent", "GitHubCopilotChat/0.35.0"),
192                    ],
193                    &[
194                        ("client_id", &client_id),
195                        ("device_code", &device_code),
196                        ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
197                    ],
198                )
199                .await?;
200
201                if let Some(token) = raw.get("access_token").and_then(|t| t.as_str()) {
202                    return Ok(PollStatus::Complete(token.to_string()));
203                }
204
205                if let Some(error) = raw.get("error").and_then(|e| e.as_str()) {
206                    match error {
207                        "authorization_pending" => return Ok(PollStatus::Pending),
208                        "slow_down" => return Ok(PollStatus::SlowDown),
209                        _ => {
210                            let desc = raw
211                                .get("error_description")
212                                .and_then(|d| d.as_str())
213                                .unwrap_or("");
214                            return Ok(PollStatus::Failed(format!(
215                                "Device flow failed: {}{}",
216                                error,
217                                if desc.is_empty() {
218                                    String::new()
219                                } else {
220                                    format!(": {}", desc)
221                                }
222                            )));
223                        }
224                    }
225                }
226
227                Ok(PollStatus::Failed(
228                    "Invalid device token response".to_string(),
229                ))
230            })
231        }),
232    })
233    .await
234}
235
236/// Exchange GitHub access token for a Copilot token.
237async fn exchange_for_copilot_token(
238    github_token: &str,
239    enterprise_domain: Option<&str>,
240) -> Result<serde_json::Value, String> {
241    let domain = enterprise_domain.unwrap_or("github.com");
242    let (_, _, copilot_token_url) = get_urls(domain);
243
244    let auth_val = format!("Bearer {}", github_token);
245    let mut headers: Vec<(&str, &str)> =
246        vec![("Accept", "application/json"), ("Authorization", &auth_val)];
247    for (k, v) in COPILOT_HEADERS {
248        headers.push((k, v));
249    }
250
251    fetch_json(&copilot_token_url, &headers).await
252}
253
254/// Refresh the Copilot token using the refresh token.
255async fn refresh_copilot_access_token(
256    refresh_token: &str,
257    enterprise_domain: Option<&str>,
258) -> Result<serde_json::Value, String> {
259    let domain = enterprise_domain.unwrap_or("github.com");
260    let (_, _, copilot_token_url) = get_urls(domain);
261
262    let auth_val = format!("Bearer {}", refresh_token);
263    let mut headers: Vec<(&str, &str)> =
264        vec![("Accept", "application/json"), ("Authorization", &auth_val)];
265    for (k, v) in COPILOT_HEADERS {
266        headers.push((k, v));
267    }
268
269    fetch_json(&copilot_token_url, &headers).await
270}
271
272/// Fetch available Copilot model IDs.
273async fn fetch_available_model_ids(
274    copilot_token: &str,
275    enterprise_domain: Option<&str>,
276) -> Result<Vec<String>, String> {
277    let base_url = get_copilot_base_url(Some(copilot_token), enterprise_domain);
278    let url = format!("{}/models", base_url);
279
280    let auth_val = format!("Bearer {}", copilot_token);
281    let mut headers: Vec<(&str, &str)> =
282        vec![("Accept", "application/json"), ("Authorization", &auth_val)];
283    for (k, v) in COPILOT_HEADERS {
284        headers.push((k, v));
285    }
286    headers.push(("X-GitHub-Api-Version", COPILOT_API_VERSION));
287
288    let raw = fetch_json(&url, &headers).await?;
289
290    // Parse model list
291    let data = raw.get("data").and_then(|d| d.as_array());
292    match data {
293        Some(items) => {
294            let ids: Vec<String> = items
295                .iter()
296                .filter(|item| {
297                    let policy = item.get("policy").and_then(|p| p.as_object());
298                    let capabilities = item.get("capabilities").and_then(|c| c.as_object());
299                    let supports = capabilities
300                        .and_then(|c| c.get("supports"))
301                        .and_then(|s| s.as_object());
302                    let model_picker_enabled = item
303                        .get("model_picker_enabled")
304                        .and_then(|v| v.as_bool())
305                        .unwrap_or(false);
306                    let policy_enabled =
307                        policy.and_then(|p| p.get("state")).and_then(|s| s.as_str())
308                            != Some("disabled");
309                    let supports_tool_calls = supports
310                        .and_then(|s| s.get("tool_calls"))
311                        .and_then(|v| v.as_bool())
312                        .unwrap_or(true);
313                    model_picker_enabled && policy_enabled && supports_tool_calls
314                })
315                .filter_map(|item| {
316                    item.get("id")
317                        .and_then(|id| id.as_str())
318                        .map(|s| s.to_string())
319                })
320                .collect();
321            Ok(ids)
322        }
323        None => Err("Invalid Copilot models response: missing data array".to_string()),
324    }
325}
326
327/// Enable a model via the policy endpoint.
328async fn enable_model(
329    copilot_token: &str,
330    model_id: &str,
331    enterprise_domain: Option<&str>,
332) -> Result<bool, String> {
333    let base_url = get_copilot_base_url(Some(copilot_token), enterprise_domain);
334    let url = format!("{}/models/{}/policy", base_url, model_id);
335
336    let client = reqwest::Client::new();
337    let auth_header = format!("Bearer {}", copilot_token);
338    let mut req = client
339        .post(&url)
340        .header("Content-Type", "application/json")
341        .header("Authorization", &auth_header)
342        .header("openai-intent", "chat-policy")
343        .header("x-interaction-type", "chat-policy");
344    for (k, v) in COPILOT_HEADERS {
345        req = req.header(*k, *v);
346    }
347    let body = serde_json::json!({"state": "enabled"});
348    let resp = req.json(&body).send().await;
349    Ok(resp.map(|r| r.status().is_success()).unwrap_or(false))
350}
351
352/// Enable all known GitHub Copilot models after login.
353async fn enable_all_models(
354    copilot_token: &str,
355    enterprise_domain: Option<&str>,
356    on_progress: &mut (dyn FnMut(String, bool) + Send),
357) {
358    // Known Copilot model IDs from GITHUB_COPILOT_MODELS
359    let known_models = [
360        "claude-sonnet-4-20250514",
361        "claude-sonnet-4.5-preview-20250619",
362        "claude-opus-4-20250514",
363        "claude-opus-4.5-preview-20250619",
364        "claude-haiku-4-20250514",
365        "claude-haiku-4.5-preview-20250619",
366        "claude-fable-5",
367        "claude-haiku-4.5",
368        "claude-opus-4.5",
369        "claude-sonnet-4",
370        "gpt-4o",
371        "gpt-4o-mini",
372        "o3",
373        "o4-mini",
374        "gemini-2.5-flash-001",
375        "gemini-2.5-pro-001",
376    ];
377
378    // Pi-compatible parallel enabling via join_all
379    use futures::future::join_all;
380    let tasks: Vec<_> = known_models
381        .iter()
382        .map(|model_id| {
383            let token = copilot_token.to_string();
384            let domain = enterprise_domain.map(|s| s.to_string());
385            let mid = model_id.to_string();
386            async move {
387                let success = enable_model(&token, &mid, domain.as_deref())
388                    .await
389                    .unwrap_or(false);
390                (mid, success)
391            }
392        })
393        .collect();
394
395    let results = join_all(tasks).await;
396    for (model_id, success) in results {
397        on_progress(model_id, success);
398    }
399}
400
401// ── OAuthProvider implementation ───────────────────────────────────
402
403pub struct GitHubCopilotOAuth;
404
405#[async_trait]
406impl OAuthProvider for GitHubCopilotOAuth {
407    fn id(&self) -> &str {
408        "github-copilot"
409    }
410
411    fn name(&self) -> &str {
412        "GitHub Copilot"
413    }
414
415    async fn login(
416        &self,
417        callbacks: &mut OAuthLoginCallbacks<'_>,
418    ) -> Result<OAuthCredentials, String> {
419        // 1. Prompt for enterprise domain
420        let input = (callbacks.on_prompt)(OAuthPrompt::Text {
421            message: "GitHub Enterprise URL/domain (blank for github.com)".to_string(),
422            placeholder: Some("company.ghe.com".to_string()),
423            allow_empty: true,
424        })?;
425
426        if let Some(ref cancel) = callbacks.signal
427            && cancel.is_cancelled()
428        {
429            return Err("Login cancelled".to_string());
430        }
431
432        let trimmed = input.trim().to_string();
433        let enterprise_domain = if trimmed.is_empty() {
434            None
435        } else {
436            normalize_domain(&trimmed)
437        };
438        if !trimmed.is_empty() && enterprise_domain.is_none() {
439            return Err("Invalid GitHub Enterprise URL/domain".to_string());
440        }
441        let domain = enterprise_domain
442            .clone()
443            .unwrap_or_else(|| "github.com".to_string());
444
445        // 2. Start device flow
446        let device_resp = start_device_flow(&domain).await?;
447
448        let device_code = device_resp
449            .get("device_code")
450            .and_then(|v| v.as_str())
451            .ok_or_else(|| "Missing device_code in response".to_string())?
452            .to_string();
453        let user_code = device_resp
454            .get("user_code")
455            .and_then(|v| v.as_str())
456            .ok_or_else(|| "Missing user_code in response".to_string())?
457            .to_string();
458        let verification_uri = device_resp
459            .get("verification_uri")
460            .and_then(|v| v.as_str())
461            .ok_or_else(|| "Missing verification_uri in response".to_string())?
462            .to_string();
463        let interval = device_resp
464            .get("interval")
465            .and_then(|v| v.as_u64())
466            .map(|v| v as u32);
467        let expires_in = device_resp
468            .get("expires_in")
469            .and_then(|v| v.as_u64())
470            .map(|v| v as u32);
471
472        // Validate verification_uri is a trusted URL
473        if let Ok(parsed) = url::Url::parse(&verification_uri) {
474            if parsed.scheme() != "https" && parsed.scheme() != "http" {
475                return Err("Untrusted verification_uri in device code response".to_string());
476            }
477        } else {
478            return Err("Invalid verification_uri in device code response".to_string());
479        }
480
481        // 3. Notify user with device code info
482        (callbacks.on_device_code)(DeviceCodeInfo {
483            user_code: user_code.clone(),
484            verification_uri: verification_uri.clone(),
485            interval_seconds: interval,
486            expires_in_seconds: expires_in,
487        });
488
489        // 4. Poll for GitHub access token
490        let cancel = callbacks.signal.clone();
491        let github_access_token =
492            poll_for_github_access_token(&domain, &device_code, interval, expires_in, cancel)
493                .await?;
494
495        // 5. Exchange for Copilot token
496        let copilot_resp =
497            exchange_for_copilot_token(&github_access_token, enterprise_domain.as_deref()).await?;
498
499        let token = copilot_resp
500            .get("token")
501            .and_then(|v| v.as_str())
502            .ok_or_else(|| "Missing token in Copilot response".to_string())?
503            .to_string();
504        let expires_at = copilot_resp
505            .get("expires_at")
506            .and_then(|v| v.as_f64())
507            .ok_or_else(|| "Missing expires_at in Copilot response".to_string())?
508            as i64;
509
510        // 6. Enable all models
511        (callbacks.on_progress)("Enabling models...".to_string());
512        enable_all_models(
513            &token,
514            enterprise_domain.as_deref(),
515            &mut |model, success| {
516                (callbacks.on_progress)(format!(
517                    "Model {}: {}",
518                    model,
519                    if success { "enabled" } else { "skipped" }
520                ));
521            },
522        )
523        .await;
524
525        // 7. Fetch available model IDs
526        let available_ids = fetch_available_model_ids(&token, enterprise_domain.as_deref())
527            .await
528            .unwrap_or_default();
529
530        let mut extra = HashMap::new();
531        extra.insert("availableModelIds".to_string(), available_ids.join(","));
532        if let Some(ref ed) = enterprise_domain {
533            extra.insert("enterpriseUrl".to_string(), ed.clone());
534        }
535
536        Ok(OAuthCredentials {
537            access: token.clone(),
538            refresh: github_access_token,
539            expires: (expires_at * 1000) - (5 * 60 * 1000), // 5 min buffer
540            enterprise_url: enterprise_domain,
541            extra,
542        })
543    }
544
545    async fn refresh_token(
546        &self,
547        credentials: &OAuthCredentials,
548    ) -> Result<OAuthCredentials, String> {
549        let enterprise_domain = credentials.enterprise_url.as_deref();
550        let raw = refresh_copilot_access_token(&credentials.refresh, enterprise_domain).await?;
551
552        let token = raw
553            .get("token")
554            .and_then(|v| v.as_str())
555            .ok_or_else(|| "Missing token in Copilot refresh response".to_string())?
556            .to_string();
557        let expires_at = raw
558            .get("expires_at")
559            .and_then(|v| v.as_f64())
560            .ok_or_else(|| "Missing expires_at in Copilot refresh response".to_string())?
561            as i64;
562
563        // Fetch available model IDs
564        let available_ids = fetch_available_model_ids(&token, enterprise_domain)
565            .await
566            .unwrap_or_default();
567
568        let mut extra = credentials.extra.clone();
569        extra.insert("availableModelIds".to_string(), available_ids.join(","));
570
571        Ok(OAuthCredentials {
572            access: token,
573            refresh: credentials.refresh.clone(),
574            expires: (expires_at * 1000) - (5 * 60 * 1000),
575            enterprise_url: credentials.enterprise_url.clone(),
576            extra,
577        })
578    }
579
580    fn get_api_key<'a>(&self, credentials: &'a OAuthCredentials) -> &'a str {
581        &credentials.access
582    }
583}