1use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_secs(24 * 60 * 60);
17
18pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28 Sha1,
29 Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33 pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34 match raw.trim().to_ascii_lowercase().as_str() {
35 "sha1" | "hmac-sha1" => Ok(Self::Sha1),
36 "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
37 other => Err(ConnectorError::Unsupported(format!(
38 "unsupported HMAC signature algorithm `{other}`"
39 ))),
40 }
41 }
42}
43
44pub fn verify_hmac_signature(
50 body: &[u8],
51 signature: &str,
52 secret: &str,
53 algorithm: HmacSignatureAlgorithm,
54) -> Result<bool, ConnectorError> {
55 let signature = signature.trim();
56 let signature = signature
57 .strip_prefix("sha256=")
58 .or_else(|| signature.strip_prefix("sha1="))
59 .unwrap_or(signature);
60 let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
61 name: "signature".to_string(),
62 detail: error.to_string(),
63 })?;
64 let expected = match algorithm {
65 HmacSignatureAlgorithm::Sha1 => hmac::hmac_sha1(secret.as_bytes(), body),
66 HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
67 };
68 Ok(hmac::secure_eq(&expected, &provided))
69}
70
71#[derive(Clone, Debug)]
72pub enum JwtKeySource<'a> {
73 Inline(&'a JwkSet),
74 Url(&'a str),
75}
76
77#[derive(Clone, Debug)]
78pub struct JwtVerificationOptions {
79 pub issuer: Option<String>,
80 pub audience: Option<String>,
81 pub required_spec_claims: Vec<String>,
82 pub jwks_cache_ttl: StdDuration,
83 pub egress_label: &'static str,
84}
85
86impl Default for JwtVerificationOptions {
87 fn default() -> Self {
88 Self {
89 issuer: None,
90 audience: None,
91 required_spec_claims: Vec::new(),
92 jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
93 egress_label: "connector:jwks",
94 }
95 }
96}
97
98impl JwtVerificationOptions {
99 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
100 self.issuer = Some(issuer.into());
101 self
102 }
103
104 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
105 self.audience = Some(audience.into());
106 self
107 }
108
109 pub fn require_spec_claims(
110 mut self,
111 claims: impl IntoIterator<Item = impl Into<String>>,
112 ) -> Self {
113 self.required_spec_claims = claims.into_iter().map(Into::into).collect();
114 self
115 }
116
117 pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
118 self.egress_label = egress_label;
119 self
120 }
121
122 pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
123 self.jwks_cache_ttl = ttl;
124 self
125 }
126}
127
128#[derive(Clone, Debug)]
129struct CachedJwks {
130 fetched_at: Instant,
131 jwks: JwkSet,
132}
133
134static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
135
136pub async fn resolve_jwks(
137 http: &reqwest::Client,
138 source: JwtKeySource<'_>,
139 options: &JwtVerificationOptions,
140) -> Result<JwkSet, ConnectorError> {
141 match source {
142 JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
143 JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
144 }
145}
146
147pub async fn verify_jwt_claims<T>(
148 http: &reqwest::Client,
149 token: &str,
150 source: JwtKeySource<'_>,
151 options: &JwtVerificationOptions,
152) -> Result<T, ConnectorError>
153where
154 T: DeserializeOwned,
155{
156 let header = decode_header(token)
157 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
158 let jwks = resolve_jwks(http, source.clone(), options).await?;
159 let jwks = match (source, header.kid.as_deref()) {
160 (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
161 fetch_uncached_jwks(http, jwks_url, options).await?
162 }
163 _ => jwks,
164 };
165 let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
166 let key = DecodingKey::from_jwk(jwk)
167 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
168 let mut validation = Validation::new(header.alg);
169 if !options.required_spec_claims.is_empty() {
170 let claims = options
171 .required_spec_claims
172 .iter()
173 .map(String::as_str)
174 .collect::<Vec<_>>();
175 validation.set_required_spec_claims(&claims);
176 }
177 if let Some(issuer) = options.issuer.as_deref() {
178 validation.set_issuer(&[issuer]);
179 }
180 if let Some(audience) = options.audience.as_deref() {
181 validation.set_audience(&[audience]);
182 }
183 decode::<T>(token, &key, &validation)
184 .map(|token| token.claims)
185 .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
186}
187
188fn jwk_for_header<'a>(
189 jwks: &'a JwkSet,
190 kid: Option<&str>,
191) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
192 match kid {
193 Some(kid) => jwks.find(kid).ok_or_else(|| {
194 ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
195 }),
196 None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
197 None => Err(ConnectorError::invalid_signature(
198 "JWT missing kid and JWKS contains multiple keys",
199 )),
200 }
201}
202
203pub async fn verify_jwt_json(
204 http: &reqwest::Client,
205 token: &str,
206 source: JwtKeySource<'_>,
207 options: &JwtVerificationOptions,
208) -> Result<JsonValue, ConnectorError> {
209 verify_jwt_claims(http, token, source, options).await
210}
211
212async fn fetch_cached_jwks(
213 http: &reqwest::Client,
214 jwks_url: &str,
215 options: &JwtVerificationOptions,
216) -> Result<JwkSet, ConnectorError> {
217 if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
218 return Ok(cached);
219 }
220 fetch_uncached_jwks(http, jwks_url, options).await
221}
222
223async fn fetch_uncached_jwks(
224 http: &reqwest::Client,
225 jwks_url: &str,
226 options: &JwtVerificationOptions,
227) -> Result<JwkSet, ConnectorError> {
228 if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
229 return Err(error);
230 }
231 let jwks = http
232 .get(jwks_url)
233 .send()
234 .await
235 .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
236 .error_for_status()
237 .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
238 .json::<JwkSet>()
239 .await
240 .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
241 store_cached_jwks(jwks_url, jwks.clone());
242 Ok(jwks)
243}
244
245fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
246 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
247 let cache = cache.read().expect("connector JWKS cache poisoned");
248 let cached = cache.get(url)?;
249 (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
250}
251
252fn store_cached_jwks(url: &str, jwks: JwkSet) {
253 let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
254 cache
255 .write()
256 .expect("connector JWKS cache poisoned")
257 .insert(
258 url.to_string(),
259 CachedJwks {
260 fetched_at: Instant::now(),
261 jwks,
262 },
263 );
264}
265
266#[derive(Clone, Debug, PartialEq)]
267pub struct CursorPage {
268 pub items: Vec<JsonValue>,
269 pub next_cursor: Option<String>,
270 pub has_more: bool,
271}
272
273pub async fn paginate_cursor<F, Fut>(
276 initial_cursor: Option<String>,
277 max_pages: Option<usize>,
278 mut fetch: F,
279) -> Result<Vec<JsonValue>, ConnectorError>
280where
281 F: FnMut(Option<String>) -> Fut,
282 Fut: Future<Output = Result<CursorPage, ConnectorError>>,
283{
284 let mut cursor = initial_cursor;
285 let mut pages = 0usize;
286 let mut results = Vec::new();
287 loop {
288 if max_pages.is_some_and(|limit| pages >= limit) {
289 break;
290 }
291 let page = fetch(cursor.clone()).await?;
292 results.extend(page.items);
293 pages += 1;
294 if !page.has_more {
295 break;
296 }
297 cursor = page.next_cursor;
298 if cursor.as_deref().is_none_or(str::is_empty) {
299 return Err(ConnectorError::Json(
300 "cursor-paginated connector response set has_more without next_cursor".to_string(),
301 ));
302 }
303 }
304 Ok(results)
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
311 use serde::{Deserialize, Serialize};
312 use serde_json::json;
313
314 #[derive(Debug, Deserialize, Serialize)]
315 struct Claims {
316 iss: String,
317 aud: String,
318 exp: i64,
319 jti: String,
320 }
321
322 fn hs_jwks() -> JwkSet {
323 serde_json::from_value(json!({
324 "keys": [{
325 "kty": "oct",
326 "kid": "test-key",
327 "alg": "HS256",
328 "k": "c2VjcmV0"
329 }]
330 }))
331 .unwrap()
332 }
333
334 fn hs_token() -> String {
335 let mut header = Header::new(Algorithm::HS256);
336 header.kid = Some("test-key".to_string());
337 encode(
338 &header,
339 &Claims {
340 iss: "issuer".to_string(),
341 aud: "audience".to_string(),
342 exp: 4_102_444_800,
343 jti: "jwt-1".to_string(),
344 },
345 &EncodingKey::from_secret(b"secret"),
346 )
347 .unwrap()
348 }
349
350 #[test]
351 fn hmac_signature_accepts_provider_prefixed_hex() {
352 let body = b"Hello, World!";
353 let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
354 assert!(verify_hmac_signature(
355 body,
356 signature,
357 "It's a Secret to Everybody",
358 HmacSignatureAlgorithm::Sha256,
359 )
360 .unwrap());
361 }
362
363 #[tokio::test]
364 async fn jwt_claims_verify_against_inline_jwks() {
365 let http = reqwest::Client::new();
366 let claims: Claims = verify_jwt_claims(
367 &http,
368 &hs_token(),
369 JwtKeySource::Inline(&hs_jwks()),
370 &JwtVerificationOptions::default()
371 .with_issuer("issuer")
372 .with_audience("audience")
373 .require_spec_claims(["exp", "iss", "aud"]),
374 )
375 .await
376 .unwrap();
377 assert_eq!(claims.jti, "jwt-1");
378 }
379
380 #[tokio::test]
381 async fn paginate_cursor_collects_until_has_more_is_false() {
382 let pages = [
383 CursorPage {
384 items: vec![json!({"id": 1})],
385 next_cursor: Some("b".to_string()),
386 has_more: true,
387 },
388 CursorPage {
389 items: vec![json!({"id": 2})],
390 next_cursor: None,
391 has_more: false,
392 },
393 ];
394 let mut index = 0usize;
395 let results = paginate_cursor(None, None, |_cursor| {
396 let page = pages[index].clone();
397 index += 1;
398 async move { Ok(page) }
399 })
400 .await
401 .unwrap();
402 assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
403 }
404}