use std::sync::Arc;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use zendriver::{ClickOptions, Key, KeySequence, MouseButton, SpecialKey};
use crate::errors::{McpServerError, map_error};
use crate::selectors::Selector;
use crate::snapshot::html_trim;
use crate::state::SessionState;
use crate::tools::common::{ModifierArg, current_tab, modifiers_to_bits};
use crate::tools::find::resolve;
#[derive(Debug, Serialize, JsonSchema)]
pub struct ActionOutput {
pub ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub snapshot: Option<String>,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct AckOutput {
pub ok: bool,
}
async fn snapshot_now(tab: &zendriver::Tab) -> Result<String, ErrorData> {
let html: String = tab
.evaluate_main("document.documentElement.outerHTML")
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(html_trim::trim(&html))
}
async fn ok_with_snapshot(
tab: &zendriver::Tab,
return_snapshot: bool,
) -> Result<ActionOutput, ErrorData> {
let snapshot = if return_snapshot {
Some(snapshot_now(tab).await?)
} else {
None
};
Ok(ActionOutput { ok: true, snapshot })
}
#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MouseButtonArg {
#[default]
Left,
Middle,
Right,
}
impl From<MouseButtonArg> for MouseButton {
fn from(b: MouseButtonArg) -> Self {
match b {
MouseButtonArg::Left => MouseButton::Left,
MouseButtonArg::Middle => MouseButton::Middle,
MouseButtonArg::Right => MouseButton::Right,
}
}
}
pub fn parse_key(s: &str) -> Result<Key, ErrorData> {
let mut chars = s.chars();
if let (Some(c), None) = (chars.next(), chars.next()) {
return Ok(Key::Char(c));
}
let special = match s.to_ascii_lowercase().as_str() {
"enter" | "return" => SpecialKey::Enter,
"tab" => SpecialKey::Tab,
"escape" | "esc" => SpecialKey::Escape,
"backspace" => SpecialKey::Backspace,
"delete" | "del" => SpecialKey::Delete,
"space" => SpecialKey::Space,
"arrowup" | "up" => SpecialKey::ArrowUp,
"arrowdown" | "down" => SpecialKey::ArrowDown,
"arrowleft" | "left" => SpecialKey::ArrowLeft,
"arrowright" | "right" => SpecialKey::ArrowRight,
"home" => SpecialKey::Home,
"end" => SpecialKey::End,
"pageup" => SpecialKey::PageUp,
"pagedown" => SpecialKey::PageDown,
"insert" | "ins" => SpecialKey::Insert,
"capslock" => SpecialKey::CapsLock,
"numlock" => SpecialKey::NumLock,
"scrolllock" => SpecialKey::ScrollLock,
"printscreen" => SpecialKey::PrintScreen,
"pause" => SpecialKey::Pause,
"contextmenu" => SpecialKey::ContextMenu,
"f1" => SpecialKey::F1,
"f2" => SpecialKey::F2,
"f3" => SpecialKey::F3,
"f4" => SpecialKey::F4,
"f5" => SpecialKey::F5,
"f6" => SpecialKey::F6,
"f7" => SpecialKey::F7,
"f8" => SpecialKey::F8,
"f9" => SpecialKey::F9,
"f10" => SpecialKey::F10,
"f11" => SpecialKey::F11,
"f12" => SpecialKey::F12,
_ => {
return Err(ErrorData::invalid_params(
format!(
"Unknown key `{s}`. Accepted special keys: Enter, Tab, Escape, Backspace, Delete, Space, ArrowUp, ArrowDown, ArrowLeft, ArrowRight, Home, End, PageUp, PageDown, Insert, CapsLock, NumLock, ScrollLock, PrintScreen, Pause, ContextMenu, F1..F12. Single characters (e.g. `a`, `?`) are typed as `Key::Char`."
),
None,
));
}
};
Ok(Key::Special(special))
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ClickInput {
#[serde(flatten)]
pub selector: Selector,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub button: Option<MouseButtonArg>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub click_count: Option<u32>,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn click(
state: Arc<Mutex<SessionState>>,
input: ClickInput,
) -> Result<ActionOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
let opts = ClickOptions {
button: input.button.unwrap_or_default().into(),
click_count: input.click_count.unwrap_or(1),
..Default::default()
};
el.click_with(opts)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct HoverInput {
#[serde(flatten)]
pub selector: Selector,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn hover(
state: Arc<Mutex<SessionState>>,
input: HoverInput,
) -> Result<ActionOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
el.hover()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct TypeInput {
#[serde(flatten)]
pub selector: Selector,
pub text: String,
#[serde(default)]
pub clear_first: bool,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn type_text(
state: Arc<Mutex<SessionState>>,
input: TypeInput,
) -> Result<ActionOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
if input.clear_first {
el.clear()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
}
el.type_text(&input.text)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct PressInput {
#[serde(flatten)]
pub selector: Selector,
pub key: String,
#[serde(default)]
pub modifiers: Vec<ModifierArg>,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn press(
state: Arc<Mutex<SessionState>>,
input: PressInput,
) -> Result<ActionOutput, ErrorData> {
let key = parse_key(&input.key)?;
let mods = modifiers_to_bits(&input.modifiers);
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
if mods.is_empty() {
el.press(key)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
} else {
el.press_with(key, mods)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
}
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct KeyStep {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub key: Option<String>,
#[serde(default)]
pub modifiers: Vec<ModifierArg>,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct KeySequenceInput {
#[serde(flatten)]
pub selector: Selector,
pub sequence: Vec<KeyStep>,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn key_sequence(
state: Arc<Mutex<SessionState>>,
input: KeySequenceInput,
) -> Result<ActionOutput, ErrorData> {
let mut seq = KeySequence::new();
for step in &input.sequence {
match (&step.text, &step.key) {
(Some(text), None) => seq = seq.text(text.clone()),
(None, Some(key_str)) => {
let key = parse_key(key_str)?;
let mods = modifiers_to_bits(&step.modifiers);
if mods.is_empty() {
match key {
Key::Special(sk) => seq = seq.key(sk),
Key::Char(_) => seq = seq.chord(key, mods),
}
} else {
seq = seq.chord(key, mods);
}
}
_ => {
return Err(ErrorData::invalid_params(
"each sequence step must set exactly one of `text` or `key`".to_string(),
None,
));
}
}
}
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
el.type_keys(seq)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SetValueMode {
#[default]
Value,
Text,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct SetValueInput {
#[serde(flatten)]
pub selector: Selector,
pub value: String,
#[serde(default)]
pub mode: SetValueMode,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn set_value(
state: Arc<Mutex<SessionState>>,
input: SetValueInput,
) -> Result<ActionOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
match input.mode {
SetValueMode::Value => el.set_value(&input.value).await,
SetValueMode::Text => el.set_text(&input.value).await,
}
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ClearMode {
#[default]
Value,
Backspace,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ClearInput {
#[serde(flatten)]
pub selector: Selector,
#[serde(default)]
pub mode: ClearMode,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn clear(
state: Arc<Mutex<SessionState>>,
input: ClearInput,
) -> Result<ActionOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
match input.mode {
ClearMode::Value => el.clear().await,
ClearMode::Backspace => el.clear_by_deleting().await,
}
.map_err(|e| map_error(McpServerError::from(e)))?;
ok_with_snapshot(&tab, input.return_snapshot).await
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct FocusInput {
#[serde(flatten)]
pub selector: Selector,
}
pub async fn focus(
state: Arc<Mutex<SessionState>>,
input: FocusInput,
) -> Result<AckOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
el.focus()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(AckOutput { ok: true })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ScrollInput {
#[serde(flatten)]
pub selector: Selector,
}
pub async fn scroll_into_view(
state: Arc<Mutex<SessionState>>,
input: ScrollInput,
) -> Result<AckOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
el.scroll_into_view()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(AckOutput { ok: true })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct UploadInput {
#[serde(flatten)]
pub selector: Selector,
pub paths: Vec<String>,
}
pub async fn upload(
state: Arc<Mutex<SessionState>>,
input: UploadInput,
) -> Result<AckOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let el = resolve(&tab, &input.selector).await?;
el.upload_files(&input.paths)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(AckOutput { ok: true })
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Arc<Mutex<SessionState>> {
Arc::new(Mutex::new(SessionState::new()))
}
fn css(s: &str) -> Selector {
Selector {
css: Some(s.into()),
xpath: None,
text: None,
text_exact: None,
text_regex: None,
role: None,
role_name: None,
tag: None,
attrs: vec![],
nth: None,
visible_only: true,
timeout_ms: 5000,
frame_id: None,
}
}
fn assert_suggests_browser_open(err: &ErrorData) {
assert!(err.message.contains("browser_open"), "msg: {}", err.message);
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn click_with_no_browser_suggests_browser_open() {
let err = click(
fresh(),
ClickInput {
selector: css("button"),
button: None,
click_count: None,
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn hover_with_no_browser_suggests_browser_open() {
let err = hover(
fresh(),
HoverInput {
selector: css("a"),
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn type_with_no_browser_suggests_browser_open() {
let err = type_text(
fresh(),
TypeInput {
selector: css("input"),
text: "hello".into(),
clear_first: false,
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn press_with_no_browser_suggests_browser_open() {
let err = press(
fresh(),
PressInput {
selector: css("input"),
key: "Enter".into(),
modifiers: Vec::new(),
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn set_value_with_no_browser_suggests_browser_open() {
let err = set_value(
fresh(),
SetValueInput {
selector: css("input"),
value: "rust async".into(),
mode: SetValueMode::Value,
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn clear_with_no_browser_suggests_browser_open() {
let err = clear(
fresh(),
ClearInput {
selector: css("input"),
mode: ClearMode::Value,
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn focus_with_no_browser_suggests_browser_open() {
let err = focus(
fresh(),
FocusInput {
selector: css("input"),
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn scroll_into_view_with_no_browser_suggests_browser_open() {
let err = scroll_into_view(
fresh(),
ScrollInput {
selector: css("footer"),
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[tokio::test]
async fn upload_with_no_browser_suggests_browser_open() {
let err = upload(
fresh(),
UploadInput {
selector: css("input[type=file]"),
paths: vec!["/tmp/a.txt".into()],
},
)
.await
.expect_err("must error without an open browser");
assert_suggests_browser_open(&err);
}
#[test]
fn parse_key_accepts_canonical_special_names() {
assert_eq!(parse_key("Enter").unwrap(), Key::Special(SpecialKey::Enter));
assert_eq!(parse_key("Tab").unwrap(), Key::Special(SpecialKey::Tab));
assert_eq!(
parse_key("Backspace").unwrap(),
Key::Special(SpecialKey::Backspace)
);
assert_eq!(
parse_key("ArrowUp").unwrap(),
Key::Special(SpecialKey::ArrowUp)
);
assert_eq!(parse_key("F5").unwrap(), Key::Special(SpecialKey::F5));
}
#[test]
fn parse_key_is_case_insensitive_for_special_names() {
assert_eq!(parse_key("enter").unwrap(), Key::Special(SpecialKey::Enter));
assert_eq!(parse_key("ENTER").unwrap(), Key::Special(SpecialKey::Enter));
assert_eq!(
parse_key("arrowdown").unwrap(),
Key::Special(SpecialKey::ArrowDown)
);
assert_eq!(parse_key("ESC").unwrap(), Key::Special(SpecialKey::Escape));
}
#[test]
fn parse_key_treats_single_char_as_char_variant() {
assert_eq!(parse_key("a").unwrap(), Key::Char('a'));
assert_eq!(parse_key("?").unwrap(), Key::Char('?'));
assert_eq!(parse_key("é").unwrap(), Key::Char('é'));
}
#[test]
fn parse_key_rejects_unknown_special_with_accepted_list() {
let err = parse_key("Zomg").expect_err("unknown key must error");
assert!(
err.message.contains("Unknown key `Zomg`"),
"msg: {}",
err.message
);
assert!(err.message.contains("Enter"), "msg: {}", err.message);
assert!(err.message.contains("Backspace"), "msg: {}", err.message);
assert!(err.message.contains("F1..F12"), "msg: {}", err.message);
}
#[test]
fn parse_key_accepts_arrow_aliases_and_return_alias() {
assert_eq!(
parse_key("return").unwrap(),
Key::Special(SpecialKey::Enter)
);
assert_eq!(parse_key("up").unwrap(), Key::Special(SpecialKey::ArrowUp));
assert_eq!(
parse_key("down").unwrap(),
Key::Special(SpecialKey::ArrowDown)
);
assert_eq!(
parse_key("left").unwrap(),
Key::Special(SpecialKey::ArrowLeft)
);
assert_eq!(
parse_key("right").unwrap(),
Key::Special(SpecialKey::ArrowRight)
);
}
#[test]
fn mouse_button_arg_maps_to_zendriver_variants() {
assert_eq!(MouseButton::from(MouseButtonArg::Left), MouseButton::Left);
assert_eq!(
MouseButton::from(MouseButtonArg::Middle),
MouseButton::Middle
);
assert_eq!(MouseButton::from(MouseButtonArg::Right), MouseButton::Right);
}
}