ez-token 0.1.0

CLI tool for generating OAuth2 access tokens via PKCE and Client Credentials for Microsoft Entra ID and Auth0
Documentation
use axum::Router;
use axum::extract::Query;
use axum::response::Html;
use axum::routing::get;
use miette::Result;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;

/// The result of an OAuth2 callback — either an authorization code or an error message.
type CallbackResult = Result<String, String>;

/// A thread-safe sender for communicating the callback result back to the main flow.
type CallbackSender = Arc<mpsc::Sender<CallbackResult>>;

/// Starts a local HTTP server to receive the OAuth2 authorization callback.
///
/// Binds to `127.0.0.1:{port}` and exposes a single `/callback` endpoint.
/// When the browser redirects to this endpoint after authorization, the
/// result is sent over the returned [`mpsc::Receiver`] channel.
///
/// The server runs in a background [`tokio`] task. The caller is responsible
/// for aborting the [`JoinHandle`] once the callback has been received.
///
/// # Arguments
///
/// * `port` - The local port to bind to. Must match the redirect URI
///   registered in Entra ID (e.g., `http://localhost:3000/callback`).
///
/// # Errors
///
/// Returns an error if the port is already in use or cannot be bound.
/// The returned error includes a helpful suggestion to use the `--port` flag
/// to specify an alternative.
pub async fn start_local_server(
    port: u16,
) -> Result<(mpsc::Receiver<CallbackResult>, JoinHandle<()>)> {
    let (tx, rx) = mpsc::channel::<CallbackResult>(1);
    let tx: CallbackSender = Arc::new(tx);

    let app = Router::new().route("/callback", get(move |p| handle_callback(p, tx)));

    let addr = SocketAddr::from(([127, 0, 0, 1], port));
    let listener = TcpListener::bind(addr).await.map_err(|e| {
        miette::miette!(
            help = "Try using the `--port` flag (e.g., `ez-token login --port 3001`)",
            code = "ez_token::server::bind",
            "Failed to bind to port {}: {}",
            port,
            e
        )
    })?;

    let handle = tokio::spawn(async move {
        if let Err(e) = axum::serve(listener, app).await {
            eprintln!("Local server error: {}", e);
        }
    });

    Ok((rx, handle))
}

/// Axum handler for the `/callback` endpoint.
///
/// Extracts the authorization result from the query parameters, sends it
/// over the channel to the waiting flow, and returns an HTML response
/// for the browser to display.
async fn handle_callback(
    Query(params): Query<HashMap<String, String>>,
    tx: CallbackSender,
) -> Html<&'static str> {
    let result = extract_callback(&params);
    let html = match &result {
        Ok(_) => Html(
            r#"<script>window.close();</script><h1>Login Successful, you can close this window.</h1>"#,
        ),
        Err(_) => Html("<h1>Login Failed</h1><p>You can close this window.</p>"),
    };
    let _ = tx.send(result).await;
    html
}

/// Parses the OAuth2 callback query parameters into a [`CallbackResult`].
///
/// Checks for the following query parameters in order:
/// - `code` — Authorization was successful; returns the code.
/// - `error` + optional `error_description` — Authorization failed; returns the description.
/// - Neither — Returns a generic error indicating an unexpected callback.
fn extract_callback(params: &HashMap<String, String>) -> CallbackResult {
    if let Some(code) = params.get("code") {
        return Ok(code.clone());
    }
    if let Some(error) = params.get("error") {
        return Err(params.get("error_description").unwrap_or(error).to_string());
    }
    Err("Callback received neither code nor error details".to_string())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_callback_success_code() {
        let mut params = HashMap::new();
        params.insert("code".to_string(), "success_code_123".to_string());
        params.insert("state".to_string(), "xyz".to_string()); // Ignored param

        let result = extract_callback(&params);
        assert_eq!(result, Ok("success_code_123".to_string()));
    }

    #[test]
    fn test_extract_callback_error_with_description() {
        let mut params = HashMap::new();
        params.insert("error".to_string(), "access_denied".to_string());
        params.insert(
            "error_description".to_string(),
            "User denied access".to_string(),
        );

        let result = extract_callback(&params);
        assert_eq!(result, Err("User denied access".to_string()));
    }
}