rustauth-axum 0.3.0

Axum integration for RustAuth.
Documentation
mod common;

use std::sync::Arc;

use axum::http::{Method, StatusCode};
use axum::Router;
use common::*;
use rustauth::db::MemoryAdapter;
use rustauth::options::{AdvancedOptions, DeleteUserOptions, RustAuthOptions, UserOptions};
use rustauth::plugin::AuthPlugin;
use rustauth::RustAuth;
use rustauth_axum::{RustAuthAxumError, RustAuthAxumExt, RustAuthAxumOptions};
use tower::ServiceExt;

#[tokio::test]
async fn ok_route_is_mounted_under_default_base_path() -> Result<(), Box<dyn std::error::Error>> {
    let app = auth_with_options(RustAuthOptions::default())
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/api/auth/ok", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "OK");
    Ok(())
}

#[tokio::test]
async fn default_base_path_accepts_trailing_slash_root() -> Result<(), Box<dyn std::error::Error>> {
    let app = auth_with_options(RustAuthOptions::default())
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let root_without_slash = app
        .clone()
        .oneshot(request(Method::GET, "/api/auth", "", None)?)
        .await?;
    assert_eq!(root_without_slash.status(), StatusCode::NOT_FOUND);

    let root_with_slash = app
        .oneshot(request(Method::GET, "/api/auth/", "", None)?)
        .await?;
    assert_eq!(root_with_slash.status(), StatusCode::NOT_FOUND);
    Ok(())
}

#[tokio::test]
async fn skip_trailing_slashes_reaches_core_routes_over_axum(
) -> Result<(), Box<dyn std::error::Error>> {
    let app = auth_with_options(
        RustAuthOptions::default().advanced(AdvancedOptions::new().skip_trailing_slashes(true)),
    )
    .await?
    .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/api/auth/ok/", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "OK");
    Ok(())
}

#[tokio::test]
async fn custom_base_path_mounts_all_auth_routes() -> Result<(), Box<dyn std::error::Error>> {
    let app = RustAuth::builder()
        .secret(SECRET)
        .base_path("/auth")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/auth/ok", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    Ok(())
}

#[tokio::test]
async fn root_base_path_mounts_auth_routes_at_root() -> Result<(), Box<dyn std::error::Error>> {
    let app = RustAuth::builder()
        .secret(SECRET)
        .base_path("/")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app.oneshot(request(Method::GET, "/ok", "", None)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    Ok(())
}

#[tokio::test]
async fn empty_base_path_mounts_auth_routes_at_root() -> Result<(), Box<dyn std::error::Error>> {
    let app = RustAuth::builder()
        .secret(SECRET)
        .base_path("")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app.oneshot(request(Method::GET, "/ok", "", None)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    Ok(())
}

#[tokio::test]
async fn trailing_slash_base_path_is_mounted_without_panicking(
) -> Result<(), Box<dyn std::error::Error>> {
    let app = RustAuth::builder()
        .secret(SECRET)
        .base_path("/api/auth/")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/api/auth/ok", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    Ok(())
}

#[tokio::test]
async fn invalid_base_paths_are_rejected_before_mounting() -> Result<(), Box<dyn std::error::Error>>
{
    for base_path in [
        "api/auth",
        "/api/{auth}",
        "/api/*auth",
        "/api/auth?x=1",
        "/api/auth#x",
    ] {
        let result = RustAuth::builder()
            .secret(SECRET)
            .base_path(base_path)
            .build()
            .await?
            .mount_at_base_path(RustAuthAxumOptions::default());
        let Err(error) = result else {
            return Err(std::io::Error::other(format!("{base_path} should be rejected")).into());
        };
        assert!(
            matches!(error, RustAuthAxumError::InvalidBasePath(_)),
            "{base_path} should produce InvalidBasePath"
        );
    }
    Ok(())
}

#[tokio::test]
async fn invalid_base_url_is_rejected_before_mounting() -> Result<(), Box<dyn std::error::Error>> {
    let result = RustAuth::builder()
        .secret(SECRET)
        .base_path("/api/auth")
        .base_url("not-a-url")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default());

    let Err(error) = result else {
        return Err(std::io::Error::other("invalid base_url should be rejected").into());
    };
    assert!(matches!(error, RustAuthAxumError::InvalidBaseUrl(_)));
    Ok(())
}

#[tokio::test]
async fn inconsistent_base_url_path_is_rejected_before_mounting(
) -> Result<(), Box<dyn std::error::Error>> {
    let result = RustAuth::builder()
        .secret(SECRET)
        .base_path("/api/auth")
        .base_url("http://localhost:3000/wrong")
        .build()
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default());

    let Err(error) = result else {
        return Err(std::io::Error::other("mismatched base_url should be rejected").into());
    };
    assert!(matches!(
        error,
        RustAuthAxumError::InconsistentBaseUrlPath { .. }
    ));
    Ok(())
}

