Skip to main content

harn_vm/connectors/
shared.rs

1//! Shared connector building blocks that are useful to both Rust shims and
2//! Harn-authored connector packages.
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::sync::{OnceLock, RwLock};
7use std::time::{Duration as StdDuration, Instant};
8
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
11use serde::de::DeserializeOwned;
12use serde_json::Value as JsonValue;
13
14use super::{hmac, Connector, ConnectorError};
15
16const DEFAULT_JWKS_CACHE_TTL: StdDuration = StdDuration::from_secs(24 * 60 * 60);
17
18/// Base connector contract name for shared runtime code.
19///
20/// This stays as a blanket extension over `Connector` so there is one
21/// object-safe implementation contract for registry and adapter code.
22pub trait ConnectorBase: Connector {}
23
24impl<T: Connector + ?Sized> ConnectorBase for T {}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum HmacSignatureAlgorithm {
28    Sha1,
29    Sha256,
30}
31
32impl HmacSignatureAlgorithm {
33    pub fn parse(raw: &str) -> Result<Self, ConnectorError> {
34        match raw.trim().to_ascii_lowercase().as_str() {
35            "sha1" | "hmac-sha1" => Ok(Self::Sha1),
36            "sha256" | "hmac-sha256" | "" => Ok(Self::Sha256),
37            other => Err(ConnectorError::Unsupported(format!(
38                "unsupported HMAC signature algorithm `{other}`"
39            ))),
40        }
41    }
42}
43
44/// Verify a raw HMAC signature value using constant-time comparison.
45///
46/// `signature` may be a bare hex digest or a provider-style `sha256=<hex>` /
47/// `sha1=<hex>` value. Provider-specific timestamp and canonical-message
48/// checks belong in `hmac::verify_hmac_signed`.
49pub fn verify_hmac_signature(
50    body: &[u8],
51    signature: &str,
52    secret: &str,
53    algorithm: HmacSignatureAlgorithm,
54) -> Result<bool, ConnectorError> {
55    let signature = signature.trim();
56    let signature = signature
57        .strip_prefix("sha256=")
58        .or_else(|| signature.strip_prefix("sha1="))
59        .unwrap_or(signature);
60    let provided = hex::decode(signature).map_err(|error| ConnectorError::InvalidHeader {
61        name: "signature".to_string(),
62        detail: error.to_string(),
63    })?;
64    let expected = match algorithm {
65        HmacSignatureAlgorithm::Sha1 => hmac::hmac_sha1(secret.as_bytes(), body),
66        HmacSignatureAlgorithm::Sha256 => hmac::hmac_sha256(secret.as_bytes(), body),
67    };
68    Ok(hmac::secure_eq(&expected, &provided))
69}
70
71#[derive(Clone, Debug)]
72pub enum JwtKeySource<'a> {
73    Inline(&'a JwkSet),
74    Url(&'a str),
75}
76
77#[derive(Clone, Debug)]
78pub struct JwtVerificationOptions {
79    pub issuer: Option<String>,
80    pub audience: Option<String>,
81    pub required_spec_claims: Vec<String>,
82    pub jwks_cache_ttl: StdDuration,
83    pub egress_label: &'static str,
84}
85
86impl Default for JwtVerificationOptions {
87    fn default() -> Self {
88        Self {
89            issuer: None,
90            audience: None,
91            required_spec_claims: Vec::new(),
92            jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
93            egress_label: "connector:jwks",
94        }
95    }
96}
97
98impl JwtVerificationOptions {
99    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
100        self.issuer = Some(issuer.into());
101        self
102    }
103
104    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
105        self.audience = Some(audience.into());
106        self
107    }
108
109    pub fn require_spec_claims(
110        mut self,
111        claims: impl IntoIterator<Item = impl Into<String>>,
112    ) -> Self {
113        self.required_spec_claims = claims.into_iter().map(Into::into).collect();
114        self
115    }
116
117    pub fn with_egress_label(mut self, egress_label: &'static str) -> Self {
118        self.egress_label = egress_label;
119        self
120    }
121
122    pub fn with_jwks_cache_ttl(mut self, ttl: StdDuration) -> Self {
123        self.jwks_cache_ttl = ttl;
124        self
125    }
126}
127
128#[derive(Clone, Debug)]
129struct CachedJwks {
130    fetched_at: Instant,
131    jwks: JwkSet,
132}
133
134static JWKS_CACHE: OnceLock<RwLock<HashMap<String, CachedJwks>>> = OnceLock::new();
135
136pub async fn resolve_jwks(
137    http: &reqwest::Client,
138    source: JwtKeySource<'_>,
139    options: &JwtVerificationOptions,
140) -> Result<JwkSet, ConnectorError> {
141    match source {
142        JwtKeySource::Inline(jwks) => Ok(jwks.clone()),
143        JwtKeySource::Url(jwks_url) => fetch_cached_jwks(http, jwks_url, options).await,
144    }
145}
146
147pub async fn verify_jwt_claims<T>(
148    http: &reqwest::Client,
149    token: &str,
150    source: JwtKeySource<'_>,
151    options: &JwtVerificationOptions,
152) -> Result<T, ConnectorError>
153where
154    T: DeserializeOwned,
155{
156    let header = decode_header(token)
157        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
158    let jwks = resolve_jwks(http, source.clone(), options).await?;
159    let jwks = match (source, header.kid.as_deref()) {
160        (JwtKeySource::Url(jwks_url), Some(kid)) if jwks.find(kid).is_none() => {
161            fetch_uncached_jwks(http, jwks_url, options).await?
162        }
163        _ => jwks,
164    };
165    let jwk = jwk_for_header(&jwks, header.kid.as_deref())?;
166    let key = DecodingKey::from_jwk(jwk)
167        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))?;
168    let mut validation = Validation::new(header.alg);
169    if !options.required_spec_claims.is_empty() {
170        let claims = options
171            .required_spec_claims
172            .iter()
173            .map(String::as_str)
174            .collect::<Vec<_>>();
175        validation.set_required_spec_claims(&claims);
176    }
177    if let Some(issuer) = options.issuer.as_deref() {
178        validation.set_issuer(&[issuer]);
179    }
180    if let Some(audience) = options.audience.as_deref() {
181        validation.set_audience(&[audience]);
182    }
183    decode::<T>(token, &key, &validation)
184        .map(|token| token.claims)
185        .map_err(|error| ConnectorError::invalid_signature(error.to_string()))
186}
187
188fn jwk_for_header<'a>(
189    jwks: &'a JwkSet,
190    kid: Option<&str>,
191) -> Result<&'a jsonwebtoken::jwk::Jwk, ConnectorError> {
192    match kid {
193        Some(kid) => jwks.find(kid).ok_or_else(|| {
194            ConnectorError::invalid_signature(format!("JWT kid `{kid}` was not found in JWKS"))
195        }),
196        None if jwks.keys.len() == 1 => Ok(&jwks.keys[0]),
197        None => Err(ConnectorError::invalid_signature(
198            "JWT missing kid and JWKS contains multiple keys",
199        )),
200    }
201}
202
203pub async fn verify_jwt_json(
204    http: &reqwest::Client,
205    token: &str,
206    source: JwtKeySource<'_>,
207    options: &JwtVerificationOptions,
208) -> Result<JsonValue, ConnectorError> {
209    verify_jwt_claims(http, token, source, options).await
210}
211
212async fn fetch_cached_jwks(
213    http: &reqwest::Client,
214    jwks_url: &str,
215    options: &JwtVerificationOptions,
216) -> Result<JwkSet, ConnectorError> {
217    if let Some(cached) = cached_jwks(jwks_url, options.jwks_cache_ttl) {
218        return Ok(cached);
219    }
220    fetch_uncached_jwks(http, jwks_url, options).await
221}
222
223async fn fetch_uncached_jwks(
224    http: &reqwest::Client,
225    jwks_url: &str,
226    options: &JwtVerificationOptions,
227) -> Result<JwkSet, ConnectorError> {
228    if let Some(error) = crate::egress::connector_error_for_url(options.egress_label, jwks_url) {
229        return Err(error);
230    }
231    let jwks = http
232        .get(jwks_url)
233        .send()
234        .await
235        .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
236        .error_for_status()
237        .map_err(|error| ConnectorError::Activation(format!("fetch JWKS: {error}")))?
238        .json::<JwkSet>()
239        .await
240        .map_err(|error| ConnectorError::Activation(format!("decode JWKS: {error}")))?;
241    store_cached_jwks(jwks_url, jwks.clone());
242    Ok(jwks)
243}
244
245fn cached_jwks(url: &str, ttl: StdDuration) -> Option<JwkSet> {
246    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
247    let cache = cache.read().expect("connector JWKS cache poisoned");
248    let cached = cache.get(url)?;
249    (cached.fetched_at.elapsed() < ttl).then(|| cached.jwks.clone())
250}
251
252fn store_cached_jwks(url: &str, jwks: JwkSet) {
253    let cache = JWKS_CACHE.get_or_init(|| RwLock::new(HashMap::new()));
254    cache
255        .write()
256        .expect("connector JWKS cache poisoned")
257        .insert(
258            url.to_string(),
259            CachedJwks {
260                fetched_at: Instant::now(),
261                jwks,
262            },
263        );
264}
265
266#[derive(Clone, Debug, PartialEq)]
267pub struct CursorPage {
268    pub items: Vec<JsonValue>,
269    pub next_cursor: Option<String>,
270    pub has_more: bool,
271}
272
273/// Collect cursor-paginated results without baking in a provider response
274/// schema. The caller owns page construction; this helper owns loop safety.
275pub async fn paginate_cursor<F, Fut>(
276    initial_cursor: Option<String>,
277    max_pages: Option<usize>,
278    mut fetch: F,
279) -> Result<Vec<JsonValue>, ConnectorError>
280where
281    F: FnMut(Option<String>) -> Fut,
282    Fut: Future<Output = Result<CursorPage, ConnectorError>>,
283{
284    let mut cursor = initial_cursor;
285    let mut pages = 0usize;
286    let mut results = Vec::new();
287    loop {
288        if max_pages.is_some_and(|limit| pages >= limit) {
289            break;
290        }
291        let page = fetch(cursor.clone()).await?;
292        results.extend(page.items);
293        pages += 1;
294        if !page.has_more {
295            break;
296        }
297        cursor = page.next_cursor;
298        if cursor.as_deref().is_none_or(str::is_empty) {
299            return Err(ConnectorError::Json(
300                "cursor-paginated connector response set has_more without next_cursor".to_string(),
301            ));
302        }
303    }
304    Ok(results)
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
311    use serde::{Deserialize, Serialize};
312    use serde_json::json;
313
314    #[derive(Debug, Deserialize, Serialize)]
315    struct Claims {
316        iss: String,
317        aud: String,
318        exp: i64,
319        jti: String,
320    }
321
322    fn hs_jwks() -> JwkSet {
323        serde_json::from_value(json!({
324            "keys": [{
325                "kty": "oct",
326                "kid": "test-key",
327                "alg": "HS256",
328                "k": "c2VjcmV0"
329            }]
330        }))
331        .unwrap()
332    }
333
334    fn hs_token() -> String {
335        let mut header = Header::new(Algorithm::HS256);
336        header.kid = Some("test-key".to_string());
337        encode(
338            &header,
339            &Claims {
340                iss: "issuer".to_string(),
341                aud: "audience".to_string(),
342                exp: 4_102_444_800,
343                jti: "jwt-1".to_string(),
344            },
345            &EncodingKey::from_secret(b"secret"),
346        )
347        .unwrap()
348    }
349
350    #[test]
351    fn hmac_signature_accepts_provider_prefixed_hex() {
352        let body = b"Hello, World!";
353        let signature = "sha256=757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17";
354        assert!(verify_hmac_signature(
355            body,
356            signature,
357            "It's a Secret to Everybody",
358            HmacSignatureAlgorithm::Sha256,
359        )
360        .unwrap());
361    }
362
363    #[tokio::test]
364    async fn jwt_claims_verify_against_inline_jwks() {
365        let http = reqwest::Client::new();
366        let claims: Claims = verify_jwt_claims(
367            &http,
368            &hs_token(),
369            JwtKeySource::Inline(&hs_jwks()),
370            &JwtVerificationOptions::default()
371                .with_issuer("issuer")
372                .with_audience("audience")
373                .require_spec_claims(["exp", "iss", "aud"]),
374        )
375        .await
376        .unwrap();
377        assert_eq!(claims.jti, "jwt-1");
378    }
379
380    #[tokio::test]
381    async fn paginate_cursor_collects_until_has_more_is_false() {
382        let pages = [
383            CursorPage {
384                items: vec![json!({"id": 1})],
385                next_cursor: Some("b".to_string()),
386                has_more: true,
387            },
388            CursorPage {
389                items: vec![json!({"id": 2})],
390                next_cursor: None,
391                has_more: false,
392            },
393        ];
394        let mut index = 0usize;
395        let results = paginate_cursor(None, None, |_cursor| {
396            let page = pages[index].clone();
397            index += 1;
398            async move { Ok(page) }
399        })
400        .await
401        .unwrap();
402        assert_eq!(results, vec![json!({"id": 1}), json!({"id": 2})]);
403    }
404}