Skip to main content

a2a_client/
auth.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::A2AError;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::RwLock;
7
8use crate::middleware::CallInterceptor;
9use crate::transport::ServiceParams;
10
11/// Trait for providing credentials for authentication.
12#[async_trait]
13pub trait CredentialsStore: Send + Sync {
14    /// Get credentials for the given scheme name.
15    async fn get(&self, scheme: &str) -> Option<String>;
16}
17
18/// Simple in-memory credentials store.
19pub struct InMemoryCredentialsStore {
20    credentials: RwLock<HashMap<String, String>>,
21}
22
23impl InMemoryCredentialsStore {
24    pub fn new() -> Self {
25        InMemoryCredentialsStore {
26            credentials: RwLock::new(HashMap::new()),
27        }
28    }
29
30    pub fn set(&self, scheme: &str, credential: &str) {
31        self.credentials
32            .write()
33            .unwrap()
34            .insert(scheme.to_string(), credential.to_string());
35    }
36}
37
38impl Default for InMemoryCredentialsStore {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44#[async_trait]
45impl CredentialsStore for InMemoryCredentialsStore {
46    async fn get(&self, scheme: &str) -> Option<String> {
47        self.credentials.read().unwrap().get(scheme).cloned()
48    }
49}
50
51/// Interceptor that injects authentication credentials into requests.
52pub struct AuthInterceptor {
53    header_name: String,
54    header_value: String,
55}
56
57impl AuthInterceptor {
58    /// Create an interceptor that adds a Bearer token.
59    pub fn bearer(token: impl Into<String>) -> Self {
60        AuthInterceptor {
61            header_name: "Authorization".to_string(),
62            header_value: format!("Bearer {}", token.into()),
63        }
64    }
65
66    /// Create an interceptor with a custom header.
67    pub fn custom(header_name: impl Into<String>, header_value: impl Into<String>) -> Self {
68        AuthInterceptor {
69            header_name: header_name.into(),
70            header_value: header_value.into(),
71        }
72    }
73}
74
75#[async_trait]
76impl CallInterceptor for AuthInterceptor {
77    async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
78        params
79            .entry(self.header_name.clone())
80            .or_default()
81            .push(self.header_value.clone());
82        Ok(())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_in_memory_credentials_store_new() {
92        let store = InMemoryCredentialsStore::new();
93        let creds = store.credentials.read().unwrap();
94        assert!(creds.is_empty());
95    }
96
97    #[test]
98    fn test_in_memory_credentials_store_default() {
99        let store = InMemoryCredentialsStore::default();
100        let creds = store.credentials.read().unwrap();
101        assert!(creds.is_empty());
102    }
103
104    #[test]
105    fn test_in_memory_credentials_store_set_get() {
106        let store = InMemoryCredentialsStore::new();
107        store.set("bearer", "token123");
108        let creds = store.credentials.read().unwrap();
109        assert_eq!(creds.get("bearer").unwrap(), "token123");
110    }
111
112    #[tokio::test]
113    async fn test_credentials_store_get() {
114        let store = InMemoryCredentialsStore::new();
115        store.set("api-key", "secret");
116        assert_eq!(store.get("api-key").await, Some("secret".to_string()));
117        assert_eq!(store.get("nonexistent").await, None);
118    }
119
120    #[tokio::test]
121    async fn test_auth_interceptor_bearer() {
122        let interceptor = AuthInterceptor::bearer("mytoken");
123        let mut params = ServiceParams::new();
124        interceptor.before("test", &mut params).await.unwrap();
125        assert_eq!(
126            params.get("Authorization").unwrap(),
127            &vec!["Bearer mytoken".to_string()]
128        );
129    }
130
131    #[tokio::test]
132    async fn test_auth_interceptor_custom() {
133        let interceptor = AuthInterceptor::custom("X-API-Key", "key123");
134        let mut params = ServiceParams::new();
135        interceptor.before("test", &mut params).await.unwrap();
136        assert_eq!(
137            params.get("X-API-Key").unwrap(),
138            &vec!["key123".to_string()]
139        );
140    }
141}