Skip to main content

http_nu/
compression.rs

1use brotli::enc::backward_references::BrotliEncoderParams;
2use brotli::enc::encode::{BrotliEncoderOperation, BrotliEncoderStateStruct};
3use brotli::enc::StandardAlloc;
4use bytes::Bytes;
5use headers::Header;
6use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
7use http_encoding_headers::{AcceptEncoding, Encoding};
8use hyper::body::Frame;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::Stream;
14
15type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
16
17const BROTLI_QUALITY: i32 = 4;
18const OUTBUF_CAP: usize = 16 * 1024;
19
20/// Check if the request accepts brotli encoding.
21///
22/// Parses the `Accept-Encoding` header respecting quality values.
23/// Returns `true` only if `br` is present with quality > 0.
24#[must_use]
25pub fn accepts_brotli(headers: &hyper::header::HeaderMap) -> bool {
26    let Ok(accept) =
27        AcceptEncoding::decode(&mut headers.get_all(hyper::header::ACCEPT_ENCODING).iter())
28    else {
29        return false;
30    };
31    accept.preferred_allowed([Encoding::Br].iter()).is_some()
32}
33
34/// A streaming brotli compressor that flushes per chunk.
35pub struct BrotliStream<S> {
36    inner: S,
37    encoder: BrotliEncoderStateStruct<StandardAlloc>,
38    out_scratch: Vec<u8>,
39    tmp: Vec<u8>,
40    finished: bool,
41}
42
43impl<S> BrotliStream<S> {
44    pub fn new(inner: S) -> Self {
45        let params = BrotliEncoderParams {
46            quality: BROTLI_QUALITY,
47            ..Default::default()
48        };
49
50        let mut encoder = BrotliEncoderStateStruct::new(StandardAlloc::default());
51        encoder.params = params;
52
53        Self {
54            inner,
55            encoder,
56            out_scratch: Vec::with_capacity(OUTBUF_CAP),
57            tmp: vec![0u8; OUTBUF_CAP],
58            finished: false,
59        }
60    }
61
62    /// Unified Brotli driver for PROCESS/FLUSH/FINISH.
63    fn encode(&mut self, input: &[u8], op: BrotliEncoderOperation) -> Result<Bytes, BoxError> {
64        self.out_scratch.clear();
65        let mut in_offset = 0usize;
66
67        loop {
68            let mut avail_in = input.len().saturating_sub(in_offset);
69            let mut avail_out = self.tmp.len();
70            let mut out_offset = 0usize;
71
72            let ok = self.encoder.compress_stream(
73                op,
74                &mut avail_in,
75                &input[in_offset..],
76                &mut in_offset,
77                &mut avail_out,
78                &mut self.tmp,
79                &mut out_offset,
80                &mut None,
81                &mut |_, _, _, _| (),
82            );
83
84            if !ok {
85                return Err("brotli compression failed".into());
86            }
87
88            if out_offset > 0 {
89                self.out_scratch.extend_from_slice(&self.tmp[..out_offset]);
90            }
91
92            let done = match op {
93                BrotliEncoderOperation::BROTLI_OPERATION_FINISH => self.encoder.is_finished(),
94                BrotliEncoderOperation::BROTLI_OPERATION_FLUSH => !self.encoder.has_more_output(),
95                BrotliEncoderOperation::BROTLI_OPERATION_PROCESS => {
96                    in_offset >= input.len() && !self.encoder.has_more_output()
97                }
98                _ => unreachable!("unexpected Brotli operation"),
99            };
100
101            if done {
102                break;
103            }
104        }
105
106        // Take ownership while preserving capacity for next call
107        let result = std::mem::replace(&mut self.out_scratch, Vec::with_capacity(OUTBUF_CAP));
108        Ok(Bytes::from(result))
109    }
110}
111
112impl<S> Stream for BrotliStream<S>
113where
114    S: Stream<Item = Vec<u8>> + Unpin,
115{
116    type Item = Result<Frame<Bytes>, BoxError>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        if self.finished {
120            return Poll::Ready(None);
121        }
122
123        match Pin::new(&mut self.inner).poll_next(cx) {
124            Poll::Ready(Some(chunk)) => {
125                match self.encode(&chunk, BrotliEncoderOperation::BROTLI_OPERATION_FLUSH) {
126                    Ok(compressed) => {
127                        if compressed.is_empty() {
128                            // FLUSH on non-empty input should always produce output,
129                            // but handle defensively
130                            cx.waker().wake_by_ref();
131                            Poll::Pending
132                        } else {
133                            Poll::Ready(Some(Ok(Frame::data(compressed))))
134                        }
135                    }
136                    Err(e) => Poll::Ready(Some(Err(e))),
137                }
138            }
139
140            Poll::Ready(None) => {
141                self.finished = true;
142                match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FINISH) {
143                    Ok(final_data) => {
144                        if final_data.is_empty() {
145                            Poll::Ready(None)
146                        } else {
147                            Poll::Ready(Some(Ok(Frame::data(final_data))))
148                        }
149                    }
150                    Err(e) => Poll::Ready(Some(Err(e))),
151                }
152            }
153
154            Poll::Pending => Poll::Pending,
155        }
156    }
157}
158
159/// Wrap a streaming response body with brotli compression.
160pub fn compress_stream(rx: mpsc::Receiver<Vec<u8>>) -> BoxBody<Bytes, BoxError> {
161    let stream = ReceiverStream::new(rx);
162    let brotli_stream = BrotliStream::new(stream);
163    StreamBody::new(brotli_stream).boxed()
164}
165
166/// Compress an entire body eagerly.
167pub fn compress_full(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
168    let mut output = Vec::new();
169    let params = BrotliEncoderParams {
170        quality: BROTLI_QUALITY,
171        ..Default::default()
172    };
173    brotli::BrotliCompress(&mut &*data, &mut output, &params)?;
174    Ok(output)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use hyper::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING};
181
182    #[test]
183    fn test_accepts_brotli_simple() {
184        let mut headers = HeaderMap::new();
185        headers.insert(
186            ACCEPT_ENCODING,
187            HeaderValue::from_static("gzip, deflate, br"),
188        );
189        assert!(accepts_brotli(&headers));
190    }
191
192    #[test]
193    fn test_rejects_brotli_quality_zero() {
194        let mut headers = HeaderMap::new();
195        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, br;q=0"));
196        assert!(!accepts_brotli(&headers));
197    }
198
199    #[test]
200    fn test_no_brotli() {
201        let mut headers = HeaderMap::new();
202        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, deflate"));
203        assert!(!accepts_brotli(&headers));
204    }
205
206    #[test]
207    fn test_no_accept_encoding_header() {
208        let headers = HeaderMap::new();
209        assert!(!accepts_brotli(&headers));
210    }
211
212    #[test]
213    fn test_brotli_only() {
214        let mut headers = HeaderMap::new();
215        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("br"));
216        assert!(accepts_brotli(&headers));
217    }
218}