1use std::collections::HashMap;
7
8use async_trait::async_trait;
9use base64::Engine;
10
11use super::device_code::{PollOptions, PollStatus, poll_device_code_flow};
12use super::{DeviceCodeInfo, OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProvider};
13
14const CLIENT_ID_ENCODED: &str = "SXYxLmI1MDdhMDhjODdlY2ZlOTg=";
15
16const COPILOT_HEADERS: &[(&str, &str)] = &[
17 ("User-Agent", "GitHubCopilotChat/0.35.0"),
18 ("Editor-Version", "vscode/1.107.0"),
19 ("Editor-Plugin-Version", "copilot-chat/0.35.0"),
20 ("Copilot-Integration-Id", "vscode-chat"),
21];
22const COPILOT_API_VERSION: &str = "2026-06-01";
23
24fn client_id() -> String {
25 String::from_utf8(
26 base64::engine::general_purpose::STANDARD
27 .decode(CLIENT_ID_ENCODED)
28 .expect("valid base64"),
29 )
30 .expect("valid utf8")
31}
32
33#[allow(dead_code)]
34fn decode(s: &str) -> String {
35 String::from_utf8(
36 base64::engine::general_purpose::STANDARD
37 .decode(s)
38 .unwrap_or_default(),
39 )
40 .unwrap_or_default()
41}
42
43pub fn normalize_domain(input: &str) -> Option<String> {
44 let trimmed = input.trim();
45 if trimmed.is_empty() {
46 return None;
47 }
48 let url_str = if trimmed.contains("://") {
49 trimmed.to_string()
50 } else {
51 format!("https://{}", trimmed)
52 };
53 url::Url::parse(&url_str)
54 .ok()
55 .map(|u| u.host_str().unwrap_or("").to_string())
56}
57
58fn get_urls(domain: &str) -> (String, String, String) {
59 (
60 format!("https://{}/login/device/code", domain),
61 format!("https://{}/login/oauth/access_token", domain),
62 format!("https://api.{}/copilot_internal/v2/token", domain),
63 )
64}
65
66fn get_base_url_from_token(token: &str) -> Option<String> {
68 for part in token.split(';') {
69 if let Some(host) = part.strip_prefix("proxy-ep=") {
70 let api_host = host.replacen("proxy.", "api.", 1);
71 return Some(format!("https://{}", api_host));
72 }
73 }
74 None
75}
76
77pub fn get_copilot_base_url(token: Option<&str>, enterprise_domain: Option<&str>) -> String {
79 if let Some(t) = token
80 && let Some(url) = get_base_url_from_token(t)
81 {
82 return url;
83 }
84 if let Some(domain) = enterprise_domain {
85 return format!("https://copilot-api.{}", domain);
86 }
87 "https://api.individual.githubcopilot.com".to_string()
88}
89
90async fn fetch_json(url: &str, headers: &[(&str, &str)]) -> Result<serde_json::Value, String> {
92 let client = reqwest::Client::new();
93 let mut req = client.get(url);
94 for (k, v) in headers {
95 req = req.header(*k, *v);
96 }
97 let resp = req.send().await.map_err(|e| format!("HTTP error: {}", e))?;
98 let status = resp.status();
99 if !status.is_success() {
100 let text = resp.text().await.unwrap_or_default();
101 return Err(format!("HTTP {}: {}", status, text));
102 }
103 resp.json().await.map_err(|e| format!("JSON error: {}", e))
104}
105
106#[allow(dead_code)]
108async fn post_json(
109 url: &str,
110 headers: &[(&str, &str)],
111 body: &serde_json::Value,
112) -> Result<serde_json::Value, String> {
113 let client = reqwest::Client::new();
114 let mut req = client.post(url).json(body);
115 for (k, v) in headers {
116 req = req.header(*k, *v);
117 }
118 let resp = req.send().await.map_err(|e| format!("HTTP error: {}", e))?;
119 let status = resp.status();
120 if !status.is_success() {
121 let text = resp.text().await.unwrap_or_default();
122 return Err(format!("HTTP {}: {}", status, text));
123 }
124 resp.json().await.map_err(|e| format!("JSON error: {}", e))
125}
126
127async fn post_form(
129 url: &str,
130 headers: &[(&str, &str)],
131 form: &[(&str, &str)],
132) -> Result<serde_json::Value, String> {
133 let client = reqwest::Client::new();
134 let mut req = client.post(url);
135 for (k, v) in headers {
136 req = req.header(*k, *v);
137 }
138 let params: Vec<(&str, &str)> = form.to_vec();
139 let resp = req
140 .form(¶ms)
141 .send()
142 .await
143 .map_err(|e| format!("HTTP error: {}", e))?;
144 let status = resp.status();
145 if !status.is_success() {
146 let text = resp.text().await.unwrap_or_default();
147 return Err(format!("HTTP {}: {}", status, text));
148 }
149 resp.json().await.map_err(|e| format!("JSON error: {}", e))
150}
151
152async fn start_device_flow(domain: &str) -> Result<serde_json::Value, String> {
154 let (device_code_url, _, _) = get_urls(domain);
155 post_form(
156 &device_code_url,
157 &[
158 ("Accept", "application/json"),
159 ("User-Agent", "GitHubCopilotChat/0.35.0"),
160 ],
161 &[("client_id", &client_id()), ("scope", "read:user")],
162 )
163 .await
164}
165
166async fn poll_for_github_access_token(
168 domain: &str,
169 device_code: &str,
170 interval: Option<u32>,
171 expires_in: Option<u32>,
172 cancel: Option<tokio_util::sync::CancellationToken>,
173) -> Result<String, String> {
174 let (_, access_token_url, _) = get_urls(domain);
175 let client_id = client_id();
176 let device_code = device_code.to_string();
177
178 poll_device_code_flow(PollOptions {
179 interval_seconds: interval,
180 expires_in_seconds: expires_in,
181 cancel,
182 poll: Box::new(move || {
183 let access_token_url = access_token_url.clone();
184 let client_id = client_id.clone();
185 let device_code = device_code.clone();
186 Box::pin(async move {
187 let raw = post_form(
188 &access_token_url,
189 &[
190 ("Accept", "application/json"),
191 ("User-Agent", "GitHubCopilotChat/0.35.0"),
192 ],
193 &[
194 ("client_id", &client_id),
195 ("device_code", &device_code),
196 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
197 ],
198 )
199 .await?;
200
201 if let Some(token) = raw.get("access_token").and_then(|t| t.as_str()) {
202 return Ok(PollStatus::Complete(token.to_string()));
203 }
204
205 if let Some(error) = raw.get("error").and_then(|e| e.as_str()) {
206 match error {
207 "authorization_pending" => return Ok(PollStatus::Pending),
208 "slow_down" => return Ok(PollStatus::SlowDown),
209 _ => {
210 let desc = raw
211 .get("error_description")
212 .and_then(|d| d.as_str())
213 .unwrap_or("");
214 return Ok(PollStatus::Failed(format!(
215 "Device flow failed: {}{}",
216 error,
217 if desc.is_empty() {
218 String::new()
219 } else {
220 format!(": {}", desc)
221 }
222 )));
223 }
224 }
225 }
226
227 Ok(PollStatus::Failed(
228 "Invalid device token response".to_string(),
229 ))
230 })
231 }),
232 })
233 .await
234}
235
236async fn exchange_for_copilot_token(
238 github_token: &str,
239 enterprise_domain: Option<&str>,
240) -> Result<serde_json::Value, String> {
241 let domain = enterprise_domain.unwrap_or("github.com");
242 let (_, _, copilot_token_url) = get_urls(domain);
243
244 let auth_val = format!("Bearer {}", github_token);
245 let mut headers: Vec<(&str, &str)> =
246 vec![("Accept", "application/json"), ("Authorization", &auth_val)];
247 for (k, v) in COPILOT_HEADERS {
248 headers.push((k, v));
249 }
250
251 fetch_json(&copilot_token_url, &headers).await
252}
253
254async fn refresh_copilot_access_token(
256 refresh_token: &str,
257 enterprise_domain: Option<&str>,
258) -> Result<serde_json::Value, String> {
259 let domain = enterprise_domain.unwrap_or("github.com");
260 let (_, _, copilot_token_url) = get_urls(domain);
261
262 let auth_val = format!("Bearer {}", refresh_token);
263 let mut headers: Vec<(&str, &str)> =
264 vec![("Accept", "application/json"), ("Authorization", &auth_val)];
265 for (k, v) in COPILOT_HEADERS {
266 headers.push((k, v));
267 }
268
269 fetch_json(&copilot_token_url, &headers).await
270}
271
272async fn fetch_available_model_ids(
274 copilot_token: &str,
275 enterprise_domain: Option<&str>,
276) -> Result<Vec<String>, String> {
277 let base_url = get_copilot_base_url(Some(copilot_token), enterprise_domain);
278 let url = format!("{}/models", base_url);
279
280 let auth_val = format!("Bearer {}", copilot_token);
281 let mut headers: Vec<(&str, &str)> =
282 vec![("Accept", "application/json"), ("Authorization", &auth_val)];
283 for (k, v) in COPILOT_HEADERS {
284 headers.push((k, v));
285 }
286 headers.push(("X-GitHub-Api-Version", COPILOT_API_VERSION));
287
288 let raw = fetch_json(&url, &headers).await?;
289
290 let data = raw.get("data").and_then(|d| d.as_array());
292 match data {
293 Some(items) => {
294 let ids: Vec<String> = items
295 .iter()
296 .filter(|item| {
297 let policy = item.get("policy").and_then(|p| p.as_object());
298 let capabilities = item.get("capabilities").and_then(|c| c.as_object());
299 let supports = capabilities
300 .and_then(|c| c.get("supports"))
301 .and_then(|s| s.as_object());
302 let model_picker_enabled = item
303 .get("model_picker_enabled")
304 .and_then(|v| v.as_bool())
305 .unwrap_or(false);
306 let policy_enabled =
307 policy.and_then(|p| p.get("state")).and_then(|s| s.as_str())
308 != Some("disabled");
309 let supports_tool_calls = supports
310 .and_then(|s| s.get("tool_calls"))
311 .and_then(|v| v.as_bool())
312 .unwrap_or(true);
313 model_picker_enabled && policy_enabled && supports_tool_calls
314 })
315 .filter_map(|item| {
316 item.get("id")
317 .and_then(|id| id.as_str())
318 .map(|s| s.to_string())
319 })
320 .collect();
321 Ok(ids)
322 }
323 None => Err("Invalid Copilot models response: missing data array".to_string()),
324 }
325}
326
327async fn enable_model(
329 copilot_token: &str,
330 model_id: &str,
331 enterprise_domain: Option<&str>,
332) -> Result<bool, String> {
333 let base_url = get_copilot_base_url(Some(copilot_token), enterprise_domain);
334 let url = format!("{}/models/{}/policy", base_url, model_id);
335
336 let client = reqwest::Client::new();
337 let auth_header = format!("Bearer {}", copilot_token);
338 let mut req = client
339 .post(&url)
340 .header("Content-Type", "application/json")
341 .header("Authorization", &auth_header)
342 .header("openai-intent", "chat-policy")
343 .header("x-interaction-type", "chat-policy");
344 for (k, v) in COPILOT_HEADERS {
345 req = req.header(*k, *v);
346 }
347 let body = serde_json::json!({"state": "enabled"});
348 let resp = req.json(&body).send().await;
349 Ok(resp.map(|r| r.status().is_success()).unwrap_or(false))
350}
351
352async fn enable_all_models(
354 copilot_token: &str,
355 enterprise_domain: Option<&str>,
356 on_progress: &mut (dyn FnMut(String, bool) + Send),
357) {
358 let known_models = [
360 "claude-sonnet-4-20250514",
361 "claude-sonnet-4.5-preview-20250619",
362 "claude-opus-4-20250514",
363 "claude-opus-4.5-preview-20250619",
364 "claude-haiku-4-20250514",
365 "claude-haiku-4.5-preview-20250619",
366 "claude-fable-5",
367 "claude-haiku-4.5",
368 "claude-opus-4.5",
369 "claude-sonnet-4",
370 "gpt-4o",
371 "gpt-4o-mini",
372 "o3",
373 "o4-mini",
374 "gemini-2.5-flash-001",
375 "gemini-2.5-pro-001",
376 ];
377
378 use futures::future::join_all;
380 let tasks: Vec<_> = known_models
381 .iter()
382 .map(|model_id| {
383 let token = copilot_token.to_string();
384 let domain = enterprise_domain.map(|s| s.to_string());
385 let mid = model_id.to_string();
386 async move {
387 let success = enable_model(&token, &mid, domain.as_deref())
388 .await
389 .unwrap_or(false);
390 (mid, success)
391 }
392 })
393 .collect();
394
395 let results = join_all(tasks).await;
396 for (model_id, success) in results {
397 on_progress(model_id, success);
398 }
399}
400
401pub struct GitHubCopilotOAuth;
404
405#[async_trait]
406impl OAuthProvider for GitHubCopilotOAuth {
407 fn id(&self) -> &str {
408 "github-copilot"
409 }
410
411 fn name(&self) -> &str {
412 "GitHub Copilot"
413 }
414
415 async fn login(
416 &self,
417 callbacks: &mut OAuthLoginCallbacks<'_>,
418 ) -> Result<OAuthCredentials, String> {
419 let input = (callbacks.on_prompt)(OAuthPrompt::Text {
421 message: "GitHub Enterprise URL/domain (blank for github.com)".to_string(),
422 placeholder: Some("company.ghe.com".to_string()),
423 allow_empty: true,
424 })?;
425
426 if let Some(ref cancel) = callbacks.signal
427 && cancel.is_cancelled()
428 {
429 return Err("Login cancelled".to_string());
430 }
431
432 let trimmed = input.trim().to_string();
433 let enterprise_domain = if trimmed.is_empty() {
434 None
435 } else {
436 normalize_domain(&trimmed)
437 };
438 if !trimmed.is_empty() && enterprise_domain.is_none() {
439 return Err("Invalid GitHub Enterprise URL/domain".to_string());
440 }
441 let domain = enterprise_domain
442 .clone()
443 .unwrap_or_else(|| "github.com".to_string());
444
445 let device_resp = start_device_flow(&domain).await?;
447
448 let device_code = device_resp
449 .get("device_code")
450 .and_then(|v| v.as_str())
451 .ok_or_else(|| "Missing device_code in response".to_string())?
452 .to_string();
453 let user_code = device_resp
454 .get("user_code")
455 .and_then(|v| v.as_str())
456 .ok_or_else(|| "Missing user_code in response".to_string())?
457 .to_string();
458 let verification_uri = device_resp
459 .get("verification_uri")
460 .and_then(|v| v.as_str())
461 .ok_or_else(|| "Missing verification_uri in response".to_string())?
462 .to_string();
463 let interval = device_resp
464 .get("interval")
465 .and_then(|v| v.as_u64())
466 .map(|v| v as u32);
467 let expires_in = device_resp
468 .get("expires_in")
469 .and_then(|v| v.as_u64())
470 .map(|v| v as u32);
471
472 if let Ok(parsed) = url::Url::parse(&verification_uri) {
474 if parsed.scheme() != "https" && parsed.scheme() != "http" {
475 return Err("Untrusted verification_uri in device code response".to_string());
476 }
477 } else {
478 return Err("Invalid verification_uri in device code response".to_string());
479 }
480
481 (callbacks.on_device_code)(DeviceCodeInfo {
483 user_code: user_code.clone(),
484 verification_uri: verification_uri.clone(),
485 interval_seconds: interval,
486 expires_in_seconds: expires_in,
487 });
488
489 let cancel = callbacks.signal.clone();
491 let github_access_token =
492 poll_for_github_access_token(&domain, &device_code, interval, expires_in, cancel)
493 .await?;
494
495 let copilot_resp =
497 exchange_for_copilot_token(&github_access_token, enterprise_domain.as_deref()).await?;
498
499 let token = copilot_resp
500 .get("token")
501 .and_then(|v| v.as_str())
502 .ok_or_else(|| "Missing token in Copilot response".to_string())?
503 .to_string();
504 let expires_at = copilot_resp
505 .get("expires_at")
506 .and_then(|v| v.as_f64())
507 .ok_or_else(|| "Missing expires_at in Copilot response".to_string())?
508 as i64;
509
510 (callbacks.on_progress)("Enabling models...".to_string());
512 enable_all_models(
513 &token,
514 enterprise_domain.as_deref(),
515 &mut |model, success| {
516 (callbacks.on_progress)(format!(
517 "Model {}: {}",
518 model,
519 if success { "enabled" } else { "skipped" }
520 ));
521 },
522 )
523 .await;
524
525 let available_ids = fetch_available_model_ids(&token, enterprise_domain.as_deref())
527 .await
528 .unwrap_or_default();
529
530 let mut extra = HashMap::new();
531 extra.insert("availableModelIds".to_string(), available_ids.join(","));
532 if let Some(ref ed) = enterprise_domain {
533 extra.insert("enterpriseUrl".to_string(), ed.clone());
534 }
535
536 Ok(OAuthCredentials {
537 access: token.clone(),
538 refresh: github_access_token,
539 expires: (expires_at * 1000) - (5 * 60 * 1000), enterprise_url: enterprise_domain,
541 extra,
542 })
543 }
544
545 async fn refresh_token(
546 &self,
547 credentials: &OAuthCredentials,
548 ) -> Result<OAuthCredentials, String> {
549 let enterprise_domain = credentials.enterprise_url.as_deref();
550 let raw = refresh_copilot_access_token(&credentials.refresh, enterprise_domain).await?;
551
552 let token = raw
553 .get("token")
554 .and_then(|v| v.as_str())
555 .ok_or_else(|| "Missing token in Copilot refresh response".to_string())?
556 .to_string();
557 let expires_at = raw
558 .get("expires_at")
559 .and_then(|v| v.as_f64())
560 .ok_or_else(|| "Missing expires_at in Copilot refresh response".to_string())?
561 as i64;
562
563 let available_ids = fetch_available_model_ids(&token, enterprise_domain)
565 .await
566 .unwrap_or_default();
567
568 let mut extra = credentials.extra.clone();
569 extra.insert("availableModelIds".to_string(), available_ids.join(","));
570
571 Ok(OAuthCredentials {
572 access: token,
573 refresh: credentials.refresh.clone(),
574 expires: (expires_at * 1000) - (5 * 60 * 1000),
575 enterprise_url: credentials.enterprise_url.clone(),
576 extra,
577 })
578 }
579
580 fn get_api_key<'a>(&self, credentials: &'a OAuthCredentials) -> &'a str {
581 &credentials.access
582 }
583}