Skip to main content

victauri_plugin/mcp/
mod.rs

1mod backend_params;
2mod compound_params;
3mod helpers;
4mod other_params;
5mod server;
6mod verification_params;
7mod webview_params;
8mod window_params;
9
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use rmcp::handler::server::tool::ToolCallContext;
15use rmcp::handler::server::wrapper::Parameters;
16use rmcp::model::{
17    AnnotateAble, CallToolRequestParams, CallToolResult, Content, ListResourcesResult,
18    ListToolsResult, PaginatedRequestParams, RawResource, ReadResourceRequestParams,
19    ReadResourceResult, ResourceContents, ServerCapabilities, ServerInfo, SubscribeRequestParams,
20    Tool, UnsubscribeRequestParams,
21};
22use rmcp::service::RequestContext;
23use rmcp::{ErrorData, RoleServer, ServerHandler, tool, tool_router};
24use tokio::sync::Mutex;
25
26use crate::VictauriState;
27use crate::bridge::WebviewBridge;
28
29use helpers::{
30    js_string, json_result, missing_param, sanitize_css_color, tool_disabled, tool_error,
31    validate_url,
32};
33
34pub use backend_params::*;
35pub use compound_params::*;
36pub use other_params::{
37    FindElementsParams, ResolveCommandParams, SemanticAssertParams, WaitCondition, WaitForParams,
38};
39pub use server::*;
40pub use verification_params::*;
41pub use webview_params::*;
42pub use window_params::*;
43
44// ── MCP Handler ──────────────────────────────────────────────────────────────
45
46/// Maximum number of in-flight JavaScript eval requests. Prevents unbounded
47/// growth of the `pending_evals` map if callbacks are never resolved.
48pub(crate) const MAX_PENDING_EVALS: usize = 100;
49
50const RESOURCE_URI_IPC_LOG: &str = "victauri://ipc-log";
51const RESOURCE_URI_WINDOWS: &str = "victauri://windows";
52const RESOURCE_URI_STATE: &str = "victauri://state";
53
54const BRIDGE_VERSION: &str = "0.3.0";
55
56/// MCP tool handler that dispatches tool calls to the webview bridge and state.
57#[derive(Clone)]
58pub struct VictauriMcpHandler {
59    state: Arc<VictauriState>,
60    bridge: Arc<dyn WebviewBridge>,
61    subscriptions: Arc<Mutex<HashSet<String>>>,
62    bridge_checked: Arc<AtomicBool>,
63}
64
65#[tool_router]
66impl VictauriMcpHandler {
67    // ── Standalone Tools ────────────────────────────────────────────────────
68
69    #[tool(
70        description = "Evaluate JavaScript in the Tauri webview and return the result. Async expressions are wrapped automatically.",
71        annotations(
72            read_only_hint = false,
73            destructive_hint = true,
74            idempotent_hint = false,
75            open_world_hint = false
76        )
77    )]
78    async fn eval_js(&self, Parameters(params): Parameters<EvalJsParams>) -> CallToolResult {
79        if !self.state.privacy.is_tool_enabled("eval_js") {
80            return tool_disabled("eval_js");
81        }
82        match self
83            .eval_with_return(&params.code, params.webview_label.as_deref())
84            .await
85        {
86            Ok(result) => self.redact_result(result),
87            Err(e) => tool_error(e),
88        }
89    }
90
91    #[tool(
92        description = "Get the DOM snapshot with stable ref handles. Default: compact accessible text (70-80%% fewer tokens). Set format=\"json\" for full tree. Returns tree + stale_refs (refs invalidated since last snapshot).",
93        annotations(
94            read_only_hint = true,
95            destructive_hint = false,
96            idempotent_hint = true,
97            open_world_hint = false
98        )
99    )]
100    async fn dom_snapshot(&self, Parameters(params): Parameters<SnapshotParams>) -> CallToolResult {
101        let format = params.format.unwrap_or(SnapshotFormat::Compact);
102        let format_str = match format {
103            SnapshotFormat::Compact => "compact",
104            SnapshotFormat::Json => "json",
105        };
106        let code = format!(
107            "return window.__VICTAURI__?.snapshot({})",
108            js_string(format_str)
109        );
110        self.eval_bridge(&code, params.webview_label.as_deref())
111            .await
112    }
113
114    #[tool(
115        description = "Search for elements by text, role, test_id, CSS selector, or accessible name without a full snapshot. Returns lightweight matches with ref handles.",
116        annotations(
117            read_only_hint = true,
118            destructive_hint = false,
119            idempotent_hint = true,
120            open_world_hint = false
121        )
122    )]
123    async fn find_elements(
124        &self,
125        Parameters(params): Parameters<FindElementsParams>,
126    ) -> CallToolResult {
127        let mut parts: Vec<String> = Vec::new();
128        if let Some(t) = &params.text {
129            parts.push(format!("text: {}", js_string(t)));
130        }
131        if let Some(r) = &params.role {
132            parts.push(format!("role: {}", js_string(r)));
133        }
134        if let Some(tid) = &params.test_id {
135            parts.push(format!("test_id: {}", js_string(tid)));
136        }
137        if let Some(c) = &params.css {
138            parts.push(format!("css: {}", js_string(c)));
139        }
140        if let Some(n) = &params.name {
141            parts.push(format!("name: {}", js_string(n)));
142        }
143        if let Some(max) = params.max_results {
144            parts.push(format!("max_results: {max}"));
145        }
146        let code = format!(
147            "return window.__VICTAURI__?.findElements({{ {} }})",
148            parts.join(", ")
149        );
150        self.eval_bridge(&code, params.webview_label.as_deref())
151            .await
152    }
153
154    #[tool(
155        description = "Invoke a registered Tauri command via IPC, just like the frontend would. Goes through the real IPC pipeline so calls are logged and verifiable. Returns the command's result. Subject to privacy command filtering.",
156        annotations(
157            read_only_hint = false,
158            destructive_hint = true,
159            idempotent_hint = false,
160            open_world_hint = false
161        )
162    )]
163    async fn invoke_command(
164        &self,
165        Parameters(params): Parameters<InvokeCommandParams>,
166    ) -> CallToolResult {
167        if !self.state.privacy.is_command_allowed(&params.command) {
168            return tool_error(format!(
169                "command '{}' is blocked by privacy configuration",
170                params.command
171            ));
172        }
173        let args_json = params.args.unwrap_or(serde_json::json!({}));
174        let args_str = serde_json::to_string(&args_json).unwrap_or_else(|_| "{}".to_string());
175        let code = format!(
176            "return window.__TAURI__.core.invoke({}, {args_str})",
177            js_string(&params.command)
178        );
179        match self
180            .eval_with_return(&code, params.webview_label.as_deref())
181            .await
182        {
183            Ok(result) => self.redact_result(result),
184            Err(e) => tool_error(format!("invoke_command failed: {e}")),
185        }
186    }
187
188    #[tool(
189        description = "Capture a screenshot of a Tauri window as a base64-encoded PNG image. Works on Windows (PrintWindow), macOS (CGWindowListCreateImage), and Linux (X11/Wayland).",
190        annotations(
191            read_only_hint = true,
192            destructive_hint = false,
193            idempotent_hint = true,
194            open_world_hint = false
195        )
196    )]
197    async fn screenshot(&self, Parameters(params): Parameters<ScreenshotParams>) -> CallToolResult {
198        self.track_tool_call();
199        if !self.state.privacy.is_tool_enabled("screenshot") {
200            return tool_disabled("screenshot");
201        }
202        match self
203            .bridge
204            .get_native_handle(params.window_label.as_deref())
205        {
206            Ok(hwnd) => match crate::screenshot::capture_window(hwnd).await {
207                Ok(png_bytes) => {
208                    use base64::Engine;
209                    let b64 = base64::engine::general_purpose::STANDARD.encode(&png_bytes);
210                    CallToolResult::success(vec![Content::image(b64, "image/png")])
211                }
212                Err(e) => tool_error(format!("screenshot capture failed: {e}")),
213            },
214            Err(e) => tool_error(format!("cannot get window handle: {e}")),
215        }
216    }
217
218    #[tool(
219        description = "Compare frontend state (evaluated via JS expression) against backend state to detect divergences. Returns a VerificationResult with any mismatches.",
220        annotations(
221            read_only_hint = true,
222            destructive_hint = false,
223            idempotent_hint = true,
224            open_world_hint = false
225        )
226    )]
227    async fn verify_state(
228        &self,
229        Parameters(params): Parameters<VerifyStateParams>,
230    ) -> CallToolResult {
231        let code = format!("return ({})", params.frontend_expr);
232        let frontend_json = match self
233            .eval_with_return(&code, params.webview_label.as_deref())
234            .await
235        {
236            Ok(result) => result,
237            Err(e) => return tool_error(format!("failed to evaluate frontend expression: {e}")),
238        };
239
240        let frontend_state: serde_json::Value = match serde_json::from_str(&frontend_json) {
241            Ok(v) => v,
242            Err(e) => {
243                return tool_error(format!(
244                    "frontend expression did not return valid JSON: {e}"
245                ));
246            }
247        };
248
249        let result = victauri_core::verify_state(frontend_state, params.backend_state);
250        json_result(&result)
251    }
252
253    #[tool(
254        description = "Detect ghost commands — commands invoked from the frontend that have no backend handler, or registered backend commands never called. Reads from the JS-side IPC interception log.",
255        annotations(
256            read_only_hint = true,
257            destructive_hint = false,
258            idempotent_hint = true,
259            open_world_hint = false
260        )
261    )]
262    async fn detect_ghost_commands(
263        &self,
264        Parameters(params): Parameters<GhostCommandParams>,
265    ) -> CallToolResult {
266        let code = "return window.__VICTAURI__?.getIpcLog()";
267        let ipc_json = match self
268            .eval_with_return(code, params.webview_label.as_deref())
269            .await
270        {
271            Ok(r) => r,
272            Err(e) => return tool_error(format!("failed to read IPC log: {e}")),
273        };
274
275        let ipc_calls: Vec<serde_json::Value> = match serde_json::from_str(&ipc_json) {
276            Ok(v) => v,
277            Err(e) => return tool_error(format!("failed to parse IPC log JSON: {e}")),
278        };
279        let frontend_commands: Vec<String> = ipc_calls
280            .iter()
281            .filter_map(|c| c.get("command").and_then(|v| v.as_str()).map(String::from))
282            .collect::<std::collections::HashSet<_>>()
283            .into_iter()
284            .collect();
285
286        let report = victauri_core::detect_ghost_commands(&frontend_commands, &self.state.registry);
287        json_result(&report)
288    }
289
290    #[tool(
291        description = "Check IPC round-trip integrity: find stale (stuck) pending calls and errored calls. Returns health status and lists of problematic IPC calls.",
292        annotations(
293            read_only_hint = true,
294            destructive_hint = false,
295            idempotent_hint = true,
296            open_world_hint = false
297        )
298    )]
299    async fn check_ipc_integrity(
300        &self,
301        Parameters(params): Parameters<IpcIntegrityParams>,
302    ) -> CallToolResult {
303        let threshold = params.stale_threshold_ms.unwrap_or(5000);
304        let code = format!(
305            r"return (function() {{
306                var log = window.__VICTAURI__?.getIpcLog() || [];
307                var now = Date.now();
308                var threshold = {threshold};
309                var pending = log.filter(function(c) {{ return c.status === 'pending'; }});
310                var stale = pending.filter(function(c) {{ return (now - c.timestamp) > threshold; }});
311                var errored = log.filter(function(c) {{ return c.status === 'error'; }});
312                return {{
313                    healthy: stale.length === 0 && errored.length === 0,
314                    total_calls: log.length,
315                    pending_count: pending.length,
316                    stale_count: stale.length,
317                    error_count: errored.length,
318                    stale_calls: stale.slice(0, 20),
319                    errored_calls: errored.slice(0, 20)
320                }};
321            }})()"
322        );
323        self.eval_bridge(&code, params.webview_label.as_deref())
324            .await
325    }
326
327    #[tool(
328        description = "Wait for a condition to be met. Polls at regular intervals until satisfied or timeout. Conditions: text (text appears), text_gone (text disappears), selector (CSS selector matches), selector_gone, url (URL contains value), ipc_idle (no pending IPC calls), network_idle (no pending network requests).",
329        annotations(
330            read_only_hint = true,
331            destructive_hint = false,
332            idempotent_hint = true,
333            open_world_hint = false
334        )
335    )]
336    async fn wait_for(&self, Parameters(params): Parameters<WaitForParams>) -> CallToolResult {
337        let value = params
338            .value
339            .as_ref()
340            .map_or_else(|| "null".to_string(), |v| js_string(v));
341        let timeout_ms = params.timeout_ms.unwrap_or(10000);
342        let poll = params.poll_ms.unwrap_or(200);
343        let code = format!(
344            "return window.__VICTAURI__?.waitFor({{ condition: {}, value: {value}, timeout_ms: {timeout_ms}, poll_ms: {poll} }})",
345            js_string(params.condition.as_str())
346        );
347        let eval_timeout = std::time::Duration::from_millis(timeout_ms + 5000);
348        match self
349            .eval_with_return_timeout(&code, params.webview_label.as_deref(), eval_timeout)
350            .await
351        {
352            Ok(result) => CallToolResult::success(vec![Content::text(result)]),
353            Err(e) => tool_error(e),
354        }
355    }
356
357    #[tool(
358        description = "Run a semantic assertion: evaluate a JS expression and check the result against an expected condition. Conditions: equals, not_equals, contains, greater_than, less_than, truthy, falsy, exists, type_is.",
359        annotations(
360            read_only_hint = true,
361            destructive_hint = false,
362            idempotent_hint = true,
363            open_world_hint = false
364        )
365    )]
366    async fn assert_semantic(
367        &self,
368        Parameters(params): Parameters<SemanticAssertParams>,
369    ) -> CallToolResult {
370        let code = format!("return ({})", params.expression);
371        let actual_json = match self
372            .eval_with_return(&code, params.webview_label.as_deref())
373            .await
374        {
375            Ok(result) => result,
376            Err(e) => return tool_error(format!("failed to evaluate expression: {e}")),
377        };
378
379        let actual: serde_json::Value = match serde_json::from_str(&actual_json) {
380            Ok(v) => v,
381            Err(e) => return tool_error(format!("expression did not return valid JSON: {e}")),
382        };
383
384        let assertion = victauri_core::SemanticAssertion {
385            label: params.label,
386            condition: params.condition,
387            expected: params.expected,
388        };
389
390        let result = victauri_core::evaluate_assertion(actual, &assertion);
391        json_result(&result)
392    }
393
394    #[tool(
395        description = "Resolve a natural language query to matching Tauri commands. Returns scored results ranked by relevance, using command names, descriptions, intents, categories, and examples.",
396        annotations(
397            read_only_hint = true,
398            destructive_hint = false,
399            idempotent_hint = true,
400            open_world_hint = false
401        )
402    )]
403    async fn resolve_command(
404        &self,
405        Parameters(params): Parameters<ResolveCommandParams>,
406    ) -> CallToolResult {
407        self.track_tool_call();
408        let limit = params.limit.unwrap_or(5);
409        let mut results = self.state.registry.resolve(&params.query);
410        results.truncate(limit);
411        json_result(&results)
412    }
413
414    #[tool(
415        description = "List or search all registered Tauri commands with their argument schemas. Pass query to filter by name/description substring. Commands are registered via #[inspectable] macro.",
416        annotations(
417            read_only_hint = true,
418            destructive_hint = false,
419            idempotent_hint = true,
420            open_world_hint = false
421        )
422    )]
423    async fn get_registry(&self, Parameters(params): Parameters<RegistryParams>) -> CallToolResult {
424        self.track_tool_call();
425        let commands = match params.query {
426            Some(q) => self.state.registry.search(&q),
427            None => self.state.registry.list(),
428        };
429        json_result(&commands)
430    }
431
432    #[tool(
433        description = "Get real-time process memory statistics from the OS (working set, page file usage). On Windows returns detailed metrics; on Linux returns virtual/resident size.",
434        annotations(
435            read_only_hint = true,
436            destructive_hint = false,
437            idempotent_hint = true,
438            open_world_hint = false
439        )
440    )]
441    async fn get_memory_stats(&self) -> CallToolResult {
442        self.track_tool_call();
443        let stats = crate::memory::current_stats();
444        json_result(&stats)
445    }
446
447    #[tool(
448        description = "Inspect the Victauri plugin's own configuration: port, enabled/disabled tools, command filters, privacy settings, capacities, and version. Useful for agents to understand their capabilities before acting.",
449        annotations(
450            read_only_hint = true,
451            destructive_hint = false,
452            idempotent_hint = true,
453            open_world_hint = false
454        )
455    )]
456    async fn get_plugin_info(&self) -> CallToolResult {
457        self.track_tool_call();
458        let disabled: Vec<&str> = self
459            .state
460            .privacy
461            .disabled_tools
462            .iter()
463            .map(std::string::String::as_str)
464            .collect();
465        let blocklist: Vec<&str> = self
466            .state
467            .privacy
468            .command_blocklist
469            .iter()
470            .map(std::string::String::as_str)
471            .collect();
472        let allowlist: Option<Vec<&str>> = self
473            .state
474            .privacy
475            .command_allowlist
476            .as_ref()
477            .map(|s| s.iter().map(std::string::String::as_str).collect());
478        let all_tools = Self::tool_router().list_all();
479        let enabled_tools: Vec<&str> = all_tools
480            .iter()
481            .filter(|t| self.state.privacy.is_tool_enabled(t.name.as_ref()))
482            .map(|t| t.name.as_ref())
483            .collect();
484
485        let result = serde_json::json!({
486            "version": env!("CARGO_PKG_VERSION"),
487            "bridge_version": BRIDGE_VERSION,
488            "port": self.state.port.load(Ordering::Relaxed),
489            "tools": {
490                "total": all_tools.len(),
491                "enabled": enabled_tools.len(),
492                "enabled_list": enabled_tools,
493                "disabled_list": disabled,
494            },
495            "commands": {
496                "allowlist": allowlist,
497                "blocklist": blocklist,
498            },
499            "privacy": {
500                "redaction_enabled": self.state.privacy.redaction_enabled,
501            },
502            "capacities": {
503                "event_log": self.state.event_log.capacity(),
504                "eval_timeout_secs": self.state.eval_timeout.as_secs(),
505            },
506            "registered_commands": self.state.registry.count(),
507            "tool_invocations": self.state.tool_invocations.load(std::sync::atomic::Ordering::Relaxed),
508            "uptime_secs": self.state.started_at.elapsed().as_secs(),
509        });
510        json_result(&result)
511    }
512
513    // ── Compound Tools ──────────────────────────────────────────────────────
514
515    #[tool(
516        description = "DOM element interactions. Actions: click, double_click, hover, focus, scroll_into_view, select_option. Requires ref_id from a dom_snapshot for most actions.",
517        annotations(
518            read_only_hint = false,
519            destructive_hint = false,
520            idempotent_hint = false,
521            open_world_hint = false
522        )
523    )]
524    async fn interact(&self, Parameters(params): Parameters<InteractParams>) -> CallToolResult {
525        match params.action {
526            InteractAction::Click => {
527                let Some(ref_id) = &params.ref_id else {
528                    return missing_param("ref_id", "click");
529                };
530                let code = format!("return window.__VICTAURI__?.click({})", js_string(ref_id));
531                self.eval_bridge(&code, params.webview_label.as_deref())
532                    .await
533            }
534            InteractAction::DoubleClick => {
535                let Some(ref_id) = &params.ref_id else {
536                    return missing_param("ref_id", "double_click");
537                };
538                let code = format!(
539                    "return window.__VICTAURI__?.doubleClick({})",
540                    js_string(ref_id)
541                );
542                self.eval_bridge(&code, params.webview_label.as_deref())
543                    .await
544            }
545            InteractAction::Hover => {
546                let Some(ref_id) = &params.ref_id else {
547                    return missing_param("ref_id", "hover");
548                };
549                let code = format!("return window.__VICTAURI__?.hover({})", js_string(ref_id));
550                self.eval_bridge(&code, params.webview_label.as_deref())
551                    .await
552            }
553            InteractAction::Focus => {
554                let Some(ref_id) = &params.ref_id else {
555                    return missing_param("ref_id", "focus");
556                };
557                let code = format!(
558                    "return window.__VICTAURI__?.focusElement({})",
559                    js_string(ref_id)
560                );
561                self.eval_bridge(&code, params.webview_label.as_deref())
562                    .await
563            }
564            InteractAction::ScrollIntoView => {
565                let ref_arg = params
566                    .ref_id
567                    .as_ref()
568                    .map_or_else(|| "null".to_string(), |r| js_string(r));
569                let x = params.x.unwrap_or(0.0);
570                let y = params.y.unwrap_or(0.0);
571                let code = format!("return window.__VICTAURI__?.scrollTo({ref_arg}, {x}, {y})");
572                self.eval_bridge(&code, params.webview_label.as_deref())
573                    .await
574            }
575            InteractAction::SelectOption => {
576                let Some(ref_id) = &params.ref_id else {
577                    return missing_param("ref_id", "select_option");
578                };
579                let values = params.values.as_deref().unwrap_or(&[]);
580                let values_json =
581                    serde_json::to_string(values).unwrap_or_else(|_| "[]".to_string());
582                let code = format!(
583                    "return window.__VICTAURI__?.selectOption({}, {})",
584                    js_string(ref_id),
585                    values_json
586                );
587                self.eval_bridge(&code, params.webview_label.as_deref())
588                    .await
589            }
590        }
591    }
592
593    #[tool(
594        description = "Text and keyboard input. Actions: fill (set input value), type_text (character-by-character typing), press_key (trigger a keyboard key). Subject to privacy controls.",
595        annotations(
596            read_only_hint = false,
597            destructive_hint = false,
598            idempotent_hint = false,
599            open_world_hint = false
600        )
601    )]
602    async fn input(&self, Parameters(params): Parameters<InputParams>) -> CallToolResult {
603        match params.action {
604            InputAction::Fill => {
605                if !self.state.privacy.is_tool_enabled("fill") {
606                    return tool_disabled("fill");
607                }
608                let Some(ref_id) = &params.ref_id else {
609                    return missing_param("ref_id", "fill");
610                };
611                let Some(value) = &params.value else {
612                    return missing_param("value", "fill");
613                };
614                let code = format!(
615                    "return window.__VICTAURI__?.fill({}, {})",
616                    js_string(ref_id),
617                    js_string(value)
618                );
619                self.eval_bridge(&code, params.webview_label.as_deref())
620                    .await
621            }
622            InputAction::TypeText => {
623                if !self.state.privacy.is_tool_enabled("type_text") {
624                    return tool_disabled("type_text");
625                }
626                let Some(ref_id) = &params.ref_id else {
627                    return missing_param("ref_id", "type_text");
628                };
629                let Some(text) = &params.text else {
630                    return missing_param("text", "type_text");
631                };
632                let code = format!(
633                    "return window.__VICTAURI__?.type({}, {})",
634                    js_string(ref_id),
635                    js_string(text)
636                );
637                self.eval_bridge(&code, params.webview_label.as_deref())
638                    .await
639            }
640            InputAction::PressKey => {
641                let Some(key) = &params.key else {
642                    return missing_param("key", "press_key");
643                };
644                let code = format!("return window.__VICTAURI__?.pressKey({})", js_string(key));
645                self.eval_bridge(&code, params.webview_label.as_deref())
646                    .await
647            }
648        }
649    }
650
651    #[tool(
652        description = "Window management. Actions: get_state (window positions/sizes/visibility), list (all window labels), manage (minimize/maximize/close/focus/show/hide/fullscreen/always_on_top), resize, move_to, set_title.",
653        annotations(
654            read_only_hint = false,
655            destructive_hint = false,
656            idempotent_hint = true,
657            open_world_hint = false
658        )
659    )]
660    async fn window(&self, Parameters(params): Parameters<WindowParams>) -> CallToolResult {
661        self.track_tool_call();
662        match params.action {
663            WindowAction::GetState => {
664                let states = self.bridge.get_window_states(params.label.as_deref());
665                json_result(&states)
666            }
667            WindowAction::List => {
668                let labels = self.bridge.list_window_labels();
669                json_result(&labels)
670            }
671            WindowAction::Manage => {
672                let Some(manage_action) = &params.manage_action else {
673                    return missing_param("manage_action", "manage");
674                };
675                match self
676                    .bridge
677                    .manage_window(params.label.as_deref(), manage_action.as_str())
678                {
679                    Ok(msg) => CallToolResult::success(vec![Content::text(msg)]),
680                    Err(e) => tool_error(e),
681                }
682            }
683            WindowAction::Resize => {
684                let Some(width) = params.width else {
685                    return missing_param("width", "resize");
686                };
687                let Some(height) = params.height else {
688                    return missing_param("height", "resize");
689                };
690                match self
691                    .bridge
692                    .resize_window(params.label.as_deref(), width, height)
693                {
694                    Ok(()) => {
695                        let result =
696                            serde_json::json!({"ok": true, "width": width, "height": height});
697                        CallToolResult::success(vec![Content::text(result.to_string())])
698                    }
699                    Err(e) => tool_error(e),
700                }
701            }
702            WindowAction::MoveTo => {
703                let Some(x) = params.x else {
704                    return missing_param("x", "move_to");
705                };
706                let Some(y) = params.y else {
707                    return missing_param("y", "move_to");
708                };
709                match self.bridge.move_window(params.label.as_deref(), x, y) {
710                    Ok(()) => {
711                        let result = serde_json::json!({"ok": true, "x": x, "y": y});
712                        CallToolResult::success(vec![Content::text(result.to_string())])
713                    }
714                    Err(e) => tool_error(e),
715                }
716            }
717            WindowAction::SetTitle => {
718                let Some(title) = &params.title else {
719                    return missing_param("title", "set_title");
720                };
721                match self.bridge.set_window_title(params.label.as_deref(), title) {
722                    Ok(()) => {
723                        let result = serde_json::json!({"ok": true, "title": title});
724                        CallToolResult::success(vec![Content::text(result.to_string())])
725                    }
726                    Err(e) => tool_error(e),
727                }
728            }
729        }
730    }
731
732    #[tool(
733        description = "Browser storage operations. Actions: get (read localStorage/sessionStorage), set (write), delete (remove key), get_cookies. Subject to privacy controls for set and delete.",
734        annotations(
735            read_only_hint = false,
736            destructive_hint = true,
737            idempotent_hint = false,
738            open_world_hint = false
739        )
740    )]
741    async fn storage(&self, Parameters(params): Parameters<StorageParams>) -> CallToolResult {
742        match params.action {
743            StorageAction::Get => {
744                let method = match params.storage_type.unwrap_or(StorageType::Local) {
745                    StorageType::Session => "getSessionStorage",
746                    StorageType::Local => "getLocalStorage",
747                };
748                let key_arg = params
749                    .key
750                    .as_ref()
751                    .map(|k| js_string(k))
752                    .unwrap_or_default();
753                let code = format!("return window.__VICTAURI__?.{method}({key_arg})");
754                self.eval_bridge_redacted(&code, params.webview_label.as_deref())
755                    .await
756            }
757            StorageAction::Set => {
758                if !self.state.privacy.is_tool_enabled("set_storage") {
759                    return tool_disabled("set_storage");
760                }
761                let method = match params.storage_type.unwrap_or(StorageType::Local) {
762                    StorageType::Session => "setSessionStorage",
763                    StorageType::Local => "setLocalStorage",
764                };
765                let Some(key) = &params.key else {
766                    return missing_param("key", "set");
767                };
768                let value = params
769                    .value
770                    .as_ref()
771                    .cloned()
772                    .unwrap_or(serde_json::Value::Null);
773                let value_json =
774                    serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string());
775                let code = format!(
776                    "return window.__VICTAURI__?.{method}({}, {value_json})",
777                    js_string(key)
778                );
779                self.eval_bridge(&code, params.webview_label.as_deref())
780                    .await
781            }
782            StorageAction::Delete => {
783                if !self.state.privacy.is_tool_enabled("delete_storage") {
784                    return tool_disabled("delete_storage");
785                }
786                let method = match params.storage_type.unwrap_or(StorageType::Local) {
787                    StorageType::Session => "deleteSessionStorage",
788                    StorageType::Local => "deleteLocalStorage",
789                };
790                let Some(key) = &params.key else {
791                    return missing_param("key", "delete");
792                };
793                let code = format!("return window.__VICTAURI__?.{method}({})", js_string(key));
794                self.eval_bridge(&code, params.webview_label.as_deref())
795                    .await
796            }
797            StorageAction::GetCookies => {
798                self.eval_bridge_redacted(
799                    "return window.__VICTAURI__?.getCookies()",
800                    params.webview_label.as_deref(),
801                )
802                .await
803            }
804        }
805    }
806
807    #[tool(
808        description = "Navigation and dialog control. Actions: go_to (navigate to URL), go_back (browser back), get_history (navigation log), set_dialog_response (auto-respond to alert/confirm/prompt), get_dialog_log (captured dialog events). Subject to privacy controls for go_to and set_dialog_response.",
809        annotations(
810            read_only_hint = false,
811            destructive_hint = false,
812            idempotent_hint = false,
813            open_world_hint = false
814        )
815    )]
816    async fn navigate(&self, Parameters(params): Parameters<NavigateParams>) -> CallToolResult {
817        match params.action {
818            NavigateAction::GoTo => {
819                if !self.state.privacy.is_tool_enabled("navigate") {
820                    return tool_disabled("navigate");
821                }
822                let Some(url) = &params.url else {
823                    return missing_param("url", "go_to");
824                };
825                if let Err(e) = validate_url(url) {
826                    return tool_error(e);
827                }
828                let code = format!("return window.__VICTAURI__?.navigate({})", js_string(url));
829                self.eval_bridge(&code, params.webview_label.as_deref())
830                    .await
831            }
832            NavigateAction::GoBack => {
833                self.eval_bridge(
834                    "return window.__VICTAURI__?.navigateBack()",
835                    params.webview_label.as_deref(),
836                )
837                .await
838            }
839            NavigateAction::GetHistory => {
840                self.eval_bridge(
841                    "return window.__VICTAURI__?.getNavigationLog()",
842                    params.webview_label.as_deref(),
843                )
844                .await
845            }
846            NavigateAction::SetDialogResponse => {
847                if !self.state.privacy.is_tool_enabled("set_dialog_response") {
848                    return tool_disabled("set_dialog_response");
849                }
850                let Some(dialog_type) = params.dialog_type else {
851                    return missing_param("dialog_type", "set_dialog_response");
852                };
853                let Some(dialog_action) = params.dialog_action else {
854                    return missing_param("dialog_action", "set_dialog_response");
855                };
856                let text_arg = params
857                    .text
858                    .as_ref()
859                    .map_or_else(|| "undefined".to_string(), |t| js_string(t));
860                let code = format!(
861                    "return window.__VICTAURI__?.setDialogAutoResponse({}, {}, {text_arg})",
862                    js_string(dialog_type.as_str()),
863                    js_string(dialog_action.as_str())
864                );
865                self.eval_bridge(&code, params.webview_label.as_deref())
866                    .await
867            }
868            NavigateAction::GetDialogLog => {
869                self.eval_bridge(
870                    "return window.__VICTAURI__?.getDialogLog()",
871                    params.webview_label.as_deref(),
872                )
873                .await
874            }
875        }
876    }
877
878    #[tool(
879        description = "Time-travel recording. Actions: start (begin recording), stop (end and return session), checkpoint (save state snapshot), list_checkpoints, get_events (since index), events_between (two checkpoints), get_replay (IPC replay sequence), export (session as JSON), import (load session from JSON).",
880        annotations(
881            read_only_hint = false,
882            destructive_hint = false,
883            idempotent_hint = false,
884            open_world_hint = false
885        )
886    )]
887    async fn recording(&self, Parameters(params): Parameters<RecordingParams>) -> CallToolResult {
888        self.track_tool_call();
889        match params.action {
890            RecordingAction::Start => {
891                let session_id = params
892                    .session_id
893                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
894                match self.state.recorder.start(session_id.clone()) {
895                    Ok(()) => {
896                        let result = serde_json::json!({
897                            "started": true,
898                            "session_id": session_id,
899                        });
900                        CallToolResult::success(vec![Content::text(result.to_string())])
901                    }
902                    Err(e) => tool_error(e.to_string()),
903                }
904            }
905            RecordingAction::Stop => match self.state.recorder.stop() {
906                Some(session) => json_result(&session),
907                None => tool_error("no recording is active"),
908            },
909            RecordingAction::Checkpoint => {
910                let Some(id) = params.checkpoint_id else {
911                    return missing_param("checkpoint_id", "checkpoint");
912                };
913                let state = params.state.unwrap_or(serde_json::Value::Null);
914                match self
915                    .state
916                    .recorder
917                    .checkpoint(id.clone(), params.checkpoint_label, state)
918                {
919                    Ok(()) => {
920                        let result = serde_json::json!({
921                            "created": true,
922                            "checkpoint_id": id,
923                            "event_index": self.state.recorder.event_count(),
924                        });
925                        CallToolResult::success(vec![Content::text(result.to_string())])
926                    }
927                    Err(e) => tool_error(e.to_string()),
928                }
929            }
930            RecordingAction::ListCheckpoints => {
931                let checkpoints = self.state.recorder.get_checkpoints();
932                json_result(&checkpoints)
933            }
934            RecordingAction::GetEvents => {
935                let events = self
936                    .state
937                    .recorder
938                    .events_since(params.since_index.unwrap_or(0));
939                json_result(&events)
940            }
941            RecordingAction::EventsBetween => {
942                let Some(from) = &params.from else {
943                    return missing_param("from", "events_between");
944                };
945                let Some(to) = &params.to else {
946                    return missing_param("to", "events_between");
947                };
948                match self.state.recorder.events_between_checkpoints(from, to) {
949                    Ok(events) => json_result(&events),
950                    Err(e) => tool_error(e.to_string()),
951                }
952            }
953            RecordingAction::GetReplay => {
954                let calls = self.state.recorder.ipc_replay_sequence();
955                json_result(&calls)
956            }
957            RecordingAction::Export => match self.state.recorder.export() {
958                Some(s) => {
959                    let json = serde_json::to_string_pretty(&s)
960                        .unwrap_or_else(|e| format!("{{\"error\": \"{e}\"}}"));
961                    CallToolResult::success(vec![Content::text(json)])
962                }
963                None => tool_error("no recording is active — start one first"),
964            },
965            RecordingAction::Import => {
966                let Some(session_json) = &params.session_json else {
967                    return missing_param("session_json", "import");
968                };
969                let session: victauri_core::RecordedSession =
970                    match serde_json::from_str(session_json) {
971                        Ok(s) => s,
972                        Err(e) => return tool_error(format!("invalid session JSON: {e}")),
973                    };
974
975                let result = serde_json::json!({
976                    "imported": true,
977                    "session_id": session.id,
978                    "event_count": session.events.len(),
979                    "checkpoint_count": session.checkpoints.len(),
980                    "started_at": session.started_at.to_rfc3339(),
981                });
982                self.state.recorder.import(session);
983                CallToolResult::success(vec![Content::text(result.to_string())])
984            }
985        }
986    }
987
988    #[tool(
989        description = "CSS and visual inspection. Actions: get_styles (computed CSS for element), get_bounding_boxes (layout rects), highlight (debug overlay), clear_highlights, audit_accessibility (a11y audit), get_performance (timing/heap/DOM metrics).",
990        annotations(
991            read_only_hint = true,
992            destructive_hint = false,
993            idempotent_hint = true,
994            open_world_hint = false
995        )
996    )]
997    async fn inspect(&self, Parameters(params): Parameters<InspectParams>) -> CallToolResult {
998        match params.action {
999            InspectAction::GetStyles => {
1000                let Some(ref_id) = &params.ref_id else {
1001                    return missing_param("ref_id", "get_styles");
1002                };
1003                let props_arg = match &params.properties {
1004                    Some(props) => {
1005                        let arr: Vec<String> = props.iter().map(|p| js_string(p)).collect();
1006                        format!("[{}]", arr.join(","))
1007                    }
1008                    None => "null".to_string(),
1009                };
1010                let code = format!(
1011                    "return window.__VICTAURI__?.getStyles({}, {})",
1012                    js_string(ref_id),
1013                    props_arg
1014                );
1015                self.eval_bridge(&code, params.webview_label.as_deref())
1016                    .await
1017            }
1018            InspectAction::GetBoundingBoxes => {
1019                let Some(ref_ids) = &params.ref_ids else {
1020                    return missing_param("ref_ids", "get_bounding_boxes");
1021                };
1022                let refs: Vec<String> = ref_ids.iter().map(|r| js_string(r)).collect();
1023                let code = format!(
1024                    "return window.__VICTAURI__?.getBoundingBoxes([{}])",
1025                    refs.join(",")
1026                );
1027                self.eval_bridge(&code, params.webview_label.as_deref())
1028                    .await
1029            }
1030            InspectAction::Highlight => {
1031                let Some(ref_id) = &params.ref_id else {
1032                    return missing_param("ref_id", "highlight");
1033                };
1034                let color_arg = match &params.color {
1035                    Some(c) => match sanitize_css_color(c) {
1036                        Ok(safe) => format!("\"{safe}\""),
1037                        Err(e) => return tool_error(e),
1038                    },
1039                    None => "null".to_string(),
1040                };
1041                let label_arg = match &params.label {
1042                    Some(l) => js_string(l),
1043                    None => "null".to_string(),
1044                };
1045                let code = format!(
1046                    "return window.__VICTAURI__?.highlightElement({}, {}, {})",
1047                    js_string(ref_id),
1048                    color_arg,
1049                    label_arg
1050                );
1051                self.eval_bridge(&code, params.webview_label.as_deref())
1052                    .await
1053            }
1054            InspectAction::ClearHighlights => {
1055                self.eval_bridge(
1056                    "return window.__VICTAURI__?.clearHighlights()",
1057                    params.webview_label.as_deref(),
1058                )
1059                .await
1060            }
1061            InspectAction::AuditAccessibility => {
1062                self.eval_bridge(
1063                    "return window.__VICTAURI__?.auditAccessibility()",
1064                    params.webview_label.as_deref(),
1065                )
1066                .await
1067            }
1068            InspectAction::GetPerformance => {
1069                self.eval_bridge(
1070                    "return window.__VICTAURI__?.getPerformanceMetrics()",
1071                    params.webview_label.as_deref(),
1072                )
1073                .await
1074            }
1075        }
1076    }
1077
1078    #[tool(
1079        description = "CSS injection. Actions: inject (add custom CSS to page), remove (remove previously injected CSS). Subject to privacy controls.",
1080        annotations(
1081            read_only_hint = false,
1082            destructive_hint = false,
1083            idempotent_hint = true,
1084            open_world_hint = false
1085        )
1086    )]
1087    async fn css(&self, Parameters(params): Parameters<CssParams>) -> CallToolResult {
1088        match params.action {
1089            CssAction::Inject => {
1090                if !self.state.privacy.is_tool_enabled("inject_css") {
1091                    return tool_disabled("inject_css");
1092                }
1093                let Some(css) = &params.css else {
1094                    return missing_param("css", "inject");
1095                };
1096                let code = format!("return window.__VICTAURI__?.injectCss({})", js_string(css));
1097                self.eval_bridge(&code, params.webview_label.as_deref())
1098                    .await
1099            }
1100            CssAction::Remove => {
1101                self.eval_bridge(
1102                    "return window.__VICTAURI__?.removeInjectedCss()",
1103                    params.webview_label.as_deref(),
1104                )
1105                .await
1106            }
1107        }
1108    }
1109
1110    #[tool(
1111        description = "Application logs and monitoring. Actions: console (captured console.log/warn/error), network (intercepted fetch/XHR), ipc (IPC call log), navigation (URL change history), dialogs (alert/confirm/prompt events), events (combined event stream), slow_ipc (find slow IPC calls).",
1112        annotations(
1113            read_only_hint = true,
1114            destructive_hint = false,
1115            idempotent_hint = true,
1116            open_world_hint = false
1117        )
1118    )]
1119    async fn logs(&self, Parameters(params): Parameters<LogsParams>) -> CallToolResult {
1120        match params.action {
1121            LogsAction::Console => {
1122                let since_arg = params.since.map(|ts| format!("{ts}")).unwrap_or_default();
1123                let code = if since_arg.is_empty() {
1124                    "return window.__VICTAURI__?.getConsoleLogs()".to_string()
1125                } else {
1126                    format!("return window.__VICTAURI__?.getConsoleLogs({since_arg})")
1127                };
1128                self.eval_bridge_redacted(&code, params.webview_label.as_deref())
1129                    .await
1130            }
1131            LogsAction::Network => {
1132                let filter_arg = params
1133                    .filter
1134                    .as_ref()
1135                    .map_or_else(|| "null".to_string(), |f| js_string(f));
1136                let limit_arg = params
1137                    .limit
1138                    .map_or_else(|| "null".to_string(), |l| l.to_string());
1139                let code =
1140                    format!("return window.__VICTAURI__?.getNetworkLog({filter_arg}, {limit_arg})");
1141                self.eval_bridge_redacted(&code, params.webview_label.as_deref())
1142                    .await
1143            }
1144            LogsAction::Ipc => {
1145                let limit_arg = params.limit.map(|l| format!("{l}")).unwrap_or_default();
1146                let code = if limit_arg.is_empty() {
1147                    "return window.__VICTAURI__?.getIpcLog()".to_string()
1148                } else {
1149                    format!("return window.__VICTAURI__?.getIpcLog({limit_arg})")
1150                };
1151                self.eval_bridge_redacted(&code, params.webview_label.as_deref())
1152                    .await
1153            }
1154            LogsAction::Navigation => {
1155                self.eval_bridge(
1156                    "return window.__VICTAURI__?.getNavigationLog()",
1157                    params.webview_label.as_deref(),
1158                )
1159                .await
1160            }
1161            LogsAction::Dialogs => {
1162                self.eval_bridge(
1163                    "return window.__VICTAURI__?.getDialogLog()",
1164                    params.webview_label.as_deref(),
1165                )
1166                .await
1167            }
1168            LogsAction::Events => {
1169                let since_arg = params.since.map(|ts| format!("{ts}")).unwrap_or_default();
1170                let code = if since_arg.is_empty() {
1171                    "return window.__VICTAURI__?.getEventStream()".to_string()
1172                } else {
1173                    format!("return window.__VICTAURI__?.getEventStream({since_arg})")
1174                };
1175                self.eval_bridge(&code, params.webview_label.as_deref())
1176                    .await
1177            }
1178            LogsAction::SlowIpc => {
1179                let Some(threshold) = params.threshold_ms else {
1180                    return missing_param("threshold_ms", "slow_ipc");
1181                };
1182                let limit = params.limit.unwrap_or(20);
1183                let code = format!(
1184                    r"return (function() {{
1185                        var log = window.__VICTAURI__?.getIpcLog() || [];
1186                        var slow = log.filter(function(c) {{ return (c.duration_ms || 0) > {threshold}; }});
1187                        slow.sort(function(a, b) {{ return (b.duration_ms || 0) - (a.duration_ms || 0); }});
1188                        return {{ threshold_ms: {threshold}, count: Math.min(slow.length, {limit}), calls: slow.slice(0, {limit}) }};
1189                    }})()",
1190                );
1191                self.eval_bridge_redacted(&code, None).await
1192            }
1193        }
1194    }
1195}
1196
1197impl VictauriMcpHandler {
1198    /// Create a new handler backed by the given state and webview bridge.
1199    pub fn new(state: Arc<VictauriState>, bridge: Arc<dyn WebviewBridge>) -> Self {
1200        Self {
1201            state,
1202            bridge,
1203            subscriptions: Arc::new(Mutex::new(HashSet::new())),
1204            bridge_checked: Arc::new(AtomicBool::new(false)),
1205        }
1206    }
1207
1208    fn track_tool_call(&self) {
1209        self.state.tool_invocations.fetch_add(1, Ordering::Relaxed);
1210    }
1211
1212    async fn eval_bridge(&self, code: &str, webview_label: Option<&str>) -> CallToolResult {
1213        match self.eval_with_return(code, webview_label).await {
1214            Ok(result) => CallToolResult::success(vec![Content::text(result)]),
1215            Err(e) => tool_error(e),
1216        }
1217    }
1218
1219    async fn eval_bridge_redacted(
1220        &self,
1221        code: &str,
1222        webview_label: Option<&str>,
1223    ) -> CallToolResult {
1224        match self.eval_with_return(code, webview_label).await {
1225            Ok(result) => self.redact_result(result),
1226            Err(e) => tool_error(e),
1227        }
1228    }
1229
1230    fn redact_result(&self, output: String) -> CallToolResult {
1231        let redacted = self.state.privacy.redact_output(&output);
1232        CallToolResult::success(vec![Content::text(redacted)])
1233    }
1234
1235    async fn eval_with_return(
1236        &self,
1237        code: &str,
1238        webview_label: Option<&str>,
1239    ) -> Result<String, String> {
1240        self.eval_with_return_timeout(code, webview_label, self.state.eval_timeout)
1241            .await
1242    }
1243
1244    async fn eval_with_return_timeout(
1245        &self,
1246        code: &str,
1247        webview_label: Option<&str>,
1248        timeout: std::time::Duration,
1249    ) -> Result<String, String> {
1250        self.track_tool_call();
1251        let id = uuid::Uuid::new_v4().to_string();
1252        let (tx, rx) = tokio::sync::oneshot::channel();
1253
1254        {
1255            let mut pending = self.state.pending_evals.lock().await;
1256            if pending.len() >= MAX_PENDING_EVALS {
1257                return Err(format!(
1258                    "too many concurrent eval requests (limit: {MAX_PENDING_EVALS})"
1259                ));
1260            }
1261            pending.insert(id.clone(), tx);
1262        }
1263
1264        // Auto-prepend `return` so bare expressions produce a value.
1265        // Only skip for code that starts with a statement keyword where
1266        // prepending `return` would be a syntax error.
1267        let code = code.trim();
1268        let needs_return = !code.starts_with("return ")
1269            && !code.starts_with("return;")
1270            && !code.starts_with('{')
1271            && !code.starts_with("if ")
1272            && !code.starts_with("if(")
1273            && !code.starts_with("for ")
1274            && !code.starts_with("for(")
1275            && !code.starts_with("while ")
1276            && !code.starts_with("while(")
1277            && !code.starts_with("switch ")
1278            && !code.starts_with("try ")
1279            && !code.starts_with("const ")
1280            && !code.starts_with("let ")
1281            && !code.starts_with("var ")
1282            && !code.starts_with("function ")
1283            && !code.starts_with("class ")
1284            && !code.starts_with("throw ");
1285        let code = if needs_return {
1286            format!("return {code}")
1287        } else {
1288            code.to_string()
1289        };
1290
1291        let inject = format!(
1292            r"
1293            (async () => {{
1294                try {{
1295                    const __result = await (async () => {{ {code} }})();
1296                    await window.__TAURI__.core.invoke('plugin:victauri|victauri_eval_callback', {{
1297                        id: '{id}',
1298                        result: JSON.stringify(__result)
1299                    }});
1300                }} catch (e) {{
1301                    await window.__TAURI__.core.invoke('plugin:victauri|victauri_eval_callback', {{
1302                        id: '{id}',
1303                        result: JSON.stringify({{ __error: e.message }})
1304                    }});
1305                }}
1306            }})();
1307            "
1308        );
1309
1310        if let Err(e) = self.bridge.eval_webview(webview_label, &inject) {
1311            self.state.pending_evals.lock().await.remove(&id);
1312            return Err(format!("eval injection failed: {e}"));
1313        }
1314
1315        match tokio::time::timeout(timeout, rx).await {
1316            Ok(Ok(result)) => {
1317                self.check_bridge_version_once();
1318                Ok(result)
1319            }
1320            Ok(Err(_)) => Err("eval callback channel closed".to_string()),
1321            Err(_) => {
1322                self.state.pending_evals.lock().await.remove(&id);
1323                Err(format!("eval timed out after {}s", timeout.as_secs()))
1324            }
1325        }
1326    }
1327
1328    fn check_bridge_version_once(&self) {
1329        if self.bridge_checked.swap(true, Ordering::Relaxed) {
1330            return;
1331        }
1332        let handler = self.clone();
1333        tokio::spawn(async move {
1334            match handler
1335                .eval_with_return_timeout(
1336                    "window.__VICTAURI__?.version",
1337                    None,
1338                    std::time::Duration::from_secs(5),
1339                )
1340                .await
1341            {
1342                Ok(v) => {
1343                    let v = v.trim_matches('"');
1344                    if v == BRIDGE_VERSION {
1345                        tracing::debug!("Bridge version verified: {v}");
1346                    } else {
1347                        tracing::warn!(
1348                            "Bridge version mismatch: Rust expects {BRIDGE_VERSION}, JS reports {v}"
1349                        );
1350                    }
1351                }
1352                Err(e) => tracing::debug!("Bridge version check skipped: {e}"),
1353            }
1354        });
1355    }
1356}
1357
1358const SERVER_INSTRUCTIONS: &str = "Victauri gives you X-ray vision and hands inside a running Tauri application. \
1359Use compound tools with an 'action' parameter to interact with the app: \
1360'interact' (click, hover, focus, scroll, select), 'input' (fill, type_text, press_key), \
1361'window' (get_state, list, manage, resize, move_to, set_title), \
1362'storage' (get, set, delete, get_cookies), 'navigate' (go_to, go_back, get_history, \
1363set_dialog_response, get_dialog_log), 'recording' (start, stop, checkpoint, list_checkpoints, \
1364get_events, events_between, get_replay, export, import), 'inspect' (get_styles, \
1365get_bounding_boxes, highlight, clear_highlights, audit_accessibility, get_performance), \
1366'css' (inject, remove), 'logs' (console, network, ipc, navigation, dialogs, events, slow_ipc). \
1367Standalone tools: eval_js, dom_snapshot, invoke_command, screenshot, verify_state, \
1368detect_ghost_commands, check_ipc_integrity, wait_for, assert_semantic, resolve_command, \
1369get_registry, get_memory_stats, get_plugin_info.";
1370
1371impl ServerHandler for VictauriMcpHandler {
1372    fn get_info(&self) -> ServerInfo {
1373        ServerInfo::new(
1374            ServerCapabilities::builder()
1375                .enable_tools()
1376                .enable_resources()
1377                .enable_resources_subscribe()
1378                .build(),
1379        )
1380        .with_instructions(SERVER_INSTRUCTIONS)
1381    }
1382
1383    async fn list_tools(
1384        &self,
1385        _request: Option<PaginatedRequestParams>,
1386        _context: RequestContext<RoleServer>,
1387    ) -> Result<ListToolsResult, ErrorData> {
1388        let all_tools = Self::tool_router().list_all();
1389        let filtered: Vec<Tool> = all_tools
1390            .into_iter()
1391            .filter(|t| self.state.privacy.is_tool_enabled(t.name.as_ref()))
1392            .collect();
1393        Ok(ListToolsResult {
1394            tools: filtered,
1395            ..Default::default()
1396        })
1397    }
1398
1399    async fn call_tool(
1400        &self,
1401        request: CallToolRequestParams,
1402        context: RequestContext<RoleServer>,
1403    ) -> Result<CallToolResult, ErrorData> {
1404        let tool_name: String = request.name.as_ref().to_owned();
1405        if !self.state.privacy.is_tool_enabled(&tool_name) {
1406            tracing::debug!(tool = %tool_name, "tool call blocked by privacy config");
1407            return Ok(tool_disabled(&tool_name));
1408        }
1409        self.state
1410            .tool_invocations
1411            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1412        let start = std::time::Instant::now();
1413        tracing::debug!(tool = %tool_name, "tool invocation started");
1414        let ctx = ToolCallContext::new(self, request, context);
1415        let result = Self::tool_router().call(ctx).await;
1416        let elapsed = start.elapsed();
1417        tracing::debug!(
1418            tool = %tool_name,
1419            elapsed_ms = elapsed.as_millis() as u64,
1420            is_error = result.as_ref().map_or(true, |r| r.is_error.unwrap_or(false)),
1421            "tool invocation completed"
1422        );
1423        result
1424    }
1425
1426    fn get_tool(&self, name: &str) -> Option<Tool> {
1427        if !self.state.privacy.is_tool_enabled(name) {
1428            return None;
1429        }
1430        Self::tool_router().get(name).cloned()
1431    }
1432
1433    async fn list_resources(
1434        &self,
1435        _request: Option<PaginatedRequestParams>,
1436        _context: RequestContext<RoleServer>,
1437    ) -> Result<ListResourcesResult, ErrorData> {
1438        Ok(ListResourcesResult {
1439            resources: vec![
1440                RawResource::new(RESOURCE_URI_IPC_LOG, "ipc-log")
1441                    .with_description(
1442                        "Live IPC call log — all commands invoked between frontend and backend",
1443                    )
1444                    .with_mime_type("application/json")
1445                    .no_annotation(),
1446                RawResource::new(RESOURCE_URI_WINDOWS, "windows")
1447                    .with_description(
1448                        "Current state of all Tauri windows — position, size, visibility, focus",
1449                    )
1450                    .with_mime_type("application/json")
1451                    .no_annotation(),
1452                RawResource::new(RESOURCE_URI_STATE, "state")
1453                    .with_description(
1454                        "Victauri plugin state — event count, registered commands, memory stats",
1455                    )
1456                    .with_mime_type("application/json")
1457                    .no_annotation(),
1458            ],
1459            ..Default::default()
1460        })
1461    }
1462
1463    async fn read_resource(
1464        &self,
1465        request: ReadResourceRequestParams,
1466        _context: RequestContext<RoleServer>,
1467    ) -> Result<ReadResourceResult, ErrorData> {
1468        let uri = &request.uri;
1469        let json = match uri.as_str() {
1470            RESOURCE_URI_IPC_LOG => {
1471                if let Ok(json) = self
1472                    .eval_with_return("return window.__VICTAURI__?.getIpcLog()", None)
1473                    .await
1474                {
1475                    json
1476                } else {
1477                    let calls = self.state.event_log.ipc_calls();
1478                    serde_json::to_string_pretty(&calls)
1479                        .map_err(|e| ErrorData::internal_error(e.to_string(), None))?
1480                }
1481            }
1482            RESOURCE_URI_WINDOWS => {
1483                let states = self.bridge.get_window_states(None);
1484                serde_json::to_string_pretty(&states)
1485                    .map_err(|e| ErrorData::internal_error(e.to_string(), None))?
1486            }
1487            RESOURCE_URI_STATE => {
1488                let state_json = serde_json::json!({
1489                    "events_captured": self.state.event_log.len(),
1490                    "commands_registered": self.state.registry.count(),
1491                    "memory": crate::memory::current_stats(),
1492                    "port": self.state.port.load(Ordering::Relaxed),
1493                });
1494                serde_json::to_string_pretty(&state_json)
1495                    .map_err(|e| ErrorData::internal_error(e.to_string(), None))?
1496            }
1497            _ => {
1498                return Err(ErrorData::resource_not_found(
1499                    format!("unknown resource: {uri}"),
1500                    None,
1501                ));
1502            }
1503        };
1504
1505        Ok(ReadResourceResult::new(vec![ResourceContents::text(
1506            json, uri,
1507        )]))
1508    }
1509
1510    async fn subscribe(
1511        &self,
1512        request: SubscribeRequestParams,
1513        _context: RequestContext<RoleServer>,
1514    ) -> Result<(), ErrorData> {
1515        let uri = &request.uri;
1516        match uri.as_str() {
1517            RESOURCE_URI_IPC_LOG | RESOURCE_URI_WINDOWS | RESOURCE_URI_STATE => {
1518                self.subscriptions.lock().await.insert(uri.clone());
1519                tracing::info!("Client subscribed to resource: {uri}");
1520                Ok(())
1521            }
1522            _ => Err(ErrorData::resource_not_found(
1523                format!("unknown resource: {uri}"),
1524                None,
1525            )),
1526        }
1527    }
1528
1529    async fn unsubscribe(
1530        &self,
1531        request: UnsubscribeRequestParams,
1532        _context: RequestContext<RoleServer>,
1533    ) -> Result<(), ErrorData> {
1534        self.subscriptions.lock().await.remove(&request.uri);
1535        tracing::info!("Client unsubscribed from resource: {}", request.uri);
1536        Ok(())
1537    }
1538}
1539
1540#[cfg(test)]
1541mod tests {
1542    use super::*;
1543
1544    #[test]
1545    fn js_string_simple() {
1546        assert_eq!(js_string("hello"), "\"hello\"");
1547    }
1548
1549    #[test]
1550    fn js_string_single_quotes() {
1551        let result = js_string("it's a test");
1552        assert!(result.contains("it's a test"));
1553    }
1554
1555    #[test]
1556    fn js_string_double_quotes() {
1557        let result = js_string(r#"say "hello""#);
1558        assert!(result.contains(r#"\""#));
1559    }
1560
1561    #[test]
1562    fn js_string_backslashes() {
1563        let result = js_string(r"path\to\file");
1564        assert!(result.contains(r"\\"));
1565    }
1566
1567    #[test]
1568    fn js_string_newlines_and_tabs() {
1569        let result = js_string("line1\nline2\ttab");
1570        assert!(result.contains(r"\n"));
1571        assert!(result.contains(r"\t"));
1572        assert!(!result.contains('\n'));
1573    }
1574
1575    #[test]
1576    fn js_string_null_bytes() {
1577        let input = String::from_utf8(b"before\x00after".to_vec()).unwrap();
1578        let result = js_string(&input);
1579        // serde_json escapes null bytes as
1580        assert!(result.contains("\\u0000"));
1581        assert!(!result.contains('\0'));
1582    }
1583
1584    #[test]
1585    fn js_string_template_literal_injection() {
1586        let result = js_string("`${alert(1)}`");
1587        // Should not contain unescaped backticks that could break template literals
1588        // serde_json wraps in double quotes, so backticks are safe
1589        assert!(result.starts_with('"'));
1590        assert!(result.ends_with('"'));
1591    }
1592
1593    #[test]
1594    fn js_string_unicode_separators() {
1595        // U+2028 (Line Separator) and U+2029 (Paragraph Separator) are valid in
1596        // JSON strings per RFC 8259, and serde_json passes them through literally.
1597        // Since js_string is used inside JS double-quoted strings (not template
1598        // literals), they are safe in modern JS engines (ES2019+).
1599        let result = js_string("a\u{2028}b\u{2029}c");
1600        // Verify the string is valid JSON that round-trips correctly
1601        let decoded: String = serde_json::from_str(&result).unwrap();
1602        assert_eq!(decoded, "a\u{2028}b\u{2029}c");
1603    }
1604
1605    #[test]
1606    fn js_string_empty() {
1607        assert_eq!(js_string(""), "\"\"");
1608    }
1609
1610    #[test]
1611    fn js_string_html_script_close() {
1612        // </script> in a JS string inside HTML could break out of script tags
1613        let result = js_string("</script><img onerror=alert(1)>");
1614        assert!(result.starts_with('"'));
1615        // The string is JSON-encoded; verify it round-trips safely
1616        let decoded: String = serde_json::from_str(&result).unwrap();
1617        assert_eq!(decoded, "</script><img onerror=alert(1)>");
1618    }
1619
1620    #[test]
1621    fn js_string_very_long() {
1622        let long = "a".repeat(100_000);
1623        let result = js_string(&long);
1624        assert!(result.len() >= 100_002); // quotes + content
1625    }
1626
1627    // ── URL validation tests ────────────────────────────────────────────────
1628
1629    #[test]
1630    fn url_allows_http() {
1631        assert!(validate_url("http://example.com").is_ok());
1632    }
1633
1634    #[test]
1635    fn url_allows_https() {
1636        assert!(validate_url("https://example.com/path?q=1").is_ok());
1637    }
1638
1639    #[test]
1640    fn url_allows_file() {
1641        assert!(validate_url("file:///tmp/test.html").is_ok());
1642    }
1643
1644    #[test]
1645    fn url_blocks_javascript() {
1646        assert!(validate_url("javascript:alert(1)").is_err());
1647    }
1648
1649    #[test]
1650    fn url_blocks_javascript_case_insensitive() {
1651        assert!(validate_url("JAVASCRIPT:alert(1)").is_err());
1652    }
1653
1654    #[test]
1655    fn url_blocks_data_scheme() {
1656        assert!(validate_url("data:text/html,<script>alert(1)</script>").is_err());
1657    }
1658
1659    #[test]
1660    fn url_blocks_vbscript() {
1661        assert!(validate_url("vbscript:MsgBox(1)").is_err());
1662    }
1663
1664    #[test]
1665    fn url_rejects_invalid() {
1666        assert!(validate_url("not a url at all").is_err());
1667    }
1668
1669    #[test]
1670    fn url_strips_control_chars() {
1671        // Control characters should be stripped, leaving a valid URL
1672        let input = format!("http://example{}com", '\0');
1673        assert!(validate_url(&input).is_ok());
1674    }
1675
1676    // ── CSS color sanitization tests ───────────────────────────────────────
1677
1678    #[test]
1679    fn css_color_valid_hex() {
1680        assert_eq!(sanitize_css_color("#ff0000").unwrap(), "#ff0000");
1681        assert_eq!(sanitize_css_color("#FFF").unwrap(), "#FFF");
1682        assert_eq!(sanitize_css_color("#12345678").unwrap(), "#12345678");
1683    }
1684
1685    #[test]
1686    fn css_color_valid_rgb() {
1687        assert_eq!(
1688            sanitize_css_color("rgb(255, 0, 0)").unwrap(),
1689            "rgb(255, 0, 0)"
1690        );
1691        assert_eq!(
1692            sanitize_css_color("rgba(0, 0, 0, 0.5)").unwrap(),
1693            "rgba(0, 0, 0, 0.5)"
1694        );
1695    }
1696
1697    #[test]
1698    fn css_color_valid_named() {
1699        assert_eq!(sanitize_css_color("red").unwrap(), "red");
1700        assert_eq!(sanitize_css_color("transparent").unwrap(), "transparent");
1701    }
1702
1703    #[test]
1704    fn css_color_valid_hsl() {
1705        assert_eq!(
1706            sanitize_css_color("hsl(120, 50%, 50%)").unwrap(),
1707            "hsl(120, 50%, 50%)"
1708        );
1709    }
1710
1711    #[test]
1712    fn css_color_rejects_too_long() {
1713        let long = "a".repeat(101);
1714        assert!(sanitize_css_color(&long).is_err());
1715    }
1716
1717    #[test]
1718    fn css_color_rejects_backslash_escapes() {
1719        assert!(sanitize_css_color(r"red\00").is_err());
1720        assert!(sanitize_css_color(r"\72\65\64").is_err());
1721    }
1722
1723    #[test]
1724    fn css_color_rejects_url_injection() {
1725        assert!(sanitize_css_color("url(http://evil.com)").is_err());
1726        assert!(sanitize_css_color("URL(http://evil.com)").is_err());
1727    }
1728
1729    #[test]
1730    fn css_color_rejects_expression_injection() {
1731        assert!(sanitize_css_color("expression(alert(1))").is_err());
1732        assert!(sanitize_css_color("EXPRESSION(alert(1))").is_err());
1733    }
1734
1735    #[test]
1736    fn css_color_rejects_import() {
1737        assert!(sanitize_css_color("@import url(evil.css)").is_err());
1738    }
1739
1740    #[test]
1741    fn css_color_rejects_semicolons_and_braces() {
1742        assert!(sanitize_css_color("red; background: url(evil)").is_err());
1743        assert!(sanitize_css_color("red} body { color: blue").is_err());
1744    }
1745
1746    #[test]
1747    fn css_color_rejects_special_chars() {
1748        assert!(sanitize_css_color("red<script>").is_err());
1749        assert!(sanitize_css_color("red\"onload=alert").is_err());
1750        assert!(sanitize_css_color("red'onclick=alert").is_err());
1751    }
1752
1753    #[test]
1754    fn css_color_trims_whitespace() {
1755        assert_eq!(sanitize_css_color("  red  ").unwrap(), "red");
1756    }
1757
1758    #[test]
1759    fn css_color_empty_string() {
1760        assert_eq!(sanitize_css_color("").unwrap(), "");
1761    }
1762}