1use std::collections::HashMap;
30use std::fmt;
31use std::sync::{Arc, RwLock};
32
33use crate::error::ClientResult;
34use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct SessionId(String);
44
45impl SessionId {
46 #[must_use]
48 pub fn new(s: impl Into<String>) -> Self {
49 Self(s.into())
50 }
51}
52
53impl fmt::Display for SessionId {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.write_str(&self.0)
56 }
57}
58
59impl From<String> for SessionId {
60 fn from(s: String) -> Self {
61 Self(s)
62 }
63}
64
65impl From<&str> for SessionId {
66 fn from(s: &str) -> Self {
67 Self(s.to_owned())
68 }
69}
70
71pub trait CredentialsStore: Send + Sync + 'static {
79 fn get(&self, session: &SessionId, scheme: &str) -> Option<String>;
81
82 fn set(&self, session: SessionId, scheme: &str, credential: String);
84
85 fn remove(&self, session: &SessionId, scheme: &str);
87}
88
89pub struct InMemoryCredentialsStore {
102 inner: RwLock<HashMap<SessionId, HashMap<String, String>>>,
103}
104
105impl InMemoryCredentialsStore {
106 #[must_use]
108 pub fn new() -> Self {
109 Self {
110 inner: RwLock::new(HashMap::new()),
111 }
112 }
113}
114
115impl Default for InMemoryCredentialsStore {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl fmt::Debug for InMemoryCredentialsStore {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 let count = self
125 .inner
126 .read()
127 .expect("credentials store lock poisoned")
128 .len();
129 f.debug_struct("InMemoryCredentialsStore")
130 .field("sessions", &count)
131 .finish()
132 }
133}
134
135impl CredentialsStore for InMemoryCredentialsStore {
136 fn get(&self, session: &SessionId, scheme: &str) -> Option<String> {
137 let guard = self.inner.read().expect("credentials store lock poisoned");
139 guard.get(session)?.get(scheme).cloned()
140 }
141
142 fn set(&self, session: SessionId, scheme: &str, credential: String) {
143 let mut guard = self.inner.write().expect("credentials store lock poisoned");
144 guard
145 .entry(session)
146 .or_default()
147 .insert(scheme.to_owned(), credential);
148 }
149
150 fn remove(&self, session: &SessionId, scheme: &str) {
151 let mut guard = self.inner.write().expect("credentials store lock poisoned");
152 if let Some(schemes) = guard.get_mut(session) {
153 schemes.remove(scheme);
154 }
155 }
156}
157
158pub struct AuthInterceptor {
172 store: Arc<dyn CredentialsStore>,
173 session: SessionId,
174 scheme: String,
176}
177
178impl AuthInterceptor {
179 #[must_use]
181 pub fn new(store: Arc<dyn CredentialsStore>, session: SessionId) -> Self {
182 Self {
183 store,
184 session,
185 scheme: "bearer".to_owned(),
186 }
187 }
188
189 #[must_use]
191 pub fn with_scheme(
192 store: Arc<dyn CredentialsStore>,
193 session: SessionId,
194 scheme: impl Into<String>,
195 ) -> Self {
196 Self {
197 store,
198 session,
199 scheme: scheme.into(),
200 }
201 }
202}
203
204#[allow(clippy::missing_fields_in_debug)]
205impl fmt::Debug for AuthInterceptor {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 f.debug_struct("AuthInterceptor")
209 .field("session", &self.session)
210 .field("scheme", &self.scheme)
211 .finish()
212 }
213}
214
215impl CallInterceptor for AuthInterceptor {
216 #[allow(clippy::manual_async_fn)]
217 fn before<'a>(
218 &'a self,
219 req: &'a mut ClientRequest,
220 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
221 async move {
222 if let Some(credential) = self.store.get(&self.session, &self.scheme) {
223 let header_value = if self.scheme.eq_ignore_ascii_case("bearer") {
224 format!("Bearer {credential}")
225 } else if self.scheme.eq_ignore_ascii_case("basic") {
226 format!("Basic {credential}")
227 } else {
228 credential
229 };
230 req.extra_headers
231 .insert("authorization".to_owned(), header_value);
232 }
233 Ok(())
234 }
235 }
236
237 #[allow(clippy::manual_async_fn)]
238 fn after<'a>(
239 &'a self,
240 _resp: &'a ClientResponse,
241 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
242 async move { Ok(()) }
243 }
244}
245
246#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn credentials_store_set_get_remove() {
254 let store = InMemoryCredentialsStore::new();
255 let session = SessionId::new("sess-1");
256
257 assert!(store.get(&session, "bearer").is_none());
258
259 store.set(session.clone(), "bearer", "my-token".into());
260 assert_eq!(store.get(&session, "bearer").as_deref(), Some("my-token"));
261
262 store.remove(&session, "bearer");
263 assert!(store.get(&session, "bearer").is_none());
264 }
265
266 #[tokio::test]
267 async fn auth_interceptor_injects_bearer() {
268 let store = Arc::new(InMemoryCredentialsStore::new());
269 let session = SessionId::new("test");
270 store.set(session.clone(), "bearer", "my-secret-token".into());
271
272 let interceptor = AuthInterceptor::new(store, session);
273 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
274
275 interceptor.before(&mut req).await.unwrap();
276
277 assert_eq!(
278 req.extra_headers.get("authorization").map(String::as_str),
279 Some("Bearer my-secret-token")
280 );
281 }
282
283 #[tokio::test]
284 async fn auth_interceptor_no_credential_no_header() {
285 let store = Arc::new(InMemoryCredentialsStore::new());
286 let session = SessionId::new("empty");
287 let interceptor = AuthInterceptor::new(store, session);
288
289 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
290 interceptor.before(&mut req).await.unwrap();
291
292 assert!(!req.extra_headers.contains_key("authorization"));
293 }
294
295 #[test]
296 fn credentials_store_multiple_sessions() {
297 let store = InMemoryCredentialsStore::new();
298 let s1 = SessionId::new("session-1");
299 let s2 = SessionId::new("session-2");
300
301 store.set(s1.clone(), "bearer", "token-1".into());
302 store.set(s2.clone(), "bearer", "token-2".into());
303
304 assert_eq!(store.get(&s1, "bearer").as_deref(), Some("token-1"));
305 assert_eq!(store.get(&s2, "bearer").as_deref(), Some("token-2"));
306
307 store.remove(&s1, "bearer");
309 assert!(store.get(&s1, "bearer").is_none());
310 assert_eq!(store.get(&s2, "bearer").as_deref(), Some("token-2"));
311 }
312
313 #[test]
314 fn credentials_store_multiple_schemes() {
315 let store = InMemoryCredentialsStore::new();
316 let session = SessionId::new("multi-scheme");
317
318 store.set(session.clone(), "bearer", "bearer-tok".into());
319 store.set(session.clone(), "api-key", "key-123".into());
320
321 assert_eq!(store.get(&session, "bearer").as_deref(), Some("bearer-tok"));
322 assert_eq!(store.get(&session, "api-key").as_deref(), Some("key-123"));
323 }
324
325 #[test]
326 fn credentials_store_overwrite() {
327 let store = InMemoryCredentialsStore::new();
328 let session = SessionId::new("overwrite");
329
330 store.set(session.clone(), "bearer", "old-token".into());
331 store.set(session.clone(), "bearer", "new-token".into());
332
333 assert_eq!(store.get(&session, "bearer").as_deref(), Some("new-token"));
334 }
335
336 #[test]
337 fn credentials_store_debug_hides_values() {
338 let store = InMemoryCredentialsStore::new();
339 let session = SessionId::new("secret");
340 store.set(session, "bearer", "super-secret-token".into());
341
342 let debug_output = format!("{store:?}");
343 assert!(
344 !debug_output.contains("super-secret"),
345 "debug output should not expose credentials: {debug_output}"
346 );
347 assert!(debug_output.contains("sessions"));
348 }
349
350 #[tokio::test]
351 async fn auth_interceptor_basic_scheme() {
352 let store = Arc::new(InMemoryCredentialsStore::new());
353 let session = SessionId::new("basic-test");
354 store.set(session.clone(), "basic", "dXNlcjpwYXNz".into());
355
356 let interceptor = AuthInterceptor::with_scheme(store, session, "basic");
357 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
358 interceptor.before(&mut req).await.unwrap();
359
360 assert_eq!(
361 req.extra_headers.get("authorization").map(String::as_str),
362 Some("Basic dXNlcjpwYXNz")
363 );
364 }
365
366 #[tokio::test]
367 async fn auth_interceptor_custom_scheme() {
368 let store = Arc::new(InMemoryCredentialsStore::new());
369 let session = SessionId::new("custom-test");
370 store.set(session.clone(), "api-key", "my-api-key".into());
371
372 let interceptor = AuthInterceptor::with_scheme(store, session, "api-key");
373 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
374 interceptor.before(&mut req).await.unwrap();
375
376 assert_eq!(
378 req.extra_headers.get("authorization").map(String::as_str),
379 Some("my-api-key")
380 );
381 }
382
383 #[test]
384 fn session_id_display() {
385 let session = SessionId::new("my-session");
386 assert_eq!(session.to_string(), "my-session");
387 }
388
389 #[test]
390 fn session_id_from_string() {
391 let session: SessionId = "test".into();
392 assert_eq!(session, SessionId::new("test"));
393
394 let session: SessionId = String::from("owned").into();
395 assert_eq!(session, SessionId::new("owned"));
396 }
397
398 #[test]
399 fn credentials_store_default_impl() {
400 let store = InMemoryCredentialsStore::default();
401 let session = SessionId::new("test");
402 assert!(store.get(&session, "bearer").is_none());
403 }
404
405 #[tokio::test]
406 async fn auth_interceptor_after_is_noop() {
407 let store = Arc::new(InMemoryCredentialsStore::new());
408 let session = SessionId::new("test");
409 let interceptor = AuthInterceptor::new(store, session);
410 let resp = ClientResponse {
411 method: "test".into(),
412 result: serde_json::Value::Null,
413 status_code: 200,
414 };
415 interceptor.after(&resp).await.unwrap();
416 }
417
418 #[test]
419 fn auth_interceptor_debug_contains_fields() {
420 let store = Arc::new(InMemoryCredentialsStore::new());
421 let session = SessionId::new("debug-session");
422 let interceptor = AuthInterceptor::new(store, session);
423 let debug = format!("{interceptor:?}");
424 assert!(
425 debug.contains("AuthInterceptor"),
426 "debug output missing struct name: {debug}"
427 );
428 assert!(
429 debug.contains("debug-session"),
430 "debug output missing session: {debug}"
431 );
432 assert!(
433 debug.contains("bearer"),
434 "debug output missing scheme: {debug}"
435 );
436 }
437}