forest/rpc/
compression_layer.rs1use std::{
11 env,
12 future::Future,
13 pin::Pin,
14 sync::LazyLock,
15 task::{Context, Poll},
16};
17
18use jsonrpsee::server::{HttpBody, HttpRequest, HttpResponse};
19use tower::{Layer, Service};
20use tower_http::compression::predicate::SizeAbove;
21use tower_http::compression::{Compression, CompressionLayer as TowerCompressionLayer};
22
23const COMPRESS_MIN_BODY_SIZE_VAR: &str = "FOREST_RPC_COMPRESS_MIN_BODY_SIZE";
24
25pub(crate) static COMPRESS_MIN_BODY_SIZE: LazyLock<Option<u16>> = LazyLock::new(|| {
37 parse_compress_min_body_size(env::var(COMPRESS_MIN_BODY_SIZE_VAR).ok().as_deref())
38});
39
40fn parse_compress_min_body_size(raw: Option<&str>) -> Option<u16> {
47 const DEFAULT: u16 = 1024;
50 let Some(raw) = raw else {
51 return Some(DEFAULT);
52 };
53 let Ok(parsed) = raw.parse::<i128>() else {
57 tracing::warn!(
58 "{COMPRESS_MIN_BODY_SIZE_VAR}={raw:?} is not a valid integer; \
59 falling back to default ({DEFAULT} bytes)"
60 );
61 return Some(DEFAULT);
62 };
63 if parsed < 0 {
64 return None;
65 }
66 let max = i128::from(u16::MAX);
67 if parsed > max {
68 tracing::warn!(
69 "{COMPRESS_MIN_BODY_SIZE_VAR}={parsed} exceeds the maximum of {max}; \
70 clamping to {max} bytes"
71 );
72 }
73 Some(u16::try_from(parsed.min(max)).expect("bounded above to u16::MAX"))
75}
76
77#[derive(Clone)]
79pub(crate) struct CompressionLayer {
80 inner: TowerCompressionLayer<SizeAbove>,
81}
82
83impl CompressionLayer {
84 pub(crate) fn new(min_body_size: u16) -> Self {
86 Self {
87 inner: TowerCompressionLayer::new().compress_when(SizeAbove::new(min_body_size)),
88 }
89 }
90}
91
92impl<S> Layer<S> for CompressionLayer {
93 type Service = CompressionService<S>;
94
95 fn layer(&self, inner: S) -> Self::Service {
96 CompressionService {
97 inner: self.inner.layer(inner),
98 }
99 }
100}
101
102#[derive(Clone)]
103pub(crate) struct CompressionService<S> {
104 inner: Compression<S, SizeAbove>,
105}
106
107impl<S, ReqBody> Service<HttpRequest<ReqBody>> for CompressionService<S>
108where
109 S: Service<HttpRequest<ReqBody>, Response = HttpResponse>,
110 S::Future: Send + 'static,
111{
112 type Response = HttpResponse;
113 type Error = S::Error;
114 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
115
116 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 self.inner.poll_ready(cx)
118 }
119
120 fn call(&mut self, req: HttpRequest<ReqBody>) -> Self::Future {
121 let fut = self.inner.call(req);
122 Box::pin(async move {
123 let resp = fut.await?;
125 let (parts, compressed_body) = resp.into_parts();
126 Ok(Self::Response::from_parts(
127 parts,
128 HttpBody::new(compressed_body),
129 ))
130 })
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING};
138 use std::{convert::Infallible, future::ready};
139
140 const TEST_DATA: &str = "cthulhu fhtagn ";
141 const REPEAT_COUNT: usize = 1000;
142
143 #[derive(Clone)]
144 struct MockService;
145
146 impl Service<HttpRequest> for MockService {
147 type Response = HttpResponse;
148 type Error = Infallible;
149 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
150
151 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152 Poll::Ready(Ok(()))
153 }
154
155 fn call(&mut self, _: HttpRequest) -> Self::Future {
156 let body = HttpBody::from(TEST_DATA.repeat(REPEAT_COUNT));
157 ready(Ok(HttpResponse::builder().body(body).unwrap()))
158 }
159 }
160
161 async fn body_size(resp: HttpResponse) -> usize {
162 let body = axum::body::Body::new(resp.into_body());
163 axum::body::to_bytes(body, usize::MAX).await.unwrap().len()
164 }
165
166 fn uncompressed_size() -> usize {
167 TEST_DATA.repeat(REPEAT_COUNT).len()
168 }
169
170 #[tokio::test]
171 async fn gzip_compresses_when_requested() {
172 let mut svc = CompressionLayer::new(0).layer(MockService);
173 let req = HttpRequest::builder()
174 .header(ACCEPT_ENCODING, "gzip")
175 .body(HttpBody::empty())
176 .unwrap();
177 let resp = svc.call(req).await.unwrap();
178 assert_eq!(resp.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
179 assert!(body_size(resp).await < uncompressed_size());
180 }
181
182 #[tokio::test]
183 async fn passthrough_when_encoding_not_requested() {
184 let mut svc = CompressionLayer::new(0).layer(MockService);
185 let req = HttpRequest::builder().body(HttpBody::empty()).unwrap();
186 let resp = svc.call(req).await.unwrap();
187 assert!(resp.headers().get(CONTENT_ENCODING).is_none());
188 assert_eq!(body_size(resp).await, uncompressed_size());
189 }
190
191 #[tokio::test]
192 async fn below_threshold_is_not_compressed() {
193 let mut svc = CompressionLayer::new(u16::MAX).layer(MockService);
194 let req = HttpRequest::builder()
195 .header(ACCEPT_ENCODING, "gzip")
196 .body(HttpBody::empty())
197 .unwrap();
198 let resp = svc.call(req).await.unwrap();
199 assert!(resp.headers().get(CONTENT_ENCODING).is_none());
200 assert_eq!(body_size(resp).await, uncompressed_size());
201 }
202
203 #[test]
204 fn parse_defaults_when_unset() {
205 assert_eq!(parse_compress_min_body_size(None), Some(1024));
206 }
207
208 #[test]
209 fn parse_negative_disables() {
210 assert_eq!(parse_compress_min_body_size(Some("-1")), None);
211 assert_eq!(parse_compress_min_body_size(Some("-999999")), None);
212 assert_eq!(parse_compress_min_body_size(Some("-2147483648")), None); assert_eq!(
215 parse_compress_min_body_size(Some("-9223372036854775808")),
216 None
217 ); }
219
220 #[test]
221 fn parse_accepts_in_range_values() {
222 assert_eq!(parse_compress_min_body_size(Some("0")), Some(0));
223 assert_eq!(parse_compress_min_body_size(Some("512")), Some(512));
224 assert_eq!(parse_compress_min_body_size(Some("1024")), Some(1024));
225 assert_eq!(parse_compress_min_body_size(Some("65535")), Some(u16::MAX));
226 }
227
228 #[test]
229 fn parse_clamps_above_u16_max() {
230 assert_eq!(parse_compress_min_body_size(Some("65536")), Some(u16::MAX));
231 assert_eq!(
232 parse_compress_min_body_size(Some("1000000")),
233 Some(u16::MAX)
234 );
235 assert_eq!(
236 parse_compress_min_body_size(Some("2147483647")), Some(u16::MAX)
238 );
239 assert_eq!(
241 parse_compress_min_body_size(Some("99999999999")),
242 Some(u16::MAX)
243 );
244 assert_eq!(
245 parse_compress_min_body_size(Some("9223372036854775807")), Some(u16::MAX)
247 );
248 }
249
250 #[test]
251 fn parse_invalid_falls_back_to_default() {
252 assert_eq!(parse_compress_min_body_size(Some("")), Some(1024));
253 assert_eq!(parse_compress_min_body_size(Some("lots")), Some(1024));
254 assert_eq!(parse_compress_min_body_size(Some("1.5")), Some(1024));
255 }
256}