Skip to main content

codec_rs/
compression.rs

1// SPDX-License-Identifier: MIT
2//! Client-side helpers for the Codec compression contract.
3//!
4//! Rust twin of `codecai.compression` (Python): the server emits
5//! `Codec-Zstd-Dict: sha256:<hex>` on every zstd response, the client
6//! validates that header against locally-loaded dicts before
7//! decompressing. See `spec/PROTOCOL.md` "Codec-Zstd-Dict response
8//! header" (stable since v0.2) for the full contract.
9//!
10//! The actual zstd decompression is intentionally out of scope here —
11//! callers usually already have an HTTP stack and pick their own
12//! zstd binding (`zstd` crate, `zstd-safe`, an FFI wrapper, etc.).
13//! This module just gives you the small piece that's specific to
14//! Codec: matching a response's declared dict hash to the dict you've
15//! loaded, with the case-insensitive header lookup and fail-fast
16//! semantics the spec mandates.
17//!
18//! Wrong-dict decompression produces garbage bytes that downstream
19//! msgpack / protobuf parsers will silently misinterpret — fail fast
20//! at the dict-select boundary instead.
21
22use std::collections::HashMap;
23use std::fmt;
24
25use sha2::{Digest, Sha256};
26
27/// Compute the canonical `Codec-Zstd-Dict` hash for `dict_bytes`.
28///
29/// Returns `sha256:<lowercase hex>` — same shape as the server-side
30/// header value and the `hash` field in tokenizer-map
31/// `zstd_dictionaries[]` entries.
32pub fn hash_zstd_dict(dict_bytes: &[u8]) -> String {
33    let mut hasher = Sha256::new();
34    hasher.update(dict_bytes);
35    let digest = hasher.finalize();
36    format!("sha256:{}", hex::encode(digest))
37}
38
39/// Raised when the server's `Codec-Zstd-Dict` header doesn't match any
40/// dict the client has loaded, or is missing on a zstd response.
41///
42/// A wrong-dict decompression would produce garbage bytes that
43/// downstream parsers would misinterpret — fail fast instead.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum CodecZstdDictError {
46    /// Response was `Content-Encoding: zstd` but the server omitted
47    /// the `Codec-Zstd-Dict` header. Per spec/PROTOCOL.md the server
48    /// MUST name the dict it used; we refuse to guess.
49    MissingHeader,
50    /// `Codec-Zstd-Dict` value was not in the canonical
51    /// `sha256:<64 hex chars>` shape. The wrapped string is the raw
52    /// header value (trimmed) for diagnostics.
53    MalformedHash(String),
54    /// `Codec-Zstd-Dict` named a dict we haven't loaded. The caller
55    /// should fetch it from the tokenizer map's `zstd_dictionaries[]`
56    /// entry (the one whose `hash` matches) or retry the request with
57    /// `Accept-Encoding: gzip` to downgrade to a no-dict path. The
58    /// wrapped string is the declared hash.
59    UnknownHash(String),
60}
61
62impl fmt::Display for CodecZstdDictError {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        match self {
65            CodecZstdDictError::MissingHeader => write!(
66                f,
67                "Response is Content-Encoding: zstd but no Codec-Zstd-Dict \
68                 header was present. Per spec/PROTOCOL.md the server MUST \
69                 name the dict it used. Refusing to guess."
70            ),
71            CodecZstdDictError::MalformedHash(value) => write!(
72                f,
73                "Malformed Codec-Zstd-Dict value: {value:?}. Expected \
74                 'sha256:<64 hex chars>'."
75            ),
76            CodecZstdDictError::UnknownHash(hash) => write!(
77                f,
78                "Server used zstd dict {hash} but it isn't loaded \
79                 locally. Fetch it from the tokenizer map's \
80                 zstd_dictionaries[] entry (the entry whose hash \
81                 matches), or send Accept-Encoding: gzip to downgrade."
82            ),
83        }
84    }
85}
86
87impl std::error::Error for CodecZstdDictError {}
88
89/// Pick the zstd dict to decompress this response with.
90///
91/// # Parameters
92///
93/// - `response_headers`: HTTP response header map. Keys are looked up
94///   case-insensitively, matching the way `reqwest::HeaderMap` and most
95///   HTTP libraries treat headers on the wire.
96/// - `loaded_dicts`: `{ "sha256:<hex>" -> dict bytes }`. Populate this
97///   from your tokenizer map's `zstd_dictionaries[]` entries; keys
98///   follow the same canonical shape the server emits.
99///
100/// # Returns
101///
102/// - `Ok(Some(&dict_bytes))` when the response is
103///   `Content-Encoding: zstd` and the server's `Codec-Zstd-Dict`
104///   header points at a loaded dict — pass these bytes to your zstd
105///   decoder (e.g. `zstd::stream::Decoder::with_dictionary`).
106/// - `Ok(None)` when the response isn't zstd. The caller should pass
107///   the body through identity / let its HTTP stack auto-decompress
108///   gzip / brotli.
109///
110/// # Errors
111///
112/// Returns `CodecZstdDictError` when the response is zstd but the
113/// header is missing, malformed, or names a dict the client hasn't
114/// loaded. Wrong-dict decompression is never attempted — see the
115/// spec rationale at `spec/PROTOCOL.md`.
116pub fn select_zstd_dict_for_response<'a>(
117    response_headers: &HashMap<String, String>,
118    loaded_dicts: &'a HashMap<String, Vec<u8>>,
119) -> Result<Option<&'a [u8]>, CodecZstdDictError> {
120    let enc = header(response_headers, "content-encoding");
121    match enc.map(|v| v.trim().to_ascii_lowercase()) {
122        Some(ref v) if v == "zstd" => {}
123        _ => return Ok(None), // caller's HTTP stack handles gzip/br/identity
124    }
125
126    let declared = match header(response_headers, "codec-zstd-dict") {
127        Some(v) => v.trim().to_string(),
128        None => return Err(CodecZstdDictError::MissingHeader),
129    };
130    if declared.is_empty() {
131        return Err(CodecZstdDictError::MissingHeader);
132    }
133
134    if !is_canonical_sha256(&declared) {
135        return Err(CodecZstdDictError::MalformedHash(declared));
136    }
137
138    match loaded_dicts.get(&declared) {
139        Some(bytes) => Ok(Some(bytes.as_slice())),
140        None => Err(CodecZstdDictError::UnknownHash(declared)),
141    }
142}
143
144/// Case-insensitive header lookup. Most idiomatic HTTP libraries
145/// (`reqwest::HeaderMap`, `http::HeaderMap`) already treat header
146/// names case-insensitively; this is the defensive fallback for
147/// callers that hand us a plain `HashMap<String, String>` lifted
148/// out of their own request plumbing.
149fn header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
150    if let Some(v) = headers.get(name) {
151        return Some(v.as_str());
152    }
153    let lower = name.to_ascii_lowercase();
154    for (k, v) in headers.iter() {
155        if k.to_ascii_lowercase() == lower {
156            return Some(v.as_str());
157        }
158    }
159    None
160}
161
162fn is_canonical_sha256(value: &str) -> bool {
163    const PREFIX: &str = "sha256:";
164    if !value.starts_with(PREFIX) {
165        return false;
166    }
167    let hex = &value[PREFIX.len()..];
168    hex.len() == 64 && hex.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
169}
170
171// ── Discoverable zstd dictionaries (.well-known/codec/dicts/<sha>.zstd, v0.5+) ──
172
173/// Errors raised by the v0.5 zstd-dictionary discovery surface.
174///
175/// The discovery path is hard-fail by design (no silent fallback to identity
176/// bytes) — see `spec/WELL_KNOWN_DISCOVERY.md § Resolution failures`. Silent
177/// dict-load failure was the v0.4.1 sglang COPY-dicts regression class this
178/// surface eliminates.
179#[derive(Debug, thiserror::Error)]
180pub enum ZstdDictDiscoveryError {
181    /// Hash input was not `sha256:<64 hex>` or bare `<64 hex>`.
182    #[error("Invalid dict hash {hash:?}: expected 'sha256:<64 hex>' or '<64 hex>'")]
183    InvalidHash { hash: String },
184    /// `.well-known/codec/dicts/<hex>.zstd` returned HTTP 404.
185    #[error("No zstd dict at {url} (HTTP 404)")]
186    NotFound { url: String },
187    /// Fetched bytes did not hash to the `<hex>` path component in the URL.
188    /// Treat as byte-tampering: never decompress.
189    #[error("Zstd dict hash mismatch at {url}\n  expected: {expected}\n  actual:   {actual}")]
190    HashMismatch {
191        url: String,
192        expected: String,
193        actual: String,
194    },
195    /// HTTP transport-layer failure (reqwest). Surfaced separately from 404 so
196    /// callers can distinguish "origin doesn't publish this dict" from "we
197    /// couldn't reach the origin at all."
198    #[cfg(feature = "http")]
199    #[error("HTTP error fetching {url}: {source}")]
200    Http {
201        url: String,
202        #[source]
203        source: reqwest::Error,
204    },
205}
206
207/// Validate + normalise an sha256 dict hash to bare lowercase hex.
208///
209/// Accepts either `sha256:<hex>` or bare `<hex>`. Used both as a URL builder
210/// guard and as the expected verifier in [`discover_zstd_dict_blocking`].
211fn parse_dict_hash(hash: &str) -> Result<String, ZstdDictDiscoveryError> {
212    let s = hash.trim();
213    let stripped = s.strip_prefix("sha256:").unwrap_or(s);
214    let lower = stripped.to_ascii_lowercase();
215    if lower.len() != 64 || !lower.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f')) {
216        return Err(ZstdDictDiscoveryError::InvalidHash {
217            hash: hash.to_string(),
218        });
219    }
220    Ok(lower)
221}
222
223/// Per-dict document URL for an origin + sha256 hash (v0.5+).
224///
225/// Returns `<origin>/.well-known/codec/dicts/<sha256-hex>.zstd`. The URL is
226/// fully derived from the hash — there is no mutable per-id form for dicts.
227///
228/// # Errors
229///
230/// [`ZstdDictDiscoveryError::InvalidHash`] when the hash is not the expected
231/// `sha256:<hex>` / bare `<hex>` shape.
232pub fn well_known_dict_url(origin: &str, hash: &str) -> Result<String, ZstdDictDiscoveryError> {
233    let hex = parse_dict_hash(hash)?;
234    let origin = origin.strip_suffix('/').unwrap_or(origin);
235    Ok(format!("{origin}/.well-known/codec/dicts/{hex}.zstd"))
236}
237
238#[cfg(feature = "http")]
239fn sha256_hex_bytes(bytes: &[u8]) -> String {
240    let mut hasher = Sha256::new();
241    hasher.update(bytes);
242    hex::encode(hasher.finalize())
243}
244
245/// Synchronously resolve a zstd dictionary via
246/// `.well-known/codec/dicts/<hex>.zstd` (v0.5+).
247///
248/// Fetches `<origin>/.well-known/codec/dicts/<sha256-hex>.zstd`, verifies the
249/// fetched bytes hash to `<hex>`, and returns the raw dict bytes ready to
250/// feed into `zstd::dict::DecoderDictionary::copy(...)` or equivalent.
251///
252/// # Example
253///
254/// ```no_run
255/// use codec_rs::discover_zstd_dict_blocking;
256///
257/// let dict = discover_zstd_dict_blocking(
258///     "https://codec.example",
259///     "sha256:abc1230000000000000000000000000000000000000000000000000000000000",
260/// )?;
261/// # Ok::<(), Box<dyn std::error::Error>>(())
262/// ```
263///
264/// # Errors
265///
266/// - [`ZstdDictDiscoveryError::InvalidHash`] if the hash is malformed
267///   (rejected before any HTTP request)
268/// - [`ZstdDictDiscoveryError::NotFound`] for HTTP 404
269/// - [`ZstdDictDiscoveryError::HashMismatch`] if origin served wrong bytes
270/// - [`ZstdDictDiscoveryError::Http`] for transport-layer failures
271#[cfg(feature = "http")]
272pub fn discover_zstd_dict_blocking(
273    origin: &str,
274    hash: &str,
275) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
276    let expected = parse_dict_hash(hash)?;
277    let url = well_known_dict_url(origin, hash)?;
278
279    let client = reqwest::blocking::Client::builder()
280        .user_agent("codec-rs/0.4")
281        .build()
282        .map_err(|e| ZstdDictDiscoveryError::Http {
283            url: url.clone(),
284            source: e,
285        })?;
286
287    let resp = client
288        .get(&url)
289        .send()
290        .map_err(|e| ZstdDictDiscoveryError::Http {
291            url: url.clone(),
292            source: e,
293        })?;
294    if resp.status() == reqwest::StatusCode::NOT_FOUND {
295        return Err(ZstdDictDiscoveryError::NotFound { url });
296    }
297    let resp = resp
298        .error_for_status()
299        .map_err(|e| ZstdDictDiscoveryError::Http {
300            url: url.clone(),
301            source: e,
302        })?;
303    let bytes = resp.bytes().map_err(|e| ZstdDictDiscoveryError::Http {
304        url: url.clone(),
305        source: e,
306    })?;
307    let actual = sha256_hex_bytes(&bytes);
308    if actual != expected {
309        return Err(ZstdDictDiscoveryError::HashMismatch {
310            url,
311            expected,
312            actual,
313        });
314    }
315    Ok(bytes.to_vec())
316}
317
318/// Async variant of [`discover_zstd_dict_blocking`]. Requires a Tokio runtime.
319#[cfg(feature = "http")]
320pub async fn discover_zstd_dict(
321    origin: &str,
322    hash: &str,
323) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
324    let expected = parse_dict_hash(hash)?;
325    let url = well_known_dict_url(origin, hash)?;
326
327    let client = reqwest::Client::builder()
328        .user_agent("codec-rs/0.4")
329        .build()
330        .map_err(|e| ZstdDictDiscoveryError::Http {
331            url: url.clone(),
332            source: e,
333        })?;
334
335    let resp =
336        client
337            .get(&url)
338            .send()
339            .await
340            .map_err(|e| ZstdDictDiscoveryError::Http {
341                url: url.clone(),
342                source: e,
343            })?;
344    if resp.status() == reqwest::StatusCode::NOT_FOUND {
345        return Err(ZstdDictDiscoveryError::NotFound { url });
346    }
347    let resp = resp
348        .error_for_status()
349        .map_err(|e| ZstdDictDiscoveryError::Http {
350            url: url.clone(),
351            source: e,
352        })?;
353    let bytes = resp
354        .bytes()
355        .await
356        .map_err(|e| ZstdDictDiscoveryError::Http {
357            url: url.clone(),
358            source: e,
359        })?;
360    let actual = sha256_hex_bytes(&bytes);
361    if actual != expected {
362        return Err(ZstdDictDiscoveryError::HashMismatch {
363            url,
364            expected,
365            actual,
366        });
367    }
368    Ok(bytes.to_vec())
369}
370
371// ── unit tests ────────────────────────────────────────────────────────────
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn hash_zstd_dict_matches_python_reference() {
379        // "hello world" sha256 — same digest both languages produce.
380        let got = hash_zstd_dict(b"hello world");
381        assert_eq!(
382            got,
383            "sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
384        );
385    }
386
387    #[test]
388    fn select_returns_none_when_not_zstd() {
389        let mut headers = HashMap::new();
390        headers.insert("content-encoding".into(), "gzip".into());
391        let dicts: HashMap<String, Vec<u8>> = HashMap::new();
392        assert_eq!(
393            select_zstd_dict_for_response(&headers, &dicts).unwrap(),
394            None
395        );
396    }
397
398    #[test]
399    fn select_returns_none_when_no_encoding() {
400        let headers: HashMap<String, String> = HashMap::new();
401        let dicts: HashMap<String, Vec<u8>> = HashMap::new();
402        assert_eq!(
403            select_zstd_dict_for_response(&headers, &dicts).unwrap(),
404            None
405        );
406    }
407
408    #[test]
409    fn select_missing_header_is_error() {
410        let mut headers = HashMap::new();
411        headers.insert("Content-Encoding".into(), "zstd".into());
412        let dicts: HashMap<String, Vec<u8>> = HashMap::new();
413        assert_eq!(
414            select_zstd_dict_for_response(&headers, &dicts),
415            Err(CodecZstdDictError::MissingHeader)
416        );
417    }
418
419    #[test]
420    fn select_malformed_hash_is_error() {
421        let mut headers = HashMap::new();
422        headers.insert("content-encoding".into(), "zstd".into());
423        headers.insert("codec-zstd-dict".into(), "md5:abc".into());
424        let dicts: HashMap<String, Vec<u8>> = HashMap::new();
425        match select_zstd_dict_for_response(&headers, &dicts) {
426            Err(CodecZstdDictError::MalformedHash(v)) => assert_eq!(v, "md5:abc"),
427            other => panic!("expected MalformedHash, got {other:?}"),
428        }
429    }
430
431    // ── well_known_dict_url / parse_dict_hash (v0.5) ─────────────────────────
432
433    #[test]
434    fn well_known_dict_url_strips_sha256_prefix() {
435        let h = "a".repeat(64);
436        assert_eq!(
437            well_known_dict_url("https://codec.example", &format!("sha256:{h}")).unwrap(),
438            format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
439        );
440    }
441
442    #[test]
443    fn well_known_dict_url_accepts_bare_hex() {
444        let h = "b".repeat(64);
445        assert_eq!(
446            well_known_dict_url("https://codec.example", &h).unwrap(),
447            format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
448        );
449    }
450
451    #[test]
452    fn well_known_dict_url_strips_trailing_slash() {
453        let h = "c".repeat(64);
454        assert_eq!(
455            well_known_dict_url("https://codec.example/", &h).unwrap(),
456            format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
457        );
458    }
459
460    #[test]
461    fn well_known_dict_url_normalises_uppercase_hex() {
462        let upper = "D".repeat(64);
463        let expected = "d".repeat(64);
464        assert_eq!(
465            well_known_dict_url("https://codec.example", &upper).unwrap(),
466            format!("https://codec.example/.well-known/codec/dicts/{expected}.zstd"),
467        );
468    }
469
470    #[test]
471    fn well_known_dict_url_rejects_short_hash() {
472        let err = well_known_dict_url("https://codec.example", "deadbeef").unwrap_err();
473        assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
474    }
475
476    #[test]
477    fn well_known_dict_url_rejects_wrong_algorithm() {
478        let err = well_known_dict_url(
479            "https://codec.example",
480            &format!("md5:{}", "a".repeat(32)),
481        )
482        .unwrap_err();
483        assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
484    }
485
486    #[test]
487    fn well_known_dict_url_rejects_nonhex_chars() {
488        let err = well_known_dict_url("https://codec.example", &"z".repeat(64)).unwrap_err();
489        assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
490    }
491}