Skip to main content

relay_core_lib/proxy/
budget.rs

1use crate::interceptor::{BoxError, HttpBody};
2use bytes::Bytes;
3use hyper::body::{Body, Frame, SizeHint};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7/// Wraps an HttpBody and accumulates up to `budget` bytes while
8/// passing through all data. Tracks whether the budget was exceeded.
9pub struct BudgetedBody {
10    inner: HttpBody,
11    budget: usize,
12    accumulated: Vec<u8>,
13    /// Set to true when accumulated bytes exceed the budget.
14    pub budget_exceeded: bool,
15    /// Total bytes passed through.
16    pub total_bytes: u64,
17    done: bool,
18}
19
20impl BudgetedBody {
21    pub fn new(inner: HttpBody, budget: usize) -> Self {
22        Self {
23            inner,
24            budget,
25            accumulated: Vec::with_capacity(budget.min(65_536)),
26            budget_exceeded: false,
27            total_bytes: 0,
28            done: false,
29        }
30    }
31
32    /// Returns a copy of the accumulated body data (up to budget).
33    pub fn accumulated_data(&self) -> Bytes {
34        Bytes::from(self.accumulated.clone())
35    }
36}
37
38impl Body for BudgetedBody {
39    type Data = Bytes;
40    type Error = BoxError;
41
42    fn poll_frame(
43        mut self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
46        if self.done {
47            return Poll::Ready(None);
48        }
49
50        let frame = match Pin::new(&mut self.inner).poll_frame(cx) {
51            Poll::Ready(Some(Ok(frame))) => frame,
52            other => {
53                self.done = true;
54                return other;
55            }
56        };
57
58        // Track total bytes
59        if let Some(data) = frame.data_ref() {
60            self.total_bytes += data.len() as u64;
61
62            // Accumulate up to budget for rule inspection
63            let remaining = self.budget.saturating_sub(self.accumulated.len());
64            if remaining > 0 {
65                let take = data.len().min(remaining);
66                self.accumulated.extend_from_slice(&data[..take]);
67            }
68            if self.accumulated.len() >= self.budget {
69                self.budget_exceeded = true;
70            }
71        }
72
73        // Check for end-of-stream via trailers frame
74        if frame.is_trailers() {
75            self.done = true;
76        }
77
78        Poll::Ready(Some(Ok(frame)))
79    }
80
81    fn is_end_stream(&self) -> bool {
82        self.done || self.inner.is_end_stream()
83    }
84
85    fn size_hint(&self) -> SizeHint {
86        self.inner.size_hint()
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use bytes::Bytes;
94    use http_body_util::{BodyExt, Full};
95
96    #[tokio::test]
97    async fn test_budgeted_body_within_limit() {
98        let data = Bytes::from("hello world");
99        let body: HttpBody = Full::new(data.clone())
100            .map_err(|e| -> BoxError { Box::new(e) })
101            .boxed();
102        let mut budgeted = BudgetedBody::new(body, 100);
103        let collected = (&mut budgeted).collect().await.unwrap().to_bytes();
104        assert_eq!(collected, data);
105        assert!(!budgeted.budget_exceeded);
106        assert_eq!(budgeted.accumulated_data(), data);
107    }
108
109    #[tokio::test]
110    async fn test_budgeted_body_exceeds_limit() {
111        let data = Bytes::from("this is a long message that exceeds the budget");
112        let body: HttpBody = Full::new(data.clone())
113            .map_err(|e| -> BoxError { Box::new(e) })
114            .boxed();
115        let mut budgeted = BudgetedBody::new(body, 10);
116        let collected = (&mut budgeted).collect().await.unwrap().to_bytes();
117        // All data still passes through
118        assert_eq!(collected, data);
119        // Budget was exceeded
120        assert!(budgeted.budget_exceeded);
121        // Only first 10 bytes accumulated
122        assert_eq!(budgeted.accumulated_data(), Bytes::from("this is a "));
123    }
124}