use std::sync::Arc;
use std::time::Duration;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use zendriver::{IdleOptions, ReadyState, ReloadOptions};
use crate::errors::{McpServerError, map_error};
use crate::snapshot::html_trim;
use crate::state::SessionState;
use crate::tools::actions::AckOutput;
use crate::tools::common::{EmptyInput, current_tab, lookup_frame};
#[derive(Debug, Serialize, JsonSchema)]
pub struct NavOutput {
pub url: String,
pub title: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub snapshot: Option<String>,
}
async fn nav_output(tab: &zendriver::Tab, return_snapshot: bool) -> Result<NavOutput, ErrorData> {
let url = tab.url().await.map(|u| u.to_string()).unwrap_or_default();
let title = tab.title().await.unwrap_or_default();
let snapshot = if return_snapshot {
Some(snapshot_now(tab).await?)
} else {
None
};
Ok(NavOutput {
url,
title,
snapshot,
})
}
async fn snapshot_now(tab: &zendriver::Tab) -> Result<String, ErrorData> {
let html: String = tab
.evaluate_main("document.documentElement.outerHTML")
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(html_trim::trim(&html))
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct GotoInput {
pub url: String,
#[serde(default = "default_wait")]
pub wait_for: WaitFor,
#[serde(default)]
pub return_snapshot: bool,
}
#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum WaitFor {
#[default]
Load,
Idle,
None,
}
const fn default_wait() -> WaitFor {
WaitFor::Load
}
pub async fn goto(
state: Arc<Mutex<SessionState>>,
input: GotoInput,
) -> Result<NavOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
tab.goto(&input.url)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
match input.wait_for {
WaitFor::Load => tab
.wait_for_load()
.await
.map_err(|e| map_error(McpServerError::from(e)))?,
WaitFor::Idle => tab
.wait_for_idle_with(Duration::from_millis(5000), Duration::from_millis(500))
.await
.map_err(|e| map_error(McpServerError::from(e)))?,
WaitFor::None => {}
}
nav_output(&tab, input.return_snapshot).await
}
#[derive(Debug, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct HistoryInput {
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn back(
state: Arc<Mutex<SessionState>>,
input: HistoryInput,
) -> Result<NavOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
tab.back()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
nav_output(&tab, input.return_snapshot).await
}
pub async fn forward(
state: Arc<Mutex<SessionState>>,
input: HistoryInput,
) -> Result<NavOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
tab.forward()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
nav_output(&tab, input.return_snapshot).await
}
#[derive(Debug, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ReloadInput {
#[serde(default)]
pub ignore_cache: bool,
#[serde(default)]
pub return_snapshot: bool,
}
pub async fn reload(
state: Arc<Mutex<SessionState>>,
input: ReloadInput,
) -> Result<NavOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
if input.ignore_cache {
tab.reload_with(ReloadOptions {
ignore_cache: true,
..ReloadOptions::default()
})
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
} else {
tab.reload()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
}
nav_output(&tab, input.return_snapshot).await
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ReadyStateArg {
Interactive,
Complete,
}
impl From<ReadyStateArg> for ReadyState {
fn from(r: ReadyStateArg) -> Self {
match r {
ReadyStateArg::Interactive => ReadyState::Interactive,
ReadyStateArg::Complete => ReadyState::Complete,
}
}
}
#[derive(Debug, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct WaitForLoadInput {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ready_state: Option<ReadyStateArg>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub frame_id: Option<String>,
}
pub async fn wait_for_load(
state: Arc<Mutex<SessionState>>,
input: WaitForLoadInput,
) -> Result<NavOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
if let Some(fid) = input.frame_id.as_deref() {
let frame = lookup_frame(&tab, fid).await?;
frame
.wait_for_load()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
} else if let Some(rs) = input.ready_state {
tab.wait_for_ready_state(rs.into())
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
} else {
tab.wait_for_load()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
}
nav_output(&tab, false).await
}
pub async fn bypass_insecure_warning(
state: Arc<Mutex<SessionState>>,
_: EmptyInput,
) -> Result<AckOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
tab.bypass_insecure_connection_warning()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(AckOutput { ok: true })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct IdleInput {
#[serde(default = "default_idle_timeout")]
pub timeout_ms: u64,
#[serde(default)]
pub max_inflight_age_ms: Option<u64>,
}
const fn default_idle_timeout() -> u64 {
5000
}
impl Default for IdleInput {
fn default() -> Self {
Self {
timeout_ms: default_idle_timeout(),
max_inflight_age_ms: None,
}
}
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct IdleOutput {
pub idle: bool,
}
pub async fn wait_for_idle(
state: Arc<Mutex<SessionState>>,
input: IdleInput,
) -> Result<IdleOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
tab.wait_for_idle_opts(IdleOptions {
timeout: Duration::from_millis(input.timeout_ms),
quiet_window: Duration::from_millis(500),
max_inflight_age: input.max_inflight_age_ms.map(Duration::from_millis),
})
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(IdleOutput { idle: true })
}
impl From<EmptyInput> for HistoryInput {
fn from(_: EmptyInput) -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn goto_with_no_browser_suggests_browser_open() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = goto(
state,
GotoInput {
url: "https://example.com".into(),
wait_for: WaitFor::Load,
return_snapshot: false,
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"), "msg: {}", err.message);
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn back_with_no_browser_suggests_browser_open() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = back(state, HistoryInput::default())
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn forward_with_no_browser_suggests_browser_open() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = forward(state, HistoryInput::default())
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn reload_with_no_browser_suggests_browser_open() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = reload(state, ReloadInput::default())
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn wait_for_idle_with_no_browser_suggests_browser_open() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = wait_for_idle(state, IdleInput::default())
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
}