use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Limited};
use hyper::{Request, Response};
use proxyapi_models::{ProxiedRequest, ProxiedResponse};
use tokio::sync::mpsc;
#[cfg(feature = "scripting")]
use std::sync::Arc;
use crate::body::{self, ProxyBody};
use crate::event::{next_id, ProxyEvent};
use crate::{HttpContext, HttpHandler, RequestOrResponse};
const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;
fn now_millis() -> i64 {
chrono::Local::now().timestamp_millis()
}
pub fn collect_and_emit(
handler: &mut CapturingHandler,
#[allow(unused_mut)] mut parts: http::response::Parts,
#[allow(unused_mut)] mut body_bytes: Bytes,
) -> Response<ProxyBody> {
#[cfg(feature = "scripting")]
if let Some(ref engine) = handler.script_engine {
let (req_method, req_url) = handler
.captured_request
.as_ref()
.map(|r| (r.method().as_str().to_owned(), r.uri().to_string()))
.unwrap_or_default();
match engine.on_response(
&req_method,
&req_url,
parts.status.as_u16(),
&parts.headers,
&body_bytes,
) {
Ok(crate::scripting::ScriptResponseAction::Modified {
status,
headers,
body,
}) => {
if let Ok(s) = http::StatusCode::from_u16(status) {
parts.status = s;
}
parts.headers = headers;
body_bytes = body;
}
Ok(crate::scripting::ScriptResponseAction::PassThrough) => {}
Err(e) => {
tracing::warn!("Lua on_response error (passing through): {e}");
}
}
}
let proxied_response = ProxiedResponse::new(
parts.status,
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
if let Some(request) = handler.take_captured_request() {
let event = ProxyEvent::RequestComplete {
id: next_id(),
request: Box::new(request),
response: Box::new(proxied_response),
};
handler.send_event(event);
}
Response::from_parts(parts, body::full(body_bytes))
}
pub async fn collect_body(body: hyper::body::Incoming) -> Bytes {
Limited::new(body, MAX_BODY_SIZE)
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.unwrap_or_else(|e| {
tracing::warn!(
"Failed to collect response body (possibly exceeds {}MB limit): {e}",
MAX_BODY_SIZE / (1024 * 1024)
);
Bytes::new()
})
}
#[derive(Clone)]
pub struct CapturingHandler {
event_tx: mpsc::Sender<ProxyEvent>,
captured_request: Option<ProxiedRequest>,
#[cfg(feature = "scripting")]
script_engine: Option<Arc<crate::scripting::ScriptEngine>>,
}
impl std::fmt::Debug for CapturingHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapturingHandler")
.field("event_tx", &self.event_tx)
.field("captured_request", &self.captured_request)
.finish()
}
}
impl CapturingHandler {
#[must_use]
pub fn new(event_tx: mpsc::Sender<ProxyEvent>) -> Self {
Self {
event_tx,
captured_request: None,
#[cfg(feature = "scripting")]
script_engine: None,
}
}
#[cfg(feature = "scripting")]
#[must_use]
pub fn with_script_engine(mut self, engine: Arc<crate::scripting::ScriptEngine>) -> Self {
self.script_engine = Some(engine);
self
}
pub(crate) fn take_captured_request(&mut self) -> Option<ProxiedRequest> {
self.captured_request.take()
}
pub(crate) fn send_event(&self, event: ProxyEvent) {
match self.event_tx.try_send(event) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::warn!("Event channel full, dropping event");
}
Err(mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Event channel closed");
}
}
}
}
#[async_trait]
impl HttpHandler for CapturingHandler {
async fn handle_request(
&mut self,
_ctx: &HttpContext,
req: Request<hyper::body::Incoming>,
) -> RequestOrResponse {
#[allow(unused_mut)]
let (mut parts, incoming) = req.into_parts();
#[allow(unused_mut)]
let mut body_bytes = Limited::new(incoming, MAX_BODY_SIZE)
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.unwrap_or_else(|e| {
tracing::warn!("Failed to collect request body: {e}");
Bytes::new()
});
#[cfg(feature = "scripting")]
if let Some(ref engine) = self.script_engine {
match engine.on_request(
parts.method.as_str(),
&parts.uri.to_string(),
&parts.headers,
&body_bytes,
) {
Ok(crate::scripting::ScriptRequestAction::Forward {
method,
url,
headers,
body,
}) => {
if let Ok(m) = method.parse() {
parts.method = m;
}
if let Ok(u) = url.parse() {
parts.uri = u;
}
parts.headers = headers;
body_bytes = body;
}
Ok(crate::scripting::ScriptRequestAction::ShortCircuit {
status,
headers,
body,
}) => {
let proxied_request = ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
self.captured_request = Some(proxied_request);
let status_code = http::StatusCode::from_u16(status)
.unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
let mut builder = Response::builder().status(status_code);
if let Some(h) = builder.headers_mut() {
*h = headers;
}
let response = builder
.body(body::full(body))
.unwrap_or_else(|_| Response::new(body::empty()));
return RequestOrResponse::Response(response);
}
Ok(crate::scripting::ScriptRequestAction::PassThrough) => {}
Err(e) => {
tracing::warn!("Lua on_request error (passing through): {e}");
}
}
}
let proxied_request = ProxiedRequest::new(
parts.method.clone(),
parts.uri.clone(),
parts.version,
parts.headers.clone(),
body_bytes.clone(),
now_millis(),
);
self.captured_request = Some(proxied_request);
let req = Request::from_parts(parts, body::full(body_bytes));
RequestOrResponse::Request(req)
}
async fn handle_response(
&mut self,
_ctx: &HttpContext,
res: Response<hyper::body::Incoming>,
) -> Response<ProxyBody> {
let (parts, incoming) = res.into_parts();
let body_bytes = collect_body(incoming).await;
collect_and_emit(self, parts, body_bytes)
}
}