a2a_protocol_client/
auth.rs1use std::collections::HashMap;
28use std::fmt;
29use std::sync::{Arc, RwLock};
30
31use crate::error::ClientResult;
32use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct SessionId(String);
42
43impl SessionId {
44 #[must_use]
46 pub fn new(s: impl Into<String>) -> Self {
47 Self(s.into())
48 }
49}
50
51impl fmt::Display for SessionId {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.write_str(&self.0)
54 }
55}
56
57impl From<String> for SessionId {
58 fn from(s: String) -> Self {
59 Self(s)
60 }
61}
62
63impl From<&str> for SessionId {
64 fn from(s: &str) -> Self {
65 Self(s.to_owned())
66 }
67}
68
69pub trait CredentialsStore: Send + Sync + 'static {
77 fn get(&self, session: &SessionId, scheme: &str) -> Option<String>;
79
80 fn set(&self, session: SessionId, scheme: &str, credential: String);
82
83 fn remove(&self, session: &SessionId, scheme: &str);
85}
86
87pub struct InMemoryCredentialsStore {
94 inner: RwLock<HashMap<SessionId, HashMap<String, String>>>,
95}
96
97impl InMemoryCredentialsStore {
98 #[must_use]
100 pub fn new() -> Self {
101 Self {
102 inner: RwLock::new(HashMap::new()),
103 }
104 }
105}
106
107impl Default for InMemoryCredentialsStore {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl fmt::Debug for InMemoryCredentialsStore {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 let count = self.inner.read().map(|g| g.len()).unwrap_or(0);
117 f.debug_struct("InMemoryCredentialsStore")
118 .field("sessions", &count)
119 .finish()
120 }
121}
122
123impl CredentialsStore for InMemoryCredentialsStore {
124 fn get(&self, session: &SessionId, scheme: &str) -> Option<String> {
125 self.inner.read().ok()?.get(session)?.get(scheme).cloned()
126 }
127
128 fn set(&self, session: SessionId, scheme: &str, credential: String) {
129 if let Ok(mut guard) = self.inner.write() {
130 guard
131 .entry(session)
132 .or_default()
133 .insert(scheme.to_owned(), credential);
134 }
135 }
136
137 fn remove(&self, session: &SessionId, scheme: &str) {
138 if let Ok(mut guard) = self.inner.write() {
139 if let Some(schemes) = guard.get_mut(session) {
140 schemes.remove(scheme);
141 }
142 }
143 }
144}
145
146pub struct AuthInterceptor {
160 store: Arc<dyn CredentialsStore>,
161 session: SessionId,
162 scheme: String,
164}
165
166impl AuthInterceptor {
167 #[must_use]
169 pub fn new(store: Arc<dyn CredentialsStore>, session: SessionId) -> Self {
170 Self {
171 store,
172 session,
173 scheme: "bearer".to_owned(),
174 }
175 }
176
177 #[must_use]
179 pub fn with_scheme(
180 store: Arc<dyn CredentialsStore>,
181 session: SessionId,
182 scheme: impl Into<String>,
183 ) -> Self {
184 Self {
185 store,
186 session,
187 scheme: scheme.into(),
188 }
189 }
190}
191
192#[allow(clippy::missing_fields_in_debug)]
193impl fmt::Debug for AuthInterceptor {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("AuthInterceptor")
197 .field("session", &self.session)
198 .field("scheme", &self.scheme)
199 .finish()
200 }
201}
202
203impl CallInterceptor for AuthInterceptor {
204 #[allow(clippy::manual_async_fn)]
205 fn before<'a>(
206 &'a self,
207 req: &'a mut ClientRequest,
208 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
209 async move {
210 if let Some(credential) = self.store.get(&self.session, &self.scheme) {
211 let header_value = if self.scheme.eq_ignore_ascii_case("bearer") {
212 format!("Bearer {credential}")
213 } else if self.scheme.eq_ignore_ascii_case("basic") {
214 format!("Basic {credential}")
215 } else {
216 credential
217 };
218 req.extra_headers
219 .insert("authorization".to_owned(), header_value);
220 }
221 Ok(())
222 }
223 }
224
225 #[allow(clippy::manual_async_fn)]
226 fn after<'a>(
227 &'a self,
228 _resp: &'a ClientResponse,
229 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
230 async move { Ok(()) }
231 }
232}
233
234#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn credentials_store_set_get_remove() {
242 let store = InMemoryCredentialsStore::new();
243 let session = SessionId::new("sess-1");
244
245 assert!(store.get(&session, "bearer").is_none());
246
247 store.set(session.clone(), "bearer", "my-token".into());
248 assert_eq!(store.get(&session, "bearer").as_deref(), Some("my-token"));
249
250 store.remove(&session, "bearer");
251 assert!(store.get(&session, "bearer").is_none());
252 }
253
254 #[tokio::test]
255 async fn auth_interceptor_injects_bearer() {
256 let store = Arc::new(InMemoryCredentialsStore::new());
257 let session = SessionId::new("test");
258 store.set(session.clone(), "bearer", "my-secret-token".into());
259
260 let interceptor = AuthInterceptor::new(store, session);
261 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
262
263 interceptor.before(&mut req).await.unwrap();
264
265 assert_eq!(
266 req.extra_headers.get("authorization").map(String::as_str),
267 Some("Bearer my-secret-token")
268 );
269 }
270
271 #[tokio::test]
272 async fn auth_interceptor_no_credential_no_header() {
273 let store = Arc::new(InMemoryCredentialsStore::new());
274 let session = SessionId::new("empty");
275 let interceptor = AuthInterceptor::new(store, session);
276
277 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
278 interceptor.before(&mut req).await.unwrap();
279
280 assert!(!req.extra_headers.contains_key("authorization"));
281 }
282}