#[tokio::test]
async fn non_auth_paths_and_wrong_methods_return_not_found(
) -> Result<(), Box<dyn std::error::Error>> {
    let app = auth_with_options(RustAuthOptions::default())
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let outside = app
        .clone()
        .oneshot(request(Method::GET, "/api/authentication/ok", "", None)?)
        .await?;
    assert_eq!(outside.status(), StatusCode::NOT_FOUND);

    let wrong_method = app
        .clone()
        .oneshot(request(Method::POST, "/api/auth/ok", "{}", None)?)
        .await?;
    assert_eq!(wrong_method.status(), StatusCode::NOT_FOUND);

    let head = app
        .clone()
        .oneshot(request(Method::HEAD, "/api/auth/ok", "", None)?)
        .await?;
    assert_eq!(head.status(), StatusCode::NOT_FOUND);

    let options = app
        .oneshot(request(Method::OPTIONS, "/api/auth/ok", "", None)?)
        .await?;
    assert_eq!(options.status(), StatusCode::NOT_FOUND);
    Ok(())
}

#[tokio::test]
async fn mount_routes_can_be_nested_without_consuming_auth(
) -> Result<(), Box<dyn std::error::Error>> {
    let auth = Arc::new(auth_with_options(RustAuthOptions::default()).await?);
    let app = Router::new().nest(
        "/api/auth",
        auth.mount_routes(RustAuthAxumOptions::default())?,
    );

    let response = app
        .oneshot(request(Method::GET, "/api/auth/ok", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "OK");
    assert!(!auth.endpoint_registry().is_empty());
    Ok(())
}

#[tokio::test]
async fn mount_routes_can_be_nested_manually_on_owned_auth(
) -> Result<(), Box<dyn std::error::Error>> {
    let auth = auth_with_options(RustAuthOptions::default()).await?;
    let app = Router::new().nest(
        "/api/auth",
        auth.mount_routes(RustAuthAxumOptions::default())?,
    );

    let response = app
        .oneshot(request(Method::GET, "/api/auth/ok", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "OK");
    Ok(())
}

#[tokio::test]
async fn extra_async_endpoint_is_reachable_through_catch_all(
) -> Result<(), Box<dyn std::error::Error>> {
    let app = auth_with_async_endpoint(custom_endpoint("/plugin/custom"))
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/api/auth/plugin/custom", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "CUSTOM");
    Ok(())
}

#[tokio::test]
async fn plugin_endpoint_is_reachable_through_catch_all() -> Result<(), Box<dyn std::error::Error>>
{
    let plugin = AuthPlugin::new("route-plugin").with_endpoint(custom_endpoint("/plugin/hello"));
    let app = auth_with_options(RustAuthOptions::default().plugin(plugin))
        .await?
        .mount_at_base_path(RustAuthAxumOptions::default())?;

    let response = app
        .oneshot(request(Method::GET, "/api/auth/plugin/hello", "", None)?)
        .await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(body_text(response).await?, "CUSTOM");
    Ok(())
}

#[tokio::test]
async fn every_core_auth_route_is_mounted_through_axum() -> Result<(), Box<dyn std::error::Error>> {
    let auth = auth_with_adapter(
        MemoryAdapter::new(),
        RustAuthOptions::default()
            .base_url("http://localhost:3000/api/auth")
            .user(UserOptions::default().delete_user(DeleteUserOptions::default().enabled(true)))
            .social_provider(FakeProvider::new("github")),
    )
    .await?;
    let cases = auth
        .endpoint_registry()
        .into_iter()
        .map(RouteCase::from_endpoint)
        .collect::<Vec<_>>();
    let app = auth.mount_at_base_path(RustAuthAxumOptions::default())?;

    for case in cases {
        let response = app
            .clone()
            .oneshot(request(
                case.method.clone(),
                &case.path,
                case.body,
                case.cookie,
            )?)
            .await?;
        assert_ne!(
            response.status(),
            StatusCode::NOT_FOUND,
            "{} {} should be mounted",
            case.method,
            case.path
        );
    }
    Ok(())
}

struct RouteCase {
    method: Method,
    path: String,
    body: &'static str,
    cookie: Option<&'static str>,
}

impl RouteCase {
    fn from_endpoint(endpoint: rustauth::api::EndpointInfo) -> Self {
        let path = materialize_route_path(&endpoint.path);
        let path = match endpoint.path.as_str() {
            "/callback/:id" => format!("{path}?state=missing"),
            "/error" => format!("{path}?error=invalid_request"),
            "/reset-password/:token" => format!("{path}?callbackURL=/reset"),
            "/verify-email" | "/delete-user/callback" => format!("{path}?token=missing"),
            _ => path,
        };
        let body = if endpoint.method == Method::POST {
            "{}"
        } else {
            ""
        };
        let cookie = (endpoint.path == "/sign-out").then_some("x=1");

        Self {
            method: endpoint.method,
            path,
            body,
            cookie,
        }
    }
}

fn materialize_route_path(path: &str) -> String {
    let path = path.replace(":id", "github").replace(":token", "missing");
    format!("/api/auth{path}")
}