use std::sync::Arc;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use zendriver::Browser;
use zendriver::stealth::{Platform, StealthProfile};
use crate::errors::{McpServerError, map_error};
use crate::state::{SessionState, StealthOverrides, StealthPlatformChoice, StealthProfileChoice};
use crate::tools::common::EmptyInput;
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct OpenInput {
#[serde(default = "default_true")]
pub headless: bool,
#[serde(default)]
pub stealth_profile: Option<StealthProfileChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub preferences: Option<std::collections::HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub persona: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub block_trackers: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tracker_blocklist: Option<TrackerBlocklist>,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(tag = "source", rename_all = "snake_case", deny_unknown_fields)]
pub enum TrackerBlocklist {
Url { url: String },
File { path: String },
Domains { domains: Vec<String> },
}
const fn default_true() -> bool {
true
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct OpenOutput {
pub chrome_version: String,
pub headless: bool,
pub profile: StealthProfileChoice,
}
pub async fn open(
state: Arc<Mutex<SessionState>>,
input: OpenInput,
) -> Result<OpenOutput, ErrorData> {
let mut s = state.lock().await;
if s.browser.is_some() {
return Err(map_error(McpServerError::BrowserAlreadyOpen));
}
let profile = input.stealth_profile.unwrap_or(s.stealth_profile_choice);
let stealth = apply_overrides(stealth_profile_for(profile), &s.stealth_overrides);
let mut builder = Browser::builder().headless(input.headless).stealth(stealth);
if let Some(prefs) = &input.preferences {
for (k, v) in prefs {
builder = builder.preference(k.clone(), v.clone());
}
}
if let Some(p) = &input.persona {
let persona = zendriver::Persona::try_from_json(&p.to_string())
.map_err(|e| ErrorData::invalid_params(format!("invalid persona JSON: {e}"), None))?;
builder = builder.persona(persona);
}
#[cfg(feature = "tracker-blocking")]
{
if input.block_trackers.unwrap_or(false) {
builder = builder.block_trackers(true);
}
if let Some(bl) = &input.tracker_blocklist {
builder = match bl {
TrackerBlocklist::Url { url } => builder.tracker_blocklist_url(url.clone()),
TrackerBlocklist::File { path } => {
builder.tracker_blocklist_file(std::path::PathBuf::from(path))
}
TrackerBlocklist::Domains { domains } => {
builder.tracker_blocklist_add(domains.clone())
}
};
}
}
#[cfg(not(feature = "tracker-blocking"))]
if input.block_trackers.unwrap_or(false) || input.tracker_blocklist.is_some() {
return Err(ErrorData::invalid_params(
"tracker blocking requested but this server was built without the `tracker-blocking` feature".to_string(),
None,
));
}
let browser = builder
.launch()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
let tabs = browser.tabs().await;
s.current_tab_id = tabs.first().map(|t| t.target_id().to_string());
s.browser = Some(browser);
s.stealth_profile_choice = profile;
Ok(OpenOutput {
chrome_version: String::new(),
headless: input.headless,
profile,
})
}
fn stealth_profile_for(choice: StealthProfileChoice) -> StealthProfile {
match choice {
StealthProfileChoice::Auto | StealthProfileChoice::Native => StealthProfile::native(),
StealthProfileChoice::SpoofMacos => StealthProfile::spoofed().platform(Platform::MacIntel),
StealthProfileChoice::SpoofLinux => {
StealthProfile::spoofed().platform(Platform::LinuxX86_64)
}
StealthProfileChoice::SpoofWindows => StealthProfile::spoofed().platform(Platform::Win32),
}
}
impl From<StealthPlatformChoice> for Platform {
fn from(p: StealthPlatformChoice) -> Self {
match p {
StealthPlatformChoice::Win32 => Platform::Win32,
StealthPlatformChoice::MacIntel => Platform::MacIntel,
StealthPlatformChoice::LinuxX86_64 => Platform::LinuxX86_64,
}
}
}
fn apply_overrides(mut profile: StealthProfile, overrides: &StealthOverrides) -> StealthProfile {
if let Some(platform) = overrides.platform {
profile = profile.platform(platform.into());
}
#[cfg(feature = "geo")]
if let Some(ref cc) = overrides.geo_country {
match zendriver_stealth::geo::Country::try_from(cc.as_str()) {
Ok(country) => {
let derived = zendriver_stealth::geo::persona(country);
if let Some(locale) = derived.locale {
profile = profile.locale(locale);
}
if let Some(langs) = derived.languages {
profile = profile.languages(langs);
}
}
Err(_) => tracing::warn!("geo_country {cc:?} is not a valid country code; ignoring"),
}
}
if let Some(ref locale) = overrides.locale {
profile = profile.locale(locale);
}
if let Some(ref timezone) = overrides.timezone {
profile = profile.timezone(timezone);
}
if let Some(memory_gb) = overrides.memory_gb {
profile = profile.memory_gb(memory_gb);
}
if let Some(cpu_count) = overrides.cpu_count {
profile = profile.cpu_count(cpu_count);
}
if let Some(chrome_version) = overrides.chrome_version {
profile = profile.chrome_version(chrome_version);
}
if let Some(ref user_agent) = overrides.user_agent {
profile = profile.user_agent(user_agent);
}
if let Some(bypass_csp) = overrides.bypass_csp {
profile = profile.bypass_csp(bypass_csp);
}
profile
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct CloseOutput {
pub ok: bool,
}
pub async fn close(
state: Arc<Mutex<SessionState>>,
_: EmptyInput,
) -> Result<CloseOutput, ErrorData> {
let mut s = state.lock().await;
#[cfg(feature = "expect")]
{
for (_, h) in s.expectations.drain() {
h.task.abort();
}
}
#[cfg(feature = "interception")]
{
s.rules.clear();
}
#[cfg(feature = "monitor")]
{
s.monitors.clear();
}
if let Some(b) = s.browser.take() {
b.close()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
}
s.current_tab_id = None;
Ok(CloseOutput { ok: true })
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct TabSummary {
pub id: String,
pub url: String,
pub title: String,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct StatusOutput {
pub open: bool,
pub tab_count: usize,
pub current_tab: Option<TabSummary>,
#[serde(skip_serializing_if = "Option::is_none")]
pub inspector_url: Option<String>,
pub profile: StealthProfileChoice,
}
pub async fn status(
state: Arc<Mutex<SessionState>>,
_: EmptyInput,
) -> Result<StatusOutput, ErrorData> {
let s = state.lock().await;
let Some(b) = s.browser.as_ref() else {
return Ok(StatusOutput {
open: false,
tab_count: 0,
current_tab: None,
inspector_url: None,
profile: s.stealth_profile_choice,
});
};
let tabs = b.tabs().await;
let mut inspector_url = None;
let current_tab = match &s.current_tab_id {
Some(id) => {
let mut found = None;
for t in &tabs {
if t.target_id() == id {
let url = t.url().await.map(|u| u.to_string()).unwrap_or_default();
let title = t.title().await.unwrap_or_default();
inspector_url = t.inspector_url().ok();
found = Some(TabSummary {
id: t.target_id().to_string(),
url,
title,
});
break;
}
}
found
}
None => None,
};
Ok(StatusOutput {
open: true,
tab_count: tabs.len(),
current_tab,
inspector_url,
profile: s.stealth_profile_choice,
})
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
#[tokio::test]
async fn close_with_no_browser_is_noop() {
let state = Arc::new(Mutex::new(SessionState::new()));
let out = close(state, EmptyInput {}).await.expect("close ok");
assert!(out.ok);
}
#[tokio::test]
async fn status_with_no_browser_reports_closed() {
let state = Arc::new(Mutex::new(SessionState::new()));
let out = status(state, EmptyInput {}).await.expect("status ok");
assert!(!out.open);
assert_eq!(out.tab_count, 0);
assert!(out.current_tab.is_none());
assert_eq!(out.profile, StealthProfileChoice::Auto);
}
#[cfg(feature = "expect")]
#[tokio::test]
async fn close_drains_and_aborts_expectations() {
use crate::state::ExpectationHandle;
let state = Arc::new(Mutex::new(SessionState::new()));
let (_tx_keep_alive, rx) =
tokio::sync::oneshot::channel::<Result<serde_json::Value, String>>();
let task = tokio::spawn(async move {
std::future::pending::<()>().await;
});
let join_handle_for_check = task.abort_handle();
{
let mut s = state.lock().await;
s.expectations.insert(
"test-id".into(),
ExpectationHandle {
kind: "request",
task,
rx,
},
);
assert_eq!(s.expectations.len(), 1, "precondition: expectation present");
}
let out = close(state.clone(), EmptyInput {}).await.expect("close ok");
assert!(out.ok);
let s = state.lock().await;
assert!(
s.expectations.is_empty(),
"expectations map must be empty after close (was: {})",
s.expectations.len(),
);
drop(s);
tokio::task::yield_now().await;
assert!(
join_handle_for_check.is_finished(),
"expectation task should be aborted (and thus finished) after close",
);
}
#[cfg(feature = "geo")]
#[test]
fn geo_country_sets_locale_and_languages() {
let overrides = StealthOverrides {
geo_country: Some("US".into()),
..Default::default()
};
let profile = apply_overrides(StealthProfile::native(), &overrides);
let flags = profile.build_flags();
assert!(
flags.iter().any(|f| f == "--lang=en-US"),
"expected --lang=en-US in flags: {flags:?}",
);
}
#[cfg(feature = "interception")]
#[tokio::test]
async fn close_clears_interception_rules() {
use crate::state::InterceptRuleHandle;
let state = Arc::new(Mutex::new(SessionState::new()));
{
let mut s = state.lock().await;
s.rules.insert(
"rule-a".into(),
InterceptRuleHandle {
pattern: "*/ads/*".into(),
action_kind: "block",
_handle: zendriver_interception::InterceptHandle::for_tests(),
},
);
s.rules.insert(
"rule-b".into(),
InterceptRuleHandle {
pattern: "*/api/*".into(),
action_kind: "respond",
_handle: zendriver_interception::InterceptHandle::for_tests(),
},
);
assert_eq!(s.rules.len(), 2, "precondition: rules present");
}
let out = close(state.clone(), EmptyInput {}).await.expect("close ok");
assert!(out.ok);
let s = state.lock().await;
assert!(
s.rules.is_empty(),
"rules map must be empty after close (was: {})",
s.rules.len(),
);
}
}