1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum JwtError {
16 #[error("JWT encoding failed: {0}")]
17 Encode(#[from] jsonwebtoken::errors::Error),
18 #[error("Invalid PEM key: {0}")]
19 InvalidKey(String),
20 #[error("No encoding key configured (private key required for issuance)")]
21 NoEncodingKey,
22 #[error("No decoding key configured (public key required for validation)")]
23 NoDecodingKey,
24 #[error("Base64 decode error: {0}")]
25 Base64(String),
26}
27
28#[derive(Clone)]
30pub struct JwtConfig {
31 pub decoding_key: DecodingKey,
33 pub encoding_key: Option<EncodingKey>,
35 pub algorithm: Algorithm,
37 pub required_issuer: Option<String>,
39 pub accepted_audiences: Vec<String>,
48 pub leeway_secs: u64,
50 pub public_key_pem: Option<Vec<u8>>,
52}
53
54impl std::fmt::Debug for JwtConfig {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("JwtConfig")
57 .field("algorithm", &self.algorithm)
58 .field("required_issuer", &self.required_issuer)
59 .field("accepted_audiences", &self.accepted_audiences)
60 .field("leeway_secs", &self.leeway_secs)
61 .field("has_encoding_key", &self.encoding_key.is_some())
62 .finish()
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct AtiNamespace {
69 pub v: u8,
71 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73 pub rate: HashMap<String, String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TokenClaims {
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub iss: Option<String>,
82 pub sub: String,
84 pub aud: String,
86 pub iat: u64,
88 pub exp: u64,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub jti: Option<String>,
93 pub scope: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub ati: Option<AtiNamespace>,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub job_id: Option<String>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub sandbox_id: Option<String>,
104}
105
106impl TokenClaims {
107 pub fn scopes(&self) -> Vec<String> {
109 self.scope.split_whitespace().map(String::from).collect()
110 }
111}
112
113pub fn validate(token: &str, config: &JwtConfig) -> Result<TokenClaims, JwtError> {
119 if config.accepted_audiences.is_empty() {
127 return Err(JwtError::InvalidKey(
128 "accepted_audiences must not be empty; configure at least one audience".into(),
129 ));
130 }
131 let mut validation = Validation::new(config.algorithm);
132 let auds: Vec<&str> = config
137 .accepted_audiences
138 .iter()
139 .map(String::as_str)
140 .collect();
141 validation.set_audience(&auds);
142 validation.leeway = config.leeway_secs;
143
144 if let Some(ref issuer) = config.required_issuer {
145 validation.set_issuer(&[issuer]);
146 } else {
147 validation.set_required_spec_claims(&["exp", "sub", "aud"]);
149 }
150
151 let token_data: TokenData<TokenClaims> =
152 jsonwebtoken::decode(token, &config.decoding_key, &validation)?;
153
154 Ok(token_data.claims)
155}
156
157pub fn issue(claims: &TokenClaims, config: &JwtConfig) -> Result<String, JwtError> {
159 let encoding_key = config
160 .encoding_key
161 .as_ref()
162 .ok_or(JwtError::NoEncodingKey)?;
163
164 let header = Header::new(config.algorithm);
165 let token = jsonwebtoken::encode(&header, claims, encoding_key)?;
166 Ok(token)
167}
168
169pub fn inspect(token: &str) -> Result<TokenClaims, JwtError> {
171 let mut validation = Validation::default();
172 validation.insecure_disable_signature_validation();
173 validation.validate_aud = false;
174 validation.validate_exp = false;
175 validation.set_required_spec_claims::<&str>(&[]);
176
177 let key = DecodingKey::from_secret(b"unused");
179 let token_data: TokenData<TokenClaims> = jsonwebtoken::decode(token, &key, &validation)?;
180
181 Ok(token_data.claims)
182}
183
184pub fn load_public_key_pem(pem: &[u8], alg: Algorithm) -> Result<DecodingKey, JwtError> {
186 match alg {
187 Algorithm::ES256 | Algorithm::ES384 => {
188 DecodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
189 }
190 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
191 DecodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
192 }
193 _ => Err(JwtError::InvalidKey(format!(
194 "Unsupported algorithm for PEM: {alg:?}"
195 ))),
196 }
197}
198
199pub fn load_private_key_pem(pem: &[u8], alg: Algorithm) -> Result<EncodingKey, JwtError> {
201 match alg {
202 Algorithm::ES256 | Algorithm::ES384 => {
203 EncodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
204 }
205 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
206 EncodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
207 }
208 _ => Err(JwtError::InvalidKey(format!(
209 "Unsupported algorithm for PEM: {alg:?}"
210 ))),
211 }
212}
213
214pub fn config_from_secret(
223 secret: &[u8],
224 issuer: Option<String>,
225 audiences: Vec<String>,
226) -> JwtConfig {
227 JwtConfig {
228 decoding_key: DecodingKey::from_secret(secret),
229 encoding_key: Some(EncodingKey::from_secret(secret)),
230 algorithm: Algorithm::HS256,
231 required_issuer: issuer,
232 accepted_audiences: audiences,
233 leeway_secs: 60,
234 public_key_pem: None,
235 }
236}
237
238pub fn config_from_pem(
241 public_pem: &[u8],
242 private_pem: Option<&[u8]>,
243 alg: Algorithm,
244 issuer: Option<String>,
245 audiences: Vec<String>,
246) -> Result<JwtConfig, JwtError> {
247 let decoding_key = load_public_key_pem(public_pem, alg)?;
248 let encoding_key = match private_pem {
249 Some(pem) => Some(load_private_key_pem(pem, alg)?),
250 None => None,
251 };
252
253 Ok(JwtConfig {
254 decoding_key,
255 encoding_key,
256 algorithm: alg,
257 required_issuer: issuer,
258 accepted_audiences: audiences,
259 leeway_secs: 60,
260 public_key_pem: Some(public_pem.to_vec()),
261 })
262}
263
264pub fn public_key_to_jwks(
267 pem: &[u8],
268 alg: Algorithm,
269 kid: &str,
270) -> Result<serde_json::Value, JwtError> {
271 let pem_str = std::str::from_utf8(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))?;
273
274 let key_type = match alg {
276 Algorithm::ES256 | Algorithm::ES384 => "EC",
277 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => "RSA",
278 _ => {
279 return Err(JwtError::InvalidKey(
280 "Unsupported algorithm for JWKS".into(),
281 ))
282 }
283 };
284
285 let alg_str = match alg {
286 Algorithm::ES256 => "ES256",
287 Algorithm::ES384 => "ES384",
288 Algorithm::RS256 => "RS256",
289 Algorithm::RS384 => "RS384",
290 Algorithm::RS512 => "RS512",
291 _ => "unknown",
292 };
293
294 let der_b64: String = pem_str
297 .lines()
298 .filter(|line| !line.starts_with("-----"))
299 .collect::<Vec<_>>()
300 .join("");
301
302 let jwk = serde_json::json!({
303 "kty": key_type,
304 "use": "sig",
305 "alg": alg_str,
306 "kid": kid,
307 "x5c": [der_b64],
308 });
309
310 Ok(serde_json::json!({
311 "keys": [jwk]
312 }))
313}
314
315pub fn parse_audiences_env() -> Vec<String> {
330 if let Ok(csv) = std::env::var("ATI_JWT_ACCEPTED_AUDIENCES") {
331 let v: Vec<String> = csv
332 .split(',')
333 .map(|s| s.trim().to_string())
334 .filter(|s| !s.is_empty())
335 .collect();
336 if !v.is_empty() {
337 return v;
338 }
339 }
340 let single = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".to_string());
341 vec![single]
342}
343
344pub fn config_from_env() -> Result<Option<JwtConfig>, JwtError> {
355 let issuer = std::env::var("ATI_JWT_ISSUER").ok();
356 let audiences = parse_audiences_env();
357
358 if let Ok(pub_key_path) = std::env::var("ATI_JWT_PUBLIC_KEY") {
360 let public_pem = std::fs::read(&pub_key_path)
361 .map_err(|e| JwtError::InvalidKey(format!("Cannot read {pub_key_path}: {e}")))?;
362
363 let private_pem = std::env::var("ATI_JWT_PRIVATE_KEY")
364 .ok()
365 .and_then(|path| std::fs::read(&path).ok());
366
367 let mut config = config_from_pem(
368 &public_pem,
369 private_pem.as_deref(),
370 Algorithm::ES256,
371 issuer,
372 audiences,
373 )?;
374
375 config.public_key_pem = Some(public_pem);
377
378 return Ok(Some(config));
379 }
380
381 if let Ok(secret_hex) = std::env::var("ATI_JWT_SECRET") {
383 let secret_bytes = hex::decode(&secret_hex)
384 .map_err(|e| JwtError::InvalidKey(format!("ATI_JWT_SECRET is not valid hex: {e}")))?;
385
386 return Ok(Some(config_from_secret(&secret_bytes, issuer, audiences)));
387 }
388
389 Ok(None)
390}
391
392pub fn now_secs() -> u64 {
394 std::time::SystemTime::now()
395 .duration_since(std::time::UNIX_EPOCH)
396 .unwrap_or_default()
397 .as_secs()
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn hs256_config() -> JwtConfig {
405 config_from_secret(
406 b"test-secret-key-32-bytes-long!!!",
407 None,
408 vec!["ati-proxy".into()],
409 )
410 }
411
412 fn hs256_config_with_issuer() -> JwtConfig {
413 config_from_secret(
414 b"test-secret-key-32-bytes-long!!!",
415 Some("ati-orchestrator".into()),
416 vec!["ati-proxy".into()],
417 )
418 }
419
420 fn make_claims(scope: &str) -> TokenClaims {
421 let now = now_secs();
422 TokenClaims {
423 iss: Some("ati-orchestrator".into()),
424 sub: "agent-7".into(),
425 aud: "ati-proxy".into(),
426 iat: now,
427 exp: now + 1800,
428 jti: Some(uuid::Uuid::new_v4().to_string()),
429 scope: scope.into(),
430 ati: Some(AtiNamespace {
431 v: 1,
432 rate: HashMap::new(),
433 }),
434 job_id: None,
435 sandbox_id: None,
436 }
437 }
438
439 #[test]
440 fn test_hs256_round_trip() {
441 let config = hs256_config();
442 let claims = make_claims("tool:web_search tool:github:*");
443
444 let token = issue(&claims, &config).unwrap();
445 let decoded = validate(&token, &config).unwrap();
446
447 assert_eq!(decoded.sub, "agent-7");
448 assert_eq!(decoded.aud, "ati-proxy");
449 assert_eq!(decoded.scope, "tool:web_search tool:github:*");
450 assert_eq!(decoded.scopes(), vec!["tool:web_search", "tool:github:*"]);
451 assert_eq!(decoded.iss, Some("ati-orchestrator".into()));
452 }
453
454 #[test]
455 fn test_expired_token_rejected() {
456 let config = hs256_config();
457 let mut claims = make_claims("tool:web_search");
458 claims.exp = 1; let token = issue(&claims, &config).unwrap();
461 let result = validate(&token, &config);
462 assert!(result.is_err());
463 }
464
465 #[test]
466 fn test_wrong_secret_rejected() {
467 let config1 = hs256_config();
468 let config2 = config_from_secret(
469 b"different-secret-key-32-bytes!!",
470 None,
471 vec!["ati-proxy".into()],
472 );
473
474 let claims = make_claims("tool:web_search");
475 let token = issue(&claims, &config1).unwrap();
476 let result = validate(&token, &config2);
477 assert!(result.is_err());
478 }
479
480 #[test]
481 fn test_wrong_audience_rejected() {
482 let config = hs256_config();
483 let mut claims = make_claims("tool:web_search");
484 claims.aud = "wrong-audience".into();
485
486 let token = issue(&claims, &config).unwrap();
487 let result = validate(&token, &config);
488 assert!(result.is_err());
489 }
490
491 #[test]
492 fn test_wrong_issuer_rejected() {
493 let config = hs256_config_with_issuer();
494 let mut claims = make_claims("tool:web_search");
495 claims.iss = Some("evil-orchestrator".into());
496
497 let token = issue(&claims, &config).unwrap();
498 let result = validate(&token, &config);
499 assert!(result.is_err());
500 }
501
502 #[test]
503 fn test_tampered_payload_rejected() {
504 let config = hs256_config();
505 let claims = make_claims("tool:web_search");
506 let token = issue(&claims, &config).unwrap();
507
508 let parts: Vec<&str> = token.split('.').collect();
510 assert_eq!(parts.len(), 3);
511 let mut tampered_payload = parts[1].to_string();
512 if tampered_payload.ends_with('A') {
514 tampered_payload.push('B');
515 } else {
516 tampered_payload.push('A');
517 }
518 let tampered = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
519
520 let result = validate(&tampered, &config);
521 assert!(result.is_err());
522 }
523
524 #[test]
525 fn test_malformed_token_rejected() {
526 let config = hs256_config();
527 let result = validate("not.a.jwt.token.at.all", &config);
528 assert!(result.is_err());
529
530 let result = validate("", &config);
531 assert!(result.is_err());
532
533 let result = validate("just-a-string", &config);
534 assert!(result.is_err());
535 }
536
537 #[test]
538 fn test_inspect_decodes_without_key() {
539 let config = hs256_config();
540 let claims = make_claims("tool:web_search skill:research-*");
541 let token = issue(&claims, &config).unwrap();
542
543 let decoded = inspect(&token).unwrap();
544 assert_eq!(decoded.sub, "agent-7");
545 assert_eq!(decoded.scope, "tool:web_search skill:research-*");
546 }
547
548 #[test]
549 fn test_scope_parsing() {
550 let claims = make_claims("tool:web_search tool:github:* skill:research-* help");
551 let scopes = claims.scopes();
552 assert_eq!(
553 scopes,
554 vec![
555 "tool:web_search",
556 "tool:github:*",
557 "skill:research-*",
558 "help"
559 ]
560 );
561 }
562
563 #[test]
564 fn test_empty_scope() {
565 let claims = make_claims("");
566 assert!(claims.scopes().is_empty());
567 }
568
569 #[test]
570 fn test_single_scope() {
571 let claims = make_claims("*");
572 assert_eq!(claims.scopes(), vec!["*"]);
573 }
574
575 #[test]
576 fn test_no_encoding_key_fails() {
577 let config = JwtConfig {
578 decoding_key: DecodingKey::from_secret(b"test"),
579 encoding_key: None,
580 algorithm: Algorithm::HS256,
581 required_issuer: None,
582 accepted_audiences: vec!["ati-proxy".into()],
583 leeway_secs: 60,
584 public_key_pem: None,
585 };
586
587 let claims = make_claims("tool:web_search");
588 let result = issue(&claims, &config);
589 assert!(result.is_err());
590 }
591
592 #[test]
593 fn test_issuer_not_required_when_none() {
594 let config = hs256_config(); let mut claims = make_claims("tool:web_search");
596 claims.iss = None;
597
598 let token = issue(&claims, &config).unwrap();
599 let decoded = validate(&token, &config).unwrap();
600 assert_eq!(decoded.iss, None);
601 }
602
603 #[test]
604 fn test_jti_preserved() {
605 let config = hs256_config();
606 let claims = make_claims("tool:web_search");
607 let jti = claims.jti.clone();
608
609 let token = issue(&claims, &config).unwrap();
610 let decoded = validate(&token, &config).unwrap();
611 assert_eq!(decoded.jti, jti);
612 }
613
614 #[test]
615 fn test_ati_namespace_preserved() {
616 let config = hs256_config();
617 let claims = make_claims("tool:web_search");
618
619 let token = issue(&claims, &config).unwrap();
620 let decoded = validate(&token, &config).unwrap();
621 assert!(decoded.ati.is_some());
622 assert_eq!(decoded.ati.unwrap().v, 1);
623 }
624
625 fn hs256_config_multi(audiences: Vec<String>) -> JwtConfig {
638 config_from_secret(b"test-secret-key-32-bytes-long!!!", None, audiences)
639 }
640
641 #[test]
642 fn test_multi_audience_accepts_first() {
643 let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
644 let mut claims = make_claims("tool:web_search");
645 claims.aud = "ati-proxy".into();
646
647 let token = issue(&claims, &config).unwrap();
648 let decoded = validate(&token, &config).expect("aud=ati-proxy should pass");
649 assert_eq!(decoded.aud, "ati-proxy");
650 }
651
652 #[test]
653 fn test_multi_audience_accepts_second() {
654 let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
655 let mut claims = make_claims("tool:web_search");
656 claims.aud = "parcha-tools".into();
657
658 let token = issue(&claims, &config).unwrap();
659 let decoded = validate(&token, &config).expect("aud=parcha-tools should pass");
660 assert_eq!(decoded.aud, "parcha-tools");
661 }
662
663 #[test]
664 fn test_multi_audience_rejects_out_of_list() {
665 let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
666 let mut claims = make_claims("tool:web_search");
667 claims.aud = "evil-aud".into();
668
669 let token = issue(&claims, &config).unwrap();
670 let result = validate(&token, &config);
671 assert!(result.is_err(), "aud not in allowlist must be rejected");
672 }
673
674 #[test]
675 fn test_single_audience_back_compat() {
676 let config = hs256_config_multi(vec!["ati-proxy".into()]);
679 let mut claims = make_claims("tool:web_search");
680
681 claims.aud = "ati-proxy".into();
682 let token = issue(&claims, &config).unwrap();
683 assert!(validate(&token, &config).is_ok());
684
685 claims.aud = "wrong".into();
686 let token = issue(&claims, &config).unwrap();
687 assert!(validate(&token, &config).is_err());
688 }
689
690 #[test]
691 fn test_empty_audiences_vec_rejected_not_bypassed() {
692 let config = hs256_config_multi(vec![]);
701 let mut claims = make_claims("tool:web_search");
702 claims.aud = "ati-proxy".into();
703
704 let issuer = hs256_config_multi(vec!["ati-proxy".into()]);
708 let token = issue(&claims, &issuer).unwrap();
709
710 let err = validate(&token, &config).expect_err("empty allowlist must reject");
711 match err {
712 JwtError::InvalidKey(msg) => assert!(
713 msg.contains("accepted_audiences"),
714 "error must mention accepted_audiences; got: {msg}"
715 ),
716 other => panic!("expected InvalidKey, got {other:?}"),
717 }
718 }
719
720 use std::sync::Mutex;
728 static ENV_LOCK: Mutex<()> = Mutex::new(());
729
730 struct EnvGuard {
731 prev: Vec<(&'static str, Option<String>)>,
732 }
733
734 impl EnvGuard {
735 fn set(pairs: &[(&'static str, Option<&str>)]) -> Self {
736 let mut prev = Vec::new();
737 for (k, v) in pairs {
738 prev.push((*k, std::env::var(k).ok()));
739 match v {
740 Some(val) => std::env::set_var(k, val),
741 None => std::env::remove_var(k),
742 }
743 }
744 Self { prev }
745 }
746 }
747
748 impl Drop for EnvGuard {
749 fn drop(&mut self) {
750 for (k, v) in &self.prev {
751 match v {
752 Some(val) => std::env::set_var(k, val),
753 None => std::env::remove_var(k),
754 }
755 }
756 }
757 }
758
759 #[test]
760 fn test_parse_audiences_env_csv_wins() {
761 let _g = ENV_LOCK.lock().unwrap();
762 let _e = EnvGuard::set(&[
763 ("ATI_JWT_ACCEPTED_AUDIENCES", Some("a, b ,c")),
764 ("ATI_JWT_AUDIENCE", Some("ignored-singular")),
765 ]);
766 assert_eq!(parse_audiences_env(), vec!["a", "b", "c"]);
767 }
768
769 #[test]
770 fn test_parse_audiences_env_falls_back_to_singular() {
771 let _g = ENV_LOCK.lock().unwrap();
772 let _e = EnvGuard::set(&[
773 ("ATI_JWT_ACCEPTED_AUDIENCES", None),
774 ("ATI_JWT_AUDIENCE", Some("custom-aud")),
775 ]);
776 assert_eq!(parse_audiences_env(), vec!["custom-aud"]);
777 }
778
779 #[test]
780 fn test_parse_audiences_env_default_is_ati_proxy() {
781 let _g = ENV_LOCK.lock().unwrap();
782 let _e = EnvGuard::set(&[
783 ("ATI_JWT_ACCEPTED_AUDIENCES", None),
784 ("ATI_JWT_AUDIENCE", None),
785 ]);
786 assert_eq!(parse_audiences_env(), vec!["ati-proxy"]);
787 }
788
789 #[test]
790 fn test_parse_audiences_env_csv_all_empty_falls_back() {
791 let _g = ENV_LOCK.lock().unwrap();
796 let _e = EnvGuard::set(&[
797 ("ATI_JWT_ACCEPTED_AUDIENCES", Some(" , , ")),
798 ("ATI_JWT_AUDIENCE", Some("fallback-aud")),
799 ]);
800 assert_eq!(parse_audiences_env(), vec!["fallback-aud"]);
801 }
802}