Skip to main content

http_nu/
compression.rs

1use brotli::enc::backward_references::BrotliEncoderParams;
2use brotli::enc::encode::{BrotliEncoderOperation, BrotliEncoderStateStruct};
3use brotli::enc::StandardAlloc;
4use bytes::Bytes;
5use headers::Header;
6use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
7use http_encoding_headers::{AcceptEncoding, Encoding};
8use hyper::body::Frame;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::{Stream, StreamExt};
14
15type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
16
17const BROTLI_QUALITY: i32 = 4;
18const OUTBUF_CAP: usize = 16 * 1024;
19
20/// Check if the request accepts brotli encoding.
21///
22/// Parses the `Accept-Encoding` header respecting quality values.
23/// Returns `true` only if `br` is present with quality > 0.
24#[must_use]
25pub fn accepts_brotli(headers: &hyper::header::HeaderMap) -> bool {
26    let Ok(accept) =
27        AcceptEncoding::decode(&mut headers.get_all(hyper::header::ACCEPT_ENCODING).iter())
28    else {
29        return false;
30    };
31    accept.preferred_allowed([Encoding::Br].iter()).is_some()
32}
33
34/// A streaming brotli compressor that flushes per chunk.
35pub struct BrotliStream<S> {
36    inner: S,
37    encoder: BrotliEncoderStateStruct<StandardAlloc>,
38    out_scratch: Vec<u8>,
39    tmp: Vec<u8>,
40    finished: bool,
41}
42
43impl<S> BrotliStream<S> {
44    pub fn new(inner: S) -> Self {
45        let params = BrotliEncoderParams {
46            quality: BROTLI_QUALITY,
47            ..Default::default()
48        };
49
50        let mut encoder = BrotliEncoderStateStruct::new(StandardAlloc::default());
51        encoder.params = params;
52
53        Self {
54            inner,
55            encoder,
56            out_scratch: Vec::with_capacity(OUTBUF_CAP),
57            tmp: vec![0u8; OUTBUF_CAP],
58            finished: false,
59        }
60    }
61
62    /// Unified Brotli driver for PROCESS/FLUSH/FINISH.
63    fn encode(&mut self, input: &[u8], op: BrotliEncoderOperation) -> Result<Bytes, BoxError> {
64        self.out_scratch.clear();
65        let mut in_offset = 0usize;
66
67        loop {
68            let mut avail_in = input.len().saturating_sub(in_offset);
69            let mut avail_out = self.tmp.len();
70            let mut out_offset = 0usize;
71
72            let ok = self.encoder.compress_stream(
73                op,
74                &mut avail_in,
75                &input[in_offset..],
76                &mut in_offset,
77                &mut avail_out,
78                &mut self.tmp,
79                &mut out_offset,
80                &mut None,
81                &mut |_, _, _, _| (),
82            );
83
84            if !ok {
85                return Err("brotli compression failed".into());
86            }
87
88            if out_offset > 0 {
89                self.out_scratch.extend_from_slice(&self.tmp[..out_offset]);
90            }
91
92            let done = match op {
93                BrotliEncoderOperation::BROTLI_OPERATION_FINISH => self.encoder.is_finished(),
94                BrotliEncoderOperation::BROTLI_OPERATION_FLUSH => !self.encoder.has_more_output(),
95                BrotliEncoderOperation::BROTLI_OPERATION_PROCESS => {
96                    in_offset >= input.len() && !self.encoder.has_more_output()
97                }
98                _ => unreachable!("unexpected Brotli operation"),
99            };
100
101            if done {
102                break;
103            }
104        }
105
106        // Take ownership while preserving capacity for next call
107        let result = std::mem::replace(&mut self.out_scratch, Vec::with_capacity(OUTBUF_CAP));
108        Ok(Bytes::from(result))
109    }
110}
111
112impl<S> Stream for BrotliStream<S>
113where
114    S: Stream<Item = Result<Vec<u8>, BoxError>> + Unpin,
115{
116    type Item = Result<Frame<Bytes>, BoxError>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        if self.finished {
120            return Poll::Ready(None);
121        }
122
123        // Drain ready chunks with PROCESS (lets brotli batch for compression),
124        // then FLUSH once at the boundary -- when the source goes idle, ends,
125        // or the per-poll budget is hit. One brotli sync-point per burst
126        // instead of per chunk.
127        const DRAIN_BUDGET: usize = 32;
128        let mut accumulated: Vec<u8> = Vec::new();
129        let mut drained = 0usize;
130
131        loop {
132            match Pin::new(&mut self.inner).poll_next(cx) {
133                Poll::Ready(Some(Ok(chunk))) => {
134                    match self.encode(&chunk, BrotliEncoderOperation::BROTLI_OPERATION_PROCESS) {
135                        Ok(out) => accumulated.extend_from_slice(&out),
136                        Err(e) => return Poll::Ready(Some(Err(e))),
137                    }
138                    drained += 1;
139                    if drained >= DRAIN_BUDGET {
140                        // Cap the loop so a chatty source can't starve other
141                        // tasks. Flush what we have and yield.
142                        match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FLUSH) {
143                            Ok(out) => accumulated.extend_from_slice(&out),
144                            Err(e) => return Poll::Ready(Some(Err(e))),
145                        }
146                        cx.waker().wake_by_ref();
147                        return Poll::Ready(Some(Ok(Frame::data(Bytes::from(accumulated)))));
148                    }
149                }
150
151                // Source idle: flush accumulated data so the client gets it
152                // immediately. If nothing is buffered, we're truly Pending --
153                // the inner already registered our waker.
154                Poll::Pending => {
155                    match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FLUSH) {
156                        Ok(out) => accumulated.extend_from_slice(&out),
157                        Err(e) => return Poll::Ready(Some(Err(e))),
158                    }
159                    if accumulated.is_empty() {
160                        return Poll::Pending;
161                    }
162                    return Poll::Ready(Some(Ok(Frame::data(Bytes::from(accumulated)))));
163                }
164
165                // Inner errored (e.g. SSE cancel): propagate the error, mark
166                // ourselves finished so the next poll yields None. We do NOT
167                // run the brotli FINISH op -- the deliberately truncated body
168                // lets the client see this as a fetch error and auto-retry.
169                Poll::Ready(Some(Err(e))) => {
170                    self.finished = true;
171                    return Poll::Ready(Some(Err(e)));
172                }
173
174                Poll::Ready(None) => {
175                    self.finished = true;
176                    match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FINISH) {
177                        Ok(out) => accumulated.extend_from_slice(&out),
178                        Err(e) => return Poll::Ready(Some(Err(e))),
179                    }
180                    if accumulated.is_empty() {
181                        return Poll::Ready(None);
182                    }
183                    return Poll::Ready(Some(Ok(Frame::data(Bytes::from(accumulated)))));
184                }
185            }
186        }
187    }
188}
189
190/// Wrap a streaming response body with brotli compression.
191pub fn compress_stream(rx: mpsc::Receiver<Vec<u8>>) -> BoxBody<Bytes, BoxError> {
192    let stream = ReceiverStream::new(rx).map(Ok::<Vec<u8>, BoxError>);
193    let brotli_stream = BrotliStream::new(stream);
194    StreamBody::new(brotli_stream).boxed()
195}
196
197/// Compress an entire body eagerly.
198pub fn compress_full(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
199    let mut output = Vec::new();
200    let params = BrotliEncoderParams {
201        quality: BROTLI_QUALITY,
202        ..Default::default()
203    };
204    brotli::BrotliCompress(&mut &*data, &mut output, &params)?;
205    Ok(output)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use hyper::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING};
212
213    #[test]
214    fn test_accepts_brotli_simple() {
215        let mut headers = HeaderMap::new();
216        headers.insert(
217            ACCEPT_ENCODING,
218            HeaderValue::from_static("gzip, deflate, br"),
219        );
220        assert!(accepts_brotli(&headers));
221    }
222
223    #[test]
224    fn test_rejects_brotli_quality_zero() {
225        let mut headers = HeaderMap::new();
226        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, br;q=0"));
227        assert!(!accepts_brotli(&headers));
228    }
229
230    #[test]
231    fn test_no_brotli() {
232        let mut headers = HeaderMap::new();
233        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, deflate"));
234        assert!(!accepts_brotli(&headers));
235    }
236
237    #[test]
238    fn test_no_accept_encoding_header() {
239        let headers = HeaderMap::new();
240        assert!(!accepts_brotli(&headers));
241    }
242
243    #[test]
244    fn test_brotli_only() {
245        let mut headers = HeaderMap::new();
246        headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("br"));
247        assert!(accepts_brotli(&headers));
248    }
249}