1use std::collections::HashMap;
2use std::sync::{OnceLock, RwLock};
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine as _;
6use js_sys::{Array, Date, Object, Reflect, Uint8Array};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use wasm_bindgen::JsCast;
9use wasm_bindgen_futures::JsFuture;
10use web_sys::{CryptoKey, SubtleCrypto};
11use worker::{Error, Fetch, Method, Request as WorkerRequest};
12
13#[cfg(feature = "kv")]
14use crate::cache::{get_cached_jwks, set_cached_jwks, CachedJwk};
15use crate::config::AuthConfig;
16
17type WorkerResult<T> = worker::Result<T>;
18
19const JWKS_CACHE_TTL_MS: f64 = 10.0 * 60.0 * 1000.0; #[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Claims {
23 pub aud: Vec<String>,
24 pub email: String,
25 pub exp: i64,
26 pub iss: String,
27 pub sub: String,
28 pub name: Option<String>,
29 #[serde(default)]
30 pub groups: Vec<String>,
31}
32
33#[derive(Clone, Deserialize)]
34struct Jwks {
35 keys: Vec<Jwk>,
36}
37
38#[derive(Clone, Deserialize)]
39struct Jwk {
40 kty: String,
41 kid: String,
42 n: String,
43 e: String,
44}
45
46#[derive(Deserialize)]
47struct JwtHeader {
48 alg: String,
49 kid: String,
50}
51
52#[derive(Clone)]
53struct CachedKeys {
54 fetched_at_ms: f64,
55 keys: Vec<Jwk>,
56}
57
58static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedKeys>>> = OnceLock::new();
59
60fn cache() -> &'static RwLock<HashMap<String, CachedKeys>> {
61 JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
62}
63
64#[worker::send]
65pub async fn verify_access_jwt(token: &str, config: &AuthConfig) -> WorkerResult<Claims> {
66 let token = token.trim();
67 let token = token.strip_prefix("Bearer ").unwrap_or(token);
68 let (header_b64, payload_b64, signature_b64) = split_jwt(token)?;
69
70 let header: JwtHeader = decode_segment(header_b64)?;
71 if header.alg != "RS256" {
72 return Err(auth_error("unsupported JWT algorithm"));
73 }
74
75 let jwk = find_jwk(config, &header.kid).await?;
76 verify_signature(header_b64, payload_b64, signature_b64, &jwk).await?;
77
78 let claims: Claims = decode_segment(payload_b64)?;
79 validate_claims(&claims, config)?;
80 Ok(claims)
81}
82
83#[cfg(feature = "kv")]
84#[worker::send]
85pub async fn verify_access_jwt_cached(
86 token: &str,
87 config: &AuthConfig,
88 kv: &worker::kv::KvStore,
89) -> WorkerResult<Claims> {
90 let token = token.trim();
91 let token = token.strip_prefix("Bearer ").unwrap_or(token);
92 let (header_b64, payload_b64, signature_b64) = split_jwt(token)?;
93
94 let header: JwtHeader = decode_segment(header_b64)?;
95 if header.alg != "RS256" {
96 return Err(auth_error("unsupported JWT algorithm"));
97 }
98
99 let jwk = find_jwk_cached(config, &header.kid, kv).await?;
100 verify_signature(header_b64, payload_b64, signature_b64, &jwk).await?;
101
102 let claims: Claims = decode_segment(payload_b64)?;
103 validate_claims(&claims, config)?;
104 Ok(claims)
105}
106
107fn validate_claims(claims: &Claims, config: &AuthConfig) -> WorkerResult<()> {
108 let aud_match = claims.aud.iter().any(|aud| aud == &*config.audience);
109 if !aud_match {
110 return Err(auth_error("audience mismatch"));
111 }
112
113 if claims.iss != config.issuer() {
114 return Err(auth_error("issuer mismatch"));
115 }
116
117 let now = Date::now() / 1000.0;
118 if (claims.exp as f64) <= now {
119 return Err(auth_error("token expired"));
120 }
121
122 Ok(())
123}
124
125async fn verify_signature(
126 header_b64: &str,
127 payload_b64: &str,
128 signature_b64: &str,
129 jwk: &Jwk,
130) -> WorkerResult<()> {
131 let crypto = get_subtle_crypto()?;
132 let crypto_key = import_jwk_as_crypto_key(&crypto, jwk).await?;
133
134 let signing_input = format!("{header_b64}.{payload_b64}");
135 let signature_bytes = decode_segment_raw(signature_b64)?;
136
137 let algorithm = Object::new();
138 Reflect::set(&algorithm, &"name".into(), &"RSASSA-PKCS1-v1_5".into())
139 .map_err(|_| auth_error("failed to set algorithm"))?;
140
141 let data = Uint8Array::from(signing_input.as_bytes());
142 let signature = Uint8Array::from(signature_bytes.as_slice());
143
144 let result = JsFuture::from(
145 crypto
146 .verify_with_object_and_buffer_source_and_buffer_source(
147 &algorithm,
148 &crypto_key,
149 &signature,
150 &data,
151 )
152 .map_err(|_| auth_error("verify call failed"))?,
153 )
154 .await
155 .map_err(|_| auth_error("signature verification failed"))?;
156
157 if !result.as_bool().unwrap_or(false) {
158 return Err(auth_error("JWT signature verification failed"));
159 }
160
161 Ok(())
162}
163
164async fn import_jwk_as_crypto_key(crypto: &SubtleCrypto, jwk: &Jwk) -> WorkerResult<CryptoKey> {
165 if jwk.kty != "RSA" {
166 return Err(auth_error("unexpected JWK kty"));
167 }
168
169 let jwk_obj = Object::new();
170 Reflect::set(&jwk_obj, &"kty".into(), &jwk.kty.as_str().into())
171 .map_err(|_| auth_error("failed to set kty"))?;
172 Reflect::set(&jwk_obj, &"n".into(), &jwk.n.as_str().into())
173 .map_err(|_| auth_error("failed to set n"))?;
174 Reflect::set(&jwk_obj, &"e".into(), &jwk.e.as_str().into())
175 .map_err(|_| auth_error("failed to set e"))?;
176 Reflect::set(&jwk_obj, &"alg".into(), &"RS256".into())
177 .map_err(|_| auth_error("failed to set alg"))?;
178
179 let algorithm = Object::new();
180 Reflect::set(&algorithm, &"name".into(), &"RSASSA-PKCS1-v1_5".into())
181 .map_err(|_| auth_error("failed to set algorithm name"))?;
182 Reflect::set(&algorithm, &"hash".into(), &"SHA-256".into())
183 .map_err(|_| auth_error("failed to set hash"))?;
184
185 let key_usages = Array::new();
186 key_usages.push(&"verify".into());
187
188 let promise = crypto
189 .import_key_with_object("jwk", &jwk_obj, &algorithm, false, &key_usages)
190 .map_err(|_| auth_error("import_key call failed"))?;
191
192 JsFuture::from(promise)
193 .await
194 .map_err(|_| auth_error("failed to import JWK"))?
195 .dyn_into::<CryptoKey>()
196 .map_err(|_| auth_error("failed to cast to CryptoKey"))
197}
198
199fn get_subtle_crypto() -> WorkerResult<SubtleCrypto> {
200 let global = js_sys::global();
201 let crypto =
202 Reflect::get(&global, &"crypto".into()).map_err(|_| auth_error("crypto not available"))?;
203 let subtle = Reflect::get(&crypto, &"subtle".into())
204 .map_err(|_| auth_error("subtle crypto not available"))?;
205 subtle
206 .dyn_into::<SubtleCrypto>()
207 .map_err(|_| auth_error("invalid SubtleCrypto"))
208}
209
210#[worker::send]
211async fn find_jwk(config: &AuthConfig, kid: &str) -> WorkerResult<Jwk> {
212 let keys = load_jwks(config).await?;
213 keys.into_iter()
214 .find(|key| key.kid == kid)
215 .ok_or_else(|| auth_error("kid not found in JWKS"))
216}
217
218#[worker::send]
219async fn load_jwks(config: &AuthConfig) -> WorkerResult<Vec<Jwk>> {
220 {
221 let c = cache()
222 .read()
223 .map_err(|_| auth_error("failed to read JWKS cache"))?;
224 if let Some(entry) = c.get(config.team_domain.as_ref()) {
225 if Date::now() - entry.fetched_at_ms <= JWKS_CACHE_TTL_MS {
226 return Ok(entry.keys.clone());
227 }
228 }
229 }
230
231 let url = format!("{}/cdn-cgi/access/certs", config.team_domain.as_ref());
232 let request = WorkerRequest::new(&url, Method::Get)?;
233 let mut resp = Fetch::Request(request).send().await?;
234 let status = resp.status_code();
235 if !(200..=299).contains(&status) {
236 return Err(auth_error(format!(
237 "unable to fetch Access JWKS (status {status})"
238 )));
239 }
240 let body = resp.text().await?;
241 let jwks: Jwks =
242 serde_json::from_str(&body).map_err(|err| auth_error(format!("invalid JWKS: {err}")))?;
243
244 {
245 let mut c = cache()
246 .write()
247 .map_err(|_| auth_error("failed to write JWKS cache"))?;
248 c.insert(
249 config.team_domain.as_ref().clone(),
250 CachedKeys {
251 fetched_at_ms: Date::now(),
252 keys: jwks.keys.clone(),
253 },
254 );
255 }
256
257 Ok(jwks.keys)
258}
259
260#[cfg(feature = "kv")]
261#[worker::send]
262async fn find_jwk_cached(
263 config: &AuthConfig,
264 kid: &str,
265 kv: &worker::kv::KvStore,
266) -> WorkerResult<Jwk> {
267 let keys = load_jwks_cached(config, kv).await?;
268 keys.into_iter()
269 .find(|key| key.kid == kid)
270 .ok_or_else(|| auth_error("kid not found in JWKS"))
271}
272
273#[cfg(feature = "kv")]
274#[worker::send]
275async fn load_jwks_cached(config: &AuthConfig, kv: &worker::kv::KvStore) -> WorkerResult<Vec<Jwk>> {
276 if let Some(cached) = get_cached_jwks(kv, config.team_domain.as_ref()).await {
277 return Ok(cached
278 .keys
279 .into_iter()
280 .map(|k| Jwk {
281 kty: k.kty,
282 kid: k.kid,
283 n: k.n,
284 e: k.e,
285 })
286 .collect());
287 }
288
289 let url = format!("{}/cdn-cgi/access/certs", config.team_domain.as_ref());
290 let request = WorkerRequest::new(&url, Method::Get)?;
291 let mut resp = Fetch::Request(request).send().await?;
292 let status = resp.status_code();
293 if !(200..=299).contains(&status) {
294 return Err(auth_error(format!(
295 "unable to fetch Access JWKS (status {status})"
296 )));
297 }
298 let body = resp.text().await?;
299 let jwks: Jwks =
300 serde_json::from_str(&body).map_err(|err| auth_error(format!("invalid JWKS: {err}")))?;
301
302 let cached_keys: Vec<CachedJwk> = jwks
303 .keys
304 .iter()
305 .map(|k| CachedJwk {
306 kty: k.kty.clone(),
307 kid: k.kid.clone(),
308 n: k.n.clone(),
309 e: k.e.clone(),
310 })
311 .collect();
312
313 if let Err(e) = set_cached_jwks(kv, config.team_domain.as_ref(), cached_keys).await {
314 tracing::warn!("Failed to cache JWKS in KV: {e:?}");
315 }
316
317 Ok(jwks.keys)
318}
319
320fn split_jwt(token: &str) -> WorkerResult<(&str, &str, &str)> {
321 let mut segments = token.split('.');
322 match (
323 segments.next(),
324 segments.next(),
325 segments.next(),
326 segments.next(),
327 ) {
328 (Some(h), Some(p), Some(s), None) => Ok((h, p, s)),
329 _ => Err(auth_error("malformed JWT")),
330 }
331}
332
333fn decode_segment<T>(segment: &str) -> WorkerResult<T>
334where
335 T: DeserializeOwned,
336{
337 let bytes = decode_segment_raw(segment)?;
338 serde_json::from_slice(&bytes).map_err(|err| auth_error(format!("invalid JSON: {err}")))
339}
340
341fn decode_segment_raw(segment: &str) -> WorkerResult<Vec<u8>> {
342 URL_SAFE_NO_PAD
343 .decode(segment.as_bytes())
344 .map_err(|_| auth_error("invalid base64 segment"))
345}
346
347fn auth_error<T: Into<String>>(message: T) -> Error {
348 Error::RustError(format!("auth: {}", message.into()))
349}