use axum::Router;
use axum::extract::Query;
use axum::response::Html;
use axum::routing::get;
use miette::Result;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
type CallbackResult = Result<String, String>;
type CallbackSender = Arc<mpsc::Sender<CallbackResult>>;
pub async fn start_local_server(
port: u16,
) -> Result<(mpsc::Receiver<CallbackResult>, JoinHandle<()>)> {
let (tx, rx) = mpsc::channel::<CallbackResult>(1);
let tx: CallbackSender = Arc::new(tx);
let app = Router::new().route("/callback", get(move |p| handle_callback(p, tx)));
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await.map_err(|e| {
miette::miette!(
help = "Try using the `--port` flag (e.g., `ez-token login --port 3001`)",
code = "ez_token::server::bind",
"Failed to bind to port {}: {}",
port,
e
)
})?;
let handle = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app).await {
eprintln!("Local server error: {}", e);
}
});
Ok((rx, handle))
}
async fn handle_callback(
Query(params): Query<HashMap<String, String>>,
tx: CallbackSender,
) -> Html<&'static str> {
let result = extract_callback(¶ms);
let html = match &result {
Ok(_) => Html(
r#"<script>window.close();</script><h1>Login Successful, you can close this window.</h1>"#,
),
Err(_) => Html("<h1>Login Failed</h1><p>You can close this window.</p>"),
};
let _ = tx.send(result).await;
html
}
fn extract_callback(params: &HashMap<String, String>) -> CallbackResult {
if let Some(code) = params.get("code") {
return Ok(code.clone());
}
if let Some(error) = params.get("error") {
return Err(params.get("error_description").unwrap_or(error).to_string());
}
Err("Callback received neither code nor error details".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_callback_success_code() {
let mut params = HashMap::new();
params.insert("code".to_string(), "success_code_123".to_string());
params.insert("state".to_string(), "xyz".to_string());
let result = extract_callback(¶ms);
assert_eq!(result, Ok("success_code_123".to_string()));
}
#[test]
fn test_extract_callback_error_with_description() {
let mut params = HashMap::new();
params.insert("error".to_string(), "access_denied".to_string());
params.insert(
"error_description".to_string(),
"User denied access".to_string(),
);
let result = extract_callback(¶ms);
assert_eq!(result, Err("User denied access".to_string()));
}
}