Skip to main content

relay_core_lib/proxy/
tap.rs

1use crate::interceptor::{BoxError, HttpBody};
2use crate::proxy::body_codec::process_body;
3use hyper::body::{Body, Bytes, Frame, SizeHint};
4use relay_core_api::flow::{BodyData, Direction, FlowUpdate};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::sync::mpsc::Sender;
8
9pub struct TapBody {
10    inner: HttpBody,
11    flow_id: String,
12    on_flow: Sender<FlowUpdate>,
13    direction: Direction,
14    buffer: Vec<u8>,
15    limit: usize,
16    headers: Vec<(String, String)>,
17    /// Set to true when accumulated bytes exceed the limit.
18    pub budget_exceeded: bool,
19    /// Total bytes passed through.
20    pub total_bytes: u64,
21}
22
23impl TapBody {
24    pub fn new(
25        inner: HttpBody,
26        flow_id: String,
27        on_flow: Sender<FlowUpdate>,
28        direction: Direction,
29        limit: usize,
30        headers: Vec<(String, String)>,
31    ) -> Self {
32        crate::metrics::inc_proxy_stream_mode_tap();
33        Self {
34            inner,
35            flow_id,
36            on_flow,
37            direction,
38            buffer: Vec::new(),
39            limit,
40            headers,
41            budget_exceeded: false,
42            total_bytes: 0,
43        }
44    }
45}
46
47impl Body for TapBody {
48    type Data = Bytes;
49    type Error = BoxError;
50
51    fn poll_frame(
52        mut self: Pin<&mut Self>,
53        cx: &mut Context<'_>,
54    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
55        match Pin::new(&mut self.inner).poll_frame(cx) {
56            Poll::Ready(Some(Ok(frame))) => {
57                if let Some(data) = frame.data_ref() {
58                    self.total_bytes += data.len() as u64;
59                    if self.buffer.len() < self.limit {
60                        let len = std::cmp::min(data.len(), self.limit - self.buffer.len());
61                        self.buffer.extend_from_slice(&data[..len]);
62                    }
63                    if self.buffer.len() >= self.limit {
64                        self.budget_exceeded = true;
65                    }
66                }
67                Poll::Ready(Some(Ok(frame)))
68            }
69            Poll::Ready(None) => {
70                let (encoding, content) = process_body(&self.buffer, &self.headers);
71                let body_data = BodyData {
72                    encoding,
73                    content,
74                    size: self.total_bytes, // Report actual transfer size, not truncated buffer
75                };
76
77                let _ = self.on_flow.try_send(FlowUpdate::HttpBody {
78                    flow_id: self.flow_id.clone(),
79                    direction: self.direction.clone(),
80                    body: body_data,
81                });
82
83                // P1: Notify budget exceeded for streaming-first pipeline
84                if self.budget_exceeded {
85                    crate::metrics::inc_proxy_body_degraded();
86                    crate::metrics::inc_proxy_stream_mode_degrade();
87                    let _ = self.on_flow.try_send(FlowUpdate::BodyBudgetExceeded {
88                        flow_id: self.flow_id.clone(),
89                        direction: self.direction.clone(),
90                    });
91                }
92
93                Poll::Ready(None)
94            }
95            other => other,
96        }
97    }
98
99    fn is_end_stream(&self) -> bool {
100        self.inner.is_end_stream()
101    }
102
103    fn size_hint(&self) -> SizeHint {
104        self.inner.size_hint()
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use bytes::Bytes;
112    use http_body_util::BodyExt;
113    use hyper::body::Frame;
114    use relay_core_api::flow::Direction;
115    use std::pin::Pin;
116    use std::task::{Context, Poll, Waker};
117
118    /// Simple Data + Trailers body used in both tests below.
119    struct DataThenTrailers {
120        phase: u8,
121    }
122
123    impl Body for DataThenTrailers {
124        type Data = Bytes;
125        type Error = BoxError;
126
127        fn poll_frame(
128            mut self: Pin<&mut Self>,
129            _cx: &mut Context<'_>,
130        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
131            match self.phase {
132                0 => {
133                    self.phase = 1;
134                    Poll::Ready(Some(Ok(Frame::data(Bytes::from("hello")))))
135                }
136                1 => {
137                    self.phase = 2;
138                    let mut trailers = hyper::HeaderMap::new();
139                    trailers.insert("x-trailer", "value".parse().unwrap());
140                    Poll::Ready(Some(Ok(Frame::trailers(trailers))))
141                }
142                _ => Poll::Ready(None),
143            }
144        }
145    }
146
147    /// Verify TapBody passes trailers through while still correctly
148    /// buffering body data and emitting HttpBody/BodyBudgetExceeded events.
149    #[tokio::test]
150    async fn test_tap_body_passes_trailers() {
151        let body: HttpBody = DataThenTrailers { phase: 0 }.boxed();
152        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
153
154        let mut tap = TapBody::new(
155            body,
156            "test-flow".to_string(),
157            tx,
158            Direction::ServerToClient,
159            4096,
160            vec![],
161        );
162
163        // Collect frames and FlowUpdate events
164        let waker = Waker::noop();
165        let mut cx = Context::from_waker(&waker);
166
167        let mut data_frames = 0;
168        let mut trailer_frames = 0;
169        let mut trailers: Option<hyper::HeaderMap> = None;
170
171        loop {
172            match Pin::new(&mut tap).poll_frame(&mut cx) {
173                Poll::Ready(Some(Ok(frame))) => {
174                    if frame.data_ref().is_some() {
175                        data_frames += 1;
176                    }
177                    if let Some(t) = frame.trailers_ref() {
178                        trailer_frames += 1;
179                        trailers = Some(t.clone());
180                    }
181                }
182                Poll::Ready(Some(Err(e))) => panic!("unexpected error: {}", e),
183                Poll::Ready(None) => break,
184                Poll::Pending => panic!("unexpected pending"),
185            }
186        }
187
188        // Verify trailers forwarded
189        assert_eq!(data_frames, 1, "should forward 1 data frame");
190        assert_eq!(trailer_frames, 1, "should forward 1 trailers frame");
191        let trailers = trailers.expect("trailers should be present");
192        assert_eq!(
193            trailers.get("x-trailer").and_then(|v| v.to_str().ok()),
194            Some("value"),
195            "trailer x-trailer should be preserved"
196        );
197
198        // Verify TapBody still sent HttpBody event
199        let event = rx.try_recv().expect("should emit HttpBody event");
200        match event {
201            FlowUpdate::HttpBody { body, .. } => {
202                assert_eq!(body.size, 5, "body size should match data");
203            }
204            other => panic!("expected HttpBody, got {:?}", other),
205        }
206    }
207}