Skip to main content

harn_vm/connectors/
shared.rs

1//! Shared connector building blocks that are useful to both Rust shims and
2//! Harn-authored connector packages.
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_hours(24);
17
18/// Base connector contract name for shared runtime code.
19///
20/// This stays as a blanket extension over `Connector` so there is one
21/// object-safe implementation contract for registry and adapter code.
22pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28    LegacySha1,
29    Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33    pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34        Self::parse_with_legacy_sha1(raw, false)
35    }
36
37    pub fn parse_with_legacy_sha1(
38        raw: &str,
39        allow_legacy_sha1: bool,
40    ) -> Result<Self, ConnectorError> {
41        match raw.trim().to_ascii_lowercase().as_str() {
42            "sha1" | "hmac-sha1" if allow_legacy_sha1 => Ok(Self::LegacySha1),
43            "sha1" | "hmac-sha1" => Err(ConnectorError::Unsupported(
44                "HMAC-SHA1 is legacy; set `allow_legacy_sha1: true` for an existing provider"
45                    .to_string(),
46            )),
47            "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
48            other => Err(ConnectorError::Unsupported(format!(
49                "unsupported HMAC signature algorithm `{other}`"
50            ))),
51        }
52    }
53}
54
55/// Verify a raw HMAC signature value using constant-time comparison.
56///
57/// `signature` may be a bare hex digest or a provider-style `sha256=<hex>` /
58/// `sha1=<hex>` value. Provider-specific timestamp and canonical-message
59/// checks belong in `hmac::verify_hmac_signed`.
60pub fn verify_hmac_signature(
61    body: &[u8],
62    signature: &str,
63    secret: &str,
64    algorithm: HmacSignatureAlgorithm,
65) -> Result<bool, ConnectorError> {
66    let signature = signature.trim();
67    let signature = signature
68        .strip_prefix("sha256=")
69        .or_else(|| signature.strip_prefix("sha1="))
70        .unwrap_or(signature);
71    let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
72        name: "signature".to_string(),
73        detail: error.to_string(),
74    })?;
75    let expected = match algorithm {
76        HmacSignatureAlgorithm::LegacySha1 => hmac::hmac_sha1(secret.as_bytes(), body),
77        HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
78    };
79    Ok(hmac::secure_eq(&expected, &provided))
80}
81
82#[derive(Clone, Debug)]
83pub enum JwtKeySource<'a> {
84    Inline(&'a JwkSet),
85    Url(&'a str),
86}
87
88#[derive(Clone, Debug)]
89pub struct JwtVerificationOptions {
90    pub issuer: Option<String>,
91    pub audience: Option<String>,
92    pub required_spec_claims: Vec<String>,
93    pub jwks_cache_ttl: StdDuration,
94    pub egress_label: &'static str,
95    /// Signing algorithm the caller expects this token to be signed
96    /// with. The verifier asserts `header.alg == expected_algorithm`
97    /// BEFORE constructing the `Validation` so an attacker-controlled
98    /// `alg` header cannot down/up-grade the verification algorithm
99    /// (the canonical "alg confusion" footgun). Defaults to RS256
100    /// (`Algorithm::RS256`) — change it explicitly via
101    /// [`JwtVerificationOptions::with_algorithm`] when a connector
102    /// uses HS256 / ES256 / etc.
103    pub expected_algorithm: Algorithm,
104}
105
106impl Default for JwtVerificationOptions {
107    fn default() -> Self {
108        Self {
109            issuer: None,
110            audience: None,
111            required_spec_claims: Vec::new(),
112            jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
113            egress_label: "connector:jwks",
114            expected_algorithm: Algorithm::RS256,
115        }
116    }
117}
118
119impl JwtVerificationOptions {
120    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
121        self.issuer = Some(issuer.into());
122        self
123    }
124
125    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
126        self.audience = Some(audience.into());
127        self
128    }
129
130    pub fn require_spec_claims(
131        mut self,
132        claims: impl IntoIterator<Item = impl Into<String>>,
133    ) -> Self {
134        self.required_spec_claims = claims.into_iter().map(Into::into).collect();
135        self
136    }
137
138    pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
139        self.egress_label = egress_label;
140        self
141    }
142
143    pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
144        self.jwks_cache_ttl = ttl;
145        self
146    }
147
148    /// Override the expected signing algorithm. Use this when the
149    /// connector you're verifying is known to sign with something
150    /// other than RS256 (e.g. HS256 for Slack-style shared secrets,
151    /// ES256 for some Apple flows).
152    pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
153        self.expected_algorithm = algorithm;
154        self
155    }
156}
157
158#[derive(Clone, Debug)]
159struct CachedJwks {
160    fetched_at: Instant,
161    jwks: JwkSet,
162}
163
164static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
165
166pub async fn resolve_jwks(
167    http: &reqwest::Client,
168    source: JwtKeySource<'_>,
169    options: &JwtVerificationOptions,
170) -> Result<JwkSet, ConnectorError> {
171    match source {
172        JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
173        JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
174    }
175}
176
177pub async fn verify_jwt_claims<T>(
178    http: &reqwest::Client,
179    token: &str,
180    source: JwtKeySource<'_>,
181    options: &JwtVerificationOptions,
182) -> Result<T, ConnectorError>
183where
184    T: DeserializeOwned,
185{
186    let header = decode_header(token)
187        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
188    // Defense-in-depth against JWT "alg confusion": refuse to accept a
189    // token whose self-declared algorithm does not match what the
190    // caller said to expect. jsonwebtoken 10.x cross-checks JWK key
191    // type against `header.alg`, which closes the immediate bypass,
192    // but the canonical defense is to never trust `header.alg` for
193    // algorithm selection in the first place.
194    if header.alg != options.expected_algorithm {
195        return Err(ConnectorError::invalid_signature(format!(
196            "JWT header alg {:?} does not match expected {:?}",
197            header.alg, options.expected_algorithm
198        )));
199    }
200    let jwks = resolve_jwks(http, source.clone(), options).await?;
201    let jwks = match (source, header.kid.as_deref()) {
202        (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
203            fetch_uncached_jwks(http, jwks_url, options).await?
204        }
205        _ => jwks,
206    };
207    let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
208    let key = DecodingKey::from_jwk(jwk)
209        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
210    let mut validation = Validation::new(options.expected_algorithm);
211    if !options.required_spec_claims.is_empty() {
212        let claims = options
213            .required_spec_claims
214            .iter()
215            .map(String::as_str)
216            .collect::<Vec<_>>();
217        validation.set_required_spec_claims(&claims);
218    }
219    if let Some(issuer) = options.issuer.as_deref() {
220        validation.set_issuer(&[issuer]);
221    }
222    if let Some(audience) = options.audience.as_deref() {
223        validation.set_audience(&[audience]);
224    }
225    decode::<T>(token, &key, &validation)
226        .map(|token| token.claims)
227        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
228}
229
230fn jwk_for_header<'a>(
231    jwks: &'a JwkSet,
232    kid: Option<&str>,
233) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
234    match kid {
235        Some(kid) => jwks.find(kid).ok_or_else(|| {
236            ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
237        }),
238        None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
239        None => Err(ConnectorError::invalid_signature(
240            "JWT missing kid and JWKS contains multiple keys",
241        )),
242    }
243}
244
245pub async fn verify_jwt_json(
246    http: &reqwest::Client,
247    token: &str,
248    source: JwtKeySource<'_>,
249    options: &JwtVerificationOptions,
250) -> Result<JsonValue, ConnectorError> {
251    verify_jwt_claims(http, token, source, options).await
252}
253
254async fn fetch_cached_jwks(
255    http: &reqwest::Client,
256    jwks_url: &str,
257    options: &JwtVerificationOptions,
258) -> Result<JwkSet, ConnectorError> {
259    if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
260        return Ok(cached);
261    }
262    fetch_uncached_jwks(http, jwks_url, options).await
263}
264
265async fn fetch_uncached_jwks(
266    http: &reqwest::Client,
267    jwks_url: &str,
268    options: &JwtVerificationOptions,
269) -> Result<JwkSet, ConnectorError> {
270    if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
271        return Err(error);
272    }
273    let jwks = http
274        .get(jwks_url)
275        .send()
276        .await
277        .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
278        .error_for_status()
279        .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
280        .json::<JwkSet>()
281        .await
282        .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
283    store_cached_jwks(jwks_url, jwks.clone());
284    Ok(jwks)
285}
286
287fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
288    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
289    let cache = cache.read().expect("connector JWKS cache poisoned");
290    let cached = cache.get(url)?;
291    (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
292}
293
294fn store_cached_jwks(url: &str, jwks: JwkSet) {
295    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
296    cache
297        .write()
298        .expect("connector JWKS cache poisoned")
299        .insert(
300            url.to_string(),
301            CachedJwks {
302                fetched_at: Instant::now(),
303                jwks,
304            },
305        );
306}
307
308#[derive(Clone, Debug, PartialEq, Eq)]
309pub struct CursorPage {
310    pub items: Vec<JsonValue>,
311    pub next_cursor: Option<String>,
312    pub has_more: bool,
313}
314
315/// Collect cursor-paginated results without baking in a provider response
316/// schema. The caller owns page construction; this helper owns loop safety.
317pub async fn paginate_cursor<F, Fut>(
318    initial_cursor: Option<String>,
319    max_pages: Option<usize>,
320    mut fetch: F,
321) -> Result<Vec<JsonValue>, ConnectorError>
322where
323    F: FnMut(Option<String>) -> Fut,
324    Fut: Future<Output = Result<CursorPage, ConnectorError>>,
325{
326    let mut cursor = initial_cursor;
327    let mut pages = 0usize;
328    let mut results = Vec::new();
329    loop {
330        if max_pages.is_some_and(|limit| pages >= limit) {
331            break;
332        }
333        let page = fetch(cursor.clone()).await?;
334        results.extend(page.items);
335        pages += 1;
336        if !page.has_more {
337            break;
338        }
339        cursor = page.next_cursor;
340        if cursor.as_deref().is_none_or(str::is_empty) {
341            return Err(ConnectorError::Json(
342                "cursor-paginated connector response set has_more without next_cursor".to_string(),
343            ));
344        }
345    }
346    Ok(results)
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
353    use serde::{Deserialize, Serialize};
354    use serde_json::json;
355
356    #[derive(Debug, Deserialize, Serialize)]
357    struct Claims {
358        iss: String,
359        aud: String,
360        exp: i64,
361        jti: String,
362    }
363
364    fn hs_jwks() -> JwkSet {
365        serde_json::from_value(json!({
366            "keys": [{
367                "kty": "oct",
368                "kid": "test-key",
369                "alg": "HS256",
370                "k": "c2VjcmV0"
371            }]
372        }))
373        .unwrap()
374    }
375
376    fn hs_token() -> String {
377        let mut header = Header::new(Algorithm::HS256);
378        header.kid = Some("test-key".to_string());
379        encode(
380            &header,
381            &Claims {
382                iss: "issuer".to_string(),
383                aud: "audience".to_string(),
384                exp: 4_102_444_800,
385                jti: "jwt-1".to_string(),
386            },
387            &EncodingKey::from_secret(b"secret"),
388        )
389        .unwrap()
390    }
391
392    #[test]
393    fn hmac_signature_accepts_provider_prefixed_hex() {
394        let body = b"Hello, World!";
395        let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
396        assert!(verify_hmac_signature(
397            body,
398            signature,
399            "It's a Secret to Everybody",
400            HmacSignatureAlgorithm::Sha256,
401        )
402        .unwrap());
403    }
404
405    #[test]
406    fn hmac_signature_sha1_requires_explicit_legacy_algorithm() {
407        let parse_error = HmacSignatureAlgorithm::parse("sha1").expect_err("sha1 is gated");
408        assert!(parse_error.to_string().contains("allow_legacy_sha1"));
409        assert_eq!(
410            HmacSignatureAlgorithm::parse_with_legacy_sha1("sha1", true).unwrap(),
411            HmacSignatureAlgorithm::LegacySha1
412        );
413
414        let body = b"legacy";
415        let digest = hmac::hmac_sha1(b"legacy-secret", body);
416        let signature = format!("sha1={}", hex::encode(digest));
417        assert!(verify_hmac_signature(
418            body,
419            &signature,
420            "legacy-secret",
421            HmacSignatureAlgorithm::LegacySha1,
422        )
423        .unwrap());
424    }
425
426    #[tokio::test]
427    async fn jwt_claims_verify_against_inline_jwks() {
428        let http = reqwest::Client::new();
429        let claims: Claims = verify_jwt_claims(
430            &http,
431            &hs_token(),
432            JwtKeySource::Inline(&hs_jwks()),
433            &JwtVerificationOptions::default()
434                .with_algorithm(Algorithm::HS256)
435                .with_issuer("issuer")
436                .with_audience("audience")
437                .require_spec_claims(["exp", "iss", "aud"]),
438        )
439        .await
440        .unwrap();
441        assert_eq!(claims.jti, "jwt-1");
442    }
443
444    #[tokio::test]
445    async fn jwt_claims_reject_alg_confusion() {
446        // Token is signed with HS256; verifier is told to expect
447        // RS256. Even though jsonwebtoken would catch this downstream
448        // because the JWK is symmetric, our up-front guard refuses
449        // before constructing `Validation`, which is the canonical
450        // defense against alg-confusion exploits.
451        let http = reqwest::Client::new();
452        let result = verify_jwt_claims::<Claims>(
453            &http,
454            &hs_token(),
455            JwtKeySource::Inline(&hs_jwks()),
456            &JwtVerificationOptions::default()
457                .with_algorithm(Algorithm::RS256)
458                .with_issuer("issuer")
459                .with_audience("audience"),
460        )
461        .await;
462        let error = result.expect_err("HS256 token should not verify under RS256");
463        let message = error.to_string();
464        assert!(
465            message.contains("alg") && message.contains("expected"),
466            "unexpected error: {message}"
467        );
468    }
469
470    #[tokio::test]
471    async fn paginate_cursor_collects_until_has_more_is_false() {
472        let pages = [
473            CursorPage {
474                items: vec![json!({"id": 1})],
475                next_cursor: Some("b".to_string()),
476                has_more: true,
477            },
478            CursorPage {
479                items: vec![json!({"id": 2})],
480                next_cursor: None,
481                has_more: false,
482            },
483        ];
484        let mut index = 0usize;
485        let results = paginate_cursor(None, None, |_cursor| {
486            let page = pages[index].clone();
487            index += 1;
488            async move { Ok(page) }
489        })
490        .await
491        .unwrap();
492        assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
493    }
494}