rustauth-axum 0.3.0

Axum integration for RustAuth.
Documentation
#![allow(dead_code)]

use axum::body::{to_bytes, Body};
use axum::http::{header, HeaderValue, Method, Request};
use rustauth::api::{create_auth_endpoint, ApiResponse, AsyncAuthEndpoint, AuthEndpointOptions};
use rustauth::db::{DbValue, MemoryAdapter};
use rustauth::error::RustAuthError;
use rustauth::oauth::oauth2::{
    OAuth2Tokens, OAuth2UserInfo, OAuthError, ProviderOptions, SocialAuthorizationCodeRequest,
    SocialAuthorizationUrlRequest, SocialIdTokenRequest, SocialOAuthProvider, SocialProviderFuture,
};
use rustauth::options::RustAuthOptions;
use rustauth::RustAuth;
use serde_json::Value;
use url::Url;

pub const SECRET: &str = "test-secret-123456789012345678901234";
pub const BODY_LIMIT: usize = 10 * 1024 * 1024;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ResponseExtensionMarker(pub &'static str);

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RequestExtensionMarker(pub &'static str);

fn with_test_defaults(options: RustAuthOptions) -> RustAuthOptions {
    rustauth_core::test_utils::with_integration_test_defaults(options)
}

pub async fn auth_with_options(
    options: RustAuthOptions,
) -> Result<RustAuth, rustauth::error::RustAuthError> {
    RustAuth::builder()
        .options(with_test_defaults(options))
        .secret(SECRET)
        .build()
        .await
}

pub async fn auth_with_adapter(
    adapter: MemoryAdapter,
    options: RustAuthOptions,
) -> Result<RustAuth, rustauth::error::RustAuthError> {
    RustAuth::builder()
        .options(with_test_defaults(options))
        .secret(SECRET)
        .adapter(adapter)
        .build()
        .await
}

pub async fn auth_with_async_endpoint(
    endpoint: AsyncAuthEndpoint,
) -> Result<RustAuth, rustauth::error::RustAuthError> {
    RustAuth::builder()
        .secret(SECRET)
        .async_endpoint(endpoint)
        .build()
        .await
}

pub fn custom_endpoint(path: &'static str) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new(),
        |_context, _request| async move {
            let mut response = ApiResponse::new(b"CUSTOM".to_vec());
            *response.status_mut() = axum::http::StatusCode::OK;
            Ok(response)
        },
    )
}

pub fn request_extension_endpoint(path: &'static str) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new(),
        |_context, request| async move {
            let marker = request
                .extensions()
                .get::<RequestExtensionMarker>()
                .map(|marker| marker.0)
                .unwrap_or("missing");
            let mut response = ApiResponse::new(format!("request={marker}").into_bytes());
            *response.status_mut() = axum::http::StatusCode::OK;
            Ok(response)
        },
    )
}

pub fn response_contract_endpoint(path: &'static str) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new(),
        |_context, request| async move {
            let query = request.uri().query().unwrap_or("");
            let mut response = ApiResponse::new(format!("query={query}").into_bytes());
            *response.status_mut() = axum::http::StatusCode::CREATED;
            *response.version_mut() = axum::http::Version::HTTP_2;
            response.headers_mut().append(
                header::SET_COOKIE,
                HeaderValue::from_static("a=1; Path=/; HttpOnly"),
            );
            response.headers_mut().append(
                header::SET_COOKIE,
                HeaderValue::from_static("b=2; Path=/; HttpOnly"),
            );
            response
                .headers_mut()
                .append("x-rustauth-test", HeaderValue::from_static("one"));
            response
                .headers_mut()
                .append("x-rustauth-test", HeaderValue::from_static("two"));
            response
                .extensions_mut()
                .insert(ResponseExtensionMarker("response-contract"));
            Ok(response)
        },
    )
}

pub fn empty_response_endpoint(path: &'static str) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new(),
        |_context, _request| async move {
            let mut response = ApiResponse::new(Vec::new());
            *response.status_mut() = axum::http::StatusCode::NO_CONTENT;
            Ok(response)
        },
    )
}

pub fn failing_endpoint(path: &'static str) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new(),
        |_context, _request| async move {
            Err(RustAuthError::Api("simulated internal failure".to_owned()))
        },
    )
}

pub fn json_request(
    method: Method,
    path: &str,
    body: &str,
    cookie: Option<&str>,
) -> Result<Request<Body>, axum::http::Error> {
    request(method, path, body, cookie)?.with_header(header::CONTENT_TYPE, "application/json")
}

pub fn request(
    method: Method,
    path: &str,
    body: &str,
    cookie: Option<&str>,
) -> Result<Request<Body>, axum::http::Error> {
    let mut builder = Request::builder().method(method).uri(path);
    if let Some(cookie) = cookie {
        builder = builder.header(header::COOKIE, cookie);
    }
    builder.body(Body::from(body.to_owned()))
}

pub trait RequestHeaderExt {
    fn with_header(
        self,
        name: header::HeaderName,
        value: &'static str,
    ) -> Result<Request<Body>, axum::http::Error>;
}

