use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Limited};
use hyper::{Request, Response};
use proxyapi_models::{ProxiedRequest, ProxiedResponse};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::body::{self, ProxyBody};
use crate::event::{next_id, ProxyEvent};
use crate::intercept::{InterceptConfig, InterceptDecision};
use crate::{HttpContext, HttpHandler, RequestOrResponse};
const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;
fn now_millis() -> i64 {
chrono::Local::now().timestamp_millis()
}
pub fn collect_and_emit(
handler: &mut CapturingHandler,
#[allow(unused_mut)] mut parts: http::response::Parts,
#[allow(unused_mut)] mut body_bytes: Bytes,
) -> Response<ProxyBody> {
#[cfg(feature = "scripting")]
if let Some(ref engine) = handler.script_engine {
let (req_method, req_url) = handler
.captured_request
.as_ref()
.map(|r| (r.method().as_str().to_owned(), r.uri().to_string()))
.unwrap_or_default();
match engine.on_response(
&req_method,
&req_url,
parts.status.as_u16(),
&parts.headers,
&body_bytes,
) {
Ok(crate::scripting::ScriptResponseAction::Modified {
status,
headers,
body,
}) => {
if let Ok(s) = http::StatusCode::from_u16(status) {
parts.status = s;
}
parts.headers = headers;
body_bytes = body;
}
Ok(crate::scripting::ScriptResponseAction::PassThrough) => {}
Err(e) => {
tracing::warn!("Lua on_response error (passing through): {e}");
}
}
}
let proxied_response = ProxiedResponse::new(
parts.status,
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
if let Some(request) = handler.take_captured_request() {
let id = handler.pending_id.take().unwrap_or_else(next_id);
let event = ProxyEvent::RequestComplete {
id,
request: Box::new(request),
response: Box::new(proxied_response),
};
handler.send_event(event);
}
Response::from_parts(parts, body::full(body_bytes))
}
pub async fn collect_body(body: hyper::body::Incoming) -> Bytes {
Limited::new(body, MAX_BODY_SIZE)
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.unwrap_or_else(|e| {
tracing::warn!(
"Failed to collect response body (possibly exceeds {}MB limit): {e}",
MAX_BODY_SIZE / (1024 * 1024)
);
Bytes::new()
})
}
#[derive(Clone)]
pub struct CapturingHandler {
event_tx: mpsc::Sender<ProxyEvent>,
captured_request: Option<ProxiedRequest>,
pending_id: Option<u64>,
intercept: Option<Arc<InterceptConfig>>,
#[cfg(feature = "scripting")]
script_engine: Option<Arc<crate::scripting::ScriptEngine>>,
}
impl std::fmt::Debug for CapturingHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapturingHandler")
.field("event_tx", &self.event_tx)
.field("captured_request", &self.captured_request)
.field("pending_id", &self.pending_id)
.finish_non_exhaustive()
}
}
impl CapturingHandler {
#[must_use]
pub fn new(event_tx: mpsc::Sender<ProxyEvent>) -> Self {
Self {
event_tx,
captured_request: None,
pending_id: None,
intercept: None,
#[cfg(feature = "scripting")]
script_engine: None,
}
}
#[must_use]
pub fn with_intercept(mut self, cfg: Arc<InterceptConfig>) -> Self {
self.intercept = Some(cfg);
self
}
#[cfg(feature = "scripting")]
#[must_use]
pub fn with_script_engine(mut self, engine: Arc<crate::scripting::ScriptEngine>) -> Self {
self.script_engine = Some(engine);
self
}
pub(crate) fn take_captured_request(&mut self) -> Option<ProxiedRequest> {
self.captured_request.take()
}
pub(crate) fn send_event(&self, event: ProxyEvent) {
match self.event_tx.try_send(event) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::warn!("Event channel full, dropping event");
}
Err(mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Event channel closed");
}
}
}
}
#[async_trait]
impl HttpHandler for CapturingHandler {
async fn handle_request(
&mut self,
_ctx: &HttpContext,
req: Request<hyper::body::Incoming>,
) -> RequestOrResponse {
let id = next_id();
self.pending_id = Some(id);
let (mut parts, incoming) = req.into_parts();
let mut body_bytes = Limited::new(incoming, MAX_BODY_SIZE)
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.unwrap_or_else(|e| {
tracing::warn!("Failed to collect request body: {e}");
Bytes::new()
});
#[cfg(feature = "scripting")]
if let Some(ref engine) = self.script_engine {
match engine.on_request(
parts.method.as_str(),
&parts.uri.to_string(),
&parts.headers,
&body_bytes,
) {
Ok(crate::scripting::ScriptRequestAction::Forward {
method,
url,
headers,
body,
}) => {
if let Ok(m) = method.parse() {
parts.method = m;
}
if let Ok(u) = url.parse() {
parts.uri = u;
}
parts.headers = headers;
body_bytes = body;
}
Ok(crate::scripting::ScriptRequestAction::ShortCircuit {
status,
headers,
body,
}) => {
let proxied_request = ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
self.captured_request = Some(proxied_request);
let status_code = http::StatusCode::from_u16(status)
.unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
let mut builder = Response::builder().status(status_code);
if let Some(h) = builder.headers_mut() {
*h = headers;
}
let response = builder
.body(body::full(body))
.unwrap_or_else(|_| Response::new(body::empty()));
return RequestOrResponse::Response(response);
}
Ok(crate::scripting::ScriptRequestAction::PassThrough) => {}
Err(e) => {
tracing::warn!("Lua on_request error (passing through): {e}");
}
}
}
if let Some(ref cfg) = self.intercept {
if cfg.is_enabled() {
let snapshot = ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
let rx = cfg.register(id);
let event = ProxyEvent::RequestIntercepted {
id,
request: Box::new(snapshot.clone()),
};
if self.event_tx.try_send(event).is_err() {
cfg.resolve(id, InterceptDecision::Forward);
tracing::warn!("Event channel full, skipping intercept for id={id}");
} else {
self.captured_request = Some(snapshot);
match tokio::time::timeout(std::time::Duration::from_secs(300), rx).await {
Ok(Ok(InterceptDecision::Forward)) => {
}
Ok(Ok(InterceptDecision::Modified {
method,
uri,
headers,
body,
})) => {
if let Ok(m) = method.parse() {
parts.method = m;
}
if let Ok(u) = uri.parse() {
parts.uri = u;
}
parts.headers = headers;
body_bytes = body;
self.captured_request = Some(ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
));
}
Ok(Ok(InterceptDecision::Block { status, body })) => {
let status_code = http::StatusCode::from_u16(status)
.unwrap_or(http::StatusCode::BAD_GATEWAY);
let response = Response::builder()
.status(status_code)
.body(body::full(body))
.unwrap_or_else(|_| Response::new(body::empty()));
return RequestOrResponse::Response(response);
}
_ => {
tracing::warn!("Intercept timed out for id={id}, returning 504");
let response = Response::builder()
.status(http::StatusCode::GATEWAY_TIMEOUT)
.body(body::empty())
.unwrap_or_else(|_| Response::new(body::empty()));
return RequestOrResponse::Response(response);
}
}
let req = Request::from_parts(parts, body::full(body_bytes));
return RequestOrResponse::Request(req);
}
}
}
let proxied_request = ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
self.captured_request = Some(proxied_request);
let req = Request::from_parts(parts, body::full(body_bytes));
RequestOrResponse::Request(req)
}
async fn handle_response(
&mut self,
_ctx: &HttpContext,
res: Response<hyper::body::Incoming>,
) -> Response<ProxyBody> {
let (parts, incoming) = res.into_parts();
let body_bytes = collect_body(incoming).await;
collect_and_emit(self, parts, body_bytes)
}
}