Skip to main content

brainwires_agent_network/middleware/
mod.rs

1/// Authentication middleware.
2pub mod auth;
3/// Request logging middleware.
4pub mod logging;
5/// Rate limiting middleware.
6pub mod rate_limit;
7/// Tool filtering middleware.
8pub mod tool_filter;
9
10use anyhow::Result;
11use async_trait::async_trait;
12use brainwires_mcp::{JsonRpcError, JsonRpcRequest, JsonRpcResponse};
13
14use crate::connection::RequestContext;
15
16/// Result of middleware processing.
17pub enum MiddlewareResult {
18    /// Allow the request to continue.
19    Continue,
20    /// Reject the request with an error.
21    Reject(JsonRpcError),
22}
23
24/// Trait for request/response middleware.
25#[async_trait]
26pub trait Middleware: Send + Sync + 'static {
27    /// Process an incoming request. Return `Continue` or `Reject`.
28    async fn process_request(
29        &self,
30        request: &JsonRpcRequest,
31        ctx: &mut RequestContext,
32    ) -> MiddlewareResult;
33
34    /// Optionally process the outgoing response (no-op by default).
35    async fn process_response(&self, _response: &mut JsonRpcResponse, _ctx: &RequestContext) {}
36}
37
38/// Ordered chain of middleware layers.
39pub struct MiddlewareChain {
40    layers: Vec<Box<dyn Middleware>>,
41}
42
43impl MiddlewareChain {
44    /// Create a new empty middleware chain.
45    pub fn new() -> Self {
46        Self { layers: Vec::new() }
47    }
48
49    /// Add a middleware layer to the chain.
50    pub fn add(&mut self, middleware: impl Middleware) {
51        self.layers.push(Box::new(middleware));
52    }
53
54    /// Run all middleware on the request, stopping on first reject.
55    pub async fn process_request(
56        &self,
57        request: &JsonRpcRequest,
58        ctx: &mut RequestContext,
59    ) -> Result<(), JsonRpcError> {
60        for layer in &self.layers {
61            match layer.process_request(request, ctx).await {
62                MiddlewareResult::Continue => continue,
63                MiddlewareResult::Reject(err) => return Err(err),
64            }
65        }
66        Ok(())
67    }
68
69    /// Run all middleware on the response.
70    pub async fn process_response(&self, response: &mut JsonRpcResponse, ctx: &RequestContext) {
71        for layer in &self.layers {
72            layer.process_response(response, ctx).await;
73        }
74    }
75}
76
77impl Default for MiddlewareChain {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use serde_json::json;
87
88    struct PassMiddleware;
89
90    #[async_trait]
91    impl Middleware for PassMiddleware {
92        async fn process_request(
93            &self,
94            _request: &JsonRpcRequest,
95            _ctx: &mut RequestContext,
96        ) -> MiddlewareResult {
97            MiddlewareResult::Continue
98        }
99    }
100
101    struct RejectMiddleware;
102
103    #[async_trait]
104    impl Middleware for RejectMiddleware {
105        async fn process_request(
106            &self,
107            _request: &JsonRpcRequest,
108            _ctx: &mut RequestContext,
109        ) -> MiddlewareResult {
110            MiddlewareResult::Reject(JsonRpcError {
111                code: -32003,
112                message: "Rejected".to_string(),
113                data: None,
114            })
115        }
116    }
117
118    #[tokio::test]
119    async fn test_chain_all_pass() {
120        let mut chain = MiddlewareChain::new();
121        chain.add(PassMiddleware);
122        chain.add(PassMiddleware);
123
124        let request = JsonRpcRequest {
125            jsonrpc: "2.0".to_string(),
126            id: json!(1),
127            method: "test".to_string(),
128            params: None,
129        };
130        let mut ctx = RequestContext::new(json!(1));
131        assert!(chain.process_request(&request, &mut ctx).await.is_ok());
132    }
133
134    #[tokio::test]
135    async fn test_chain_reject_stops() {
136        let mut chain = MiddlewareChain::new();
137        chain.add(PassMiddleware);
138        chain.add(RejectMiddleware);
139        chain.add(PassMiddleware);
140
141        let request = JsonRpcRequest {
142            jsonrpc: "2.0".to_string(),
143            id: json!(1),
144            method: "test".to_string(),
145            params: None,
146        };
147        let mut ctx = RequestContext::new(json!(1));
148        let result = chain.process_request(&request, &mut ctx).await;
149        assert!(result.is_err());
150        assert_eq!(result.unwrap_err().code, -32003);
151    }
152}