Skip to main content

brainwires_mcp_server/middleware/
mod.rs

1/// Authentication middleware.
2pub mod auth;
3/// Request logging middleware.
4pub mod logging;
5/// OAuth 2.1 JWT validation middleware.
6#[cfg(feature = "oauth")]
7pub mod oauth;
8/// Rate limiting middleware.
9pub mod rate_limit;
10/// Tool filtering middleware.
11pub mod tool_filter;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use brainwires_mcp::{JsonRpcError, JsonRpcRequest, JsonRpcResponse};
16
17use crate::connection::RequestContext;
18
19/// Result of middleware processing.
20pub enum MiddlewareResult {
21    /// Allow the request to continue.
22    Continue,
23    /// Reject the request with an error.
24    Reject(JsonRpcError),
25}
26
27/// Trait for request/response middleware.
28#[async_trait]
29pub trait Middleware: Send + Sync + 'static {
30    /// Process an incoming request. Return `Continue` or `Reject`.
31    async fn process_request(
32        &self,
33        request: &JsonRpcRequest,
34        ctx: &mut RequestContext,
35    ) -> MiddlewareResult;
36
37    /// Optionally process the outgoing response (no-op by default).
38    async fn process_response(&self, _response: &mut JsonRpcResponse, _ctx: &RequestContext) {}
39}
40
41/// Ordered chain of middleware layers.
42pub struct MiddlewareChain {
43    layers: Vec<Box<dyn Middleware>>,
44}
45
46impl MiddlewareChain {
47    /// Create a new empty middleware chain.
48    pub fn new() -> Self {
49        Self { layers: Vec::new() }
50    }
51
52    /// Add a middleware layer to the chain.
53    pub fn add(&mut self, middleware: impl Middleware) {
54        self.layers.push(Box::new(middleware));
55    }
56
57    /// Run all middleware on the request, stopping on first reject.
58    pub async fn process_request(
59        &self,
60        request: &JsonRpcRequest,
61        ctx: &mut RequestContext,
62    ) -> Result<(), JsonRpcError> {
63        for layer in &self.layers {
64            match layer.process_request(request, ctx).await {
65                MiddlewareResult::Continue => continue,
66                MiddlewareResult::Reject(err) => return Err(err),
67            }
68        }
69        Ok(())
70    }
71
72    /// Run all middleware on the response.
73    pub async fn process_response(&self, response: &mut JsonRpcResponse, ctx: &RequestContext) {
74        for layer in &self.layers {
75            layer.process_response(response, ctx).await;
76        }
77    }
78}
79
80impl Default for MiddlewareChain {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use serde_json::json;
90
91    struct PassMiddleware;
92
93    #[async_trait]
94    impl Middleware for PassMiddleware {
95        async fn process_request(
96            &self,
97            _request: &JsonRpcRequest,
98            _ctx: &mut RequestContext,
99        ) -> MiddlewareResult {
100            MiddlewareResult::Continue
101        }
102    }
103
104    struct RejectMiddleware;
105
106    #[async_trait]
107    impl Middleware for RejectMiddleware {
108        async fn process_request(
109            &self,
110            _request: &JsonRpcRequest,
111            _ctx: &mut RequestContext,
112        ) -> MiddlewareResult {
113            MiddlewareResult::Reject(JsonRpcError {
114                code: -32003,
115                message: "Rejected".to_string(),
116                data: None,
117            })
118        }
119    }
120
121    #[tokio::test]
122    async fn test_chain_all_pass() {
123        let mut chain = MiddlewareChain::new();
124        chain.add(PassMiddleware);
125        chain.add(PassMiddleware);
126
127        let request = JsonRpcRequest {
128            jsonrpc: "2.0".to_string(),
129            id: json!(1),
130            method: "test".to_string(),
131            params: None,
132        };
133        let mut ctx = RequestContext::new(json!(1));
134        assert!(chain.process_request(&request, &mut ctx).await.is_ok());
135    }
136
137    #[tokio::test]
138    async fn test_chain_reject_stops() {
139        let mut chain = MiddlewareChain::new();
140        chain.add(PassMiddleware);
141        chain.add(RejectMiddleware);
142        chain.add(PassMiddleware);
143
144        let request = JsonRpcRequest {
145            jsonrpc: "2.0".to_string(),
146            id: json!(1),
147            method: "test".to_string(),
148            params: None,
149        };
150        let mut ctx = RequestContext::new(json!(1));
151        let result = chain.process_request(&request, &mut ctx).await;
152        assert!(result.is_err());
153        assert_eq!(result.unwrap_err().code, -32003);
154    }
155}