use crate::interceptor::{BoxError, HttpBody};
use crate::proxy::body_codec::process_body;
use hyper::body::{Body, Bytes, Frame, SizeHint};
use relay_core_api::flow::{BodyData, Direction, FlowUpdate};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc::Sender;
pub struct TapBody {
inner: HttpBody,
flow_id: String,
on_flow: Sender<FlowUpdate>,
direction: Direction,
buffer: Vec<u8>,
limit: usize,
headers: Vec<(String, String)>,
pub budget_exceeded: bool,
pub total_bytes: u64,
}
impl TapBody {
pub fn new(
inner: HttpBody,
flow_id: String,
on_flow: Sender<FlowUpdate>,
direction: Direction,
limit: usize,
headers: Vec<(String, String)>,
) -> Self {
crate::metrics::inc_proxy_stream_mode_tap();
Self {
inner,
flow_id,
on_flow,
direction,
buffer: Vec::new(),
limit,
headers,
budget_exceeded: false,
total_bytes: 0,
}
}
}
impl Body for TapBody {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match Pin::new(&mut self.inner).poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
self.total_bytes += data.len() as u64;
if self.buffer.len() < self.limit {
let len = std::cmp::min(data.len(), self.limit - self.buffer.len());
self.buffer.extend_from_slice(&data[..len]);
}
if self.buffer.len() >= self.limit {
self.budget_exceeded = true;
}
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(None) => {
let (encoding, content) = process_body(&self.buffer, &self.headers);
let body_data = BodyData {
encoding,
content,
size: self.total_bytes, };
let _ = self.on_flow.try_send(FlowUpdate::HttpBody {
flow_id: self.flow_id.clone(),
direction: self.direction.clone(),
body: body_data,
});
if self.budget_exceeded {
crate::metrics::inc_proxy_body_degraded();
crate::metrics::inc_proxy_stream_mode_degrade();
let _ = self.on_flow.try_send(FlowUpdate::BodyBudgetExceeded {
flow_id: self.flow_id.clone(),
direction: self.direction.clone(),
});
}
Poll::Ready(None)
}
other => other,
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::body::Frame;
use relay_core_api::flow::Direction;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
struct DataThenTrailers {
phase: u8,
}
impl Body for DataThenTrailers {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.phase {
0 => {
self.phase = 1;
Poll::Ready(Some(Ok(Frame::data(Bytes::from("hello")))))
}
1 => {
self.phase = 2;
let mut trailers = hyper::HeaderMap::new();
trailers.insert("x-trailer", "value".parse().unwrap());
Poll::Ready(Some(Ok(Frame::trailers(trailers))))
}
_ => Poll::Ready(None),
}
}
}
#[tokio::test]
async fn test_tap_body_passes_trailers() {
let body: HttpBody = DataThenTrailers { phase: 0 }.boxed();
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let mut tap = TapBody::new(
body,
"test-flow".to_string(),
tx,
Direction::ServerToClient,
4096,
vec![],
);
let waker = Waker::noop();
let mut cx = Context::from_waker(&waker);
let mut data_frames = 0;
let mut trailer_frames = 0;
let mut trailers: Option<hyper::HeaderMap> = None;
loop {
match Pin::new(&mut tap).poll_frame(&mut cx) {
Poll::Ready(Some(Ok(frame))) => {
if frame.data_ref().is_some() {
data_frames += 1;
}
if let Some(t) = frame.trailers_ref() {
trailer_frames += 1;
trailers = Some(t.clone());
}
}
Poll::Ready(Some(Err(e))) => panic!("unexpected error: {}", e),
Poll::Ready(None) => break,
Poll::Pending => panic!("unexpected pending"),
}
}
assert_eq!(data_frames, 1, "should forward 1 data frame");
assert_eq!(trailer_frames, 1, "should forward 1 trailers frame");
let trailers = trailers.expect("trailers should be present");
assert_eq!(
trailers.get("x-trailer").and_then(|v| v.to_str().ok()),
Some("value"),
"trailer x-trailer should be preserved"
);
let event = rx.try_recv().expect("should emit HttpBody event");
match event {
FlowUpdate::HttpBody { body, .. } => {
assert_eq!(body.size, 5, "body size should match data");
}
other => panic!("expected HttpBody, got {:?}", other),
}
}
}