fast_dav_rs/common/
compression.rs

1//! Compression utilities for HTTP content encoding.
2//!
3//! This module provides support for automatic compression and decompression
4//! of HTTP responses using various encoding formats.
5
6use anyhow::Result;
7use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZstdDecoder};
8use bytes::Bytes;
9use futures_util::TryStreamExt;
10use http_body_util::BodyStream;
11use hyper::body::Incoming;
12use hyper::{HeaderMap, header, http};
13use std::io::Cursor;
14use tokio::io::{AsyncBufRead, AsyncReadExt, BufReader};
15use tokio_util::io::StreamReader;
16
17/// Supported content encodings for streaming decompression.
18///
19/// These values correspond to the `Content-Encoding` header and are used by
20/// the decompression functions to decide how to wrap the body reader.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ContentEncoding {
23    Identity,
24    Br,
25    Gzip,
26    Zstd,
27}
28
29impl ContentEncoding {
30    pub fn as_str(&self) -> &'static str {
31        match self {
32            ContentEncoding::Identity => "identity",
33            ContentEncoding::Br => "br",
34            ContentEncoding::Gzip => "gzip",
35            ContentEncoding::Zstd => "zstd",
36        }
37    }
38}
39
40/// Detect the response `Content-Encoding` header and return the ordered chain of encodings.
41///
42/// The vector is ordered from outermost encoding to innermost (as received). When empty, the
43/// payload is identity encoded.
44pub fn detect_encodings(headers: &HeaderMap) -> Vec<ContentEncoding> {
45    let Some(val) = headers.get(header::CONTENT_ENCODING) else {
46        return Vec::new();
47    };
48
49    let Ok(raw) = val.to_str() else {
50        return Vec::new();
51    };
52
53    raw.split(',')
54        .filter_map(|token| {
55            let enc = token.trim().to_ascii_lowercase();
56            Some(match enc.as_str() {
57                "br" => ContentEncoding::Br,
58                "gzip" => ContentEncoding::Gzip,
59                "zstd" | "zst" => ContentEncoding::Zstd,
60                "identity" => return None,
61                _ => return None,
62            })
63        })
64        .collect()
65}
66
67/// Insert an `Accept-Encoding` header (`br, zstd, gzip`) if not already present.
68///
69/// This hints to the server that the client supports compressed responses.
70pub fn add_accept_encoding(h: &mut HeaderMap) {
71    if !h.contains_key(header::ACCEPT_ENCODING) {
72        h.insert(
73            header::ACCEPT_ENCODING,
74            http::HeaderValue::from_static("br, zstd, gzip"),
75        );
76    }
77}
78
79/// Detect the most efficient request compression supported by the server.
80///
81/// This inspects the server's `Accept-Encoding` response header and applies
82/// quality factors (`q=` weights) to pick the optimal [`ContentEncoding`]
83/// supported by both parties. Returns `None` when the header is absent or when
84/// no mutually supported encoding is advertised.
85pub fn detect_request_compression_preference(headers: &HeaderMap) -> Option<ContentEncoding> {
86    let raw = headers.get(header::ACCEPT_ENCODING)?.to_str().ok()?;
87
88    let mut wildcard_q: Option<f32> = None;
89    let mut identity_q: f32 = 1.0;
90    let mut identity_explicit = false;
91    let mut entries: Vec<(String, f32)> = Vec::new();
92
93    for part in raw.split(',') {
94        let trimmed = part.trim();
95        if trimmed.is_empty() {
96            continue;
97        }
98
99        let mut segments = trimmed.split(';');
100        let token = segments.next().unwrap().trim().to_ascii_lowercase();
101        if token.is_empty() {
102            continue;
103        }
104
105        let mut weight = 1.0_f32;
106        for param in segments {
107            if let Some((key, value)) = param.split_once('=')
108                && key.trim().eq_ignore_ascii_case("q")
109                && let Ok(parsed) = value.trim().parse::<f32>()
110            {
111                weight = parsed.clamp(0.0, 1.0);
112            }
113        }
114
115        match token.as_str() {
116            "identity" => {
117                identity_q = weight;
118                identity_explicit = true;
119            }
120            "*" => {
121                wildcard_q = Some(weight);
122            }
123            other => entries.push((other.to_string(), weight)),
124        }
125    }
126
127    if !identity_explicit && let Some(q) = wildcard_q {
128        identity_q = q;
129    }
130
131    let mut best: Option<(ContentEncoding, f32)> = None;
132    for candidate in [
133        ContentEncoding::Br,
134        ContentEncoding::Zstd,
135        ContentEncoding::Gzip,
136    ] {
137        let direct_q = entries.iter().find_map(|(name, q)| {
138            if name == candidate.as_str() {
139                Some(*q)
140            } else {
141                None
142            }
143        });
144        let effective_q = direct_q.or(wildcard_q);
145
146        if let Some(q) = effective_q {
147            if q <= 0.0 {
148                continue;
149            }
150
151            let should_replace = best
152                .map(|(_, best_q)| q > best_q + f32::EPSILON)
153                .unwrap_or(true);
154            if should_replace {
155                best = Some((candidate, q));
156            }
157        }
158    }
159
160    if let Some((encoding, _)) = best {
161        return Some(encoding);
162    }
163
164    if identity_q > 0.0 {
165        return Some(ContentEncoding::Identity);
166    }
167
168    None
169}
170
171/// Backwards-compatible helper that returns the first encoding in the chain or identity when none.
172pub fn detect_encoding(headers: &HeaderMap) -> ContentEncoding {
173    detect_encodings(headers)
174        .into_iter()
175        .next()
176        .unwrap_or(ContentEncoding::Identity)
177}
178
179/// Decompress a response body based on the content encoding.
180///
181/// This function takes an aggregated response body and decompresses it according
182/// to the specified encoding.
183pub async fn decompress_body(body: Incoming, encodings: &[ContentEncoding]) -> Result<Bytes> {
184    let stream = BodyStream::new(body)
185        .map_ok(|frame| frame.into_data().unwrap_or_default())
186        .map_err(std::io::Error::other);
187    let reader = StreamReader::new(stream);
188    let reader = BufReader::new(reader);
189    let mut out = Vec::with_capacity(32 * 1024);
190    let mut current: Box<dyn AsyncBufRead + Unpin + Send> = Box::new(reader);
191
192    for encoding in encodings.iter().rev() {
193        current = match encoding {
194            ContentEncoding::Identity => current,
195            ContentEncoding::Br => Box::new(BufReader::new(BrotliDecoder::new(current))),
196            ContentEncoding::Gzip => Box::new(BufReader::new(GzipDecoder::new(current))),
197            ContentEncoding::Zstd => Box::new(BufReader::new(ZstdDecoder::new(current))),
198        };
199    }
200
201    let mut decoder = current;
202    decoder.read_to_end(&mut out).await?;
203
204    Ok(Bytes::from(out))
205}
206
207/// Create a buffered reader with decompression support for streaming.
208///
209/// This function wraps a stream with the appropriate decompression decoder
210/// based on the content encoding.
211pub fn decompress_stream(
212    body: Incoming,
213    encodings: &[ContentEncoding],
214) -> Result<Box<dyn AsyncBufRead + Unpin + Send>> {
215    let stream = BodyStream::new(body)
216        .map_ok(|frame| frame.into_data().unwrap_or_default())
217        .map_err(std::io::Error::other);
218    let reader: Box<dyn AsyncBufRead + Unpin + Send> =
219        Box::new(BufReader::new(StreamReader::new(stream)));
220
221    let mut current = reader;
222    for encoding in encodings.iter().rev() {
223        current = match encoding {
224            ContentEncoding::Identity => current,
225            ContentEncoding::Br => Box::new(BufReader::new(BrotliDecoder::new(current))),
226            ContentEncoding::Gzip => Box::new(BufReader::new(GzipDecoder::new(current))),
227            ContentEncoding::Zstd => Box::new(BufReader::new(ZstdDecoder::new(current))),
228        };
229    }
230
231    Ok(current)
232}
233
234/// Compress a byte payload using the specified encoding.
235///
236/// This function takes a byte payload and compresses it according to the
237/// specified encoding algorithm.
238///
239/// # Arguments
240///
241/// * `data` - The data to compress
242/// * `encoding` - The compression algorithm to use
243///
244/// # Returns
245///
246/// The compressed data as Bytes, or the original data if encoding is Identity
247///
248/// # Example
249///
250/// ```
251/// use fast_dav_rs::compression::{compress_payload, ContentEncoding};
252/// use bytes::Bytes;
253///
254/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
255/// let data = Bytes::from("Hello, compressed world!");
256/// let compressed = compress_payload(data, ContentEncoding::Gzip).await?;
257/// // compressed now contains gzipped data
258/// # Ok(())
259/// # }
260/// ```
261pub async fn compress_payload(data: Bytes, encoding: ContentEncoding) -> Result<Bytes> {
262    match encoding {
263        ContentEncoding::Identity => Ok(data),
264        ContentEncoding::Br => {
265            use async_compression::tokio::bufread::BrotliEncoder;
266
267            let mut encoder = BrotliEncoder::new(BufReader::new(Cursor::new(data)));
268            let mut compressed = Vec::new();
269            encoder.read_to_end(&mut compressed).await?;
270            Ok(Bytes::from(compressed))
271        }
272        ContentEncoding::Gzip => {
273            use async_compression::tokio::bufread::GzipEncoder;
274
275            let mut encoder = GzipEncoder::new(BufReader::new(Cursor::new(data)));
276            let mut compressed = Vec::new();
277            encoder.read_to_end(&mut compressed).await?;
278            Ok(Bytes::from(compressed))
279        }
280        ContentEncoding::Zstd => {
281            use async_compression::tokio::bufread::ZstdEncoder;
282
283            let mut encoder = ZstdEncoder::new(BufReader::new(Cursor::new(data)));
284            let mut compressed = Vec::new();
285            encoder.read_to_end(&mut compressed).await?;
286            Ok(Bytes::from(compressed))
287        }
288    }
289}
290
291/// Add a Content-Encoding header for outgoing requests that will be compressed.
292///
293/// This function adds the appropriate Content-Encoding header to indicate
294/// to the server how the request body is compressed.
295///
296/// # Arguments
297///
298/// * `headers` - The header map to modify
299/// * `encoding` - The compression algorithm being used
300///
301/// # Example
302///
303/// ```
304/// use fast_dav_rs::compression::{add_content_encoding, ContentEncoding};
305/// use hyper::HeaderMap;
306///
307/// let mut headers = HeaderMap::new();
308/// add_content_encoding(&mut headers, ContentEncoding::Gzip);
309/// assert_eq!(headers.get("Content-Encoding").unwrap(), "gzip");
310/// ```
311pub fn add_content_encoding(headers: &mut HeaderMap, encoding: ContentEncoding) {
312    if encoding != ContentEncoding::Identity
313        && let Ok(value) = http::HeaderValue::from_str(encoding.as_str())
314    {
315        headers.insert("Content-Encoding", value);
316    }
317}