1use crate::auth::resolver::{AuthError, AuthResolver, ResolvedIdentity};
2use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
3use serde::Deserialize;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tonic::metadata::MetadataMap;
7
8#[derive(Debug, Clone, Deserialize)]
9struct MACPClaims {
10 sub: String,
11 #[serde(default)]
12 macp_scopes: Option<MACPScopes>,
13}
14
15#[derive(Debug, Clone, Deserialize, Default)]
16struct MACPScopes {
17 #[serde(default)]
18 can_start_sessions: Option<bool>,
19 #[serde(default)]
20 can_manage_mode_registry: Option<bool>,
21 #[serde(default)]
22 is_observer: Option<bool>,
23 #[serde(default)]
24 allowed_modes: Option<Vec<String>>,
25 #[serde(default)]
26 max_open_sessions: Option<usize>,
27}
28
29#[derive(Debug, Clone)]
30pub struct JwtConfig {
31 pub issuer: String,
32 pub audience: String,
33 pub algorithms: Vec<Algorithm>,
34}
35
36struct CachedKeys {
37 keys: Vec<DecodingKey>,
38 fetched_at: std::time::Instant,
39}
40
41pub struct JwtBearerResolver {
42 config: JwtConfig,
43 jwks_source: JwksSource,
44 cached_keys: Arc<RwLock<Option<CachedKeys>>>,
45 cache_ttl: std::time::Duration,
46}
47
48enum JwksSource {
49 Inline(Vec<DecodingKey>),
50 Url(String),
51}
52
53impl JwtBearerResolver {
54 pub fn from_inline_json(config: JwtConfig, jwks_json: &str) -> Result<Self, String> {
55 let jwks: serde_json::Value =
56 serde_json::from_str(jwks_json).map_err(|e| format!("invalid JWKS JSON: {e}"))?;
57 let keys = Self::parse_jwks(&jwks)?;
58 tracing::info!(
59 keys = keys.len(),
60 issuer = %config.issuer,
61 "JWT resolver initialized with inline JWKS"
62 );
63 Ok(Self {
64 config,
65 jwks_source: JwksSource::Inline(keys.clone()),
66 cached_keys: Arc::new(RwLock::new(Some(CachedKeys {
67 keys,
68 fetched_at: std::time::Instant::now(),
69 }))),
70 cache_ttl: std::time::Duration::from_secs(u64::MAX),
71 })
72 }
73
74 pub fn from_url(config: JwtConfig, url: String, cache_ttl_secs: u64) -> Self {
75 tracing::info!(
76 url = %url,
77 issuer = %config.issuer,
78 cache_ttl_secs,
79 "JWT resolver initialized with JWKS URL"
80 );
81 Self {
82 config,
83 jwks_source: JwksSource::Url(url),
84 cached_keys: Arc::new(RwLock::new(None)),
85 cache_ttl: std::time::Duration::from_secs(cache_ttl_secs),
86 }
87 }
88
89 fn extract_bearer(metadata: &MetadataMap) -> Option<String> {
90 metadata
91 .get("authorization")
92 .and_then(|v| v.to_str().ok())
93 .and_then(|v| v.strip_prefix("Bearer "))
94 .map(str::to_string)
95 }
96
97 async fn get_keys(&self) -> Result<Vec<DecodingKey>, AuthError> {
98 {
99 let guard = self.cached_keys.read().await;
100 if let Some(cached) = guard.as_ref() {
101 if cached.fetched_at.elapsed() < self.cache_ttl {
102 return Ok(cached.keys.clone());
103 }
104 }
105 }
106
107 match &self.jwks_source {
108 JwksSource::Inline(keys) => Ok(keys.clone()),
109 JwksSource::Url(url) => {
110 let keys = self.fetch_jwks(url).await?;
111 let mut guard = self.cached_keys.write().await;
112 *guard = Some(CachedKeys {
113 keys: keys.clone(),
114 fetched_at: std::time::Instant::now(),
115 });
116 Ok(keys)
117 }
118 }
119 }
120
121 async fn fetch_jwks(&self, url: &str) -> Result<Vec<DecodingKey>, AuthError> {
122 let resp = reqwest::get(url)
123 .await
124 .map_err(|e| AuthError::FetchFailed(format!("JWKS fetch failed: {e}")))?;
125 let jwks: serde_json::Value = resp
126 .json()
127 .await
128 .map_err(|e| AuthError::FetchFailed(format!("JWKS parse failed: {e}")))?;
129 Self::parse_jwks(&jwks).map_err(AuthError::FetchFailed)
130 }
131
132 fn parse_jwks(jwks: &serde_json::Value) -> Result<Vec<DecodingKey>, String> {
133 let keys_arr = jwks
134 .get("keys")
135 .and_then(|k| k.as_array())
136 .ok_or_else(|| "JWKS missing 'keys' array".to_string())?;
137
138 let mut decoding_keys = Vec::new();
139 for key in keys_arr {
140 let kty = key.get("kty").and_then(|v| v.as_str()).unwrap_or("");
141 match kty {
142 "RSA" => {
143 let n = key.get("n").and_then(|v| v.as_str()).unwrap_or("");
144 let e = key.get("e").and_then(|v| v.as_str()).unwrap_or("");
145 if !n.is_empty() && !e.is_empty() {
146 if let Ok(dk) = DecodingKey::from_rsa_components(n, e) {
147 decoding_keys.push(dk);
148 }
149 }
150 }
151 "EC" => {
152 let x = key.get("x").and_then(|v| v.as_str()).unwrap_or("");
153 let y = key.get("y").and_then(|v| v.as_str()).unwrap_or("");
154 let crv = key.get("crv").and_then(|v| v.as_str()).unwrap_or("P-256");
155 if !x.is_empty() && !y.is_empty() {
156 if let Ok(dk) = DecodingKey::from_ec_components(x, y) {
157 let _ = crv;
158 decoding_keys.push(dk);
159 }
160 }
161 }
162 "oct" => {
163 if let Some(k_val) = key.get("k").and_then(|v| v.as_str()) {
164 decoding_keys.push(
165 DecodingKey::from_base64_secret(k_val)
166 .unwrap_or_else(|_| DecodingKey::from_secret(k_val.as_bytes())),
167 );
168 }
169 }
170 _ => {}
171 }
172 }
173
174 if decoding_keys.is_empty() {
175 return Err("no usable keys found in JWKS".to_string());
176 }
177 Ok(decoding_keys)
178 }
179}
180
181#[async_trait::async_trait]
182impl AuthResolver for JwtBearerResolver {
183 fn name(&self) -> &str {
184 "jwt_bearer"
185 }
186
187 async fn resolve(&self, metadata: &MetadataMap) -> Result<Option<ResolvedIdentity>, AuthError> {
188 let token = match Self::extract_bearer(metadata) {
189 Some(t) => t,
190 None => return Ok(None),
191 };
192
193 if !token.contains('.') {
195 return Ok(None);
196 }
197
198 let keys = self.get_keys().await?;
199
200 let header = decode_header(&token)
206 .map_err(|e| AuthError::InvalidCredential(format!("malformed JWT header: {e}")))?;
207 if !self.config.algorithms.contains(&header.alg) {
208 return Err(AuthError::InvalidCredential(format!(
209 "JWT algorithm {:?} is not in the configured allowlist",
210 header.alg
211 )));
212 }
213 let mut validation = Validation::new(header.alg);
214 validation.set_issuer(&[&self.config.issuer]);
215 validation.set_audience(&[&self.config.audience]);
216 validation.algorithms = vec![header.alg];
217
218 let mut last_err = None;
219 for key in &keys {
220 match decode::<MACPClaims>(&token, key, &validation) {
221 Ok(token_data) => {
222 let claims = token_data.claims;
223 let scopes = claims.macp_scopes.unwrap_or_default();
224
225 return Ok(Some(ResolvedIdentity {
226 sender: claims.sub,
227 allowed_modes: scopes.allowed_modes.map(|m| m.into_iter().collect()),
228 can_start_sessions: scopes.can_start_sessions.unwrap_or(true),
229 max_open_sessions: scopes.max_open_sessions,
230 can_manage_mode_registry: scopes.can_manage_mode_registry.unwrap_or(false),
231 is_observer: scopes.is_observer.unwrap_or(false),
232 resolver: "jwt_bearer".to_string(),
233 }));
234 }
235 Err(e) => {
236 last_err = Some(e);
237 continue;
238 }
239 }
240 }
241
242 match last_err {
243 Some(e) => {
244 use jsonwebtoken::errors::ErrorKind;
245 match e.kind() {
246 ErrorKind::ExpiredSignature => Err(AuthError::Expired),
247 ErrorKind::InvalidIssuer => {
248 Err(AuthError::InvalidCredential("invalid issuer".to_string()))
249 }
250 ErrorKind::InvalidAudience => {
251 Err(AuthError::InvalidCredential("invalid audience".to_string()))
252 }
253 _ => Err(AuthError::InvalidCredential(format!(
254 "JWT validation failed: {e}"
255 ))),
256 }
257 }
258 None => Err(AuthError::InvalidCredential(
259 "no keys available to validate JWT".to_string(),
260 )),
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use base64::Engine;
269 use jsonwebtoken::{encode, EncodingKey, Header};
270 use serde::Serialize;
271
272 const ISSUER: &str = "https://issuer.test";
273 const AUDIENCE: &str = "macp-runtime";
274 const SECRET: &[u8] = b"super-secret-symmetric-key-32-by";
275
276 #[derive(Serialize)]
277 struct TestClaims<'a> {
278 sub: &'a str,
279 iss: &'a str,
280 aud: &'a str,
281 exp: i64,
282 #[serde(skip_serializing_if = "Option::is_none")]
283 macp_scopes: Option<serde_json::Value>,
284 }
285
286 fn jwks_inline() -> String {
287 let k = base64::engine::general_purpose::STANDARD.encode(SECRET);
288 serde_json::json!({
289 "keys": [
290 { "kty": "oct", "alg": "HS256", "k": k }
291 ]
292 })
293 .to_string()
294 }
295
296 fn config() -> JwtConfig {
297 JwtConfig {
298 issuer: ISSUER.to_string(),
299 audience: AUDIENCE.to_string(),
300 algorithms: vec![Algorithm::HS256],
301 }
302 }
303
304 fn sign(claims: &TestClaims) -> String {
305 let mut header = Header::new(Algorithm::HS256);
306 header.kid = Some("test-key".into());
307 encode(&header, claims, &EncodingKey::from_secret(SECRET)).unwrap()
308 }
309
310 fn bearer(token: &str) -> MetadataMap {
311 let mut m = MetadataMap::new();
312 m.insert("authorization", format!("Bearer {token}").parse().unwrap());
313 m
314 }
315
316 #[tokio::test]
317 async fn valid_jwt_resolves_to_identity_with_scopes() {
318 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
319 let token = sign(&TestClaims {
320 sub: "agent://alice",
321 iss: ISSUER,
322 aud: AUDIENCE,
323 exp: (chrono::Utc::now().timestamp() + 300),
324 macp_scopes: Some(serde_json::json!({
325 "allowed_modes": ["macp.mode.decision.v1"],
326 "can_start_sessions": true,
327 "max_open_sessions": 5,
328 "can_manage_mode_registry": false,
329 "is_observer": false,
330 })),
331 });
332
333 let id = resolver
334 .resolve(&bearer(&token))
335 .await
336 .expect("ok")
337 .expect("some");
338 assert_eq!(id.sender, "agent://alice");
339 assert_eq!(id.resolver, "jwt_bearer");
340 assert!(id.can_start_sessions);
341 assert_eq!(id.max_open_sessions, Some(5));
342 assert!(!id.can_manage_mode_registry);
343 assert!(!id.is_observer);
344 let modes = id.allowed_modes.unwrap();
345 assert!(modes.contains("macp.mode.decision.v1"));
346 }
347
348 #[tokio::test]
349 async fn jwt_without_scopes_defaults_to_permissive_sender() {
350 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
351 let token = sign(&TestClaims {
352 sub: "agent://bob",
353 iss: ISSUER,
354 aud: AUDIENCE,
355 exp: (chrono::Utc::now().timestamp() + 300),
356 macp_scopes: None,
357 });
358 let id = resolver.resolve(&bearer(&token)).await.unwrap().unwrap();
359 assert_eq!(id.sender, "agent://bob");
360 assert!(id.can_start_sessions); assert!(id.allowed_modes.is_none());
362 assert!(!id.is_observer);
363 }
364
365 #[tokio::test]
366 async fn expired_jwt_returns_expired_error() {
367 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
368 let token = sign(&TestClaims {
370 sub: "agent://alice",
371 iss: ISSUER,
372 aud: AUDIENCE,
373 exp: (chrono::Utc::now().timestamp() - 600),
374 macp_scopes: None,
375 });
376 let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
377 assert!(matches!(err, AuthError::Expired), "got {err:?}");
378 }
379
380 #[tokio::test]
381 async fn wrong_issuer_rejected() {
382 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
383 let token = sign(&TestClaims {
384 sub: "agent://alice",
385 iss: "https://other.example",
386 aud: AUDIENCE,
387 exp: (chrono::Utc::now().timestamp() + 300),
388 macp_scopes: None,
389 });
390 let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
391 assert!(
392 matches!(err, AuthError::InvalidCredential(ref m) if m.contains("issuer")),
393 "got {err:?}"
394 );
395 }
396
397 #[tokio::test]
398 async fn wrong_audience_rejected() {
399 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
400 let token = sign(&TestClaims {
401 sub: "agent://alice",
402 iss: ISSUER,
403 aud: "other-audience",
404 exp: (chrono::Utc::now().timestamp() + 300),
405 macp_scopes: None,
406 });
407 let err = resolver.resolve(&bearer(&token)).await.unwrap_err();
408 assert!(
409 matches!(err, AuthError::InvalidCredential(ref m) if m.contains("audience")),
410 "got {err:?}"
411 );
412 }
413
414 #[tokio::test]
415 async fn bad_signature_rejected() {
416 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
417 let claims = TestClaims {
419 sub: "agent://alice",
420 iss: ISSUER,
421 aud: AUDIENCE,
422 exp: (chrono::Utc::now().timestamp() + 300),
423 macp_scopes: None,
424 };
425 let bad_token = encode(
426 &Header::new(Algorithm::HS256),
427 &claims,
428 &EncodingKey::from_secret(b"different-key-bytes-0123456789!!"),
429 )
430 .unwrap();
431 let err = resolver.resolve(&bearer(&bad_token)).await.unwrap_err();
432 assert!(
433 matches!(err, AuthError::InvalidCredential(_)),
434 "got {err:?}"
435 );
436 }
437
438 #[tokio::test]
439 async fn opaque_bearer_token_is_not_claimed() {
440 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
441 let outcome = resolver
443 .resolve(&bearer("static-opaque-token"))
444 .await
445 .unwrap();
446 assert!(outcome.is_none());
447 }
448
449 #[tokio::test]
450 async fn missing_authorization_header_is_not_claimed() {
451 let resolver = JwtBearerResolver::from_inline_json(config(), &jwks_inline()).unwrap();
452 let outcome = resolver.resolve(&MetadataMap::new()).await.unwrap();
453 assert!(outcome.is_none());
454 }
455
456 #[tokio::test]
457 async fn server_env_algorithms_accept_hs256_tokens() {
458 let cfg = JwtConfig {
460 issuer: ISSUER.to_string(),
461 audience: AUDIENCE.to_string(),
462 algorithms: vec![Algorithm::RS256, Algorithm::ES256, Algorithm::HS256],
463 };
464 let resolver = JwtBearerResolver::from_inline_json(cfg, &jwks_inline()).unwrap();
465 let token = sign(&TestClaims {
466 sub: "agent://alice",
467 iss: ISSUER,
468 aud: AUDIENCE,
469 exp: (chrono::Utc::now().timestamp() + 300),
470 macp_scopes: None,
471 });
472 let id = resolver
473 .resolve(&bearer(&token))
474 .await
475 .expect("ok")
476 .expect("some");
477 assert_eq!(id.sender, "agent://alice");
478 }
479}