use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
pub struct OAuthCredentials {
pub client_id: &'static str,
pub client_secret: &'static str,
}
const GOOGLE_CLIENT_ID: &str = match option_env!("IRONCLAW_GOOGLE_CLIENT_ID") {
Some(v) => v,
None => "564604149681-efo25d43rs85v0tibdepsmdv5dsrhhr0.apps.googleusercontent.com",
};
const GOOGLE_CLIENT_SECRET: &str = match option_env!("IRONCLAW_GOOGLE_CLIENT_SECRET") {
Some(v) => v,
None => "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",
};
pub fn builtin_credentials(secret_name: &str) -> Option<OAuthCredentials> {
match secret_name {
"google_oauth_token" => Some(OAuthCredentials {
client_id: GOOGLE_CLIENT_ID,
client_secret: GOOGLE_CLIENT_SECRET,
}),
_ => None,
}
}
pub const OAUTH_CALLBACK_PORT: u16 = 9876;
#[derive(Debug, thiserror::Error)]
pub enum OAuthCallbackError {
#[error("Port {0} is in use (another auth flow running?): {1}")]
PortInUse(u16, String),
#[error("Authorization denied by user")]
Denied,
#[error("Timed out waiting for authorization")]
Timeout,
#[error("IO error: {0}")]
Io(String),
}
pub async fn bind_callback_listener() -> Result<TcpListener, OAuthCallbackError> {
let ipv4_addr = format!("127.0.0.1:{}", OAUTH_CALLBACK_PORT);
match TcpListener::bind(&ipv4_addr).await {
Ok(listener) => return Ok(listener),
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
return Err(OAuthCallbackError::PortInUse(
OAUTH_CALLBACK_PORT,
e.to_string(),
));
}
Err(_) => {
}
}
TcpListener::bind(format!("[::1]:{}", OAUTH_CALLBACK_PORT))
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
OAuthCallbackError::PortInUse(OAUTH_CALLBACK_PORT, e.to_string())
} else {
OAuthCallbackError::Io(e.to_string())
}
})
}
pub async fn wait_for_callback(
listener: TcpListener,
path_prefix: &str,
param_name: &str,
display_name: &str,
) -> Result<String, OAuthCallbackError> {
let path_prefix = path_prefix.to_string();
let param_name = param_name.to_string();
let display_name = display_name.to_string();
tokio::time::timeout(Duration::from_secs(300), async move {
loop {
let (mut socket, _) = listener
.accept()
.await
.map_err(|e| OAuthCallbackError::Io(e.to_string()))?;
let mut reader = BufReader::new(&mut socket);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.await
.map_err(|e| OAuthCallbackError::Io(e.to_string()))?;
if let Some(path) = request_line.split_whitespace().nth(1)
&& path.starts_with(&path_prefix)
&& let Some(query) = path.split('?').nth(1)
{
if query.contains("error=") {
let html = landing_html(&display_name, false);
let response = format!(
"HTTP/1.1 400 Bad Request\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: close\r\n\
\r\n\
{}",
html
);
let _ = socket.write_all(response.as_bytes()).await;
return Err(OAuthCallbackError::Denied);
}
for param in query.split('&') {
let parts: Vec<&str> = param.splitn(2, '=').collect();
if parts.len() == 2 && parts[0] == param_name {
let value = urlencoding::decode(parts[1])
.unwrap_or_else(|_| parts[1].into())
.into_owned();
let html = landing_html(&display_name, true);
let response = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: close\r\n\
\r\n\
{}",
html
);
let _ = socket.write_all(response.as_bytes()).await;
let _ = socket.shutdown().await;
return Ok(value);
}
}
}
let response = "HTTP/1.1 404 Not Found\r\nConnection: close\r\n\r\n";
let _ = socket.write_all(response.as_bytes()).await;
}
})
.await
.map_err(|_| OAuthCallbackError::Timeout)?
}
fn html_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'"' => out.push_str("""),
'\'' => out.push_str("'"),
_ => out.push(c),
}
}
out
}
pub fn landing_html(provider_name: &str, success: bool) -> String {
let safe_name = html_escape(provider_name);
let (icon, heading, subtitle, accent) = if success {
(
r##"<div style="width:64px;height:64px;border-radius:50%;background:#22c55e;display:flex;align-items:center;justify-content:center;margin:0 auto 24px">
<svg width="32" height="32" viewBox="0 0 24 24" fill="none" stroke="#fff" stroke-width="3" stroke-linecap="round" stroke-linejoin="round"><polyline points="20 6 9 17 4 12"/></svg>
</div>"##,
format!("{} Connected", safe_name),
"You can close this window and return to your terminal.",
"#22c55e",
)
} else {
(
r##"<div style="width:64px;height:64px;border-radius:50%;background:#ef4444;display:flex;align-items:center;justify-content:center;margin:0 auto 24px">
<svg width="32" height="32" viewBox="0 0 24 24" fill="none" stroke="#fff" stroke-width="3" stroke-linecap="round" stroke-linejoin="round"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
</div>"##,
"Authorization Failed".to_string(),
"The request was denied. You can close this window and try again.",
"#ef4444",
)
};
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>IronClaw - {heading}</title>
<style>
* {{ margin:0; padding:0; box-sizing:border-box }}
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
background: #0a0a0a;
color: #e5e5e5;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}}
.card {{
text-align: center;
padding: 48px 40px;
max-width: 420px;
border: 1px solid #262626;
border-radius: 16px;
background: #141414;
}}
h1 {{
font-size: 22px;
font-weight: 600;
margin-bottom: 8px;
color: #fafafa;
}}
p {{
font-size: 14px;
color: #a3a3a3;
line-height: 1.5;
}}
.accent {{ color: {accent}; }}
.brand {{
margin-top: 32px;
font-size: 12px;
color: #525252;
letter-spacing: 0.5px;
text-transform: uppercase;
}}
</style>
</head>
<body>
<div class="card">
{icon}
<h1>{heading}</h1>
<p>{subtitle}</p>
<div class="brand">IronClaw</div>
</div>
</body>
</html>"#,
heading = heading,
icon = icon,
subtitle = subtitle,
accent = accent,
)
}
#[cfg(test)]
mod tests {
use crate::cli::oauth_defaults::{builtin_credentials, landing_html};
#[test]
fn test_unknown_provider_returns_none() {
assert!(builtin_credentials("unknown_token").is_none());
}
#[test]
fn test_google_returns_based_on_compile_env() {
let creds = builtin_credentials("google_oauth_token");
assert!(creds.is_some());
let creds = creds.unwrap();
assert!(!creds.client_id.is_empty());
assert!(!creds.client_secret.is_empty());
}
#[test]
fn test_landing_html_success_contains_key_elements() {
let html = landing_html("Google", true);
assert!(html.contains("Google Connected"));
assert!(html.contains("charset"));
assert!(html.contains("IronClaw"));
assert!(html.contains("#22c55e")); assert!(!html.contains("Failed"));
}
#[test]
fn test_landing_html_escapes_provider_name() {
let html = landing_html("<script>alert(1)</script>", true);
assert!(!html.contains("<script>"));
assert!(html.contains("<script>"));
}
#[test]
fn test_landing_html_error_contains_key_elements() {
let html = landing_html("Notion", false);
assert!(html.contains("Authorization Failed"));
assert!(html.contains("charset"));
assert!(html.contains("IronClaw"));
assert!(html.contains("#ef4444")); assert!(!html.contains("Connected"));
}
}