mod backend_params;
mod compound_params;
mod helpers;
mod other_params;
mod server;
mod verification_params;
mod webview_params;
mod window_params;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use rmcp::handler::server::tool::ToolCallContext;
use rmcp::handler::server::wrapper::Parameters;
use rmcp::model::{
AnnotateAble, CallToolRequestParams, CallToolResult, Content, ListResourcesResult,
ListToolsResult, PaginatedRequestParams, RawResource, ReadResourceRequestParams,
ReadResourceResult, ResourceContents, ServerCapabilities, ServerInfo, SubscribeRequestParams,
Tool, UnsubscribeRequestParams,
};
use rmcp::service::RequestContext;
use rmcp::{ErrorData, RoleServer, ServerHandler, tool, tool_router};
use tokio::sync::Mutex;
use crate::VictauriState;
use crate::bridge::WebviewBridge;
use helpers::{
js_string, json_result, missing_param, sanitize_css_color, tool_disabled, tool_error,
validate_url,
};
pub use backend_params::*;
pub use compound_params::*;
pub use other_params::{
FindElementsParams, ResolveCommandParams, SemanticAssertParams, WaitCondition, WaitForParams,
};
pub use server::*;
pub use verification_params::*;
pub use webview_params::*;
pub use window_params::*;
pub(crate) const MAX_PENDING_EVALS: usize = 100;
const RESOURCE_URI_IPC_LOG: &str = "victauri://ipc-log";
const RESOURCE_URI_WINDOWS: &str = "victauri://windows";
const RESOURCE_URI_STATE: &str = "victauri://state";
const BRIDGE_VERSION: &str = "0.3.0";
#[derive(Clone)]
pub struct VictauriMcpHandler {
state: Arc<VictauriState>,
bridge: Arc<dyn WebviewBridge>,
subscriptions: Arc<Mutex<HashSet<String>>>,
bridge_checked: Arc<AtomicBool>,
}
#[tool_router]
impl VictauriMcpHandler {
#[tool(
description = "Evaluate JavaScript in the Tauri webview and return the result. Async expressions are wrapped automatically.",
annotations(
read_only_hint = false,
destructive_hint = true,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn eval_js(&self, Parameters(params): Parameters<EvalJsParams>) -> CallToolResult {
if !self.state.privacy.is_tool_enabled("eval_js") {
return tool_disabled("eval_js");
}
match self
.eval_with_return(¶ms.code, params.webview_label.as_deref())
.await
{
Ok(result) => self.redact_result(result),
Err(e) => tool_error(e),
}
}
#[tool(
description = "Get the DOM snapshot with stable ref handles. Default: compact accessible text (70-80%% fewer tokens). Set format=\"json\" for full tree. Returns tree + stale_refs (refs invalidated since last snapshot).",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn dom_snapshot(&self, Parameters(params): Parameters<SnapshotParams>) -> CallToolResult {
let format = params.format.unwrap_or(SnapshotFormat::Compact);
let format_str = match format {
SnapshotFormat::Compact => "compact",
SnapshotFormat::Json => "json",
};
let code = format!(
"return window.__VICTAURI__?.snapshot({})",
js_string(format_str)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
#[tool(
description = "Search for elements by text, role, test_id, CSS selector, or accessible name without a full snapshot. Returns lightweight matches with ref handles.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn find_elements(
&self,
Parameters(params): Parameters<FindElementsParams>,
) -> CallToolResult {
let mut parts: Vec<String> = Vec::new();
if let Some(t) = ¶ms.text {
parts.push(format!("text: {}", js_string(t)));
}
if let Some(r) = ¶ms.role {
parts.push(format!("role: {}", js_string(r)));
}
if let Some(tid) = ¶ms.test_id {
parts.push(format!("test_id: {}", js_string(tid)));
}
if let Some(c) = ¶ms.css {
parts.push(format!("css: {}", js_string(c)));
}
if let Some(n) = ¶ms.name {
parts.push(format!("name: {}", js_string(n)));
}
if let Some(max) = params.max_results {
parts.push(format!("max_results: {max}"));
}
let code = format!(
"return window.__VICTAURI__?.findElements({{ {} }})",
parts.join(", ")
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
#[tool(
description = "Invoke a registered Tauri command via IPC, just like the frontend would. Goes through the real IPC pipeline so calls are logged and verifiable. Returns the command's result. Subject to privacy command filtering.",
annotations(
read_only_hint = false,
destructive_hint = true,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn invoke_command(
&self,
Parameters(params): Parameters<InvokeCommandParams>,
) -> CallToolResult {
if !self.state.privacy.is_command_allowed(¶ms.command) {
return tool_error(format!(
"command '{}' is blocked by privacy configuration",
params.command
));
}
let args_json = params.args.unwrap_or(serde_json::json!({}));
let args_str = serde_json::to_string(&args_json).unwrap_or_else(|_| "{}".to_string());
let code = format!(
"return window.__TAURI__.core.invoke({}, {args_str})",
js_string(¶ms.command)
);
match self
.eval_with_return(&code, params.webview_label.as_deref())
.await
{
Ok(result) => self.redact_result(result),
Err(e) => tool_error(format!("invoke_command failed: {e}")),
}
}
#[tool(
description = "Capture a screenshot of a Tauri window as a base64-encoded PNG image. Works on Windows (PrintWindow), macOS (CGWindowListCreateImage), and Linux (X11/Wayland).",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn screenshot(&self, Parameters(params): Parameters<ScreenshotParams>) -> CallToolResult {
self.track_tool_call();
if !self.state.privacy.is_tool_enabled("screenshot") {
return tool_disabled("screenshot");
}
match self
.bridge
.get_native_handle(params.window_label.as_deref())
{
Ok(hwnd) => match crate::screenshot::capture_window(hwnd).await {
Ok(png_bytes) => {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD.encode(&png_bytes);
CallToolResult::success(vec![Content::image(b64, "image/png")])
}
Err(e) => tool_error(format!("screenshot capture failed: {e}")),
},
Err(e) => tool_error(format!("cannot get window handle: {e}")),
}
}
#[tool(
description = "Compare frontend state (evaluated via JS expression) against backend state to detect divergences. Returns a VerificationResult with any mismatches.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn verify_state(
&self,
Parameters(params): Parameters<VerifyStateParams>,
) -> CallToolResult {
let code = format!("return ({})", params.frontend_expr);
let frontend_json = match self
.eval_with_return(&code, params.webview_label.as_deref())
.await
{
Ok(result) => result,
Err(e) => return tool_error(format!("failed to evaluate frontend expression: {e}")),
};
let frontend_state: serde_json::Value = match serde_json::from_str(&frontend_json) {
Ok(v) => v,
Err(e) => {
return tool_error(format!(
"frontend expression did not return valid JSON: {e}"
));
}
};
let result = victauri_core::verify_state(frontend_state, params.backend_state);
json_result(&result)
}
#[tool(
description = "Detect ghost commands — commands invoked from the frontend that have no backend handler, or registered backend commands never called. Reads from the JS-side IPC interception log.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn detect_ghost_commands(
&self,
Parameters(params): Parameters<GhostCommandParams>,
) -> CallToolResult {
let code = "return window.__VICTAURI__?.getIpcLog()";
let ipc_json = match self
.eval_with_return(code, params.webview_label.as_deref())
.await
{
Ok(r) => r,
Err(e) => return tool_error(format!("failed to read IPC log: {e}")),
};
let ipc_calls: Vec<serde_json::Value> = match serde_json::from_str(&ipc_json) {
Ok(v) => v,
Err(e) => return tool_error(format!("failed to parse IPC log JSON: {e}")),
};
let frontend_commands: Vec<String> = ipc_calls
.iter()
.filter_map(|c| c.get("command").and_then(|v| v.as_str()).map(String::from))
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let report = victauri_core::detect_ghost_commands(&frontend_commands, &self.state.registry);
json_result(&report)
}
#[tool(
description = "Check IPC round-trip integrity: find stale (stuck) pending calls and errored calls. Returns health status and lists of problematic IPC calls.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn check_ipc_integrity(
&self,
Parameters(params): Parameters<IpcIntegrityParams>,
) -> CallToolResult {
let threshold = params.stale_threshold_ms.unwrap_or(5000);
let code = format!(
r"return (function() {{
var log = window.__VICTAURI__?.getIpcLog() || [];
var now = Date.now();
var threshold = {threshold};
var pending = log.filter(function(c) {{ return c.status === 'pending'; }});
var stale = pending.filter(function(c) {{ return (now - c.timestamp) > threshold; }});
var errored = log.filter(function(c) {{ return c.status === 'error'; }});
return {{
healthy: stale.length === 0 && errored.length === 0,
total_calls: log.length,
pending_count: pending.length,
stale_count: stale.length,
error_count: errored.length,
stale_calls: stale.slice(0, 20),
errored_calls: errored.slice(0, 20)
}};
}})()"
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
#[tool(
description = "Wait for a condition to be met. Polls at regular intervals until satisfied or timeout. Conditions: text (text appears), text_gone (text disappears), selector (CSS selector matches), selector_gone, url (URL contains value), ipc_idle (no pending IPC calls), network_idle (no pending network requests).",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn wait_for(&self, Parameters(params): Parameters<WaitForParams>) -> CallToolResult {
let value = params
.value
.as_ref()
.map_or_else(|| "null".to_string(), |v| js_string(v));
let timeout_ms = params.timeout_ms.unwrap_or(10000);
let poll = params.poll_ms.unwrap_or(200);
let code = format!(
"return window.__VICTAURI__?.waitFor({{ condition: {}, value: {value}, timeout_ms: {timeout_ms}, poll_ms: {poll} }})",
js_string(params.condition.as_str())
);
let eval_timeout = std::time::Duration::from_millis(timeout_ms + 5000);
match self
.eval_with_return_timeout(&code, params.webview_label.as_deref(), eval_timeout)
.await
{
Ok(result) => CallToolResult::success(vec![Content::text(result)]),
Err(e) => tool_error(e),
}
}
#[tool(
description = "Run a semantic assertion: evaluate a JS expression and check the result against an expected condition. Conditions: equals, not_equals, contains, greater_than, less_than, truthy, falsy, exists, type_is.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn assert_semantic(
&self,
Parameters(params): Parameters<SemanticAssertParams>,
) -> CallToolResult {
let code = format!("return ({})", params.expression);
let actual_json = match self
.eval_with_return(&code, params.webview_label.as_deref())
.await
{
Ok(result) => result,
Err(e) => return tool_error(format!("failed to evaluate expression: {e}")),
};
let actual: serde_json::Value = match serde_json::from_str(&actual_json) {
Ok(v) => v,
Err(e) => return tool_error(format!("expression did not return valid JSON: {e}")),
};
let assertion = victauri_core::SemanticAssertion {
label: params.label,
condition: params.condition,
expected: params.expected,
};
let result = victauri_core::evaluate_assertion(actual, &assertion);
json_result(&result)
}
#[tool(
description = "Resolve a natural language query to matching Tauri commands. Returns scored results ranked by relevance, using command names, descriptions, intents, categories, and examples.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn resolve_command(
&self,
Parameters(params): Parameters<ResolveCommandParams>,
) -> CallToolResult {
self.track_tool_call();
let limit = params.limit.unwrap_or(5);
let mut results = self.state.registry.resolve(¶ms.query);
results.truncate(limit);
json_result(&results)
}
#[tool(
description = "List or search all registered Tauri commands with their argument schemas. Pass query to filter by name/description substring. Commands are registered via #[inspectable] macro.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn get_registry(&self, Parameters(params): Parameters<RegistryParams>) -> CallToolResult {
self.track_tool_call();
let commands = match params.query {
Some(q) => self.state.registry.search(&q),
None => self.state.registry.list(),
};
json_result(&commands)
}
#[tool(
description = "Get real-time process memory statistics from the OS (working set, page file usage). On Windows returns detailed metrics; on Linux returns virtual/resident size.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn get_memory_stats(&self) -> CallToolResult {
self.track_tool_call();
let stats = crate::memory::current_stats();
json_result(&stats)
}
#[tool(
description = "Inspect the Victauri plugin's own configuration: port, enabled/disabled tools, command filters, privacy settings, capacities, and version. Useful for agents to understand their capabilities before acting.",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn get_plugin_info(&self) -> CallToolResult {
self.track_tool_call();
let disabled: Vec<&str> = self
.state
.privacy
.disabled_tools
.iter()
.map(std::string::String::as_str)
.collect();
let blocklist: Vec<&str> = self
.state
.privacy
.command_blocklist
.iter()
.map(std::string::String::as_str)
.collect();
let allowlist: Option<Vec<&str>> = self
.state
.privacy
.command_allowlist
.as_ref()
.map(|s| s.iter().map(std::string::String::as_str).collect());
let all_tools = Self::tool_router().list_all();
let enabled_tools: Vec<&str> = all_tools
.iter()
.filter(|t| self.state.privacy.is_tool_enabled(t.name.as_ref()))
.map(|t| t.name.as_ref())
.collect();
let result = serde_json::json!({
"version": env!("CARGO_PKG_VERSION"),
"bridge_version": BRIDGE_VERSION,
"port": self.state.port.load(Ordering::Relaxed),
"tools": {
"total": all_tools.len(),
"enabled": enabled_tools.len(),
"enabled_list": enabled_tools,
"disabled_list": disabled,
},
"commands": {
"allowlist": allowlist,
"blocklist": blocklist,
},
"privacy": {
"redaction_enabled": self.state.privacy.redaction_enabled,
},
"capacities": {
"event_log": self.state.event_log.capacity(),
"eval_timeout_secs": self.state.eval_timeout.as_secs(),
},
"registered_commands": self.state.registry.count(),
"tool_invocations": self.state.tool_invocations.load(std::sync::atomic::Ordering::Relaxed),
"uptime_secs": self.state.started_at.elapsed().as_secs(),
});
json_result(&result)
}
#[tool(
description = "DOM element interactions. Actions: click, double_click, hover, focus, scroll_into_view, select_option. Requires ref_id from a dom_snapshot for most actions.",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn interact(&self, Parameters(params): Parameters<InteractParams>) -> CallToolResult {
match params.action {
InteractAction::Click => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "click");
};
let code = format!("return window.__VICTAURI__?.click({})", js_string(ref_id));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InteractAction::DoubleClick => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "double_click");
};
let code = format!(
"return window.__VICTAURI__?.doubleClick({})",
js_string(ref_id)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InteractAction::Hover => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "hover");
};
let code = format!("return window.__VICTAURI__?.hover({})", js_string(ref_id));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InteractAction::Focus => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "focus");
};
let code = format!(
"return window.__VICTAURI__?.focusElement({})",
js_string(ref_id)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InteractAction::ScrollIntoView => {
let ref_arg = params
.ref_id
.as_ref()
.map_or_else(|| "null".to_string(), |r| js_string(r));
let x = params.x.unwrap_or(0.0);
let y = params.y.unwrap_or(0.0);
let code = format!("return window.__VICTAURI__?.scrollTo({ref_arg}, {x}, {y})");
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InteractAction::SelectOption => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "select_option");
};
let values = params.values.as_deref().unwrap_or(&[]);
let values_json =
serde_json::to_string(values).unwrap_or_else(|_| "[]".to_string());
let code = format!(
"return window.__VICTAURI__?.selectOption({}, {})",
js_string(ref_id),
values_json
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
}
}
#[tool(
description = "Text and keyboard input. Actions: fill (set input value), type_text (character-by-character typing), press_key (trigger a keyboard key). Subject to privacy controls.",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn input(&self, Parameters(params): Parameters<InputParams>) -> CallToolResult {
match params.action {
InputAction::Fill => {
if !self.state.privacy.is_tool_enabled("fill") {
return tool_disabled("fill");
}
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "fill");
};
let Some(value) = ¶ms.value else {
return missing_param("value", "fill");
};
let code = format!(
"return window.__VICTAURI__?.fill({}, {})",
js_string(ref_id),
js_string(value)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InputAction::TypeText => {
if !self.state.privacy.is_tool_enabled("type_text") {
return tool_disabled("type_text");
}
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "type_text");
};
let Some(text) = ¶ms.text else {
return missing_param("text", "type_text");
};
let code = format!(
"return window.__VICTAURI__?.type({}, {})",
js_string(ref_id),
js_string(text)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InputAction::PressKey => {
let Some(key) = ¶ms.key else {
return missing_param("key", "press_key");
};
let code = format!("return window.__VICTAURI__?.pressKey({})", js_string(key));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
}
}
#[tool(
description = "Window management. Actions: get_state (window positions/sizes/visibility), list (all window labels), manage (minimize/maximize/close/focus/show/hide/fullscreen/always_on_top), resize, move_to, set_title.",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn window(&self, Parameters(params): Parameters<WindowParams>) -> CallToolResult {
self.track_tool_call();
match params.action {
WindowAction::GetState => {
let states = self.bridge.get_window_states(params.label.as_deref());
json_result(&states)
}
WindowAction::List => {
let labels = self.bridge.list_window_labels();
json_result(&labels)
}
WindowAction::Manage => {
let Some(manage_action) = ¶ms.manage_action else {
return missing_param("manage_action", "manage");
};
match self
.bridge
.manage_window(params.label.as_deref(), manage_action.as_str())
{
Ok(msg) => CallToolResult::success(vec![Content::text(msg)]),
Err(e) => tool_error(e),
}
}
WindowAction::Resize => {
let Some(width) = params.width else {
return missing_param("width", "resize");
};
let Some(height) = params.height else {
return missing_param("height", "resize");
};
match self
.bridge
.resize_window(params.label.as_deref(), width, height)
{
Ok(()) => {
let result =
serde_json::json!({"ok": true, "width": width, "height": height});
CallToolResult::success(vec![Content::text(result.to_string())])
}
Err(e) => tool_error(e),
}
}
WindowAction::MoveTo => {
let Some(x) = params.x else {
return missing_param("x", "move_to");
};
let Some(y) = params.y else {
return missing_param("y", "move_to");
};
match self.bridge.move_window(params.label.as_deref(), x, y) {
Ok(()) => {
let result = serde_json::json!({"ok": true, "x": x, "y": y});
CallToolResult::success(vec![Content::text(result.to_string())])
}
Err(e) => tool_error(e),
}
}
WindowAction::SetTitle => {
let Some(title) = ¶ms.title else {
return missing_param("title", "set_title");
};
match self.bridge.set_window_title(params.label.as_deref(), title) {
Ok(()) => {
let result = serde_json::json!({"ok": true, "title": title});
CallToolResult::success(vec![Content::text(result.to_string())])
}
Err(e) => tool_error(e),
}
}
}
}
#[tool(
description = "Browser storage operations. Actions: get (read localStorage/sessionStorage), set (write), delete (remove key), get_cookies. Subject to privacy controls for set and delete.",
annotations(
read_only_hint = false,
destructive_hint = true,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn storage(&self, Parameters(params): Parameters<StorageParams>) -> CallToolResult {
match params.action {
StorageAction::Get => {
let method = match params.storage_type.unwrap_or(StorageType::Local) {
StorageType::Session => "getSessionStorage",
StorageType::Local => "getLocalStorage",
};
let key_arg = params
.key
.as_ref()
.map(|k| js_string(k))
.unwrap_or_default();
let code = format!("return window.__VICTAURI__?.{method}({key_arg})");
self.eval_bridge_redacted(&code, params.webview_label.as_deref())
.await
}
StorageAction::Set => {
if !self.state.privacy.is_tool_enabled("set_storage") {
return tool_disabled("set_storage");
}
let method = match params.storage_type.unwrap_or(StorageType::Local) {
StorageType::Session => "setSessionStorage",
StorageType::Local => "setLocalStorage",
};
let Some(key) = ¶ms.key else {
return missing_param("key", "set");
};
let value = params
.value
.as_ref()
.cloned()
.unwrap_or(serde_json::Value::Null);
let value_json =
serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string());
let code = format!(
"return window.__VICTAURI__?.{method}({}, {value_json})",
js_string(key)
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
StorageAction::Delete => {
if !self.state.privacy.is_tool_enabled("delete_storage") {
return tool_disabled("delete_storage");
}
let method = match params.storage_type.unwrap_or(StorageType::Local) {
StorageType::Session => "deleteSessionStorage",
StorageType::Local => "deleteLocalStorage",
};
let Some(key) = ¶ms.key else {
return missing_param("key", "delete");
};
let code = format!("return window.__VICTAURI__?.{method}({})", js_string(key));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
StorageAction::GetCookies => {
self.eval_bridge_redacted(
"return window.__VICTAURI__?.getCookies()",
params.webview_label.as_deref(),
)
.await
}
}
}
#[tool(
description = "Navigation and dialog control. Actions: go_to (navigate to URL), go_back (browser back), get_history (navigation log), set_dialog_response (auto-respond to alert/confirm/prompt), get_dialog_log (captured dialog events). Subject to privacy controls for go_to and set_dialog_response.",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn navigate(&self, Parameters(params): Parameters<NavigateParams>) -> CallToolResult {
match params.action {
NavigateAction::GoTo => {
if !self.state.privacy.is_tool_enabled("navigate") {
return tool_disabled("navigate");
}
let Some(url) = ¶ms.url else {
return missing_param("url", "go_to");
};
if let Err(e) = validate_url(url) {
return tool_error(e);
}
let code = format!("return window.__VICTAURI__?.navigate({})", js_string(url));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
NavigateAction::GoBack => {
self.eval_bridge(
"return window.__VICTAURI__?.navigateBack()",
params.webview_label.as_deref(),
)
.await
}
NavigateAction::GetHistory => {
self.eval_bridge(
"return window.__VICTAURI__?.getNavigationLog()",
params.webview_label.as_deref(),
)
.await
}
NavigateAction::SetDialogResponse => {
if !self.state.privacy.is_tool_enabled("set_dialog_response") {
return tool_disabled("set_dialog_response");
}
let Some(dialog_type) = params.dialog_type else {
return missing_param("dialog_type", "set_dialog_response");
};
let Some(dialog_action) = params.dialog_action else {
return missing_param("dialog_action", "set_dialog_response");
};
let text_arg = params
.text
.as_ref()
.map_or_else(|| "undefined".to_string(), |t| js_string(t));
let code = format!(
"return window.__VICTAURI__?.setDialogAutoResponse({}, {}, {text_arg})",
js_string(dialog_type.as_str()),
js_string(dialog_action.as_str())
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
NavigateAction::GetDialogLog => {
self.eval_bridge(
"return window.__VICTAURI__?.getDialogLog()",
params.webview_label.as_deref(),
)
.await
}
}
}
#[tool(
description = "Time-travel recording. Actions: start (begin recording), stop (end and return session), checkpoint (save state snapshot), list_checkpoints, get_events (since index), events_between (two checkpoints), get_replay (IPC replay sequence), export (session as JSON), import (load session from JSON).",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn recording(&self, Parameters(params): Parameters<RecordingParams>) -> CallToolResult {
self.track_tool_call();
match params.action {
RecordingAction::Start => {
let session_id = params
.session_id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
match self.state.recorder.start(session_id.clone()) {
Ok(()) => {
let result = serde_json::json!({
"started": true,
"session_id": session_id,
});
CallToolResult::success(vec![Content::text(result.to_string())])
}
Err(e) => tool_error(e.to_string()),
}
}
RecordingAction::Stop => match self.state.recorder.stop() {
Some(session) => json_result(&session),
None => tool_error("no recording is active"),
},
RecordingAction::Checkpoint => {
let Some(id) = params.checkpoint_id else {
return missing_param("checkpoint_id", "checkpoint");
};
let state = params.state.unwrap_or(serde_json::Value::Null);
match self
.state
.recorder
.checkpoint(id.clone(), params.checkpoint_label, state)
{
Ok(()) => {
let result = serde_json::json!({
"created": true,
"checkpoint_id": id,
"event_index": self.state.recorder.event_count(),
});
CallToolResult::success(vec![Content::text(result.to_string())])
}
Err(e) => tool_error(e.to_string()),
}
}
RecordingAction::ListCheckpoints => {
let checkpoints = self.state.recorder.get_checkpoints();
json_result(&checkpoints)
}
RecordingAction::GetEvents => {
let events = self
.state
.recorder
.events_since(params.since_index.unwrap_or(0));
json_result(&events)
}
RecordingAction::EventsBetween => {
let Some(from) = ¶ms.from else {
return missing_param("from", "events_between");
};
let Some(to) = ¶ms.to else {
return missing_param("to", "events_between");
};
match self.state.recorder.events_between_checkpoints(from, to) {
Ok(events) => json_result(&events),
Err(e) => tool_error(e.to_string()),
}
}
RecordingAction::GetReplay => {
let calls = self.state.recorder.ipc_replay_sequence();
json_result(&calls)
}
RecordingAction::Export => match self.state.recorder.export() {
Some(s) => {
let json = serde_json::to_string_pretty(&s)
.unwrap_or_else(|e| format!("{{\"error\": \"{e}\"}}"));
CallToolResult::success(vec![Content::text(json)])
}
None => tool_error("no recording is active — start one first"),
},
RecordingAction::Import => {
let Some(session_json) = ¶ms.session_json else {
return missing_param("session_json", "import");
};
let session: victauri_core::RecordedSession =
match serde_json::from_str(session_json) {
Ok(s) => s,
Err(e) => return tool_error(format!("invalid session JSON: {e}")),
};
let result = serde_json::json!({
"imported": true,
"session_id": session.id,
"event_count": session.events.len(),
"checkpoint_count": session.checkpoints.len(),
"started_at": session.started_at.to_rfc3339(),
});
self.state.recorder.import(session);
CallToolResult::success(vec![Content::text(result.to_string())])
}
}
}
#[tool(
description = "CSS and visual inspection. Actions: get_styles (computed CSS for element), get_bounding_boxes (layout rects), highlight (debug overlay), clear_highlights, audit_accessibility (a11y audit), get_performance (timing/heap/DOM metrics).",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn inspect(&self, Parameters(params): Parameters<InspectParams>) -> CallToolResult {
match params.action {
InspectAction::GetStyles => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "get_styles");
};
let props_arg = match ¶ms.properties {
Some(props) => {
let arr: Vec<String> = props.iter().map(|p| js_string(p)).collect();
format!("[{}]", arr.join(","))
}
None => "null".to_string(),
};
let code = format!(
"return window.__VICTAURI__?.getStyles({}, {})",
js_string(ref_id),
props_arg
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InspectAction::GetBoundingBoxes => {
let Some(ref_ids) = ¶ms.ref_ids else {
return missing_param("ref_ids", "get_bounding_boxes");
};
let refs: Vec<String> = ref_ids.iter().map(|r| js_string(r)).collect();
let code = format!(
"return window.__VICTAURI__?.getBoundingBoxes([{}])",
refs.join(",")
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InspectAction::Highlight => {
let Some(ref_id) = ¶ms.ref_id else {
return missing_param("ref_id", "highlight");
};
let color_arg = match ¶ms.color {
Some(c) => match sanitize_css_color(c) {
Ok(safe) => format!("\"{safe}\""),
Err(e) => return tool_error(e),
},
None => "null".to_string(),
};
let label_arg = match ¶ms.label {
Some(l) => js_string(l),
None => "null".to_string(),
};
let code = format!(
"return window.__VICTAURI__?.highlightElement({}, {}, {})",
js_string(ref_id),
color_arg,
label_arg
);
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
InspectAction::ClearHighlights => {
self.eval_bridge(
"return window.__VICTAURI__?.clearHighlights()",
params.webview_label.as_deref(),
)
.await
}
InspectAction::AuditAccessibility => {
self.eval_bridge(
"return window.__VICTAURI__?.auditAccessibility()",
params.webview_label.as_deref(),
)
.await
}
InspectAction::GetPerformance => {
self.eval_bridge(
"return window.__VICTAURI__?.getPerformanceMetrics()",
params.webview_label.as_deref(),
)
.await
}
}
}
#[tool(
description = "CSS injection. Actions: inject (add custom CSS to page), remove (remove previously injected CSS). Subject to privacy controls.",
annotations(
read_only_hint = false,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn css(&self, Parameters(params): Parameters<CssParams>) -> CallToolResult {
match params.action {
CssAction::Inject => {
if !self.state.privacy.is_tool_enabled("inject_css") {
return tool_disabled("inject_css");
}
let Some(css) = ¶ms.css else {
return missing_param("css", "inject");
};
let code = format!("return window.__VICTAURI__?.injectCss({})", js_string(css));
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
CssAction::Remove => {
self.eval_bridge(
"return window.__VICTAURI__?.removeInjectedCss()",
params.webview_label.as_deref(),
)
.await
}
}
}
#[tool(
description = "Application logs and monitoring. Actions: console (captured console.log/warn/error), network (intercepted fetch/XHR), ipc (IPC call log), navigation (URL change history), dialogs (alert/confirm/prompt events), events (combined event stream), slow_ipc (find slow IPC calls).",
annotations(
read_only_hint = true,
destructive_hint = false,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn logs(&self, Parameters(params): Parameters<LogsParams>) -> CallToolResult {
match params.action {
LogsAction::Console => {
let since_arg = params.since.map(|ts| format!("{ts}")).unwrap_or_default();
let code = if since_arg.is_empty() {
"return window.__VICTAURI__?.getConsoleLogs()".to_string()
} else {
format!("return window.__VICTAURI__?.getConsoleLogs({since_arg})")
};
self.eval_bridge_redacted(&code, params.webview_label.as_deref())
.await
}
LogsAction::Network => {
let filter_arg = params
.filter
.as_ref()
.map_or_else(|| "null".to_string(), |f| js_string(f));
let limit_arg = params
.limit
.map_or_else(|| "null".to_string(), |l| l.to_string());
let code =
format!("return window.__VICTAURI__?.getNetworkLog({filter_arg}, {limit_arg})");
self.eval_bridge_redacted(&code, params.webview_label.as_deref())
.await
}
LogsAction::Ipc => {
let limit_arg = params.limit.map(|l| format!("{l}")).unwrap_or_default();
let code = if limit_arg.is_empty() {
"return window.__VICTAURI__?.getIpcLog()".to_string()
} else {
format!("return window.__VICTAURI__?.getIpcLog({limit_arg})")
};
self.eval_bridge_redacted(&code, params.webview_label.as_deref())
.await
}
LogsAction::Navigation => {
self.eval_bridge(
"return window.__VICTAURI__?.getNavigationLog()",
params.webview_label.as_deref(),
)
.await
}
LogsAction::Dialogs => {
self.eval_bridge(
"return window.__VICTAURI__?.getDialogLog()",
params.webview_label.as_deref(),
)
.await
}
LogsAction::Events => {
let since_arg = params.since.map(|ts| format!("{ts}")).unwrap_or_default();
let code = if since_arg.is_empty() {
"return window.__VICTAURI__?.getEventStream()".to_string()
} else {
format!("return window.__VICTAURI__?.getEventStream({since_arg})")
};
self.eval_bridge(&code, params.webview_label.as_deref())
.await
}
LogsAction::SlowIpc => {
let Some(threshold) = params.threshold_ms else {
return missing_param("threshold_ms", "slow_ipc");
};
let limit = params.limit.unwrap_or(20);
let code = format!(
r"return (function() {{
var log = window.__VICTAURI__?.getIpcLog() || [];
var slow = log.filter(function(c) {{ return (c.duration_ms || 0) > {threshold}; }});
slow.sort(function(a, b) {{ return (b.duration_ms || 0) - (a.duration_ms || 0); }});
return {{ threshold_ms: {threshold}, count: Math.min(slow.length, {limit}), calls: slow.slice(0, {limit}) }};
}})()",
);
self.eval_bridge_redacted(&code, None).await
}
}
}
}
impl VictauriMcpHandler {
pub fn new(state: Arc<VictauriState>, bridge: Arc<dyn WebviewBridge>) -> Self {
Self {
state,
bridge,
subscriptions: Arc::new(Mutex::new(HashSet::new())),
bridge_checked: Arc::new(AtomicBool::new(false)),
}
}
fn track_tool_call(&self) {
self.state.tool_invocations.fetch_add(1, Ordering::Relaxed);
}
async fn eval_bridge(&self, code: &str, webview_label: Option<&str>) -> CallToolResult {
match self.eval_with_return(code, webview_label).await {
Ok(result) => CallToolResult::success(vec![Content::text(result)]),
Err(e) => tool_error(e),
}
}
async fn eval_bridge_redacted(
&self,
code: &str,
webview_label: Option<&str>,
) -> CallToolResult {
match self.eval_with_return(code, webview_label).await {
Ok(result) => self.redact_result(result),
Err(e) => tool_error(e),
}
}
fn redact_result(&self, output: String) -> CallToolResult {
let redacted = self.state.privacy.redact_output(&output);
CallToolResult::success(vec![Content::text(redacted)])
}
async fn eval_with_return(
&self,
code: &str,
webview_label: Option<&str>,
) -> Result<String, String> {
self.eval_with_return_timeout(code, webview_label, self.state.eval_timeout)
.await
}
async fn eval_with_return_timeout(
&self,
code: &str,
webview_label: Option<&str>,
timeout: std::time::Duration,
) -> Result<String, String> {
self.track_tool_call();
let id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.state.pending_evals.lock().await;
if pending.len() >= MAX_PENDING_EVALS {
return Err(format!(
"too many concurrent eval requests (limit: {MAX_PENDING_EVALS})"
));
}
pending.insert(id.clone(), tx);
}
let code = code.trim();
let needs_return = !code.starts_with("return ")
&& !code.starts_with("return;")
&& !code.starts_with('{')
&& !code.starts_with("if ")
&& !code.starts_with("if(")
&& !code.starts_with("for ")
&& !code.starts_with("for(")
&& !code.starts_with("while ")
&& !code.starts_with("while(")
&& !code.starts_with("switch ")
&& !code.starts_with("try ")
&& !code.starts_with("const ")
&& !code.starts_with("let ")
&& !code.starts_with("var ")
&& !code.starts_with("function ")
&& !code.starts_with("class ")
&& !code.starts_with("throw ");
let code = if needs_return {
format!("return {code}")
} else {
code.to_string()
};
let inject = format!(
r"
(async () => {{
try {{
const __result = await (async () => {{ {code} }})();
await window.__TAURI__.core.invoke('plugin:victauri|victauri_eval_callback', {{
id: '{id}',
result: JSON.stringify(__result)
}});
}} catch (e) {{
await window.__TAURI__.core.invoke('plugin:victauri|victauri_eval_callback', {{
id: '{id}',
result: JSON.stringify({{ __error: e.message }})
}});
}}
}})();
"
);
if let Err(e) = self.bridge.eval_webview(webview_label, &inject) {
self.state.pending_evals.lock().await.remove(&id);
return Err(format!("eval injection failed: {e}"));
}
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(result)) => {
self.check_bridge_version_once();
Ok(result)
}
Ok(Err(_)) => Err("eval callback channel closed".to_string()),
Err(_) => {
self.state.pending_evals.lock().await.remove(&id);
Err(format!("eval timed out after {}s", timeout.as_secs()))
}
}
}
fn check_bridge_version_once(&self) {
if self.bridge_checked.swap(true, Ordering::Relaxed) {
return;
}
let handler = self.clone();
tokio::spawn(async move {
match handler
.eval_with_return_timeout(
"window.__VICTAURI__?.version",
None,
std::time::Duration::from_secs(5),
)
.await
{
Ok(v) => {
let v = v.trim_matches('"');
if v == BRIDGE_VERSION {
tracing::debug!("Bridge version verified: {v}");
} else {
tracing::warn!(
"Bridge version mismatch: Rust expects {BRIDGE_VERSION}, JS reports {v}"
);
}
}
Err(e) => tracing::debug!("Bridge version check skipped: {e}"),
}
});
}
}
const SERVER_INSTRUCTIONS: &str = "Victauri gives you X-ray vision and hands inside a running Tauri application. \
Use compound tools with an 'action' parameter to interact with the app: \
'interact' (click, hover, focus, scroll, select), 'input' (fill, type_text, press_key), \
'window' (get_state, list, manage, resize, move_to, set_title), \
'storage' (get, set, delete, get_cookies), 'navigate' (go_to, go_back, get_history, \
set_dialog_response, get_dialog_log), 'recording' (start, stop, checkpoint, list_checkpoints, \
get_events, events_between, get_replay, export, import), 'inspect' (get_styles, \
get_bounding_boxes, highlight, clear_highlights, audit_accessibility, get_performance), \
'css' (inject, remove), 'logs' (console, network, ipc, navigation, dialogs, events, slow_ipc). \
Standalone tools: eval_js, dom_snapshot, invoke_command, screenshot, verify_state, \
detect_ghost_commands, check_ipc_integrity, wait_for, assert_semantic, resolve_command, \
get_registry, get_memory_stats, get_plugin_info.";
impl ServerHandler for VictauriMcpHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(
ServerCapabilities::builder()
.enable_tools()
.enable_resources()
.enable_resources_subscribe()
.build(),
)
.with_instructions(SERVER_INSTRUCTIONS)
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListToolsResult, ErrorData> {
let all_tools = Self::tool_router().list_all();
let filtered: Vec<Tool> = all_tools
.into_iter()
.filter(|t| self.state.privacy.is_tool_enabled(t.name.as_ref()))
.collect();
Ok(ListToolsResult {
tools: filtered,
..Default::default()
})
}
async fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
let tool_name: String = request.name.as_ref().to_owned();
if !self.state.privacy.is_tool_enabled(&tool_name) {
tracing::debug!(tool = %tool_name, "tool call blocked by privacy config");
return Ok(tool_disabled(&tool_name));
}
self.state
.tool_invocations
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let start = std::time::Instant::now();
tracing::debug!(tool = %tool_name, "tool invocation started");
let ctx = ToolCallContext::new(self, request, context);
let result = Self::tool_router().call(ctx).await;
let elapsed = start.elapsed();
tracing::debug!(
tool = %tool_name,
elapsed_ms = elapsed.as_millis() as u64,
is_error = result.as_ref().map_or(true, |r| r.is_error.unwrap_or(false)),
"tool invocation completed"
);
result
}
fn get_tool(&self, name: &str) -> Option<Tool> {
if !self.state.privacy.is_tool_enabled(name) {
return None;
}
Self::tool_router().get(name).cloned()
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, ErrorData> {
Ok(ListResourcesResult {
resources: vec![
RawResource::new(RESOURCE_URI_IPC_LOG, "ipc-log")
.with_description(
"Live IPC call log — all commands invoked between frontend and backend",
)
.with_mime_type("application/json")
.no_annotation(),
RawResource::new(RESOURCE_URI_WINDOWS, "windows")
.with_description(
"Current state of all Tauri windows — position, size, visibility, focus",
)
.with_mime_type("application/json")
.no_annotation(),
RawResource::new(RESOURCE_URI_STATE, "state")
.with_description(
"Victauri plugin state — event count, registered commands, memory stats",
)
.with_mime_type("application/json")
.no_annotation(),
],
..Default::default()
})
}
async fn read_resource(
&self,
request: ReadResourceRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, ErrorData> {
let uri = &request.uri;
let json = match uri.as_str() {
RESOURCE_URI_IPC_LOG => {
if let Ok(json) = self
.eval_with_return("return window.__VICTAURI__?.getIpcLog()", None)
.await
{
json
} else {
let calls = self.state.event_log.ipc_calls();
serde_json::to_string_pretty(&calls)
.map_err(|e| ErrorData::internal_error(e.to_string(), None))?
}
}
RESOURCE_URI_WINDOWS => {
let states = self.bridge.get_window_states(None);
serde_json::to_string_pretty(&states)
.map_err(|e| ErrorData::internal_error(e.to_string(), None))?
}
RESOURCE_URI_STATE => {
let state_json = serde_json::json!({
"events_captured": self.state.event_log.len(),
"commands_registered": self.state.registry.count(),
"memory": crate::memory::current_stats(),
"port": self.state.port.load(Ordering::Relaxed),
});
serde_json::to_string_pretty(&state_json)
.map_err(|e| ErrorData::internal_error(e.to_string(), None))?
}
_ => {
return Err(ErrorData::resource_not_found(
format!("unknown resource: {uri}"),
None,
));
}
};
Ok(ReadResourceResult::new(vec![ResourceContents::text(
json, uri,
)]))
}
async fn subscribe(
&self,
request: SubscribeRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<(), ErrorData> {
let uri = &request.uri;
match uri.as_str() {
RESOURCE_URI_IPC_LOG | RESOURCE_URI_WINDOWS | RESOURCE_URI_STATE => {
self.subscriptions.lock().await.insert(uri.clone());
tracing::info!("Client subscribed to resource: {uri}");
Ok(())
}
_ => Err(ErrorData::resource_not_found(
format!("unknown resource: {uri}"),
None,
)),
}
}
async fn unsubscribe(
&self,
request: UnsubscribeRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<(), ErrorData> {
self.subscriptions.lock().await.remove(&request.uri);
tracing::info!("Client unsubscribed from resource: {}", request.uri);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn js_string_simple() {
assert_eq!(js_string("hello"), "\"hello\"");
}
#[test]
fn js_string_single_quotes() {
let result = js_string("it's a test");
assert!(result.contains("it's a test"));
}
#[test]
fn js_string_double_quotes() {
let result = js_string(r#"say "hello""#);
assert!(result.contains(r#"\""#));
}
#[test]
fn js_string_backslashes() {
let result = js_string(r"path\to\file");
assert!(result.contains(r"\\"));
}
#[test]
fn js_string_newlines_and_tabs() {
let result = js_string("line1\nline2\ttab");
assert!(result.contains(r"\n"));
assert!(result.contains(r"\t"));
assert!(!result.contains('\n'));
}
#[test]
fn js_string_null_bytes() {
let input = String::from_utf8(b"before\x00after".to_vec()).unwrap();
let result = js_string(&input);
assert!(result.contains("\\u0000"));
assert!(!result.contains('\0'));
}
#[test]
fn js_string_template_literal_injection() {
let result = js_string("`${alert(1)}`");
assert!(result.starts_with('"'));
assert!(result.ends_with('"'));
}
#[test]
fn js_string_unicode_separators() {
let result = js_string("a\u{2028}b\u{2029}c");
let decoded: String = serde_json::from_str(&result).unwrap();
assert_eq!(decoded, "a\u{2028}b\u{2029}c");
}
#[test]
fn js_string_empty() {
assert_eq!(js_string(""), "\"\"");
}
#[test]
fn js_string_html_script_close() {
let result = js_string("</script><img onerror=alert(1)>");
assert!(result.starts_with('"'));
let decoded: String = serde_json::from_str(&result).unwrap();
assert_eq!(decoded, "</script><img onerror=alert(1)>");
}
#[test]
fn js_string_very_long() {
let long = "a".repeat(100_000);
let result = js_string(&long);
assert!(result.len() >= 100_002); }
#[test]
fn url_allows_http() {
assert!(validate_url("http://example.com").is_ok());
}
#[test]
fn url_allows_https() {
assert!(validate_url("https://example.com/path?q=1").is_ok());
}
#[test]
fn url_allows_file() {
assert!(validate_url("file:///tmp/test.html").is_ok());
}
#[test]
fn url_blocks_javascript() {
assert!(validate_url("javascript:alert(1)").is_err());
}
#[test]
fn url_blocks_javascript_case_insensitive() {
assert!(validate_url("JAVASCRIPT:alert(1)").is_err());
}
#[test]
fn url_blocks_data_scheme() {
assert!(validate_url("data:text/html,<script>alert(1)</script>").is_err());
}
#[test]
fn url_blocks_vbscript() {
assert!(validate_url("vbscript:MsgBox(1)").is_err());
}
#[test]
fn url_rejects_invalid() {
assert!(validate_url("not a url at all").is_err());
}
#[test]
fn url_strips_control_chars() {
let input = format!("http://example{}com", '\0');
assert!(validate_url(&input).is_ok());
}
#[test]
fn css_color_valid_hex() {
assert_eq!(sanitize_css_color("#ff0000").unwrap(), "#ff0000");
assert_eq!(sanitize_css_color("#FFF").unwrap(), "#FFF");
assert_eq!(sanitize_css_color("#12345678").unwrap(), "#12345678");
}
#[test]
fn css_color_valid_rgb() {
assert_eq!(
sanitize_css_color("rgb(255, 0, 0)").unwrap(),
"rgb(255, 0, 0)"
);
assert_eq!(
sanitize_css_color("rgba(0, 0, 0, 0.5)").unwrap(),
"rgba(0, 0, 0, 0.5)"
);
}
#[test]
fn css_color_valid_named() {
assert_eq!(sanitize_css_color("red").unwrap(), "red");
assert_eq!(sanitize_css_color("transparent").unwrap(), "transparent");
}
#[test]
fn css_color_valid_hsl() {
assert_eq!(
sanitize_css_color("hsl(120, 50%, 50%)").unwrap(),
"hsl(120, 50%, 50%)"
);
}
#[test]
fn css_color_rejects_too_long() {
let long = "a".repeat(101);
assert!(sanitize_css_color(&long).is_err());
}
#[test]
fn css_color_rejects_backslash_escapes() {
assert!(sanitize_css_color(r"red\00").is_err());
assert!(sanitize_css_color(r"\72\65\64").is_err());
}
#[test]
fn css_color_rejects_url_injection() {
assert!(sanitize_css_color("url(http://evil.com)").is_err());
assert!(sanitize_css_color("URL(http://evil.com)").is_err());
}
#[test]
fn css_color_rejects_expression_injection() {
assert!(sanitize_css_color("expression(alert(1))").is_err());
assert!(sanitize_css_color("EXPRESSION(alert(1))").is_err());
}
#[test]
fn css_color_rejects_import() {
assert!(sanitize_css_color("@import url(evil.css)").is_err());
}
#[test]
fn css_color_rejects_semicolons_and_braces() {
assert!(sanitize_css_color("red; background: url(evil)").is_err());
assert!(sanitize_css_color("red} body { color: blue").is_err());
}
#[test]
fn css_color_rejects_special_chars() {
assert!(sanitize_css_color("red<script>").is_err());
assert!(sanitize_css_color("red\"onload=alert").is_err());
assert!(sanitize_css_color("red'onclick=alert").is_err());
}
#[test]
fn css_color_trims_whitespace() {
assert_eq!(sanitize_css_color(" red ").unwrap(), "red");
}
#[test]
fn css_color_empty_string() {
assert_eq!(sanitize_css_color("").unwrap(), "");
}
}