use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
pub const COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
pub const COPILOT_CHAT_URL: &str = "https://api.githubcopilot.com/chat/completions";
pub const COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const OAUTH_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
#[derive(Debug, Clone, serde::Deserialize)]
pub struct DeviceFlowResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, serde::Deserialize)]
struct CopilotTokenResponse {
token: String,
expires_at: u64,
}
#[derive(Debug, serde::Deserialize)]
struct OAuthPollResponse {
#[serde(default)]
access_token: Option<String>,
#[serde(default)]
error: Option<String>,
}
fn copilot_headers() -> Vec<(&'static str, &'static str)> {
vec![
("copilot-integration-id", "vscode-chat"),
("editor-version", "vscode/1.99.0"),
("editor-plugin-version", "copilot-chat/0.26.0"),
("user-agent", "GitHubCopilotChat/0.26.0"),
]
}
pub async fn start_device_flow() -> anyhow::Result<DeviceFlowResponse> {
let client = reqwest::Client::new();
let resp = client
.post(DEVICE_CODE_URL)
.header("content-type", "application/json")
.header("accept", "application/json")
.json(&serde_json::json!({
"client_id": COPILOT_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("Device flow request failed: {}", body);
}
Ok(resp.json::<DeviceFlowResponse>().await?)
}
pub async fn poll_for_oauth_token(device_code: &str, interval: u64) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let mut poll_interval = Duration::from_secs(interval.max(5));
loop {
tokio::time::sleep(poll_interval).await;
let resp = client
.post(OAUTH_TOKEN_URL)
.header("content-type", "application/json")
.header("accept", "application/json")
.json(&serde_json::json!({
"client_id": COPILOT_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?;
let poll: OAuthPollResponse = resp.json().await?;
if let Some(token) = poll.access_token
&& !token.is_empty()
{
return Ok(token);
}
match poll.error.as_deref() {
Some("authorization_pending") => continue,
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
continue;
}
Some("expired_token") => anyhow::bail!("Device code expired. Please try again."),
Some("access_denied") => anyhow::bail!("Authorization denied by user."),
Some(other) => anyhow::bail!("OAuth error: {}", other),
None => continue,
}
}
}
pub struct CopilotTokenManager {
oauth_token: String,
copilot_token: Arc<RwLock<String>>,
expires_at: Arc<RwLock<Instant>>,
}
impl CopilotTokenManager {
pub fn new(oauth_token: String) -> Self {
Self {
oauth_token,
copilot_token: Arc::new(RwLock::new(String::new())),
expires_at: Arc::new(RwLock::new(Instant::now())),
}
}
pub async fn ensure_token(&self) -> anyhow::Result<String> {
{
let token = self.copilot_token.read().unwrap();
let expires = self.expires_at.read().unwrap();
if !token.is_empty() && *expires > Instant::now() + Duration::from_secs(120) {
return Ok(token.clone());
}
}
self.refresh().await?;
Ok(self.copilot_token.read().unwrap().clone())
}
pub fn get_cached_token(&self) -> String {
self.copilot_token.read().unwrap().clone()
}
pub async fn refresh(&self) -> anyhow::Result<()> {
let client = reqwest::Client::new();
let mut req = client
.get(COPILOT_TOKEN_URL)
.header("authorization", format!("token {}", self.oauth_token))
.header("accept", "application/json");
for (k, v) in copilot_headers() {
req = req.header(k, v);
}
let resp = req.send().await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"Copilot token refresh failed ({}): {}",
status,
&body[..body.floor_char_boundary(300)]
);
}
let token_resp: CopilotTokenResponse = resp.json().await?;
let now_unix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let ttl = token_resp.expires_at.saturating_sub(now_unix);
{
let mut token = self.copilot_token.write().unwrap();
*token = token_resp.token;
}
{
let mut expires = self.expires_at.write().unwrap();
*expires = Instant::now() + Duration::from_secs(ttl);
}
tracing::debug!("Copilot token refreshed, TTL {}s", ttl);
Ok(())
}
pub fn start_background_refresh(self: Arc<Self>) {
tokio::spawn(async move {
if let Err(e) = self.refresh().await {
tracing::warn!("Copilot initial token refresh failed: {}", e);
}
loop {
let sleep_secs = {
let expires = self.expires_at.read().unwrap();
let remaining = expires.saturating_duration_since(Instant::now());
remaining.as_secs().saturating_sub(120).max(60)
};
tokio::time::sleep(Duration::from_secs(sleep_secs)).await;
if let Err(e) = self.refresh().await {
tracing::warn!("Copilot token background refresh failed: {}", e);
tokio::time::sleep(Duration::from_secs(30)).await;
}
}
});
}
}
pub fn copilot_extra_headers() -> Vec<(String, String)> {
copilot_headers()
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
pub async fn fetch_copilot_models(oauth_token: &str) -> anyhow::Result<Vec<String>> {
let manager = CopilotTokenManager::new(oauth_token.to_string());
let copilot_token = manager.ensure_token().await?;
let client = reqwest::Client::new();
let mut req = client
.get(COPILOT_MODELS_URL)
.header("authorization", format!("Bearer {}", copilot_token));
for (k, v) in copilot_headers() {
req = req.header(k, v);
}
let resp = req.send().await?;
if !resp.status().is_success() {
anyhow::bail!("Failed to fetch Copilot models: {}", resp.status());
}
#[derive(serde::Deserialize)]
struct ModelEntry {
id: String,
#[serde(default)]
created: i64,
}
#[derive(serde::Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
let body: ModelsResponse = resp.json().await?;
let mut entries = body.data;
entries.sort_by_key(|e| std::cmp::Reverse(e.created));
Ok(entries.into_iter().map(|m| m.id).collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn copilot_client_id_is_well_known() {
assert_eq!(COPILOT_CLIENT_ID, "Iv1.b507a08c87ecfe98");
}
#[test]
fn copilot_urls_are_correct() {
assert!(COPILOT_CHAT_URL.starts_with("https://api.githubcopilot.com"));
assert!(COPILOT_MODELS_URL.starts_with("https://api.githubcopilot.com"));
assert!(COPILOT_TOKEN_URL.contains("copilot_internal"));
assert!(DEVICE_CODE_URL.contains("login/device"));
assert!(OAUTH_TOKEN_URL.contains("login/oauth"));
}
#[test]
fn copilot_extra_headers_include_required_fields() {
let headers = copilot_extra_headers();
let keys: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
assert!(keys.contains(&"copilot-integration-id"));
assert!(keys.contains(&"editor-version"));
assert!(keys.contains(&"user-agent"));
}
#[test]
fn token_manager_starts_with_empty_token() {
let manager = CopilotTokenManager::new("gho_test".to_string());
assert!(manager.get_cached_token().is_empty());
}
#[test]
fn device_flow_response_deserializes() {
let json = r#"{
"device_code": "abc123",
"user_code": "ABCD-1234",
"verification_uri": "https://github.com/login/device",
"expires_in": 900,
"interval": 5
}"#;
let resp: DeviceFlowResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.user_code, "ABCD-1234");
assert_eq!(resp.interval, 5);
assert_eq!(resp.expires_in, 900);
}
#[test]
fn oauth_poll_response_handles_pending() {
let json = r#"{"error": "authorization_pending"}"#;
let resp: OAuthPollResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.error.as_deref(), Some("authorization_pending"));
assert!(resp.access_token.is_none());
}
#[test]
fn oauth_poll_response_handles_success() {
let json =
r#"{"access_token": "gho_abc123", "token_type": "bearer", "scope": "read:user"}"#;
let resp: OAuthPollResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.access_token.as_deref(), Some("gho_abc123"));
assert!(resp.error.is_none());
}
#[test]
fn copilot_token_response_deserializes() {
let json = r#"{"token": "tid=abc;exp=9999999999", "expires_at": 9999999999}"#;
let resp: CopilotTokenResponse = serde_json::from_str(json).unwrap();
assert!(resp.token.starts_with("tid="));
assert_eq!(resp.expires_at, 9999999999);
}
}