use crate::handler::{
extract_body_bytes, extract_response_body_bytes, put_body_back, put_response_body_back,
BoxBody, Buffered, Dropped, RequestHandler,
};
use bytes::Bytes;
use hyper::header::HeaderMap;
use hyper::{Method, Request, Response, StatusCode, Uri, Version};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use tracing::info;
pub type InterceptId = u64;
pub enum InterceptedItem {
Request {
id: InterceptId,
method: Method,
uri: Uri,
version: Version,
headers: HeaderMap,
body: Bytes,
reply: mpsc::Sender<Verdict>,
},
Response {
id: InterceptId,
status: StatusCode,
version: Version,
headers: HeaderMap,
body: Bytes,
reply: mpsc::Sender<Verdict>,
},
}
pub enum Verdict {
Forward {
headers: Box<HeaderMap>,
body: Bytes,
method: Option<Method>,
uri: Option<Uri>,
status: Option<StatusCode>,
},
Drop,
}
pub struct InterceptHandler {
tx: mpsc::SyncSender<InterceptedItem>,
active: Arc<AtomicBool>,
next_id: AtomicU64,
}
impl InterceptHandler {
pub fn new(tx: mpsc::SyncSender<InterceptedItem>, active: Arc<AtomicBool>) -> Self {
Self {
tx,
active,
next_id: AtomicU64::new(1),
}
}
}
impl RequestHandler for InterceptHandler {
fn handle_request(&self, req: &mut Request<BoxBody>) {
let path = req.uri().path();
let display_uri = if req.uri().query().is_some() {
format!("{path}?...")
} else {
path.to_string()
};
info!(">> {} {} {:?}", req.method(), display_uri, req.version());
if !self.active.load(Ordering::Relaxed)
|| req.extensions().get::<Buffered>().is_none()
{
return;
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let body_bytes = extract_body_bytes(req);
let (reply_tx, reply_rx) = mpsc::channel();
let item = InterceptedItem::Request {
id,
method: req.method().clone(),
uri: req.uri().clone(),
version: req.version(),
headers: req.headers().clone(),
body: body_bytes.clone(),
reply: reply_tx,
};
let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
match send_result {
Ok(()) => {}
Err(_) => {
tracing::warn!("TUI disconnected, disabling interception");
self.active.store(false, Ordering::Relaxed);
put_body_back(req, body_bytes);
return;
}
}
match tokio::task::block_in_place(|| reply_rx.recv()) {
Ok(Verdict::Forward {
headers,
body,
..
}) => {
*req.headers_mut() = *headers;
let changed = body != body_bytes;
put_body_back(req, body.clone());
fix_headers_after_edit(req.headers_mut(), body.len(), changed);
}
Ok(Verdict::Drop) => {
req.extensions_mut().insert(Dropped);
put_body_back(req, Bytes::new());
fix_headers_after_edit(req.headers_mut(), 0, true);
}
Err(_) => {
put_body_back(req, body_bytes);
}
}
}
fn handle_response(&self, res: &mut Response<BoxBody>) {
info!("<< {}", res.status());
if !self.active.load(Ordering::Relaxed)
|| res.extensions().get::<Buffered>().is_none()
{
return;
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let body_bytes = extract_response_body_bytes(res);
let (reply_tx, reply_rx) = mpsc::channel();
let item = InterceptedItem::Response {
id,
status: res.status(),
version: res.version(),
headers: res.headers().clone(),
body: body_bytes.clone(),
reply: reply_tx,
};
let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
match send_result {
Ok(()) => {}
Err(_) => {
tracing::warn!("TUI disconnected, disabling interception");
self.active.store(false, Ordering::Relaxed);
put_response_body_back(res, body_bytes);
return;
}
}
match tokio::task::block_in_place(|| reply_rx.recv()) {
Ok(Verdict::Forward {
headers,
body,
status,
..
}) => {
*res.headers_mut() = *headers;
if let Some(s) = status {
*res.status_mut() = s;
}
let changed = body != body_bytes;
put_response_body_back(res, body.clone());
fix_headers_after_edit(res.headers_mut(), body.len(), changed);
}
Ok(Verdict::Drop) => {
res.extensions_mut().insert(Dropped);
put_response_body_back(res, Bytes::new());
fix_headers_after_edit(res.headers_mut(), 0, true);
}
Err(_) => {
put_response_body_back(res, body_bytes);
}
}
}
}
fn fix_headers_after_edit(headers: &mut HeaderMap, body_len: usize, body_changed: bool) {
for name in &[
hyper::header::CONNECTION,
hyper::header::PROXY_AUTHORIZATION,
hyper::header::PROXY_AUTHENTICATE,
hyper::header::TE,
hyper::header::TRAILER,
hyper::header::UPGRADE,
] {
headers.remove(name);
}
headers.remove("keep-alive");
if !body_changed {
return; }
headers.remove(hyper::header::TRANSFER_ENCODING);
headers.remove(hyper::header::CONTENT_ENCODING);
if body_len > 0 {
headers.insert(
hyper::header::CONTENT_LENGTH,
hyper::header::HeaderValue::from(body_len),
);
} else {
headers.remove(hyper::header::CONTENT_LENGTH);
}
}
pub fn is_text_body(body: &Bytes) -> bool {
body.is_empty() || std::str::from_utf8(body).is_ok()
}
pub fn serialize_request(
method: &Method,
uri: &Uri,
version: Version,
headers: &HeaderMap,
body: &Bytes,
) -> String {
let mut s = format!("{method} {uri} {version:?}\r\n");
for (name, value) in headers.iter() {
s.push_str(&format!(
"{}: {}\r\n",
name,
value.to_str().unwrap_or("<binary>")
));
}
s.push_str("\r\n");
if !body.is_empty() {
match std::str::from_utf8(body) {
Ok(text) => s.push_str(text),
Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
}
}
s
}
pub fn serialize_response(
status: StatusCode,
version: Version,
headers: &HeaderMap,
body: &Bytes,
) -> String {
let mut s = format!("{version:?} {status}\r\n");
for (name, value) in headers.iter() {
s.push_str(&format!(
"{}: {}\r\n",
name,
value.to_str().unwrap_or("<binary>")
));
}
s.push_str("\r\n");
if !body.is_empty() {
match std::str::from_utf8(body) {
Ok(text) => s.push_str(text),
Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
}
}
s
}
pub fn parse_request_text(text: &str) -> Option<(Method, Uri, HeaderMap, Bytes)> {
let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
let mut lines = head.lines();
let request_line = lines.next()?;
let mut parts = request_line.splitn(3, ' ');
let method: Method = parts.next()?.parse().ok()?;
let uri: Uri = parts.next()?.parse().ok()?;
let mut headers = HeaderMap::new();
for line in lines {
if let Some((name, value)) = line.split_once(": ") {
if let (Ok(n), Ok(v)) = (
name.parse::<hyper::header::HeaderName>(),
value.parse::<hyper::header::HeaderValue>(),
) {
headers.append(n, v);
}
}
}
Some((method, uri, headers, Bytes::from(body.to_string())))
}
pub fn parse_response_text(text: &str) -> Option<(StatusCode, HeaderMap, Bytes)> {
let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
let mut lines = head.lines();
let status_line = lines.next()?;
let status_str = status_line.split_once(' ')?.1;
let status_code: u16 = status_str.split_whitespace().next()?.parse().ok()?;
let status = StatusCode::from_u16(status_code).ok()?;
let mut headers = HeaderMap::new();
for line in lines {
if let Some((name, value)) = line.split_once(": ") {
if let (Ok(n), Ok(v)) = (
name.parse::<hyper::header::HeaderName>(),
value.parse::<hyper::header::HeaderValue>(),
) {
headers.append(n, v);
}
}
}
Some((status, headers, Bytes::from(body.to_string())))
}