1#[cfg(feature = "axum")]
2use axum::{
3 extract::FromRequestParts,
4 http::{header::COOKIE, request::Parts, StatusCode},
5};
6use tracing::error;
7
8use crate::{
9 config::AuthConfig,
10 jwt::{verify_access_jwt, Claims},
11};
12
13#[cfg(feature = "kv")]
14use crate::jwt::verify_access_jwt_cached;
15
16pub trait HasAuthConfig {
17 fn auth_config(&self) -> &AuthConfig;
18}
19
20#[cfg(feature = "kv")]
21pub trait HasJwksCache {
22 fn jwks_kv(&self) -> Option<worker::kv::KvStore>;
23}
24
25pub trait RoleMapper: Sized + Send + Sync + 'static {
26 fn from_claims(claims: &Claims) -> Vec<Self>;
27}
28
29impl RoleMapper for () {
30 fn from_claims(_: &Claims) -> Vec<Self> {
31 vec![]
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct User<R: RoleMapper = ()> {
37 pub claims: Claims,
38 pub roles: Vec<R>,
39 pub token: String,
40}
41
42impl<R: RoleMapper> User<R> {
43 pub fn has_role(&self, role: R) -> bool
44 where
45 R: PartialEq,
46 {
47 self.roles.contains(&role)
48 }
49
50 pub fn email(&self) -> &str {
52 &self.claims.email
53 }
54
55 pub fn sub(&self) -> &str {
56 &self.claims.sub
57 }
58
59 pub async fn from_worker_request(
60 req: &worker::Request,
61 config: &AuthConfig,
62 ) -> Result<Self, String> {
63 let token = extract_token_worker(req)
64 .or_else(|| extract_token_from_cookies_worker(req))
65 .ok_or_else(|| "missing access token".to_string())?;
66
67 let claims = verify_access_jwt(&token, config).await.map_err(|err| {
68 error!("JWT verification failed: {err:?}");
69 "invalid or expired token".to_string()
70 })?;
71
72 let roles = R::from_claims(&claims);
73
74 Ok(User {
75 claims,
76 roles,
77 token,
78 })
79 }
80
81 #[cfg(feature = "kv")]
82 pub async fn from_worker_request_cached(
83 req: &worker::Request,
84 config: &AuthConfig,
85 kv: &worker::kv::KvStore,
86 ) -> Result<Self, String> {
87 let token = extract_token_worker(req)
88 .or_else(|| extract_token_from_cookies_worker(req))
89 .ok_or_else(|| "missing access token".to_string())?;
90
91 let claims = verify_access_jwt_cached(&token, config, kv)
92 .await
93 .map_err(|err| {
94 error!("JWT verification failed: {err:?}");
95 "invalid or expired token".to_string()
96 })?;
97
98 let roles = R::from_claims(&claims);
99
100 Ok(User {
101 claims,
102 roles,
103 token,
104 })
105 }
106}
107
108#[cfg(all(feature = "axum", not(feature = "kv")))]
109impl<S, R> FromRequestParts<S> for User<R>
110where
111 S: HasAuthConfig + Send + Sync,
112 R: RoleMapper,
113{
114 type Rejection = (StatusCode, String);
115
116 #[worker::send]
117 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
118 let token = extract_token(&parts.headers)
119 .or_else(|| extract_token_from_cookies(&parts.headers))
120 .ok_or((StatusCode::UNAUTHORIZED, "missing access token".to_string()))?;
121
122 let config = state.auth_config();
123
124 let claims = verify_access_jwt(&token, config).await.map_err(|err| {
125 error!("JWT verification failed: {err:?}");
126 (
127 StatusCode::UNAUTHORIZED,
128 "invalid or expired token".to_string(),
129 )
130 })?;
131
132 let roles = R::from_claims(&claims);
133
134 Ok(User {
135 claims,
136 roles,
137 token,
138 })
139 }
140}
141
142#[cfg(all(feature = "axum", feature = "kv"))]
143impl<S, R> FromRequestParts<S> for User<R>
144where
145 S: HasAuthConfig + HasJwksCache + Send + Sync,
146 R: RoleMapper,
147{
148 type Rejection = (StatusCode, String);
149
150 #[worker::send]
151 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
152 let token = extract_token(&parts.headers)
153 .or_else(|| extract_token_from_cookies(&parts.headers))
154 .ok_or((StatusCode::UNAUTHORIZED, "missing access token".to_string()))?;
155
156 let config = state.auth_config();
157
158 let claims = if let Some(ref kv) = state.jwks_kv() {
159 verify_access_jwt_cached(&token, config, kv).await
160 } else {
161 verify_access_jwt(&token, config).await
162 }
163 .map_err(|err| {
164 error!("JWT verification failed: {err:?}");
165 (
166 StatusCode::UNAUTHORIZED,
167 "invalid or expired token".to_string(),
168 )
169 })?;
170
171 let roles = R::from_claims(&claims);
172
173 Ok(User {
174 claims,
175 roles,
176 token,
177 })
178 }
179}
180
181#[cfg(feature = "axum")]
182fn extract_token(headers: &axum::http::HeaderMap) -> Option<String> {
183 headers
184 .get("CF_Authorization")
185 .or_else(|| headers.get("Cf-Access-Jwt-Assertion"))
186 .and_then(|value| value.to_str().ok())
187 .map(|s| s.to_string())
188}
189
190#[cfg(feature = "axum")]
191fn extract_token_from_cookies(headers: &axum::http::HeaderMap) -> Option<String> {
192 headers
193 .get(COOKIE)
194 .and_then(|value| value.to_str().ok())
195 .and_then(|cookie_header| {
196 cookie_header
197 .split(';')
198 .map(|kv| kv.trim())
199 .find_map(|pair| {
200 let mut parts = pair.splitn(2, '=');
201 match (parts.next(), parts.next()) {
202 (Some("CF_Authorization"), Some(token)) => Some(token.to_string()),
203 _ => None,
204 }
205 })
206 })
207}
208
209fn extract_token_worker(req: &worker::Request) -> Option<String> {
210 let headers = req.headers();
211 headers
212 .get("CF_Authorization")
213 .ok()
214 .flatten()
215 .or_else(|| headers.get("Cf-Access-Jwt-Assertion").ok().flatten())
216}
217
218fn extract_token_from_cookies_worker(req: &worker::Request) -> Option<String> {
219 req.headers()
220 .get("Cookie")
221 .ok()
222 .flatten()
223 .and_then(|cookie_header| {
224 cookie_header
225 .split(';')
226 .map(|kv| kv.trim())
227 .find_map(|pair| {
228 let mut parts = pair.splitn(2, '=');
229 match (parts.next(), parts.next()) {
230 (Some("CF_Authorization"), Some(token)) => Some(token.to_string()),
231 _ => None,
232 }
233 })
234 })
235}