use crate::error::Result;
use crate::shared::TransportMessage;
use crate::types::{JSONRPCNotification, JSONRPCRequest, JSONRPCResponse};
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::RwLock;
use std::fmt;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct MiddlewareContext {
pub request_id: Option<String>,
pub metadata: Arc<DashMap<String, String>>,
pub metrics: Arc<PerformanceMetrics>,
pub start_time: Instant,
pub priority: Option<crate::shared::transport::MessagePriority>,
}
impl Default for MiddlewareContext {
fn default() -> Self {
Self {
request_id: None,
metadata: Arc::new(DashMap::new()),
metrics: Arc::new(PerformanceMetrics::new()),
start_time: Instant::now(),
priority: None,
}
}
}
impl MiddlewareContext {
pub fn with_request_id(request_id: String) -> Self {
Self {
request_id: Some(request_id),
..Default::default()
}
}
pub fn set_metadata(&self, key: String, value: String) {
self.metadata.insert(key, value);
}
pub fn get_metadata(&self, key: &str) -> Option<String> {
self.metadata.get(key).map(|v| v.clone())
}
pub fn record_metric(&self, name: String, value: f64) {
self.metrics.record(name, value);
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
}
#[derive(Debug, Default)]
pub struct PerformanceMetrics {
metrics: DashMap<String, f64>,
request_count: AtomicU64,
error_count: AtomicU64,
total_time_us: AtomicU64,
}
impl PerformanceMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, name: String, value: f64) {
self.metrics.insert(name, value);
}
pub fn get(&self, name: &str) -> Option<f64> {
self.metrics.get(name).map(|v| *v)
}
pub fn inc_requests(&self) {
self.request_count.fetch_add(1, Ordering::Relaxed);
}
pub fn inc_errors(&self) {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
pub fn add_time(&self, duration: Duration) {
self.total_time_us
.fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
}
pub fn request_count(&self) -> u64 {
self.request_count.load(Ordering::Relaxed)
}
pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}
pub fn average_time(&self) -> Duration {
let total_time = self.total_time_us.load(Ordering::Relaxed);
let count = self.request_count.load(Ordering::Relaxed);
if count > 0 {
Duration::from_micros(total_time / count)
} else {
Duration::ZERO
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum MiddlewarePriority {
Critical = 0,
High = 1,
#[default]
Normal = 2,
Low = 3,
Lowest = 4,
}
#[async_trait]
pub trait AdvancedMiddleware: Send + Sync {
fn priority(&self) -> MiddlewarePriority {
MiddlewarePriority::Normal
}
fn name(&self) -> &'static str {
"unknown"
}
async fn should_execute(&self, _context: &MiddlewareContext) -> bool {
true
}
async fn on_request_with_context(
&self,
request: &mut JSONRPCRequest,
context: &MiddlewareContext,
) -> Result<()> {
let _ = (request, context);
Ok(())
}
async fn on_response_with_context(
&self,
response: &mut JSONRPCResponse,
context: &MiddlewareContext,
) -> Result<()> {
let _ = (response, context);
Ok(())
}
async fn on_send_with_context(
&self,
message: &TransportMessage,
context: &MiddlewareContext,
) -> Result<()> {
let _ = (message, context);
Ok(())
}
async fn on_receive_with_context(
&self,
message: &TransportMessage,
context: &MiddlewareContext,
) -> Result<()> {
let _ = (message, context);
Ok(())
}
async fn on_notification_with_context(
&self,
notification: &mut JSONRPCNotification,
context: &MiddlewareContext,
) -> Result<()> {
let _ = (notification, context);
Ok(())
}
async fn on_chain_start(&self, _context: &MiddlewareContext) -> Result<()> {
Ok(())
}
async fn on_chain_complete(&self, _context: &MiddlewareContext) -> Result<()> {
Ok(())
}
async fn on_error(
&self,
_error: &crate::error::Error,
_context: &MiddlewareContext,
) -> Result<()> {
Ok(())
}
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
let _ = request;
Ok(())
}
async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
let _ = response;
Ok(())
}
async fn on_send(&self, message: &TransportMessage) -> Result<()> {
let _ = message;
Ok(())
}
async fn on_receive(&self, message: &TransportMessage) -> Result<()> {
let _ = message;
Ok(())
}
async fn on_notification(&self, notification: &mut JSONRPCNotification) -> Result<()> {
let _ = notification;
Ok(())
}
}
pub struct EnhancedMiddlewareChain {
middlewares: Vec<Arc<dyn AdvancedMiddleware>>,
auto_sort: bool,
}
impl fmt::Debug for EnhancedMiddlewareChain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EnhancedMiddlewareChain")
.field("count", &self.middlewares.len())
.field("auto_sort", &self.auto_sort)
.finish()
}
}
impl Default for EnhancedMiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl EnhancedMiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
auto_sort: true,
}
}
pub fn new_no_sort() -> Self {
Self {
middlewares: Vec::new(),
auto_sort: false,
}
}
pub fn add(&mut self, middleware: Arc<dyn AdvancedMiddleware>) {
self.middlewares.push(middleware);
if self.auto_sort {
self.sort_by_priority();
}
}
pub fn sort_by_priority(&mut self) {
self.middlewares.sort_by_key(|m| m.priority());
}
pub fn len(&self) -> usize {
self.middlewares.len()
}
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
pub async fn process_request_with_context(
&self,
request: &mut JSONRPCRequest,
context: &MiddlewareContext,
) -> Result<()> {
context.metrics.inc_requests();
let start_time = Instant::now();
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
middleware.on_chain_start(context).await?;
}
}
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_request_with_context(request, context).await {
context.metrics.inc_errors();
for m in &self.middlewares {
if m.should_execute(context).await {
let _ = m.on_error(&e, context).await;
}
}
return Err(e);
}
}
}
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
middleware.on_chain_complete(context).await?;
}
}
context.metrics.add_time(start_time.elapsed());
Ok(())
}
pub async fn process_response_with_context(
&self,
response: &mut JSONRPCResponse,
context: &MiddlewareContext,
) -> Result<()> {
let start_time = Instant::now();
for middleware in self.middlewares.iter().rev() {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_response_with_context(response, context).await {
context.metrics.inc_errors();
for m in &self.middlewares {
if m.should_execute(context).await {
let _ = m.on_error(&e, context).await;
}
}
return Err(e);
}
}
}
context.metrics.add_time(start_time.elapsed());
Ok(())
}
pub async fn process_send_with_context(
&self,
message: &TransportMessage,
context: &MiddlewareContext,
) -> Result<()> {
let start_time = Instant::now();
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_send_with_context(message, context).await {
context.metrics.inc_errors();
for m in &self.middlewares {
if m.should_execute(context).await {
let _ = m.on_error(&e, context).await;
}
}
return Err(e);
}
}
}
context.metrics.add_time(start_time.elapsed());
Ok(())
}
pub async fn process_receive_with_context(
&self,
message: &TransportMessage,
context: &MiddlewareContext,
) -> Result<()> {
let start_time = Instant::now();
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_receive_with_context(message, context).await {
context.metrics.inc_errors();
for m in &self.middlewares {
if m.should_execute(context).await {
let _ = m.on_error(&e, context).await;
}
}
return Err(e);
}
}
}
context.metrics.add_time(start_time.elapsed());
Ok(())
}
pub async fn process_notification_with_context(
&self,
notification: &mut JSONRPCNotification,
context: &MiddlewareContext,
) -> Result<()> {
let start_time = Instant::now();
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware
.on_notification_with_context(notification, context)
.await
{
context.metrics.inc_errors();
for m in &self.middlewares {
if m.should_execute(context).await {
let _ = m.on_error(&e, context).await;
}
}
return Err(e);
}
}
}
context.metrics.add_time(start_time.elapsed());
Ok(())
}
pub fn get_metrics(&self) -> Vec<Arc<PerformanceMetrics>> {
Vec::new()
}
}
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
}
impl fmt::Debug for MiddlewareChain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MiddlewareChain")
.field("count", &self.middlewares.len())
.finish()
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
self.middlewares.push(middleware);
}
pub async fn process_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_request(request).await?;
}
Ok(())
}
pub async fn process_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_response(response).await?;
}
Ok(())
}
pub async fn process_send(&self, message: &TransportMessage) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_send(message).await?;
}
Ok(())
}
pub async fn process_receive(&self, message: &TransportMessage) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_receive(message).await?;
}
Ok(())
}
pub async fn process_notification(&self, notification: &mut JSONRPCNotification) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_notification(notification).await?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LoggingMiddleware {
level: tracing::Level,
}
impl LoggingMiddleware {
pub fn new(level: tracing::Level) -> Self {
Self { level }
}
}
impl Default for LoggingMiddleware {
fn default() -> Self {
Self::new(tracing::Level::DEBUG)
}
}
#[async_trait]
impl Middleware for LoggingMiddleware {
async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
match self.level {
tracing::Level::TRACE => tracing::trace!("Sending request: {:?}", request),
tracing::Level::DEBUG => tracing::debug!("Sending request: {}", request.method),
tracing::Level::INFO => tracing::info!("Sending request: {}", request.method),
tracing::Level::WARN => tracing::warn!("Sending request: {}", request.method),
tracing::Level::ERROR => tracing::error!("Sending request: {}", request.method),
}
Ok(())
}
async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
match self.level {
tracing::Level::TRACE => tracing::trace!("Received response: {:?}", response),
tracing::Level::DEBUG => tracing::debug!("Received response for: {:?}", response.id),
tracing::Level::INFO => tracing::info!("Received response"),
tracing::Level::WARN => tracing::warn!("Received response"),
tracing::Level::ERROR => tracing::error!("Received response"),
}
Ok(())
}
}
#[derive(Debug)]
pub struct AuthMiddleware {
#[allow(dead_code)]
auth_token: String,
}
impl AuthMiddleware {
pub fn new(auth_token: String) -> Self {
Self { auth_token }
}
}
#[async_trait]
impl Middleware for AuthMiddleware {
async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
tracing::debug!("Adding authentication to request: {}", request.method);
Ok(())
}
}
#[derive(Debug)]
pub struct RetryMiddleware {
max_retries: u32,
#[allow(dead_code)]
initial_delay_ms: u64,
#[allow(dead_code)]
max_delay_ms: u64,
}
impl RetryMiddleware {
pub fn new(max_retries: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
Self {
max_retries,
initial_delay_ms,
max_delay_ms,
}
}
}
impl Default for RetryMiddleware {
fn default() -> Self {
Self::new(3, 1000, 30000)
}
}
#[async_trait]
impl Middleware for RetryMiddleware {
async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
tracing::debug!(
"Request {} configured with max {} retries",
request.method,
self.max_retries
);
Ok(())
}
}
#[derive(Debug)]
pub struct RateLimitMiddleware {
max_requests: u32,
bucket_size: u32,
refill_duration: Duration,
tokens: Arc<AtomicUsize>,
last_refill: Arc<RwLock<Instant>>,
}
impl RateLimitMiddleware {
pub fn new(max_requests: u32, bucket_size: u32, refill_duration: Duration) -> Self {
Self {
max_requests,
bucket_size,
refill_duration,
tokens: Arc::new(AtomicUsize::new(bucket_size as usize)),
last_refill: Arc::new(RwLock::new(Instant::now())),
}
}
fn check_rate_limit(&self) -> bool {
let now = Instant::now();
let mut last_refill = self.last_refill.write();
let elapsed = now.duration_since(*last_refill);
if elapsed >= self.refill_duration {
let refill_count = (elapsed.as_millis() / self.refill_duration.as_millis()) as u32;
let tokens_to_add = (refill_count * self.max_requests).min(self.bucket_size);
self.tokens.store(
(self.tokens.load(Ordering::Relaxed) + tokens_to_add as usize)
.min(self.bucket_size as usize),
Ordering::Relaxed,
);
*last_refill = now;
}
loop {
let current = self.tokens.load(Ordering::Relaxed);
if current == 0 {
return false;
}
if self
.tokens
.compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
return true;
}
}
}
}
#[async_trait]
impl AdvancedMiddleware for RateLimitMiddleware {
fn name(&self) -> &'static str {
"rate_limit"
}
fn priority(&self) -> MiddlewarePriority {
MiddlewarePriority::High
}
async fn on_request_with_context(
&self,
request: &mut JSONRPCRequest,
context: &MiddlewareContext,
) -> Result<()> {
if !self.check_rate_limit() {
tracing::warn!("Rate limit exceeded for request: {}", request.method);
context.record_metric("rate_limit_exceeded".to_string(), 1.0);
return Err(crate::error::Error::RateLimited);
}
tracing::debug!("Rate limit check passed for request: {}", request.method);
context.record_metric("rate_limit_passed".to_string(), 1.0);
Ok(())
}
}
#[derive(Debug)]
pub struct CircuitBreakerMiddleware {
failure_threshold: u32,
time_window: Duration,
timeout_duration: Duration,
failure_count: Arc<AtomicU64>,
last_failure: Arc<RwLock<Option<Instant>>>,
circuit_open_time: Arc<RwLock<Option<Instant>>>,
}
impl CircuitBreakerMiddleware {
pub fn new(failure_threshold: u32, time_window: Duration, timeout_duration: Duration) -> Self {
Self {
failure_threshold,
time_window,
timeout_duration,
failure_count: Arc::new(AtomicU64::new(0)),
last_failure: Arc::new(RwLock::new(None)),
circuit_open_time: Arc::new(RwLock::new(None)),
}
}
fn should_allow_request(&self) -> bool {
let now = Instant::now();
let open_time_value = *self.circuit_open_time.read();
if let Some(open_time) = open_time_value {
if now.duration_since(open_time) > self.timeout_duration {
*self.circuit_open_time.write() = None;
self.failure_count.store(0, Ordering::Relaxed);
return true;
}
return false; }
let last_failure_value = *self.last_failure.read();
if let Some(last_failure) = last_failure_value {
if now.duration_since(last_failure) > self.time_window {
self.failure_count.store(0, Ordering::Relaxed);
}
}
self.failure_count.load(Ordering::Relaxed) < self.failure_threshold as u64
}
fn record_failure(&self) {
let now = Instant::now();
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
*self.last_failure.write() = Some(now);
if failures >= self.failure_threshold as u64 {
*self.circuit_open_time.write() = Some(now);
tracing::warn!("Circuit breaker opened due to {} failures", failures);
}
}
}
#[async_trait]
impl AdvancedMiddleware for CircuitBreakerMiddleware {
fn name(&self) -> &'static str {
"circuit_breaker"
}
fn priority(&self) -> MiddlewarePriority {
MiddlewarePriority::High
}
async fn on_request_with_context(
&self,
request: &mut JSONRPCRequest,
context: &MiddlewareContext,
) -> Result<()> {
if !self.should_allow_request() {
tracing::warn!(
"Circuit breaker open, rejecting request: {}",
request.method
);
context.record_metric("circuit_breaker_open".to_string(), 1.0);
return Err(crate::error::Error::CircuitBreakerOpen);
}
context.record_metric("circuit_breaker_allowed".to_string(), 1.0);
Ok(())
}
async fn on_error(
&self,
_error: &crate::error::Error,
_context: &MiddlewareContext,
) -> Result<()> {
self.record_failure();
Ok(())
}
}
#[derive(Debug)]
pub struct MetricsMiddleware {
service_name: String,
request_counts: Arc<DashMap<String, AtomicU64>>,
request_durations: Arc<DashMap<String, AtomicU64>>,
error_counts: Arc<DashMap<String, AtomicU64>>,
}
impl MetricsMiddleware {
pub fn new(service_name: String) -> Self {
Self {
service_name,
request_counts: Arc::new(DashMap::new()),
request_durations: Arc::new(DashMap::new()),
error_counts: Arc::new(DashMap::new()),
}
}
pub fn get_request_count(&self, method: &str) -> u64 {
self.request_counts
.get(method)
.map_or(0, |c| c.load(Ordering::Relaxed))
}
pub fn get_error_count(&self, method: &str) -> u64 {
self.error_counts
.get(method)
.map_or(0, |c| c.load(Ordering::Relaxed))
}
pub fn get_average_duration(&self, method: &str) -> u64 {
let total_duration = self
.request_durations
.get(method)
.map_or(0, |d| d.load(Ordering::Relaxed));
let count = self.get_request_count(method);
if count > 0 {
total_duration / count
} else {
0
}
}
}
#[async_trait]
impl AdvancedMiddleware for MetricsMiddleware {
fn name(&self) -> &'static str {
"metrics"
}
fn priority(&self) -> MiddlewarePriority {
MiddlewarePriority::Low
}
async fn on_request_with_context(
&self,
request: &mut JSONRPCRequest,
context: &MiddlewareContext,
) -> Result<()> {
self.request_counts
.entry(request.method.clone())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
context.set_metadata(
"request_start_time".to_string(),
context.start_time.elapsed().as_micros().to_string(),
);
context.set_metadata("service_name".to_string(), self.service_name.clone());
tracing::debug!(
"Metrics recorded for request: {} (service: {})",
request.method,
self.service_name
);
Ok(())
}
async fn on_response_with_context(
&self,
response: &mut JSONRPCResponse,
context: &MiddlewareContext,
) -> Result<()> {
let duration_us = context.elapsed().as_micros() as u64;
if let Some(method) = context.get_metadata("method") {
self.request_durations
.entry(method)
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(duration_us, Ordering::Relaxed);
}
tracing::debug!(
"Response metrics recorded for ID: {:?} ({}μs)",
response.id,
duration_us
);
Ok(())
}
async fn on_error(
&self,
error: &crate::error::Error,
context: &MiddlewareContext,
) -> Result<()> {
if let Some(method) = context.get_metadata("method") {
self.error_counts
.entry(method)
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
tracing::warn!("Error recorded in metrics: {:?}", error);
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum CompressionType {
None,
Gzip,
Deflate,
}
#[derive(Debug)]
pub struct CompressionMiddleware {
compression_type: CompressionType,
min_size: usize,
}
impl CompressionMiddleware {
pub fn new(compression_type: CompressionType, min_size: usize) -> Self {
Self {
compression_type,
min_size,
}
}
fn should_compress(&self, content_size: usize) -> bool {
content_size >= self.min_size && !matches!(self.compression_type, CompressionType::None)
}
}
#[async_trait]
impl AdvancedMiddleware for CompressionMiddleware {
fn name(&self) -> &'static str {
"compression"
}
fn priority(&self) -> MiddlewarePriority {
MiddlewarePriority::Normal
}
async fn on_send_with_context(
&self,
message: &TransportMessage,
context: &MiddlewareContext,
) -> Result<()> {
let serialized = serde_json::to_string(message).unwrap_or_default();
let content_size = serialized.len();
if self.should_compress(content_size) {
context.set_metadata(
"compression_type".to_string(),
format!("{:?}", self.compression_type),
);
context.record_metric("compression_original_size".to_string(), content_size as f64);
tracing::debug!("Compression applied to message of {} bytes", content_size);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RequestId;
#[tokio::test]
async fn test_middleware_chain() {
let mut chain = MiddlewareChain::new();
chain.add(Arc::new(LoggingMiddleware::default()));
let mut request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: RequestId::from(1i64),
method: "test".to_string(),
params: None,
};
assert!(chain.process_request(&mut request).await.is_ok());
}
#[tokio::test]
async fn test_auth_middleware() {
let middleware = AuthMiddleware::new("test-token".to_string());
let mut request = JSONRPCRequest {
jsonrpc: "2.0".to_string(),
id: RequestId::from(1i64),
method: "test".to_string(),
params: None,
};
assert!(middleware.on_request(&mut request).await.is_ok());
}
#[tokio::test]
async fn test_notification_middleware_legacy() {
let mut chain = MiddlewareChain::new();
chain.add(Arc::new(LoggingMiddleware::default()));
let mut notification = JSONRPCNotification::new(
"notifications/progress",
Some(serde_json::json!({
"progressToken": "test-123",
"progress": 50,
"total": 100
})),
);
assert!(chain.process_notification(&mut notification).await.is_ok());
}
#[tokio::test]
async fn test_notification_middleware_enhanced() {
let mut chain = EnhancedMiddlewareChain::new();
chain.add(Arc::new(MetricsMiddleware::new("test-service".to_string())));
let context = MiddlewareContext::with_request_id("notif-001".to_string());
let mut notification = JSONRPCNotification::new(
"notifications/resourceUpdated",
Some(serde_json::json!({
"uri": "file:///test.txt",
"type": "modified"
})),
);
assert!(chain
.process_notification_with_context(&mut notification, &context)
.await
.is_ok());
let stats = context.metrics;
assert_eq!(stats.request_count(), 0);
}
struct NotificationMetadataMiddleware {
tag: String,
}
#[async_trait::async_trait]
impl AdvancedMiddleware for NotificationMetadataMiddleware {
fn name(&self) -> &'static str {
"notification_metadata"
}
async fn on_notification_with_context(
&self,
notification: &mut JSONRPCNotification,
context: &MiddlewareContext,
) -> Result<()> {
context.set_metadata(
"notification_method".to_string(),
notification.method.clone(),
);
context.set_metadata("middleware_tag".to_string(), self.tag.clone());
Ok(())
}
}
#[tokio::test]
async fn test_notification_metadata_middleware() {
let mut chain = EnhancedMiddlewareChain::new();
chain.add(Arc::new(NotificationMetadataMiddleware {
tag: "test-tag".to_string(),
}));
let context = MiddlewareContext::with_request_id("notif-002".to_string());
let mut notification = JSONRPCNotification::new(
"notifications/cancelled",
Some(serde_json::json!({
"requestId": "req-123",
"reason": "user cancelled"
})),
);
chain
.process_notification_with_context(&mut notification, &context)
.await
.unwrap();
assert_eq!(
context.get_metadata("notification_method"),
Some("notifications/cancelled".to_string())
);
assert_eq!(
context.get_metadata("middleware_tag"),
Some("test-tag".to_string())
);
}
#[tokio::test]
async fn test_notification_error_handling() {
struct FailingNotificationMiddleware;
#[async_trait::async_trait]
impl AdvancedMiddleware for FailingNotificationMiddleware {
fn name(&self) -> &'static str {
"failing_notification"
}
async fn on_notification_with_context(
&self,
notification: &mut JSONRPCNotification,
_context: &MiddlewareContext,
) -> Result<()> {
if notification.method == "notifications/error" {
return Err(crate::Error::internal("notification processing failed"));
}
Ok(())
}
}
let mut chain = EnhancedMiddlewareChain::new();
chain.add(Arc::new(FailingNotificationMiddleware));
let context = MiddlewareContext::default();
let mut ok_notification =
JSONRPCNotification::new("notifications/ok", None::<serde_json::Value>);
assert!(chain
.process_notification_with_context(&mut ok_notification, &context)
.await
.is_ok());
let mut error_notification =
JSONRPCNotification::new("notifications/error", None::<serde_json::Value>);
let result = chain
.process_notification_with_context(&mut error_notification, &context)
.await;
assert!(result.is_err());
assert_eq!(context.metrics.error_count(), 1);
}
}