use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use zendriver::query::FindAllBuilder;
use zendriver::{AriaRole, Element, FindBuilder, Tab};
use crate::errors::{McpServerError, map_error};
use crate::selectors::{AttrOp, Selector};
use crate::state::SessionState;
use crate::tools::common::{current_tab, lookup_frame};
pub async fn resolve(tab: &Tab, sel: &Selector) -> Result<Element, ErrorData> {
sel.validate()
.map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
let frame_handle = if let Some(fid) = sel.frame_id.as_deref() {
Some(lookup_frame(tab, fid).await?)
} else {
None
};
let builder = tab.find();
let builder = apply_selector(builder, sel)?;
let builder = apply_modifiers(builder, sel);
let builder = if let Some(ref f) = frame_handle {
builder.in_frame(f)
} else {
builder
};
builder
.one()
.await
.map_err(|e| map_error(McpServerError::from(e)))
}
pub async fn resolve_all(
tab: &Tab,
sel: &Selector,
limit: usize,
) -> Result<Vec<Element>, ErrorData> {
sel.validate()
.map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
let frame_handle = if let Some(fid) = sel.frame_id.as_deref() {
Some(lookup_frame(tab, fid).await?)
} else {
None
};
let builder = tab.find_all();
let builder = apply_selector_all(builder, sel)?;
let builder = apply_modifiers_all(builder, sel);
let builder = if let Some(ref f) = frame_handle {
builder.in_frame(f)
} else {
builder
};
let mut all = builder
.many_or_empty()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
if all.len() > limit {
all.truncate(limit);
}
Ok(all)
}
macro_rules! apply_selector_to {
($builder:expr, $sel:expr) => {{
let mut builder = $builder;
let sel = $sel;
if sel.tag.is_some() || !sel.attrs.is_empty() {
if let Some(tag) = sel.tag.as_deref() {
builder = builder.tag(tag);
}
for ap in &sel.attrs {
let name = ap.name.as_str();
let value = ap.value.as_deref().unwrap_or(""); builder = match ap.op {
AttrOp::Eq => builder.attr(name, value),
AttrOp::Contains => builder.attr_contains(name, value),
AttrOp::StartsWith => builder.attr_starts_with(name, value),
AttrOp::EndsWith => builder.attr_ends_with(name, value),
AttrOp::Has => builder.has_attr(name),
AttrOp::Regex => builder.attr_regex(name, value),
};
}
if let Some(t) = sel.text.as_deref() {
builder = builder.containing_text(t);
}
if let Some(t) = sel.text_exact.as_deref() {
builder = builder.text_equals(t);
}
if let Some(pat) = sel.text_regex.as_deref() {
builder = builder.text_matches(pat);
}
Ok(builder)
} else if let Some(css) = sel.css.as_deref() {
Ok(builder.css(css))
} else if let Some(xp) = sel.xpath.as_deref() {
Ok(builder.xpath(xp))
} else if let Some(t) = sel.text.as_deref() {
Ok(builder.text(t))
} else if let Some(t) = sel.text_exact.as_deref() {
Ok(builder.text_exact(t))
} else if let Some(pat) = sel.text_regex.as_deref() {
let re = compile_regex(pat)?;
Ok(builder.text_regex(re))
} else if let Some(role_str) = sel.role.as_deref() {
let role = parse_role(role_str)?;
Ok(match sel.role_name.as_deref() {
Some(name) => builder.role_named(role, name),
None => builder.role(role),
})
} else {
Err(ErrorData::invalid_params(
"Selector has no selector kind set (validate() should have caught this)"
.to_string(),
None,
))
}
}};
}
fn apply_selector<'a>(
builder: FindBuilder<'a>,
sel: &Selector,
) -> Result<FindBuilder<'a>, ErrorData> {
apply_selector_to!(builder, sel)
}
fn apply_selector_all<'a>(
builder: FindAllBuilder<'a>,
sel: &Selector,
) -> Result<FindAllBuilder<'a>, ErrorData> {
apply_selector_to!(builder, sel)
}
fn apply_modifiers<'a>(mut builder: FindBuilder<'a>, sel: &Selector) -> FindBuilder<'a> {
if let Some(n) = sel.nth {
builder = builder.nth(n);
}
builder = builder.visible_only(sel.visible_only);
builder = builder.timeout(Duration::from_millis(sel.timeout_ms));
builder
}
fn apply_modifiers_all<'a>(mut builder: FindAllBuilder<'a>, sel: &Selector) -> FindAllBuilder<'a> {
builder = builder.visible_only(sel.visible_only);
builder = builder.timeout(Duration::from_millis(sel.timeout_ms));
builder
}
fn compile_regex(pat: &str) -> Result<regex::Regex, ErrorData> {
regex::Regex::new(pat)
.map_err(|e| ErrorData::invalid_params(format!("Invalid `text_regex` pattern: {e}"), None))
}
fn parse_role(s: &str) -> Result<AriaRole, ErrorData> {
match s {
"button" => Ok(AriaRole::Button),
"link" => Ok(AriaRole::Link),
"textbox" => Ok(AriaRole::Textbox),
"combobox" => Ok(AriaRole::Combobox),
"checkbox" => Ok(AriaRole::Checkbox),
"radio" => Ok(AriaRole::Radio),
"tab" => Ok(AriaRole::Tab),
"menu" => Ok(AriaRole::Menu),
"menuitem" => Ok(AriaRole::Menuitem),
"dialog" => Ok(AriaRole::Dialog),
"heading" => Ok(AriaRole::Heading),
"banner" => Ok(AriaRole::Banner),
"navigation" => Ok(AriaRole::Navigation),
"main" => Ok(AriaRole::Main),
"article" => Ok(AriaRole::Article),
"list" => Ok(AriaRole::List),
"listitem" => Ok(AriaRole::Listitem),
"row" => Ok(AriaRole::Row),
"cell" => Ok(AriaRole::Cell),
"columnheader" => Ok(AriaRole::Columnheader),
"rowheader" => Ok(AriaRole::Rowheader),
other => Err(ErrorData::invalid_params(
format!(
"Unknown ARIA role `{other}`. Accepted: button, link, textbox, combobox, checkbox, radio, tab, menu, menuitem, dialog, heading, banner, navigation, main, article, list, listitem, row, cell, columnheader, rowheader."
),
None,
)),
}
}
#[derive(Debug, Clone, Copy, Serialize, JsonSchema, PartialEq)]
pub struct BoundingBox {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
impl From<zendriver::BoundingBox> for BoundingBox {
fn from(b: zendriver::BoundingBox) -> Self {
Self {
x: b.x,
y: b.y,
width: b.width,
height: b.height,
}
}
}
#[derive(Debug, Clone, Serialize, JsonSchema)]
pub struct ElementDescriptor {
#[serde(skip_serializing_if = "Option::is_none")]
pub tag: Option<String>,
pub text_snippet: String,
pub attrs: BTreeMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bounding_box: Option<BoundingBox>,
pub visible: bool,
pub enabled: bool,
}
pub async fn describe(el: &Element) -> Result<ElementDescriptor, ErrorData> {
let tag: Option<String> = el.evaluate::<String>("el.tagName.toLowerCase()").await.ok();
let text_snippet: String = el
.inner_text()
.await
.unwrap_or_default()
.chars()
.take(200)
.collect();
let attrs_map = el
.attrs()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
let attrs: BTreeMap<String, String> = attrs_map.into_iter().collect();
let visible = el
.is_visible()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
let enabled = el
.is_enabled()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
let bounding_box = el
.bounding_box()
.await
.ok()
.flatten()
.map(BoundingBox::from);
Ok(ElementDescriptor {
tag,
text_snippet,
attrs,
bounding_box,
visible,
enabled,
})
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct FindInput {
#[serde(flatten)]
pub selector: Selector,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct FindOutput {
pub found: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub element: Option<ElementDescriptor>,
}
pub async fn find(
state: Arc<Mutex<SessionState>>,
input: FindInput,
) -> Result<FindOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
match resolve(&tab, &input.selector).await {
Ok(el) => {
let desc = describe(&el).await?;
Ok(FindOutput {
found: true,
element: Some(desc),
})
}
Err(err) if is_not_found(&err) => Ok(FindOutput {
found: false,
element: None,
}),
Err(err) => Err(err),
}
}
fn is_not_found(err: &ErrorData) -> bool {
err.data
.as_ref()
.and_then(|v| v.get("suggested_next"))
.and_then(|v| v.as_str())
== Some("browser_html")
&& err.message.contains("No element matched")
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct FindAllInput {
#[serde(flatten)]
pub selector: Selector,
#[serde(default = "default_limit")]
pub limit: usize,
}
const fn default_limit() -> usize {
50
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct FindAllOutput {
pub elements: Vec<ElementDescriptor>,
}
pub async fn find_all(
state: Arc<Mutex<SessionState>>,
input: FindAllInput,
) -> Result<FindAllOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let els = resolve_all(&tab, &input.selector, input.limit).await?;
let mut out = Vec::with_capacity(els.len());
for el in &els {
out.push(describe(el).await?);
}
Ok(FindAllOutput { elements: out })
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Arc<Mutex<SessionState>> {
Arc::new(Mutex::new(SessionState::new()))
}
fn css_sel(s: &str) -> Selector {
Selector {
css: Some(s.into()),
xpath: None,
text: None,
text_exact: None,
text_regex: None,
role: None,
role_name: None,
tag: None,
attrs: vec![],
nth: None,
visible_only: true,
timeout_ms: 5000,
frame_id: None,
}
}
#[tokio::test]
async fn find_with_no_browser_suggests_browser_open() {
let err = find(
fresh(),
FindInput {
selector: css_sel("h1"),
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"), "msg: {}", err.message);
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn find_all_with_no_browser_suggests_browser_open() {
let err = find_all(
fresh(),
FindAllInput {
selector: css_sel("a"),
limit: 10,
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[test]
fn compile_regex_rejects_malformed_pattern() {
let err = compile_regex("[unclosed").expect_err("malformed regex must error");
assert!(
err.message.contains("Invalid `text_regex`"),
"msg: {}",
err.message
);
}
#[test]
fn compile_regex_accepts_valid_pattern() {
let re = compile_regex(r"^hello\s+world$").expect("valid regex");
assert!(re.is_match("hello world"));
}
#[test]
fn parse_role_rejects_unknown_role_with_enum_list() {
let err = parse_role("buttn").expect_err("typo must error");
assert!(
err.message.contains("Unknown ARIA role `buttn`"),
"msg: {}",
err.message
);
assert!(err.message.contains("button"));
assert!(err.message.contains("textbox"));
}
#[test]
fn parse_role_covers_every_lib_variant() {
let cases = [
("button", AriaRole::Button),
("link", AriaRole::Link),
("textbox", AriaRole::Textbox),
("combobox", AriaRole::Combobox),
("checkbox", AriaRole::Checkbox),
("radio", AriaRole::Radio),
("tab", AriaRole::Tab),
("menu", AriaRole::Menu),
("menuitem", AriaRole::Menuitem),
("dialog", AriaRole::Dialog),
("heading", AriaRole::Heading),
("banner", AriaRole::Banner),
("navigation", AriaRole::Navigation),
("main", AriaRole::Main),
("article", AriaRole::Article),
("list", AriaRole::List),
("listitem", AriaRole::Listitem),
("row", AriaRole::Row),
("cell", AriaRole::Cell),
("columnheader", AriaRole::Columnheader),
("rowheader", AriaRole::Rowheader),
];
for (s, expect) in cases {
assert_eq!(parse_role(s).expect("ok"), expect, "input: {s}");
}
}
}