Skip to main content

harn_vm/connectors/
shared.rs

1//! Shared connector building blocks that are useful to both Rust shims and
2//! Harn-authored connector packages.
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_hours(24);
17
18/// Base connector contract name for shared runtime code.
19///
20/// This stays as a blanket extension over `Connector` so there is one
21/// object-safe implementation contract for registry and adapter code.
22pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28    LegacySha1,
29    Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33    pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34        Self::parse_with_legacy_sha1(raw, false)
35    }
36
37    pub fn parse_with_legacy_sha1(
38        raw: &str,
39        allow_legacy_sha1: bool,
40    ) -> Result<Self, ConnectorError> {
41        match raw.trim().to_ascii_lowercase().as_str() {
42            "sha1" | "hmac-sha1" if allow_legacy_sha1 => Ok(Self::LegacySha1),
43            "sha1" | "hmac-sha1" => Err(ConnectorError::Unsupported(
44                "HMAC-SHA1 is legacy; set `allow_legacy_sha1: true` for an existing provider"
45                    .to_string(),
46            )),
47            "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
48            other => Err(ConnectorError::Unsupported(format!(
49                "unsupported HMAC signature algorithm `{other}`"
50            ))),
51        }
52    }
53}
54
55/// Verify a raw HMAC signature value using constant-time comparison.
56///
57/// `signature` may be a bare hex digest or a provider-style `sha256=<hex>` /
58/// `sha1=<hex>` value. Provider-specific timestamp and canonical-message
59/// checks belong in `hmac::verify_hmac_signed`.
60pub fn verify_hmac_signature(
61    body: &[u8],
62    signature: &str,
63    secret: &str,
64    algorithm: HmacSignatureAlgorithm,
65) -> Result<bool, ConnectorError> {
66    let signature = signature.trim();
67    let signature = signature
68        .strip_prefix("sha256=")
69        .or_else(|| signature.strip_prefix("sha1="))
70        .unwrap_or(signature);
71    let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
72        name: "signature".to_string(),
73        detail: error.to_string(),
74    })?;
75    let expected = match algorithm {
76        HmacSignatureAlgorithm::LegacySha1 => hmac::hmac_sha1(secret.as_bytes(), body),
77        HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
78    };
79    Ok(hmac::secure_eq(&expected, &provided))
80}
81
82#[derive(Clone, Debug)]
83pub enum JwtKeySource<'a> {
84    Inline(&'a JwkSet),
85    Url(&'a str),
86}
87
88#[derive(Clone, Debug)]
89pub struct JwtVerificationOptions {
90    pub issuer: Option<String>,
91    pub audience: Option<String>,
92    pub required_spec_claims: Vec<String>,
93    pub jwks_cache_ttl: StdDuration,
94    pub egress_label: &'static str,
95    /// Signing algorithm the caller expects this token to be signed
96    /// with. The verifier asserts `header.alg == expected_algorithm`
97    /// BEFORE constructing the `Validation` so an attacker-controlled
98    /// `alg` header cannot down/up-grade the verification algorithm
99    /// (the canonical "alg confusion" footgun). Defaults to RS256
100    /// (`Algorithm::RS256`) — change it explicitly via
101    /// [`JwtVerificationOptions::with_algorithm`] when a connector
102    /// uses HS256 / ES256 / etc.
103    pub expected_algorithm: Algorithm,
104}
105
106impl Default for JwtVerificationOptions {
107    fn default() -> Self {
108        Self {
109            issuer: None,
110            audience: None,
111            required_spec_claims: Vec::new(),
112            jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
113            egress_label: "connector:jwks",
114            expected_algorithm: Algorithm::RS256,
115        }
116    }
117}
118
119impl JwtVerificationOptions {
120    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
121        self.issuer = Some(issuer.into());
122        self
123    }
124
125    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
126        self.audience = Some(audience.into());
127        self
128    }
129
130    pub fn require_spec_claims(
131        mut self,
132        claims: impl IntoIterator<Item = impl Into<String>>,
133    ) -> Self {
134        self.required_spec_claims = claims.into_iter().map(Into::into).collect();
135        self
136    }
137
138    pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
139        self.egress_label = egress_label;
140        self
141    }
142
143    pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
144        self.jwks_cache_ttl = ttl;
145        self
146    }
147
148    /// Override the expected signing algorithm. Use this when the
149    /// connector you're verifying is known to sign with something
150    /// other than RS256 (e.g. HS256 for Slack-style shared secrets,
151    /// ES256 for some Apple flows).
152    pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
153        self.expected_algorithm = algorithm;
154        self
155    }
156}
157
158#[derive(Clone, Debug)]
159struct CachedJwks {
160    fetched_at: Instant,
161    jwks: JwkSet,
162}
163
164static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
165
166pub async fn resolve_jwks(
167    http: &reqwest::Client,
168    source: JwtKeySource<'_>,
169    options: &JwtVerificationOptions,
170) -> Result<JwkSet, ConnectorError> {
171    match source {
172        JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
173        JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
174    }
175}
176
177pub async fn verify_jwt_claims<T>(
178    http: &reqwest::Client,
179    token: &str,
180    source: JwtKeySource<'_>,
181    options: &JwtVerificationOptions,
182) -> Result<T, ConnectorError>
183where
184    T: DeserializeOwned,
185{
186    let header = decode_header(token)
187        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
188    // Defense-in-depth against JWT "alg confusion": refuse to accept a
189    // token whose self-declared algorithm does not match what the
190    // caller said to expect. jsonwebtoken 10.x cross-checks JWK key
191    // type against `header.alg`, which closes the immediate bypass,
192    // but the canonical defense is to never trust `header.alg` for
193    // algorithm selection in the first place.
194    if header.alg != options.expected_algorithm {
195        return Err(ConnectorError::invalid_signature(format!(
196            "JWT header alg {:?} does not match expected {:?}",
197            header.alg, options.expected_algorithm
198        )));
199    }
200    let jwks = resolve_jwks(http, source.clone(), options).await?;
201    let jwks = match (source, header.kid.as_deref()) {
202        (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
203            fetch_uncached_jwks(http, jwks_url, options).await?
204        }
205        _ => jwks,
206    };
207    let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
208    let key = DecodingKey::from_jwk(jwk)
209        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
210    let mut validation = Validation::new(options.expected_algorithm);
211    if !options.required_spec_claims.is_empty() {
212        let claims = options
213            .required_spec_claims
214            .iter()
215            .map(String::as_str)
216            .collect::<Vec<_>>();
217        validation.set_required_spec_claims(&claims);
218    }
219    if let Some(issuer) = options.issuer.as_deref() {
220        validation.set_issuer(&[issuer]);
221    }
222    if let Some(audience) = options.audience.as_deref() {
223        validation.set_audience(&[audience]);
224    }
225    decode::<T>(token, &key, &validation)
226        .map(|token| token.claims)
227        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
228}
229
230fn jwk_for_header<'a>(
231    jwks: &'a JwkSet,
232    kid: Option<&str>,
233) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
234    match kid {
235        Some(kid) => jwks.find(kid).ok_or_else(|| {
236            ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
237        }),
238        None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
239        None => Err(ConnectorError::invalid_signature(
240            "JWT missing kid and JWKS contains multiple keys",
241        )),
242    }
243}
244
245pub async fn verify_jwt_json(
246    http: &reqwest::Client,
247    token: &str,
248    source: JwtKeySource<'_>,
249    options: &JwtVerificationOptions,
250) -> Result<JsonValue, ConnectorError> {
251    verify_jwt_claims(http, token, source, options).await
252}
253
254async fn fetch_cached_jwks(
255    http: &reqwest::Client,
256    jwks_url: &str,
257    options: &JwtVerificationOptions,
258) -> Result<JwkSet, ConnectorError> {
259    if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
260        return Ok(cached);
261    }
262    fetch_uncached_jwks(http, jwks_url, options).await
263}
264
265async fn fetch_uncached_jwks(
266    http: &reqwest::Client,
267    jwks_url: &str,
268    options: &JwtVerificationOptions,
269) -> Result<JwkSet, ConnectorError> {
270    if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
271        return Err(error);
272    }
273    let jwks = http
274        .get(jwks_url)
275        .send()
276        .await
277        .map_err(|error| {
278            ConnectorError::Activation(format!(
279                "fetch JWKS: {}",
280                crate::egress::redact_reqwest_error(&error)
281            ))
282        })?
283        .error_for_status()
284        .map_err(|error| {
285            ConnectorError::Activation(format!(
286                "fetch JWKS: {}",
287                crate::egress::redact_reqwest_error(&error)
288            ))
289        })?
290        .json::<JwkSet>()
291        .await
292        .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
293    store_cached_jwks(jwks_url, jwks.clone());
294    Ok(jwks)
295}
296
297fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
298    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
299    let cache = cache.read().expect("connector JWKS cache poisoned");
300    let cached = cache.get(url)?;
301    (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
302}
303
304fn store_cached_jwks(url: &str, jwks: JwkSet) {
305    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
306    cache
307        .write()
308        .expect("connector JWKS cache poisoned")
309        .insert(
310            url.to_string(),
311            CachedJwks {
312                fetched_at: Instant::now(),
313                jwks,
314            },
315        );
316}
317
318#[derive(Clone, Debug, PartialEq, Eq)]
319pub struct CursorPage {
320    pub items: Vec<JsonValue>,
321    pub next_cursor: Option<String>,
322    pub has_more: bool,
323}
324
325/// Collect cursor-paginated results without baking in a provider response
326/// schema. The caller owns page construction; this helper owns loop safety.
327pub async fn paginate_cursor<F, Fut>(
328    initial_cursor: Option<String>,
329    max_pages: Option<usize>,
330    mut fetch: F,
331) -> Result<Vec<JsonValue>, ConnectorError>
332where
333    F: FnMut(Option<String>) -> Fut,
334    Fut: Future<Output = Result<CursorPage, ConnectorError>>,
335{
336    let mut cursor = initial_cursor;
337    let mut pages = 0usize;
338    let mut results = Vec::new();
339    loop {
340        if max_pages.is_some_and(|limit| pages >= limit) {
341            break;
342        }
343        let page = fetch(cursor.clone()).await?;
344        results.extend(page.items);
345        pages += 1;
346        if !page.has_more {
347            break;
348        }
349        cursor = page.next_cursor;
350        if cursor.as_deref().is_none_or(str::is_empty) {
351            return Err(ConnectorError::Json(
352                "cursor-paginated connector response set has_more without next_cursor".to_string(),
353            ));
354        }
355    }
356    Ok(results)
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
363    use serde::{Deserialize, Serialize};
364    use serde_json::json;
365
366    #[derive(Debug, Deserialize, Serialize)]
367    struct Claims {
368        iss: String,
369        aud: String,
370        exp: i64,
371        jti: String,
372    }
373
374    fn hs_jwks() -> JwkSet {
375        serde_json::from_value(json!({
376            "keys": [{
377                "kty": "oct",
378                "kid": "test-key",
379                "alg": "HS256",
380                "k": "c2VjcmV0"
381            }]
382        }))
383        .unwrap()
384    }
385
386    fn hs_token() -> String {
387        let mut header = Header::new(Algorithm::HS256);
388        header.kid = Some("test-key".to_string());
389        encode(
390            &header,
391            &Claims {
392                iss: "issuer".to_string(),
393                aud: "audience".to_string(),
394                exp: 4_102_444_800,
395                jti: "jwt-1".to_string(),
396            },
397            &EncodingKey::from_secret(b"secret"),
398        )
399        .unwrap()
400    }
401
402    #[test]
403    fn hmac_signature_accepts_provider_prefixed_hex() {
404        let body = b"Hello, World!";
405        let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
406        assert!(verify_hmac_signature(
407            body,
408            signature,
409            "It's a Secret to Everybody",
410            HmacSignatureAlgorithm::Sha256,
411        )
412        .unwrap());
413    }
414
415    #[test]
416    fn hmac_signature_sha1_requires_explicit_legacy_algorithm() {
417        let parse_error = HmacSignatureAlgorithm::parse("sha1").expect_err("sha1 is gated");
418        assert!(parse_error.to_string().contains("allow_legacy_sha1"));
419        assert_eq!(
420            HmacSignatureAlgorithm::parse_with_legacy_sha1("sha1", true).unwrap(),
421            HmacSignatureAlgorithm::LegacySha1
422        );
423
424        let body = b"legacy";
425        let digest = hmac::hmac_sha1(b"legacy-secret", body);
426        let signature = format!("sha1={}", hex::encode(digest));
427        assert!(verify_hmac_signature(
428            body,
429            &signature,
430            "legacy-secret",
431            HmacSignatureAlgorithm::LegacySha1,
432        )
433        .unwrap());
434    }
435
436    #[tokio::test]
437    async fn jwt_claims_verify_against_inline_jwks() {
438        let http = reqwest::Client::new();
439        let claims: Claims = verify_jwt_claims(
440            &http,
441            &hs_token(),
442            JwtKeySource::Inline(&hs_jwks()),
443            &JwtVerificationOptions::default()
444                .with_algorithm(Algorithm::HS256)
445                .with_issuer("issuer")
446                .with_audience("audience")
447                .require_spec_claims(["exp", "iss", "aud"]),
448        )
449        .await
450        .unwrap();
451        assert_eq!(claims.jti, "jwt-1");
452    }
453
454    #[tokio::test]
455    async fn jwt_claims_reject_alg_confusion() {
456        // Token is signed with HS256; verifier is told to expect
457        // RS256. Even though jsonwebtoken would catch this downstream
458        // because the JWK is symmetric, our up-front guard refuses
459        // before constructing `Validation`, which is the canonical
460        // defense against alg-confusion exploits.
461        let http = reqwest::Client::new();
462        let result = verify_jwt_claims::<Claims>(
463            &http,
464            &hs_token(),
465            JwtKeySource::Inline(&hs_jwks()),
466            &JwtVerificationOptions::default()
467                .with_algorithm(Algorithm::RS256)
468                .with_issuer("issuer")
469                .with_audience("audience"),
470        )
471        .await;
472        let error = result.expect_err("HS256 token should not verify under RS256");
473        let message = error.to_string();
474        assert!(
475            message.contains("alg") && message.contains("expected"),
476            "unexpected error: {message}"
477        );
478    }
479
480    #[tokio::test]
481    async fn paginate_cursor_collects_until_has_more_is_false() {
482        let pages = [
483            CursorPage {
484                items: vec![json!({"id": 1})],
485                next_cursor: Some("b".to_string()),
486                has_more: true,
487            },
488            CursorPage {
489                items: vec![json!({"id": 2})],
490                next_cursor: None,
491                has_more: false,
492            },
493        ];
494        let mut index = 0usize;
495        let results = paginate_cursor(None, None, |_cursor| {
496            let page = pages[index].clone();
497            index += 1;
498            async move { Ok(page) }
499        })
500        .await
501        .unwrap();
502        assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
503    }
504}