Skip to main content

inference_gateway_adk/server/
auth.rs

1//! Authentication primitives for the A2A HTTP surface.
2//!
3//! The middleware wired here protects `POST /a2a` (and any other route
4//! that opts in via [`auth_middleware`]) by validating an
5//! `Authorization: Bearer <token>` JWT against an OIDC issuer's JWKS.
6//! `GET /health` and `GET /.well-known/agent.json` are intentionally left
7//! public so health probes and discovery clients keep working without a
8//! credential.
9//!
10//! The middleware is only attached when [`crate::config::AuthConfig`]
11//! has `enable == true` and an [`AuthVerifier`] is registered on
12//! [`AppState`] - otherwise the routes behave exactly as before, which
13//! preserves backwards compatibility for callers that have not opted in
14//! to authentication.
15//!
16//! The wire shape (bearer JWT) matches the Go ADK's middleware. Broader
17//! `securitySchemes`-driven negotiation (API key, mTLS, OAuth2
18//! authorization-code) is intentionally out of scope; this module is the
19//! foundation those schemes can build on.
20
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24
25use anyhow::{Result, anyhow};
26use async_trait::async_trait;
27use axum::{
28    body::Body,
29    extract::{Request, State},
30    http::{StatusCode, header},
31    middleware::Next,
32    response::{IntoResponse, Json, Response},
33};
34use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
35use serde::{Deserialize, Serialize};
36use serde_json::Value;
37use tokio::sync::RwLock;
38use tracing::{debug, warn};
39
40use crate::config::AuthConfig;
41
42use super::protocol::AppState;
43
44/// Subject claims extracted from a validated bearer token.
45///
46/// Plumbed through Axum request extensions so JSON-RPC handlers can
47/// surface or scope behaviour by tenant in a follow-up (e.g. filtering
48/// the extended agent card).
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AuthenticatedPrincipal {
51    /// `sub` claim from the JWT.
52    pub subject: String,
53    /// Tenant claim, lifted from the first of `tenant`, `tid`, or
54    /// `organization` that is present in the token. May be empty when
55    /// no claim is set - downstream code should treat this as
56    /// "unspecified" rather than "anonymous".
57    pub tenant: String,
58    /// Issuer (`iss`) claim - already validated against
59    /// [`AuthConfig::issuer_url`] before this struct is constructed.
60    pub issuer: String,
61    /// All claims, retained so handlers can inspect provider-specific
62    /// fields (e.g. `groups`, `roles`) without re-decoding the token.
63    pub claims: HashMap<String, Value>,
64}
65
66impl AuthenticatedPrincipal {
67    fn from_claims(issuer: String, claims: HashMap<String, Value>) -> Self {
68        let subject = claims
69            .get("sub")
70            .and_then(|v| v.as_str())
71            .unwrap_or_default()
72            .to_string();
73        let tenant = ["tenant", "tid", "organization"]
74            .iter()
75            .find_map(|k| claims.get(*k).and_then(|v| v.as_str()))
76            .unwrap_or_default()
77            .to_string();
78        Self {
79            subject,
80            tenant,
81            issuer,
82            claims,
83        }
84    }
85}
86
87/// Error surface for token verification. The middleware maps any
88/// variant to HTTP 401 - the granularity exists for logs and to give
89/// tests something concrete to match on.
90#[derive(Debug, thiserror::Error)]
91pub enum AuthError {
92    #[error("missing Authorization header")]
93    MissingHeader,
94    #[error("Authorization header must use the Bearer scheme")]
95    MalformedHeader,
96    #[error("token is empty")]
97    EmptyToken,
98    #[error("token validation failed: {0}")]
99    InvalidToken(String),
100    #[error("OIDC discovery failed: {0}")]
101    DiscoveryFailed(String),
102    #[error("JWKS fetch failed: {0}")]
103    JwksFetchFailed(String),
104    #[error("signing key not found for token")]
105    UnknownKid,
106    #[error("internal auth error: {0}")]
107    Internal(String),
108}
109
110/// Pluggable bearer-token verifier. Implementing this trait lets callers
111/// plug a custom backend (e.g. a static signing key, a mock for tests)
112/// in place of the bundled OIDC verifier.
113#[async_trait]
114pub trait AuthVerifier: Send + Sync + std::fmt::Debug {
115    /// Validate a raw bearer token (no `Bearer ` prefix) and return the
116    /// authenticated principal on success.
117    async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError>;
118}
119
120/// JWT verifier that pulls the JWKS from an OIDC issuer's discovery
121/// document and caches the keys in memory. Verifies token signature,
122/// `iss`, `exp`, and (when `client_id` is configured) `aud`.
123#[derive(Debug)]
124pub struct OidcJwtVerifier {
125    issuer_url: String,
126    audience: Option<String>,
127    http: reqwest::Client,
128    cache: RwLock<JwksCache>,
129}
130
131#[derive(Debug, Default)]
132struct JwksCache {
133    jwks_uri: Option<String>,
134    /// `kid -> (DecodingKey, alg)` for keys advertised by the JWKS.
135    keys: HashMap<String, (DecodingKey, Algorithm)>,
136}
137
138#[derive(Debug, Deserialize)]
139struct DiscoveryDocument {
140    jwks_uri: String,
141}
142
143#[derive(Debug, Deserialize)]
144struct JwksDocument {
145    keys: Vec<jsonwebtoken::jwk::Jwk>,
146}
147
148impl OidcJwtVerifier {
149    /// Build a verifier from the [`AuthConfig`]. The HTTP client uses a
150    /// 5s timeout for discovery + JWKS fetches.
151    pub fn from_config(config: &AuthConfig) -> Result<Self> {
152        if config.issuer_url.trim().is_empty() {
153            return Err(anyhow!("AUTH_ISSUER_URL is required when AUTH_ENABLE=true"));
154        }
155        let http = reqwest::Client::builder()
156            .timeout(Duration::from_secs(5))
157            .build()
158            .map_err(|e| anyhow!("failed to build OIDC HTTP client: {e}"))?;
159        let audience = if config.client_id.trim().is_empty() {
160            None
161        } else {
162            Some(config.client_id.clone())
163        };
164        Ok(Self {
165            issuer_url: config.issuer_url.trim_end_matches('/').to_string(),
166            audience,
167            http,
168            cache: RwLock::new(JwksCache::default()),
169        })
170    }
171
172    /// Discovery URL: `<issuer>/.well-known/openid-configuration`.
173    fn discovery_url(&self) -> String {
174        format!("{}/.well-known/openid-configuration", self.issuer_url)
175    }
176
177    /// Returns the cached `jwks_uri`, fetching the discovery document
178    /// if the cache is cold.
179    async fn jwks_uri(&self) -> Result<String, AuthError> {
180        if let Some(uri) = self.cache.read().await.jwks_uri.clone() {
181            return Ok(uri);
182        }
183        let url = self.discovery_url();
184        debug!("fetching OIDC discovery document from {url}");
185        let doc: DiscoveryDocument = self
186            .http
187            .get(&url)
188            .send()
189            .await
190            .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?
191            .error_for_status()
192            .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?
193            .json()
194            .await
195            .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
196        let mut cache = self.cache.write().await;
197        cache.jwks_uri = Some(doc.jwks_uri.clone());
198        Ok(doc.jwks_uri)
199    }
200
201    /// Refresh the cached JWKS by fetching `jwks_uri`. Existing entries
202    /// are replaced wholesale.
203    async fn refresh_jwks(&self) -> Result<(), AuthError> {
204        let uri = self.jwks_uri().await?;
205        debug!("fetching JWKS from {uri}");
206        let doc: JwksDocument = self
207            .http
208            .get(&uri)
209            .send()
210            .await
211            .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?
212            .error_for_status()
213            .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?
214            .json()
215            .await
216            .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?;
217
218        let mut keys = HashMap::new();
219        for jwk in &doc.keys {
220            let Some(kid) = jwk.common.key_id.clone() else {
221                continue;
222            };
223            let Some(alg) = jwk.common.key_algorithm.and_then(map_key_algorithm) else {
224                continue;
225            };
226            let Ok(decoding) = DecodingKey::from_jwk(jwk) else {
227                continue;
228            };
229            keys.insert(kid, (decoding, alg));
230        }
231
232        let mut cache = self.cache.write().await;
233        cache.keys = keys;
234        Ok(())
235    }
236
237    async fn key_for_kid(&self, kid: &str) -> Option<(DecodingKey, Algorithm)> {
238        self.cache.read().await.keys.get(kid).cloned()
239    }
240}
241
242fn map_key_algorithm(alg: jsonwebtoken::jwk::KeyAlgorithm) -> Option<Algorithm> {
243    use jsonwebtoken::jwk::KeyAlgorithm as K;
244    match alg {
245        K::HS256 => Some(Algorithm::HS256),
246        K::HS384 => Some(Algorithm::HS384),
247        K::HS512 => Some(Algorithm::HS512),
248        K::ES256 => Some(Algorithm::ES256),
249        K::ES384 => Some(Algorithm::ES384),
250        K::RS256 => Some(Algorithm::RS256),
251        K::RS384 => Some(Algorithm::RS384),
252        K::RS512 => Some(Algorithm::RS512),
253        K::PS256 => Some(Algorithm::PS256),
254        K::PS384 => Some(Algorithm::PS384),
255        K::PS512 => Some(Algorithm::PS512),
256        K::EdDSA => Some(Algorithm::EdDSA),
257        _ => None,
258    }
259}
260
261#[async_trait]
262impl AuthVerifier for OidcJwtVerifier {
263    async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError> {
264        if token.is_empty() {
265            return Err(AuthError::EmptyToken);
266        }
267        let header = decode_header(token).map_err(|e| AuthError::InvalidToken(e.to_string()))?;
268        let Some(kid) = header.kid else {
269            return Err(AuthError::InvalidToken(
270                "token header missing `kid`".to_string(),
271            ));
272        };
273
274        let key_and_alg = match self.key_for_kid(&kid).await {
275            Some(found) => found,
276            None => {
277                self.refresh_jwks().await?;
278                self.key_for_kid(&kid).await.ok_or(AuthError::UnknownKid)?
279            }
280        };
281        let (decoding, alg) = key_and_alg;
282
283        let mut validation = Validation::new(alg);
284        validation.set_issuer(&[self.issuer_url.as_str()]);
285        if let Some(aud) = self.audience.as_ref() {
286            validation.set_audience(&[aud.as_str()]);
287        } else {
288            validation.validate_aud = false;
289        }
290
291        let data = decode::<HashMap<String, Value>>(token, &decoding, &validation)
292            .map_err(|e| AuthError::InvalidToken(e.to_string()))?;
293
294        Ok(AuthenticatedPrincipal::from_claims(
295            self.issuer_url.clone(),
296            data.claims,
297        ))
298    }
299}
300
301/// Axum middleware that enforces a valid bearer token before allowing
302/// the request to reach the wrapped handler. The middleware is a
303/// no-op when [`AppState::auth_verifier`] is `None`, which is what the
304/// builder produces when `AuthConfig.enable == false`.
305pub(crate) async fn auth_middleware(
306    State(state): State<Arc<AppState>>,
307    mut req: Request,
308    next: Next,
309) -> Result<Response, Response> {
310    let Some(verifier) = state.auth_verifier.as_ref().cloned() else {
311        return Ok(next.run(req).await);
312    };
313
314    let raw = match req.headers().get(header::AUTHORIZATION) {
315        Some(value) => value
316            .to_str()
317            .map_err(|_| reject(AuthError::MalformedHeader))?,
318        None => return Err(reject(AuthError::MissingHeader)),
319    };
320    let Some(token) = raw
321        .strip_prefix("Bearer ")
322        .or_else(|| raw.strip_prefix("bearer "))
323    else {
324        return Err(reject(AuthError::MalformedHeader));
325    };
326    let token = token.trim();
327
328    match verifier.verify(token).await {
329        Ok(principal) => {
330            req.extensions_mut().insert(principal);
331            Ok(next.run(req).await)
332        }
333        Err(e) => {
334            warn!("auth middleware rejected request: {e}");
335            Err(reject(e))
336        }
337    }
338}
339
340fn reject(err: AuthError) -> Response {
341    let body = Json(serde_json::json!({
342        "error": "Unauthorized",
343        "message": err.to_string(),
344    }));
345    let mut response = (StatusCode::UNAUTHORIZED, body).into_response();
346    response.headers_mut().insert(
347        header::WWW_AUTHENTICATE,
348        axum::http::HeaderValue::from_static("Bearer realm=\"a2a\""),
349    );
350    let _ = std::marker::PhantomData::<Body>;
351    response
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[derive(Debug)]
359    struct AcceptToken(&'static str);
360
361    #[async_trait]
362    impl AuthVerifier for AcceptToken {
363        async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError> {
364            if token == self.0 {
365                let mut claims = HashMap::new();
366                claims.insert("sub".to_string(), Value::String("test-user".to_string()));
367                claims.insert(
368                    "tenant".to_string(),
369                    Value::String("test-tenant".to_string()),
370                );
371                Ok(AuthenticatedPrincipal::from_claims(
372                    "https://example.test".to_string(),
373                    claims,
374                ))
375            } else {
376                Err(AuthError::InvalidToken("nope".to_string()))
377            }
378        }
379    }
380
381    #[tokio::test]
382    async fn principal_extracts_known_claims() {
383        let verifier = AcceptToken("good");
384        let p = verifier.verify("good").await.expect("ok");
385        assert_eq!(p.subject, "test-user");
386        assert_eq!(p.tenant, "test-tenant");
387        assert_eq!(p.issuer, "https://example.test");
388        assert!(p.claims.contains_key("sub"));
389    }
390
391    #[tokio::test]
392    async fn rejects_unknown_token() {
393        let verifier = AcceptToken("good");
394        let err = verifier.verify("bad").await.expect_err("must reject");
395        assert!(matches!(err, AuthError::InvalidToken(_)));
396    }
397
398    #[test]
399    fn from_config_requires_issuer_url() {
400        let cfg = AuthConfig {
401            enable: true,
402            issuer_url: String::new(),
403            client_id: String::new(),
404            client_secret: String::new(),
405        };
406        let err = OidcJwtVerifier::from_config(&cfg).expect_err("issuer required");
407        assert!(err.to_string().contains("AUTH_ISSUER_URL"));
408    }
409}