use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::ClientResult;
#[derive(Debug)]
pub struct ClientRequest {
pub method: String,
pub params: serde_json::Value,
pub extra_headers: HashMap<String, String>,
}
impl ClientRequest {
#[must_use]
pub fn new(method: impl Into<String>, params: serde_json::Value) -> Self {
Self {
method: method.into(),
params,
extra_headers: HashMap::new(),
}
}
}
#[derive(Debug)]
pub struct ClientResponse {
pub method: String,
pub result: serde_json::Value,
pub status_code: u16,
}
pub trait CallInterceptor: Send + Sync + 'static {
fn before<'a>(
&'a self,
req: &'a mut ClientRequest,
) -> impl Future<Output = ClientResult<()>> + Send + 'a;
fn after<'a>(
&'a self,
resp: &'a ClientResponse,
) -> impl Future<Output = ClientResult<()>> + Send + 'a;
}
pub(crate) trait CallInterceptorBoxed: Send + Sync + 'static {
fn before_boxed<'a>(
&'a self,
req: &'a mut ClientRequest,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
fn after_boxed<'a>(
&'a self,
resp: &'a ClientResponse,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
}
impl<T: CallInterceptor> CallInterceptorBoxed for T {
fn before_boxed<'a>(
&'a self,
req: &'a mut ClientRequest,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
Box::pin(self.before(req))
}
fn after_boxed<'a>(
&'a self,
resp: &'a ClientResponse,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
Box::pin(self.after(resp))
}
}
impl CallInterceptorBoxed for Box<dyn CallInterceptorBoxed> {
fn before_boxed<'a>(
&'a self,
req: &'a mut ClientRequest,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
(**self).before_boxed(req)
}
fn after_boxed<'a>(
&'a self,
resp: &'a ClientResponse,
) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
(**self).after_boxed(resp)
}
}
#[derive(Default)]
pub struct InterceptorChain {
interceptors: Vec<Arc<dyn CallInterceptorBoxed>>,
}
impl InterceptorChain {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn push<I: CallInterceptor>(&mut self, interceptor: I) {
self.interceptors.push(Arc::new(interceptor));
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
pub async fn run_before(&self, req: &mut ClientRequest) -> ClientResult<()> {
for interceptor in &self.interceptors {
interceptor.before_boxed(req).await?;
}
Ok(())
}
pub async fn run_after(&self, resp: &ClientResponse) -> ClientResult<()> {
for interceptor in self.interceptors.iter().rev() {
interceptor.after_boxed(resp).await?;
}
Ok(())
}
}
impl std::fmt::Debug for InterceptorChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InterceptorChain")
.field("count", &self.interceptors.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingInterceptor(Arc<AtomicUsize>);
impl CallInterceptor for CountingInterceptor {
#[allow(clippy::manual_async_fn)]
fn before<'a>(
&'a self,
_req: &'a mut ClientRequest,
) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
async move {
self.0.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[allow(clippy::manual_async_fn)]
fn after<'a>(
&'a self,
_resp: &'a ClientResponse,
) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
async move {
self.0.fetch_add(10, Ordering::SeqCst);
Ok(())
}
}
}
#[test]
fn chain_is_empty_when_new() {
let chain = InterceptorChain::new();
assert!(chain.is_empty(), "new chain should be empty");
}
#[test]
fn chain_is_not_empty_after_push() {
let counter = Arc::new(AtomicUsize::new(0));
let mut chain = InterceptorChain::new();
chain.push(CountingInterceptor(Arc::clone(&counter)));
assert!(
!chain.is_empty(),
"chain with one interceptor should not be empty"
);
}
#[tokio::test]
async fn chain_runs_before_in_order() {
let counter = Arc::new(AtomicUsize::new(0));
let mut chain = InterceptorChain::new();
chain.push(CountingInterceptor(Arc::clone(&counter)));
chain.push(CountingInterceptor(Arc::clone(&counter)));
let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
chain.run_before(&mut req).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn chain_runs_after_in_reverse_order() {
let counter = Arc::new(AtomicUsize::new(0));
let mut chain = InterceptorChain::new();
chain.push(CountingInterceptor(Arc::clone(&counter)));
let resp = ClientResponse {
method: "message/send".into(),
result: serde_json::Value::Null,
status_code: 200,
};
chain.run_after(&resp).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[tokio::test]
async fn boxed_interceptor_delegates_before_and_after() {
let counter = Arc::new(AtomicUsize::new(0));
let interceptor = CountingInterceptor(Arc::clone(&counter));
let boxed: Box<dyn CallInterceptorBoxed> = Box::new(interceptor);
let mut req = ClientRequest::new("test", serde_json::Value::Null);
boxed.before_boxed(&mut req).await.unwrap();
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"before_boxed should delegate"
);
let resp = ClientResponse {
method: "test".into(),
result: serde_json::Value::Null,
status_code: 200,
};
boxed.after_boxed(&resp).await.unwrap();
assert_eq!(
counter.load(Ordering::SeqCst),
11,
"after_boxed should delegate"
);
let double_boxed: Box<dyn CallInterceptorBoxed> = Box::new(boxed);
double_boxed.before_boxed(&mut req).await.unwrap();
assert_eq!(
counter.load(Ordering::SeqCst),
12,
"double-boxed before should delegate"
);
double_boxed.after_boxed(&resp).await.unwrap();
assert_eq!(
counter.load(Ordering::SeqCst),
22,
"double-boxed after should delegate"
);
}
}