relay_core_lib/proxy/
budget.rs1use crate::interceptor::{BoxError, HttpBody};
2use bytes::Bytes;
3use hyper::body::{Body, Frame, SizeHint};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub struct BudgetedBody {
10 inner: HttpBody,
11 budget: usize,
12 accumulated: Vec<u8>,
13 pub budget_exceeded: bool,
15 pub total_bytes: u64,
17 done: bool,
18}
19
20impl BudgetedBody {
21 pub fn new(inner: HttpBody, budget: usize) -> Self {
22 Self {
23 inner,
24 budget,
25 accumulated: Vec::with_capacity(budget.min(65_536)),
26 budget_exceeded: false,
27 total_bytes: 0,
28 done: false,
29 }
30 }
31
32 pub fn accumulated_data(&self) -> Bytes {
34 Bytes::from(self.accumulated.clone())
35 }
36}
37
38impl Body for BudgetedBody {
39 type Data = Bytes;
40 type Error = BoxError;
41
42 fn poll_frame(
43 mut self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
46 if self.done {
47 return Poll::Ready(None);
48 }
49
50 let frame = match Pin::new(&mut self.inner).poll_frame(cx) {
51 Poll::Ready(Some(Ok(frame))) => frame,
52 other => {
53 self.done = true;
54 return other;
55 }
56 };
57
58 if let Some(data) = frame.data_ref() {
60 self.total_bytes += data.len() as u64;
61
62 let remaining = self.budget.saturating_sub(self.accumulated.len());
64 if remaining > 0 {
65 let take = data.len().min(remaining);
66 self.accumulated.extend_from_slice(&data[..take]);
67 }
68 if self.accumulated.len() >= self.budget {
69 self.budget_exceeded = true;
70 }
71 }
72
73 if frame.is_trailers() {
75 self.done = true;
76 }
77
78 Poll::Ready(Some(Ok(frame)))
79 }
80
81 fn is_end_stream(&self) -> bool {
82 self.done || self.inner.is_end_stream()
83 }
84
85 fn size_hint(&self) -> SizeHint {
86 self.inner.size_hint()
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use bytes::Bytes;
94 use http_body_util::{BodyExt, Full};
95
96 #[tokio::test]
97 async fn test_budgeted_body_within_limit() {
98 let data = Bytes::from("hello world");
99 let body: HttpBody = Full::new(data.clone())
100 .map_err(|e| -> BoxError { Box::new(e) })
101 .boxed();
102 let mut budgeted = BudgetedBody::new(body, 100);
103 let collected = (&mut budgeted).collect().await.unwrap().to_bytes();
104 assert_eq!(collected, data);
105 assert!(!budgeted.budget_exceeded);
106 assert_eq!(budgeted.accumulated_data(), data);
107 }
108
109 #[tokio::test]
110 async fn test_budgeted_body_exceeds_limit() {
111 let data = Bytes::from("this is a long message that exceeds the budget");
112 let body: HttpBody = Full::new(data.clone())
113 .map_err(|e| -> BoxError { Box::new(e) })
114 .boxed();
115 let mut budgeted = BudgetedBody::new(body, 10);
116 let collected = (&mut budgeted).collect().await.unwrap().to_bytes();
117 assert_eq!(collected, data);
119 assert!(budgeted.budget_exceeded);
121 assert_eq!(budgeted.accumulated_data(), Bytes::from("this is a "));
123 }
124}