ez-token 0.1.0

CLI tool for generating OAuth2 access tokens via PKCE and Client Credentials for Microsoft Entra ID and Auth0
Documentation
use crate::cli::output::{
    AppEmoji, finish_spinner_error, finish_spinner_success, print_step, start_spinner,
};
use crate::services::authentication::authenticator::Authenticator;
use crate::services::authentication::urls::IdentityProvider;
use crate::services::http_client::client::create_http_client;
use crate::services::local_server::server::start_local_server;
use miette::{Context, IntoDiagnostic, Result};
use oauth2::{
    AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
    TokenResponse, TokenUrl, basic::BasicClient,
};
use tokio::sync::mpsc::Receiver;

/// An OAuth2 Authorization Code flow with PKCE for supported identity providers.
///
/// This flow is designed for interactive user authentication. It opens the
/// system browser to the provider's authorization endpoint, starts a local
/// HTTP server to receive the callback, and exchanges the authorization
/// code for an access token.
///
/// # Provider Requirements
///
/// ## Microsoft Entra ID
/// - The application must have **Delegated permissions**
/// - The redirect URI `http://localhost:{port}/callback` must be registered
///   under **Mobile and desktop applications** in the Azure Portal
/// - Public client flows must be enabled
///
/// ## Auth0
/// - **Application Type** must be set to **Native**
/// - The redirect URI `http://localhost:{port}/callback` must be listed
///   under **Allowed Callback URLs**
/// - **Token Endpoint Authentication Method** must be set to **None**
/// - **Non-Verifiable Callback URI End-User Confirmation** must be **off**
pub struct AuthorizationCodeFlow {
    /// The resolved identity provider with all required endpoint data.
    pub provider: IdentityProvider,

    /// The Application (Client) ID registered in Entra ID.
    pub client_id: String,

    /// The list of OAuth2 scopes to request.
    ///
    /// For Microsoft use delegated scopes (e.g. `User.Read`).
    /// For Auth0 use explicit scopes (e.g. `read:ez`).
    pub scopes: Vec<String>,

    /// The local port to listen on for the OAuth2 redirect callback.
    ///
    /// Must match the redirect URI registered with your identity provider
    /// (e.g. `http://localhost:3000/callback`).
    pub port: u16,
}

impl Authenticator for AuthorizationCodeFlow {
    /// Performs the full Authorization Code + PKCE flow and returns an access token.
    ///
    /// **Steps:**
    /// 1. Builds the authorization URL with a PKCE challenge
    /// 2. Starts a local HTTP server to receive the callback
    /// 3. Opens the system browser to the authorization URL
    /// 4. Waits for the authorization code via the local server (120s timeout)
    /// 5. Exchanges the code for an access token
    ///
    /// For Auth0, an `audience` parameter is included automatically in the
    /// authorization request.
    ///
    /// # Errors
    ///
    /// Returns an error if:
    /// - The provider produces an invalid URL
    /// - The local server fails to bind to the given port
    /// - The browser authorization is denied or times out
    /// - The token exchange with the identity provider fails
    async fn get_token(&self) -> Result<String> {
        let auth_uri = AuthUrl::new(self.provider.auth_url())
            .into_diagnostic()
            .wrap_err("Invalid authorization URL")?;

        let token_uri = TokenUrl::new(self.provider.token_url())
            .into_diagnostic()
            .wrap_err("Invalid token URL")?;

        let redirect_url = RedirectUrl::new(format!("http://localhost:{}/callback", self.port))
            .into_diagnostic()?;

        let client = BasicClient::new(ClientId::new(self.client_id.clone()))
            .set_auth_uri(auth_uri)
            .set_token_uri(token_uri)
            .set_redirect_uri(redirect_url);

        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();

        let mut auth_req = client
            .authorize_url(CsrfToken::new_random)
            .set_pkce_challenge(pkce_challenge);

        if let Some(audience) = self.provider.audience() {
            auth_req = auth_req.add_extra_param("audience", audience);
        }

        for scope in &self.scopes {
            auth_req = auth_req.add_scope(Scope::new(scope.clone()));
        }

        let (authorize_url, _) = auth_req.url();
        let (mut rx, server_handle) = start_local_server(self.port).await?;

        print_step(AppEmoji::Rocket, "Opening browser...");
        if webbrowser::open(authorize_url.as_str()).is_err() {
            println!("Please open: {}", authorize_url);
        }

        let code = self.wait_for_code(&mut rx).await?;
        server_handle.abort();

        let http_client = create_http_client()?;
        let token_result = client
            .exchange_code(AuthorizationCode::new(code))
            .set_pkce_verifier(pkce_verifier)
            .request_async(&http_client)
            .await
            .into_diagnostic()
            .wrap_err("Failed to exchange Authorization Code for Access Token")?;

        Ok(token_result.access_token().secret().clone())
    }
}

