use super::types::OAuthError;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::time::{timeout, Duration};
#[derive(Debug, Clone)]
pub struct CallbackResult {
pub code: String,
#[allow(dead_code)]
pub state: String,
}
pub async fn run_callback_server(
port: u16,
expected_state: String,
timeout_secs: u64,
) -> Result<CallbackResult, OAuthError> {
let bind_addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(&bind_addr)
.await
.map_err(|e| OAuthError::ServerError(format!("Failed to bind to port {}: {}", port, e)))?;
let actual_port = listener.local_addr().map(|a| a.port()).unwrap_or(port);
println!(
"Listening for OAuth callback on http://127.0.0.1:{}",
actual_port
);
let result: Arc<Mutex<Option<Result<CallbackResult, OAuthError>>>> = Arc::new(Mutex::new(None));
let server_result = result.clone();
let server_task = async move {
loop {
let (mut socket, _) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
let mut res = server_result.lock().unwrap();
*res = Some(Err(OAuthError::ServerError(format!(
"Failed to accept connection: {}",
e
))));
break;
}
};
let mut buffer = vec![0; 4096];
let n = match socket.read(&mut buffer).await {
Ok(n) if n > 0 => n,
_ => continue,
};
let request = String::from_utf8_lossy(&buffer[..n]);
if let Some(first_line) = request.lines().next() {
if let Some(path_part) = first_line.split_whitespace().nth(1) {
if let Some(query_start) = path_part.find('?') {
let query = &path_part[query_start + 1..];
let params = parse_query_string(query);
let response = if let (Some(code), Some(state)) =
(params.get("code"), params.get("state"))
{
if state != &expected_state {
let mut res = server_result.lock().unwrap();
*res = Some(Err(OAuthError::StateMismatch {
expected: expected_state.clone(),
actual: state.clone(),
}));
create_error_response("State mismatch - possible CSRF attack")
} else {
let mut res = server_result.lock().unwrap();
*res = Some(Ok(CallbackResult {
code: code.clone(),
state: state.clone(),
}));
create_success_response()
}
} else if let Some(error) = params.get("error") {
let mut res = server_result.lock().unwrap();
*res = Some(Err(OAuthError::SlackError(error.clone())));
create_error_response(&format!("OAuth error: {}", error))
} else {
create_error_response("Missing required parameters")
};
let _ = socket.write_all(response.as_bytes()).await;
let _ = socket.flush().await;
break;
}
}
}
}
};
match timeout(Duration::from_secs(timeout_secs), server_task).await {
Ok(_) => {
let res = result.lock().unwrap();
match res.as_ref() {
Some(Ok(callback_result)) => Ok(callback_result.clone()),
Some(Err(e)) => Err(format_oauth_error(e)),
None => Err(OAuthError::ServerError("No result received".to_string())),
}
}
Err(_) => Err(OAuthError::ServerError(format!(
"Timeout after {} seconds waiting for callback",
timeout_secs
))),
}
}
fn format_oauth_error(err: &OAuthError) -> OAuthError {
match err {
OAuthError::ConfigError(msg) => OAuthError::ConfigError(msg.clone()),
OAuthError::NetworkError(msg) => OAuthError::NetworkError(msg.clone()),
OAuthError::HttpError(code, msg) => OAuthError::HttpError(*code, msg.clone()),
OAuthError::ParseError(msg) => OAuthError::ParseError(msg.clone()),
OAuthError::SlackError(msg) => OAuthError::SlackError(msg.clone()),
OAuthError::StateMismatch { expected, actual } => OAuthError::StateMismatch {
expected: expected.clone(),
actual: actual.clone(),
},
OAuthError::ServerError(msg) => OAuthError::ServerError(msg.clone()),
OAuthError::BrowserError(msg) => OAuthError::BrowserError(msg.clone()),
}
}
fn parse_query_string(query: &str) -> HashMap<String, String> {
query
.split('&')
.filter_map(|pair| {
let mut parts = pair.split('=');
match (parts.next(), parts.next()) {
(Some(key), Some(value)) => Some((
key.to_string(),
urlencoding::decode(value).ok()?.to_string(),
)),
_ => None,
}
})
.collect()
}
fn create_success_response() -> String {
"HTTP/1.1 200 OK\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: close\r\n\
\r\n\
<html>\
<head><title>Authentication Successful</title></head>\
<body>\
<h1>✓ Authentication Successful</h1>\
<p>You can close this window and return to the CLI.</p>\
</body>\
</html>"
.to_string()
}
fn create_error_response(message: &str) -> String {
format!(
"HTTP/1.1 400 Bad Request\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: close\r\n\
\r\n\
<html>\
<head><title>Authentication Failed</title></head>\
<body>\
<h1>✗ Authentication Failed</h1>\
<p>{}</p>\
</body>\
</html>",
message
)
}
mod urlencoding {
pub fn decode(s: &str) -> Result<String, ()> {
let mut result = String::new();
let mut chars = s.chars();
while let Some(c) = chars.next() {
match c {
'%' => {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
} else {
return Err(());
}
} else {
return Err(());
}
}
'+' => result.push(' '),
c => result.push(c),
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_query_string() {
let query = "code=test_code&state=test_state&foo=bar";
let params = parse_query_string(query);
assert_eq!(params.get("code"), Some(&"test_code".to_string()));
assert_eq!(params.get("state"), Some(&"test_state".to_string()));
assert_eq!(params.get("foo"), Some(&"bar".to_string()));
}
#[test]
fn test_parse_query_string_with_encoding() {
let query = "message=hello+world&name=test%20user";
let params = parse_query_string(query);
assert_eq!(params.get("message"), Some(&"hello world".to_string()));
assert_eq!(params.get("name"), Some(&"test user".to_string()));
}
#[tokio::test]
async fn test_callback_server_timeout() {
let state = "test_state".to_string();
let result = run_callback_server(0, state, 1).await;
assert!(result.is_err());
match result {
Err(OAuthError::ServerError(msg)) => {
assert!(msg.contains("Timeout"));
}
_ => panic!("Expected ServerError with timeout"),
}
}
}