brainwires_agent_network/middleware/
mod.rs1pub mod auth;
3pub mod logging;
5pub mod rate_limit;
7pub mod tool_filter;
9
10use anyhow::Result;
11use async_trait::async_trait;
12use brainwires_mcp::{JsonRpcError, JsonRpcRequest, JsonRpcResponse};
13
14use crate::connection::RequestContext;
15
16pub enum MiddlewareResult {
18 Continue,
20 Reject(JsonRpcError),
22}
23
24#[async_trait]
26pub trait Middleware: Send + Sync + 'static {
27 async fn process_request(
29 &self,
30 request: &JsonRpcRequest,
31 ctx: &mut RequestContext,
32 ) -> MiddlewareResult;
33
34 async fn process_response(&self, _response: &mut JsonRpcResponse, _ctx: &RequestContext) {}
36}
37
38pub struct MiddlewareChain {
40 layers: Vec<Box<dyn Middleware>>,
41}
42
43impl MiddlewareChain {
44 pub fn new() -> Self {
46 Self { layers: Vec::new() }
47 }
48
49 pub fn add(&mut self, middleware: impl Middleware) {
51 self.layers.push(Box::new(middleware));
52 }
53
54 pub async fn process_request(
56 &self,
57 request: &JsonRpcRequest,
58 ctx: &mut RequestContext,
59 ) -> Result<(), JsonRpcError> {
60 for layer in &self.layers {
61 match layer.process_request(request, ctx).await {
62 MiddlewareResult::Continue => continue,
63 MiddlewareResult::Reject(err) => return Err(err),
64 }
65 }
66 Ok(())
67 }
68
69 pub async fn process_response(&self, response: &mut JsonRpcResponse, ctx: &RequestContext) {
71 for layer in &self.layers {
72 layer.process_response(response, ctx).await;
73 }
74 }
75}
76
77impl Default for MiddlewareChain {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use serde_json::json;
87
88 struct PassMiddleware;
89
90 #[async_trait]
91 impl Middleware for PassMiddleware {
92 async fn process_request(
93 &self,
94 _request: &JsonRpcRequest,
95 _ctx: &mut RequestContext,
96 ) -> MiddlewareResult {
97 MiddlewareResult::Continue
98 }
99 }
100
101 struct RejectMiddleware;
102
103 #[async_trait]
104 impl Middleware for RejectMiddleware {
105 async fn process_request(
106 &self,
107 _request: &JsonRpcRequest,
108 _ctx: &mut RequestContext,
109 ) -> MiddlewareResult {
110 MiddlewareResult::Reject(JsonRpcError {
111 code: -32003,
112 message: "Rejected".to_string(),
113 data: None,
114 })
115 }
116 }
117
118 #[tokio::test]
119 async fn test_chain_all_pass() {
120 let mut chain = MiddlewareChain::new();
121 chain.add(PassMiddleware);
122 chain.add(PassMiddleware);
123
124 let request = JsonRpcRequest {
125 jsonrpc: "2.0".to_string(),
126 id: json!(1),
127 method: "test".to_string(),
128 params: None,
129 };
130 let mut ctx = RequestContext::new(json!(1));
131 assert!(chain.process_request(&request, &mut ctx).await.is_ok());
132 }
133
134 #[tokio::test]
135 async fn test_chain_reject_stops() {
136 let mut chain = MiddlewareChain::new();
137 chain.add(PassMiddleware);
138 chain.add(RejectMiddleware);
139 chain.add(PassMiddleware);
140
141 let request = JsonRpcRequest {
142 jsonrpc: "2.0".to_string(),
143 id: json!(1),
144 method: "test".to_string(),
145 params: None,
146 };
147 let mut ctx = RequestContext::new(json!(1));
148 let result = chain.process_request(&request, &mut ctx).await;
149 assert!(result.is_err());
150 assert_eq!(result.unwrap_err().code, -32003);
151 }
152}