brainwires_mcp_server/middleware/
mod.rs1pub mod auth;
3pub mod logging;
5#[cfg(feature = "oauth")]
7pub mod oauth;
8pub mod rate_limit;
10pub mod tool_filter;
12
13use anyhow::Result;
14use async_trait::async_trait;
15use brainwires_mcp::{JsonRpcError, JsonRpcRequest, JsonRpcResponse};
16
17use crate::connection::RequestContext;
18
19pub enum MiddlewareResult {
21 Continue,
23 Reject(JsonRpcError),
25}
26
27#[async_trait]
29pub trait Middleware: Send + Sync + 'static {
30 async fn process_request(
32 &self,
33 request: &JsonRpcRequest,
34 ctx: &mut RequestContext,
35 ) -> MiddlewareResult;
36
37 async fn process_response(&self, _response: &mut JsonRpcResponse, _ctx: &RequestContext) {}
39}
40
41pub struct MiddlewareChain {
43 layers: Vec<Box<dyn Middleware>>,
44}
45
46impl MiddlewareChain {
47 pub fn new() -> Self {
49 Self { layers: Vec::new() }
50 }
51
52 pub fn add(&mut self, middleware: impl Middleware) {
54 self.layers.push(Box::new(middleware));
55 }
56
57 pub async fn process_request(
59 &self,
60 request: &JsonRpcRequest,
61 ctx: &mut RequestContext,
62 ) -> Result<(), JsonRpcError> {
63 for layer in &self.layers {
64 match layer.process_request(request, ctx).await {
65 MiddlewareResult::Continue => continue,
66 MiddlewareResult::Reject(err) => return Err(err),
67 }
68 }
69 Ok(())
70 }
71
72 pub async fn process_response(&self, response: &mut JsonRpcResponse, ctx: &RequestContext) {
74 for layer in &self.layers {
75 layer.process_response(response, ctx).await;
76 }
77 }
78}
79
80impl Default for MiddlewareChain {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use serde_json::json;
90
91 struct PassMiddleware;
92
93 #[async_trait]
94 impl Middleware for PassMiddleware {
95 async fn process_request(
96 &self,
97 _request: &JsonRpcRequest,
98 _ctx: &mut RequestContext,
99 ) -> MiddlewareResult {
100 MiddlewareResult::Continue
101 }
102 }
103
104 struct RejectMiddleware;
105
106 #[async_trait]
107 impl Middleware for RejectMiddleware {
108 async fn process_request(
109 &self,
110 _request: &JsonRpcRequest,
111 _ctx: &mut RequestContext,
112 ) -> MiddlewareResult {
113 MiddlewareResult::Reject(JsonRpcError {
114 code: -32003,
115 message: "Rejected".to_string(),
116 data: None,
117 })
118 }
119 }
120
121 #[tokio::test]
122 async fn test_chain_all_pass() {
123 let mut chain = MiddlewareChain::new();
124 chain.add(PassMiddleware);
125 chain.add(PassMiddleware);
126
127 let request = JsonRpcRequest {
128 jsonrpc: "2.0".to_string(),
129 id: json!(1),
130 method: "test".to_string(),
131 params: None,
132 };
133 let mut ctx = RequestContext::new(json!(1));
134 assert!(chain.process_request(&request, &mut ctx).await.is_ok());
135 }
136
137 #[tokio::test]
138 async fn test_chain_reject_stops() {
139 let mut chain = MiddlewareChain::new();
140 chain.add(PassMiddleware);
141 chain.add(RejectMiddleware);
142 chain.add(PassMiddleware);
143
144 let request = JsonRpcRequest {
145 jsonrpc: "2.0".to_string(),
146 id: json!(1),
147 method: "test".to_string(),
148 params: None,
149 };
150 let mut ctx = RequestContext::new(json!(1));
151 let result = chain.process_request(&request, &mut ctx).await;
152 assert!(result.is_err());
153 assert_eq!(result.unwrap_err().code, -32003);
154 }
155}