Skip to main content

brainwires_agent_network/
server.rs

1use anyhow::Result;
2use brainwires_mcp::{InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse};
3use serde_json::{Value, json};
4use tracing;
5
6use crate::connection::{ClientInfo, RequestContext};
7use crate::error::AgentNetworkError;
8use crate::handler::McpHandler;
9use crate::mcp_transport::{ServerTransport, StdioServerTransport};
10use crate::middleware::{Middleware, MiddlewareChain};
11
12/// MCP server that processes JSON-RPC requests via a transport.
13pub struct McpServer<H: McpHandler> {
14    handler: H,
15    middleware: MiddlewareChain,
16    transport: Box<dyn ServerTransport>,
17}
18
19impl<H: McpHandler> McpServer<H> {
20    /// Create a new server with the given handler and stdio transport.
21    pub fn new(handler: H) -> Self {
22        Self {
23            handler,
24            middleware: MiddlewareChain::new(),
25            transport: Box::new(StdioServerTransport::new()),
26        }
27    }
28
29    /// Set a custom transport.
30    pub fn with_transport(mut self, transport: impl ServerTransport + 'static) -> Self {
31        self.transport = Box::new(transport);
32        self
33    }
34
35    /// Add a middleware to the processing pipeline.
36    pub fn with_middleware(mut self, mw: impl Middleware) -> Self {
37        self.middleware.add(mw);
38        self
39    }
40
41    /// Run the server event loop until the transport closes.
42    pub async fn run(mut self) -> Result<()> {
43        let mut ctx = RequestContext::new(json!(null));
44        tracing::info!("MCP Relay server starting");
45
46        loop {
47            let line = match self.transport.read_request().await {
48                Ok(Some(line)) => line,
49                Ok(None) => {
50                    tracing::debug!("Transport closed (EOF)");
51                    break;
52                }
53                Err(e) => {
54                    tracing::error!("Transport read error: {}", e);
55                    break;
56                }
57            };
58
59            let request: JsonRpcRequest = match serde_json::from_str(&line) {
60                Ok(req) => req,
61                Err(e) => {
62                    let error = AgentNetworkError::ParseError(e.to_string());
63                    let response = JsonRpcResponse {
64                        jsonrpc: "2.0".to_string(),
65                        id: json!(null),
66                        result: None,
67                        error: Some(error.to_json_rpc_error()),
68                    };
69                    self.write_response(&response).await?;
70                    continue;
71                }
72            };
73
74            ctx.request_id = request.id.clone();
75
76            // Run middleware chain
77            if let Err(err) = self.middleware.process_request(&request, &mut ctx).await {
78                let response = JsonRpcResponse {
79                    jsonrpc: "2.0".to_string(),
80                    id: request.id.clone(),
81                    result: None,
82                    error: Some(err),
83                };
84                self.write_response(&response).await?;
85                continue;
86            }
87
88            // Dispatch to handler
89            let response = self.handle_request(&request, &mut ctx).await;
90
91            // Run response middleware
92            let mut response = response;
93            self.middleware.process_response(&mut response, &ctx).await;
94
95            self.write_response(&response).await?;
96        }
97
98        self.handler.on_shutdown().await?;
99        tracing::info!("MCP Relay server shut down");
100        Ok(())
101    }
102
103    async fn handle_request(
104        &self,
105        request: &JsonRpcRequest,
106        ctx: &mut RequestContext,
107    ) -> JsonRpcResponse {
108        match request.method.as_str() {
109            "initialize" => self.handle_initialize(request, ctx).await,
110            "notifications/initialized" => {
111                // Client confirming initialization - no response needed but we return success
112                JsonRpcResponse {
113                    jsonrpc: "2.0".to_string(),
114                    id: request.id.clone(),
115                    result: Some(json!({})),
116                    error: None,
117                }
118            }
119            "tools/list" => self.handle_list_tools(request).await,
120            "tools/call" => self.handle_call_tool(request, ctx).await,
121            _ => {
122                let error = AgentNetworkError::MethodNotFound(request.method.clone());
123                JsonRpcResponse {
124                    jsonrpc: "2.0".to_string(),
125                    id: request.id.clone(),
126                    result: None,
127                    error: Some(error.to_json_rpc_error()),
128                }
129            }
130        }
131    }
132
133    async fn handle_initialize(
134        &self,
135        request: &JsonRpcRequest,
136        ctx: &mut RequestContext,
137    ) -> JsonRpcResponse {
138        let params: InitializeParams = match request
139            .params
140            .as_ref()
141            .and_then(|p| serde_json::from_value(p.clone()).ok())
142        {
143            Some(p) => p,
144            None => {
145                // Allow initialize without params for compatibility
146                InitializeParams {
147                    protocol_version: "2024-11-05".to_string(),
148                    capabilities: Default::default(),
149                    client_info: brainwires_mcp::ClientInfo {
150                        name: "unknown".to_string(),
151                        version: "0.6.0".to_string(),
152                    },
153                }
154            }
155        };
156
157        ctx.client_info = Some(ClientInfo {
158            name: params.client_info.name.clone(),
159            version: params.client_info.version.clone(),
160        });
161        ctx.set_initialized();
162
163        if let Err(e) = self.handler.on_initialize(&params).await {
164            tracing::error!("Handler on_initialize failed: {}", e);
165        }
166
167        let info = self.handler.server_info();
168        let capabilities = self.handler.capabilities();
169
170        let result = InitializeResult {
171            protocol_version: "2024-11-05".to_string(),
172            capabilities,
173            server_info: info,
174        };
175
176        JsonRpcResponse {
177            jsonrpc: "2.0".to_string(),
178            id: request.id.clone(),
179            result: serde_json::to_value(result).ok(),
180            error: None,
181        }
182    }
183
184    async fn handle_list_tools(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
185        let tool_defs = self.handler.list_tools();
186
187        let tools: Vec<Value> = tool_defs
188            .iter()
189            .map(|t| {
190                json!({
191                    "name": t.name,
192                    "description": t.description,
193                    "inputSchema": t.input_schema,
194                })
195            })
196            .collect();
197
198        JsonRpcResponse {
199            jsonrpc: "2.0".to_string(),
200            id: request.id.clone(),
201            result: Some(json!({ "tools": tools })),
202            error: None,
203        }
204    }
205
206    async fn handle_call_tool(
207        &self,
208        request: &JsonRpcRequest,
209        ctx: &RequestContext,
210    ) -> JsonRpcResponse {
211        let params = match &request.params {
212            Some(p) => p,
213            None => {
214                let error =
215                    AgentNetworkError::InvalidParams("Missing params for tools/call".to_string());
216                return JsonRpcResponse {
217                    jsonrpc: "2.0".to_string(),
218                    id: request.id.clone(),
219                    result: None,
220                    error: Some(error.to_json_rpc_error()),
221                };
222            }
223        };
224
225        let tool_name = match params.get("name").and_then(|n| n.as_str()) {
226            Some(name) => name,
227            None => {
228                let error =
229                    AgentNetworkError::InvalidParams("Missing 'name' in tools/call".to_string());
230                return JsonRpcResponse {
231                    jsonrpc: "2.0".to_string(),
232                    id: request.id.clone(),
233                    result: None,
234                    error: Some(error.to_json_rpc_error()),
235                };
236            }
237        };
238
239        let args = params.get("arguments").cloned().unwrap_or(json!({}));
240
241        match self.handler.call_tool(tool_name, args, ctx).await {
242            Ok(result) => {
243                let result_value = serde_json::to_value(result).unwrap_or(json!({}));
244                JsonRpcResponse {
245                    jsonrpc: "2.0".to_string(),
246                    id: request.id.clone(),
247                    result: Some(result_value),
248                    error: None,
249                }
250            }
251            Err(e) => {
252                let error = AgentNetworkError::Internal(e);
253                JsonRpcResponse {
254                    jsonrpc: "2.0".to_string(),
255                    id: request.id.clone(),
256                    result: None,
257                    error: Some(error.to_json_rpc_error()),
258                }
259            }
260        }
261    }
262
263    async fn write_response(&mut self, response: &JsonRpcResponse) -> Result<()> {
264        let json = serde_json::to_string(response)?;
265        self.transport.write_response(&json).await
266    }
267}