#![cfg(feature = "expect")]
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use tokio::sync::{Mutex, oneshot};
use zendriver::{
DialogType, MatchedDialog, MatchedDownload, MatchedRequest, MatchedResponse, UrlMatcher,
ZendriverError,
};
use crate::errors::{McpServerError, map_error};
use crate::state::{ExpectationHandle, ExpectationId, SessionState};
use crate::tools::common::current_tab;
#[derive(Debug, Clone, Copy, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ExpectKind {
Request,
Response,
Dialog,
Download,
}
impl ExpectKind {
fn label(self) -> &'static str {
match self {
Self::Request => "request",
Self::Response => "response",
Self::Dialog => "dialog",
Self::Download => "download",
}
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DialogAction {
Accept,
Dismiss,
}
#[derive(Debug, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ExpectMatcher {
#[serde(default)]
pub url_substr: Option<String>,
#[serde(default)]
pub url_regex: Option<String>,
}
impl ExpectMatcher {
fn into_url_matcher(self) -> Result<UrlMatcher, ErrorData> {
if let Some(re) = self.url_regex {
let compiled = regex::Regex::new(&re).map_err(|e| {
ErrorData::invalid_request(format!("invalid url_regex `{re}`: {e}"), None)
})?;
return Ok(UrlMatcher::Regex(compiled));
}
if let Some(sub) = self.url_substr {
return Ok(UrlMatcher::Substring(sub));
}
Ok(UrlMatcher::Substring(String::new()))
}
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RegisterInput {
pub kind: ExpectKind,
#[serde(default)]
pub matcher: Option<ExpectMatcher>,
#[serde(default = "default_pre_timeout")]
pub pre_await_timeout_ms: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dialog_action: Option<DialogAction>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dialog_prompt_text: Option<String>,
#[serde(default)]
pub fetch_body: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub save_to: Option<String>,
}
fn default_pre_timeout() -> u64 {
60_000
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct RegisterOutput {
pub expectation_id: ExpectationId,
}
pub async fn register(
state: Arc<Mutex<SessionState>>,
input: RegisterInput,
) -> Result<RegisterOutput, ErrorData> {
let pre_await = Duration::from_millis(input.pre_await_timeout_ms);
let matcher = input.matcher.unwrap_or_default().into_url_matcher()?;
let kind = input.kind;
let dialog_action = input.dialog_action;
let dialog_prompt = input.dialog_prompt_text;
let fetch_body = input.fetch_body;
let save_to = input.save_to;
let tab = {
let s = state.lock().await;
current_tab(&s).await?
};
let (tx, rx) = oneshot::channel::<Result<Value, String>>();
let task: tokio::task::JoinHandle<()> = match kind {
ExpectKind::Request => {
let exp = tab.expect_request(matcher).timeout(pre_await);
tokio::spawn(async move {
let msg = exp.matched().await.map(request_to_json).map_err(err_to_str);
let _ = tx.send(msg);
})
}
ExpectKind::Response => {
let exp = tab.expect_response(matcher).timeout(pre_await);
tokio::spawn(async move {
let msg = match exp.matched().await {
Ok(m) => response_to_json(m, fetch_body).await,
Err(e) => Err(err_to_str(e)),
};
let _ = tx.send(msg);
})
}
ExpectKind::Dialog => {
let exp = tab.expect_dialog().timeout(pre_await);
tokio::spawn(async move {
let msg = match exp.matched().await {
Ok(m) => dialog_to_json(m, dialog_action, dialog_prompt).await,
Err(e) => Err(err_to_str(e)),
};
let _ = tx.send(msg);
})
}
ExpectKind::Download => {
let exp = tab
.expect_download()
.await
.map_err(|e| map_error(McpServerError::from(e)))?
.timeout(pre_await);
tokio::spawn(async move {
let msg = match exp.matched().await {
Ok(m) => download_to_json(m, save_to).await,
Err(e) => Err(err_to_str(e)),
};
let _ = tx.send(msg);
})
}
};
let id: ExpectationId = uuid::Uuid::new_v4().to_string();
{
let mut s = state.lock().await;
s.expectations.insert(
id.clone(),
ExpectationHandle {
kind: kind.label(),
task,
rx,
},
);
}
Ok(RegisterOutput { expectation_id: id })
}
fn err_to_str(e: ZendriverError) -> String {
e.to_string()
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct AwaitInput {
pub expectation_id: ExpectationId,
#[serde(default = "default_await_timeout")]
pub timeout_ms: u64,
}
fn default_await_timeout() -> u64 {
30_000
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct AwaitOutput {
pub expectation_id: ExpectationId,
pub event: Value,
}
pub async fn await_expectation(
state: Arc<Mutex<SessionState>>,
input: AwaitInput,
) -> Result<AwaitOutput, ErrorData> {
let handle = {
let mut s = state.lock().await;
s.expectations
.remove(&input.expectation_id)
.ok_or_else(|| {
map_error(McpServerError::ExpectationNotFound(
input.expectation_id.clone(),
))
})?
};
let ExpectationHandle { rx, task, .. } = handle;
let outer = Duration::from_millis(input.timeout_ms);
match tokio::time::timeout(outer, rx).await {
Err(_elapsed) => {
task.abort();
Err(ErrorData::invalid_request(
format!(
"expectation `{}` await_timed_out after {:?}",
input.expectation_id, outer
),
None,
))
}
Ok(Err(_recv)) => {
task.abort();
Err(ErrorData::invalid_request(
format!(
"expectation `{}` channel_closed (spawned task ended without sending)",
input.expectation_id
),
None,
))
}
Ok(Ok(Err(err_str))) => {
Err(ErrorData::invalid_request(
format!("expectation `{}` failed: {err_str}", input.expectation_id),
None,
))
}
Ok(Ok(Ok(event))) => Ok(AwaitOutput {
expectation_id: input.expectation_id,
event,
}),
}
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct CancelInput {
pub expectation_id: ExpectationId,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct CancelOutput {
pub cancelled: bool,
}
pub async fn cancel(
state: Arc<Mutex<SessionState>>,
input: CancelInput,
) -> Result<CancelOutput, ErrorData> {
let mut s = state.lock().await;
let handle = s
.expectations
.remove(&input.expectation_id)
.ok_or_else(|| {
map_error(McpServerError::ExpectationNotFound(
input.expectation_id.clone(),
))
})?;
handle.task.abort();
Ok(CancelOutput { cancelled: true })
}
fn request_to_json(m: MatchedRequest) -> Value {
json!({
"kind": "request",
"url": m.url,
"method": m.method,
"headers": m.headers,
"request_id": m.request_id,
"post_data_len": m.post_data.as_ref().map(Vec::len),
})
}
async fn response_to_json(m: MatchedResponse, fetch_body: bool) -> Result<Value, String> {
let mut v = json!({
"kind": "response",
"url": m.url,
"status": m.status,
"status_text": m.status_text,
"headers": m.headers,
"request_id": m.request_id,
});
if fetch_body {
let bytes = m.body().await.map_err(err_to_str)?;
v["body_len"] = json!(bytes.len());
v["body_base64"] = json!(BASE64.encode(&bytes));
}
Ok(v)
}
fn dialog_type_str(d: &DialogType) -> &'static str {
match d {
DialogType::Alert => "alert",
DialogType::Beforeunload => "beforeunload",
DialogType::Confirm => "confirm",
DialogType::Prompt => "prompt",
}
}
async fn dialog_to_json(
m: MatchedDialog,
action: Option<DialogAction>,
prompt: Option<String>,
) -> Result<Value, String> {
let dialog_type = dialog_type_str(&m.dialog_type);
let message = m.message.clone();
let default_prompt = m.default_prompt.clone();
let url = m.url.clone();
let driven = match action {
Some(DialogAction::Accept) => {
m.accept(prompt).await.map_err(err_to_str)?;
"accept"
}
Some(DialogAction::Dismiss) => {
m.dismiss().await.map_err(err_to_str)?;
"dismiss"
}
None => {
drop(m);
"default"
}
};
Ok(json!({
"kind": "dialog",
"dialog_type": dialog_type,
"message": message,
"default_prompt": default_prompt,
"url": url,
"driven": driven,
}))
}
async fn download_to_json(m: MatchedDownload, save_to: Option<String>) -> Result<Value, String> {
let url = m.url.clone();
let suggested_filename = m.suggested_filename.clone();
let guid = m.guid.clone();
let download_dir = m.download_dir.to_string_lossy().into_owned();
let saved_path = if let Some(dest) = save_to {
m.save_to(PathBuf::from(&dest)).await.map_err(err_to_str)?;
Some(dest)
} else {
None
};
Ok(json!({
"kind": "download",
"url": url,
"suggested_filename": suggested_filename,
"guid": guid,
"download_dir": download_dir,
"saved_path": saved_path,
}))
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn matcher_default_is_match_any_substring() {
let m = ExpectMatcher::default().into_url_matcher().unwrap();
match m {
UrlMatcher::Substring(s) => assert!(
s.is_empty(),
"default matcher must be empty substring (matches any url)",
),
UrlMatcher::Regex(_) => panic!("expected Substring(\"\")"),
}
}
#[test]
fn matcher_url_regex_wins_over_url_substr() {
let m = ExpectMatcher {
url_substr: Some("/foo/".into()),
url_regex: Some(r"^https://api\.".into()),
}
.into_url_matcher()
.unwrap();
match m {
UrlMatcher::Regex(re) => assert!(re.is_match("https://api.example.com/v1")),
UrlMatcher::Substring(_) => panic!("expected Regex variant"),
}
}
#[test]
fn matcher_invalid_regex_errors() {
let err = ExpectMatcher {
url_regex: Some("[".into()),
url_substr: None,
}
.into_url_matcher()
.expect_err("expected invalid regex");
assert!(err.message.contains("invalid url_regex"));
}
#[tokio::test]
async fn await_unknown_expectation_id_surfaces_not_found() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = await_expectation(
state,
AwaitInput {
expectation_id: "nope".into(),
timeout_ms: 100,
},
)
.await
.expect_err("expected ExpectationNotFound");
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_expect_register");
}
#[tokio::test]
async fn cancel_unknown_expectation_id_surfaces_not_found() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = cancel(
state,
CancelInput {
expectation_id: "nope".into(),
},
)
.await
.expect_err("expected ExpectationNotFound");
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_expect_register");
}
#[tokio::test]
async fn register_with_no_browser_errors() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = register(
state,
RegisterInput {
kind: ExpectKind::Request,
matcher: None,
pre_await_timeout_ms: 1_000,
dialog_action: None,
dialog_prompt_text: None,
fetch_body: false,
save_to: None,
},
)
.await
.expect_err("expected BrowserNotOpen");
assert!(err.message.contains("Browser not open"));
}
}