impl RequestHeaderExt for Request<Body> {
    fn with_header(
        self,
        name: header::HeaderName,
        value: &'static str,
    ) -> Result<Request<Body>, axum::http::Error> {
        let (mut parts, body) = self.into_parts();
        parts.headers.insert(name, HeaderValue::from_static(value));
        Ok(Request::from_parts(parts, body))
    }
}

pub async fn body_json(
    response: axum::response::Response,
) -> Result<Value, Box<dyn std::error::Error>> {
    let bytes = to_bytes(response.into_body(), BODY_LIMIT).await?;
    Ok(serde_json::from_slice(&bytes)?)
}

pub async fn body_text(
    response: axum::response::Response,
) -> Result<String, Box<dyn std::error::Error>> {
    let bytes = to_bytes(response.into_body(), BODY_LIMIT).await?;
    Ok(String::from_utf8(bytes.to_vec())?)
}

pub fn cookie_header(response: &axum::response::Response) -> Option<String> {
    let cookies = response
        .headers()
        .get_all(header::SET_COOKIE)
        .iter()
        .filter_map(|value| value.to_str().ok())
        .filter_map(|value| value.split(';').next().map(str::to_owned))
        .collect::<Vec<_>>();
    (!cookies.is_empty()).then(|| cookies.join("; "))
}

pub async fn wait_for_mutex_option<T: Clone>(
    value: &std::sync::Mutex<Option<T>>,
) -> Result<T, Box<dyn std::error::Error>> {
    for _ in 0..200 {
        if let Some(value) = value.lock().ok().and_then(|guard| guard.clone()) {
            return Ok(value);
        }
        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
    }
    Err("missing captured outbound value".into())
}

pub async fn reset_token(adapter: &MemoryAdapter) -> Result<String, Box<dyn std::error::Error>> {
    let records = adapter.records("verification").await;
    let record = records.first().ok_or("missing verification")?;
    let identifier = match record.get("identifier") {
        Some(DbValue::String(identifier)) => identifier,
        _ => return Err("missing verification identifier".into()),
    };
    let token = identifier
        .strip_prefix("reset-password:")
        .ok_or("unexpected verification identifier")?;
    Ok(token.to_owned())
}

pub fn query_value(url: &str, key: &str) -> Option<String> {
    Url::parse(url)
        .ok()?
        .query_pairs()
        .find_map(|(name, value)| (name == key).then(|| value.into_owned()))
}

#[derive(Debug)]
pub struct FakeProvider {
    id: String,
    options: ProviderOptions,
}

impl FakeProvider {
    pub fn new(id: &str) -> Self {
        Self {
            id: id.to_owned(),
            options: ProviderOptions::default(),
        }
    }
}

impl SocialOAuthProvider for FakeProvider {
    fn id(&self) -> &str {
        &self.id
    }

    fn name(&self) -> &str {
        &self.id
    }

    fn provider_options(&self) -> ProviderOptions {
        self.options.clone()
    }

    fn create_authorization_url(
        &self,
        request: SocialAuthorizationUrlRequest,
    ) -> Result<Url, OAuthError> {
        let mut url = Url::parse("https://provider.example/authorize")?;
        url.query_pairs_mut()
            .append_pair("state", &request.state)
            .append_pair("redirect_uri", &request.redirect_uri);
        Ok(url)
    }

    fn validate_authorization_code<'a>(
        &'a self,
        _request: SocialAuthorizationCodeRequest,
    ) -> SocialProviderFuture<'a, OAuth2Tokens> {
        Box::pin(async {
            Ok(OAuth2Tokens {
                token_type: Some("Bearer".to_owned()),
                access_token: Some("access-token".to_owned()),
                refresh_token: None,
                access_token_expires_at: None,
                refresh_token_expires_at: None,
                scopes: Vec::new(),
                id_token: None,
                raw: Value::Null,
            })
        })
    }

    fn verify_id_token<'a>(
        &'a self,
        _request: SocialIdTokenRequest,
    ) -> SocialProviderFuture<'a, bool> {
        Box::pin(async { Ok(true) })
    }

    fn get_user_info<'a>(
        &'a self,
        _tokens: OAuth2Tokens,
        _provider_user: Option<Value>,
    ) -> SocialProviderFuture<'a, Option<OAuth2UserInfo>> {
        Box::pin(async {
            Ok(Some(OAuth2UserInfo {
                id: "provider-user-1".to_owned(),
                name: Some("Ada".to_owned()),
                email: Some("ada@example.com".to_owned()),
                image: None,
                email_verified: true,
            }))
        })
    }

    fn refresh_access_token<'a>(
        &'a self,
        refresh_token: String,
    ) -> SocialProviderFuture<'a, OAuth2Tokens> {
        Box::pin(async move {
            if refresh_token != "stored-refresh-token" {
                return Err(OAuthError::InvalidResponse("bad refresh token".to_owned()));
            }
            Ok(OAuth2Tokens {
                token_type: Some("Bearer".to_owned()),
                access_token: Some("new-access-token".to_owned()),
                refresh_token: Some("new-refresh-token".to_owned()),
                access_token_expires_at: None,
                refresh_token_expires_at: None,
                scopes: vec!["read:user".to_owned()],
                id_token: Some("new-id-token".to_owned()),
                raw: Value::Null,
            })
        })
    }
}