use std::fs;
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use super::config::GitHubCopilotConfig;
use super::error::GitHubCopilotError;
use crate::core::providers::unified_provider::ProviderError;
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyInfo {
pub token: String,
pub expires_at: u64,
#[serde(default)]
pub endpoints: Endpoints,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Endpoints {
pub api: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
interval: u64,
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CopilotAuthenticator {
token_dir: PathBuf,
access_token_path: PathBuf,
api_key_path: PathBuf,
}
impl CopilotAuthenticator {
pub fn new(config: &GitHubCopilotConfig) -> Self {
let token_dir = PathBuf::from(config.get_token_dir());
let access_token_path = token_dir.join(config.get_access_token_file());
let api_key_path = token_dir.join(config.get_api_key_file());
Self {
token_dir,
access_token_path,
api_key_path,
}
}
fn ensure_token_dir(&self) -> Result<(), GitHubCopilotError> {
if !self.token_dir.exists() {
fs::create_dir_all(&self.token_dir).map_err(|e| {
ProviderError::configuration(
"github_copilot",
format!("Failed to create token directory: {}", e),
)
})?;
}
Ok(())
}
pub async fn get_access_token(&self) -> Result<String, GitHubCopilotError> {
if let Ok(token) = fs::read_to_string(&self.access_token_path) {
let token = token.trim().to_string();
if !token.is_empty() {
return Ok(token);
}
}
warn!("No cached access token found, need to authenticate");
for attempt in 1..=3 {
debug!("Access token acquisition attempt {}/3", attempt);
match self.perform_device_flow().await {
Ok(token) => {
self.ensure_token_dir()?;
if let Err(e) = fs::write(&self.access_token_path, &token) {
warn!("Failed to cache access token: {}", e);
}
return Ok(token);
}
Err(e) => {
warn!("Device flow attempt {} failed: {}", attempt, e);
if attempt == 3 {
return Err(ProviderError::authentication(
"github_copilot",
"Access token error: Failed to get access token after 3 attempts",
));
}
}
}
}
Err(ProviderError::authentication(
"github_copilot",
"Access token error: Failed to get access token",
))
}
pub async fn get_api_key(&self) -> Result<String, GitHubCopilotError> {
if let Ok(content) = fs::read_to_string(&self.api_key_path)
&& let Ok(api_key_info) = serde_json::from_str::<ApiKeyInfo>(&content)
{
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs();
if api_key_info.expires_at > now {
return Ok(api_key_info.token);
}
debug!("API key expired, refreshing...");
}
self.refresh_api_key().await
}
pub fn get_api_base(&self) -> Option<String> {
if let Ok(content) = fs::read_to_string(&self.api_key_path)
&& let Ok(api_key_info) = serde_json::from_str::<ApiKeyInfo>(&content)
{
return api_key_info.endpoints.api;
}
None
}
async fn refresh_api_key(&self) -> Result<String, GitHubCopilotError> {
let access_token = self.get_access_token().await?;
let headers = self.get_github_headers(Some(&access_token));
let client = reqwest::Client::new();
for attempt in 1..=3 {
let response = client
.get(GITHUB_API_KEY_URL)
.headers(headers.clone())
.send()
.await
.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Refresh error: HTTP error: {}", e),
)
})?;
if !response.status().is_success() {
warn!(
"API key refresh attempt {}/3 failed with status: {}",
attempt,
response.status()
);
if attempt == 3 {
return Err(ProviderError::authentication(
"github_copilot",
"Refresh error: Failed to refresh API key after 3 attempts",
));
}
continue;
}
let api_key_info: ApiKeyInfo = response.json().await.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Refresh error: Failed to parse response: {}", e),
)
})?;
self.ensure_token_dir()?;
if let Ok(json) = serde_json::to_string(&api_key_info)
&& let Err(e) = fs::write(&self.api_key_path, json)
{
warn!("Failed to cache API key: {}", e);
}
return Ok(api_key_info.token);
}
Err(ProviderError::authentication(
"github_copilot",
"Refresh error: Failed to refresh API key",
))
}
async fn perform_device_flow(&self) -> Result<String, GitHubCopilotError> {
let client = reqwest::Client::new();
let headers = self.get_github_headers(None);
let response = client
.post(GITHUB_DEVICE_CODE_URL)
.headers(headers.clone())
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await
.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Device code error: HTTP error: {}", e),
)
})?;
if !response.status().is_success() {
return Err(ProviderError::authentication(
"github_copilot",
format!(
"Device code error: Failed to get device code: {}",
response.status()
),
));
}
let device_code_response: DeviceCodeResponse = response.json().await.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Device code error: Failed to parse response: {}", e),
)
})?;
println!(
"\nPlease visit {} and enter code {} to authenticate.\n",
device_code_response.verification_uri, device_code_response.user_code
);
let interval = if device_code_response.interval > 0 {
device_code_response.interval
} else {
5
};
let max_attempts = 60 / interval as usize;
for _attempt in 0..max_attempts {
tokio::time::sleep(Duration::from_secs(interval)).await;
let response = client
.post(GITHUB_ACCESS_TOKEN_URL)
.headers(headers.clone())
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": device_code_response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await
.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Access token error: HTTP error: {}", e),
)
})?;
let token_response: AccessTokenResponse = response.json().await.map_err(|e| {
ProviderError::authentication(
"github_copilot",
format!("Access token error: Failed to parse response: {}", e),
)
})?;
if let Some(access_token) = token_response.access_token {
debug!("Authentication successful!");
return Ok(access_token);
}
if let Some(error) = &token_response.error
&& error != "authorization_pending"
{
return Err(ProviderError::authentication(
"github_copilot",
format!("Access token error: OAuth error: {}", error),
));
}
}
Err(ProviderError::authentication(
"github_copilot",
"Access token error: Timed out waiting for user authorization",
))
}
fn get_github_headers(&self, access_token: Option<&str>) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"accept",
"application/json"
.parse()
.expect("static header value 'application/json' is always valid"),
);
headers.insert(
"content-type",
"application/json"
.parse()
.expect("static header value 'application/json' is always valid"),
);
headers.insert(
"editor-version",
"vscode/1.85.1"
.parse()
.expect("static header value 'vscode/1.85.1' is always valid"),
);
headers.insert(
"editor-plugin-version",
"copilot/1.155.0"
.parse()
.expect("static header value 'copilot/1.155.0' is always valid"),
);
headers.insert(
"user-agent",
"GithubCopilot/1.155.0"
.parse()
.expect("static header value 'GithubCopilot/1.155.0' is always valid"),
);
headers.insert(
"accept-encoding",
"gzip,deflate,br"
.parse()
.expect("static header value 'gzip,deflate,br' is always valid"),
);
if let Some(token) = access_token {
if let Ok(value) = format!("token {}", token).parse() {
headers.insert("authorization", value);
} else {
tracing::warn!(
"GitHub Copilot access token contains invalid header characters, skipping authorization header"
);
}
}
headers
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_authenticator_creation() {
let config = GitHubCopilotConfig::default();
let auth = CopilotAuthenticator::new(&config);
assert!(auth.token_dir.to_string_lossy().contains("github_copilot"));
}
#[test]
fn test_api_key_info_serialization() {
let info = ApiKeyInfo {
token: "test-token".to_string(),
expires_at: 1234567890,
endpoints: Endpoints {
api: Some("https://api.example.com".to_string()),
},
};
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("test-token"));
assert!(json.contains("1234567890"));
let deserialized: ApiKeyInfo = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.token, "test-token");
assert_eq!(deserialized.expires_at, 1234567890);
}
#[test]
fn test_api_key_info_deserialization_with_defaults() {
let json = r#"{"token": "test", "expires_at": 123}"#;
let info: ApiKeyInfo = serde_json::from_str(json).unwrap();
assert_eq!(info.token, "test");
assert!(info.endpoints.api.is_none());
}
#[test]
fn test_get_github_headers() {
let config = GitHubCopilotConfig::default();
let auth = CopilotAuthenticator::new(&config);
let headers = auth.get_github_headers(None);
assert!(headers.get("accept").is_some());
assert!(headers.get("user-agent").is_some());
assert!(headers.get("authorization").is_none());
let headers = auth.get_github_headers(Some("test-token"));
assert!(headers.get("authorization").is_some());
}
}