Skip to main content

browser_control/mcp/
server.rs

1//! Minimal hand-rolled MCP JSON-RPC server over stdio.
2//!
3//! This is the wave-3 skeleton. A future task may replace this with a more
4//! capable framework (e.g. `rmcp`). The protocol surface is small:
5//! newline-delimited JSON-RPC 2.0 over stdin/stdout.
6
7use anyhow::Result;
8use serde_json::{json, Value};
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10
11use crate::cli::env_resolver::ResolvedBrowser;
12
13/// Persistent BiDi session, opened lazily on first use. Reused across all
14/// tool calls because Firefox limits concurrent BiDi sessions per browser
15/// to one.
16pub type BidiCache =
17    std::sync::Arc<tokio::sync::Mutex<Option<(std::sync::Arc<crate::bidi::BidiClient>, String)>>>;
18
19/// State carried by the server. Tools reach into this for the resolved
20/// browser endpoint and any cached engine clients.
21#[derive(Clone)]
22pub struct ServerState {
23    pub browser: ResolvedBrowser,
24    pub bidi: BidiCache,
25}
26
27impl ServerState {
28    pub fn new(browser: ResolvedBrowser) -> Self {
29        Self {
30            browser,
31            bidi: std::sync::Arc::new(tokio::sync::Mutex::new(None)),
32        }
33    }
34}
35
36impl std::fmt::Debug for ServerState {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("ServerState")
39            .field("browser", &self.browser)
40            .finish()
41    }
42}
43
44/// Handler signature: takes `(state, params)` and returns a tool result.
45pub type ToolHandler = std::sync::Arc<
46    dyn Fn(ServerState, Value) -> futures_util::future::BoxFuture<'static, Result<Value>>
47        + Send
48        + Sync,
49>;
50
51pub struct RegisteredTool {
52    pub name: String,
53    pub description: String,
54    pub input_schema: Value,
55    pub handler: ToolHandler,
56}
57
58#[derive(Clone, Default)]
59pub struct ToolRegistry {
60    inner: std::sync::Arc<std::sync::Mutex<Vec<RegisteredTool>>>,
61}
62
63impl ToolRegistry {
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    pub fn register(&self, t: RegisteredTool) {
69        self.inner.lock().unwrap().push(t);
70    }
71
72    pub fn list(&self) -> Vec<Value> {
73        self.inner
74            .lock()
75            .unwrap()
76            .iter()
77            .map(|t| {
78                json!({
79                    "name": t.name,
80                    "description": t.description,
81                    "inputSchema": t.input_schema,
82                })
83            })
84            .collect()
85    }
86
87    pub fn handler(&self, name: &str) -> Option<ToolHandler> {
88        self.inner
89            .lock()
90            .unwrap()
91            .iter()
92            .find(|t| t.name == name)
93            .map(|t| t.handler.clone())
94    }
95}
96
97/// Run the server using the real stdin/stdout.
98pub async fn run(state: ServerState, tools: ToolRegistry) -> Result<()> {
99    run_with_streams(state, tools, tokio::io::stdin(), tokio::io::stdout()).await
100}
101
102/// Run the server with injected I/O streams (used by tests).
103pub async fn run_with_streams<I, O>(
104    state: ServerState,
105    tools: ToolRegistry,
106    stdin: I,
107    mut stdout: O,
108) -> Result<()>
109where
110    I: tokio::io::AsyncRead + Unpin,
111    O: tokio::io::AsyncWrite + Unpin,
112{
113    let mut lines = BufReader::new(stdin).lines();
114    while let Some(line) = lines.next_line().await? {
115        if line.trim().is_empty() {
116            continue;
117        }
118        let req: Value = match serde_json::from_str(&line) {
119            Ok(v) => v,
120            Err(e) => {
121                write_error(
122                    &mut stdout,
123                    Value::Null,
124                    -32700,
125                    &format!("parse error: {e}"),
126                )
127                .await?;
128                continue;
129            }
130        };
131        let id = req.get("id").cloned().unwrap_or(Value::Null);
132        let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
133        let params = req.get("params").cloned().unwrap_or(Value::Null);
134
135        // Notifications: no id, no response.
136        if id.is_null() && method == "notifications/initialized" {
137            continue;
138        }
139
140        let result = match method {
141            "initialize" => Ok(json!({
142                "protocolVersion": "2024-11-05",
143                "capabilities": {"tools": {}},
144                "serverInfo": {
145                    "name": "browser-control",
146                    "version": env!("CARGO_PKG_VERSION"),
147                },
148            })),
149            "ping" => Ok(json!({})),
150            "tools/list" => Ok(json!({"tools": tools.list()})),
151            "tools/call" => {
152                let name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
153                let args = params.get("arguments").cloned().unwrap_or(Value::Null);
154                match tools.handler(name) {
155                    Some(h) => h(state.clone(), args).await,
156                    None => Err(anyhow::anyhow!("tool not found: {name}")),
157                }
158            }
159            _ => {
160                write_error(
161                    &mut stdout,
162                    id,
163                    -32601,
164                    &format!("method not found: {method}"),
165                )
166                .await?;
167                continue;
168            }
169        };
170
171        match result {
172            Ok(v) => write_result(&mut stdout, id, v).await?,
173            Err(e) => write_error(&mut stdout, id, -32000, &e.to_string()).await?,
174        }
175    }
176    Ok(())
177}
178
179async fn write_result<O: tokio::io::AsyncWrite + Unpin>(
180    out: &mut O,
181    id: Value,
182    result: Value,
183) -> Result<()> {
184    let resp = json!({"jsonrpc": "2.0", "id": id, "result": result});
185    let mut s = serde_json::to_vec(&resp)?;
186    s.push(b'\n');
187    out.write_all(&s).await?;
188    out.flush().await?;
189    Ok(())
190}
191
192async fn write_error<O: tokio::io::AsyncWrite + Unpin>(
193    out: &mut O,
194    id: Value,
195    code: i64,
196    message: &str,
197) -> Result<()> {
198    let resp = json!({
199        "jsonrpc": "2.0",
200        "id": id,
201        "error": {"code": code, "message": message},
202    });
203    let mut s = serde_json::to_vec(&resp)?;
204    s.push(b'\n');
205    out.write_all(&s).await?;
206    out.flush().await?;
207    Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::cli::env_resolver::Source;
214    use crate::detect::Engine;
215    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
216
217    fn dummy_resolved() -> ResolvedBrowser {
218        ResolvedBrowser {
219            endpoint: "ws://localhost:9999".into(),
220            engine: Engine::Cdp,
221            source: Source::External,
222        }
223    }
224
225    fn dummy_state() -> ServerState {
226        ServerState::new(dummy_resolved())
227    }
228
229    async fn send_recv(tools: ToolRegistry, requests: &[Value]) -> Vec<Value> {
230        let (mut client_w, server_r) = tokio::io::duplex(8192);
231        let (server_w, client_r) = tokio::io::duplex(8192);
232        let state = dummy_state();
233        let join = tokio::spawn(async move {
234            let _ = run_with_streams(state, tools, server_r, server_w).await;
235        });
236
237        for req in requests {
238            let mut s = serde_json::to_vec(req).unwrap();
239            s.push(b'\n');
240            client_w.write_all(&s).await.unwrap();
241        }
242        // Closing the writer ends the server loop after it drains.
243        drop(client_w);
244
245        let mut reader = BufReader::new(client_r);
246        let mut responses = Vec::new();
247        loop {
248            let mut line = String::new();
249            let n = reader.read_line(&mut line).await.unwrap();
250            if n == 0 {
251                break;
252            }
253            responses.push(serde_json::from_str(line.trim()).unwrap());
254        }
255        let _ = join.await;
256        responses
257    }
258
259    fn echo_tool() -> RegisteredTool {
260        RegisteredTool {
261            name: "echo".to_string(),
262            description: "Echo arguments back".to_string(),
263            input_schema: json!({"type": "object"}),
264            handler: std::sync::Arc::new(|_state, args| {
265                Box::pin(async move { Ok(json!({"echoed": args})) })
266            }),
267        }
268    }
269
270    #[tokio::test]
271    async fn initialize_round_trip() {
272        let resp = send_recv(
273            ToolRegistry::new(),
274            &[json!({"jsonrpc":"2.0","id":1,"method":"initialize","params":{}})],
275        )
276        .await;
277        assert_eq!(resp.len(), 1);
278        assert_eq!(resp[0]["id"], 1);
279        assert_eq!(resp[0]["result"]["protocolVersion"], "2024-11-05");
280        assert_eq!(resp[0]["result"]["serverInfo"]["name"], "browser-control");
281    }
282
283    #[tokio::test]
284    async fn tools_list_empty() {
285        let resp = send_recv(
286            ToolRegistry::new(),
287            &[json!({"jsonrpc":"2.0","id":2,"method":"tools/list"})],
288        )
289        .await;
290        assert_eq!(resp[0]["result"]["tools"], json!([]));
291    }
292
293    #[tokio::test]
294    async fn tools_list_after_register() {
295        let tools = ToolRegistry::new();
296        tools.register(echo_tool());
297        let resp = send_recv(
298            tools,
299            &[json!({"jsonrpc":"2.0","id":3,"method":"tools/list"})],
300        )
301        .await;
302        let list = resp[0]["result"]["tools"].as_array().unwrap();
303        assert_eq!(list.len(), 1);
304        assert_eq!(list[0]["name"], "echo");
305    }
306
307    #[tokio::test]
308    async fn tools_call_unknown_errors() {
309        let resp = send_recv(
310            ToolRegistry::new(),
311            &[json!({
312                "jsonrpc":"2.0","id":4,"method":"tools/call",
313                "params":{"name":"nope","arguments":{}}
314            })],
315        )
316        .await;
317        assert!(resp[0]["error"].is_object());
318        assert!(resp[0]["error"]["message"]
319            .as_str()
320            .unwrap()
321            .contains("nope"));
322    }
323
324    #[tokio::test]
325    async fn tools_call_registered_returns_result() {
326        let tools = ToolRegistry::new();
327        tools.register(echo_tool());
328        let resp = send_recv(
329            tools,
330            &[json!({
331                "jsonrpc":"2.0","id":5,"method":"tools/call",
332                "params":{"name":"echo","arguments":{"hello":"world"}}
333            })],
334        )
335        .await;
336        assert_eq!(resp[0]["result"]["echoed"], json!({"hello":"world"}));
337    }
338
339    #[tokio::test]
340    async fn unknown_method_returns_minus_32601() {
341        let resp = send_recv(
342            ToolRegistry::new(),
343            &[json!({"jsonrpc":"2.0","id":6,"method":"bogus"})],
344        )
345        .await;
346        assert_eq!(resp[0]["error"]["code"], -32601);
347    }
348
349    #[tokio::test]
350    async fn ping_returns_empty_object() {
351        let resp = send_recv(
352            ToolRegistry::new(),
353            &[json!({"jsonrpc":"2.0","id":7,"method":"ping"})],
354        )
355        .await;
356        assert_eq!(resp[0]["result"], json!({}));
357    }
358
359    #[tokio::test]
360    async fn initialized_notification_is_silently_ignored() {
361        // Send notification, then a real request; we should only see the
362        // response to the real request.
363        let resp = send_recv(
364            ToolRegistry::new(),
365            &[
366                json!({"jsonrpc":"2.0","method":"notifications/initialized"}),
367                json!({"jsonrpc":"2.0","id":8,"method":"ping"}),
368            ],
369        )
370        .await;
371        assert_eq!(resp.len(), 1);
372        assert_eq!(resp[0]["id"], 8);
373    }
374}