use anyhow::{Context, Result};
use std::net::TcpListener;
use tokio::sync::oneshot;
const DEFAULT_PORT_RANGE_START: u16 = 8787;
const DEFAULT_PORT_RANGE_END: u16 = 8887;
#[derive(Debug, Clone)]
pub struct OAuthCallbackData {
pub code: String,
pub state: String,
pub callback_url: Option<String>,
}
pub struct OAuthCallbackServer {
port: u16,
shutdown_tx: oneshot::Sender<()>,
}
impl OAuthCallbackServer {
pub fn new(port: u16) -> Self {
let (shutdown_tx, _) = oneshot::channel();
Self {
port,
shutdown_tx,
}
}
pub fn with_available_port() -> Result<Self> {
let port = find_available_port(DEFAULT_PORT_RANGE_START, DEFAULT_PORT_RANGE_END)
.context("No available port in callback range")?;
Ok(Self::new(port))
}
pub fn redirect_uri(&self) -> String {
format!("http://localhost:{}/callback", self.port)
}
pub fn port(&self) -> u16 {
self.port
}
pub async fn start(self) -> Result<OAuthCallbackData> {
let listener = TcpListener::bind(("127.0.0.1", self.port))
.context(format!("Failed to bind to port {}", self.port))?;
listener.set_nonblocking(true)?;
let (tx, rx) = oneshot::channel::<Result<OAuthCallbackData, OAuthError>>();
tokio::task::spawn_local(async move {
if let Err(e) = run_server(listener, tx).await {
eprintln!("OAuth callback server error: {}", e);
}
});
let result = rx.await.map_err(|e| anyhow::anyhow!("OAuth callback error: {}", e))?;
result.map_err(|e| anyhow::anyhow!("OAuth error: {}", e))
}
pub fn is_running(&self) -> bool {
!self.shutdown_tx.is_closed()
}
}
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid callback URL: {0}")]
InvalidCallback(String),
#[error("Missing authorization code")]
MissingCode,
#[error("Missing state parameter")]
MissingState,
#[error("Server shutdown")]
Shutdown,
#[error("Callback timeout")]
Timeout,
#[error("HTTP parse error: {0}")]
HttpParse(#[from] url::ParseError),
}
fn find_available_port(start: u16, end: u16) -> Option<u16> {
for port in start..=end {
if TcpListener::bind(("127.0.0.1", port)).is_ok() {
return Some(port);
}
}
None
}
async fn run_server(
listener: TcpListener,
tx: oneshot::Sender<Result<OAuthCallbackData, OAuthError>>,
) -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let listener = tokio::net::TcpListener::from_std(listener)?;
let timeout_duration = std::time::Duration::from_secs(600);
let result = tokio::time::timeout(timeout_duration, listener.accept()).await;
match result {
Ok(Ok((mut stream, _))) => {
let mut buf = [0u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(n) if n > 0 => n,
_ => return Ok(()),
};
let request = String::from_utf8_lossy(&buf[..n]);
if let Some(callback_data) = parse_oauth_callback(&request) {
let response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: close\r\n\
\r\n\
<!DOCTYPE html>\
<html><head><title>OAuth Callback</title></head>\
<body style=\"font-family: system-ui; padding: 40px; text-align: center;\">\
<h2>Authentication Successful</h2>\
<p>You can close this window and return to the terminal.</p>\
<script>window.close();</script>\
</body></html>";
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
let _ = tx.send(Ok(callback_data));
} else {
let response = "HTTP/1.1 400 Bad Request\r\n\
Content-Type: text/html\r\n\
Connection: close\r\n\
\r\n\
<!DOCTYPE html>\
<html><head><title>OAuth Error</title></head>\
<body><h2>Invalid OAuth Callback</h2></body></html>";
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
}
}
Ok(Err(e)) => {
eprintln!("Connection error: {}", e);
}
Err(_) => {
let _ = tx.send(Err(OAuthError::Timeout));
}
}
Ok(())
}
fn parse_oauth_callback(request: &str) -> Option<OAuthCallbackData> {
let request_line = request.lines().next()?;
if !request_line.starts_with("GET ") {
return None;
}
let path = request_line
.strip_prefix("GET ")?
.split_whitespace()
.next()?;
if !path.starts_with("/callback") {
return None;
}
let query = path.split('?').nth(1)?;
let mut code = None;
let mut state = None;
let mut callback_url = None;
for pair in query.split('&') {
let mut parts = pair.split('=');
let key = parts.next()?;
let value = parts.next()?.replace("%3D", "=").replace("%26", "&");
match key {
"code" => code = Some(value),
"state" => state = Some(value),
"url" => callback_url = Some(value),
_ => {}
}
}
let code = code?;
let state = state?;
Some(OAuthCallbackData {
code,
state,
callback_url,
})
}
pub fn open_browser(url: &str) -> std::io::Result<std::process::Child> {
#[cfg(target_os = "windows")]
{
std::process::Command::new("cmd")
.args(["/C", "start", "", url])
.spawn()
}
#[cfg(target_os = "macos")]
{
std::process::Command::new("open")
.arg(url)
.spawn()
}
#[cfg(target_os = "linux")]
{
let browsers = ["xdg-open", "gnome-open", "kde-open", "x-www-browser", "firefox", "google-chrome"];
for browser in browsers {
if let Ok(child) = std::process::Command::new(browser)
.arg(url)
.spawn()
{
return Ok(child);
}
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"No suitable browser found",
))
}
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
{
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Unsupported platform",
))
}
}
pub async fn authorize_with_browser(
auth_url: &str,
) -> Result<OAuthCallbackData> {
open_browser(auth_url).map_err(|e| anyhow::anyhow!("Failed to open browser: {}", e))?;
let server = OAuthCallbackServer::with_available_port()
.context("Failed to create callback server")?;
let port = server.port();
tracing::info!("OAuth callback server listening on port {}", port);
server.start().await
}