#![cfg(feature = "interception")]
use std::collections::BTreeMap;
use std::sync::Arc;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use zendriver::{RequestOverrides, ResponseInfo, ResponseOverrides, ZendriverError};
use crate::errors::{McpServerError, map_error};
use crate::state::{InterceptRuleHandle, RuleId, SessionState};
use crate::tools::common::current_tab;
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, tag = "kind", rename_all = "snake_case")]
pub enum InterceptAction {
Block,
Redirect {
to: String,
},
Respond {
status: u16,
body: String,
#[serde(default)]
content_type: Option<String>,
#[serde(default)]
headers: BTreeMap<String, String>,
},
ModifyRequest {
#[serde(default)]
headers: BTreeMap<String, String>,
},
ModifyResponse {
#[serde(default)]
status: Option<u16>,
#[serde(default)]
headers: BTreeMap<String, String>,
},
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct AddRuleInput {
pub pattern: String,
pub action: InterceptAction,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct AddRuleOutput {
pub rule_id: RuleId,
}
pub async fn add_rule(
state: Arc<Mutex<SessionState>>,
input: AddRuleInput,
) -> Result<AddRuleOutput, ErrorData> {
let mut s = state.lock().await;
let tab = current_tab(&s).await?;
let (handle, action_kind): (zendriver::InterceptHandle, &'static str) = match input.action {
InterceptAction::Block => {
let h = tab
.intercept()
.block(input.pattern.clone())
.map_err(zendriver_err)?
.start();
(h, "block")
}
InterceptAction::Redirect { to } => {
let h = tab
.intercept()
.redirect(input.pattern.clone(), to)
.map_err(zendriver_err)?
.start();
(h, "redirect")
}
InterceptAction::Respond {
status,
body,
content_type,
mut headers,
} => {
if let Some(ct) = content_type {
headers.insert("content-type".into(), ct);
}
let header_vec: Vec<(String, String)> = headers.into_iter().collect();
let h = tab
.intercept()
.respond(input.pattern.clone(), status, header_vec, body.into_bytes())
.map_err(zendriver_err)?
.start();
(h, "respond")
}
InterceptAction::ModifyRequest { headers } => {
let overlay = Arc::new(headers);
let h = tab
.intercept()
.modify_request(input.pattern.clone(), move |req| {
merge_headers(&req.headers, &overlay)
})
.map_err(zendriver_err)?
.start();
(h, "modify_request")
}
InterceptAction::ModifyResponse { status, headers } => {
let overlay = Arc::new(headers);
let h = tab
.intercept()
.modify_response(input.pattern.clone(), move |resp: &ResponseInfo| {
let headers = if overlay.is_empty() {
None
} else {
Some(merge_header_list(&resp.headers, &overlay))
};
ResponseOverrides {
status,
headers,
..ResponseOverrides::default()
}
})
.map_err(zendriver_err)?
.start();
(h, "modify_response")
}
};
let id: RuleId = uuid::Uuid::new_v4().to_string();
s.rules.insert(
id.clone(),
InterceptRuleHandle {
pattern: input.pattern,
action_kind,
_handle: handle,
},
);
Ok(AddRuleOutput { rule_id: id })
}
fn merge_headers(
original: &[(String, String)],
overlay: &BTreeMap<String, String>,
) -> RequestOverrides {
RequestOverrides {
headers: Some(merge_header_list(original, overlay)),
..RequestOverrides::default()
}
}
fn merge_header_list(
original: &[(String, String)],
overlay: &BTreeMap<String, String>,
) -> Vec<(String, String)> {
let mut out: Vec<(String, String)> = Vec::with_capacity(original.len() + overlay.len());
let overlay_keys_lower: std::collections::HashSet<String> =
overlay.keys().map(|k| k.to_ascii_lowercase()).collect();
for (k, v) in original {
if !overlay_keys_lower.contains(&k.to_ascii_lowercase()) {
out.push((k.clone(), v.clone()));
}
}
for (k, v) in overlay {
out.push((k.clone(), v.clone()));
}
out
}
fn zendriver_err(e: zendriver::InterceptionError) -> ErrorData {
map_error(McpServerError::from(ZendriverError::from(e)))
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RemoveRuleInput {
pub rule_id: RuleId,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct RemoveRuleOutput {
pub removed: bool,
}
pub async fn remove_rule(
state: Arc<Mutex<SessionState>>,
input: RemoveRuleInput,
) -> Result<RemoveRuleOutput, ErrorData> {
let mut s = state.lock().await;
s.rules
.remove(&input.rule_id)
.ok_or_else(|| map_error(McpServerError::RuleNotFound(input.rule_id)))?;
Ok(RemoveRuleOutput { removed: true })
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct ListRulesOutput {
pub rules: Vec<RuleSummary>,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct RuleSummary {
pub rule_id: RuleId,
pub pattern: String,
pub action_kind: String,
}
pub async fn list_rules(
state: Arc<Mutex<SessionState>>,
_: crate::tools::common::EmptyInput,
) -> Result<ListRulesOutput, ErrorData> {
let s = state.lock().await;
let mut rules: Vec<RuleSummary> = s
.rules
.iter()
.map(|(id, h)| RuleSummary {
rule_id: id.clone(),
pattern: h.pattern.clone(),
action_kind: h.action_kind.to_string(),
})
.collect();
rules.sort_by(|a, b| a.rule_id.cmp(&b.rule_id));
Ok(ListRulesOutput { rules })
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct ClearRulesOutput {
pub cleared: usize,
}
pub async fn clear_rules(
state: Arc<Mutex<SessionState>>,
_: crate::tools::common::EmptyInput,
) -> Result<ClearRulesOutput, ErrorData> {
let mut s = state.lock().await;
let cleared = s.rules.len();
s.rules.clear();
Ok(ClearRulesOutput { cleared })
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use crate::tools::common::EmptyInput;
#[tokio::test]
async fn list_rules_empty_returns_empty_vec() {
let state = Arc::new(Mutex::new(SessionState::new()));
let out = list_rules(state, EmptyInput {})
.await
.expect("list_rules ok");
assert!(out.rules.is_empty());
}
#[tokio::test]
async fn clear_rules_empty_returns_zero() {
let state = Arc::new(Mutex::new(SessionState::new()));
let out = clear_rules(state, EmptyInput {})
.await
.expect("clear_rules ok");
assert_eq!(out.cleared, 0);
}
#[tokio::test]
async fn remove_unknown_rule_surfaces_rule_not_found() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = remove_rule(
state,
RemoveRuleInput {
rule_id: "nope".into(),
},
)
.await
.expect_err("expected RuleNotFound");
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_intercept_add_rule");
}
#[tokio::test]
async fn add_rule_with_no_browser_errors() {
let state = Arc::new(Mutex::new(SessionState::new()));
let err = add_rule(
state,
AddRuleInput {
pattern: "*".into(),
action: InterceptAction::Block,
},
)
.await
.expect_err("expected BrowserNotOpen");
assert!(err.message.contains("Browser not open"));
}
#[test]
fn merge_headers_overlays_case_insensitively_and_drops_originals() {
let original = vec![
("Host".to_string(), "example.com".to_string()),
("User-Agent".to_string(), "old".to_string()),
("Accept".to_string(), "*/*".to_string()),
];
let mut overlay = BTreeMap::new();
overlay.insert("user-agent".to_string(), "new".to_string());
overlay.insert("X-Marker".to_string(), "yes".to_string());
let ov = merge_headers(&original, &overlay);
let headers = ov.headers.expect("headers populated");
let names_lower: Vec<String> = headers
.iter()
.map(|(k, _)| k.to_ascii_lowercase())
.collect();
assert!(names_lower.contains(&"host".into()));
assert!(names_lower.contains(&"accept".into()));
assert!(names_lower.contains(&"user-agent".into()));
assert!(names_lower.contains(&"x-marker".into()));
assert_eq!(names_lower.iter().filter(|n| *n == "user-agent").count(), 1);
let ua = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("user-agent"))
.map(|(_, v)| v.as_str())
.expect("user-agent present");
assert_eq!(ua, "new");
}
}