use crate::config::OAuthConfig;
use colored::Colorize;
use serde::{Deserialize, Serialize};
use serde_urlencoded;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
fn now_secs() -> u64 {
crate::utils::time::now_secs()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceCodeResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceTokenResponse {
pub access_token: String,
#[serde(default)]
pub token_type: String,
#[serde(default)]
pub scope: String,
#[serde(default)]
pub refresh_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceFlowErrorResponse {
pub error: String,
#[serde(default)]
pub error_description: Option<String>,
#[serde(default)]
pub error_uri: Option<String>,
#[serde(default)]
pub interval: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct PendingDeviceAuth {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub expires_at: u64, pub interval: u64,
pub last_poll: std::time::Instant,
}
lazy_static::lazy_static! {
static ref PENDING_DEVICE_AUTHS: Arc<Mutex<HashMap<String, PendingDeviceAuth>>> =
Arc::new(Mutex::new(HashMap::new()));
}
pub async fn start_device_flow(config: &OAuthConfig) -> Result<DeviceCodeResponse, String> {
let client = reqwest::Client::new();
let scope = config.scopes.join(" ");
crate::log_debug!(
"Starting GitHub Device Flow - client_id: {}, scopes: {}",
config.client_id,
scope
);
let params = [
("client_id", config.client_id.as_str()),
("scope", scope.as_str()),
];
let form_body =
serde_urlencoded::to_string(params).map_err(|e| format!("Form error: {}", e))?;
let response = client
.post("https://github.com/login/device/code")
.header(reqwest::header::ACCEPT, "application/json")
.header(
reqwest::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.body(form_body)
.send()
.await
.map_err(|e| format!("Network error: {}", e))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| format!("Read error: {}", e))?;
crate::log_debug!("Device code response - status: {}, body: {}", status, text);
if !status.is_success() {
if let Ok(flow_err) = serde_json::from_str::<DeviceFlowErrorResponse>(&text) {
return Err(format!(
"{} - {}",
flow_err.error,
flow_err.error_description.unwrap_or_default()
));
}
return Err(format!("Device code request failed: {} - {}", status, text));
}
serde_json::from_str(&text).map_err(|e| format!("Invalid response: {}", e))
}
pub async fn poll_for_device_token(
config: &OAuthConfig,
device_code: &str,
) -> Result<DeviceTokenResponse, String> {
let client = reqwest::Client::new();
let params = [
("client_id", config.client_id.as_str()),
("device_code", device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
];
let form_body =
serde_urlencoded::to_string(params).map_err(|e| format!("Form error: {}", e))?;
let response = client
.post("https://github.com/login/oauth/access_token")
.header(reqwest::header::ACCEPT, "application/json")
.header(
reqwest::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.body(form_body)
.send()
.await
.map_err(|e| format!("Network error: {}", e))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| format!("Read error: {}", e))?;
crate::log_debug!(
"Device token poll response - status: {}, body: {}",
status,
text
);
if let Ok(error_response) = serde_json::from_str::<DeviceFlowErrorResponse>(&text) {
return Err(match error_response.error.as_str() {
"authorization_pending" => "authorization_pending".to_string(),
"slow_down" => {
if let Some(new_interval) = error_response.interval {
format!("slow_down:{}", new_interval)
} else {
"slow_down".to_string()
}
}
"expired_token" => "expired_token".to_string(),
"access_denied" => "access_denied".to_string(),
_ => format!(
"{} - {}",
error_response.error,
error_response.error_description.unwrap_or_default()
),
});
}
if !status.is_success() {
return Err(format!("Token request failed: {} - {}", status, text));
}
match serde_json::from_str::<DeviceTokenResponse>(&text) {
Ok(token) => Ok(token),
Err(e) => {
crate::log_debug!("Failed to parse as JSON: {}, trying URL-encoded format", e);
let params: std::collections::HashMap<String, String> =
serde_urlencoded::from_str(&text).map_err(|parse_err| {
format!(
"Invalid response format: JSON error: {}, URL-encoded error: {}",
e, parse_err
)
})?;
let access_token = params
.get("access_token")
.ok_or_else(|| format!("Missing access_token in response: {}", text))?
.clone();
let token_type = params
.get("token_type")
.unwrap_or(&"bearer".to_string())
.clone();
let scope = params.get("scope").unwrap_or(&"".to_string()).clone();
Ok(DeviceTokenResponse {
access_token,
token_type,
scope,
refresh_token: params.get("refresh_token").cloned(),
})
}
}
}
async fn get_or_create_device_auth(
config: &OAuthConfig,
server_name: &str,
) -> Result<PendingDeviceAuth, String> {
let mut auths = PENDING_DEVICE_AUTHS.lock().await;
if let Some(auth) = auths.get(server_name) {
if auth.expires_at > now_secs() {
return Ok(auth.clone());
} else {
auths.remove(server_name);
}
}
drop(auths);
let device_response = start_device_flow(config).await?;
let pending_auth = PendingDeviceAuth {
device_code: device_response.device_code,
user_code: device_response.user_code,
verification_uri: device_response.verification_uri,
expires_at: now_secs() + device_response.expires_in,
interval: device_response.interval,
last_poll: std::time::Instant::now(),
};
let mut auths = PENDING_DEVICE_AUTHS.lock().await;
auths.insert(server_name.to_string(), pending_auth.clone());
Ok(pending_auth)
}
pub async fn execute_device_flow(
config: &OAuthConfig,
server_name: &str,
) -> Result<String, String> {
let pending_auth = get_or_create_device_auth(config, server_name).await?;
println!("\n");
println!("{}", "═".repeat(70));
println!("\\n");
println!("{}", "═".repeat(70));
println!(
"{}",
"🔐 GITHUB AUTHORIZATION REQUIRED".bright_cyan().bold()
);
println!("{}", "═".repeat(70));
println!();
println!(
"Please visit: {}",
pending_auth.verification_uri.bright_white()
);
println!();
println!(
"And enter this code: {}",
pending_auth.user_code.bright_green().bold()
);
println!();
println!(
"This code expires in {} minutes.",
(pending_auth.expires_at - now_secs()) / 60
);
println!();
println!("Waiting for authorization... (press Ctrl+C to cancel)");
println!("{}", "─".repeat(70));
println!();
let mut interval_seconds = pending_auth.interval;
let expires_at_timestamp = pending_auth.expires_at; let mut last_poll_time = std::time::Instant::now();
println!(
"🔍 Starting polling loop with interval: {}s",
interval_seconds
);
loop {
if now_secs() >= expires_at_timestamp {
let mut auths = PENDING_DEVICE_AUTHS.lock().await;
auths.remove(server_name);
return Err("Authorization timed out. Please try again.".to_string());
}
let elapsed_since_last_poll = last_poll_time.elapsed();
if elapsed_since_last_poll < Duration::from_secs(interval_seconds) {
tokio::time::sleep(Duration::from_secs(interval_seconds) - elapsed_since_last_poll)
.await;
}
last_poll_time = std::time::Instant::now();
match poll_for_device_token(config, &pending_auth.device_code).await {
Ok(token_response) => {
println!();
println!("✅ Authorization successful!");
println!();
let mut auths = PENDING_DEVICE_AUTHS.lock().await;
auths.remove(server_name);
return Ok(token_response.access_token);
}
Err(e) => {
if e.starts_with("slow_down") {
if let Some(new_interval_str) = e.strip_prefix("slow_down:") {
if let Ok(new_interval) = new_interval_str.parse::<u64>() {
interval_seconds = new_interval;
crate::log_debug!(
"slow_down: using new interval from GitHub: {}s",
interval_seconds
);
} else {
interval_seconds += 5;
}
} else {
interval_seconds += 5;
}
println!(
"\nRate limited - slowing down polling (new interval: {}s)...",
interval_seconds
);
} else {
match e.as_str() {
"authorization_pending" => {
print!(".");
let _ = std::io::Write::flush(&mut std::io::stdout());
}
"access_denied" => {
let mut auths: tokio::sync::MutexGuard<
'_,
HashMap<String, PendingDeviceAuth>,
> = PENDING_DEVICE_AUTHS.lock().await;
auths.remove(server_name);
return Err("Authorization was denied. Please try again.".to_string());
}
"expired_token" => {
let mut auths: tokio::sync::MutexGuard<
'_,
HashMap<String, PendingDeviceAuth>,
> = PENDING_DEVICE_AUTHS.lock().await;
auths.remove(server_name);
return Err("Authorization code expired. Please try again.".to_string());
}
_ => {
crate::log_debug!("Device flow error: {}", e);
print!(".");
let _ = std::io::Write::flush(&mut std::io::stdout());
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_device_code_response() {
let json = r#"{
"device_code": "3584d83530557fdd1f46af8289938c8ef79f9dc5",
"user_code": "WDJB-MJHT",
"verification_uri": "https://github.com/login/device",
"expires_in": 900,
"interval": 5
}"#;
let response: DeviceCodeResponse = serde_json::from_str(json).unwrap();
assert_eq!(
response.device_code,
"3584d83530557fdd1f46af8289938c8ef79f9dc5"
);
assert_eq!(response.user_code, "WDJB-MJHT");
assert_eq!(response.verification_uri, "https://github.com/login/device");
assert_eq!(response.expires_in, 900);
assert_eq!(response.interval, 5);
}
#[test]
fn test_parse_device_token_response() {
let json = r#"{
"access_token": "gho_16C7e42F292c6912E7710c838347Ae178B4a",
"token_type": "bearer",
"scope": "repo,gist"
}"#;
let response: DeviceTokenResponse = serde_json::from_str(json).unwrap();
assert_eq!(
response.access_token,
"gho_16C7e42F292c6912E7710c838347Ae178B4a"
);
assert_eq!(response.token_type, "bearer");
assert_eq!(response.scope, "repo,gist");
}
#[test]
fn test_parse_device_flow_error() {
let json = r#"{
"error": "authorization_pending",
"error_description": "The authorization request is still pending"
}"#;
let error: DeviceFlowErrorResponse = serde_json::from_str(json).unwrap();
assert_eq!(error.error, "authorization_pending");
assert_eq!(
error.error_description,
Some("The authorization request is still pending".to_string())
);
}
}