1use a2a::A2AError;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::RwLock;
7
8use crate::middleware::CallInterceptor;
9use crate::transport::ServiceParams;
10
11#[async_trait]
13pub trait CredentialsStore: Send + Sync {
14 async fn get(&self, scheme: &str) -> Option<String>;
16}
17
18pub struct InMemoryCredentialsStore {
20 credentials: RwLock<HashMap<String, String>>,
21}
22
23impl InMemoryCredentialsStore {
24 pub fn new() -> Self {
25 InMemoryCredentialsStore {
26 credentials: RwLock::new(HashMap::new()),
27 }
28 }
29
30 pub fn set(&self, scheme: &str, credential: &str) {
31 self.credentials
32 .write()
33 .unwrap()
34 .insert(scheme.to_string(), credential.to_string());
35 }
36}
37
38impl Default for InMemoryCredentialsStore {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[async_trait]
45impl CredentialsStore for InMemoryCredentialsStore {
46 async fn get(&self, scheme: &str) -> Option<String> {
47 self.credentials.read().unwrap().get(scheme).cloned()
48 }
49}
50
51pub struct AuthInterceptor {
53 header_name: String,
54 header_value: String,
55}
56
57impl AuthInterceptor {
58 pub fn bearer(token: impl Into<String>) -> Self {
60 AuthInterceptor {
61 header_name: "Authorization".to_string(),
62 header_value: format!("Bearer {}", token.into()),
63 }
64 }
65
66 pub fn custom(header_name: impl Into<String>, header_value: impl Into<String>) -> Self {
68 AuthInterceptor {
69 header_name: header_name.into(),
70 header_value: header_value.into(),
71 }
72 }
73}
74
75#[async_trait]
76impl CallInterceptor for AuthInterceptor {
77 async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
78 params
79 .entry(self.header_name.clone())
80 .or_default()
81 .push(self.header_value.clone());
82 Ok(())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn test_in_memory_credentials_store_new() {
92 let store = InMemoryCredentialsStore::new();
93 let creds = store.credentials.read().unwrap();
94 assert!(creds.is_empty());
95 }
96
97 #[test]
98 fn test_in_memory_credentials_store_default() {
99 let store = InMemoryCredentialsStore::default();
100 let creds = store.credentials.read().unwrap();
101 assert!(creds.is_empty());
102 }
103
104 #[test]
105 fn test_in_memory_credentials_store_set_get() {
106 let store = InMemoryCredentialsStore::new();
107 store.set("bearer", "token123");
108 let creds = store.credentials.read().unwrap();
109 assert_eq!(creds.get("bearer").unwrap(), "token123");
110 }
111
112 #[tokio::test]
113 async fn test_credentials_store_get() {
114 let store = InMemoryCredentialsStore::new();
115 store.set("api-key", "secret");
116 assert_eq!(store.get("api-key").await, Some("secret".to_string()));
117 assert_eq!(store.get("nonexistent").await, None);
118 }
119
120 #[tokio::test]
121 async fn test_auth_interceptor_bearer() {
122 let interceptor = AuthInterceptor::bearer("mytoken");
123 let mut params = ServiceParams::new();
124 interceptor.before("test", &mut params).await.unwrap();
125 assert_eq!(
126 params.get("Authorization").unwrap(),
127 &vec!["Bearer mytoken".to_string()]
128 );
129 }
130
131 #[tokio::test]
132 async fn test_auth_interceptor_custom() {
133 let interceptor = AuthInterceptor::custom("X-API-Key", "key123");
134 let mut params = ServiceParams::new();
135 interceptor.before("test", &mut params).await.unwrap();
136 assert_eq!(
137 params.get("X-API-Key").unwrap(),
138 &vec!["key123".to_string()]
139 );
140 }
141}