axum_supabase_auth/middleware/
extractor.rs

1use super::{AuthState, Claims};
2use crate::AuthTypes;
3use axum::extract::{FromRef, FromRequestParts};
4use axum::http::request::Parts;
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::{async_trait, Json};
8use axum_extra::extract::CookieJar;
9use serde_json::json;
10use std::fmt::Debug;
11use tracing::{trace, warn, Span};
12
13pub type AuthClaims<T> =
14    Claims<<T as AuthTypes>::AppData, <T as AuthTypes>::UserData, <T as AuthTypes>::AdditionalData>;
15
16pub struct User<T: AuthTypes>(pub AuthClaims<T>);
17pub struct MaybeUser<T: AuthTypes>(pub Option<AuthClaims<T>>);
18
19pub const AUTH_COOKIE_NAME: &str = "portal-auth";
20pub const REFRESH_COOKIE_NAME: &str = "portal-refresh";
21pub const CSRF_VERIFIER_COOKIE_NAME: &str = "portal-token-code-verifier";
22
23#[async_trait]
24impl<S, T> FromRequestParts<S> for User<T>
25where
26    S: Send + Sync,
27    T: AuthTypes,
28    MaybeUser<T>: FromRequestParts<S, Rejection = AuthError>,
29{
30    type Rejection = AuthError;
31
32    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
33        let user = MaybeUser::from_request_parts(parts, state).await?;
34        if let Some(user) = user.0 {
35            Ok(User(user))
36        } else {
37            Err(AuthError::MissingCredentials)
38        }
39    }
40}
41
42#[async_trait]
43impl<S, T> FromRequestParts<S> for MaybeUser<T>
44where
45    S: Send + Sync,
46    T: AuthTypes,
47    AuthState<T>: FromRef<S>,
48{
49    type Rejection = AuthError;
50
51    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
52        let jar = match CookieJar::from_request_parts(parts, state).await {
53            Ok(jar) => jar,
54            Err(err) => match err {},
55        };
56
57        let state = AuthState::<T>::from_ref(state);
58
59        let token = jar.get(state.cookies().auth_cookie_name());
60        let token = match token {
61            Some(token) => token,
62            None => {
63                trace!("no auth cookie found");
64                return Ok(MaybeUser(None));
65            }
66        };
67
68        let claims = state.decode(token.value_trimmed()).map_err(|error| {
69            warn!(error = ?error, "invalid token");
70            AuthError::InvalidToken
71        })?;
72
73        trace!(claims = ?claims, "extracted user from cookie");
74        Span::current().record("user_id", &claims.sub);
75
76        Ok(MaybeUser(Some(claims)))
77    }
78}
79
80// error types for axum errors
81#[derive(Debug)]
82pub enum AuthError {
83    WrongCredentials,
84    MissingCredentials,
85    TokenCreation,
86    InvalidToken,
87}
88
89impl IntoResponse for AuthError {
90    fn into_response(self) -> Response {
91        let (status, error_message) = match self {
92            AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
93            AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
94            AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
95            AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
96        };
97        let body = Json(json!({ "error": error_message }));
98        (status, body).into_response()
99    }
100}