inference_gateway_adk/server/
auth.rs1use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24
25use anyhow::{Result, anyhow};
26use async_trait::async_trait;
27use axum::{
28 body::Body,
29 extract::{Request, State},
30 http::{StatusCode, header},
31 middleware::Next,
32 response::{IntoResponse, Json, Response},
33};
34use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
35use serde::{Deserialize, Serialize};
36use serde_json::Value;
37use tokio::sync::RwLock;
38use tracing::{debug, warn};
39
40use crate::config::AuthConfig;
41
42use super::protocol::AppState;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AuthenticatedPrincipal {
51 pub subject: String,
53 pub tenant: String,
58 pub issuer: String,
61 pub claims: HashMap<String, Value>,
64}
65
66impl AuthenticatedPrincipal {
67 fn from_claims(issuer: String, claims: HashMap<String, Value>) -> Self {
68 let subject = claims
69 .get("sub")
70 .and_then(|v| v.as_str())
71 .unwrap_or_default()
72 .to_string();
73 let tenant = ["tenant", "tid", "organization"]
74 .iter()
75 .find_map(|k| claims.get(*k).and_then(|v| v.as_str()))
76 .unwrap_or_default()
77 .to_string();
78 Self {
79 subject,
80 tenant,
81 issuer,
82 claims,
83 }
84 }
85}
86
87#[derive(Debug, thiserror::Error)]
91pub enum AuthError {
92 #[error("missing Authorization header")]
93 MissingHeader,
94 #[error("Authorization header must use the Bearer scheme")]
95 MalformedHeader,
96 #[error("token is empty")]
97 EmptyToken,
98 #[error("token validation failed: {0}")]
99 InvalidToken(String),
100 #[error("OIDC discovery failed: {0}")]
101 DiscoveryFailed(String),
102 #[error("JWKS fetch failed: {0}")]
103 JwksFetchFailed(String),
104 #[error("signing key not found for token")]
105 UnknownKid,
106 #[error("internal auth error: {0}")]
107 Internal(String),
108}
109
110#[async_trait]
114pub trait AuthVerifier: Send + Sync + std::fmt::Debug {
115 async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError>;
118}
119
120#[derive(Debug)]
124pub struct OidcJwtVerifier {
125 issuer_url: String,
126 audience: Option<String>,
127 http: reqwest::Client,
128 cache: RwLock<JwksCache>,
129}
130
131#[derive(Debug, Default)]
132struct JwksCache {
133 jwks_uri: Option<String>,
134 keys: HashMap<String, (DecodingKey, Algorithm)>,
136}
137
138#[derive(Debug, Deserialize)]
139struct DiscoveryDocument {
140 jwks_uri: String,
141}
142
143#[derive(Debug, Deserialize)]
144struct JwksDocument {
145 keys: Vec<jsonwebtoken::jwk::Jwk>,
146}
147
148impl OidcJwtVerifier {
149 pub fn from_config(config: &AuthConfig) -> Result<Self> {
152 if config.issuer_url.trim().is_empty() {
153 return Err(anyhow!("AUTH_ISSUER_URL is required when AUTH_ENABLE=true"));
154 }
155 let http = reqwest::Client::builder()
156 .timeout(Duration::from_secs(5))
157 .build()
158 .map_err(|e| anyhow!("failed to build OIDC HTTP client: {e}"))?;
159 let audience = if config.client_id.trim().is_empty() {
160 None
161 } else {
162 Some(config.client_id.clone())
163 };
164 Ok(Self {
165 issuer_url: config.issuer_url.trim_end_matches('/').to_string(),
166 audience,
167 http,
168 cache: RwLock::new(JwksCache::default()),
169 })
170 }
171
172 fn discovery_url(&self) -> String {
174 format!("{}/.well-known/openid-configuration", self.issuer_url)
175 }
176
177 async fn jwks_uri(&self) -> Result<String, AuthError> {
180 if let Some(uri) = self.cache.read().await.jwks_uri.clone() {
181 return Ok(uri);
182 }
183 let url = self.discovery_url();
184 debug!("fetching OIDC discovery document from {url}");
185 let doc: DiscoveryDocument = self
186 .http
187 .get(&url)
188 .send()
189 .await
190 .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?
191 .error_for_status()
192 .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?
193 .json()
194 .await
195 .map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
196 let mut cache = self.cache.write().await;
197 cache.jwks_uri = Some(doc.jwks_uri.clone());
198 Ok(doc.jwks_uri)
199 }
200
201 async fn refresh_jwks(&self) -> Result<(), AuthError> {
204 let uri = self.jwks_uri().await?;
205 debug!("fetching JWKS from {uri}");
206 let doc: JwksDocument = self
207 .http
208 .get(&uri)
209 .send()
210 .await
211 .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?
212 .error_for_status()
213 .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?
214 .json()
215 .await
216 .map_err(|e| AuthError::JwksFetchFailed(e.to_string()))?;
217
218 let mut keys = HashMap::new();
219 for jwk in &doc.keys {
220 let Some(kid) = jwk.common.key_id.clone() else {
221 continue;
222 };
223 let Some(alg) = jwk.common.key_algorithm.and_then(map_key_algorithm) else {
224 continue;
225 };
226 let Ok(decoding) = DecodingKey::from_jwk(jwk) else {
227 continue;
228 };
229 keys.insert(kid, (decoding, alg));
230 }
231
232 let mut cache = self.cache.write().await;
233 cache.keys = keys;
234 Ok(())
235 }
236
237 async fn key_for_kid(&self, kid: &str) -> Option<(DecodingKey, Algorithm)> {
238 self.cache.read().await.keys.get(kid).cloned()
239 }
240}
241
242fn map_key_algorithm(alg: jsonwebtoken::jwk::KeyAlgorithm) -> Option<Algorithm> {
243 use jsonwebtoken::jwk::KeyAlgorithm as K;
244 match alg {
245 K::HS256 => Some(Algorithm::HS256),
246 K::HS384 => Some(Algorithm::HS384),
247 K::HS512 => Some(Algorithm::HS512),
248 K::ES256 => Some(Algorithm::ES256),
249 K::ES384 => Some(Algorithm::ES384),
250 K::RS256 => Some(Algorithm::RS256),
251 K::RS384 => Some(Algorithm::RS384),
252 K::RS512 => Some(Algorithm::RS512),
253 K::PS256 => Some(Algorithm::PS256),
254 K::PS384 => Some(Algorithm::PS384),
255 K::PS512 => Some(Algorithm::PS512),
256 K::EdDSA => Some(Algorithm::EdDSA),
257 _ => None,
258 }
259}
260
261#[async_trait]
262impl AuthVerifier for OidcJwtVerifier {
263 async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError> {
264 if token.is_empty() {
265 return Err(AuthError::EmptyToken);
266 }
267 let header = decode_header(token).map_err(|e| AuthError::InvalidToken(e.to_string()))?;
268 let Some(kid) = header.kid else {
269 return Err(AuthError::InvalidToken(
270 "token header missing `kid`".to_string(),
271 ));
272 };
273
274 let key_and_alg = match self.key_for_kid(&kid).await {
275 Some(found) => found,
276 None => {
277 self.refresh_jwks().await?;
278 self.key_for_kid(&kid).await.ok_or(AuthError::UnknownKid)?
279 }
280 };
281 let (decoding, alg) = key_and_alg;
282
283 let mut validation = Validation::new(alg);
284 validation.set_issuer(&[self.issuer_url.as_str()]);
285 if let Some(aud) = self.audience.as_ref() {
286 validation.set_audience(&[aud.as_str()]);
287 } else {
288 validation.validate_aud = false;
289 }
290
291 let data = decode::<HashMap<String, Value>>(token, &decoding, &validation)
292 .map_err(|e| AuthError::InvalidToken(e.to_string()))?;
293
294 Ok(AuthenticatedPrincipal::from_claims(
295 self.issuer_url.clone(),
296 data.claims,
297 ))
298 }
299}
300
301pub(crate) async fn auth_middleware(
306 State(state): State<Arc<AppState>>,
307 mut req: Request,
308 next: Next,
309) -> Result<Response, Response> {
310 let Some(verifier) = state.auth_verifier.as_ref().cloned() else {
311 return Ok(next.run(req).await);
312 };
313
314 let raw = match req.headers().get(header::AUTHORIZATION) {
315 Some(value) => value
316 .to_str()
317 .map_err(|_| reject(AuthError::MalformedHeader))?,
318 None => return Err(reject(AuthError::MissingHeader)),
319 };
320 let Some(token) = raw
321 .strip_prefix("Bearer ")
322 .or_else(|| raw.strip_prefix("bearer "))
323 else {
324 return Err(reject(AuthError::MalformedHeader));
325 };
326 let token = token.trim();
327
328 match verifier.verify(token).await {
329 Ok(principal) => {
330 req.extensions_mut().insert(principal);
331 Ok(next.run(req).await)
332 }
333 Err(e) => {
334 warn!("auth middleware rejected request: {e}");
335 Err(reject(e))
336 }
337 }
338}
339
340fn reject(err: AuthError) -> Response {
341 let body = Json(serde_json::json!({
342 "error": "Unauthorized",
343 "message": err.to_string(),
344 }));
345 let mut response = (StatusCode::UNAUTHORIZED, body).into_response();
346 response.headers_mut().insert(
347 header::WWW_AUTHENTICATE,
348 axum::http::HeaderValue::from_static("Bearer realm=\"a2a\""),
349 );
350 let _ = std::marker::PhantomData::<Body>;
351 response
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[derive(Debug)]
359 struct AcceptToken(&'static str);
360
361 #[async_trait]
362 impl AuthVerifier for AcceptToken {
363 async fn verify(&self, token: &str) -> Result<AuthenticatedPrincipal, AuthError> {
364 if token == self.0 {
365 let mut claims = HashMap::new();
366 claims.insert("sub".to_string(), Value::String("test-user".to_string()));
367 claims.insert(
368 "tenant".to_string(),
369 Value::String("test-tenant".to_string()),
370 );
371 Ok(AuthenticatedPrincipal::from_claims(
372 "https://example.test".to_string(),
373 claims,
374 ))
375 } else {
376 Err(AuthError::InvalidToken("nope".to_string()))
377 }
378 }
379 }
380
381 #[tokio::test]
382 async fn principal_extracts_known_claims() {
383 let verifier = AcceptToken("good");
384 let p = verifier.verify("good").await.expect("ok");
385 assert_eq!(p.subject, "test-user");
386 assert_eq!(p.tenant, "test-tenant");
387 assert_eq!(p.issuer, "https://example.test");
388 assert!(p.claims.contains_key("sub"));
389 }
390
391 #[tokio::test]
392 async fn rejects_unknown_token() {
393 let verifier = AcceptToken("good");
394 let err = verifier.verify("bad").await.expect_err("must reject");
395 assert!(matches!(err, AuthError::InvalidToken(_)));
396 }
397
398 #[test]
399 fn from_config_requires_issuer_url() {
400 let cfg = AuthConfig {
401 enable: true,
402 issuer_url: String::new(),
403 client_id: String::new(),
404 client_secret: String::new(),
405 };
406 let err = OidcJwtVerifier::from_config(&cfg).expect_err("issuer required");
407 assert!(err.to_string().contains("AUTH_ISSUER_URL"));
408 }
409}