Skip to main content

ati/core/
jwt.rs

1//! JWT-based authentication for ATI.
2//!
3//! ES256-signed JWTs carry identity + scopes + expiry in a single tamper-proof
4//! credential. The orchestrator signs with a private key; the proxy validates
5//! with the corresponding public key (served via JWKS).
6//!
7//! Supports ES256 (recommended) and HS256 (simpler, for single-machine setups).
8
9use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum JwtError {
16    #[error("JWT encoding failed: {0}")]
17    Encode(#[from] jsonwebtoken::errors::Error),
18    #[error("Invalid PEM key: {0}")]
19    InvalidKey(String),
20    #[error("No encoding key configured (private key required for issuance)")]
21    NoEncodingKey,
22    #[error("No decoding key configured (public key required for validation)")]
23    NoDecodingKey,
24    #[error("Base64 decode error: {0}")]
25    Base64(String),
26}
27
28/// Configuration for JWT validation and (optionally) issuance.
29#[derive(Clone)]
30pub struct JwtConfig {
31    /// Public key for validation.
32    pub decoding_key: DecodingKey,
33    /// Private key for issuance (only on orchestrator).
34    pub encoding_key: Option<EncodingKey>,
35    /// Signing algorithm (ES256 or HS256).
36    pub algorithm: Algorithm,
37    /// Expected `iss` claim (optional — skipped if None).
38    pub required_issuer: Option<String>,
39    /// Expected `aud` claim.
40    pub required_audience: String,
41    /// Clock skew tolerance in seconds.
42    pub leeway_secs: u64,
43    /// Raw public key PEM bytes (for JWKS endpoint).
44    pub public_key_pem: Option<Vec<u8>>,
45}
46
47impl std::fmt::Debug for JwtConfig {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("JwtConfig")
50            .field("algorithm", &self.algorithm)
51            .field("required_issuer", &self.required_issuer)
52            .field("required_audience", &self.required_audience)
53            .field("leeway_secs", &self.leeway_secs)
54            .field("has_encoding_key", &self.encoding_key.is_some())
55            .finish()
56    }
57}
58
59/// ATI-specific namespace in JWT claims.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AtiNamespace {
62    /// Claims schema version.
63    pub v: u8,
64    /// Per-tool-pattern rate limits (e.g. {"tool:github:*": "10/hour"}).
65    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
66    pub rate: HashMap<String, String>,
67}
68
69/// JWT claims per RFC 9068.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TokenClaims {
72    /// Issuer (who signed this token).
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub iss: Option<String>,
75    /// Subject (agent identity).
76    pub sub: String,
77    /// Audience (target service, e.g. "ati-proxy").
78    pub aud: String,
79    /// Issued-at timestamp (Unix seconds).
80    pub iat: u64,
81    /// Expiry timestamp (Unix seconds).
82    pub exp: u64,
83    /// Unique token ID (UUID) for replay detection.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub jti: Option<String>,
86    /// Space-delimited scopes per RFC 9068 §2.2.3.
87    pub scope: String,
88    /// ATI-specific claims namespace.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub ati: Option<AtiNamespace>,
91    /// Job identifier (set by orchestrator provisioner).
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub job_id: Option<String>,
94    /// Sandbox identifier (set by orchestrator provisioner).
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub sandbox_id: Option<String>,
97}
98
99impl TokenClaims {
100    /// Parse the space-delimited scope string into a Vec.
101    pub fn scopes(&self) -> Vec<String> {
102        self.scope.split_whitespace().map(String::from).collect()
103    }
104}
105
106/// Validate a JWT token string and return the claims.
107pub fn validate(token: &str, config: &JwtConfig) -> Result<TokenClaims, JwtError> {
108    let mut validation = Validation::new(config.algorithm);
109    validation.set_audience(&[&config.required_audience]);
110    validation.leeway = config.leeway_secs;
111
112    if let Some(ref issuer) = config.required_issuer {
113        validation.set_issuer(&[issuer]);
114    } else {
115        // Don't require issuer validation if not configured
116        validation.set_required_spec_claims(&["exp", "sub", "aud"]);
117    }
118
119    let token_data: TokenData<TokenClaims> =
120        jsonwebtoken::decode(token, &config.decoding_key, &validation)?;
121
122    Ok(token_data.claims)
123}
124
125/// Issue (sign) a JWT token from claims.
126pub fn issue(claims: &TokenClaims, config: &JwtConfig) -> Result<String, JwtError> {
127    let encoding_key = config
128        .encoding_key
129        .as_ref()
130        .ok_or(JwtError::NoEncodingKey)?;
131
132    let header = Header::new(config.algorithm);
133    let token = jsonwebtoken::encode(&header, claims, encoding_key)?;
134    Ok(token)
135}
136
137/// Decode a JWT without verifying the signature (for inspection only).
138pub fn inspect(token: &str) -> Result<TokenClaims, JwtError> {
139    let mut validation = Validation::default();
140    validation.insecure_disable_signature_validation();
141    validation.validate_aud = false;
142    validation.validate_exp = false;
143    validation.set_required_spec_claims::<&str>(&[]);
144
145    // Use a dummy key since we're not validating
146    let key = DecodingKey::from_secret(b"unused");
147    let token_data: TokenData<TokenClaims> = jsonwebtoken::decode(token, &key, &validation)?;
148
149    Ok(token_data.claims)
150}
151
152/// Load an ES256 or RS256 public key from PEM bytes.
153pub fn load_public_key_pem(pem: &[u8], alg: Algorithm) -> Result<DecodingKey, JwtError> {
154    match alg {
155        Algorithm::ES256 | Algorithm::ES384 => {
156            DecodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
157        }
158        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
159            DecodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
160        }
161        _ => Err(JwtError::InvalidKey(format!(
162            "Unsupported algorithm for PEM: {alg:?}"
163        ))),
164    }
165}
166
167/// Load an ES256 or RS256 private key from PEM bytes.
168pub fn load_private_key_pem(pem: &[u8], alg: Algorithm) -> Result<EncodingKey, JwtError> {
169    match alg {
170        Algorithm::ES256 | Algorithm::ES384 => {
171            EncodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
172        }
173        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
174            EncodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
175        }
176        _ => Err(JwtError::InvalidKey(format!(
177            "Unsupported algorithm for PEM: {alg:?}"
178        ))),
179    }
180}
181
182/// Create a JwtConfig from an HS256 shared secret.
183pub fn config_from_secret(secret: &[u8], issuer: Option<String>, audience: String) -> JwtConfig {
184    JwtConfig {
185        decoding_key: DecodingKey::from_secret(secret),
186        encoding_key: Some(EncodingKey::from_secret(secret)),
187        algorithm: Algorithm::HS256,
188        required_issuer: issuer,
189        required_audience: audience,
190        leeway_secs: 60,
191        public_key_pem: None,
192    }
193}
194
195/// Create a JwtConfig from PEM key files.
196pub fn config_from_pem(
197    public_pem: &[u8],
198    private_pem: Option<&[u8]>,
199    alg: Algorithm,
200    issuer: Option<String>,
201    audience: String,
202) -> Result<JwtConfig, JwtError> {
203    let decoding_key = load_public_key_pem(public_pem, alg)?;
204    let encoding_key = match private_pem {
205        Some(pem) => Some(load_private_key_pem(pem, alg)?),
206        None => None,
207    };
208
209    Ok(JwtConfig {
210        decoding_key,
211        encoding_key,
212        algorithm: alg,
213        required_issuer: issuer,
214        required_audience: audience,
215        leeway_secs: 60,
216        public_key_pem: Some(public_pem.to_vec()),
217    })
218}
219
220/// Generate a JWKS JSON object from a public key PEM.
221/// Returns the JWKS `keys` array suitable for `/.well-known/jwks.json`.
222pub fn public_key_to_jwks(
223    pem: &[u8],
224    alg: Algorithm,
225    kid: &str,
226) -> Result<serde_json::Value, JwtError> {
227    // Parse the PEM to extract the raw key bytes
228    let pem_str = std::str::from_utf8(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))?;
229
230    // Extract base64 content between PEM headers
231    let key_type = match alg {
232        Algorithm::ES256 | Algorithm::ES384 => "EC",
233        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => "RSA",
234        _ => {
235            return Err(JwtError::InvalidKey(
236                "Unsupported algorithm for JWKS".into(),
237            ))
238        }
239    };
240
241    let alg_str = match alg {
242        Algorithm::ES256 => "ES256",
243        Algorithm::ES384 => "ES384",
244        Algorithm::RS256 => "RS256",
245        Algorithm::RS384 => "RS384",
246        Algorithm::RS512 => "RS512",
247        _ => "unknown",
248    };
249
250    // For JWKS, we encode the full DER of the public key as x5c or use raw coordinates.
251    // Simpler approach: encode the entire PEM-decoded DER as a base64url x5c entry.
252    let der_b64: String = pem_str
253        .lines()
254        .filter(|line| !line.starts_with("-----"))
255        .collect::<Vec<_>>()
256        .join("");
257
258    let jwk = serde_json::json!({
259        "kty": key_type,
260        "use": "sig",
261        "alg": alg_str,
262        "kid": kid,
263        "x5c": [der_b64],
264    });
265
266    Ok(serde_json::json!({
267        "keys": [jwk]
268    }))
269}
270
271/// Build a JwtConfig from environment variables.
272///
273/// Priority:
274/// 1. `ATI_JWT_PUBLIC_KEY` (PEM file) → ES256
275/// 2. `ATI_JWT_SECRET` (hex string) → HS256
276/// 3. Neither → None (JWT disabled)
277pub fn config_from_env() -> Result<Option<JwtConfig>, JwtError> {
278    let issuer = std::env::var("ATI_JWT_ISSUER").ok();
279    let audience = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".to_string());
280
281    // Try ES256 first
282    if let Ok(pub_key_path) = std::env::var("ATI_JWT_PUBLIC_KEY") {
283        let public_pem = std::fs::read(&pub_key_path)
284            .map_err(|e| JwtError::InvalidKey(format!("Cannot read {pub_key_path}: {e}")))?;
285
286        let private_pem = std::env::var("ATI_JWT_PRIVATE_KEY")
287            .ok()
288            .and_then(|path| std::fs::read(&path).ok());
289
290        let mut config = config_from_pem(
291            &public_pem,
292            private_pem.as_deref(),
293            Algorithm::ES256,
294            issuer,
295            audience,
296        )?;
297
298        // Store raw PEM for JWKS endpoint
299        config.public_key_pem = Some(public_pem);
300
301        return Ok(Some(config));
302    }
303
304    // Try HS256 fallback
305    if let Ok(secret_hex) = std::env::var("ATI_JWT_SECRET") {
306        let secret_bytes = hex::decode(&secret_hex)
307            .map_err(|e| JwtError::InvalidKey(format!("ATI_JWT_SECRET is not valid hex: {e}")))?;
308
309        return Ok(Some(config_from_secret(&secret_bytes, issuer, audience)));
310    }
311
312    Ok(None)
313}
314
315/// Get the current Unix timestamp.
316pub fn now_secs() -> u64 {
317    std::time::SystemTime::now()
318        .duration_since(std::time::UNIX_EPOCH)
319        .unwrap_or_default()
320        .as_secs()
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    fn hs256_config() -> JwtConfig {
328        config_from_secret(
329            b"test-secret-key-32-bytes-long!!!",
330            None,
331            "ati-proxy".into(),
332        )
333    }
334
335    fn hs256_config_with_issuer() -> JwtConfig {
336        config_from_secret(
337            b"test-secret-key-32-bytes-long!!!",
338            Some("ati-orchestrator".into()),
339            "ati-proxy".into(),
340        )
341    }
342
343    fn make_claims(scope: &str) -> TokenClaims {
344        let now = now_secs();
345        TokenClaims {
346            iss: Some("ati-orchestrator".into()),
347            sub: "agent-7".into(),
348            aud: "ati-proxy".into(),
349            iat: now,
350            exp: now + 1800,
351            jti: Some(uuid::Uuid::new_v4().to_string()),
352            scope: scope.into(),
353            ati: Some(AtiNamespace {
354                v: 1,
355                rate: HashMap::new(),
356            }),
357            job_id: None,
358            sandbox_id: None,
359        }
360    }
361
362    #[test]
363    fn test_hs256_round_trip() {
364        let config = hs256_config();
365        let claims = make_claims("tool:web_search tool:github:*");
366
367        let token = issue(&claims, &config).unwrap();
368        let decoded = validate(&token, &config).unwrap();
369
370        assert_eq!(decoded.sub, "agent-7");
371        assert_eq!(decoded.aud, "ati-proxy");
372        assert_eq!(decoded.scope, "tool:web_search tool:github:*");
373        assert_eq!(decoded.scopes(), vec!["tool:web_search", "tool:github:*"]);
374        assert_eq!(decoded.iss, Some("ati-orchestrator".into()));
375    }
376
377    #[test]
378    fn test_expired_token_rejected() {
379        let config = hs256_config();
380        let mut claims = make_claims("tool:web_search");
381        claims.exp = 1; // Expired long ago
382
383        let token = issue(&claims, &config).unwrap();
384        let result = validate(&token, &config);
385        assert!(result.is_err());
386    }
387
388    #[test]
389    fn test_wrong_secret_rejected() {
390        let config1 = hs256_config();
391        let config2 =
392            config_from_secret(b"different-secret-key-32-bytes!!", None, "ati-proxy".into());
393
394        let claims = make_claims("tool:web_search");
395        let token = issue(&claims, &config1).unwrap();
396        let result = validate(&token, &config2);
397        assert!(result.is_err());
398    }
399
400    #[test]
401    fn test_wrong_audience_rejected() {
402        let config = hs256_config();
403        let mut claims = make_claims("tool:web_search");
404        claims.aud = "wrong-audience".into();
405
406        let token = issue(&claims, &config).unwrap();
407        let result = validate(&token, &config);
408        assert!(result.is_err());
409    }
410
411    #[test]
412    fn test_wrong_issuer_rejected() {
413        let config = hs256_config_with_issuer();
414        let mut claims = make_claims("tool:web_search");
415        claims.iss = Some("evil-orchestrator".into());
416
417        let token = issue(&claims, &config).unwrap();
418        let result = validate(&token, &config);
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn test_tampered_payload_rejected() {
424        let config = hs256_config();
425        let claims = make_claims("tool:web_search");
426        let token = issue(&claims, &config).unwrap();
427
428        // Tamper with the payload: change a character in the middle section
429        let parts: Vec<&str> = token.split('.').collect();
430        assert_eq!(parts.len(), 3);
431        let mut tampered_payload = parts[1].to_string();
432        // Flip a character
433        if tampered_payload.ends_with('A') {
434            tampered_payload.push('B');
435        } else {
436            tampered_payload.push('A');
437        }
438        let tampered = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
439
440        let result = validate(&tampered, &config);
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn test_malformed_token_rejected() {
446        let config = hs256_config();
447        let result = validate("not.a.jwt.token.at.all", &config);
448        assert!(result.is_err());
449
450        let result = validate("", &config);
451        assert!(result.is_err());
452
453        let result = validate("just-a-string", &config);
454        assert!(result.is_err());
455    }
456
457    #[test]
458    fn test_inspect_decodes_without_key() {
459        let config = hs256_config();
460        let claims = make_claims("tool:web_search skill:research-*");
461        let token = issue(&claims, &config).unwrap();
462
463        let decoded = inspect(&token).unwrap();
464        assert_eq!(decoded.sub, "agent-7");
465        assert_eq!(decoded.scope, "tool:web_search skill:research-*");
466    }
467
468    #[test]
469    fn test_scope_parsing() {
470        let claims = make_claims("tool:web_search tool:github:* skill:research-* help");
471        let scopes = claims.scopes();
472        assert_eq!(
473            scopes,
474            vec![
475                "tool:web_search",
476                "tool:github:*",
477                "skill:research-*",
478                "help"
479            ]
480        );
481    }
482
483    #[test]
484    fn test_empty_scope() {
485        let claims = make_claims("");
486        assert!(claims.scopes().is_empty());
487    }
488
489    #[test]
490    fn test_single_scope() {
491        let claims = make_claims("*");
492        assert_eq!(claims.scopes(), vec!["*"]);
493    }
494
495    #[test]
496    fn test_no_encoding_key_fails() {
497        let config = JwtConfig {
498            decoding_key: DecodingKey::from_secret(b"test"),
499            encoding_key: None,
500            algorithm: Algorithm::HS256,
501            required_issuer: None,
502            required_audience: "ati-proxy".into(),
503            leeway_secs: 60,
504            public_key_pem: None,
505        };
506
507        let claims = make_claims("tool:web_search");
508        let result = issue(&claims, &config);
509        assert!(result.is_err());
510    }
511
512    #[test]
513    fn test_issuer_not_required_when_none() {
514        let config = hs256_config(); // No required_issuer
515        let mut claims = make_claims("tool:web_search");
516        claims.iss = None;
517
518        let token = issue(&claims, &config).unwrap();
519        let decoded = validate(&token, &config).unwrap();
520        assert_eq!(decoded.iss, None);
521    }
522
523    #[test]
524    fn test_jti_preserved() {
525        let config = hs256_config();
526        let claims = make_claims("tool:web_search");
527        let jti = claims.jti.clone();
528
529        let token = issue(&claims, &config).unwrap();
530        let decoded = validate(&token, &config).unwrap();
531        assert_eq!(decoded.jti, jti);
532    }
533
534    #[test]
535    fn test_ati_namespace_preserved() {
536        let config = hs256_config();
537        let claims = make_claims("tool:web_search");
538
539        let token = issue(&claims, &config).unwrap();
540        let decoded = validate(&token, &config).unwrap();
541        assert!(decoded.ati.is_some());
542        assert_eq!(decoded.ati.unwrap().v, 1);
543    }
544}