axum_jwt_sessions/extractors/
refresh.rs1use axum::{
2 extract::FromRequestParts,
3 http::{header, request::Parts},
4};
5use serde::de::DeserializeOwned;
6use uuid::Uuid;
7
8use crate::{
9 error::{AuthError, Result},
10 middleware::AuthState,
11 refresher::SessionDataRefresher,
12 storage::SessionStorage,
13};
14
15pub struct RefreshSession<T> {
24 pub session_id: Uuid,
25 pub data: T,
26}
27
28pub struct OptionalRefreshSession<T>(pub Option<RefreshSession<T>>);
30
31impl<S, R, T> FromRequestParts<AuthState<S, R>> for RefreshSession<T>
32where
33 S: SessionStorage,
34 R: SessionDataRefresher,
35 T: Send + Sync + DeserializeOwned,
36{
37 type Rejection = AuthError;
38
39 async fn from_request_parts(
40 parts: &mut Parts,
41 state: &AuthState<S, R>,
42 ) -> std::result::Result<Self, Self::Rejection> {
43 let refresh_token = if let Some(header_value) = parts.headers.get("X-Refresh-Token") {
45 header_value
46 .to_str()
47 .map_err(|_| AuthError::InvalidAuthHeaderFormat)?
48 .to_string()
49 } else {
50 extract_token_from_authorization_header(parts)?
52 };
53
54 let refresh_claims = state.token_generator.verify_refresh_token(&refresh_token)?;
56
57 let user_id = refresh_claims
59 .user_id
60 .as_ref()
61 .ok_or(AuthError::InvalidRefreshToken)?;
62
63 if !state
64 .storage
65 .user_session_exists(user_id, &refresh_claims.sub)
66 .await?
67 {
68 return Err(AuthError::SessionNotFound);
69 }
70
71 let auth_header = parts
74 .headers
75 .get(header::AUTHORIZATION)
76 .ok_or(AuthError::MissingAuthHeader)?;
77
78 let auth_str = auth_header
79 .to_str()
80 .map_err(|_| AuthError::InvalidAuthHeaderFormat)?;
81
82 let token = auth_str
83 .strip_prefix("Bearer ")
84 .ok_or(AuthError::InvalidAuthHeaderFormat)?;
85
86 let access_claims = state
88 .token_generator
89 .verify_access_token(token)
90 .map_err(|_| AuthError::InvalidToken)?;
91
92 if access_claims.sub != refresh_claims.sub {
94 return Err(AuthError::InvalidToken);
95 }
96
97 let session_data = access_claims
98 .session_data
99 .ok_or(AuthError::SessionNotFound)?;
100
101 let session_data =
102 serde_json::from_value::<T>(session_data).map_err(|_| AuthError::InvalidToken)?;
103
104 Ok(RefreshSession {
105 session_id: refresh_claims.sub,
106 data: session_data,
107 })
108 }
109}
110
111impl<S, R, T> FromRequestParts<AuthState<S, R>> for OptionalRefreshSession<T>
112where
113 S: SessionStorage,
114 R: SessionDataRefresher,
115 T: Send + Sync + DeserializeOwned,
116{
117 type Rejection = AuthError;
118
119 async fn from_request_parts(
120 parts: &mut Parts,
121 state: &AuthState<S, R>,
122 ) -> std::result::Result<Self, Self::Rejection> {
123 match RefreshSession::<T>::from_request_parts(parts, state).await {
125 Ok(session) => Ok(OptionalRefreshSession(Some(session))),
126 Err(_) => Ok(OptionalRefreshSession(None)),
127 }
128 }
129}
130
131fn extract_token_from_authorization_header(parts: &Parts) -> Result<String> {
132 let header_value = parts
133 .headers
134 .get(header::AUTHORIZATION)
135 .ok_or(AuthError::MissingAuthHeader)?;
136
137 let auth_str = header_value
138 .to_str()
139 .map_err(|_| AuthError::InvalidAuthHeaderFormat)?;
140
141 if !auth_str.starts_with("Bearer ") {
142 return Err(AuthError::InvalidAuthHeaderFormat);
143 }
144
145 Ok(auth_str[7..].to_string())
146}