1use std::collections::HashMap;
4#[cfg(feature = "http-server")]
5use std::sync::Arc;
6
7use async_trait::async_trait;
8#[cfg(feature = "http-server")]
9use axum::{
10 extract::State,
11 http::{HeaderMap, Request, StatusCode},
12 middleware::Next,
13 response::Response,
14};
15
16#[cfg(not(feature = "http-server"))]
17type HeaderMap = std::collections::HashMap<String, String>;
18
19use crate::{
20 domain::{A2AError, core::agent::SecurityScheme},
21 port::authenticator::{AuthContext, AuthContextExtractor, AuthPrincipal, Authenticator},
22};
23
24#[derive(Clone)]
26pub struct BearerTokenAuthenticator {
27 tokens: Vec<String>,
29 scheme: SecurityScheme,
31}
32
33impl BearerTokenAuthenticator {
34 pub fn new(tokens: Vec<String>) -> Self {
36 Self {
37 tokens,
38 scheme: SecurityScheme::http(
39 "bearer".to_string(),
40 None,
41 Some("Bearer token authentication".to_string()),
42 ),
43 }
44 }
45
46 pub fn with_format(tokens: Vec<String>, format: String) -> Self {
48 Self {
49 tokens,
50 scheme: SecurityScheme::http(
51 "bearer".to_string(),
52 Some(format),
53 Some("Bearer token authentication".to_string()),
54 ),
55 }
56 }
57}
58
59#[async_trait]
60impl Authenticator for BearerTokenAuthenticator {
61 async fn authenticate(&self, context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
62 self.validate_context(context)?;
63
64 if self.tokens.contains(&context.credential) {
65 Ok(AuthPrincipal::new(
66 context.credential.clone(),
67 "bearer".to_string(),
68 ))
69 } else {
70 Err(A2AError::Internal(
71 "Invalid authentication token".to_string(),
72 ))
73 }
74 }
75
76 fn security_scheme(&self) -> &SecurityScheme {
77 &self.scheme
78 }
79
80 fn validate_context(&self, context: &AuthContext) -> Result<(), A2AError> {
81 if context.scheme_type != "bearer" {
82 return Err(A2AError::Internal(format!(
83 "Invalid authentication scheme: expected 'bearer', got '{}'",
84 context.scheme_type
85 )));
86 }
87 Ok(())
88 }
89}
90
91#[derive(Clone)]
93pub struct BearerTokenExtractor;
94
95#[async_trait]
96impl AuthContextExtractor for BearerTokenExtractor {
97 #[cfg(feature = "http-server")]
98 async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
99 headers
100 .get(axum::http::header::AUTHORIZATION)
101 .and_then(|h| h.to_str().ok())
102 .and_then(|auth| {
103 let parts: Vec<&str> = auth.splitn(2, ' ').collect();
104 if parts.len() == 2 && parts[0].to_lowercase() == "bearer" {
105 Some(AuthContext::new("bearer".to_string(), parts[1].to_string()))
106 } else {
107 None
108 }
109 })
110 }
111
112 #[cfg(not(feature = "http-server"))]
113 async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
114 headers
115 .get("authorization")
116 .or_else(|| headers.get("Authorization"))
117 .and_then(|auth| {
118 let parts: Vec<&str> = auth.splitn(2, ' ').collect();
119 if parts.len() == 2 && parts[0].to_lowercase() == "bearer" {
120 Some(AuthContext::new("bearer".to_string(), parts[1].to_string()))
121 } else {
122 None
123 }
124 })
125 }
126
127 async fn extract_from_query(&self, _params: &HashMap<String, String>) -> Option<AuthContext> {
128 None
130 }
131
132 async fn extract_from_cookies(&self, _cookies: &str) -> Option<AuthContext> {
133 None
135 }
136}
137
138#[derive(Clone)]
140pub struct ApiKeyAuthenticator {
141 api_keys: Vec<String>,
143 scheme: SecurityScheme,
145}
146
147impl ApiKeyAuthenticator {
148 pub fn new(api_keys: Vec<String>, location: String, name: String) -> Self {
150 Self {
151 api_keys,
152 scheme: SecurityScheme::api_key(
153 name,
154 location,
155 Some("API key authentication".to_string()),
156 ),
157 }
158 }
159
160 pub fn header(api_keys: Vec<String>, header_name: String) -> Self {
162 Self::new(api_keys, "header".to_string(), header_name)
163 }
164
165 pub fn query(api_keys: Vec<String>, param_name: String) -> Self {
167 Self::new(api_keys, "query".to_string(), param_name)
168 }
169
170 pub fn cookie(api_keys: Vec<String>, cookie_name: String) -> Self {
172 Self::new(api_keys, "cookie".to_string(), cookie_name)
173 }
174}
175
176#[async_trait]
177impl Authenticator for ApiKeyAuthenticator {
178 async fn authenticate(&self, context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
179 self.validate_context(context)?;
180
181 if self.api_keys.contains(&context.credential) {
182 Ok(
183 AuthPrincipal::new(context.credential.clone(), "apikey".to_string())
184 .with_attribute(
185 "location".to_string(),
186 context
187 .metadata
188 .get("location")
189 .unwrap_or(&String::new())
190 .clone(),
191 ),
192 )
193 } else {
194 Err(A2AError::Internal("Invalid API key".to_string()))
195 }
196 }
197
198 fn security_scheme(&self) -> &SecurityScheme {
199 &self.scheme
200 }
201
202 fn validate_context(&self, context: &AuthContext) -> Result<(), A2AError> {
203 if context.scheme_type != "apikey" {
204 return Err(A2AError::Internal(format!(
205 "Invalid authentication scheme: expected 'apikey', got '{}'",
206 context.scheme_type
207 )));
208 }
209 Ok(())
210 }
211}
212
213#[derive(Clone)]
215pub struct ApiKeyExtractor {
216 location: String,
217 name: String,
218}
219
220impl ApiKeyExtractor {
221 pub fn new(location: String, name: String) -> Self {
222 Self { location, name }
223 }
224}
225
226#[async_trait]
227impl AuthContextExtractor for ApiKeyExtractor {
228 #[cfg(feature = "http-server")]
229 async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
230 if self.location != "header" {
231 return None;
232 }
233
234 headers
235 .get(axum::http::HeaderName::from_bytes(self.name.as_bytes()).ok()?)
236 .and_then(|h| h.to_str().ok())
237 .map(|value| {
238 AuthContext::new("apikey".to_string(), value.to_string())
239 .with_metadata("location".to_string(), "header".to_string())
240 .with_metadata("name".to_string(), self.name.clone())
241 })
242 }
243
244 #[cfg(not(feature = "http-server"))]
245 async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
246 if self.location != "header" {
247 return None;
248 }
249
250 headers.get(&self.name).map(|value| {
251 AuthContext::new("apikey".to_string(), value.clone())
252 .with_metadata("location".to_string(), "header".to_string())
253 .with_metadata("name".to_string(), self.name.clone())
254 })
255 }
256
257 async fn extract_from_query(&self, params: &HashMap<String, String>) -> Option<AuthContext> {
258 if self.location != "query" {
259 return None;
260 }
261
262 params.get(&self.name).map(|value| {
263 AuthContext::new("apikey".to_string(), value.clone())
264 .with_metadata("location".to_string(), "query".to_string())
265 .with_metadata("name".to_string(), self.name.clone())
266 })
267 }
268
269 async fn extract_from_cookies(&self, cookies: &str) -> Option<AuthContext> {
270 if self.location != "cookie" {
271 return None;
272 }
273
274 cookies
276 .split(';')
277 .map(|cookie| cookie.trim())
278 .find_map(|cookie| {
279 let parts: Vec<&str> = cookie.splitn(2, '=').collect();
280 if parts.len() == 2 && parts[0] == self.name {
281 Some(
282 AuthContext::new("apikey".to_string(), parts[1].to_string())
283 .with_metadata("location".to_string(), "cookie".to_string())
284 .with_metadata("name".to_string(), self.name.clone()),
285 )
286 } else {
287 None
288 }
289 })
290 }
291}
292
293#[derive(Clone)]
295pub struct NoopAuthenticator {
296 scheme: SecurityScheme,
297}
298
299impl NoopAuthenticator {
300 pub fn new() -> Self {
302 Self {
303 scheme: SecurityScheme::http(
304 "none".to_string(),
305 None,
306 Some("No authentication required".to_string()),
307 ),
308 }
309 }
310}
311
312impl Default for NoopAuthenticator {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[async_trait]
319impl Authenticator for NoopAuthenticator {
320 async fn authenticate(&self, _context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
321 Ok(AuthPrincipal::new(
322 "anonymous".to_string(),
323 "none".to_string(),
324 ))
325 }
326
327 fn security_scheme(&self) -> &SecurityScheme {
328 &self.scheme
329 }
330
331 fn validate_context(&self, _context: &AuthContext) -> Result<(), A2AError> {
332 Ok(())
334 }
335}
336
337#[cfg(feature = "http-server")]
338mod http_auth {
339 use super::*;
340
341 #[derive(Clone)]
343 pub struct AuthState {
344 authenticator: Arc<dyn Authenticator>,
346 extractors: Vec<Arc<dyn AuthContextExtractor>>,
348 }
349
350 impl AuthState {
351 pub fn new(authenticator: impl Authenticator + 'static) -> Self {
353 Self {
354 authenticator: Arc::new(authenticator),
355 extractors: vec![Arc::new(BearerTokenExtractor)],
356 }
357 }
358
359 #[allow(dead_code)]
361 pub fn with_extractors(
362 authenticator: impl Authenticator + 'static,
363 extractors: Vec<Arc<dyn AuthContextExtractor>>,
364 ) -> Self {
365 Self {
366 authenticator: Arc::new(authenticator),
367 extractors,
368 }
369 }
370 }
371
372 pub async fn http_auth_middleware(
374 State(state): State<AuthState>,
375 req: Request<axum::body::Body>,
376 next: Next,
377 ) -> Result<Response, StatusCode> {
378 let headers = req.headers();
379
380 for extractor in &state.extractors {
382 if let Some(context) = extractor.extract_from_headers(headers).await {
383 match state.authenticator.authenticate(&context).await {
385 Ok(_principal) => {
386 return Ok(next.run(req).await);
389 }
390 Err(_) => {
391 return Err(StatusCode::UNAUTHORIZED);
393 }
394 }
395 }
396 }
397
398 Err(StatusCode::UNAUTHORIZED)
400 }
401
402 pub fn with_auth<R>(router: R, authenticator: impl Authenticator + 'static) -> axum::Router
404 where
405 R: Into<axum::Router>,
406 {
407 let auth_state = AuthState::new(authenticator);
408 let router = router.into();
409
410 router.layer(axum::middleware::from_fn_with_state(
411 auth_state,
412 http_auth_middleware,
413 ))
414 }
415}
416
417#[cfg(feature = "http-server")]
418pub use http_auth::with_auth;