axum_jwt_sessions/extractors/
refresh.rs

1use axum::{
2    extract::FromRequestParts,
3    http::{header, request::Parts},
4};
5use serde::de::DeserializeOwned;
6use uuid::Uuid;
7
8use crate::{
9    error::{AuthError, Result},
10    middleware::AuthState,
11    refresher::SessionDataRefresher,
12    storage::SessionStorage,
13};
14
15/// Extractor that requires a valid refresh token.
16///
17/// This extractor validates the refresh token from either:
18/// - The `X-Refresh-Token` header
19/// - The `Authorization` header with `Bearer` prefix
20///
21/// If a valid refresh token is provided, the session is loaded.
22/// This is useful for sensitive operations that require extra authentication.
23pub struct RefreshSession<T> {
24    pub session_id: Uuid,
25    pub data: T,
26}
27
28/// Optional version of RefreshSession that doesn't fail if no refresh token is provided
29pub struct OptionalRefreshSession<T>(pub Option<RefreshSession<T>>);
30
31impl<S, R, T> FromRequestParts<AuthState<S, R>> for RefreshSession<T>
32where
33    S: SessionStorage,
34    R: SessionDataRefresher,
35    T: Send + Sync + DeserializeOwned,
36{
37    type Rejection = AuthError;
38
39    async fn from_request_parts(
40        parts: &mut Parts,
41        state: &AuthState<S, R>,
42    ) -> std::result::Result<Self, Self::Rejection> {
43        // Extract refresh token
44        let refresh_token = if let Some(header_value) = parts.headers.get("X-Refresh-Token") {
45            header_value
46                .to_str()
47                .map_err(|_| AuthError::InvalidAuthHeaderFormat)?
48                .to_string()
49        } else {
50            // Fall back to Authorization header
51            extract_token_from_authorization_header(parts)?
52        };
53
54        // Verify refresh token
55        let refresh_claims = state.token_generator.verify_refresh_token(&refresh_token)?;
56
57        // Check if refresh token exists in storage (for revocation)
58        let user_id = refresh_claims
59            .user_id
60            .as_ref()
61            .ok_or(AuthError::InvalidRefreshToken)?;
62
63        if !state
64            .storage
65            .user_session_exists(user_id, &refresh_claims.sub)
66            .await?
67        {
68            return Err(AuthError::SessionNotFound);
69        }
70
71        // For refresh token paths, we need to get session data from the access token
72        // Extract access token from Authorization header
73        let auth_header = parts
74            .headers
75            .get(header::AUTHORIZATION)
76            .ok_or(AuthError::MissingAuthHeader)?;
77
78        let auth_str = auth_header
79            .to_str()
80            .map_err(|_| AuthError::InvalidAuthHeaderFormat)?;
81
82        let token = auth_str
83            .strip_prefix("Bearer ")
84            .ok_or(AuthError::InvalidAuthHeaderFormat)?;
85
86        // Verify access token and extract session data
87        let access_claims = state
88            .token_generator
89            .verify_access_token(token)
90            .map_err(|_| AuthError::InvalidToken)?;
91
92        // Verify that the subjects match
93        if access_claims.sub != refresh_claims.sub {
94            return Err(AuthError::InvalidToken);
95        }
96
97        let session_data = access_claims
98            .session_data
99            .ok_or(AuthError::SessionNotFound)?;
100
101        let session_data =
102            serde_json::from_value::<T>(session_data).map_err(|_| AuthError::InvalidToken)?;
103
104        Ok(RefreshSession {
105            session_id: refresh_claims.sub,
106            data: session_data,
107        })
108    }
109}
110
111impl<S, R, T> FromRequestParts<AuthState<S, R>> for OptionalRefreshSession<T>
112where
113    S: SessionStorage,
114    R: SessionDataRefresher,
115    T: Send + Sync + DeserializeOwned,
116{
117    type Rejection = AuthError;
118
119    async fn from_request_parts(
120        parts: &mut Parts,
121        state: &AuthState<S, R>,
122    ) -> std::result::Result<Self, Self::Rejection> {
123        // Try to extract the RefreshSession using the existing implementation
124        match RefreshSession::<T>::from_request_parts(parts, state).await {
125            Ok(session) => Ok(OptionalRefreshSession(Some(session))),
126            Err(_) => Ok(OptionalRefreshSession(None)),
127        }
128    }
129}
130
131fn extract_token_from_authorization_header(parts: &Parts) -> Result<String> {
132    let header_value = parts
133        .headers
134        .get(header::AUTHORIZATION)
135        .ok_or(AuthError::MissingAuthHeader)?;
136
137    let auth_str = header_value
138        .to_str()
139        .map_err(|_| AuthError::InvalidAuthHeaderFormat)?;
140
141    if !auth_str.starts_with("Bearer ") {
142        return Err(AuthError::InvalidAuthHeaderFormat);
143    }
144
145    Ok(auth_str[7..].to_string())
146}