1use brotli::enc::backward_references::BrotliEncoderParams;
2use brotli::enc::encode::{BrotliEncoderOperation, BrotliEncoderStateStruct};
3use brotli::enc::StandardAlloc;
4use bytes::Bytes;
5use headers::Header;
6use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
7use http_encoding_headers::{AcceptEncoding, Encoding};
8use hyper::body::Frame;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::Stream;
14
15type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
16
17const BROTLI_QUALITY: i32 = 4;
18const OUTBUF_CAP: usize = 16 * 1024;
19
20#[must_use]
25pub fn accepts_brotli(headers: &hyper::header::HeaderMap) -> bool {
26 let Ok(accept) =
27 AcceptEncoding::decode(&mut headers.get_all(hyper::header::ACCEPT_ENCODING).iter())
28 else {
29 return false;
30 };
31 accept.preferred_allowed([Encoding::Br].iter()).is_some()
32}
33
34pub struct BrotliStream<S> {
36 inner: S,
37 encoder: BrotliEncoderStateStruct<StandardAlloc>,
38 out_scratch: Vec<u8>,
39 tmp: Vec<u8>,
40 finished: bool,
41}
42
43impl<S> BrotliStream<S> {
44 pub fn new(inner: S) -> Self {
45 let params = BrotliEncoderParams {
46 quality: BROTLI_QUALITY,
47 ..Default::default()
48 };
49
50 let mut encoder = BrotliEncoderStateStruct::new(StandardAlloc::default());
51 encoder.params = params;
52
53 Self {
54 inner,
55 encoder,
56 out_scratch: Vec::with_capacity(OUTBUF_CAP),
57 tmp: vec![0u8; OUTBUF_CAP],
58 finished: false,
59 }
60 }
61
62 fn encode(&mut self, input: &[u8], op: BrotliEncoderOperation) -> Result<Bytes, BoxError> {
64 self.out_scratch.clear();
65 let mut in_offset = 0usize;
66
67 loop {
68 let mut avail_in = input.len().saturating_sub(in_offset);
69 let mut avail_out = self.tmp.len();
70 let mut out_offset = 0usize;
71
72 let ok = self.encoder.compress_stream(
73 op,
74 &mut avail_in,
75 &input[in_offset..],
76 &mut in_offset,
77 &mut avail_out,
78 &mut self.tmp,
79 &mut out_offset,
80 &mut None,
81 &mut |_, _, _, _| (),
82 );
83
84 if !ok {
85 return Err("brotli compression failed".into());
86 }
87
88 if out_offset > 0 {
89 self.out_scratch.extend_from_slice(&self.tmp[..out_offset]);
90 }
91
92 let done = match op {
93 BrotliEncoderOperation::BROTLI_OPERATION_FINISH => self.encoder.is_finished(),
94 BrotliEncoderOperation::BROTLI_OPERATION_FLUSH => !self.encoder.has_more_output(),
95 BrotliEncoderOperation::BROTLI_OPERATION_PROCESS => {
96 in_offset >= input.len() && !self.encoder.has_more_output()
97 }
98 _ => unreachable!("unexpected Brotli operation"),
99 };
100
101 if done {
102 break;
103 }
104 }
105
106 let result = std::mem::replace(&mut self.out_scratch, Vec::with_capacity(OUTBUF_CAP));
108 Ok(Bytes::from(result))
109 }
110}
111
112impl<S> Stream for BrotliStream<S>
113where
114 S: Stream<Item = Vec<u8>> + Unpin,
115{
116 type Item = Result<Frame<Bytes>, BoxError>;
117
118 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 if self.finished {
120 return Poll::Ready(None);
121 }
122
123 match Pin::new(&mut self.inner).poll_next(cx) {
124 Poll::Ready(Some(chunk)) => {
125 match self.encode(&chunk, BrotliEncoderOperation::BROTLI_OPERATION_FLUSH) {
126 Ok(compressed) => {
127 if compressed.is_empty() {
128 cx.waker().wake_by_ref();
131 Poll::Pending
132 } else {
133 Poll::Ready(Some(Ok(Frame::data(compressed))))
134 }
135 }
136 Err(e) => Poll::Ready(Some(Err(e))),
137 }
138 }
139
140 Poll::Ready(None) => {
141 self.finished = true;
142 match self.encode(&[], BrotliEncoderOperation::BROTLI_OPERATION_FINISH) {
143 Ok(final_data) => {
144 if final_data.is_empty() {
145 Poll::Ready(None)
146 } else {
147 Poll::Ready(Some(Ok(Frame::data(final_data))))
148 }
149 }
150 Err(e) => Poll::Ready(Some(Err(e))),
151 }
152 }
153
154 Poll::Pending => Poll::Pending,
155 }
156 }
157}
158
159pub fn compress_stream(rx: mpsc::Receiver<Vec<u8>>) -> BoxBody<Bytes, BoxError> {
161 let stream = ReceiverStream::new(rx);
162 let brotli_stream = BrotliStream::new(stream);
163 StreamBody::new(brotli_stream).boxed()
164}
165
166pub fn compress_full(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
168 let mut output = Vec::new();
169 let params = BrotliEncoderParams {
170 quality: BROTLI_QUALITY,
171 ..Default::default()
172 };
173 brotli::BrotliCompress(&mut &*data, &mut output, ¶ms)?;
174 Ok(output)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use hyper::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING};
181
182 #[test]
183 fn test_accepts_brotli_simple() {
184 let mut headers = HeaderMap::new();
185 headers.insert(
186 ACCEPT_ENCODING,
187 HeaderValue::from_static("gzip, deflate, br"),
188 );
189 assert!(accepts_brotli(&headers));
190 }
191
192 #[test]
193 fn test_rejects_brotli_quality_zero() {
194 let mut headers = HeaderMap::new();
195 headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, br;q=0"));
196 assert!(!accepts_brotli(&headers));
197 }
198
199 #[test]
200 fn test_no_brotli() {
201 let mut headers = HeaderMap::new();
202 headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, deflate"));
203 assert!(!accepts_brotli(&headers));
204 }
205
206 #[test]
207 fn test_no_accept_encoding_header() {
208 let headers = HeaderMap::new();
209 assert!(!accepts_brotli(&headers));
210 }
211
212 #[test]
213 fn test_brotli_only() {
214 let mut headers = HeaderMap::new();
215 headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("br"));
216 assert!(accepts_brotli(&headers));
217 }
218}