use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use crate::error::KnafehError;
use crate::rpc::message::{RpcRequest, RpcResponse};
#[async_trait]
pub trait Interceptor: Send + Sync + 'static {
async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
let _ = request;
Ok(())
}
async fn on_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
let _ = response;
Ok(())
}
}
pub struct MiddlewareStack {
interceptors: Vec<Arc<dyn Interceptor>>,
}
impl MiddlewareStack {
pub fn new() -> Self {
Self {
interceptors: Vec::new(),
}
}
pub fn add(&mut self, interceptor: Arc<dyn Interceptor>) {
self.interceptors.push(interceptor);
}
#[inline]
pub async fn apply_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
if self.interceptors.is_empty() {
return Ok(());
}
for interceptor in &self.interceptors {
interceptor.on_request(request).await?;
}
Ok(())
}
#[inline]
pub async fn apply_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
if self.interceptors.is_empty() {
return Ok(());
}
for interceptor in self.interceptors.iter().rev() {
interceptor.on_response(response).await?;
}
Ok(())
}
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
}
impl Default for MiddlewareStack {
fn default() -> Self {
Self::new()
}
}
pub struct LoggingInterceptor;
#[async_trait]
impl Interceptor for LoggingInterceptor {
async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
tracing::info!(method = %request.method, body_len = request.body.len(), "RPC request");
Ok(())
}
async fn on_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
tracing::info!(
status = ?response.status.code,
body_len = response.body.len(),
"RPC response"
);
Ok(())
}
}
pub struct TimeoutInterceptor {
pub duration: Duration,
}
impl TimeoutInterceptor {
pub fn new(duration: Duration) -> Self {
Self { duration }
}
}
#[async_trait]
impl Interceptor for TimeoutInterceptor {
async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
request.metadata.insert(
"x-rpc-timeout-ms".to_string(),
self.duration.as_millis().to_string(),
);
Ok(())
}
}