use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use crate::error::ClientResult;
use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SessionId(String);
impl SessionId {
#[must_use]
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
}
impl fmt::Display for SessionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl From<String> for SessionId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for SessionId {
fn from(s: &str) -> Self {
Self(s.to_owned())
}
}
pub trait CredentialsStore: Send + Sync + 'static {
fn get(&self, session: &SessionId, scheme: &str) -> Option<String>;
fn set(&self, session: SessionId, scheme: &str, credential: String);
fn remove(&self, session: &SessionId, scheme: &str);
}
pub struct InMemoryCredentialsStore {
inner: RwLock<HashMap<SessionId, HashMap<String, String>>>,
}
impl InMemoryCredentialsStore {
#[must_use]
pub fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryCredentialsStore {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for InMemoryCredentialsStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self
.inner
.read()
.expect("credentials store lock poisoned")
.len();
f.debug_struct("InMemoryCredentialsStore")
.field("sessions", &count)
.finish()
}
}
impl CredentialsStore for InMemoryCredentialsStore {
fn get(&self, session: &SessionId, scheme: &str) -> Option<String> {
let guard = self.inner.read().expect("credentials store lock poisoned");
guard.get(session)?.get(scheme).cloned()
}
fn set(&self, session: SessionId, scheme: &str, credential: String) {
let mut guard = self.inner.write().expect("credentials store lock poisoned");
guard
.entry(session)
.or_default()
.insert(scheme.to_owned(), credential);
}
fn remove(&self, session: &SessionId, scheme: &str) {
let mut guard = self.inner.write().expect("credentials store lock poisoned");
if let Some(schemes) = guard.get_mut(session) {
schemes.remove(scheme);
}
}
}
pub struct AuthInterceptor {
store: Arc<dyn CredentialsStore>,
session: SessionId,
scheme: String,
}
impl AuthInterceptor {
#[must_use]
pub fn new(store: Arc<dyn CredentialsStore>, session: SessionId) -> Self {
Self {
store,
session,
scheme: "bearer".to_owned(),
}
}
#[must_use]
pub fn with_scheme(
store: Arc<dyn CredentialsStore>,
session: SessionId,
scheme: impl Into<String>,
) -> Self {
Self {
store,
session,
scheme: scheme.into(),
}
}
}
#[allow(clippy::missing_fields_in_debug)]
impl fmt::Debug for AuthInterceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthInterceptor")
.field("session", &self.session)
.field("scheme", &self.scheme)
.finish()
}
}
impl CallInterceptor for AuthInterceptor {
#[allow(clippy::manual_async_fn)]
fn before<'a>(
&'a self,
req: &'a mut ClientRequest,
) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
async move {
if let Some(credential) = self.store.get(&self.session, &self.scheme) {
let header_value = if self.scheme.eq_ignore_ascii_case("bearer") {
format!("Bearer {credential}")
} else if self.scheme.eq_ignore_ascii_case("basic") {
format!("Basic {credential}")
} else {
credential
};
req.extra_headers
.insert("authorization".to_owned(), header_value);
}
Ok(())
}
}
#[allow(clippy::manual_async_fn)]
fn after<'a>(
&'a self,
_resp: &'a ClientResponse,
) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
async move { Ok(()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn credentials_store_set_get_remove() {
let store = InMemoryCredentialsStore::new();
let session = SessionId::new("sess-1");
assert!(store.get(&session, "bearer").is_none());
store.set(session.clone(), "bearer", "my-token".into());
assert_eq!(store.get(&session, "bearer").as_deref(), Some("my-token"));
store.remove(&session, "bearer");
assert!(store.get(&session, "bearer").is_none());
}
#[tokio::test]
async fn auth_interceptor_injects_bearer() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("test");
store.set(session.clone(), "bearer", "my-secret-token".into());
let interceptor = AuthInterceptor::new(store, session);
let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
interceptor.before(&mut req).await.unwrap();
assert_eq!(
req.extra_headers.get("authorization").map(String::as_str),
Some("Bearer my-secret-token")
);
}
#[tokio::test]
async fn auth_interceptor_no_credential_no_header() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("empty");
let interceptor = AuthInterceptor::new(store, session);
let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
interceptor.before(&mut req).await.unwrap();
assert!(!req.extra_headers.contains_key("authorization"));
}
#[test]
fn credentials_store_multiple_sessions() {
let store = InMemoryCredentialsStore::new();
let s1 = SessionId::new("session-1");
let s2 = SessionId::new("session-2");
store.set(s1.clone(), "bearer", "token-1".into());
store.set(s2.clone(), "bearer", "token-2".into());
assert_eq!(store.get(&s1, "bearer").as_deref(), Some("token-1"));
assert_eq!(store.get(&s2, "bearer").as_deref(), Some("token-2"));
store.remove(&s1, "bearer");
assert!(store.get(&s1, "bearer").is_none());
assert_eq!(store.get(&s2, "bearer").as_deref(), Some("token-2"));
}
#[test]
fn credentials_store_multiple_schemes() {
let store = InMemoryCredentialsStore::new();
let session = SessionId::new("multi-scheme");
store.set(session.clone(), "bearer", "bearer-tok".into());
store.set(session.clone(), "api-key", "key-123".into());
assert_eq!(store.get(&session, "bearer").as_deref(), Some("bearer-tok"));
assert_eq!(store.get(&session, "api-key").as_deref(), Some("key-123"));
}
#[test]
fn credentials_store_overwrite() {
let store = InMemoryCredentialsStore::new();
let session = SessionId::new("overwrite");
store.set(session.clone(), "bearer", "old-token".into());
store.set(session.clone(), "bearer", "new-token".into());
assert_eq!(store.get(&session, "bearer").as_deref(), Some("new-token"));
}
#[test]
fn credentials_store_debug_hides_values() {
let store = InMemoryCredentialsStore::new();
let session = SessionId::new("secret");
store.set(session, "bearer", "super-secret-token".into());
let debug_output = format!("{store:?}");
assert!(
!debug_output.contains("super-secret"),
"debug output should not expose credentials: {debug_output}"
);
assert!(debug_output.contains("sessions"));
}
#[tokio::test]
async fn auth_interceptor_basic_scheme() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("basic-test");
store.set(session.clone(), "basic", "dXNlcjpwYXNz".into());
let interceptor = AuthInterceptor::with_scheme(store, session, "basic");
let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
interceptor.before(&mut req).await.unwrap();
assert_eq!(
req.extra_headers.get("authorization").map(String::as_str),
Some("Basic dXNlcjpwYXNz")
);
}
#[tokio::test]
async fn auth_interceptor_custom_scheme() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("custom-test");
store.set(session.clone(), "api-key", "my-api-key".into());
let interceptor = AuthInterceptor::with_scheme(store, session, "api-key");
let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
interceptor.before(&mut req).await.unwrap();
assert_eq!(
req.extra_headers.get("authorization").map(String::as_str),
Some("my-api-key")
);
}
#[test]
fn session_id_display() {
let session = SessionId::new("my-session");
assert_eq!(session.to_string(), "my-session");
}
#[test]
fn session_id_from_string() {
let session: SessionId = "test".into();
assert_eq!(session, SessionId::new("test"));
let session: SessionId = String::from("owned").into();
assert_eq!(session, SessionId::new("owned"));
}
#[test]
fn credentials_store_default_impl() {
let store = InMemoryCredentialsStore::default();
let session = SessionId::new("test");
assert!(store.get(&session, "bearer").is_none());
}
#[tokio::test]
async fn auth_interceptor_after_is_noop() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("test");
let interceptor = AuthInterceptor::new(store, session);
let resp = ClientResponse {
method: "test".into(),
result: serde_json::Value::Null,
status_code: 200,
};
interceptor.after(&resp).await.unwrap();
}
#[test]
fn auth_interceptor_debug_contains_fields() {
let store = Arc::new(InMemoryCredentialsStore::new());
let session = SessionId::new("debug-session");
let interceptor = AuthInterceptor::new(store, session);
let debug = format!("{interceptor:?}");
assert!(
debug.contains("AuthInterceptor"),
"debug output missing struct name: {debug}"
);
assert!(
debug.contains("debug-session"),
"debug output missing session: {debug}"
);
assert!(
debug.contains("bearer"),
"debug output missing scheme: {debug}"
);
}
}