Skip to main content

ati/core/
jwt.rs

1//! JWT-based authentication for ATI.
2//!
3//! ES256-signed JWTs carry identity + scopes + expiry in a single tamper-proof
4//! credential. The orchestrator signs with a private key; the proxy validates
5//! with the corresponding public key (served via JWKS).
6//!
7//! Supports ES256 (recommended) and HS256 (simpler, for single-machine setups).
8
9use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum JwtError {
16    #[error("JWT encoding failed: {0}")]
17    Encode(#[from] jsonwebtoken::errors::Error),
18    #[error("Invalid PEM key: {0}")]
19    InvalidKey(String),
20    #[error("No encoding key configured (private key required for issuance)")]
21    NoEncodingKey,
22    #[error("No decoding key configured (public key required for validation)")]
23    NoDecodingKey,
24    #[error("Base64 decode error: {0}")]
25    Base64(String),
26}
27
28/// Configuration for JWT validation and (optionally) issuance.
29#[derive(Clone)]
30pub struct JwtConfig {
31    /// Public key for validation.
32    pub decoding_key: DecodingKey,
33    /// Private key for issuance (only on orchestrator).
34    pub encoding_key: Option<EncodingKey>,
35    /// Signing algorithm (ES256 or HS256).
36    pub algorithm: Algorithm,
37    /// Expected `iss` claim (optional — skipped if None).
38    pub required_issuer: Option<String>,
39    /// Accepted `aud` claim values. A token whose `aud` matches **any** entry
40    /// in this list passes audience validation; the per-tool scope check
41    /// downstream still gates the actual call.
42    ///
43    /// Single-element vec preserves v0.7.x single-audience behaviour. For
44    /// multi-audience deployments (e.g. proxy accepting both `ati-proxy` and
45    /// per-MCP-audience tokens — see issue #121), populate from
46    /// `ATI_JWT_ACCEPTED_AUDIENCES` (CSV env var).
47    pub accepted_audiences: Vec<String>,
48    /// Clock skew tolerance in seconds.
49    pub leeway_secs: u64,
50    /// Raw public key PEM bytes (for JWKS endpoint).
51    pub public_key_pem: Option<Vec<u8>>,
52}
53
54impl std::fmt::Debug for JwtConfig {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("JwtConfig")
57            .field("algorithm", &self.algorithm)
58            .field("required_issuer", &self.required_issuer)
59            .field("accepted_audiences", &self.accepted_audiences)
60            .field("leeway_secs", &self.leeway_secs)
61            .field("has_encoding_key", &self.encoding_key.is_some())
62            .finish()
63    }
64}
65
66/// ATI-specific namespace in JWT claims.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct AtiNamespace {
69    /// Claims schema version.
70    pub v: u8,
71    /// Per-tool-pattern rate limits (e.g. {"tool:github:*": "10/hour"}).
72    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73    pub rate: HashMap<String, String>,
74}
75
76/// JWT claims per RFC 9068.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TokenClaims {
79    /// Issuer (who signed this token).
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub iss: Option<String>,
82    /// Subject (agent identity).
83    pub sub: String,
84    /// Audience (target service, e.g. "ati-proxy").
85    pub aud: String,
86    /// Issued-at timestamp (Unix seconds).
87    pub iat: u64,
88    /// Expiry timestamp (Unix seconds).
89    pub exp: u64,
90    /// Unique token ID (UUID) for replay detection.
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub jti: Option<String>,
93    /// Space-delimited scopes per RFC 9068 §2.2.3.
94    pub scope: String,
95    /// ATI-specific claims namespace.
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub ati: Option<AtiNamespace>,
98    /// Job identifier (set by orchestrator provisioner).
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub job_id: Option<String>,
101    /// Sandbox identifier (set by orchestrator provisioner).
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub sandbox_id: Option<String>,
104}
105
106impl TokenClaims {
107    /// Parse the space-delimited scope string into a Vec.
108    pub fn scopes(&self) -> Vec<String> {
109        self.scope.split_whitespace().map(String::from).collect()
110    }
111}
112
113/// Validate a JWT token string and return the claims.
114///
115/// A token whose `aud` claim matches **any** entry in
116/// `config.accepted_audiences` passes. Per-tool authorization (scope check)
117/// is done separately by callers via `ScopeConfig`.
118pub fn validate(token: &str, config: &JwtConfig) -> Result<TokenClaims, JwtError> {
119    // Defense in depth: jsonwebtoken's `Validation::set_audience(&[])` silently
120    // disables audience validation, so an empty `accepted_audiences` Vec would
121    // accept any `aud` claim — a quiet security regression if a future caller
122    // builds a `JwtConfig` literal with an empty Vec or hand-crafts one in a
123    // test. `parse_audiences_env()` already guards this at the env-loading
124    // layer (Greptile P1 on #121); this guard locks the API boundary so the
125    // invariant can't drift independently. See issue #121 review.
126    if config.accepted_audiences.is_empty() {
127        return Err(JwtError::InvalidKey(
128            "accepted_audiences must not be empty; configure at least one audience".into(),
129        ));
130    }
131    let mut validation = Validation::new(config.algorithm);
132    // jsonwebtoken's set_audience uses "any-match" semantics: token.aud is
133    // accepted iff it matches at least one entry. Passing a slice of &str
134    // borrows from the Vec<String> without allocating an intermediate owned
135    // collection.
136    let auds: Vec<&str> = config
137        .accepted_audiences
138        .iter()
139        .map(String::as_str)
140        .collect();
141    validation.set_audience(&auds);
142    validation.leeway = config.leeway_secs;
143
144    if let Some(ref issuer) = config.required_issuer {
145        validation.set_issuer(&[issuer]);
146    } else {
147        // Don't require issuer validation if not configured
148        validation.set_required_spec_claims(&["exp", "sub", "aud"]);
149    }
150
151    let token_data: TokenData<TokenClaims> =
152        jsonwebtoken::decode(token, &config.decoding_key, &validation)?;
153
154    Ok(token_data.claims)
155}
156
157/// Issue (sign) a JWT token from claims.
158pub fn issue(claims: &TokenClaims, config: &JwtConfig) -> Result<String, JwtError> {
159    let encoding_key = config
160        .encoding_key
161        .as_ref()
162        .ok_or(JwtError::NoEncodingKey)?;
163
164    let header = Header::new(config.algorithm);
165    let token = jsonwebtoken::encode(&header, claims, encoding_key)?;
166    Ok(token)
167}
168
169/// Decode a JWT without verifying the signature (for inspection only).
170pub fn inspect(token: &str) -> Result<TokenClaims, JwtError> {
171    let mut validation = Validation::default();
172    validation.insecure_disable_signature_validation();
173    validation.validate_aud = false;
174    validation.validate_exp = false;
175    validation.set_required_spec_claims::<&str>(&[]);
176
177    // Use a dummy key since we're not validating
178    let key = DecodingKey::from_secret(b"unused");
179    let token_data: TokenData<TokenClaims> = jsonwebtoken::decode(token, &key, &validation)?;
180
181    Ok(token_data.claims)
182}
183
184/// Load an ES256 or RS256 public key from PEM bytes.
185pub fn load_public_key_pem(pem: &[u8], alg: Algorithm) -> Result<DecodingKey, JwtError> {
186    match alg {
187        Algorithm::ES256 | Algorithm::ES384 => {
188            DecodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
189        }
190        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
191            DecodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
192        }
193        _ => Err(JwtError::InvalidKey(format!(
194            "Unsupported algorithm for PEM: {alg:?}"
195        ))),
196    }
197}
198
199/// Load an ES256 or RS256 private key from PEM bytes.
200pub fn load_private_key_pem(pem: &[u8], alg: Algorithm) -> Result<EncodingKey, JwtError> {
201    match alg {
202        Algorithm::ES256 | Algorithm::ES384 => {
203            EncodingKey::from_ec_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
204        }
205        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
206            EncodingKey::from_rsa_pem(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))
207        }
208        _ => Err(JwtError::InvalidKey(format!(
209            "Unsupported algorithm for PEM: {alg:?}"
210        ))),
211    }
212}
213
214/// Create a JwtConfig from an HS256 shared secret.
215///
216/// `audiences` is the allowlist of acceptable `aud` claim values; a token
217/// matches if its `aud` equals any entry. Single-element vec for the
218/// common single-audience case. **Empty vec is an error in v0.7.x callers
219/// that hit it indirectly — every config-creation path normalises empty to
220/// `["ati-proxy"]` (the historical default) so the proxy fails loud rather
221/// than accepting any aud silently.** See `parse_audiences_env`.
222pub fn config_from_secret(
223    secret: &[u8],
224    issuer: Option<String>,
225    audiences: Vec<String>,
226) -> JwtConfig {
227    JwtConfig {
228        decoding_key: DecodingKey::from_secret(secret),
229        encoding_key: Some(EncodingKey::from_secret(secret)),
230        algorithm: Algorithm::HS256,
231        required_issuer: issuer,
232        accepted_audiences: audiences,
233        leeway_secs: 60,
234        public_key_pem: None,
235    }
236}
237
238/// Create a JwtConfig from PEM key files. See [`config_from_secret`] for
239/// the `audiences` contract.
240pub fn config_from_pem(
241    public_pem: &[u8],
242    private_pem: Option<&[u8]>,
243    alg: Algorithm,
244    issuer: Option<String>,
245    audiences: Vec<String>,
246) -> Result<JwtConfig, JwtError> {
247    let decoding_key = load_public_key_pem(public_pem, alg)?;
248    let encoding_key = match private_pem {
249        Some(pem) => Some(load_private_key_pem(pem, alg)?),
250        None => None,
251    };
252
253    Ok(JwtConfig {
254        decoding_key,
255        encoding_key,
256        algorithm: alg,
257        required_issuer: issuer,
258        accepted_audiences: audiences,
259        leeway_secs: 60,
260        public_key_pem: Some(public_pem.to_vec()),
261    })
262}
263
264/// Generate a JWKS JSON object from a public key PEM.
265/// Returns the JWKS `keys` array suitable for `/.well-known/jwks.json`.
266pub fn public_key_to_jwks(
267    pem: &[u8],
268    alg: Algorithm,
269    kid: &str,
270) -> Result<serde_json::Value, JwtError> {
271    // Parse the PEM to extract the raw key bytes
272    let pem_str = std::str::from_utf8(pem).map_err(|e| JwtError::InvalidKey(e.to_string()))?;
273
274    // Extract base64 content between PEM headers
275    let key_type = match alg {
276        Algorithm::ES256 | Algorithm::ES384 => "EC",
277        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => "RSA",
278        _ => {
279            return Err(JwtError::InvalidKey(
280                "Unsupported algorithm for JWKS".into(),
281            ))
282        }
283    };
284
285    let alg_str = match alg {
286        Algorithm::ES256 => "ES256",
287        Algorithm::ES384 => "ES384",
288        Algorithm::RS256 => "RS256",
289        Algorithm::RS384 => "RS384",
290        Algorithm::RS512 => "RS512",
291        _ => "unknown",
292    };
293
294    // For JWKS, we encode the full DER of the public key as x5c or use raw coordinates.
295    // Simpler approach: encode the entire PEM-decoded DER as a base64url x5c entry.
296    let der_b64: String = pem_str
297        .lines()
298        .filter(|line| !line.starts_with("-----"))
299        .collect::<Vec<_>>()
300        .join("");
301
302    let jwk = serde_json::json!({
303        "kty": key_type,
304        "use": "sig",
305        "alg": alg_str,
306        "kid": kid,
307        "x5c": [der_b64],
308    });
309
310    Ok(serde_json::json!({
311        "keys": [jwk]
312    }))
313}
314
315/// Resolve the accepted-audience allowlist from environment.
316///
317/// Priority (first source wins):
318/// 1. `ATI_JWT_ACCEPTED_AUDIENCES` (CSV) — operator declares an allowlist,
319///    e.g. `"ati-proxy,parcha-custom-tools"`. Used when the proxy accepts
320///    multiple aud values for per-provider audience separation (#121).
321/// 2. `ATI_JWT_AUDIENCE` (singular) — back-compat with v0.7.x single-aud
322///    deployments; wrapped in a one-element vec.
323/// 3. Default: `["ati-proxy"]` — preserves v0.7.x behaviour when nothing
324///    is set.
325///
326/// Empty/whitespace-only CSV entries are dropped. An empty list falls
327/// through to the singular env / default rather than producing
328/// `Vec::new()` (which `validate()` would interpret as "accept any aud").
329pub fn parse_audiences_env() -> Vec<String> {
330    if let Ok(csv) = std::env::var("ATI_JWT_ACCEPTED_AUDIENCES") {
331        let v: Vec<String> = csv
332            .split(',')
333            .map(|s| s.trim().to_string())
334            .filter(|s| !s.is_empty())
335            .collect();
336        if !v.is_empty() {
337            return v;
338        }
339    }
340    let single = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".to_string());
341    vec![single]
342}
343
344/// Build a JwtConfig from environment variables.
345///
346/// Priority:
347/// 1. `ATI_JWT_PUBLIC_KEY` (PEM file) → ES256
348/// 2. `ATI_JWT_SECRET` (hex string) → HS256
349/// 3. Neither → None (JWT disabled)
350///
351/// The audience allowlist is sourced via [`parse_audiences_env`] —
352/// `ATI_JWT_ACCEPTED_AUDIENCES` (CSV) > `ATI_JWT_AUDIENCE` (singular) >
353/// `["ati-proxy"]` default.
354pub fn config_from_env() -> Result<Option<JwtConfig>, JwtError> {
355    let issuer = std::env::var("ATI_JWT_ISSUER").ok();
356    let audiences = parse_audiences_env();
357
358    // Try ES256 first
359    if let Ok(pub_key_path) = std::env::var("ATI_JWT_PUBLIC_KEY") {
360        let public_pem = std::fs::read(&pub_key_path)
361            .map_err(|e| JwtError::InvalidKey(format!("Cannot read {pub_key_path}: {e}")))?;
362
363        let private_pem = std::env::var("ATI_JWT_PRIVATE_KEY")
364            .ok()
365            .and_then(|path| std::fs::read(&path).ok());
366
367        let mut config = config_from_pem(
368            &public_pem,
369            private_pem.as_deref(),
370            Algorithm::ES256,
371            issuer,
372            audiences,
373        )?;
374
375        // Store raw PEM for JWKS endpoint
376        config.public_key_pem = Some(public_pem);
377
378        return Ok(Some(config));
379    }
380
381    // Try HS256 fallback
382    if let Ok(secret_hex) = std::env::var("ATI_JWT_SECRET") {
383        let secret_bytes = hex::decode(&secret_hex)
384            .map_err(|e| JwtError::InvalidKey(format!("ATI_JWT_SECRET is not valid hex: {e}")))?;
385
386        return Ok(Some(config_from_secret(&secret_bytes, issuer, audiences)));
387    }
388
389    Ok(None)
390}
391
392/// Get the current Unix timestamp.
393pub fn now_secs() -> u64 {
394    std::time::SystemTime::now()
395        .duration_since(std::time::UNIX_EPOCH)
396        .unwrap_or_default()
397        .as_secs()
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    fn hs256_config() -> JwtConfig {
405        config_from_secret(
406            b"test-secret-key-32-bytes-long!!!",
407            None,
408            vec!["ati-proxy".into()],
409        )
410    }
411
412    fn hs256_config_with_issuer() -> JwtConfig {
413        config_from_secret(
414            b"test-secret-key-32-bytes-long!!!",
415            Some("ati-orchestrator".into()),
416            vec!["ati-proxy".into()],
417        )
418    }
419
420    fn make_claims(scope: &str) -> TokenClaims {
421        let now = now_secs();
422        TokenClaims {
423            iss: Some("ati-orchestrator".into()),
424            sub: "agent-7".into(),
425            aud: "ati-proxy".into(),
426            iat: now,
427            exp: now + 1800,
428            jti: Some(uuid::Uuid::new_v4().to_string()),
429            scope: scope.into(),
430            ati: Some(AtiNamespace {
431                v: 1,
432                rate: HashMap::new(),
433            }),
434            job_id: None,
435            sandbox_id: None,
436        }
437    }
438
439    #[test]
440    fn test_hs256_round_trip() {
441        let config = hs256_config();
442        let claims = make_claims("tool:web_search tool:github:*");
443
444        let token = issue(&claims, &config).unwrap();
445        let decoded = validate(&token, &config).unwrap();
446
447        assert_eq!(decoded.sub, "agent-7");
448        assert_eq!(decoded.aud, "ati-proxy");
449        assert_eq!(decoded.scope, "tool:web_search tool:github:*");
450        assert_eq!(decoded.scopes(), vec!["tool:web_search", "tool:github:*"]);
451        assert_eq!(decoded.iss, Some("ati-orchestrator".into()));
452    }
453
454    #[test]
455    fn test_expired_token_rejected() {
456        let config = hs256_config();
457        let mut claims = make_claims("tool:web_search");
458        claims.exp = 1; // Expired long ago
459
460        let token = issue(&claims, &config).unwrap();
461        let result = validate(&token, &config);
462        assert!(result.is_err());
463    }
464
465    #[test]
466    fn test_wrong_secret_rejected() {
467        let config1 = hs256_config();
468        let config2 = config_from_secret(
469            b"different-secret-key-32-bytes!!",
470            None,
471            vec!["ati-proxy".into()],
472        );
473
474        let claims = make_claims("tool:web_search");
475        let token = issue(&claims, &config1).unwrap();
476        let result = validate(&token, &config2);
477        assert!(result.is_err());
478    }
479
480    #[test]
481    fn test_wrong_audience_rejected() {
482        let config = hs256_config();
483        let mut claims = make_claims("tool:web_search");
484        claims.aud = "wrong-audience".into();
485
486        let token = issue(&claims, &config).unwrap();
487        let result = validate(&token, &config);
488        assert!(result.is_err());
489    }
490
491    #[test]
492    fn test_wrong_issuer_rejected() {
493        let config = hs256_config_with_issuer();
494        let mut claims = make_claims("tool:web_search");
495        claims.iss = Some("evil-orchestrator".into());
496
497        let token = issue(&claims, &config).unwrap();
498        let result = validate(&token, &config);
499        assert!(result.is_err());
500    }
501
502    #[test]
503    fn test_tampered_payload_rejected() {
504        let config = hs256_config();
505        let claims = make_claims("tool:web_search");
506        let token = issue(&claims, &config).unwrap();
507
508        // Tamper with the payload: change a character in the middle section
509        let parts: Vec<&str> = token.split('.').collect();
510        assert_eq!(parts.len(), 3);
511        let mut tampered_payload = parts[1].to_string();
512        // Flip a character
513        if tampered_payload.ends_with('A') {
514            tampered_payload.push('B');
515        } else {
516            tampered_payload.push('A');
517        }
518        let tampered = format!("{}.{}.{}", parts[0], tampered_payload, parts[2]);
519
520        let result = validate(&tampered, &config);
521        assert!(result.is_err());
522    }
523
524    #[test]
525    fn test_malformed_token_rejected() {
526        let config = hs256_config();
527        let result = validate("not.a.jwt.token.at.all", &config);
528        assert!(result.is_err());
529
530        let result = validate("", &config);
531        assert!(result.is_err());
532
533        let result = validate("just-a-string", &config);
534        assert!(result.is_err());
535    }
536
537    #[test]
538    fn test_inspect_decodes_without_key() {
539        let config = hs256_config();
540        let claims = make_claims("tool:web_search skill:research-*");
541        let token = issue(&claims, &config).unwrap();
542
543        let decoded = inspect(&token).unwrap();
544        assert_eq!(decoded.sub, "agent-7");
545        assert_eq!(decoded.scope, "tool:web_search skill:research-*");
546    }
547
548    #[test]
549    fn test_scope_parsing() {
550        let claims = make_claims("tool:web_search tool:github:* skill:research-* help");
551        let scopes = claims.scopes();
552        assert_eq!(
553            scopes,
554            vec![
555                "tool:web_search",
556                "tool:github:*",
557                "skill:research-*",
558                "help"
559            ]
560        );
561    }
562
563    #[test]
564    fn test_empty_scope() {
565        let claims = make_claims("");
566        assert!(claims.scopes().is_empty());
567    }
568
569    #[test]
570    fn test_single_scope() {
571        let claims = make_claims("*");
572        assert_eq!(claims.scopes(), vec!["*"]);
573    }
574
575    #[test]
576    fn test_no_encoding_key_fails() {
577        let config = JwtConfig {
578            decoding_key: DecodingKey::from_secret(b"test"),
579            encoding_key: None,
580            algorithm: Algorithm::HS256,
581            required_issuer: None,
582            accepted_audiences: vec!["ati-proxy".into()],
583            leeway_secs: 60,
584            public_key_pem: None,
585        };
586
587        let claims = make_claims("tool:web_search");
588        let result = issue(&claims, &config);
589        assert!(result.is_err());
590    }
591
592    #[test]
593    fn test_issuer_not_required_when_none() {
594        let config = hs256_config(); // No required_issuer
595        let mut claims = make_claims("tool:web_search");
596        claims.iss = None;
597
598        let token = issue(&claims, &config).unwrap();
599        let decoded = validate(&token, &config).unwrap();
600        assert_eq!(decoded.iss, None);
601    }
602
603    #[test]
604    fn test_jti_preserved() {
605        let config = hs256_config();
606        let claims = make_claims("tool:web_search");
607        let jti = claims.jti.clone();
608
609        let token = issue(&claims, &config).unwrap();
610        let decoded = validate(&token, &config).unwrap();
611        assert_eq!(decoded.jti, jti);
612    }
613
614    #[test]
615    fn test_ati_namespace_preserved() {
616        let config = hs256_config();
617        let claims = make_claims("tool:web_search");
618
619        let token = issue(&claims, &config).unwrap();
620        let decoded = validate(&token, &config).unwrap();
621        assert!(decoded.ati.is_some());
622        assert_eq!(decoded.ati.unwrap().v, 1);
623    }
624
625    // -------------------------------------------------------------------------
626    // Multi-audience validation (issue #121).
627    //
628    // The proxy needs to accept JWTs minted for different downstream services
629    // (each with their own `aud` claim) so the sandbox can forward a
630    // per-provider-scoped token without re-minting at the proxy boundary.
631    //
632    // jsonwebtoken's `Validation::set_audience` is natively any-match across
633    // the slice, so the change is one-line. Tests here lock in the behaviour
634    // so a future refactor can't regress to single-audience matching.
635    // -------------------------------------------------------------------------
636
637    fn hs256_config_multi(audiences: Vec<String>) -> JwtConfig {
638        config_from_secret(b"test-secret-key-32-bytes-long!!!", None, audiences)
639    }
640
641    #[test]
642    fn test_multi_audience_accepts_first() {
643        let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
644        let mut claims = make_claims("tool:web_search");
645        claims.aud = "ati-proxy".into();
646
647        let token = issue(&claims, &config).unwrap();
648        let decoded = validate(&token, &config).expect("aud=ati-proxy should pass");
649        assert_eq!(decoded.aud, "ati-proxy");
650    }
651
652    #[test]
653    fn test_multi_audience_accepts_second() {
654        let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
655        let mut claims = make_claims("tool:web_search");
656        claims.aud = "parcha-tools".into();
657
658        let token = issue(&claims, &config).unwrap();
659        let decoded = validate(&token, &config).expect("aud=parcha-tools should pass");
660        assert_eq!(decoded.aud, "parcha-tools");
661    }
662
663    #[test]
664    fn test_multi_audience_rejects_out_of_list() {
665        let config = hs256_config_multi(vec!["ati-proxy".into(), "parcha-tools".into()]);
666        let mut claims = make_claims("tool:web_search");
667        claims.aud = "evil-aud".into();
668
669        let token = issue(&claims, &config).unwrap();
670        let result = validate(&token, &config);
671        assert!(result.is_err(), "aud not in allowlist must be rejected");
672    }
673
674    #[test]
675    fn test_single_audience_back_compat() {
676        // A one-element vec must behave exactly as the v0.7.x single-aud
677        // String case did — same tokens accepted, same tokens rejected.
678        let config = hs256_config_multi(vec!["ati-proxy".into()]);
679        let mut claims = make_claims("tool:web_search");
680
681        claims.aud = "ati-proxy".into();
682        let token = issue(&claims, &config).unwrap();
683        assert!(validate(&token, &config).is_ok());
684
685        claims.aud = "wrong".into();
686        let token = issue(&claims, &config).unwrap();
687        assert!(validate(&token, &config).is_err());
688    }
689
690    #[test]
691    fn test_empty_audiences_vec_rejected_not_bypassed() {
692        // Greptile P1 / security on #121: jsonwebtoken's
693        // Validation::set_audience(&[]) silently bypasses audience validation.
694        // validate() must hard-error on an empty accepted_audiences rather
695        // than accepting any aud — this is the API-boundary guard that
696        // mirrors parse_audiences_env()'s env-loading guard.
697        //
698        // Even a token with the "right" aud must be rejected because the
699        // config is broken at construction time.
700        let config = hs256_config_multi(vec![]);
701        let mut claims = make_claims("tool:web_search");
702        claims.aud = "ati-proxy".into();
703
704        // Issue with a separate one-element config so the token is
705        // structurally valid; the rejection should come from validate's
706        // empty-allowlist guard, not from missing aud.
707        let issuer = hs256_config_multi(vec!["ati-proxy".into()]);
708        let token = issue(&claims, &issuer).unwrap();
709
710        let err = validate(&token, &config).expect_err("empty allowlist must reject");
711        match err {
712            JwtError::InvalidKey(msg) => assert!(
713                msg.contains("accepted_audiences"),
714                "error must mention accepted_audiences; got: {msg}"
715            ),
716            other => panic!("expected InvalidKey, got {other:?}"),
717        }
718    }
719
720    // -------------------------------------------------------------------------
721    // parse_audiences_env (issue #121).
722    //
723    // Test using env vars requires serializing across tests; reuse the same
724    // Mutex pattern as core::token::tests::ENV_LOCK rather than re-rolling.
725    // -------------------------------------------------------------------------
726
727    use std::sync::Mutex;
728    static ENV_LOCK: Mutex<()> = Mutex::new(());
729
730    struct EnvGuard {
731        prev: Vec<(&'static str, Option<String>)>,
732    }
733
734    impl EnvGuard {
735        fn set(pairs: &[(&'static str, Option<&str>)]) -> Self {
736            let mut prev = Vec::new();
737            for (k, v) in pairs {
738                prev.push((*k, std::env::var(k).ok()));
739                match v {
740                    Some(val) => std::env::set_var(k, val),
741                    None => std::env::remove_var(k),
742                }
743            }
744            Self { prev }
745        }
746    }
747
748    impl Drop for EnvGuard {
749        fn drop(&mut self) {
750            for (k, v) in &self.prev {
751                match v {
752                    Some(val) => std::env::set_var(k, val),
753                    None => std::env::remove_var(k),
754                }
755            }
756        }
757    }
758
759    #[test]
760    fn test_parse_audiences_env_csv_wins() {
761        let _g = ENV_LOCK.lock().unwrap();
762        let _e = EnvGuard::set(&[
763            ("ATI_JWT_ACCEPTED_AUDIENCES", Some("a, b ,c")),
764            ("ATI_JWT_AUDIENCE", Some("ignored-singular")),
765        ]);
766        assert_eq!(parse_audiences_env(), vec!["a", "b", "c"]);
767    }
768
769    #[test]
770    fn test_parse_audiences_env_falls_back_to_singular() {
771        let _g = ENV_LOCK.lock().unwrap();
772        let _e = EnvGuard::set(&[
773            ("ATI_JWT_ACCEPTED_AUDIENCES", None),
774            ("ATI_JWT_AUDIENCE", Some("custom-aud")),
775        ]);
776        assert_eq!(parse_audiences_env(), vec!["custom-aud"]);
777    }
778
779    #[test]
780    fn test_parse_audiences_env_default_is_ati_proxy() {
781        let _g = ENV_LOCK.lock().unwrap();
782        let _e = EnvGuard::set(&[
783            ("ATI_JWT_ACCEPTED_AUDIENCES", None),
784            ("ATI_JWT_AUDIENCE", None),
785        ]);
786        assert_eq!(parse_audiences_env(), vec!["ati-proxy"]);
787    }
788
789    #[test]
790    fn test_parse_audiences_env_csv_all_empty_falls_back() {
791        // Pathological config: "ATI_JWT_ACCEPTED_AUDIENCES=  ,  ,  ". If we
792        // honoured the empty list we'd silently accept any aud — instead
793        // fall through to the singular env / default. validate()'s contract
794        // is that the allowlist is never empty.
795        let _g = ENV_LOCK.lock().unwrap();
796        let _e = EnvGuard::set(&[
797            ("ATI_JWT_ACCEPTED_AUDIENCES", Some("  ,  ,  ")),
798            ("ATI_JWT_AUDIENCE", Some("fallback-aud")),
799        ]);
800        assert_eq!(parse_audiences_env(), vec!["fallback-aud"]);
801    }
802}