brainwires_agent_network/
server.rs1use anyhow::Result;
2use brainwires_mcp::{InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse};
3use serde_json::{Value, json};
4use tracing;
5
6use crate::connection::{ClientInfo, RequestContext};
7use crate::error::AgentNetworkError;
8use crate::handler::McpHandler;
9use crate::mcp_transport::{ServerTransport, StdioServerTransport};
10use crate::middleware::{Middleware, MiddlewareChain};
11
12pub struct McpServer<H: McpHandler> {
14 handler: H,
15 middleware: MiddlewareChain,
16 transport: Box<dyn ServerTransport>,
17}
18
19impl<H: McpHandler> McpServer<H> {
20 pub fn new(handler: H) -> Self {
22 Self {
23 handler,
24 middleware: MiddlewareChain::new(),
25 transport: Box::new(StdioServerTransport::new()),
26 }
27 }
28
29 pub fn with_transport(mut self, transport: impl ServerTransport + 'static) -> Self {
31 self.transport = Box::new(transport);
32 self
33 }
34
35 pub fn with_middleware(mut self, mw: impl Middleware) -> Self {
37 self.middleware.add(mw);
38 self
39 }
40
41 pub async fn run(mut self) -> Result<()> {
43 let mut ctx = RequestContext::new(json!(null));
44 tracing::info!("MCP Relay server starting");
45
46 loop {
47 let line = match self.transport.read_request().await {
48 Ok(Some(line)) => line,
49 Ok(None) => {
50 tracing::debug!("Transport closed (EOF)");
51 break;
52 }
53 Err(e) => {
54 tracing::error!("Transport read error: {}", e);
55 break;
56 }
57 };
58
59 let request: JsonRpcRequest = match serde_json::from_str(&line) {
60 Ok(req) => req,
61 Err(e) => {
62 let error = AgentNetworkError::ParseError(e.to_string());
63 let response = JsonRpcResponse {
64 jsonrpc: "2.0".to_string(),
65 id: json!(null),
66 result: None,
67 error: Some(error.to_json_rpc_error()),
68 };
69 self.write_response(&response).await?;
70 continue;
71 }
72 };
73
74 ctx.request_id = request.id.clone();
75
76 if let Err(err) = self.middleware.process_request(&request, &mut ctx).await {
78 let response = JsonRpcResponse {
79 jsonrpc: "2.0".to_string(),
80 id: request.id.clone(),
81 result: None,
82 error: Some(err),
83 };
84 self.write_response(&response).await?;
85 continue;
86 }
87
88 let response = self.handle_request(&request, &mut ctx).await;
90
91 let mut response = response;
93 self.middleware.process_response(&mut response, &ctx).await;
94
95 self.write_response(&response).await?;
96 }
97
98 self.handler.on_shutdown().await?;
99 tracing::info!("MCP Relay server shut down");
100 Ok(())
101 }
102
103 async fn handle_request(
104 &self,
105 request: &JsonRpcRequest,
106 ctx: &mut RequestContext,
107 ) -> JsonRpcResponse {
108 match request.method.as_str() {
109 "initialize" => self.handle_initialize(request, ctx).await,
110 "notifications/initialized" => {
111 JsonRpcResponse {
113 jsonrpc: "2.0".to_string(),
114 id: request.id.clone(),
115 result: Some(json!({})),
116 error: None,
117 }
118 }
119 "tools/list" => self.handle_list_tools(request).await,
120 "tools/call" => self.handle_call_tool(request, ctx).await,
121 _ => {
122 let error = AgentNetworkError::MethodNotFound(request.method.clone());
123 JsonRpcResponse {
124 jsonrpc: "2.0".to_string(),
125 id: request.id.clone(),
126 result: None,
127 error: Some(error.to_json_rpc_error()),
128 }
129 }
130 }
131 }
132
133 async fn handle_initialize(
134 &self,
135 request: &JsonRpcRequest,
136 ctx: &mut RequestContext,
137 ) -> JsonRpcResponse {
138 let params: InitializeParams = match request
139 .params
140 .as_ref()
141 .and_then(|p| serde_json::from_value(p.clone()).ok())
142 {
143 Some(p) => p,
144 None => {
145 InitializeParams {
147 protocol_version: "2024-11-05".to_string(),
148 capabilities: Default::default(),
149 client_info: brainwires_mcp::ClientInfo {
150 name: "unknown".to_string(),
151 version: "0.6.0".to_string(),
152 },
153 }
154 }
155 };
156
157 ctx.client_info = Some(ClientInfo {
158 name: params.client_info.name.clone(),
159 version: params.client_info.version.clone(),
160 });
161 ctx.set_initialized();
162
163 if let Err(e) = self.handler.on_initialize(¶ms).await {
164 tracing::error!("Handler on_initialize failed: {}", e);
165 }
166
167 let info = self.handler.server_info();
168 let capabilities = self.handler.capabilities();
169
170 let result = InitializeResult {
171 protocol_version: "2024-11-05".to_string(),
172 capabilities,
173 server_info: info,
174 };
175
176 JsonRpcResponse {
177 jsonrpc: "2.0".to_string(),
178 id: request.id.clone(),
179 result: serde_json::to_value(result).ok(),
180 error: None,
181 }
182 }
183
184 async fn handle_list_tools(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
185 let tool_defs = self.handler.list_tools();
186
187 let tools: Vec<Value> = tool_defs
188 .iter()
189 .map(|t| {
190 json!({
191 "name": t.name,
192 "description": t.description,
193 "inputSchema": t.input_schema,
194 })
195 })
196 .collect();
197
198 JsonRpcResponse {
199 jsonrpc: "2.0".to_string(),
200 id: request.id.clone(),
201 result: Some(json!({ "tools": tools })),
202 error: None,
203 }
204 }
205
206 async fn handle_call_tool(
207 &self,
208 request: &JsonRpcRequest,
209 ctx: &RequestContext,
210 ) -> JsonRpcResponse {
211 let params = match &request.params {
212 Some(p) => p,
213 None => {
214 let error =
215 AgentNetworkError::InvalidParams("Missing params for tools/call".to_string());
216 return JsonRpcResponse {
217 jsonrpc: "2.0".to_string(),
218 id: request.id.clone(),
219 result: None,
220 error: Some(error.to_json_rpc_error()),
221 };
222 }
223 };
224
225 let tool_name = match params.get("name").and_then(|n| n.as_str()) {
226 Some(name) => name,
227 None => {
228 let error =
229 AgentNetworkError::InvalidParams("Missing 'name' in tools/call".to_string());
230 return JsonRpcResponse {
231 jsonrpc: "2.0".to_string(),
232 id: request.id.clone(),
233 result: None,
234 error: Some(error.to_json_rpc_error()),
235 };
236 }
237 };
238
239 let args = params.get("arguments").cloned().unwrap_or(json!({}));
240
241 match self.handler.call_tool(tool_name, args, ctx).await {
242 Ok(result) => {
243 let result_value = serde_json::to_value(result).unwrap_or(json!({}));
244 JsonRpcResponse {
245 jsonrpc: "2.0".to_string(),
246 id: request.id.clone(),
247 result: Some(result_value),
248 error: None,
249 }
250 }
251 Err(e) => {
252 let error = AgentNetworkError::Internal(e);
253 JsonRpcResponse {
254 jsonrpc: "2.0".to_string(),
255 id: request.id.clone(),
256 result: None,
257 error: Some(error.to_json_rpc_error()),
258 }
259 }
260 }
261 }
262
263 async fn write_response(&mut self, response: &JsonRpcResponse) -> Result<()> {
264 let json = serde_json::to_string(response)?;
265 self.transport.write_response(&json).await
266 }
267}