1use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_hours(24);
17
18pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28 LegacySha1,
29 Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33 pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34 Self::parse_with_legacy_sha1(raw, false)
35 }
36
37 pub fn parse_with_legacy_sha1(
38 raw: &str,
39 allow_legacy_sha1: bool,
40 ) -> Result<Self, ConnectorError> {
41 match raw.trim().to_ascii_lowercase().as_str() {
42 "sha1" | "hmac-sha1" if allow_legacy_sha1 => Ok(Self::LegacySha1),
43 "sha1" | "hmac-sha1" => Err(ConnectorError::Unsupported(
44 "HMAC-SHA1 is legacy; set `allow_legacy_sha1: true` for an existing provider"
45 .to_string(),
46 )),
47 "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
48 other => Err(ConnectorError::Unsupported(format!(
49 "unsupported HMAC signature algorithm `{other}`"
50 ))),
51 }
52 }
53}
54
55pub fn verify_hmac_signature(
61 body: &[u8],
62 signature: &str,
63 secret: &str,
64 algorithm: HmacSignatureAlgorithm,
65) -> Result<bool, ConnectorError> {
66 let signature = signature.trim();
67 let signature = signature
68 .strip_prefix("sha256=")
69 .or_else(|| signature.strip_prefix("sha1="))
70 .unwrap_or(signature);
71 let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
72 name: "signature".to_string(),
73 detail: error.to_string(),
74 })?;
75 let expected = match algorithm {
76 HmacSignatureAlgorithm::LegacySha1 => hmac::hmac_sha1(secret.as_bytes(), body),
77 HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
78 };
79 Ok(hmac::secure_eq(&expected, &provided))
80}
81
82#[derive(Clone, Debug)]
83pub enum JwtKeySource<'a> {
84 Inline(&'a JwkSet),
85 Url(&'a str),
86}
87
88#[derive(Clone, Debug)]
89pub struct JwtVerificationOptions {
90 pub issuer: Option<String>,
91 pub audience: Option<String>,
92 pub required_spec_claims: Vec<String>,
93 pub jwks_cache_ttl: StdDuration,
94 pub egress_label: &'static str,
95 pub expected_algorithm: Algorithm,
104}
105
106impl Default for JwtVerificationOptions {
107 fn default() -> Self {
108 Self {
109 issuer: None,
110 audience: None,
111 required_spec_claims: Vec::new(),
112 jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
113 egress_label: "connector:jwks",
114 expected_algorithm: Algorithm::RS256,
115 }
116 }
117}
118
119impl JwtVerificationOptions {
120 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
121 self.issuer = Some(issuer.into());
122 self
123 }
124
125 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
126 self.audience = Some(audience.into());
127 self
128 }
129
130 pub fn require_spec_claims(
131 mut self,
132 claims: impl IntoIterator<Item = impl Into<String>>,
133 ) -> Self {
134 self.required_spec_claims = claims.into_iter().map(Into::into).collect();
135 self
136 }
137
138 pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
139 self.egress_label = egress_label;
140 self
141 }
142
143 pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
144 self.jwks_cache_ttl = ttl;
145 self
146 }
147
148 pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
153 self.expected_algorithm = algorithm;
154 self
155 }
156}
157
158#[derive(Clone, Debug)]
159struct CachedJwks {
160 fetched_at: Instant,
161 jwks: JwkSet,
162}
163
164static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
165
166pub async fn resolve_jwks(
167 http: &reqwest::Client,
168 source: JwtKeySource<'_>,
169 options: &JwtVerificationOptions,
170) -> Result<JwkSet, ConnectorError> {
171 match source {
172 JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
173 JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
174 }
175}
176
177pub async fn verify_jwt_claims<T>(
178 http: &reqwest::Client,
179 token: &str,
180 source: JwtKeySource<'_>,
181 options: &JwtVerificationOptions,
182) -> Result<T, ConnectorError>
183where
184 T: DeserializeOwned,
185{
186 let header = decode_header(token)
187 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
188 if header.alg != options.expected_algorithm {
195 return Err(ConnectorError::invalid_signature(format!(
196 "JWT header alg {:?} does not match expected {:?}",
197 header.alg, options.expected_algorithm
198 )));
199 }
200 let jwks = resolve_jwks(http, source.clone(), options).await?;
201 let jwks = match (source, header.kid.as_deref()) {
202 (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
203 fetch_uncached_jwks(http, jwks_url, options).await?
204 }
205 _ => jwks,
206 };
207 let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
208 let key = DecodingKey::from_jwk(jwk)
209 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
210 let mut validation = Validation::new(options.expected_algorithm);
211 if !options.required_spec_claims.is_empty() {
212 let claims = options
213 .required_spec_claims
214 .iter()
215 .map(String::as_str)
216 .collect::<Vec<_>>();
217 validation.set_required_spec_claims(&claims);
218 }
219 if let Some(issuer) = options.issuer.as_deref() {
220 validation.set_issuer(&[issuer]);
221 }
222 if let Some(audience) = options.audience.as_deref() {
223 validation.set_audience(&[audience]);
224 }
225 decode::<T>(token, &key, &validation)
226 .map(|token| token.claims)
227 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
228}
229
230fn jwk_for_header<'a>(
231 jwks: &'a JwkSet,
232 kid: Option<&str>,
233) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
234 match kid {
235 Some(kid) => jwks.find(kid).ok_or_else(|| {
236 ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
237 }),
238 None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
239 None => Err(ConnectorError::invalid_signature(
240 "JWT missing kid and JWKS contains multiple keys",
241 )),
242 }
243}
244
245pub async fn verify_jwt_json(
246 http: &reqwest::Client,
247 token: &str,
248 source: JwtKeySource<'_>,
249 options: &JwtVerificationOptions,
250) -> Result<JsonValue, ConnectorError> {
251 verify_jwt_claims(http, token, source, options).await
252}
253
254async fn fetch_cached_jwks(
255 http: &reqwest::Client,
256 jwks_url: &str,
257 options: &JwtVerificationOptions,
258) -> Result<JwkSet, ConnectorError> {
259 if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
260 return Ok(cached);
261 }
262 fetch_uncached_jwks(http, jwks_url, options).await
263}
264
265async fn fetch_uncached_jwks(
266 http: &reqwest::Client,
267 jwks_url: &str,
268 options: &JwtVerificationOptions,
269) -> Result<JwkSet, ConnectorError> {
270 if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
271 return Err(error);
272 }
273 let jwks = http
274 .get(jwks_url)
275 .send()
276 .await
277 .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
278 .error_for_status()
279 .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
280 .json::<JwkSet>()
281 .await
282 .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
283 store_cached_jwks(jwks_url, jwks.clone());
284 Ok(jwks)
285}
286
287fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
288 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
289 let cache = cache.read().expect("connector JWKS cache poisoned");
290 let cached = cache.get(url)?;
291 (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
292}
293
294fn store_cached_jwks(url: &str, jwks: JwkSet) {
295 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
296 cache
297 .write()
298 .expect("connector JWKS cache poisoned")
299 .insert(
300 url.to_string(),
301 CachedJwks {
302 fetched_at: Instant::now(),
303 jwks,
304 },
305 );
306}
307
308#[derive(Clone, Debug, PartialEq, Eq)]
309pub struct CursorPage {
310 pub items: Vec<JsonValue>,
311 pub next_cursor: Option<String>,
312 pub has_more: bool,
313}
314
315pub async fn paginate_cursor<F, Fut>(
318 initial_cursor: Option<String>,
319 max_pages: Option<usize>,
320 mut fetch: F,
321) -> Result<Vec<JsonValue>, ConnectorError>
322where
323 F: FnMut(Option<String>) -> Fut,
324 Fut: Future<Output = Result<CursorPage, ConnectorError>>,
325{
326 let mut cursor = initial_cursor;
327 let mut pages = 0usize;
328 let mut results = Vec::new();
329 loop {
330 if max_pages.is_some_and(|limit| pages >= limit) {
331 break;
332 }
333 let page = fetch(cursor.clone()).await?;
334 results.extend(page.items);
335 pages += 1;
336 if !page.has_more {
337 break;
338 }
339 cursor = page.next_cursor;
340 if cursor.as_deref().is_none_or(str::is_empty) {
341 return Err(ConnectorError::Json(
342 "cursor-paginated connector response set has_more without next_cursor".to_string(),
343 ));
344 }
345 }
346 Ok(results)
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
353 use serde::{Deserialize, Serialize};
354 use serde_json::json;
355
356 #[derive(Debug, Deserialize, Serialize)]
357 struct Claims {
358 iss: String,
359 aud: String,
360 exp: i64,
361 jti: String,
362 }
363
364 fn hs_jwks() -> JwkSet {
365 serde_json::from_value(json!({
366 "keys": [{
367 "kty": "oct",
368 "kid": "test-key",
369 "alg": "HS256",
370 "k": "c2VjcmV0"
371 }]
372 }))
373 .unwrap()
374 }
375
376 fn hs_token() -> String {
377 let mut header = Header::new(Algorithm::HS256);
378 header.kid = Some("test-key".to_string());
379 encode(
380 &header,
381 &Claims {
382 iss: "issuer".to_string(),
383 aud: "audience".to_string(),
384 exp: 4_102_444_800,
385 jti: "jwt-1".to_string(),
386 },
387 &EncodingKey::from_secret(b"secret"),
388 )
389 .unwrap()
390 }
391
392 #[test]
393 fn hmac_signature_accepts_provider_prefixed_hex() {
394 let body = b"Hello, World!";
395 let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
396 assert!(verify_hmac_signature(
397 body,
398 signature,
399 "It's a Secret to Everybody",
400 HmacSignatureAlgorithm::Sha256,
401 )
402 .unwrap());
403 }
404
405 #[test]
406 fn hmac_signature_sha1_requires_explicit_legacy_algorithm() {
407 let parse_error = HmacSignatureAlgorithm::parse("sha1").expect_err("sha1 is gated");
408 assert!(parse_error.to_string().contains("allow_legacy_sha1"));
409 assert_eq!(
410 HmacSignatureAlgorithm::parse_with_legacy_sha1("sha1", true).unwrap(),
411 HmacSignatureAlgorithm::LegacySha1
412 );
413
414 let body = b"legacy";
415 let digest = hmac::hmac_sha1(b"legacy-secret", body);
416 let signature = format!("sha1={}", hex::encode(digest));
417 assert!(verify_hmac_signature(
418 body,
419 &signature,
420 "legacy-secret",
421 HmacSignatureAlgorithm::LegacySha1,
422 )
423 .unwrap());
424 }
425
426 #[tokio::test]
427 async fn jwt_claims_verify_against_inline_jwks() {
428 let http = reqwest::Client::new();
429 let claims: Claims = verify_jwt_claims(
430 &http,
431 &hs_token(),
432 JwtKeySource::Inline(&hs_jwks()),
433 &JwtVerificationOptions::default()
434 .with_algorithm(Algorithm::HS256)
435 .with_issuer("issuer")
436 .with_audience("audience")
437 .require_spec_claims(["exp", "iss", "aud"]),
438 )
439 .await
440 .unwrap();
441 assert_eq!(claims.jti, "jwt-1");
442 }
443
444 #[tokio::test]
445 async fn jwt_claims_reject_alg_confusion() {
446 let http = reqwest::Client::new();
452 let result = verify_jwt_claims::<Claims>(
453 &http,
454 &hs_token(),
455 JwtKeySource::Inline(&hs_jwks()),
456 &JwtVerificationOptions::default()
457 .with_algorithm(Algorithm::RS256)
458 .with_issuer("issuer")
459 .with_audience("audience"),
460 )
461 .await;
462 let error = result.expect_err("HS256 token should not verify under RS256");
463 let message = error.to_string();
464 assert!(
465 message.contains("alg") && message.contains("expected"),
466 "unexpected error: {message}"
467 );
468 }
469
470 #[tokio::test]
471 async fn paginate_cursor_collects_until_has_more_is_false() {
472 let pages = [
473 CursorPage {
474 items: vec![json!({"id": 1})],
475 next_cursor: Some("b".to_string()),
476 has_more: true,
477 },
478 CursorPage {
479 items: vec![json!({"id": 2})],
480 next_cursor: None,
481 has_more: false,
482 },
483 ];
484 let mut index = 0usize;
485 let results = paginate_cursor(None, None, |_cursor| {
486 let page = pages[index].clone();
487 index += 1;
488 async move { Ok(page) }
489 })
490 .await
491 .unwrap();
492 assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
493 }
494}