hb_auth/
extractor.rs

1#[cfg(feature = "axum")]
2use axum::{
3    extract::FromRequestParts,
4    http::{header::COOKIE, request::Parts, StatusCode},
5};
6use tracing::error;
7
8use crate::{
9    config::AuthConfig,
10    jwt::{verify_access_jwt, Claims},
11};
12
13#[cfg(feature = "kv")]
14use crate::jwt::verify_access_jwt_cached;
15
16pub trait HasAuthConfig {
17    fn auth_config(&self) -> &AuthConfig;
18}
19
20#[cfg(feature = "kv")]
21pub trait HasJwksCache {
22    fn jwks_kv(&self) -> Option<worker::kv::KvStore>;
23}
24
25pub trait RoleMapper: Sized + Send + Sync + 'static {
26    fn from_claims(claims: &Claims) -> Vec<Self>;
27}
28
29impl RoleMapper for () {
30    fn from_claims(_: &Claims) -> Vec<Self> {
31        vec![]
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct User<R: RoleMapper = ()> {
37    pub claims: Claims,
38    pub roles: Vec<R>,
39    pub token: String,
40}
41
42impl<R: RoleMapper> User<R> {
43    pub fn has_role(&self, role: R) -> bool
44    where
45        R: PartialEq,
46    {
47        self.roles.contains(&role)
48    }
49
50    // Helper accessors
51    pub fn email(&self) -> &str {
52        &self.claims.email
53    }
54
55    pub fn sub(&self) -> &str {
56        &self.claims.sub
57    }
58
59    pub async fn from_worker_request(
60        req: &worker::Request,
61        config: &AuthConfig,
62    ) -> Result<Self, String> {
63        let token = extract_token_worker(req)
64            .or_else(|| extract_token_from_cookies_worker(req))
65            .ok_or_else(|| "missing access token".to_string())?;
66
67        let claims = verify_access_jwt(&token, config).await.map_err(|err| {
68            error!("JWT verification failed: {err:?}");
69            "invalid or expired token".to_string()
70        })?;
71
72        let roles = R::from_claims(&claims);
73
74        Ok(User {
75            claims,
76            roles,
77            token,
78        })
79    }
80
81    #[cfg(feature = "kv")]
82    pub async fn from_worker_request_cached(
83        req: &worker::Request,
84        config: &AuthConfig,
85        kv: &worker::kv::KvStore,
86    ) -> Result<Self, String> {
87        let token = extract_token_worker(req)
88            .or_else(|| extract_token_from_cookies_worker(req))
89            .ok_or_else(|| "missing access token".to_string())?;
90
91        let claims = verify_access_jwt_cached(&token, config, kv)
92            .await
93            .map_err(|err| {
94                error!("JWT verification failed: {err:?}");
95                "invalid or expired token".to_string()
96            })?;
97
98        let roles = R::from_claims(&claims);
99
100        Ok(User {
101            claims,
102            roles,
103            token,
104        })
105    }
106}
107
108#[cfg(all(feature = "axum", not(feature = "kv")))]
109impl<S, R> FromRequestParts<S> for User<R>
110where
111    S: HasAuthConfig + Send + Sync,
112    R: RoleMapper,
113{
114    type Rejection = (StatusCode, String);
115
116    #[worker::send]
117    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
118        let token = extract_token(&parts.headers)
119            .or_else(|| extract_token_from_cookies(&parts.headers))
120            .ok_or((StatusCode::UNAUTHORIZED, "missing access token".to_string()))?;
121
122        let config = state.auth_config();
123
124        let claims = verify_access_jwt(&token, config).await.map_err(|err| {
125            error!("JWT verification failed: {err:?}");
126            (
127                StatusCode::UNAUTHORIZED,
128                "invalid or expired token".to_string(),
129            )
130        })?;
131
132        let roles = R::from_claims(&claims);
133
134        Ok(User {
135            claims,
136            roles,
137            token,
138        })
139    }
140}
141
142#[cfg(all(feature = "axum", feature = "kv"))]
143impl<S, R> FromRequestParts<S> for User<R>
144where
145    S: HasAuthConfig + HasJwksCache + Send + Sync,
146    R: RoleMapper,
147{
148    type Rejection = (StatusCode, String);
149
150    #[worker::send]
151    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
152        let token = extract_token(&parts.headers)
153            .or_else(|| extract_token_from_cookies(&parts.headers))
154            .ok_or((StatusCode::UNAUTHORIZED, "missing access token".to_string()))?;
155
156        let config = state.auth_config();
157
158        let claims = if let Some(ref kv) = state.jwks_kv() {
159            verify_access_jwt_cached(&token, config, kv).await
160        } else {
161            verify_access_jwt(&token, config).await
162        }
163        .map_err(|err| {
164            error!("JWT verification failed: {err:?}");
165            (
166                StatusCode::UNAUTHORIZED,
167                "invalid or expired token".to_string(),
168            )
169        })?;
170
171        let roles = R::from_claims(&claims);
172
173        Ok(User {
174            claims,
175            roles,
176            token,
177        })
178    }
179}
180
181#[cfg(feature = "axum")]
182fn extract_token(headers: &axum::http::HeaderMap) -> Option<String> {
183    headers
184        .get("CF_Authorization")
185        .or_else(|| headers.get("Cf-Access-Jwt-Assertion"))
186        .and_then(|value| value.to_str().ok())
187        .map(|s| s.to_string())
188}
189
190#[cfg(feature = "axum")]
191fn extract_token_from_cookies(headers: &axum::http::HeaderMap) -> Option<String> {
192    headers
193        .get(COOKIE)
194        .and_then(|value| value.to_str().ok())
195        .and_then(|cookie_header| {
196            cookie_header
197                .split(';')
198                .map(|kv| kv.trim())
199                .find_map(|pair| {
200                    let mut parts = pair.splitn(2, '=');
201                    match (parts.next(), parts.next()) {
202                        (Some("CF_Authorization"), Some(token)) => Some(token.to_string()),
203                        _ => None,
204                    }
205                })
206        })
207}
208
209fn extract_token_worker(req: &worker::Request) -> Option<String> {
210    let headers = req.headers();
211    headers
212        .get("CF_Authorization")
213        .ok()
214        .flatten()
215        .or_else(|| headers.get("Cf-Access-Jwt-Assertion").ok().flatten())
216}
217
218fn extract_token_from_cookies_worker(req: &worker::Request) -> Option<String> {
219    req.headers()
220        .get("Cookie")
221        .ok()
222        .flatten()
223        .and_then(|cookie_header| {
224            cookie_header
225                .split(';')
226                .map(|kv| kv.trim())
227                .find_map(|pair| {
228                    let mut parts = pair.splitn(2, '=');
229                    match (parts.next(), parts.next()) {
230                        (Some("CF_Authorization"), Some(token)) => Some(token.to_string()),
231                        _ => None,
232                    }
233                })
234        })
235}