use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use anyhow::{anyhow, Context, Result};
use oauth2::basic::BasicClient;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope,
TokenResponse, TokenUrl,
};
use reqwest::redirect::Policy;
use crate::constants::{OAUTH_AUTH_URL, OAUTH_REDIRECT_URI, OAUTH_SCOPES, OAUTH_TOKEN_URL};
pub struct AuthHandler {
client_id: String,
client_secret: String,
}
impl AuthHandler {
pub fn new(client_id: String, client_secret: String) -> Self {
Self {
client_id,
client_secret,
}
}
pub fn get_auth_url(&self) -> Result<(String, CsrfToken)> {
let client = self.create_oauth_client()?;
let mut auth_request = client.authorize_url(CsrfToken::new_random);
for scope in OAUTH_SCOPES {
auth_request = auth_request.add_scope(Scope::new((*scope).to_string()));
}
let (auth_url, csrf_token) = auth_request.url();
Ok((auth_url.to_string(), csrf_token))
}
pub async fn run_oauth_flow(&self) -> Result<String> {
let (auth_url, csrf_token) = self.get_auth_url()?;
let _ = webbrowser::open(&auth_url);
let code = self.capture_callback(csrf_token)?;
let token = self.exchange_code(&code).await?;
Ok(token)
}
fn capture_callback(&self, expected_csrf: CsrfToken) -> Result<String> {
let listener = TcpListener::bind("127.0.0.1:8080")
.context("Failed to bind to localhost:8080. Is another process using this port?")?;
let (mut stream, _) = listener
.accept()
.context("Failed to accept OAuth callback connection")?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.context("Failed to read OAuth callback request")?;
let (code, state) = parse_callback_request(&request_line)?;
if state != *expected_csrf.secret() {
let response = create_error_response("CSRF token mismatch - possible security issue");
stream.write_all(response.as_bytes())?;
return Err(anyhow!(
"CSRF token mismatch - authorization may have been tampered with"
));
}
let response = create_success_response();
stream.write_all(response.as_bytes())?;
Ok(code)
}
async fn exchange_code(&self, code: &str) -> Result<String> {
let client = self.create_oauth_client()?;
let http_client = reqwest::Client::builder()
.redirect(Policy::none())
.build()
.context("Failed to create HTTP client")?;
let token_result = client
.exchange_code(AuthorizationCode::new(code.to_string()))
.request_async(&http_client)
.await
.context("Failed to exchange authorization code for token")?;
Ok(token_result.access_token().secret().clone())
}
#[allow(clippy::type_complexity)]
fn create_oauth_client(
&self,
) -> Result<
oauth2::Client<
oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
oauth2::StandardTokenResponse<
oauth2::EmptyExtraTokenFields,
oauth2::basic::BasicTokenType,
>,
oauth2::StandardTokenIntrospectionResponse<
oauth2::EmptyExtraTokenFields,
oauth2::basic::BasicTokenType,
>,
oauth2::StandardRevocableToken,
oauth2::StandardErrorResponse<oauth2::RevocationErrorResponseType>,
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>,
> {
let client = BasicClient::new(ClientId::new(self.client_id.clone()))
.set_client_secret(ClientSecret::new(self.client_secret.clone()))
.set_auth_uri(
AuthUrl::new(OAUTH_AUTH_URL.to_string()).context("Invalid authorization URL")?,
)
.set_token_uri(TokenUrl::new(OAUTH_TOKEN_URL.to_string()).context("Invalid token URL")?)
.set_redirect_uri(
RedirectUrl::new(OAUTH_REDIRECT_URI.to_string()).context("Invalid redirect URI")?,
);
Ok(client)
}
}
fn parse_callback_request(request_line: &str) -> Result<(String, String)> {
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
return Err(anyhow!("Invalid callback request format"));
}
let path = parts[1];
if path.contains("error=") {
let error_desc = extract_param(path, "error_description")
.or_else(|| extract_param(path, "error"))
.unwrap_or_else(|| "Unknown authorization error".to_string());
return Err(anyhow!("Authorization failed: {}", error_desc));
}
let code =
extract_param(path, "code").ok_or_else(|| anyhow!("No authorization code in callback"))?;
let state =
extract_param(path, "state").ok_or_else(|| anyhow!("No state parameter in callback"))?;
Ok((code, state))
}
fn extract_param(path: &str, param: &str) -> Option<String> {
let query = path.split('?').nth(1)?;
for pair in query.split('&') {
let mut kv = pair.split('=');
if let (Some(key), Some(value)) = (kv.next(), kv.next()) {
if key == param {
return Some(urlencoding_decode(value));
}
}
}
None
}
fn urlencoding_decode(s: &str) -> String {
s.replace("%20", " ")
.replace("%21", "!")
.replace("%2B", "+")
.replace("%3D", "=")
.replace("%26", "&")
.replace("%3F", "?")
.replace("%2F", "/")
.replace("%3A", ":")
}
fn create_success_response() -> String {
let body = r#"<!DOCTYPE html>
<html>
<head>
<title>Authorization Successful</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
display: flex; justify-content: center; align-items: center; height: 100vh;
margin: 0; background: #f5f5f5; }
.container { text-align: center; padding: 40px; background: white;
border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
h1 { color: #22c55e; margin-bottom: 16px; }
p { color: #666; }
</style>
</head>
<body>
<div class="container">
<h1>Authorization Successful</h1>
<p>You can close this window and return to the terminal.</p>
</div>
</body>
</html>"#;
format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
)
}
fn create_error_response(message: &str) -> String {
let body = format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>Authorization Failed</title>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
display: flex; justify-content: center; align-items: center; height: 100vh;
margin: 0; background: #f5f5f5; }}
.container {{ text-align: center; padding: 40px; background: white;
border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
h1 {{ color: #ef4444; margin-bottom: 16px; }}
p {{ color: #666; }}
</style>
</head>
<body>
<div class="container">
<h1>Authorization Failed</h1>
<p>{}</p>
</div>
</body>
</html>"#,
message
);
format!(
"HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_callback_request_success() {
let request = "GET /?code=abc123&state=xyz789 HTTP/1.1";
let (code, state) = parse_callback_request(request).unwrap();
assert_eq!(code, "abc123");
assert_eq!(state, "xyz789");
}
#[test]
fn test_parse_callback_request_error() {
let request = "GET /?error=access_denied&error_description=User%20denied HTTP/1.1";
let result = parse_callback_request(request);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Authorization failed"));
}
#[test]
fn test_parse_callback_request_missing_code() {
let request = "GET /?state=xyz789 HTTP/1.1";
let result = parse_callback_request(request);
assert!(result.is_err());
}
#[test]
fn test_parse_callback_request_missing_state() {
let request = "GET /?code=abc123 HTTP/1.1";
let result = parse_callback_request(request);
assert!(result.is_err());
}
#[test]
fn test_extract_param() {
let path = "/?code=abc&state=xyz&other=123";
assert_eq!(extract_param(path, "code"), Some("abc".to_string()));
assert_eq!(extract_param(path, "state"), Some("xyz".to_string()));
assert_eq!(extract_param(path, "other"), Some("123".to_string()));
assert_eq!(extract_param(path, "missing"), None);
}
#[test]
fn test_urlencoding_decode() {
assert_eq!(urlencoding_decode("hello%20world"), "hello world");
assert_eq!(urlencoding_decode("test%3Dvalue"), "test=value");
}
#[test]
fn test_create_success_response() {
let response = create_success_response();
assert!(response.starts_with("HTTP/1.1 200 OK"));
assert!(response.contains("Authorization Successful"));
}
#[test]
fn test_create_error_response() {
let response = create_error_response("Test error");
assert!(response.starts_with("HTTP/1.1 400 Bad Request"));
assert!(response.contains("Authorization Failed"));
assert!(response.contains("Test error"));
}
}