Skip to main content

macp_auth/auth/resolvers/
jwt_bearer.rs

1use crate::auth::resolver::{AuthError, AuthResolver, ResolvedIdentity};
2use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
3use serde::Deserialize;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tonic::metadata::MetadataMap;
7
8#[derive(Debug, Clone, Deserialize)]
9struct MACPClaims {
10    sub: String,
11    #[serde(default)]
12    macp_scopes: Option<MACPScopes>,
13}
14
15#[derive(Debug, Clone, Deserialize, Default)]
16struct MACPScopes {
17    #[serde(default)]
18    can_start_sessions: Option<bool>,
19    #[serde(default)]
20    can_manage_mode_registry: Option<bool>,
21    #[serde(default)]
22    is_observer: Option<bool>,
23    #[serde(default)]
24    allowed_modes: Option<Vec<String>>,
25    #[serde(default)]
26    max_open_sessions: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30pub struct JwtConfig {
31    pub issuer: String,
32    pub audience: String,
33    pub algorithms: Vec<Algorithm>,
34}
35
36struct CachedKeys {
37    keys: Vec<DecodingKey>,
38    fetched_at: std::time::Instant,
39}
40
41pub struct JwtBearerResolver {
42    config: JwtConfig,
43    jwks_source: JwksSource,
44    cached_keys: Arc<RwLock<Option<CachedKeys>>>,
45    cache_ttl: std::time::Duration,
46}
47
48enum JwksSource {
49    Inline(Vec<DecodingKey>),
50    Url(String),
51}
52
53impl JwtBearerResolver {
54    pub fn from_inline_json(config: JwtConfig, jwks_json: &str) -> Result<Self, String> {
55        let jwks: serde_json::Value =
56            serde_json::from_str(jwks_json).map_err(|e| format!("invalid JWKS JSON: {e}"))?;
57        let keys = Self::parse_jwks(&jwks)?;
58        tracing::info!(
59            keys = keys.len(),
60            issuer = %config.issuer,
61            "JWT resolver initialized with inline JWKS"
62        );
63        Ok(Self {
64            config,
65            jwks_source: JwksSource::Inline(keys.clone()),
66            cached_keys: Arc::new(RwLock::new(Some(CachedKeys {
67                keys,
68                fetched_at: std::time::Instant::now(),
69            }))),
70            cache_ttl: std::time::Duration::from_secs(u64::MAX),
71        })
72    }
73
74    pub fn from_url(config: JwtConfig, url: String, cache_ttl_secs: u64) -> Self {
75        tracing::info!(
76            url = %url,
77            issuer = %config.issuer,
78            cache_ttl_secs,
79            "JWT resolver initialized with JWKS URL"
80        );
81        Self {
82            config,
83            jwks_source: JwksSource::Url(url),
84            cached_keys: Arc::new(RwLock::new(None)),
85            cache_ttl: std::time::Duration::from_secs(cache_ttl_secs),
86        }
87    }
88
89    fn extract_bearer(metadata: &MetadataMap) -> Option<String> {
90        metadata
91            .get("authorization")
92            .and_then(|v| v.to_str().ok())
93            .and_then(|v| v.strip_prefix("Bearer "))
94            .map(str::to_string)
95    }
96
97    async fn get_keys(&self) -> Result<Vec<DecodingKey>, AuthError> {
98        {
99            let guard = self.cached_keys.read().await;
100            if let Some(cached) = guard.as_ref() {
101                if cached.fetched_at.elapsed() < self.cache_ttl {
102                    return Ok(cached.keys.clone());
103                }
104            }
105        }
106
107        match &self.jwks_source {
108            JwksSource::Inline(keys) => Ok(keys.clone()),
109            JwksSource::Url(url) => {
110                let keys = self.fetch_jwks(url).await?;
111                let mut guard = self.cached_keys.write().await;
112                *guard = Some(CachedKeys {
113                    keys: keys.clone(),
114                    fetched_at: std::time::Instant::now(),
115                });
116                Ok(keys)
117            }
118        }
119    }
120
121    async fn fetch_jwks(&self, url: &str) -> Result<Vec<DecodingKey>, AuthError> {
122        let resp = reqwest::get(url)
123            .await
124            .map_err(|e| AuthError::FetchFailed(format!("JWKS fetch failed: {e}")))?;
125        let jwks: serde_json::Value = resp
126            .json()
127            .await
128            .map_err(|e| AuthError::FetchFailed(format!("JWKS parse failed: {e}")))?;
129        Self::parse_jwks(&jwks).map_err(AuthError::FetchFailed)
130    }
131
132    fn parse_jwks(jwks: &serde_json::Value) -> Result<Vec<DecodingKey>, String> {
133        let keys_arr = jwks
134            .get("keys")
135            .and_then(|k| k.as_array())
136            .ok_or_else(|| "JWKS missing 'keys' array".to_string())?;
137
138        let mut decoding_keys = Vec::new();
139        for key in keys_arr {
140            let kty = key.get("kty").and_then(|v| v.as_str()).unwrap_or("");
141            match kty {
142                "RSA" => {
143                    let n = key.get("n").and_then(|v| v.as_str()).unwrap_or("");
144                    let e = key.get("e").and_then(|v| v.as_str()).unwrap_or("");
145                    if !n.is_empty() && !e.is_empty() {
146                        if let Ok(dk) = DecodingKey::from_rsa_components(n, e) {
147                            decoding_keys.push(dk);
148                        }
149                    }
150                }
151                "EC" => {
152                    let x = key.get("x").and_then(|v| v.as_str()).unwrap_or("");
153                    let y = key.get("y").and_then(|v| v.as_str()).unwrap_or("");
154                    let crv = key.get("crv").and_then(|v| v.as_str()).unwrap_or("P-256");
155                    if !x.is_empty() && !y.is_empty() {
156                        if let Ok(dk) = DecodingKey::from_ec_components(x, y) {
157                            let _ = crv;
158                            decoding_keys.push(dk);
159                        }
160                    }
161                }
162                "oct" => {
163                    if let Some(k_val) = key.get("k").and_then(|v| v.as_str()) {
164                        decoding_keys.push(
165                            DecodingKey::from_base64_secret(k_val)
166                                .unwrap_or_else(|_| DecodingKey::from_secret(k_val.as_bytes())),
167                        );
168                    }
169                }
170                _ => {}
171            }
172        }
173
174        if decoding_keys.is_empty() {
175            return Err("no usable keys found in JWKS".to_string());
176        }
177        Ok(decoding_keys)
178    }
179}
180
181#[async_trait::async_trait]
182impl AuthResolver for JwtBearerResolver {
183    fn name(&self) -> &str {
184        "jwt_bearer"
185    }
186
187    async fn resolve(&self, metadata: &MetadataMap) -> Result<Option<ResolvedIdentity>, AuthError> {
188        let token = match Self::extract_bearer(metadata) {
189            Some(t) => t,
190            None => return Ok(None),
191        };
192
193        // Only handle JWT-shaped tokens (contain dots)
194        if !token.contains('.') {
195            return Ok(None);
196        }
197
198        let keys = self.get_keys().await?;
199
200        // Inspect the token header to pick a single algorithm to validate against.
201        // jsonwebtoken 9 requires every algorithm in validation.algorithms to match
202        // the DecodingKey's family, so a mixed list (RS256 + HS256) with one key
203        // would always fail with InvalidAlgorithm. We still gate on the configured
204        // allowlist — if the token's alg isn't configured, we reject it.
205        let header = decode_header(&token)
206            .map_err(|e| AuthError::InvalidCredential(format!("malformed JWT header: {e}")))?;
207        if !self.config.algorithms.contains(&header.alg) {
208            return Err(AuthError::InvalidCredential(format!(
209                "JWT algorithm {:?} is not in the configured allowlist",
210                header.alg
211            )));
212        }
213        let mut validation = Validation::new(header.alg);
214        validation.set_issuer(&[&self.config.issuer]);
215        validation.set_audience(&[&self.config.audience]);
216        validation.algorithms = vec![header.alg];
217
218        let mut last_err = None;
219        for key in &keys {
220            match decode::<MACPClaims>(&token, key, &validation) {
221                Ok(token_data) => {
222                    let claims = token_data.claims;
223                    let scopes = claims.macp_scopes.unwrap_or_default();
224
225                    return Ok(Some(ResolvedIdentity {
226                        sender: claims.sub,
227                        allowed_modes: scopes.allowed_modes.map(|m| m.into_iter().collect()),
228                        can_start_sessions: scopes.can_start_sessions.unwrap_or(true),
229                        max_open_sessions: scopes.max_open_sessions,
230                        can_manage_mode_registry: scopes.can_manage_mode_registry.unwrap_or(false),
231                        is_observer: scopes.is_observer.unwrap_or(false),
232                        resolver: "jwt_bearer".to_string(),
233                    }));
234                }
235                Err(e) => {
236                    last_err = Some(e);
237                    continue;
238                }
239            }
240        }
241
242        match last_err {
243            Some(e) => {
244                use jsonwebtoken::errors::ErrorKind;
245                match e.kind() {
246                    ErrorKind::ExpiredSignature => Err(AuthError::Expired),
247                    ErrorKind::InvalidIssuer => {
248                        Err(AuthError::InvalidCredential("invalid issuer".to_string()))
249                    }
250                    ErrorKind::InvalidAudience => {
251                        Err(AuthError::InvalidCredential("invalid audience".to_string()))
252                    }
253                    _ => Err(AuthError::InvalidCredential(format!(
254                        "JWT validation failed: {e}"
255                    ))),
256                }
257            }
258            None => Err(AuthError::InvalidCredential(
259                "no keys available to validate JWT".to_string(),
260            )),
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use base64::Engine;
269    use jsonwebtoken::{encode, EncodingKey, Header};
270    use serde::Serialize;
271
272    const ISSUER: &str = "https://issuer.test";
273    const AUDIENCE: &str = "macp-runtime";
274    const SECRET: &[u8] = b"super-secret-symmetric-key-32-by";
275
276    #[derive(Serialize)]
277    struct TestClaims<'a> {
278        sub: &'a str,
279        iss: &'a str,
280        aud: &'a str,
281        exp: i64,
282        #[serde(skip_serializing_if = "Option::is_none")]
283        macp_scopes: Option<serde_json::Value>,
284    }
285
286    fn jwks_inline() -> String {
287        let k = base64::engine::general_purpose::STANDARD.encode(SECRET);
288        serde_json::json!({
289            "keys": [
290                { "kty": "oct", "alg": "HS256", "k": k }
291            ]
292        })
293        .to_string()
294    }
295
296    fn config() -> JwtConfig {
297        JwtConfig {
298            issuer: ISSUER.to_string(),
299            audience: AUDIENCE.to_string(),
300            algorithms: vec![Algorithm::HS256],
301        }
302    }
303
304    fn sign(claims: &TestClaims) -> String {
305        let mut header = Header::new(Algorithm::HS256);
306        header.kid = Some("test-key".into());
307        encode(&header, claims, &EncodingKey::from_secret(SECRET)).unwrap()
308    }
309
310    fn bearer(token: &str) -> MetadataMap {
311        let mut m = MetadataMap::new();
312        m.insert("authorization", format!("Bearer {token}").parse().unwrap());
313        m
314    }
315
316    #[tokio::test]
317    async fn valid_jwt_resolves_to_identity_with_scopes() {
318        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
319        let token = sign(&TestClaims {
320            sub: "agent://alice",
321            iss: ISSUER,
322            aud: AUDIENCE,
323            exp: (chrono::Utc::now().timestamp() + 300),
324            macp_scopes: Some(serde_json::json!({
325                "allowed_modes": ["macp.mode.decision.v1"],
326                "can_start_sessions": true,
327                "max_open_sessions": 5,
328                "can_manage_mode_registry": false,
329                "is_observer": false,
330            })),
331        });
332
333        let id = resolver
334            .resolve(&bearer(&token))
335            .await
336            .expect("ok")
337            .expect("some");
338        assert_eq!(id.sender, "agent://alice");
339        assert_eq!(id.resolver, "jwt_bearer");
340        assert!(id.can_start_sessions);
341        assert_eq!(id.max_open_sessions, Some(5));
342        assert!(!id.can_manage_mode_registry);
343        assert!(!id.is_observer);
344        let modes = id.allowed_modes.unwrap();
345        assert!(modes.contains("macp.mode.decision.v1"));
346    }
347
348    #[tokio::test]
349    async fn jwt_without_scopes_defaults_to_permissive_sender() {
350        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
351        let token = sign(&TestClaims {
352            sub: "agent://bob",
353            iss: ISSUER,
354            aud: AUDIENCE,
355            exp: (chrono::Utc::now().timestamp() + 300),
356            macp_scopes: None,
357        });
358        let id = resolver.resolve(&bearer(&token)).await.unwrap().unwrap();
359        assert_eq!(id.sender, "agent://bob");
360        assert!(id.can_start_sessions); // default when unspecified
361        assert!(id.allowed_modes.is_none());
362        assert!(!id.is_observer);
363    }
364
365    #[tokio::test]
366    async fn expired_jwt_returns_expired_error() {
367        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
368        // Exceed the default 60s leeway applied by jsonwebtoken's Validation.
369        let token = sign(&TestClaims {
370            sub: "agent://alice",
371            iss: ISSUER,
372            aud: AUDIENCE,
373            exp: (chrono::Utc::now().timestamp() - 600),
374            macp_scopes: None,
375        });
376        let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
377        assert!(matches!(err, AuthError::Expired), "got {err:?}");
378    }
379
380    #[tokio::test]
381    async fn wrong_issuer_rejected() {
382        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
383        let token = sign(&TestClaims {
384            sub: "agent://alice",
385            iss: "https://other.example",
386            aud: AUDIENCE,
387            exp: (chrono::Utc::now().timestamp() + 300),
388            macp_scopes: None,
389        });
390        let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
391        assert!(
392            matches!(err, AuthError::InvalidCredential(ref m) if m.contains("issuer")),
393            "got {err:?}"
394        );
395    }
396
397    #[tokio::test]
398    async fn wrong_audience_rejected() {
399        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
400        let token = sign(&TestClaims {
401            sub: "agent://alice",
402            iss: ISSUER,
403            aud: "other-audience",
404            exp: (chrono::Utc::now().timestamp() + 300),
405            macp_scopes: None,
406        });
407        let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
408        assert!(
409            matches!(err, AuthError::InvalidCredential(ref m) if m.contains("audience")),
410            "got {err:?}"
411        );
412    }
413
414    #[tokio::test]
415    async fn bad_signature_rejected() {
416        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
417        // Sign with a different key — signature won't verify.
418        let claims = TestClaims {
419            sub: "agent://alice",
420            iss: ISSUER,
421            aud: AUDIENCE,
422            exp: (chrono::Utc::now().timestamp() + 300),
423            macp_scopes: None,
424        };
425        let bad_token = encode(
426            &Header::new(Algorithm::HS256),
427            &claims,
428            &EncodingKey::from_secret(b"different-key-bytes-0123456789!!"),
429        )
430        .unwrap();
431        let err = resolver.resolve(&bearer(&bad_token)).await.unwrap_err();
432        assert!(
433            matches!(err, AuthError::InvalidCredential(_)),
434            "got {err:?}"
435        );
436    }
437
438    #[tokio::test]
439    async fn opaque_bearer_token_is_not_claimed() {
440        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
441        // No dots → not JWT-shaped → defer to next resolver.
442        let outcome = resolver
443            .resolve(&bearer("static-opaque-token"))
444            .await
445            .unwrap();
446        assert!(outcome.is_none());
447    }
448
449    #[tokio::test]
450    async fn missing_authorization_header_is_not_claimed() {
451        let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
452        let outcome = resolver.resolve(&MetadataMap::new()).await.unwrap();
453        assert!(outcome.is_none());
454    }
455
456    #[tokio::test]
457    async fn server_env_algorithms_accept_hs256_tokens() {
458        // Reproduce the server's SecurityLayer::from_env() config: algorithms = RS256/ES256/HS256.
459        let cfg = JwtConfig {
460            issuer: ISSUER.to_string(),
461            audience: AUDIENCE.to_string(),
462            algorithms: vec![Algorithm::RS256, Algorithm::ES256, Algorithm::HS256],
463        };
464        let resolver = JwtBearerResolver::from_inline_json(cfg, &jwks_inline()).unwrap();
465        let token = sign(&TestClaims {
466            sub: "agent://alice",
467            iss: ISSUER,
468            aud: AUDIENCE,
469            exp: (chrono::Utc::now().timestamp() + 300),
470            macp_scopes: None,
471        });
472        let id = resolver
473            .resolve(&bearer(&token))
474            .await
475            .expect("ok")
476            .expect("some");
477        assert_eq!(id.sender, "agent://alice");
478    }
479}