hb_auth/
jwt.rs

1use std::collections::HashMap;
2use std::sync::{OnceLock, RwLock};
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine as _;
6use js_sys::{Array, Date, Object, Reflect, Uint8Array};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use wasm_bindgen::JsCast;
9use wasm_bindgen_futures::JsFuture;
10use web_sys::{CryptoKey, SubtleCrypto};
11use worker::{Error, Fetch, Method, Request as WorkerRequest};
12
13#[cfg(feature = "kv")]
14use crate::cache::{get_cached_jwks, set_cached_jwks, CachedJwk};
15use crate::config::AuthConfig;
16
17type WorkerResult<T> = worker::Result<T>;
18
19const JWKS_CACHE_TTL_MS: f64 = 10.0 * 60.0 * 1000.0; // 10 minutes
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Claims {
23    pub aud: Vec<String>,
24    pub email: String,
25    pub exp: i64,
26    pub iss: String,
27    pub sub: String,
28    pub name: Option<String>,
29    #[serde(default)]
30    pub groups: Vec<String>,
31}
32
33#[derive(Clone, Deserialize)]
34struct Jwks {
35    keys: Vec<Jwk>,
36}
37
38#[derive(Clone, Deserialize)]
39struct Jwk {
40    kty: String,
41    kid: String,
42    n: String,
43    e: String,
44}
45
46#[derive(Deserialize)]
47struct JwtHeader {
48    alg: String,
49    kid: String,
50}
51
52#[derive(Clone)]
53struct CachedKeys {
54    fetched_at_ms: f64,
55    keys: Vec<Jwk>,
56}
57
58static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedKeys>>> = OnceLock::new();
59
60fn cache() -> &'static RwLock<HashMap<String, CachedKeys>> {
61    JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
62}
63
64#[worker::send]
65pub async fn verify_access_jwt(token: &str, config: &AuthConfig) -> WorkerResult<Claims> {
66    let token = token.trim();
67    let token = token.strip_prefix("Bearer ").unwrap_or(token);
68    let (header_b64, payload_b64, signature_b64) = split_jwt(token)?;
69
70    let header: JwtHeader = decode_segment(header_b64)?;
71    if header.alg != "RS256" {
72        return Err(auth_error("unsupported JWT algorithm"));
73    }
74
75    let jwk = find_jwk(config, &header.kid).await?;
76    verify_signature(header_b64, payload_b64, signature_b64, &jwk).await?;
77
78    let claims: Claims = decode_segment(payload_b64)?;
79    validate_claims(&claims, config)?;
80    Ok(claims)
81}
82
83#[cfg(feature = "kv")]
84#[worker::send]
85pub async fn verify_access_jwt_cached(
86    token: &str,
87    config: &AuthConfig,
88    kv: &worker::kv::KvStore,
89) -> WorkerResult<Claims> {
90    let token = token.trim();
91    let token = token.strip_prefix("Bearer ").unwrap_or(token);
92    let (header_b64, payload_b64, signature_b64) = split_jwt(token)?;
93
94    let header: JwtHeader = decode_segment(header_b64)?;
95    if header.alg != "RS256" {
96        return Err(auth_error("unsupported JWT algorithm"));
97    }
98
99    let jwk = find_jwk_cached(config, &header.kid, kv).await?;
100    verify_signature(header_b64, payload_b64, signature_b64, &jwk).await?;
101
102    let claims: Claims = decode_segment(payload_b64)?;
103    validate_claims(&claims, config)?;
104    Ok(claims)
105}
106
107fn validate_claims(claims: &Claims, config: &AuthConfig) -> WorkerResult<()> {
108    let aud_match = claims.aud.iter().any(|aud| aud == &*config.audience);
109    if !aud_match {
110        return Err(auth_error("audience mismatch"));
111    }
112
113    if claims.iss != config.issuer() {
114        return Err(auth_error("issuer mismatch"));
115    }
116
117    let now = Date::now() / 1000.0;
118    if (claims.exp as f64) <= now {
119        return Err(auth_error("token expired"));
120    }
121
122    Ok(())
123}
124
125async fn verify_signature(
126    header_b64: &str,
127    payload_b64: &str,
128    signature_b64: &str,
129    jwk: &Jwk,
130) -> WorkerResult<()> {
131    let crypto = get_subtle_crypto()?;
132    let crypto_key = import_jwk_as_crypto_key(&crypto, jwk).await?;
133
134    let signing_input = format!("{header_b64}.{payload_b64}");
135    let signature_bytes = decode_segment_raw(signature_b64)?;
136
137    let algorithm = Object::new();
138    Reflect::set(&algorithm, &"name".into(), &"RSASSA-PKCS1-v1_5".into())
139        .map_err(|_| auth_error("failed to set algorithm"))?;
140
141    let data = Uint8Array::from(signing_input.as_bytes());
142    let signature = Uint8Array::from(signature_bytes.as_slice());
143
144    let result = JsFuture::from(
145        crypto
146            .verify_with_object_and_buffer_source_and_buffer_source(
147                &algorithm,
148                &crypto_key,
149                &signature,
150                &data,
151            )
152            .map_err(|_| auth_error("verify call failed"))?,
153    )
154    .await
155    .map_err(|_| auth_error("signature verification failed"))?;
156
157    if !result.as_bool().unwrap_or(false) {
158        return Err(auth_error("JWT signature verification failed"));
159    }
160
161    Ok(())
162}
163
164async fn import_jwk_as_crypto_key(crypto: &SubtleCrypto, jwk: &Jwk) -> WorkerResult<CryptoKey> {
165    if jwk.kty != "RSA" {
166        return Err(auth_error("unexpected JWK kty"));
167    }
168
169    let jwk_obj = Object::new();
170    Reflect::set(&jwk_obj, &"kty".into(), &jwk.kty.as_str().into())
171        .map_err(|_| auth_error("failed to set kty"))?;
172    Reflect::set(&jwk_obj, &"n".into(), &jwk.n.as_str().into())
173        .map_err(|_| auth_error("failed to set n"))?;
174    Reflect::set(&jwk_obj, &"e".into(), &jwk.e.as_str().into())
175        .map_err(|_| auth_error("failed to set e"))?;
176    Reflect::set(&jwk_obj, &"alg".into(), &"RS256".into())
177        .map_err(|_| auth_error("failed to set alg"))?;
178
179    let algorithm = Object::new();
180    Reflect::set(&algorithm, &"name".into(), &"RSASSA-PKCS1-v1_5".into())
181        .map_err(|_| auth_error("failed to set algorithm name"))?;
182    Reflect::set(&algorithm, &"hash".into(), &"SHA-256".into())
183        .map_err(|_| auth_error("failed to set hash"))?;
184
185    let key_usages = Array::new();
186    key_usages.push(&"verify".into());
187
188    let promise = crypto
189        .import_key_with_object("jwk", &jwk_obj, &algorithm, false, &key_usages)
190        .map_err(|_| auth_error("import_key call failed"))?;
191
192    JsFuture::from(promise)
193        .await
194        .map_err(|_| auth_error("failed to import JWK"))?
195        .dyn_into::<CryptoKey>()
196        .map_err(|_| auth_error("failed to cast to CryptoKey"))
197}
198
199fn get_subtle_crypto() -> WorkerResult<SubtleCrypto> {
200    let global = js_sys::global();
201    let crypto =
202        Reflect::get(&global, &"crypto".into()).map_err(|_| auth_error("crypto not available"))?;
203    let subtle = Reflect::get(&crypto, &"subtle".into())
204        .map_err(|_| auth_error("subtle crypto not available"))?;
205    subtle
206        .dyn_into::<SubtleCrypto>()
207        .map_err(|_| auth_error("invalid SubtleCrypto"))
208}
209
210#[worker::send]
211async fn find_jwk(config: &AuthConfig, kid: &str) -> WorkerResult<Jwk> {
212    let keys = load_jwks(config).await?;
213    keys.into_iter()
214        .find(|key| key.kid == kid)
215        .ok_or_else(|| auth_error("kid not found in JWKS"))
216}
217
218#[worker::send]
219async fn load_jwks(config: &AuthConfig) -> WorkerResult<Vec<Jwk>> {
220    {
221        let c = cache()
222            .read()
223            .map_err(|_| auth_error("failed to read JWKS cache"))?;
224        if let Some(entry) = c.get(config.team_domain.as_ref()) {
225            if Date::now() - entry.fetched_at_ms <= JWKS_CACHE_TTL_MS {
226                return Ok(entry.keys.clone());
227            }
228        }
229    }
230
231    let url = format!("{}/cdn-cgi/access/certs", config.team_domain.as_ref());
232    let request = WorkerRequest::new(&url, Method::Get)?;
233    let mut resp = Fetch::Request(request).send().await?;
234    let status = resp.status_code();
235    if !(200..=299).contains(&status) {
236        return Err(auth_error(format!(
237            "unable to fetch Access JWKS (status {status})"
238        )));
239    }
240    let body = resp.text().await?;
241    let jwks: Jwks =
242        serde_json::from_str(&body).map_err(|err| auth_error(format!("invalid JWKS: {err}")))?;
243
244    {
245        let mut c = cache()
246            .write()
247            .map_err(|_| auth_error("failed to write JWKS cache"))?;
248        c.insert(
249            config.team_domain.as_ref().clone(),
250            CachedKeys {
251                fetched_at_ms: Date::now(),
252                keys: jwks.keys.clone(),
253            },
254        );
255    }
256
257    Ok(jwks.keys)
258}
259
260#[cfg(feature = "kv")]
261#[worker::send]
262async fn find_jwk_cached(
263    config: &AuthConfig,
264    kid: &str,
265    kv: &worker::kv::KvStore,
266) -> WorkerResult<Jwk> {
267    let keys = load_jwks_cached(config, kv).await?;
268    keys.into_iter()
269        .find(|key| key.kid == kid)
270        .ok_or_else(|| auth_error("kid not found in JWKS"))
271}
272
273#[cfg(feature = "kv")]
274#[worker::send]
275async fn load_jwks_cached(config: &AuthConfig, kv: &worker::kv::KvStore) -> WorkerResult<Vec<Jwk>> {
276    if let Some(cached) = get_cached_jwks(kv, config.team_domain.as_ref()).await {
277        return Ok(cached
278            .keys
279            .into_iter()
280            .map(|k| Jwk {
281                kty: k.kty,
282                kid: k.kid,
283                n: k.n,
284                e: k.e,
285            })
286            .collect());
287    }
288
289    let url = format!("{}/cdn-cgi/access/certs", config.team_domain.as_ref());
290    let request = WorkerRequest::new(&url, Method::Get)?;
291    let mut resp = Fetch::Request(request).send().await?;
292    let status = resp.status_code();
293    if !(200..=299).contains(&status) {
294        return Err(auth_error(format!(
295            "unable to fetch Access JWKS (status {status})"
296        )));
297    }
298    let body = resp.text().await?;
299    let jwks: Jwks =
300        serde_json::from_str(&body).map_err(|err| auth_error(format!("invalid JWKS: {err}")))?;
301
302    let cached_keys: Vec<CachedJwk> = jwks
303        .keys
304        .iter()
305        .map(|k| CachedJwk {
306            kty: k.kty.clone(),
307            kid: k.kid.clone(),
308            n: k.n.clone(),
309            e: k.e.clone(),
310        })
311        .collect();
312
313    if let Err(e) = set_cached_jwks(kv, config.team_domain.as_ref(), cached_keys).await {
314        tracing::warn!("Failed to cache JWKS in KV: {e:?}");
315    }
316
317    Ok(jwks.keys)
318}
319
320fn split_jwt(token: &str) -> WorkerResult<(&str, &str, &str)> {
321    let mut segments = token.split('.');
322    match (
323        segments.next(),
324        segments.next(),
325        segments.next(),
326        segments.next(),
327    ) {
328        (Some(h), Some(p), Some(s), None) => Ok((h, p, s)),
329        _ => Err(auth_error("malformed JWT")),
330    }
331}
332
333fn decode_segment<T>(segment: &str) -> WorkerResult<T>
334where
335    T: DeserializeOwned,
336{
337    let bytes = decode_segment_raw(segment)?;
338    serde_json::from_slice(&bytes).map_err(|err| auth_error(format!("invalid JSON: {err}")))
339}
340
341fn decode_segment_raw(segment: &str) -> WorkerResult<Vec<u8>> {
342    URL_SAFE_NO_PAD
343        .decode(segment.as_bytes())
344        .map_err(|_| auth_error("invalid base64 segment"))
345}
346
347fn auth_error<T: Into<String>>(message: T) -> Error {
348    Error::RustError(format!("auth: {}", message.into()))
349}