use crate::runtime::{Result, RobomotionError};
use std::net::TcpListener;
use std::time::Duration;
use tokio::sync::mpsc;
pub const OAUTH2_REDIRECT_URL: &str = "http://localhost:9876/oauth2/callback";
pub const OAUTH2_CALLBACK_PORT: u16 = 9876;
pub const OAUTH2_PORT_RETRY_INTERVAL: Duration = Duration::from_secs(5);
pub const OAUTH2_PORT_MAX_TIMEOUT: Duration = Duration::from_secs(300);
pub const OAUTH2_AUTH_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Clone)]
pub struct OAuth2Config {
pub client_id: String,
pub client_secret: String,
pub auth_url: String,
pub token_url: String,
pub scopes: Vec<String>,
}
pub async fn open_oauth_dialog(config: &OAuth2Config) -> Result<String> {
let listener = acquire_oauth_port().await?;
let (tx, mut rx) = mpsc::channel::<std::result::Result<String, String>>(1);
let auth_url = format!(
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&access_type=offline&prompt=consent&state=state",
config.auth_url,
urlencoding::encode(&config.client_id),
urlencoding::encode(OAUTH2_REDIRECT_URL),
urlencoding::encode(&config.scopes.join(" "))
);
if let Err(e) = webbrowser::open(&auth_url) {
return Err(RobomotionError::OAuth(format!(
"Failed to open browser: {}",
e
)));
}
let tx_clone = tx.clone();
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::from_std(listener).unwrap();
loop {
match listener.accept().await {
Ok((mut stream, _)) => {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let mut reader = BufReader::new(&mut stream);
let mut request_line = String::new();
if reader.read_line(&mut request_line).await.is_ok() {
if let Some(code) = extract_code_from_request(&request_line) {
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n{}",
SUCCESS_HTML
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = tx_clone.send(Ok(code)).await;
break;
} else if let Some(error) = extract_error_from_request(&request_line) {
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n{}",
ERROR_HTML
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = tx_clone.send(Err(error)).await;
break;
}
}
}
Err(_) => break,
}
}
});
let result = tokio::select! {
result = rx.recv() => {
match result {
Some(Ok(code)) => Ok(code),
Some(Err(e)) => Err(RobomotionError::OAuth(e)),
None => Err(RobomotionError::OAuth("Channel closed".to_string())),
}
}
_ = tokio::time::sleep(OAUTH2_AUTH_TIMEOUT) => {
Err(RobomotionError::OAuth(format!(
"OAuth authorization timed out after {:?} waiting for user to complete authorization",
OAUTH2_AUTH_TIMEOUT
)))
}
};
server_handle.abort();
result
}
async fn acquire_oauth_port() -> Result<std::net::TcpListener> {
let start_time = std::time::Instant::now();
let mut attempt = 0;
loop {
attempt += 1;
match TcpListener::bind(format!("127.0.0.1:{}", OAUTH2_CALLBACK_PORT)) {
Ok(listener) => {
listener.set_nonblocking(true).ok();
if attempt > 1 {
tracing::info!(
"OAuth: Successfully acquired port {} after {} attempts (elapsed: {:?})",
OAUTH2_CALLBACK_PORT,
attempt,
start_time.elapsed()
);
}
return Ok(listener);
}
Err(_) => {
let elapsed = start_time.elapsed();
if elapsed >= OAUTH2_PORT_MAX_TIMEOUT {
return Err(RobomotionError::OAuth(format!(
"Could not bind to port {} after {:?} ({} attempts). \
Port is in use by another process.",
OAUTH2_CALLBACK_PORT, OAUTH2_PORT_MAX_TIMEOUT, attempt
)));
}
if attempt == 1 {
tracing::info!(
"OAuth: Port {} is busy, waiting for it to become available...",
OAUTH2_CALLBACK_PORT
);
}
tokio::time::sleep(OAUTH2_PORT_RETRY_INTERVAL).await;
}
}
}
}
fn extract_code_from_request(request: &str) -> Option<String> {
if request.contains("/oauth2/callback") {
if let Some(query_start) = request.find('?') {
let query = &request[query_start + 1..];
if let Some(end) = query.find(' ') {
let query = &query[..end];
for param in query.split('&') {
let parts: Vec<&str> = param.splitn(2, '=').collect();
if parts.len() == 2 && parts[0] == "code" {
return Some(urlencoding::decode(parts[1]).ok()?.into_owned());
}
}
}
}
}
None
}
fn extract_error_from_request(request: &str) -> Option<String> {
if request.contains("/oauth2/callback") {
if let Some(query_start) = request.find('?') {
let query = &request[query_start + 1..];
if let Some(end) = query.find(' ') {
let query = &query[..end];
let mut error = None;
let mut error_desc = None;
for param in query.split('&') {
let parts: Vec<&str> = param.splitn(2, '=').collect();
if parts.len() == 2 {
match parts[0] {
"error" => error = urlencoding::decode(parts[1]).ok().map(|s| s.into_owned()),
"error_description" => {
error_desc = urlencoding::decode(parts[1]).ok().map(|s| s.into_owned())
}
_ => {}
}
}
}
if let Some(err) = error {
return Some(match error_desc {
Some(desc) => format!("{}: {}", err, desc),
None => err,
});
}
}
}
}
None
}
const SUCCESS_HTML: &str = r#"<!DOCTYPE html>
<html>
<head>
<title>Authorization Successful</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f5f5f5; }
.container { text-align: center; padding: 40px; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
h1 { color: #4CAF50; margin-bottom: 16px; }
p { color: #666; }
</style>
</head>
<body>
<div class="container">
<h1>Authorization Successful</h1>
<p>You can close this window and return to the application.</p>
</div>
</body>
</html>"#;
const ERROR_HTML: &str = r#"<!DOCTYPE html>
<html>
<head>
<title>Authorization Failed</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f5f5f5; }
.container { text-align: center; padding: 40px; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
h1 { color: #f44336; margin-bottom: 16px; }
p { color: #666; }
</style>
</head>
<body>
<div class="container">
<h1>Authorization Failed</h1>
<p>An error occurred during authorization.</p>
<p>You can close this window and try again.</p>
</div>
</body>
</html>"#;