impl AuthorizationCodeFlow {
    /// Waits for the authorization code to arrive via the local callback server.
    ///
    /// Displays a spinner while waiting. The code is sent over a channel by
    /// the local HTTP server once the browser completes the authorization.
    /// Times out after 120 seconds — this covers cases where the identity
    /// provider shows an error page instead of redirecting back to the callback.
    ///
    /// # Errors
    ///
    /// Returns an error if:
    /// - Authentication times out after 120 seconds
    /// - The channel closes before a result is received
    /// - The browser authorization is denied or the provider returns an error
    async fn wait_for_code(&self, rx: &mut Receiver<Result<String, String>>) -> Result<String> {
        let spinner = start_spinner("Waiting for authentication...")?;

        let result = tokio::time::timeout(std::time::Duration::from_secs(120), rx.recv())
            .await
            .map_err(|_| {
                miette::miette!(
                    help = "Check your browser and try again",
                    "Authentication timed out after 120 seconds"
                )
            })?
            .ok_or_else(|| miette::miette!("Failed to receive communication from local server"))?;

        match result {
            Ok(code) => {
                finish_spinner_success(&spinner, "Authentication successful!");
                Ok(code)
            }
            Err(err_msg) => {
                finish_spinner_error(&spinner, "Authentication failed!");
                Err(miette::miette!("Browser authentication error: {}", err_msg))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio::sync::mpsc;

    fn create_dummy_flow() -> AuthorizationCodeFlow {
        AuthorizationCodeFlow {
            provider: IdentityProvider::Microsoft {
                tenant_id: "common".to_string(),
            },
            client_id: "dummy_client".to_string(),
            scopes: vec![],
            port: 3000,
        }
    }

    #[tokio::test]
    async fn test_wait_for_code_success() {
        console::set_colors_enabled(false);
        let flow = create_dummy_flow();

        let (tx, mut rx) = mpsc::channel(1);

        tx.send(Ok("valid_auth_code_123".to_string()))
            .await
            .unwrap();

        let result = flow.wait_for_code(&mut rx).await;

        assert!(result.is_ok());
        assert_eq!(result.unwrap(), "valid_auth_code_123");
    }

    #[tokio::test]
    async fn test_wait_for_code_server_error() {
        console::set_colors_enabled(false);
        let flow = create_dummy_flow();
        let (tx, mut rx) = mpsc::channel(1);

        tx.send(Err("access_denied".to_string())).await.unwrap();

        let result = flow.wait_for_code(&mut rx).await;

        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("access_denied"));
    }

    #[tokio::test]
    async fn test_wait_for_code_channel_dropped_prematurely() {
        console::set_colors_enabled(false);
        let flow = create_dummy_flow();
        let (tx, mut rx) = mpsc::channel::<Result<String, String>>(1);

        drop(tx);

        let result = flow.wait_for_code(&mut rx).await;

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Failed to receive communication")
        );
    }

    #[tokio::test]
    async fn test_wait_for_code_timeout() {
        console::set_colors_enabled(false);
        let flow = create_dummy_flow();
        let (_tx, mut rx) = mpsc::channel::<Result<String, String>>(1);

        tokio::time::pause();

        let result = flow.wait_for_code(&mut rx).await;

        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("timed out"));
    }
}