1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum JwtError {
16 #[error("JWT encoding failed: {0}")]
17 Encode(#[from] jsonwebtoken::errors::Error),
18 #[error("Invalid PEM key: {0}")]
19 InvalidKey(String),
20 #[error("No encoding key configured (private key required for issuance)")]
21 NoEncodingKey,
22 #[error("No decoding key configured (public key required for validation)")]
23 NoDecodingKey,
24 #[error("Base64 decode error: {0}")]
25 Base64(String),
26}
27
28#[derive(Clone)]
30pub struct JwtConfig {
31 pub decoding_key: DecodingKey,
33 pub encoding_key: Option<EncodingKey>,
35 pub algorithm: Algorithm,
37 pub required_issuer: Option<String>,
39 pub required_audience: String,
41 pub leeway_secs: u64,
43 pub public_key_pem: Option<Vec<u8>>,
45}
46
47impl std::fmt::Debug for JwtConfig {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("JwtConfig")
50 .field("algorithm", &self.algorithm)
51 .field("required_issuer", &self.required_issuer)
52 .field("required_audience", &self.required_audience)
53 .field("leeway_secs", &self.leeway_secs)
54 .field("has_encoding_key", &self.encoding_key.is_some())
55 .finish()
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AtiNamespace {
62 pub v: u8,
64 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66 pub rate: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TokenClaims {
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub iss: Option<String>,
75 pub sub: String,
77 pub aud: String,
79 pub iat: u64,
81 pub exp: u64,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub jti: Option<String>,
86 pub scope: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub ati: Option<AtiNamespace>,
91}
92
93impl TokenClaims {
94 pub fn scopes(&self) -> Vec<String> {
96 self.scope.split_whitespace().map(String::from).collect()
97 }
98}
99
100pub fn validate(token: &str, config: &JwtConfig) -> Result<TokenClaims, JwtError> {
102 let mut validation = Validation::new(config.algorithm);
103 validation.set_audience(&[&config.required_audience]);
104 validation.leeway = config.leeway_secs;
105
106 if let Some(ref issuer) = config.required_issuer {
107 validation.set_issuer(&[issuer]);
108 } else {
109 validation.set_required_spec_claims(&["exp", "sub", "aud"]);
111 }
112
113 let token_data: TokenData<TokenClaims> =
114 jsonwebtoken::decode(token, &config.decoding_key, &validation)?;
115
116 Ok(token_data.claims)
117}
118
119pub fn issue(claims: &TokenClaims, config: &JwtConfig) -> Result<String, JwtError> {
121 let encoding_key = config
122 .encoding_key
123 .as_ref()
124 .ok_or(JwtError::NoEncodingKey)?;
125
126 let header = Header::new(config.algorithm);
127 let token = jsonwebtoken::encode(&header, claims, encoding_key)?;
128 Ok(token)
129}
130
131pub fn inspect(token: &str) -> Result<TokenClaims, JwtError> {
133 let mut validation = Validation::default();
134 validation.insecure_disable_signature_validation();
135 validation.validate_aud = false;
136 validation.validate_exp = false;
137 validation.set_required_spec_claims::<&str>(&[]);
138
139 let key = DecodingKey::from_secret(b"unused");
141 let token_data: TokenData<TokenClaims> = jsonwebtoken::decode(token, &key, &validation)?;
142
143 Ok(token_data.claims)
144}
145
146pub fn load_public_key_pem(pem: &[u8], alg: Algorithm) -> Result<DecodingKey, JwtError> {
148 match alg {
149 Algorithm::ES256 | Algorithm::ES384 => {
150 DecodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
151 }
152 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
153 DecodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
154 }
155 _ => Err(JwtError::InvalidKey(format!(
156 "Unsupported algorithm for PEM: {alg:?}"
157 ))),
158 }
159}
160
161pub fn load_private_key_pem(pem: &[u8], alg: Algorithm) -> Result<EncodingKey, JwtError> {
163 match alg {
164 Algorithm::ES256 | Algorithm::ES384 => {
165 EncodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
166 }
167 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
168 EncodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
169 }
170 _ => Err(JwtError::InvalidKey(format!(
171 "Unsupported algorithm for PEM: {alg:?}"
172 ))),
173 }
174}
175
176pub fn config_from_secret(secret: &[u8], issuer: Option<String>, audience: String) -> JwtConfig {
178 JwtConfig {
179 decoding_key: DecodingKey::from_secret(secret),
180 encoding_key: Some(EncodingKey::from_secret(secret)),
181 algorithm: Algorithm::HS256,
182 required_issuer: issuer,
183 required_audience: audience,
184 leeway_secs: 60,
185 public_key_pem: None,
186 }
187}
188
189pub fn config_from_pem(
191 public_pem: &[u8],
192 private_pem: Option<&[u8]>,
193 alg: Algorithm,
194 issuer: Option<String>,
195 audience: String,
196) -> Result<JwtConfig, JwtError> {
197 let decoding_key = load_public_key_pem(public_pem, alg)?;
198 let encoding_key = match private_pem {
199 Some(pem) => Some(load_private_key_pem(pem, alg)?),
200 None => None,
201 };
202
203 Ok(JwtConfig {
204 decoding_key,
205 encoding_key,
206 algorithm: alg,
207 required_issuer: issuer,
208 required_audience: audience,
209 leeway_secs: 60,
210 public_key_pem: Some(public_pem.to_vec()),
211 })
212}
213
214pub fn public_key_to_jwks(
217 pem: &[u8],
218 alg: Algorithm,
219 kid: &str,
220) -> Result<serde_json::Value, JwtError> {
221 let pem_str = std::str::from_utf8(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))?;
223
224 let key_type = match alg {
226 Algorithm::ES256 | Algorithm::ES384 => "EC",
227 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => "RSA",
228 _ => {
229 return Err(JwtError::InvalidKey(
230 "Unsupported algorithm for JWKS".into(),
231 ))
232 }
233 };
234
235 let alg_str = match alg {
236 Algorithm::ES256 => "ES256",
237 Algorithm::ES384 => "ES384",
238 Algorithm::RS256 => "RS256",
239 Algorithm::RS384 => "RS384",
240 Algorithm::RS512 => "RS512",
241 _ => "unknown",
242 };
243
244 let der_b64: String = pem_str
247 .lines()
248 .filter(|line| !line.starts_with("-----"))
249 .collect::<Vec<_>>()
250 .join("");
251
252 let jwk = serde_json::json!({
253 "kty": key_type,
254 "use": "sig",
255 "alg": alg_str,
256 "kid": kid,
257 "x5c": [der_b64],
258 });
259
260 Ok(serde_json::json!({
261 "keys": [jwk]
262 }))
263}
264
265pub fn config_from_env() -> Result<Option<JwtConfig>, JwtError> {
272 let issuer = std::env::var("ATI_JWT_ISSUER").ok();
273 let audience = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".to_string());
274
275 if let Ok(pub_key_path) = std::env::var("ATI_JWT_PUBLIC_KEY") {
277 let public_pem = std::fs::read(&pub_key_path)
278 .map_err(|e| JwtError::InvalidKey(format!("Cannot read {pub_key_path}: {e}")))?;
279
280 let private_pem = std::env::var("ATI_JWT_PRIVATE_KEY")
281 .ok()
282 .and_then(|path| std::fs::read(&path).ok());
283
284 let mut config = config_from_pem(
285 &public_pem,
286 private_pem.as_deref(),
287 Algorithm::ES256,
288 issuer,
289 audience,
290 )?;
291
292 config.public_key_pem = Some(public_pem);
294
295 return Ok(Some(config));
296 }
297
298 if let Ok(secret_hex) = std::env::var("ATI_JWT_SECRET") {
300 let secret_bytes = hex::decode(&secret_hex)
301 .map_err(|e| JwtError::InvalidKey(format!("ATI_JWT_SECRET is not valid hex: {e}")))?;
302
303 return Ok(Some(config_from_secret(&secret_bytes, issuer, audience)));
304 }
305
306 Ok(None)
307}
308
309pub fn now_secs() -> u64 {
311 std::time::SystemTime::now()
312 .duration_since(std::time::UNIX_EPOCH)
313 .unwrap_or_default()
314 .as_secs()
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn hs256_config() -> JwtConfig {
322 config_from_secret(
323 b"test-secret-key-32-bytes-long!!!",
324 None,
325 "ati-proxy".into(),
326 )
327 }
328
329 fn hs256_config_with_issuer() -> JwtConfig {
330 config_from_secret(
331 b"test-secret-key-32-bytes-long!!!",
332 Some("ati-orchestrator".into()),
333 "ati-proxy".into(),
334 )
335 }
336
337 fn make_claims(scope: &str) -> TokenClaims {
338 let now = now_secs();
339 TokenClaims {
340 iss: Some("ati-orchestrator".into()),
341 sub: "agent-7".into(),
342 aud: "ati-proxy".into(),
343 iat: now,
344 exp: now + 1800,
345 jti: Some(uuid::Uuid::new_v4().to_string()),
346 scope: scope.into(),
347 ati: Some(AtiNamespace {
348 v: 1,
349 rate: HashMap::new(),
350 }),
351 }
352 }
353
354 #[test]
355 fn test_hs256_round_trip() {
356 let config = hs256_config();
357 let claims = make_claims("tool:web_search tool:github:*");
358
359 let token = issue(&claims, &config).unwrap();
360 let decoded = validate(&token, &config).unwrap();
361
362 assert_eq!(decoded.sub, "agent-7");
363 assert_eq!(decoded.aud, "ati-proxy");
364 assert_eq!(decoded.scope, "tool:web_search tool:github:*");
365 assert_eq!(decoded.scopes(), vec!["tool:web_search", "tool:github:*"]);
366 assert_eq!(decoded.iss, Some("ati-orchestrator".into()));
367 }
368
369 #[test]
370 fn test_expired_token_rejected() {
371 let config = hs256_config();
372 let mut claims = make_claims("tool:web_search");
373 claims.exp = 1; let token = issue(&claims, &config).unwrap();
376 let result = validate(&token, &config);
377 assert!(result.is_err());
378 }
379
380 #[test]
381 fn test_wrong_secret_rejected() {
382 let config1 = hs256_config();
383 let config2 =
384 config_from_secret(b"different-secret-key-32-bytes!!", None, "ati-proxy".into());
385
386 let claims = make_claims("tool:web_search");
387 let token = issue(&claims, &config1).unwrap();
388 let result = validate(&token, &config2);
389 assert!(result.is_err());
390 }
391
392 #[test]
393 fn test_wrong_audience_rejected() {
394 let config = hs256_config();
395 let mut claims = make_claims("tool:web_search");
396 claims.aud = "wrong-audience".into();
397
398 let token = issue(&claims, &config).unwrap();
399 let result = validate(&token, &config);
400 assert!(result.is_err());
401 }
402
403 #[test]
404 fn test_wrong_issuer_rejected() {
405 let config = hs256_config_with_issuer();
406 let mut claims = make_claims("tool:web_search");
407 claims.iss = Some("evil-orchestrator".into());
408
409 let token = issue(&claims, &config).unwrap();
410 let result = validate(&token, &config);
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_tampered_payload_rejected() {
416 let config = hs256_config();
417 let claims = make_claims("tool:web_search");
418 let token = issue(&claims, &config).unwrap();
419
420 let parts: Vec<&str> = token.split('.').collect();
422 assert_eq!(parts.len(), 3);
423 let mut tampered_payload = parts[1].to_string();
424 if tampered_payload.ends_with('A') {
426 tampered_payload.push('B');
427 } else {
428 tampered_payload.push('A');
429 }
430 let tampered = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
431
432 let result = validate(&tampered, &config);
433 assert!(result.is_err());
434 }
435
436 #[test]
437 fn test_malformed_token_rejected() {
438 let config = hs256_config();
439 let result = validate("not.a.jwt.token.at.all", &config);
440 assert!(result.is_err());
441
442 let result = validate("", &config);
443 assert!(result.is_err());
444
445 let result = validate("just-a-string", &config);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn test_inspect_decodes_without_key() {
451 let config = hs256_config();
452 let claims = make_claims("tool:web_search skill:research-*");
453 let token = issue(&claims, &config).unwrap();
454
455 let decoded = inspect(&token).unwrap();
456 assert_eq!(decoded.sub, "agent-7");
457 assert_eq!(decoded.scope, "tool:web_search skill:research-*");
458 }
459
460 #[test]
461 fn test_scope_parsing() {
462 let claims = make_claims("tool:web_search tool:github:* skill:research-* help");
463 let scopes = claims.scopes();
464 assert_eq!(
465 scopes,
466 vec![
467 "tool:web_search",
468 "tool:github:*",
469 "skill:research-*",
470 "help"
471 ]
472 );
473 }
474
475 #[test]
476 fn test_empty_scope() {
477 let claims = make_claims("");
478 assert!(claims.scopes().is_empty());
479 }
480
481 #[test]
482 fn test_single_scope() {
483 let claims = make_claims("*");
484 assert_eq!(claims.scopes(), vec!["*"]);
485 }
486
487 #[test]
488 fn test_no_encoding_key_fails() {
489 let config = JwtConfig {
490 decoding_key: DecodingKey::from_secret(b"test"),
491 encoding_key: None,
492 algorithm: Algorithm::HS256,
493 required_issuer: None,
494 required_audience: "ati-proxy".into(),
495 leeway_secs: 60,
496 public_key_pem: None,
497 };
498
499 let claims = make_claims("tool:web_search");
500 let result = issue(&claims, &config);
501 assert!(result.is_err());
502 }
503
504 #[test]
505 fn test_issuer_not_required_when_none() {
506 let config = hs256_config(); let mut claims = make_claims("tool:web_search");
508 claims.iss = None;
509
510 let token = issue(&claims, &config).unwrap();
511 let decoded = validate(&token, &config).unwrap();
512 assert_eq!(decoded.iss, None);
513 }
514
515 #[test]
516 fn test_jti_preserved() {
517 let config = hs256_config();
518 let claims = make_claims("tool:web_search");
519 let jti = claims.jti.clone();
520
521 let token = issue(&claims, &config).unwrap();
522 let decoded = validate(&token, &config).unwrap();
523 assert_eq!(decoded.jti, jti);
524 }
525
526 #[test]
527 fn test_ati_namespace_preserved() {
528 let config = hs256_config();
529 let claims = make_claims("tool:web_search");
530
531 let token = issue(&claims, &config).unwrap();
532 let decoded = validate(&token, &config).unwrap();
533 assert!(decoded.ati.is_some());
534 assert_eq!(decoded.ati.unwrap().v, 1);
535 }
536}