relay_core_lib/proxy/
tap.rs1use crate::interceptor::{BoxError, HttpBody};
2use crate::proxy::body_codec::process_body;
3use hyper::body::{Body, Bytes, Frame, SizeHint};
4use relay_core_api::flow::{BodyData, Direction, FlowUpdate};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::sync::mpsc::Sender;
8
9pub struct TapBody {
10 inner: HttpBody,
11 flow_id: String,
12 on_flow: Sender<FlowUpdate>,
13 direction: Direction,
14 buffer: Vec<u8>,
15 limit: usize,
16 headers: Vec<(String, String)>,
17 pub budget_exceeded: bool,
19 pub total_bytes: u64,
21}
22
23impl TapBody {
24 pub fn new(
25 inner: HttpBody,
26 flow_id: String,
27 on_flow: Sender<FlowUpdate>,
28 direction: Direction,
29 limit: usize,
30 headers: Vec<(String, String)>,
31 ) -> Self {
32 crate::metrics::inc_proxy_stream_mode_tap();
33 Self {
34 inner,
35 flow_id,
36 on_flow,
37 direction,
38 buffer: Vec::new(),
39 limit,
40 headers,
41 budget_exceeded: false,
42 total_bytes: 0,
43 }
44 }
45}
46
47impl Body for TapBody {
48 type Data = Bytes;
49 type Error = BoxError;
50
51 fn poll_frame(
52 mut self: Pin<&mut Self>,
53 cx: &mut Context<'_>,
54 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
55 match Pin::new(&mut self.inner).poll_frame(cx) {
56 Poll::Ready(Some(Ok(frame))) => {
57 if let Some(data) = frame.data_ref() {
58 self.total_bytes += data.len() as u64;
59 if self.buffer.len() < self.limit {
60 let len = std::cmp::min(data.len(), self.limit - self.buffer.len());
61 self.buffer.extend_from_slice(&data[..len]);
62 }
63 if self.buffer.len() >= self.limit {
64 self.budget_exceeded = true;
65 }
66 }
67 Poll::Ready(Some(Ok(frame)))
68 }
69 Poll::Ready(None) => {
70 let (encoding, content) = process_body(&self.buffer, &self.headers);
71 let body_data = BodyData {
72 encoding,
73 content,
74 size: self.total_bytes, };
76
77 let _ = self.on_flow.try_send(FlowUpdate::HttpBody {
78 flow_id: self.flow_id.clone(),
79 direction: self.direction.clone(),
80 body: body_data,
81 });
82
83 if self.budget_exceeded {
85 crate::metrics::inc_proxy_body_degraded();
86 crate::metrics::inc_proxy_stream_mode_degrade();
87 let _ = self.on_flow.try_send(FlowUpdate::BodyBudgetExceeded {
88 flow_id: self.flow_id.clone(),
89 direction: self.direction.clone(),
90 });
91 }
92
93 Poll::Ready(None)
94 }
95 other => other,
96 }
97 }
98
99 fn is_end_stream(&self) -> bool {
100 self.inner.is_end_stream()
101 }
102
103 fn size_hint(&self) -> SizeHint {
104 self.inner.size_hint()
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use bytes::Bytes;
112 use http_body_util::BodyExt;
113 use hyper::body::Frame;
114 use relay_core_api::flow::Direction;
115 use std::pin::Pin;
116 use std::task::{Context, Poll, Waker};
117
118 struct DataThenTrailers {
120 phase: u8,
121 }
122
123 impl Body for DataThenTrailers {
124 type Data = Bytes;
125 type Error = BoxError;
126
127 fn poll_frame(
128 mut self: Pin<&mut Self>,
129 _cx: &mut Context<'_>,
130 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
131 match self.phase {
132 0 => {
133 self.phase = 1;
134 Poll::Ready(Some(Ok(Frame::data(Bytes::from("hello")))))
135 }
136 1 => {
137 self.phase = 2;
138 let mut trailers = hyper::HeaderMap::new();
139 trailers.insert("x-trailer", "value".parse().unwrap());
140 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
141 }
142 _ => Poll::Ready(None),
143 }
144 }
145 }
146
147 #[tokio::test]
150 async fn test_tap_body_passes_trailers() {
151 let body: HttpBody = DataThenTrailers { phase: 0 }.boxed();
152 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
153
154 let mut tap = TapBody::new(
155 body,
156 "test-flow".to_string(),
157 tx,
158 Direction::ServerToClient,
159 4096,
160 vec![],
161 );
162
163 let waker = Waker::noop();
165 let mut cx = Context::from_waker(&waker);
166
167 let mut data_frames = 0;
168 let mut trailer_frames = 0;
169 let mut trailers: Option<hyper::HeaderMap> = None;
170
171 loop {
172 match Pin::new(&mut tap).poll_frame(&mut cx) {
173 Poll::Ready(Some(Ok(frame))) => {
174 if frame.data_ref().is_some() {
175 data_frames += 1;
176 }
177 if let Some(t) = frame.trailers_ref() {
178 trailer_frames += 1;
179 trailers = Some(t.clone());
180 }
181 }
182 Poll::Ready(Some(Err(e))) => panic!("unexpected error: {}", e),
183 Poll::Ready(None) => break,
184 Poll::Pending => panic!("unexpected pending"),
185 }
186 }
187
188 assert_eq!(data_frames, 1, "should forward 1 data frame");
190 assert_eq!(trailer_frames, 1, "should forward 1 trailers frame");
191 let trailers = trailers.expect("trailers should be present");
192 assert_eq!(
193 trailers.get("x-trailer").and_then(|v| v.to_str().ok()),
194 Some("value"),
195 "trailer x-trailer should be preserved"
196 );
197
198 let event = rx.try_recv().expect("should emit HttpBody event");
200 match event {
201 FlowUpdate::HttpBody { body, .. } => {
202 assert_eq!(body.size, 5, "body size should match data");
203 }
204 other => panic!("expected HttpBody, got {:?}", other),
205 }
206 }
207}