Skip to main content

http_nu/
compression.rs

1use brotli::enc::backward_references::BrotliEncoderParams;
2use brotli::enc::encode::{BrotliEncoderOperation, BrotliEncoderStateStruct};
3use brotli::enc::StandardAlloc;
4use bytes::Bytes;
5use headers::Header;
6use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
7use http_encoding_headers::{AcceptEncoding, Encoding};
8use hyper::body::Frame;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::{Stream, StreamExt};
14
15type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
16
17const BROTLI_QUALITY: i32 = 4;
18const OUTBUF_CAP: usize = 16 * 1024;
19
20/// Check if the request accepts brotli encoding.
21///
22/// Parses the `Accept-Encoding` header respecting quality values.
23/// Returns `true` only if `br` is present with quality > 0.
24#[must_use]
25pub fn accepts_brotli(headers: &hyper::header::HeaderMap) -> bool {
26    let Ok(accept) =
27        AcceptEncoding::decode(&mut headers.get_all(hyper::header::ACCEPT_ENCODING).iter())
28    else {
29        return false;
30    };
31    accept.preferred_allowed([Encoding::Br].iter()).is_some()
32}
33
34/// A streaming brotli compressor that flushes per chunk.
35pub struct BrotliStream<S> {
36    inner: S,
37    encoder: BrotliEncoderStateStruct<StandardAlloc>,
38    out_scratch: Vec<u8>,
39    tmp: Vec<u8>,
40    finished: bool,
41}
42
43impl<S> BrotliStream<S> {
44    pub fn new(inner: S) -> Self {
45        let params = BrotliEncoderParams {
46            quality: BROTLI_QUALITY,
47            ..Default::default()
48        };
49
50        let mut encoder = BrotliEncoderStateStruct::new(StandardAlloc::default());
51        encoder.params = params;
52
53        Self {
54            inner,
55            encoder,
56            out_scratch: Vec::with_capacity(OUTBUF_CAP),
57            tmp: vec![0u8; OUTBUF_CAP],
58            finished: false,
59        }
60    }
61
62    /// Unified Brotli driver for PROCESS/FLUSH/FINISH.
63    fn encode(&mut self, input: &[u8], op: BrotliEncoderOperation) -> Result<Bytes, BoxError> {
64        self.out_scratch.clear();
65        let mut in_offset = 0usize;
66
67        loop {
68            let mut avail_in = input.len().saturating_sub(in_offset);
69            let mut avail_out = self.tmp.len();
70            let mut out_offset = 0usize;
71
72            let ok = self.encoder.compress_stream(
73                op,
74                &mut avail_in,
75                &input[in_offset..],
76                &mut in_offset,
77                &mut avail_out,
78                &mut self.tmp,
79                &mut out_offset,
80                &mut None,
81                &mut |_, _, _, _| (),
82            );
83
84            if !ok {
85                return Err("brotli compression failed".into());
86            }
87
88            if out_offset > 0 {
89                self.out_scratch.extend_from_slice(&self.tmp[..out_offset]);
90            }
91
92            let done = match op {
93                BrotliEncoderOperation::BROTLI_OPERATION_FINISH => self.encoder.is_finished(),
94                BrotliEncoderOperation::BROTLI_OPERATION_FLUSH => !self.encoder.has_more_output(),
95                BrotliEncoderOperation::BROTLI_OPERATION_PROCESS => {
96                    in_offset >= input.len() && !self.encoder.has_more_output()
97                }
98                _ => unreachable!("unexpected Brotli operation"),
99            };
100
101            if done {
102                break;
103            }
104        }
105
106        // Take ownership while preserving capacity for next call
107        let result = std::mem::replace(&mut self.out_scratch, Vec::with_capacity(OUTBUF_CAP));
108        Ok(Bytes::from(result))
109    }
110}
111
112impl<S> Stream for BrotliStream<S>
113where
114    S: Stream<Item = Result<Vec<u8>, BoxError>> + Unpin,
115{
116    type Item = Result<Frame<Bytes>, BoxError>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        if self.finished {
120            return Poll::Ready(None);
121        }
122
123        match Pin::new(&mut self.inner).poll_next(cx) {
124            Poll::Ready(Some(Ok(chunk))) => {
125                match self.encode(&chunk, BrotliEncoderOperation::BROTLI_OPERATION_FLUSH) {
126                    Ok(compressed) => {
127                        if compressed.is_empty() {
128                            // FLUSH on non-empty input should always produce output,
129                            // but handle defensively
130                            cx.waker().wake_by_ref();
131                            Poll::Pending
132                        } else {
133                            Poll::Ready(Some(Ok(Frame::data(compressed))))
134                        }
135                    }
136                    Err(e) => Poll::Ready(Some(Err(e))),
137                }
138            }
139
140            // Inner errored (e.g. SSE cancel): propagate the error, mark
141            // ourselves finished so the next poll yields None. We do NOT run
142            // the brotli FINISH op -- the deliberately truncated body lets
143            // the client see this as a fetch error and auto-retry.
144            Poll::Ready(Some(Err(e))) => {
145                self.finished = true;
146                Poll::Ready(Some(Err(e)))
147            }
148
149            Poll::Ready(None) => {
150                self.finished = true;
151                match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FINISH) {
152                    Ok(final_data) => {
153                        if final_data.is_empty() {
154                            Poll::Ready(None)
155                        } else {
156                            Poll::Ready(Some(Ok(Frame::data(final_data))))
157                        }
158                    }
159                    Err(e) => Poll::Ready(Some(Err(e))),
160                }
161            }
162
163            Poll::Pending => Poll::Pending,
164        }
165    }
166}
167
168/// Wrap a streaming response body with brotli compression.
169pub fn compress_stream(rx: mpsc::Receiver<Vec<u8>>) -> BoxBody<Bytes, BoxError> {
170    let stream = ReceiverStream::new(rx).map(Ok::<Vec<u8>, BoxError>);
171    let brotli_stream = BrotliStream::new(stream);
172    StreamBody::new(brotli_stream).boxed()
173}
174
175/// Compress an entire body eagerly.
176pub fn compress_full(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
177    let mut output = Vec::new();
178    let params = BrotliEncoderParams {
179        quality: BROTLI_QUALITY,
180        ..Default::default()
181    };
182    brotli::BrotliCompress(&mut &*data, &mut output, &params)?;
183    Ok(output)
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use hyper::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING};
190
191    #[test]
192    fn test_accepts_brotli_simple() {
193        let mut headers = HeaderMap::new();
194        headers.insert(
195            ACCEPT_ENCODING,
196            HeaderValue::from_static("gzip, deflate, br"),
197        );
198        assert!(accepts_brotli(&headers));
199    }
200
201    #[test]
202    fn test_rejects_brotli_quality_zero() {
203        let mut headers = HeaderMap::new();
204        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, br;q=0"));
205        assert!(!accepts_brotli(&headers));
206    }
207
208    #[test]
209    fn test_no_brotli() {
210        let mut headers = HeaderMap::new();
211        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, deflate"));
212        assert!(!accepts_brotli(&headers));
213    }
214
215    #[test]
216    fn test_no_accept_encoding_header() {
217        let headers = HeaderMap::new();
218        assert!(!accepts_brotli(&headers));
219    }
220
221    #[test]
222    fn test_brotli_only() {
223        let mut headers = HeaderMap::new();
224        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("br"));
225        assert!(accepts_brotli(&headers));
226    }
227}