Skip to main content

browser_control/mcp/
server.rs

1//! Minimal hand-rolled MCP JSON-RPC server over stdio.
2//!
3//! This is the wave-3 skeleton. A future task may replace this with a more
4//! capable framework (e.g. `rmcp`). The protocol surface is small:
5//! newline-delimited JSON-RPC 2.0 over stdin/stdout.
6
7use anyhow::Result;
8use serde_json::{json, Value};
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10
11use crate::cli::env_resolver::ResolvedBrowser;
12
13/// Persistent BiDi client, opened lazily on first use. Reused across all
14/// tool calls because Firefox limits concurrent BiDi sessions per browser
15/// to one. The browsing-context id is resolved per call so multiple tabs
16/// can be addressed once URL-regex selection is added.
17pub type BidiCache =
18    std::sync::Arc<tokio::sync::Mutex<Option<std::sync::Arc<crate::bidi::BidiClient>>>>;
19
20/// State carried by the server. Tools reach into this for the resolved
21/// browser endpoint and any cached engine clients.
22#[derive(Clone)]
23pub struct ServerState {
24    pub browser: ResolvedBrowser,
25    pub bidi: BidiCache,
26}
27
28impl ServerState {
29    pub fn new(browser: ResolvedBrowser) -> Self {
30        Self {
31            browser,
32            bidi: std::sync::Arc::new(tokio::sync::Mutex::new(None)),
33        }
34    }
35}
36
37impl std::fmt::Debug for ServerState {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("ServerState")
40            .field("browser", &self.browser)
41            .finish()
42    }
43}
44
45/// Handler signature: takes `(state, params)` and returns a tool result.
46pub type ToolHandler = std::sync::Arc<
47    dyn Fn(ServerState, Value) -> futures_util::future::BoxFuture<'static, Result<Value>>
48        + Send
49        + Sync,
50>;
51
52pub struct RegisteredTool {
53    pub name: String,
54    pub description: String,
55    pub input_schema: Value,
56    pub handler: ToolHandler,
57}
58
59#[derive(Clone, Default)]
60pub struct ToolRegistry {
61    inner: std::sync::Arc<std::sync::Mutex<Vec<RegisteredTool>>>,
62}
63
64impl ToolRegistry {
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    pub fn register(&self, t: RegisteredTool) {
70        self.inner.lock().unwrap().push(t);
71    }
72
73    pub fn list(&self) -> Vec<Value> {
74        self.inner
75            .lock()
76            .unwrap()
77            .iter()
78            .map(|t| {
79                json!({
80                    "name": t.name,
81                    "description": t.description,
82                    "inputSchema": t.input_schema,
83                })
84            })
85            .collect()
86    }
87
88    pub fn handler(&self, name: &str) -> Option<ToolHandler> {
89        self.inner
90            .lock()
91            .unwrap()
92            .iter()
93            .find(|t| t.name == name)
94            .map(|t| t.handler.clone())
95    }
96}
97
98/// Run the server using the real stdin/stdout.
99pub async fn run(state: ServerState, tools: ToolRegistry) -> Result<()> {
100    run_with_streams(state, tools, tokio::io::stdin(), tokio::io::stdout()).await
101}
102
103/// Run the server with injected I/O streams (used by tests).
104pub async fn run_with_streams<I, O>(
105    state: ServerState,
106    tools: ToolRegistry,
107    stdin: I,
108    mut stdout: O,
109) -> Result<()>
110where
111    I: tokio::io::AsyncRead + Unpin,
112    O: tokio::io::AsyncWrite + Unpin,
113{
114    let mut lines = BufReader::new(stdin).lines();
115    while let Some(line) = lines.next_line().await? {
116        if line.trim().is_empty() {
117            continue;
118        }
119        let req: Value = match serde_json::from_str(&line) {
120            Ok(v) => v,
121            Err(e) => {
122                write_error(
123                    &mut stdout,
124                    Value::Null,
125                    -32700,
126                    &format!("parse error: {e}"),
127                )
128                .await?;
129                continue;
130            }
131        };
132        let id = req.get("id").cloned().unwrap_or(Value::Null);
133        let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
134        let params = req.get("params").cloned().unwrap_or(Value::Null);
135
136        // Notifications: no id, no response.
137        if id.is_null() && method == "notifications/initialized" {
138            continue;
139        }
140
141        let result = match method {
142            "initialize" => Ok(json!({
143                "protocolVersion": "2024-11-05",
144                "capabilities": {"tools": {}},
145                "serverInfo": {
146                    "name": "browser-control",
147                    "version": env!("CARGO_PKG_VERSION"),
148                },
149            })),
150            "ping" => Ok(json!({})),
151            "tools/list" => Ok(json!({"tools": tools.list()})),
152            "tools/call" => {
153                let name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
154                let args = params.get("arguments").cloned().unwrap_or(Value::Null);
155                match tools.handler(name) {
156                    Some(h) => h(state.clone(), args).await,
157                    None => Err(anyhow::anyhow!("tool not found: {name}")),
158                }
159            }
160            _ => {
161                write_error(
162                    &mut stdout,
163                    id,
164                    -32601,
165                    &format!("method not found: {method}"),
166                )
167                .await?;
168                continue;
169            }
170        };
171
172        match result {
173            Ok(v) => write_result(&mut stdout, id, v).await?,
174            Err(e) => write_error(&mut stdout, id, -32000, &e.to_string()).await?,
175        }
176    }
177    Ok(())
178}
179
180async fn write_result<O: tokio::io::AsyncWrite + Unpin>(
181    out: &mut O,
182    id: Value,
183    result: Value,
184) -> Result<()> {
185    let resp = json!({"jsonrpc": "2.0", "id": id, "result": result});
186    let mut s = serde_json::to_vec(&resp)?;
187    s.push(b'\n');
188    out.write_all(&s).await?;
189    out.flush().await?;
190    Ok(())
191}
192
193async fn write_error<O: tokio::io::AsyncWrite + Unpin>(
194    out: &mut O,
195    id: Value,
196    code: i64,
197    message: &str,
198) -> Result<()> {
199    let resp = json!({
200        "jsonrpc": "2.0",
201        "id": id,
202        "error": {"code": code, "message": message},
203    });
204    let mut s = serde_json::to_vec(&resp)?;
205    s.push(b'\n');
206    out.write_all(&s).await?;
207    out.flush().await?;
208    Ok(())
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::cli::env_resolver::Source;
215    use crate::detect::Engine;
216    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
217
218    fn dummy_resolved() -> ResolvedBrowser {
219        ResolvedBrowser {
220            endpoint: "ws://localhost:9999".into(),
221            engine: Engine::Cdp,
222            source: Source::External,
223        }
224    }
225
226    fn dummy_state() -> ServerState {
227        ServerState::new(dummy_resolved())
228    }
229
230    async fn send_recv(tools: ToolRegistry, requests: &[Value]) -> Vec<Value> {
231        let (mut client_w, server_r) = tokio::io::duplex(8192);
232        let (server_w, client_r) = tokio::io::duplex(8192);
233        let state = dummy_state();
234        let join = tokio::spawn(async move {
235            let _ = run_with_streams(state, tools, server_r, server_w).await;
236        });
237
238        for req in requests {
239            let mut s = serde_json::to_vec(req).unwrap();
240            s.push(b'\n');
241            client_w.write_all(&s).await.unwrap();
242        }
243        // Closing the writer ends the server loop after it drains.
244        drop(client_w);
245
246        let mut reader = BufReader::new(client_r);
247        let mut responses = Vec::new();
248        loop {
249            let mut line = String::new();
250            let n = reader.read_line(&mut line).await.unwrap();
251            if n == 0 {
252                break;
253            }
254            responses.push(serde_json::from_str(line.trim()).unwrap());
255        }
256        let _ = join.await;
257        responses
258    }
259
260    fn echo_tool() -> RegisteredTool {
261        RegisteredTool {
262            name: "echo".to_string(),
263            description: "Echo arguments back".to_string(),
264            input_schema: json!({"type": "object"}),
265            handler: std::sync::Arc::new(|_state, args| {
266                Box::pin(async move { Ok(json!({"echoed": args})) })
267            }),
268        }
269    }
270
271    #[tokio::test]
272    async fn initialize_round_trip() {
273        let resp = send_recv(
274            ToolRegistry::new(),
275            &[json!({"jsonrpc":"2.0","id":1,"method":"initialize","params":{}})],
276        )
277        .await;
278        assert_eq!(resp.len(), 1);
279        assert_eq!(resp[0]["id"], 1);
280        assert_eq!(resp[0]["result"]["protocolVersion"], "2024-11-05");
281        assert_eq!(resp[0]["result"]["serverInfo"]["name"], "browser-control");
282    }
283
284    #[tokio::test]
285    async fn tools_list_empty() {
286        let resp = send_recv(
287            ToolRegistry::new(),
288            &[json!({"jsonrpc":"2.0","id":2,"method":"tools/list"})],
289        )
290        .await;
291        assert_eq!(resp[0]["result"]["tools"], json!([]));
292    }
293
294    #[tokio::test]
295    async fn tools_list_after_register() {
296        let tools = ToolRegistry::new();
297        tools.register(echo_tool());
298        let resp = send_recv(
299            tools,
300            &[json!({"jsonrpc":"2.0","id":3,"method":"tools/list"})],
301        )
302        .await;
303        let list = resp[0]["result"]["tools"].as_array().unwrap();
304        assert_eq!(list.len(), 1);
305        assert_eq!(list[0]["name"], "echo");
306    }
307
308    #[tokio::test]
309    async fn tools_call_unknown_errors() {
310        let resp = send_recv(
311            ToolRegistry::new(),
312            &[json!({
313                "jsonrpc":"2.0","id":4,"method":"tools/call",
314                "params":{"name":"nope","arguments":{}}
315            })],
316        )
317        .await;
318        assert!(resp[0]["error"].is_object());
319        assert!(resp[0]["error"]["message"]
320            .as_str()
321            .unwrap()
322            .contains("nope"));
323    }
324
325    #[tokio::test]
326    async fn tools_call_registered_returns_result() {
327        let tools = ToolRegistry::new();
328        tools.register(echo_tool());
329        let resp = send_recv(
330            tools,
331            &[json!({
332                "jsonrpc":"2.0","id":5,"method":"tools/call",
333                "params":{"name":"echo","arguments":{"hello":"world"}}
334            })],
335        )
336        .await;
337        assert_eq!(resp[0]["result"]["echoed"], json!({"hello":"world"}));
338    }
339
340    #[tokio::test]
341    async fn unknown_method_returns_minus_32601() {
342        let resp = send_recv(
343            ToolRegistry::new(),
344            &[json!({"jsonrpc":"2.0","id":6,"method":"bogus"})],
345        )
346        .await;
347        assert_eq!(resp[0]["error"]["code"], -32601);
348    }
349
350    #[tokio::test]
351    async fn ping_returns_empty_object() {
352        let resp = send_recv(
353            ToolRegistry::new(),
354            &[json!({"jsonrpc":"2.0","id":7,"method":"ping"})],
355        )
356        .await;
357        assert_eq!(resp[0]["result"], json!({}));
358    }
359
360    #[tokio::test]
361    async fn initialized_notification_is_silently_ignored() {
362        // Send notification, then a real request; we should only see the
363        // response to the real request.
364        let resp = send_recv(
365            ToolRegistry::new(),
366            &[
367                json!({"jsonrpc":"2.0","method":"notifications/initialized"}),
368                json!({"jsonrpc":"2.0","id":8,"method":"ping"}),
369            ],
370        )
371        .await;
372        assert_eq!(resp.len(), 1);
373        assert_eq!(resp[0]["id"], 8);
374    }
375}