use std::{
env,
future::Future,
pin::Pin,
sync::LazyLock,
task::{Context, Poll},
};
use jsonrpsee::server::{HttpBody, HttpRequest, HttpResponse};
use tower::{Layer, Service};
use tower_http::compression::predicate::SizeAbove;
use tower_http::compression::{Compression, CompressionLayer as TowerCompressionLayer};
const COMPRESS_MIN_BODY_SIZE_VAR: &str = "FOREST_RPC_COMPRESS_MIN_BODY_SIZE";
pub(crate) static COMPRESS_MIN_BODY_SIZE: LazyLock<Option<u16>> = LazyLock::new(|| {
parse_compress_min_body_size(env::var(COMPRESS_MIN_BODY_SIZE_VAR).ok().as_deref())
});
fn parse_compress_min_body_size(raw: Option<&str>) -> Option<u16> {
const DEFAULT: u16 = 1024;
let Some(raw) = raw else {
return Some(DEFAULT);
};
let Ok(parsed) = raw.parse::<i128>() else {
tracing::warn!(
"{COMPRESS_MIN_BODY_SIZE_VAR}={raw:?} is not a valid integer; \
falling back to default ({DEFAULT} bytes)"
);
return Some(DEFAULT);
};
if parsed < 0 {
return None;
}
let max = i128::from(u16::MAX);
if parsed > max {
tracing::warn!(
"{COMPRESS_MIN_BODY_SIZE_VAR}={parsed} exceeds the maximum of {max}; \
clamping to {max} bytes"
);
}
Some(u16::try_from(parsed.min(max)).expect("bounded above to u16::MAX"))
}
#[derive(Clone)]
pub(crate) struct CompressionLayer {
inner: TowerCompressionLayer<SizeAbove>,
}
impl CompressionLayer {
pub(crate) fn new(min_body_size: u16) -> Self {
Self {
inner: TowerCompressionLayer::new().compress_when(SizeAbove::new(min_body_size)),
}
}
}
impl<S> Layer<S> for CompressionLayer {
type Service = CompressionService<S>;
fn layer(&self, inner: S) -> Self::Service {
CompressionService {
inner: self.inner.layer(inner),
}
}
}
#[derive(Clone)]
pub(crate) struct CompressionService<S> {
inner: Compression<S, SizeAbove>,
}
impl<S, ReqBody> Service<HttpRequest<ReqBody>> for CompressionService<S>
where
S: Service<HttpRequest<ReqBody>, Response = HttpResponse>,
S::Future: Send + 'static,
{
type Response = HttpResponse;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: HttpRequest<ReqBody>) -> Self::Future {
let fut = self.inner.call(req);
Box::pin(async move {
let resp = fut.await?;
let (parts, compressed_body) = resp.into_parts();
Ok(Self::Response::from_parts(
parts,
HttpBody::new(compressed_body),
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING};
use std::{convert::Infallible, future::ready};
const TEST_DATA: &str = "cthulhu fhtagn ";
const REPEAT_COUNT: usize = 1000;
#[derive(Clone)]
struct MockService;
impl Service<HttpRequest> for MockService {
type Response = HttpResponse;
type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: HttpRequest) -> Self::Future {
let body = HttpBody::from(TEST_DATA.repeat(REPEAT_COUNT));
ready(Ok(HttpResponse::builder().body(body).unwrap()))
}
}
async fn body_size(resp: HttpResponse) -> usize {
let body = axum::body::Body::new(resp.into_body());
axum::body::to_bytes(body, usize::MAX).await.unwrap().len()
}
fn uncompressed_size() -> usize {
TEST_DATA.repeat(REPEAT_COUNT).len()
}
#[tokio::test]
async fn gzip_compresses_when_requested() {
let mut svc = CompressionLayer::new(0).layer(MockService);
let req = HttpRequest::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(HttpBody::empty())
.unwrap();
let resp = svc.call(req).await.unwrap();
assert_eq!(resp.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
assert!(body_size(resp).await < uncompressed_size());
}
#[tokio::test]
async fn passthrough_when_encoding_not_requested() {
let mut svc = CompressionLayer::new(0).layer(MockService);
let req = HttpRequest::builder().body(HttpBody::empty()).unwrap();
let resp = svc.call(req).await.unwrap();
assert!(resp.headers().get(CONTENT_ENCODING).is_none());
assert_eq!(body_size(resp).await, uncompressed_size());
}
#[tokio::test]
async fn below_threshold_is_not_compressed() {
let mut svc = CompressionLayer::new(u16::MAX).layer(MockService);
let req = HttpRequest::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(HttpBody::empty())
.unwrap();
let resp = svc.call(req).await.unwrap();
assert!(resp.headers().get(CONTENT_ENCODING).is_none());
assert_eq!(body_size(resp).await, uncompressed_size());
}
#[test]
fn parse_defaults_when_unset() {
assert_eq!(parse_compress_min_body_size(None), Some(1024));
}
#[test]
fn parse_negative_disables() {
assert_eq!(parse_compress_min_body_size(Some("-1")), None);
assert_eq!(parse_compress_min_body_size(Some("-999999")), None);
assert_eq!(parse_compress_min_body_size(Some("-2147483648")), None); assert_eq!(
parse_compress_min_body_size(Some("-9223372036854775808")),
None
); }
#[test]
fn parse_accepts_in_range_values() {
assert_eq!(parse_compress_min_body_size(Some("0")), Some(0));
assert_eq!(parse_compress_min_body_size(Some("512")), Some(512));
assert_eq!(parse_compress_min_body_size(Some("1024")), Some(1024));
assert_eq!(parse_compress_min_body_size(Some("65535")), Some(u16::MAX));
}
#[test]
fn parse_clamps_above_u16_max() {
assert_eq!(parse_compress_min_body_size(Some("65536")), Some(u16::MAX));
assert_eq!(
parse_compress_min_body_size(Some("1000000")),
Some(u16::MAX)
);
assert_eq!(
parse_compress_min_body_size(Some("2147483647")), Some(u16::MAX)
);
assert_eq!(
parse_compress_min_body_size(Some("99999999999")),
Some(u16::MAX)
);
assert_eq!(
parse_compress_min_body_size(Some("9223372036854775807")), Some(u16::MAX)
);
}
#[test]
fn parse_invalid_falls_back_to_default() {
assert_eq!(parse_compress_min_body_size(Some("")), Some(1024));
assert_eq!(parse_compress_min_body_size(Some("lots")), Some(1024));
assert_eq!(parse_compress_min_body_size(Some("1.5")), Some(1024));
}
}