use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use crate::hooks::{HookContext, HookDecision, HookEvent};
use crate::{Error, Result};
pub const HOOK_BRIDGE_PROTOCOL_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WireContext {
pub url: String,
pub depth: u32,
pub response_status: Option<u16>,
pub response_headers: Option<HashMap<String, String>>,
pub html_present: bool,
pub body_size: Option<usize>,
pub captured_urls: Vec<String>,
pub proxy: Option<String>,
pub retry_count: u32,
pub allow_retry: bool,
pub robots_allowed: Option<bool>,
pub user_data: HashMap<String, serde_json::Value>,
pub error: Option<String>,
}
impl WireContext {
pub fn from_context(ctx: &HookContext) -> Self {
Self {
url: ctx.url.to_string(),
depth: ctx.depth,
response_status: ctx.response_status,
response_headers: ctx.response_headers.as_ref().map(|h| {
h.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_string())))
.collect()
}),
html_present: ctx.html_post_js.is_some(),
body_size: ctx.body.as_ref().map(|b| b.len()),
captured_urls: ctx.captured_urls.iter().map(|u| u.to_string()).collect(),
proxy: ctx.proxy.as_ref().map(|u| u.to_string()),
retry_count: ctx.retry_count,
allow_retry: ctx.allow_retry,
robots_allowed: ctx.robots_allowed,
user_data: ctx.user_data.clone(),
error: ctx.error.clone(),
}
}
}
pub fn event_wire_name(event: HookEvent) -> &'static str {
match event {
HookEvent::BeforeEachRequest => "before_each_request",
HookEvent::AfterDnsResolve => "after_dns_resolve",
HookEvent::AfterTlsHandshake => "after_tls_handshake",
HookEvent::AfterFirstByte => "after_first_byte",
HookEvent::OnResponseBody => "on_response_body",
HookEvent::AfterLoad => "after_load",
HookEvent::AfterIdle => "after_idle",
HookEvent::OnDiscovery => "on_discovery",
HookEvent::OnJobStart => "on_job_start",
HookEvent::OnJobEnd => "on_job_end",
HookEvent::OnError => "on_error",
HookEvent::OnRobotsDecision => "on_robots_decision",
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum BridgeOutbound {
Hello { v: u32, protocol: String },
HookInvoke {
id: u64,
event: String,
ctx: WireContext,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WireDecision {
Continue,
Skip,
Retry,
Abort,
}
impl From<WireDecision> for HookDecision {
fn from(d: WireDecision) -> Self {
match d {
WireDecision::Continue => HookDecision::Continue,
WireDecision::Skip => HookDecision::Skip,
WireDecision::Retry => HookDecision::Retry,
WireDecision::Abort => HookDecision::Abort,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ContextPatch {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub captured_urls: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub robots_allowed: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allow_retry: Option<bool>,
}
impl ContextPatch {
pub fn apply(self, ctx: &mut HookContext) {
if let Some(urls) = self.captured_urls {
ctx.captured_urls = urls
.into_iter()
.filter_map(|u| url::Url::parse(&u).ok())
.collect();
}
if let Some(ud) = self.user_data {
ctx.user_data = ud;
}
if let Some(r) = self.robots_allowed {
ctx.robots_allowed = Some(r);
}
if let Some(r) = self.allow_retry {
ctx.allow_retry = r;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum BridgeInbound {
Subscribe { subscribed: Vec<String> },
HookResult {
id: u64,
decision: WireDecision,
#[serde(default)]
patch: ContextPatch,
},
}
#[async_trait::async_trait]
pub trait BridgeChannel: Send + Sync {
async fn send(&self, msg: &BridgeOutbound) -> Result<()>;
async fn recv(&self) -> Result<BridgeInbound>;
}
type Pending = HashMap<u64, oneshot::Sender<(WireDecision, ContextPatch)>>;
pub struct BridgeHookAdapter {
channel: Arc<dyn BridgeChannel>,
pending: Arc<Mutex<Pending>>,
next_id: AtomicU64,
subscribed: parking_lot::RwLock<Vec<HookEvent>>,
timeout: std::time::Duration,
}
impl BridgeHookAdapter {
pub fn new(channel: Arc<dyn BridgeChannel>) -> Self {
Self {
channel,
pending: Arc::new(Mutex::new(HashMap::new())),
next_id: AtomicU64::new(1),
subscribed: parking_lot::RwLock::new(Vec::new()),
timeout: std::time::Duration::from_secs(5),
}
}
pub fn with_timeout(mut self, d: std::time::Duration) -> Self {
self.timeout = d;
self
}
pub async fn handshake(&self) -> Result<()> {
self.channel
.send(&BridgeOutbound::Hello {
v: HOOK_BRIDGE_PROTOCOL_VERSION,
protocol: "crawlex.hooks".into(),
})
.await
}
pub async fn pump_once(&self) -> Result<()> {
let msg = self.channel.recv().await?;
match msg {
BridgeInbound::Subscribe { subscribed } => {
let parsed: Vec<HookEvent> = subscribed
.into_iter()
.filter_map(|s| event_from_wire(&s))
.collect();
*self.subscribed.write() = parsed;
}
BridgeInbound::HookResult {
id,
decision,
patch,
} => {
if let Some(tx) = self.pending.lock().remove(&id) {
let _ = tx.send((decision, patch));
} else {
tracing::debug!(id, "hook bridge: no pending entry for id");
}
}
}
Ok(())
}
pub async fn invoke(&self, event: HookEvent, ctx: &mut HookContext) -> Result<HookDecision> {
let subscribed_to = self.subscribed.read().contains(&event);
if !subscribed_to {
return Ok(HookDecision::Continue);
}
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = oneshot::channel();
self.pending.lock().insert(id, tx);
self.channel
.send(&BridgeOutbound::HookInvoke {
id,
event: event_wire_name(event).to_string(),
ctx: WireContext::from_context(ctx),
})
.await?;
let resp = tokio::time::timeout(self.timeout, rx)
.await
.map_err(|_| {
self.pending.lock().remove(&id);
Error::Hook(format!(
"hook bridge timeout: event={} id={id} budget_ms={}",
event_wire_name(event),
self.timeout.as_millis()
))
})?
.map_err(|_| Error::Hook("hook bridge channel closed before reply".into()))?;
let (decision, patch) = resp;
patch.apply(ctx);
Ok(decision.into())
}
}
pub struct StdioBridgeChannel {
writer: tokio::sync::Mutex<tokio::io::BufWriter<tokio::io::Stdout>>,
reader: tokio::sync::Mutex<tokio::io::BufReader<tokio::io::Stdin>>,
}
impl Default for StdioBridgeChannel {
fn default() -> Self {
Self::new()
}
}
impl StdioBridgeChannel {
pub fn new() -> Self {
Self {
writer: tokio::sync::Mutex::new(tokio::io::BufWriter::new(tokio::io::stdout())),
reader: tokio::sync::Mutex::new(tokio::io::BufReader::new(tokio::io::stdin())),
}
}
}
#[async_trait::async_trait]
impl BridgeChannel for StdioBridgeChannel {
async fn send(&self, msg: &BridgeOutbound) -> Result<()> {
use tokio::io::AsyncWriteExt;
let line =
serde_json::to_string(msg).map_err(|e| Error::Hook(format!("serialize: {e}")))?;
let mut w = self.writer.lock().await;
w.write_all(line.as_bytes())
.await
.map_err(|e| Error::Hook(format!("stdout write: {e}")))?;
w.write_all(b"\n")
.await
.map_err(|e| Error::Hook(format!("stdout newline: {e}")))?;
w.flush()
.await
.map_err(|e| Error::Hook(format!("stdout flush: {e}")))?;
Ok(())
}
async fn recv(&self) -> Result<BridgeInbound> {
use tokio::io::AsyncBufReadExt;
let mut buf = String::new();
let mut r = self.reader.lock().await;
let n = r
.read_line(&mut buf)
.await
.map_err(|e| Error::Hook(format!("stdin read: {e}")))?;
if n == 0 {
return Err(Error::Hook("stdin closed (EOF)".into()));
}
serde_json::from_str(buf.trim())
.map_err(|e| Error::Hook(format!("stdin parse: {e} (line={buf:?})")))
}
}
pub fn parse_bridge_spec(spec: &str) -> Result<Arc<dyn BridgeChannel>> {
match spec.trim() {
"stdio" => Ok(Arc::new(StdioBridgeChannel::new())),
other => Err(Error::Hook(format!(
"unsupported --hook-bridge spec: {other:?} (expected `stdio`)"
))),
}
}
pub fn event_from_wire(s: &str) -> Option<HookEvent> {
Some(match s {
"before_each_request" => HookEvent::BeforeEachRequest,
"after_dns_resolve" => HookEvent::AfterDnsResolve,
"after_tls_handshake" => HookEvent::AfterTlsHandshake,
"after_first_byte" => HookEvent::AfterFirstByte,
"on_response_body" => HookEvent::OnResponseBody,
"after_load" => HookEvent::AfterLoad,
"after_idle" => HookEvent::AfterIdle,
"on_discovery" => HookEvent::OnDiscovery,
"on_job_start" => HookEvent::OnJobStart,
"on_job_end" => HookEvent::OnJobEnd,
"on_error" => HookEvent::OnError,
"on_robots_decision" => HookEvent::OnRobotsDecision,
_ => return None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
use tokio::sync::Mutex as TokioMutex;
struct ChannelPair {
outbound_tx: mpsc::UnboundedSender<BridgeOutbound>,
inbound_rx: TokioMutex<mpsc::UnboundedReceiver<BridgeInbound>>,
}
#[async_trait::async_trait]
impl BridgeChannel for ChannelPair {
async fn send(&self, msg: &BridgeOutbound) -> Result<()> {
self.outbound_tx
.send(msg.clone())
.map_err(|e| Error::Hook(format!("test send: {e}")))
}
async fn recv(&self) -> Result<BridgeInbound> {
self.inbound_rx
.lock()
.await
.recv()
.await
.ok_or_else(|| Error::Hook("test channel closed".into()))
}
}
fn ctx() -> HookContext {
HookContext::new(url::Url::parse("https://example.test/p").unwrap(), 0)
}
#[tokio::test]
async fn invoke_short_circuits_unsubscribed_events() {
let (out_tx, mut out_rx) = mpsc::unbounded_channel();
let (_in_tx, in_rx) = mpsc::unbounded_channel();
let channel = Arc::new(ChannelPair {
outbound_tx: out_tx,
inbound_rx: TokioMutex::new(in_rx),
});
let adapter = BridgeHookAdapter::new(channel);
let mut cx = ctx();
let d = adapter
.invoke(HookEvent::AfterFirstByte, &mut cx)
.await
.unwrap();
assert_eq!(d, HookDecision::Continue);
assert!(out_rx.try_recv().is_err());
}
#[tokio::test]
async fn invoke_round_trip_applies_patch_and_decision() {
let (out_tx, mut out_rx) = mpsc::unbounded_channel();
let (in_tx, in_rx) = mpsc::unbounded_channel();
let channel = Arc::new(ChannelPair {
outbound_tx: out_tx,
inbound_rx: TokioMutex::new(in_rx),
});
let adapter = Arc::new(BridgeHookAdapter::new(channel));
in_tx
.send(BridgeInbound::Subscribe {
subscribed: vec!["on_discovery".into()],
})
.unwrap();
let pump = adapter.clone();
tokio::spawn(async move {
for _ in 0..10 {
if pump.pump_once().await.is_err() {
break;
}
}
});
let invoke_adapter = adapter.clone();
let invoke_task = tokio::spawn(async move {
let mut cx = ctx();
cx.captured_urls
.push(url::Url::parse("https://example.test/keep").unwrap());
let d = invoke_adapter
.invoke(HookEvent::OnDiscovery, &mut cx)
.await
.unwrap();
(d, cx)
});
let outbound = loop {
if let Ok(msg) = out_rx.try_recv() {
break msg;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
};
let id = match outbound {
BridgeOutbound::HookInvoke { id, .. } => id,
_ => panic!("expected hook.invoke first"),
};
in_tx
.send(BridgeInbound::HookResult {
id,
decision: WireDecision::Skip,
patch: ContextPatch {
captured_urls: Some(vec!["https://example.test/swap".into()]),
robots_allowed: Some(false),
..Default::default()
},
})
.unwrap();
let (decision, cx) = invoke_task.await.unwrap();
assert_eq!(decision, HookDecision::Skip);
assert_eq!(cx.captured_urls.len(), 1);
assert_eq!(cx.captured_urls[0].path(), "/swap");
assert_eq!(cx.robots_allowed, Some(false));
}
}