Skip to main content

a2a_rs/adapter/auth/
authenticator.rs

1//! Authentication adapter implementations
2
3use std::collections::HashMap;
4#[cfg(feature = "http-server")]
5use std::sync::Arc;
6
7use async_trait::async_trait;
8#[cfg(feature = "http-server")]
9use axum::{
10    extract::State,
11    http::{HeaderMap, Request, StatusCode},
12    middleware::Next,
13    response::Response,
14};
15
16#[cfg(not(feature = "http-server"))]
17type HeaderMap = std::collections::HashMap<String, String>;
18
19use crate::{
20    domain::{A2AError, core::agent::SecurityScheme},
21    port::authenticator::{AuthContext, AuthContextExtractor, AuthPrincipal, Authenticator},
22};
23
24/// HTTP Bearer token authenticator
25#[derive(Clone)]
26pub struct BearerTokenAuthenticator {
27    /// The valid tokens
28    tokens: Vec<String>,
29    /// The security scheme configuration
30    scheme: SecurityScheme,
31}
32
33impl BearerTokenAuthenticator {
34    /// Create a new bearer token authenticator with the given tokens
35    pub fn new(tokens: Vec<String>) -> Self {
36        Self {
37            tokens,
38            scheme: SecurityScheme::http(
39                "bearer".to_string(),
40                None,
41                Some("Bearer token authentication".to_string()),
42            ),
43        }
44    }
45
46    /// Create with a specific bearer format
47    pub fn with_format(tokens: Vec<String>, format: String) -> Self {
48        Self {
49            tokens,
50            scheme: SecurityScheme::http(
51                "bearer".to_string(),
52                Some(format),
53                Some("Bearer token authentication".to_string()),
54            ),
55        }
56    }
57}
58
59#[async_trait]
60impl Authenticator for BearerTokenAuthenticator {
61    async fn authenticate(&self, context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
62        self.validate_context(context)?;
63
64        if self.tokens.contains(&context.credential) {
65            Ok(AuthPrincipal::new(
66                context.credential.clone(),
67                "bearer".to_string(),
68            ))
69        } else {
70            Err(A2AError::Internal(
71                "Invalid authentication token".to_string(),
72            ))
73        }
74    }
75
76    fn security_scheme(&self) -> &SecurityScheme {
77        &self.scheme
78    }
79
80    fn validate_context(&self, context: &AuthContext) -> Result<(), A2AError> {
81        if context.scheme_type != "bearer" {
82            return Err(A2AError::Internal(format!(
83                "Invalid authentication scheme: expected 'bearer', got '{}'",
84                context.scheme_type
85            )));
86        }
87        Ok(())
88    }
89}
90
91/// HTTP context extractor for Bearer tokens
92#[derive(Clone)]
93pub struct BearerTokenExtractor;
94
95#[async_trait]
96impl AuthContextExtractor for BearerTokenExtractor {
97    #[cfg(feature = "http-server")]
98    async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
99        headers
100            .get(axum::http::header::AUTHORIZATION)
101            .and_then(|h| h.to_str().ok())
102            .and_then(|auth| {
103                let parts: Vec<&str> = auth.splitn(2, ' ').collect();
104                if parts.len() == 2 && parts[0].to_lowercase() == "bearer" {
105                    Some(AuthContext::new("bearer".to_string(), parts[1].to_string()))
106                } else {
107                    None
108                }
109            })
110    }
111
112    #[cfg(not(feature = "http-server"))]
113    async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
114        headers
115            .get("authorization")
116            .or_else(|| headers.get("Authorization"))
117            .and_then(|auth| {
118                let parts: Vec<&str> = auth.splitn(2, ' ').collect();
119                if parts.len() == 2 && parts[0].to_lowercase() == "bearer" {
120                    Some(AuthContext::new("bearer".to_string(), parts[1].to_string()))
121                } else {
122                    None
123                }
124            })
125    }
126
127    async fn extract_from_query(&self, _params: &HashMap<String, String>) -> Option<AuthContext> {
128        // Bearer tokens are not typically passed in query parameters
129        None
130    }
131
132    async fn extract_from_cookies(&self, _cookies: &str) -> Option<AuthContext> {
133        // Bearer tokens are not typically passed in cookies
134        None
135    }
136}
137
138/// API Key authenticator
139#[derive(Clone)]
140pub struct ApiKeyAuthenticator {
141    /// Valid API keys
142    api_keys: Vec<String>,
143    /// The security scheme configuration
144    scheme: SecurityScheme,
145}
146
147impl ApiKeyAuthenticator {
148    /// Create a new API key authenticator
149    pub fn new(api_keys: Vec<String>, location: String, name: String) -> Self {
150        Self {
151            api_keys,
152            scheme: SecurityScheme::api_key(
153                name,
154                location,
155                Some("API key authentication".to_string()),
156            ),
157        }
158    }
159
160    /// Create for header-based API key
161    pub fn header(api_keys: Vec<String>, header_name: String) -> Self {
162        Self::new(api_keys, "header".to_string(), header_name)
163    }
164
165    /// Create for query parameter-based API key
166    pub fn query(api_keys: Vec<String>, param_name: String) -> Self {
167        Self::new(api_keys, "query".to_string(), param_name)
168    }
169
170    /// Create for cookie-based API key
171    pub fn cookie(api_keys: Vec<String>, cookie_name: String) -> Self {
172        Self::new(api_keys, "cookie".to_string(), cookie_name)
173    }
174}
175
176#[async_trait]
177impl Authenticator for ApiKeyAuthenticator {
178    async fn authenticate(&self, context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
179        self.validate_context(context)?;
180
181        if self.api_keys.contains(&context.credential) {
182            Ok(
183                AuthPrincipal::new(context.credential.clone(), "apikey".to_string())
184                    .with_attribute(
185                        "location".to_string(),
186                        context
187                            .metadata
188                            .get("location")
189                            .unwrap_or(&String::new())
190                            .clone(),
191                    ),
192            )
193        } else {
194            Err(A2AError::Internal("Invalid API key".to_string()))
195        }
196    }
197
198    fn security_scheme(&self) -> &SecurityScheme {
199        &self.scheme
200    }
201
202    fn validate_context(&self, context: &AuthContext) -> Result<(), A2AError> {
203        if context.scheme_type != "apikey" {
204            return Err(A2AError::Internal(format!(
205                "Invalid authentication scheme: expected 'apikey', got '{}'",
206                context.scheme_type
207            )));
208        }
209        Ok(())
210    }
211}
212
213/// API Key context extractor
214#[derive(Clone)]
215pub struct ApiKeyExtractor {
216    location: String,
217    name: String,
218}
219
220impl ApiKeyExtractor {
221    pub fn new(location: String, name: String) -> Self {
222        Self { location, name }
223    }
224}
225
226#[async_trait]
227impl AuthContextExtractor for ApiKeyExtractor {
228    #[cfg(feature = "http-server")]
229    async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
230        if self.location != "header" {
231            return None;
232        }
233
234        headers
235            .get(axum::http::HeaderName::from_bytes(self.name.as_bytes()).ok()?)
236            .and_then(|h| h.to_str().ok())
237            .map(|value| {
238                AuthContext::new("apikey".to_string(), value.to_string())
239                    .with_metadata("location".to_string(), "header".to_string())
240                    .with_metadata("name".to_string(), self.name.clone())
241            })
242    }
243
244    #[cfg(not(feature = "http-server"))]
245    async fn extract_from_headers(&self, headers: &HeaderMap) -> Option<AuthContext> {
246        if self.location != "header" {
247            return None;
248        }
249
250        headers.get(&self.name).map(|value| {
251            AuthContext::new("apikey".to_string(), value.clone())
252                .with_metadata("location".to_string(), "header".to_string())
253                .with_metadata("name".to_string(), self.name.clone())
254        })
255    }
256
257    async fn extract_from_query(&self, params: &HashMap<String, String>) -> Option<AuthContext> {
258        if self.location != "query" {
259            return None;
260        }
261
262        params.get(&self.name).map(|value| {
263            AuthContext::new("apikey".to_string(), value.clone())
264                .with_metadata("location".to_string(), "query".to_string())
265                .with_metadata("name".to_string(), self.name.clone())
266        })
267    }
268
269    async fn extract_from_cookies(&self, cookies: &str) -> Option<AuthContext> {
270        if self.location != "cookie" {
271            return None;
272        }
273
274        // Simple cookie parsing - in production, use a proper cookie parser
275        cookies
276            .split(';')
277            .map(|cookie| cookie.trim())
278            .find_map(|cookie| {
279                let parts: Vec<&str> = cookie.splitn(2, '=').collect();
280                if parts.len() == 2 && parts[0] == self.name {
281                    Some(
282                        AuthContext::new("apikey".to_string(), parts[1].to_string())
283                            .with_metadata("location".to_string(), "cookie".to_string())
284                            .with_metadata("name".to_string(), self.name.clone()),
285                    )
286                } else {
287                    None
288                }
289            })
290    }
291}
292
293/// No-op authenticator that allows all requests
294#[derive(Clone)]
295pub struct NoopAuthenticator {
296    scheme: SecurityScheme,
297}
298
299impl NoopAuthenticator {
300    /// Create a new no-op authenticator
301    pub fn new() -> Self {
302        Self {
303            scheme: SecurityScheme::http(
304                "none".to_string(),
305                None,
306                Some("No authentication required".to_string()),
307            ),
308        }
309    }
310}
311
312impl Default for NoopAuthenticator {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318#[async_trait]
319impl Authenticator for NoopAuthenticator {
320    async fn authenticate(&self, _context: &AuthContext) -> Result<AuthPrincipal, A2AError> {
321        Ok(AuthPrincipal::new(
322            "anonymous".to_string(),
323            "none".to_string(),
324        ))
325    }
326
327    fn security_scheme(&self) -> &SecurityScheme {
328        &self.scheme
329    }
330
331    fn validate_context(&self, _context: &AuthContext) -> Result<(), A2AError> {
332        // No validation needed for no-op
333        Ok(())
334    }
335}
336
337#[cfg(feature = "http-server")]
338mod http_auth {
339    use super::*;
340
341    /// Authentication middleware state
342    #[derive(Clone)]
343    pub struct AuthState {
344        /// The authenticator to use
345        authenticator: Arc<dyn Authenticator>,
346        /// Context extractors
347        extractors: Vec<Arc<dyn AuthContextExtractor>>,
348    }
349
350    impl AuthState {
351        /// Create a new authentication state
352        pub fn new(authenticator: impl Authenticator + 'static) -> Self {
353            Self {
354                authenticator: Arc::new(authenticator),
355                extractors: vec![Arc::new(BearerTokenExtractor)],
356            }
357        }
358
359        /// Create with custom extractors
360        #[allow(dead_code)]
361        pub fn with_extractors(
362            authenticator: impl Authenticator + 'static,
363            extractors: Vec<Arc<dyn AuthContextExtractor>>,
364        ) -> Self {
365            Self {
366                authenticator: Arc::new(authenticator),
367                extractors,
368            }
369        }
370    }
371
372    /// Authentication middleware for Axum
373    pub async fn http_auth_middleware(
374        State(state): State<AuthState>,
375        req: Request<axum::body::Body>,
376        next: Next,
377    ) -> Result<Response, StatusCode> {
378        let headers = req.headers();
379
380        // Try to extract auth context using available extractors
381        for extractor in &state.extractors {
382            if let Some(context) = extractor.extract_from_headers(headers).await {
383                // Try to authenticate with the extracted context
384                match state.authenticator.authenticate(&context).await {
385                    Ok(_principal) => {
386                        // Authentication successful, we could add the principal to request extensions
387                        // For now, just proceed with the request
388                        return Ok(next.run(req).await);
389                    }
390                    Err(_) => {
391                        // This extractor found credentials but they were invalid
392                        return Err(StatusCode::UNAUTHORIZED);
393                    }
394                }
395            }
396        }
397
398        // No valid authentication context found
399        Err(StatusCode::UNAUTHORIZED)
400    }
401
402    /// Helper function to apply authentication middleware to a router
403    pub fn with_auth<R>(router: R, authenticator: impl Authenticator + 'static) -> axum::Router
404    where
405        R: Into<axum::Router>,
406    {
407        let auth_state = AuthState::new(authenticator);
408        let router = router.into();
409
410        router.layer(axum::middleware::from_fn_with_state(
411            auth_state,
412            http_auth_middleware,
413        ))
414    }
415}
416
417#[cfg(feature = "http-server")]
418pub use http_auth::with_auth;