rustauth-plugins 0.3.0

Official RustAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, Method, StatusCode};
use rustauth_core::db::{DbValue, MemoryAdapter};
use rustauth_core::options::RustAuthOptions;
use serde_json::Value;

use super::helpers::{
    find_user_by_email, json_request, router_with_plugin_options, set_cookie_values, FakeProvider,
};
use rustauth_plugins::last_login_method::{LastLoginMethodOptions, DEFAULT_COOKIE_NAME};

#[tokio::test]
async fn oauth_callback_persists_provider_id_and_sets_last_method_cookie(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = Arc::new(MemoryAdapter::new());
    let router = router_with_plugin_options(
        adapter.clone(),
        LastLoginMethodOptions::default(),
        RustAuthOptions {
            base_url: Some("http://localhost:3000/api/auth".to_owned()),
            social_providers: vec![Arc::new(FakeProvider::new("github"))],
            ..RustAuthOptions::default()
        },
    )?;
    let sign_in = router
        .handle_async(json_request(
            Method::POST,
            "/api/auth/sign-in/social",
            r#"{"provider":"github","callbackURL":"/dashboard"}"#,
            None,
        )?)
        .await?;
    let body: Value = serde_json::from_slice(sign_in.body())?;
    let state =
        query_value(body["url"].as_str().ok_or("missing url")?, "state").ok_or("missing state")?;
    let oauth_state_cookie = oauth_state_cookie_header(&sign_in)?;

    let callback = router
        .handle_async(json_request(
            Method::GET,
            &format!("/api/auth/callback/github?code=ok&state={state}"),
            "",
            Some(&oauth_state_cookie),
        )?)
        .await?;

    assert_eq!(callback.status(), StatusCode::FOUND);
    assert!(set_cookie_values(&callback)
        .iter()
        .any(|cookie| cookie.starts_with(&format!("{DEFAULT_COOKIE_NAME}=github"))));
    let user = find_user_by_email(adapter.as_ref(), "ada.oauth@example.com")
        .await?
        .ok_or("missing OAuth user")?;
    assert_eq!(
        user.get("last_login_method"),
        Some(&DbValue::String("github".to_owned()))
    );
    Ok(())
}

#[tokio::test]
async fn failed_oauth_callback_does_not_persist_or_set_cookie(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = Arc::new(MemoryAdapter::new());
    let router = router_with_plugin_options(
        adapter.clone(),
        LastLoginMethodOptions::default(),
        RustAuthOptions {
            base_url: Some("http://localhost:3000/api/auth".to_owned()),
            social_providers: vec![Arc::new(FakeProvider::new("github"))],
            ..RustAuthOptions::default()
        },
    )?;
    let sign_in = router
        .handle_async(json_request(
            Method::POST,
            "/api/auth/sign-in/social",
            r#"{"provider":"github","callbackURL":"/dashboard","errorCallbackURL":"/error"}"#,
            None,
        )?)
        .await?;
    let body: Value = serde_json::from_slice(sign_in.body())?;
    let state =
        query_value(body["url"].as_str().ok_or("missing url")?, "state").ok_or("missing state")?;
    let oauth_state_cookie = oauth_state_cookie_header(&sign_in)?;

    let callback = router
        .handle_async(json_request(
            Method::GET,
            &format!("/api/auth/callback/github?error=access_denied&state={state}"),
            "",
            Some(&oauth_state_cookie),
        )?)
        .await?;

    assert_eq!(callback.status(), StatusCode::FOUND);
    assert!(set_cookie_values(&callback)
        .iter()
        .all(|cookie| !cookie.starts_with(DEFAULT_COOKIE_NAME)));
    assert_eq!(
        callback
            .headers()
            .get(header::LOCATION)
            .ok_or("missing location")?,
        "/error?error=access_denied"
    );
    assert!(
        find_user_by_email(adapter.as_ref(), "ada.oauth@example.com")
            .await?
            .is_none()
    );
    Ok(())
}

fn query_value(url: &str, key: &str) -> Option<String> {
    url::Url::parse(url)
        .ok()?
        .query_pairs()
        .find_map(|(name, value)| (name == key).then(|| value.into_owned()))
}

fn oauth_state_cookie_header(
    response: &http::Response<Vec<u8>>,
) -> Result<String, Box<dyn std::error::Error>> {
    response
        .headers()
        .get_all(header::SET_COOKIE)
        .iter()
        .filter_map(|value| value.to_str().ok())
        .find_map(|cookie| {
            let (name, rest) = cookie.split_once('=')?;
            (name == "rustauth.oauth_state" || name == "__Secure-rustauth.oauth_state").then(|| {
                let value = rest.split(';').next().unwrap_or_default();
                format!("{name}={value}")
            })
        })
        .ok_or_else(|| "missing oauth_state cookie".into())
}