Skip to main content

a2a_protocol_client/
auth.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Authentication interceptor and credential storage.
5//!
6//! [`AuthInterceptor`] injects `Authorization` headers from a
7//! [`CredentialsStore`] before each request. [`InMemoryCredentialsStore`]
8//! provides a simple in-process credential store.
9//!
10//! # Usage
11//!
12//! ```rust,no_run
13//! use std::sync::Arc;
14//! use a2a_protocol_client::auth::{
15//!     InMemoryCredentialsStore, AuthInterceptor, SessionId, CredentialsStore,
16//! };
17//! use a2a_protocol_client::ClientBuilder;
18//!
19//! let store = Arc::new(InMemoryCredentialsStore::new());
20//! let session = SessionId::new("my-session");
21//! store.set(session.clone(), "bearer", "my-token".into());
22//!
23//! let _builder = ClientBuilder::new("http://localhost:8080")
24//!     .with_interceptor(AuthInterceptor::new(store, session));
25//! ```
26
27use std::collections::HashMap;
28use std::fmt;
29use std::sync::{Arc, RwLock};
30
31use crate::error::ClientResult;
32use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
33
34// ── SessionId ─────────────────────────────────────────────────────────────────
35
36/// Opaque identifier for a client authentication session.
37///
38/// Sessions scope credentials so that a single credential store can manage
39/// tokens for multiple simultaneous client instances.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct SessionId(String);
42
43impl SessionId {
44    /// Creates a new [`SessionId`] from any string-like value.
45    #[must_use]
46    pub fn new(s: impl Into<String>) -> Self {
47        Self(s.into())
48    }
49}
50
51impl fmt::Display for SessionId {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.write_str(&self.0)
54    }
55}
56
57impl From<String> for SessionId {
58    fn from(s: String) -> Self {
59        Self(s)
60    }
61}
62
63impl From<&str> for SessionId {
64    fn from(s: &str) -> Self {
65        Self(s.to_owned())
66    }
67}
68
69// ── CredentialsStore ──────────────────────────────────────────────────────────
70
71/// Persistent storage for auth credentials, keyed by session + scheme.
72///
73/// Schemes follow the A2A / HTTP convention: `"bearer"`, `"basic"`,
74/// `"api-key"`, etc. The stored value is the raw credential (e.g. the raw
75/// token string, not including the scheme prefix).
76pub trait CredentialsStore: Send + Sync + 'static {
77    /// Returns the credential for the given session and scheme, if present.
78    fn get(&self, session: &SessionId, scheme: &str) -> Option<String>;
79
80    /// Stores a credential for the given session and scheme.
81    fn set(&self, session: SessionId, scheme: &str, credential: String);
82
83    /// Removes the credential for the given session and scheme.
84    fn remove(&self, session: &SessionId, scheme: &str);
85}
86
87// ── InMemoryCredentialsStore ──────────────────────────────────────────────────
88
89/// An in-memory [`CredentialsStore`] backed by an `RwLock<HashMap>`.
90///
91/// Suitable for single-process deployments. Credentials are lost when the
92/// process exits.
93pub struct InMemoryCredentialsStore {
94    inner: RwLock<HashMap<SessionId, HashMap<String, String>>>,
95}
96
97impl InMemoryCredentialsStore {
98    /// Creates an empty credential store.
99    #[must_use]
100    pub fn new() -> Self {
101        Self {
102            inner: RwLock::new(HashMap::new()),
103        }
104    }
105}
106
107impl Default for InMemoryCredentialsStore {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl fmt::Debug for InMemoryCredentialsStore {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        // Don't expose credential values in debug output.
116        let count = self.inner.read().map(|g| g.len()).unwrap_or(0);
117        f.debug_struct("InMemoryCredentialsStore")
118            .field("sessions", &count)
119            .finish()
120    }
121}
122
123impl CredentialsStore for InMemoryCredentialsStore {
124    fn get(&self, session: &SessionId, scheme: &str) -> Option<String> {
125        self.inner.read().ok()?.get(session)?.get(scheme).cloned()
126    }
127
128    fn set(&self, session: SessionId, scheme: &str, credential: String) {
129        if let Ok(mut guard) = self.inner.write() {
130            guard
131                .entry(session)
132                .or_default()
133                .insert(scheme.to_owned(), credential);
134        }
135    }
136
137    fn remove(&self, session: &SessionId, scheme: &str) {
138        if let Ok(mut guard) = self.inner.write() {
139            if let Some(schemes) = guard.get_mut(session) {
140                schemes.remove(scheme);
141            }
142        }
143    }
144}
145
146// ── AuthInterceptor ───────────────────────────────────────────────────────────
147
148/// A [`CallInterceptor`] that injects `Authorization` headers from a
149/// [`CredentialsStore`].
150///
151/// On each `before()` call it looks up the credential for the current session
152/// using the configured scheme (default: `"bearer"`). If found, it adds:
153///
154/// ```text
155/// Authorization: Bearer <token>
156/// ```
157///
158/// to `req.extra_headers`.
159pub struct AuthInterceptor {
160    store: Arc<dyn CredentialsStore>,
161    session: SessionId,
162    /// The auth scheme to look up (e.g. `"bearer"`, `"api-key"`).
163    scheme: String,
164}
165
166impl AuthInterceptor {
167    /// Creates an [`AuthInterceptor`] that injects bearer tokens.
168    #[must_use]
169    pub fn new(store: Arc<dyn CredentialsStore>, session: SessionId) -> Self {
170        Self {
171            store,
172            session,
173            scheme: "bearer".to_owned(),
174        }
175    }
176
177    /// Creates an [`AuthInterceptor`] with a custom auth scheme.
178    #[must_use]
179    pub fn with_scheme(
180        store: Arc<dyn CredentialsStore>,
181        session: SessionId,
182        scheme: impl Into<String>,
183    ) -> Self {
184        Self {
185            store,
186            session,
187            scheme: scheme.into(),
188        }
189    }
190}
191
192#[allow(clippy::missing_fields_in_debug)]
193impl fmt::Debug for AuthInterceptor {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        // Intentionally omit `store` to avoid exposing credential internals.
196        f.debug_struct("AuthInterceptor")
197            .field("session", &self.session)
198            .field("scheme", &self.scheme)
199            .finish()
200    }
201}
202
203impl CallInterceptor for AuthInterceptor {
204    #[allow(clippy::manual_async_fn)]
205    fn before<'a>(
206        &'a self,
207        req: &'a mut ClientRequest,
208    ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
209        async move {
210            if let Some(credential) = self.store.get(&self.session, &self.scheme) {
211                let header_value = if self.scheme.eq_ignore_ascii_case("bearer") {
212                    format!("Bearer {credential}")
213                } else if self.scheme.eq_ignore_ascii_case("basic") {
214                    format!("Basic {credential}")
215                } else {
216                    credential
217                };
218                req.extra_headers
219                    .insert("authorization".to_owned(), header_value);
220            }
221            Ok(())
222        }
223    }
224
225    #[allow(clippy::manual_async_fn)]
226    fn after<'a>(
227        &'a self,
228        _resp: &'a ClientResponse,
229    ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
230        async move { Ok(()) }
231    }
232}
233
234// ── Tests ─────────────────────────────────────────────────────────────────────
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn credentials_store_set_get_remove() {
242        let store = InMemoryCredentialsStore::new();
243        let session = SessionId::new("sess-1");
244
245        assert!(store.get(&session, "bearer").is_none());
246
247        store.set(session.clone(), "bearer", "my-token".into());
248        assert_eq!(store.get(&session, "bearer").as_deref(), Some("my-token"));
249
250        store.remove(&session, "bearer");
251        assert!(store.get(&session, "bearer").is_none());
252    }
253
254    #[tokio::test]
255    async fn auth_interceptor_injects_bearer() {
256        let store = Arc::new(InMemoryCredentialsStore::new());
257        let session = SessionId::new("test");
258        store.set(session.clone(), "bearer", "my-secret-token".into());
259
260        let interceptor = AuthInterceptor::new(store, session);
261        let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
262
263        interceptor.before(&mut req).await.unwrap();
264
265        assert_eq!(
266            req.extra_headers.get("authorization").map(String::as_str),
267            Some("Bearer my-secret-token")
268        );
269    }
270
271    #[tokio::test]
272    async fn auth_interceptor_no_credential_no_header() {
273        let store = Arc::new(InMemoryCredentialsStore::new());
274        let session = SessionId::new("empty");
275        let interceptor = AuthInterceptor::new(store, session);
276
277        let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
278        interceptor.before(&mut req).await.unwrap();
279
280        assert!(!req.extra_headers.contains_key("authorization"));
281    }
282}