1use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum JwtError {
16 #[error("JWT encoding failed: {0}")]
17 Encode(#[from] jsonwebtoken::errors::Error),
18 #[error("Invalid PEM key: {0}")]
19 InvalidKey(String),
20 #[error("No encoding key configured (private key required for issuance)")]
21 NoEncodingKey,
22 #[error("No decoding key configured (public key required for validation)")]
23 NoDecodingKey,
24 #[error("Base64 decode error: {0}")]
25 Base64(String),
26}
27
28#[derive(Clone)]
30pub struct JwtConfig {
31 pub decoding_key: DecodingKey,
33 pub encoding_key: Option<EncodingKey>,
35 pub algorithm: Algorithm,
37 pub required_issuer: Option<String>,
39 pub required_audience: String,
41 pub leeway_secs: u64,
43 pub public_key_pem: Option<Vec<u8>>,
45}
46
47impl std::fmt::Debug for JwtConfig {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("JwtConfig")
50 .field("algorithm", &self.algorithm)
51 .field("required_issuer", &self.required_issuer)
52 .field("required_audience", &self.required_audience)
53 .field("leeway_secs", &self.leeway_secs)
54 .field("has_encoding_key", &self.encoding_key.is_some())
55 .finish()
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AtiNamespace {
62 pub v: u8,
64 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66 pub rate: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TokenClaims {
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub iss: Option<String>,
75 pub sub: String,
77 pub aud: String,
79 pub iat: u64,
81 pub exp: u64,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub jti: Option<String>,
86 pub scope: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub ati: Option<AtiNamespace>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub job_id: Option<String>,
94 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub sandbox_id: Option<String>,
97}
98
99impl TokenClaims {
100 pub fn scopes(&self) -> Vec<String> {
102 self.scope.split_whitespace().map(String::from).collect()
103 }
104}
105
106pub fn validate(token: &str, config: &JwtConfig) -> Result<TokenClaims, JwtError> {
108 let mut validation = Validation::new(config.algorithm);
109 validation.set_audience(&[&config.required_audience]);
110 validation.leeway = config.leeway_secs;
111
112 if let Some(ref issuer) = config.required_issuer {
113 validation.set_issuer(&[issuer]);
114 } else {
115 validation.set_required_spec_claims(&["exp", "sub", "aud"]);
117 }
118
119 let token_data: TokenData<TokenClaims> =
120 jsonwebtoken::decode(token, &config.decoding_key, &validation)?;
121
122 Ok(token_data.claims)
123}
124
125pub fn issue(claims: &TokenClaims, config: &JwtConfig) -> Result<String, JwtError> {
127 let encoding_key = config
128 .encoding_key
129 .as_ref()
130 .ok_or(JwtError::NoEncodingKey)?;
131
132 let header = Header::new(config.algorithm);
133 let token = jsonwebtoken::encode(&header, claims, encoding_key)?;
134 Ok(token)
135}
136
137pub fn inspect(token: &str) -> Result<TokenClaims, JwtError> {
139 let mut validation = Validation::default();
140 validation.insecure_disable_signature_validation();
141 validation.validate_aud = false;
142 validation.validate_exp = false;
143 validation.set_required_spec_claims::<&str>(&[]);
144
145 let key = DecodingKey::from_secret(b"unused");
147 let token_data: TokenData<TokenClaims> = jsonwebtoken::decode(token, &key, &validation)?;
148
149 Ok(token_data.claims)
150}
151
152pub fn load_public_key_pem(pem: &[u8], alg: Algorithm) -> Result<DecodingKey, JwtError> {
154 match alg {
155 Algorithm::ES256 | Algorithm::ES384 => {
156 DecodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
157 }
158 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
159 DecodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
160 }
161 _ => Err(JwtError::InvalidKey(format!(
162 "Unsupported algorithm for PEM: {alg:?}"
163 ))),
164 }
165}
166
167pub fn load_private_key_pem(pem: &[u8], alg: Algorithm) -> Result<EncodingKey, JwtError> {
169 match alg {
170 Algorithm::ES256 | Algorithm::ES384 => {
171 EncodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
172 }
173 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
174 EncodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
175 }
176 _ => Err(JwtError::InvalidKey(format!(
177 "Unsupported algorithm for PEM: {alg:?}"
178 ))),
179 }
180}
181
182pub fn config_from_secret(secret: &[u8], issuer: Option<String>, audience: String) -> JwtConfig {
184 JwtConfig {
185 decoding_key: DecodingKey::from_secret(secret),
186 encoding_key: Some(EncodingKey::from_secret(secret)),
187 algorithm: Algorithm::HS256,
188 required_issuer: issuer,
189 required_audience: audience,
190 leeway_secs: 60,
191 public_key_pem: None,
192 }
193}
194
195pub fn config_from_pem(
197 public_pem: &[u8],
198 private_pem: Option<&[u8]>,
199 alg: Algorithm,
200 issuer: Option<String>,
201 audience: String,
202) -> Result<JwtConfig, JwtError> {
203 let decoding_key = load_public_key_pem(public_pem, alg)?;
204 let encoding_key = match private_pem {
205 Some(pem) => Some(load_private_key_pem(pem, alg)?),
206 None => None,
207 };
208
209 Ok(JwtConfig {
210 decoding_key,
211 encoding_key,
212 algorithm: alg,
213 required_issuer: issuer,
214 required_audience: audience,
215 leeway_secs: 60,
216 public_key_pem: Some(public_pem.to_vec()),
217 })
218}
219
220pub fn public_key_to_jwks(
223 pem: &[u8],
224 alg: Algorithm,
225 kid: &str,
226) -> Result<serde_json::Value, JwtError> {
227 let pem_str = std::str::from_utf8(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))?;
229
230 let key_type = match alg {
232 Algorithm::ES256 | Algorithm::ES384 => "EC",
233 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => "RSA",
234 _ => {
235 return Err(JwtError::InvalidKey(
236 "Unsupported algorithm for JWKS".into(),
237 ))
238 }
239 };
240
241 let alg_str = match alg {
242 Algorithm::ES256 => "ES256",
243 Algorithm::ES384 => "ES384",
244 Algorithm::RS256 => "RS256",
245 Algorithm::RS384 => "RS384",
246 Algorithm::RS512 => "RS512",
247 _ => "unknown",
248 };
249
250 let der_b64: String = pem_str
253 .lines()
254 .filter(|line| !line.starts_with("-----"))
255 .collect::<Vec<_>>()
256 .join("");
257
258 let jwk = serde_json::json!({
259 "kty": key_type,
260 "use": "sig",
261 "alg": alg_str,
262 "kid": kid,
263 "x5c": [der_b64],
264 });
265
266 Ok(serde_json::json!({
267 "keys": [jwk]
268 }))
269}
270
271pub fn config_from_env() -> Result<Option<JwtConfig>, JwtError> {
278 let issuer = std::env::var("ATI_JWT_ISSUER").ok();
279 let audience = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".to_string());
280
281 if let Ok(pub_key_path) = std::env::var("ATI_JWT_PUBLIC_KEY") {
283 let public_pem = std::fs::read(&pub_key_path)
284 .map_err(|e| JwtError::InvalidKey(format!("Cannot read {pub_key_path}: {e}")))?;
285
286 let private_pem = std::env::var("ATI_JWT_PRIVATE_KEY")
287 .ok()
288 .and_then(|path| std::fs::read(&path).ok());
289
290 let mut config = config_from_pem(
291 &public_pem,
292 private_pem.as_deref(),
293 Algorithm::ES256,
294 issuer,
295 audience,
296 )?;
297
298 config.public_key_pem = Some(public_pem);
300
301 return Ok(Some(config));
302 }
303
304 if let Ok(secret_hex) = std::env::var("ATI_JWT_SECRET") {
306 let secret_bytes = hex::decode(&secret_hex)
307 .map_err(|e| JwtError::InvalidKey(format!("ATI_JWT_SECRET is not valid hex: {e}")))?;
308
309 return Ok(Some(config_from_secret(&secret_bytes, issuer, audience)));
310 }
311
312 Ok(None)
313}
314
315pub fn now_secs() -> u64 {
317 std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .unwrap_or_default()
320 .as_secs()
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 fn hs256_config() -> JwtConfig {
328 config_from_secret(
329 b"test-secret-key-32-bytes-long!!!",
330 None,
331 "ati-proxy".into(),
332 )
333 }
334
335 fn hs256_config_with_issuer() -> JwtConfig {
336 config_from_secret(
337 b"test-secret-key-32-bytes-long!!!",
338 Some("ati-orchestrator".into()),
339 "ati-proxy".into(),
340 )
341 }
342
343 fn make_claims(scope: &str) -> TokenClaims {
344 let now = now_secs();
345 TokenClaims {
346 iss: Some("ati-orchestrator".into()),
347 sub: "agent-7".into(),
348 aud: "ati-proxy".into(),
349 iat: now,
350 exp: now + 1800,
351 jti: Some(uuid::Uuid::new_v4().to_string()),
352 scope: scope.into(),
353 ati: Some(AtiNamespace {
354 v: 1,
355 rate: HashMap::new(),
356 }),
357 job_id: None,
358 sandbox_id: None,
359 }
360 }
361
362 #[test]
363 fn test_hs256_round_trip() {
364 let config = hs256_config();
365 let claims = make_claims("tool:web_search tool:github:*");
366
367 let token = issue(&claims, &config).unwrap();
368 let decoded = validate(&token, &config).unwrap();
369
370 assert_eq!(decoded.sub, "agent-7");
371 assert_eq!(decoded.aud, "ati-proxy");
372 assert_eq!(decoded.scope, "tool:web_search tool:github:*");
373 assert_eq!(decoded.scopes(), vec!["tool:web_search", "tool:github:*"]);
374 assert_eq!(decoded.iss, Some("ati-orchestrator".into()));
375 }
376
377 #[test]
378 fn test_expired_token_rejected() {
379 let config = hs256_config();
380 let mut claims = make_claims("tool:web_search");
381 claims.exp = 1; let token = issue(&claims, &config).unwrap();
384 let result = validate(&token, &config);
385 assert!(result.is_err());
386 }
387
388 #[test]
389 fn test_wrong_secret_rejected() {
390 let config1 = hs256_config();
391 let config2 =
392 config_from_secret(b"different-secret-key-32-bytes!!", None, "ati-proxy".into());
393
394 let claims = make_claims("tool:web_search");
395 let token = issue(&claims, &config1).unwrap();
396 let result = validate(&token, &config2);
397 assert!(result.is_err());
398 }
399
400 #[test]
401 fn test_wrong_audience_rejected() {
402 let config = hs256_config();
403 let mut claims = make_claims("tool:web_search");
404 claims.aud = "wrong-audience".into();
405
406 let token = issue(&claims, &config).unwrap();
407 let result = validate(&token, &config);
408 assert!(result.is_err());
409 }
410
411 #[test]
412 fn test_wrong_issuer_rejected() {
413 let config = hs256_config_with_issuer();
414 let mut claims = make_claims("tool:web_search");
415 claims.iss = Some("evil-orchestrator".into());
416
417 let token = issue(&claims, &config).unwrap();
418 let result = validate(&token, &config);
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn test_tampered_payload_rejected() {
424 let config = hs256_config();
425 let claims = make_claims("tool:web_search");
426 let token = issue(&claims, &config).unwrap();
427
428 let parts: Vec<&str> = token.split('.').collect();
430 assert_eq!(parts.len(), 3);
431 let mut tampered_payload = parts[1].to_string();
432 if tampered_payload.ends_with('A') {
434 tampered_payload.push('B');
435 } else {
436 tampered_payload.push('A');
437 }
438 let tampered = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
439
440 let result = validate(&tampered, &config);
441 assert!(result.is_err());
442 }
443
444 #[test]
445 fn test_malformed_token_rejected() {
446 let config = hs256_config();
447 let result = validate("not.a.jwt.token.at.all", &config);
448 assert!(result.is_err());
449
450 let result = validate("", &config);
451 assert!(result.is_err());
452
453 let result = validate("just-a-string", &config);
454 assert!(result.is_err());
455 }
456
457 #[test]
458 fn test_inspect_decodes_without_key() {
459 let config = hs256_config();
460 let claims = make_claims("tool:web_search skill:research-*");
461 let token = issue(&claims, &config).unwrap();
462
463 let decoded = inspect(&token).unwrap();
464 assert_eq!(decoded.sub, "agent-7");
465 assert_eq!(decoded.scope, "tool:web_search skill:research-*");
466 }
467
468 #[test]
469 fn test_scope_parsing() {
470 let claims = make_claims("tool:web_search tool:github:* skill:research-* help");
471 let scopes = claims.scopes();
472 assert_eq!(
473 scopes,
474 vec![
475 "tool:web_search",
476 "tool:github:*",
477 "skill:research-*",
478 "help"
479 ]
480 );
481 }
482
483 #[test]
484 fn test_empty_scope() {
485 let claims = make_claims("");
486 assert!(claims.scopes().is_empty());
487 }
488
489 #[test]
490 fn test_single_scope() {
491 let claims = make_claims("*");
492 assert_eq!(claims.scopes(), vec!["*"]);
493 }
494
495 #[test]
496 fn test_no_encoding_key_fails() {
497 let config = JwtConfig {
498 decoding_key: DecodingKey::from_secret(b"test"),
499 encoding_key: None,
500 algorithm: Algorithm::HS256,
501 required_issuer: None,
502 required_audience: "ati-proxy".into(),
503 leeway_secs: 60,
504 public_key_pem: None,
505 };
506
507 let claims = make_claims("tool:web_search");
508 let result = issue(&claims, &config);
509 assert!(result.is_err());
510 }
511
512 #[test]
513 fn test_issuer_not_required_when_none() {
514 let config = hs256_config(); let mut claims = make_claims("tool:web_search");
516 claims.iss = None;
517
518 let token = issue(&claims, &config).unwrap();
519 let decoded = validate(&token, &config).unwrap();
520 assert_eq!(decoded.iss, None);
521 }
522
523 #[test]
524 fn test_jti_preserved() {
525 let config = hs256_config();
526 let claims = make_claims("tool:web_search");
527 let jti = claims.jti.clone();
528
529 let token = issue(&claims, &config).unwrap();
530 let decoded = validate(&token, &config).unwrap();
531 assert_eq!(decoded.jti, jti);
532 }
533
534 #[test]
535 fn test_ati_namespace_preserved() {
536 let config = hs256_config();
537 let claims = make_claims("tool:web_search");
538
539 let token = issue(&claims, &config).unwrap();
540 let decoded = validate(&token, &config).unwrap();
541 assert!(decoded.ati.is_some());
542 assert_eq!(decoded.ati.unwrap().v, 1);
543 }
544}