1use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_hours(24);
17
18pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28 LegacySha1,
29 Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33 pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34 Self::parse_with_legacy_sha1(raw, false)
35 }
36
37 pub fn parse_with_legacy_sha1(
38 raw: &str,
39 allow_legacy_sha1: bool,
40 ) -> Result<Self, ConnectorError> {
41 match raw.trim().to_ascii_lowercase().as_str() {
42 "sha1" | "hmac-sha1" if allow_legacy_sha1 => Ok(Self::LegacySha1),
43 "sha1" | "hmac-sha1" => Err(ConnectorError::Unsupported(
44 "HMAC-SHA1 is legacy; set `allow_legacy_sha1: true` for an existing provider"
45 .to_string(),
46 )),
47 "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
48 other => Err(ConnectorError::Unsupported(format!(
49 "unsupported HMAC signature algorithm `{other}`"
50 ))),
51 }
52 }
53}
54
55pub fn verify_hmac_signature(
61 body: &[u8],
62 signature: &str,
63 secret: &str,
64 algorithm: HmacSignatureAlgorithm,
65) -> Result<bool, ConnectorError> {
66 let signature = signature.trim();
67 let signature = signature
68 .strip_prefix("sha256=")
69 .or_else(|| signature.strip_prefix("sha1="))
70 .unwrap_or(signature);
71 let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
72 name: "signature".to_string(),
73 detail: error.to_string(),
74 })?;
75 let expected = match algorithm {
76 HmacSignatureAlgorithm::LegacySha1 => hmac::hmac_sha1(secret.as_bytes(), body),
77 HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
78 };
79 Ok(hmac::secure_eq(&expected, &provided))
80}
81
82#[derive(Clone, Debug)]
83pub enum JwtKeySource<'a> {
84 Inline(&'a JwkSet),
85 Url(&'a str),
86}
87
88#[derive(Clone, Debug)]
89pub struct JwtVerificationOptions {
90 pub issuer: Option<String>,
91 pub audience: Option<String>,
92 pub required_spec_claims: Vec<String>,
93 pub jwks_cache_ttl: StdDuration,
94 pub egress_label: &'static str,
95 pub expected_algorithm: Algorithm,
104}
105
106impl Default for JwtVerificationOptions {
107 fn default() -> Self {
108 Self {
109 issuer: None,
110 audience: None,
111 required_spec_claims: Vec::new(),
112 jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
113 egress_label: "connector:jwks",
114 expected_algorithm: Algorithm::RS256,
115 }
116 }
117}
118
119impl JwtVerificationOptions {
120 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
121 self.issuer = Some(issuer.into());
122 self
123 }
124
125 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
126 self.audience = Some(audience.into());
127 self
128 }
129
130 pub fn require_spec_claims(
131 mut self,
132 claims: impl IntoIterator<Item = impl Into<String>>,
133 ) -> Self {
134 self.required_spec_claims = claims.into_iter().map(Into::into).collect();
135 self
136 }
137
138 pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
139 self.egress_label = egress_label;
140 self
141 }
142
143 pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
144 self.jwks_cache_ttl = ttl;
145 self
146 }
147
148 pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
153 self.expected_algorithm = algorithm;
154 self
155 }
156}
157
158#[derive(Clone, Debug)]
159struct CachedJwks {
160 fetched_at: Instant,
161 jwks: JwkSet,
162}
163
164static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
165
166pub async fn resolve_jwks(
167 http: &reqwest::Client,
168 source: JwtKeySource<'_>,
169 options: &JwtVerificationOptions,
170) -> Result<JwkSet, ConnectorError> {
171 match source {
172 JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
173 JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
174 }
175}
176
177pub async fn verify_jwt_claims<T>(
178 http: &reqwest::Client,
179 token: &str,
180 source: JwtKeySource<'_>,
181 options: &JwtVerificationOptions,
182) -> Result<T, ConnectorError>
183where
184 T: DeserializeOwned,
185{
186 let header = decode_header(token)
187 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
188 if header.alg != options.expected_algorithm {
195 return Err(ConnectorError::invalid_signature(format!(
196 "JWT header alg {:?} does not match expected {:?}",
197 header.alg, options.expected_algorithm
198 )));
199 }
200 let jwks = resolve_jwks(http, source.clone(), options).await?;
201 let jwks = match (source, header.kid.as_deref()) {
202 (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
203 fetch_uncached_jwks(http, jwks_url, options).await?
204 }
205 _ => jwks,
206 };
207 let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
208 let key = DecodingKey::from_jwk(jwk)
209 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
210 let mut validation = Validation::new(options.expected_algorithm);
211 if !options.required_spec_claims.is_empty() {
212 let claims = options
213 .required_spec_claims
214 .iter()
215 .map(String::as_str)
216 .collect::<Vec<_>>();
217 validation.set_required_spec_claims(&claims);
218 }
219 if let Some(issuer) = options.issuer.as_deref() {
220 validation.set_issuer(&[issuer]);
221 }
222 if let Some(audience) = options.audience.as_deref() {
223 validation.set_audience(&[audience]);
224 }
225 decode::<T>(token, &key, &validation)
226 .map(|token| token.claims)
227 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
228}
229
230fn jwk_for_header<'a>(
231 jwks: &'a JwkSet,
232 kid: Option<&str>,
233) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
234 match kid {
235 Some(kid) => jwks.find(kid).ok_or_else(|| {
236 ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
237 }),
238 None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
239 None => Err(ConnectorError::invalid_signature(
240 "JWT missing kid and JWKS contains multiple keys",
241 )),
242 }
243}
244
245pub async fn verify_jwt_json(
246 http: &reqwest::Client,
247 token: &str,
248 source: JwtKeySource<'_>,
249 options: &JwtVerificationOptions,
250) -> Result<JsonValue, ConnectorError> {
251 verify_jwt_claims(http, token, source, options).await
252}
253
254async fn fetch_cached_jwks(
255 http: &reqwest::Client,
256 jwks_url: &str,
257 options: &JwtVerificationOptions,
258) -> Result<JwkSet, ConnectorError> {
259 if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
260 return Ok(cached);
261 }
262 fetch_uncached_jwks(http, jwks_url, options).await
263}
264
265async fn fetch_uncached_jwks(
266 http: &reqwest::Client,
267 jwks_url: &str,
268 options: &JwtVerificationOptions,
269) -> Result<JwkSet, ConnectorError> {
270 if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
271 return Err(error);
272 }
273 let jwks = http
274 .get(jwks_url)
275 .send()
276 .await
277 .map_err(|error| {
278 ConnectorError::Activation(format!(
279 "fetch JWKS: {}",
280 crate::egress::redact_reqwest_error(&error)
281 ))
282 })?
283 .error_for_status()
284 .map_err(|error| {
285 ConnectorError::Activation(format!(
286 "fetch JWKS: {}",
287 crate::egress::redact_reqwest_error(&error)
288 ))
289 })?
290 .json::<JwkSet>()
291 .await
292 .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
293 store_cached_jwks(jwks_url, jwks.clone());
294 Ok(jwks)
295}
296
297fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
298 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
299 let cache = cache.read().expect("connector JWKS cache poisoned");
300 let cached = cache.get(url)?;
301 (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
302}
303
304fn store_cached_jwks(url: &str, jwks: JwkSet) {
305 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
306 cache
307 .write()
308 .expect("connector JWKS cache poisoned")
309 .insert(
310 url.to_string(),
311 CachedJwks {
312 fetched_at: Instant::now(),
313 jwks,
314 },
315 );
316}
317
318#[derive(Clone, Debug, PartialEq, Eq)]
319pub struct CursorPage {
320 pub items: Vec<JsonValue>,
321 pub next_cursor: Option<String>,
322 pub has_more: bool,
323}
324
325pub async fn paginate_cursor<F, Fut>(
328 initial_cursor: Option<String>,
329 max_pages: Option<usize>,
330 mut fetch: F,
331) -> Result<Vec<JsonValue>, ConnectorError>
332where
333 F: FnMut(Option<String>) -> Fut,
334 Fut: Future<Output = Result<CursorPage, ConnectorError>>,
335{
336 let mut cursor = initial_cursor;
337 let mut pages = 0usize;
338 let mut results = Vec::new();
339 loop {
340 if max_pages.is_some_and(|limit| pages >= limit) {
341 break;
342 }
343 let page = fetch(cursor.clone()).await?;
344 results.extend(page.items);
345 pages += 1;
346 if !page.has_more {
347 break;
348 }
349 cursor = page.next_cursor;
350 if cursor.as_deref().is_none_or(str::is_empty) {
351 return Err(ConnectorError::Json(
352 "cursor-paginated connector response set has_more without next_cursor".to_string(),
353 ));
354 }
355 }
356 Ok(results)
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
363 use serde::{Deserialize, Serialize};
364 use serde_json::json;
365
366 #[derive(Debug, Deserialize, Serialize)]
367 struct Claims {
368 iss: String,
369 aud: String,
370 exp: i64,
371 jti: String,
372 }
373
374 fn hs_jwks() -> JwkSet {
375 serde_json::from_value(json!({
376 "keys": [{
377 "kty": "oct",
378 "kid": "test-key",
379 "alg": "HS256",
380 "k": "c2VjcmV0"
381 }]
382 }))
383 .unwrap()
384 }
385
386 fn hs_token() -> String {
387 let mut header = Header::new(Algorithm::HS256);
388 header.kid = Some("test-key".to_string());
389 encode(
390 &header,
391 &Claims {
392 iss: "issuer".to_string(),
393 aud: "audience".to_string(),
394 exp: 4_102_444_800,
395 jti: "jwt-1".to_string(),
396 },
397 &EncodingKey::from_secret(b"secret"),
398 )
399 .unwrap()
400 }
401
402 #[test]
403 fn hmac_signature_accepts_provider_prefixed_hex() {
404 let body = b"Hello, World!";
405 let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
406 assert!(verify_hmac_signature(
407 body,
408 signature,
409 "It's a Secret to Everybody",
410 HmacSignatureAlgorithm::Sha256,
411 )
412 .unwrap());
413 }
414
415 #[test]
416 fn hmac_signature_sha1_requires_explicit_legacy_algorithm() {
417 let parse_error = HmacSignatureAlgorithm::parse("sha1").expect_err("sha1 is gated");
418 assert!(parse_error.to_string().contains("allow_legacy_sha1"));
419 assert_eq!(
420 HmacSignatureAlgorithm::parse_with_legacy_sha1("sha1", true).unwrap(),
421 HmacSignatureAlgorithm::LegacySha1
422 );
423
424 let body = b"legacy";
425 let digest = hmac::hmac_sha1(b"legacy-secret", body);
426 let signature = format!("sha1={}", hex::encode(digest));
427 assert!(verify_hmac_signature(
428 body,
429 &signature,
430 "legacy-secret",
431 HmacSignatureAlgorithm::LegacySha1,
432 )
433 .unwrap());
434 }
435
436 #[tokio::test]
437 async fn jwt_claims_verify_against_inline_jwks() {
438 let http = reqwest::Client::new();
439 let claims: Claims = verify_jwt_claims(
440 &http,
441 &hs_token(),
442 JwtKeySource::Inline(&hs_jwks()),
443 &JwtVerificationOptions::default()
444 .with_algorithm(Algorithm::HS256)
445 .with_issuer("issuer")
446 .with_audience("audience")
447 .require_spec_claims(["exp", "iss", "aud"]),
448 )
449 .await
450 .unwrap();
451 assert_eq!(claims.jti, "jwt-1");
452 }
453
454 #[tokio::test]
455 async fn jwt_claims_reject_alg_confusion() {
456 let http = reqwest::Client::new();
462 let result = verify_jwt_claims::<Claims>(
463 &http,
464 &hs_token(),
465 JwtKeySource::Inline(&hs_jwks()),
466 &JwtVerificationOptions::default()
467 .with_algorithm(Algorithm::RS256)
468 .with_issuer("issuer")
469 .with_audience("audience"),
470 )
471 .await;
472 let error = result.expect_err("HS256 token should not verify under RS256");
473 let message = error.to_string();
474 assert!(
475 message.contains("alg") && message.contains("expected"),
476 "unexpected error: {message}"
477 );
478 }
479
480 #[tokio::test]
481 async fn paginate_cursor_collects_until_has_more_is_false() {
482 let pages = [
483 CursorPage {
484 items: vec![json!({"id": 1})],
485 next_cursor: Some("b".to_string()),
486 has_more: true,
487 },
488 CursorPage {
489 items: vec![json!({"id": 2})],
490 next_cursor: None,
491 has_more: false,
492 },
493 ];
494 let mut index = 0usize;
495 let results = paginate_cursor(None, None, |_cursor| {
496 let page = pages[index].clone();
497 index += 1;
498 async move { Ok(page) }
499 })
500 .await
501 .unwrap();
502 assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
503 }
504}