codec_rs/compression.rs
1// SPDX-License-Identifier: MIT
2//! Client-side helpers for the Codec compression contract.
3//!
4//! Rust twin of `codecai.compression` (Python): the server emits
5//! `Codec-Zstd-Dict: sha256:<hex>` on every zstd response, the client
6//! validates that header against locally-loaded dicts before
7//! decompressing. See `spec/PROTOCOL.md` "Codec-Zstd-Dict response
8//! header" (stable since v0.2) for the full contract.
9//!
10//! The actual zstd decompression is intentionally out of scope here —
11//! callers usually already have an HTTP stack and pick their own
12//! zstd binding (`zstd` crate, `zstd-safe`, an FFI wrapper, etc.).
13//! This module just gives you the small piece that's specific to
14//! Codec: matching a response's declared dict hash to the dict you've
15//! loaded, with the case-insensitive header lookup and fail-fast
16//! semantics the spec mandates.
17//!
18//! Wrong-dict decompression produces garbage bytes that downstream
19//! msgpack / protobuf parsers will silently misinterpret — fail fast
20//! at the dict-select boundary instead.
21
22use std::collections::HashMap;
23use std::fmt;
24
25use sha2::{Digest, Sha256};
26
27/// Compute the canonical `Codec-Zstd-Dict` hash for `dict_bytes`.
28///
29/// Returns `sha256:<lowercase hex>` — same shape as the server-side
30/// header value and the `hash` field in tokenizer-map
31/// `zstd_dictionaries[]` entries.
32pub fn hash_zstd_dict(dict_bytes: &[u8]) -> String {
33 let mut hasher = Sha256::new();
34 hasher.update(dict_bytes);
35 let digest = hasher.finalize();
36 format!("sha256:{}", hex::encode(digest))
37}
38
39/// Raised when the server's `Codec-Zstd-Dict` header doesn't match any
40/// dict the client has loaded, or is missing on a zstd response.
41///
42/// A wrong-dict decompression would produce garbage bytes that
43/// downstream parsers would misinterpret — fail fast instead.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum CodecZstdDictError {
46 /// Response was `Content-Encoding: zstd` but the server omitted
47 /// the `Codec-Zstd-Dict` header. Per spec/PROTOCOL.md the server
48 /// MUST name the dict it used; we refuse to guess.
49 MissingHeader,
50 /// `Codec-Zstd-Dict` value was not in the canonical
51 /// `sha256:<64 hex chars>` shape. The wrapped string is the raw
52 /// header value (trimmed) for diagnostics.
53 MalformedHash(String),
54 /// `Codec-Zstd-Dict` named a dict we haven't loaded. The caller
55 /// should fetch it from the tokenizer map's `zstd_dictionaries[]`
56 /// entry (the one whose `hash` matches) or retry the request with
57 /// `Accept-Encoding: gzip` to downgrade to a no-dict path. The
58 /// wrapped string is the declared hash.
59 UnknownHash(String),
60}
61
62impl fmt::Display for CodecZstdDictError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 match self {
65 CodecZstdDictError::MissingHeader => write!(
66 f,
67 "Response is Content-Encoding: zstd but no Codec-Zstd-Dict \
68 header was present. Per spec/PROTOCOL.md the server MUST \
69 name the dict it used. Refusing to guess."
70 ),
71 CodecZstdDictError::MalformedHash(value) => write!(
72 f,
73 "Malformed Codec-Zstd-Dict value: {value:?}. Expected \
74 'sha256:<64 hex chars>'."
75 ),
76 CodecZstdDictError::UnknownHash(hash) => write!(
77 f,
78 "Server used zstd dict {hash} but it isn't loaded \
79 locally. Fetch it from the tokenizer map's \
80 zstd_dictionaries[] entry (the entry whose hash \
81 matches), or send Accept-Encoding: gzip to downgrade."
82 ),
83 }
84 }
85}
86
87impl std::error::Error for CodecZstdDictError {}
88
89/// Pick the zstd dict to decompress this response with.
90///
91/// # Parameters
92///
93/// - `response_headers`: HTTP response header map. Keys are looked up
94/// case-insensitively, matching the way `reqwest::HeaderMap` and most
95/// HTTP libraries treat headers on the wire.
96/// - `loaded_dicts`: `{ "sha256:<hex>" -> dict bytes }`. Populate this
97/// from your tokenizer map's `zstd_dictionaries[]` entries; keys
98/// follow the same canonical shape the server emits.
99///
100/// # Returns
101///
102/// - `Ok(Some(&dict_bytes))` when the response is
103/// `Content-Encoding: zstd` and the server's `Codec-Zstd-Dict`
104/// header points at a loaded dict — pass these bytes to your zstd
105/// decoder (e.g. `zstd::stream::Decoder::with_dictionary`).
106/// - `Ok(None)` when the response isn't zstd. The caller should pass
107/// the body through identity / let its HTTP stack auto-decompress
108/// gzip / brotli.
109///
110/// # Errors
111///
112/// Returns `CodecZstdDictError` when the response is zstd but the
113/// header is missing, malformed, or names a dict the client hasn't
114/// loaded. Wrong-dict decompression is never attempted — see the
115/// spec rationale at `spec/PROTOCOL.md`.
116pub fn select_zstd_dict_for_response<'a>(
117 response_headers: &HashMap<String, String>,
118 loaded_dicts: &'a HashMap<String, Vec<u8>>,
119) -> Result<Option<&'a [u8]>, CodecZstdDictError> {
120 let enc = header(response_headers, "content-encoding");
121 match enc.map(|v| v.trim().to_ascii_lowercase()) {
122 Some(ref v) if v == "zstd" => {}
123 _ => return Ok(None), // caller's HTTP stack handles gzip/br/identity
124 }
125
126 let declared = match header(response_headers, "codec-zstd-dict") {
127 Some(v) => v.trim().to_string(),
128 None => return Err(CodecZstdDictError::MissingHeader),
129 };
130 if declared.is_empty() {
131 return Err(CodecZstdDictError::MissingHeader);
132 }
133
134 if !is_canonical_sha256(&declared) {
135 return Err(CodecZstdDictError::MalformedHash(declared));
136 }
137
138 match loaded_dicts.get(&declared) {
139 Some(bytes) => Ok(Some(bytes.as_slice())),
140 None => Err(CodecZstdDictError::UnknownHash(declared)),
141 }
142}
143
144/// Case-insensitive header lookup. Most idiomatic HTTP libraries
145/// (`reqwest::HeaderMap`, `http::HeaderMap`) already treat header
146/// names case-insensitively; this is the defensive fallback for
147/// callers that hand us a plain `HashMap<String, String>` lifted
148/// out of their own request plumbing.
149fn header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
150 if let Some(v) = headers.get(name) {
151 return Some(v.as_str());
152 }
153 let lower = name.to_ascii_lowercase();
154 for (k, v) in headers.iter() {
155 if k.to_ascii_lowercase() == lower {
156 return Some(v.as_str());
157 }
158 }
159 None
160}
161
162fn is_canonical_sha256(value: &str) -> bool {
163 const PREFIX: &str = "sha256:";
164 if !value.starts_with(PREFIX) {
165 return false;
166 }
167 let hex = &value[PREFIX.len()..];
168 hex.len() == 64 && hex.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
169}
170
171// ── unit tests ────────────────────────────────────────────────────────────
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn hash_zstd_dict_matches_python_reference() {
179 // "hello world" sha256 — same digest both languages produce.
180 let got = hash_zstd_dict(b"hello world");
181 assert_eq!(
182 got,
183 "sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
184 );
185 }
186
187 #[test]
188 fn select_returns_none_when_not_zstd() {
189 let mut headers = HashMap::new();
190 headers.insert("content-encoding".into(), "gzip".into());
191 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
192 assert_eq!(
193 select_zstd_dict_for_response(&headers, &dicts).unwrap(),
194 None
195 );
196 }
197
198 #[test]
199 fn select_returns_none_when_no_encoding() {
200 let headers: HashMap<String, String> = HashMap::new();
201 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
202 assert_eq!(
203 select_zstd_dict_for_response(&headers, &dicts).unwrap(),
204 None
205 );
206 }
207
208 #[test]
209 fn select_missing_header_is_error() {
210 let mut headers = HashMap::new();
211 headers.insert("Content-Encoding".into(), "zstd".into());
212 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
213 assert_eq!(
214 select_zstd_dict_for_response(&headers, &dicts),
215 Err(CodecZstdDictError::MissingHeader)
216 );
217 }
218
219 #[test]
220 fn select_malformed_hash_is_error() {
221 let mut headers = HashMap::new();
222 headers.insert("content-encoding".into(), "zstd".into());
223 headers.insert("codec-zstd-dict".into(), "md5:abc".into());
224 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
225 match select_zstd_dict_for_response(&headers, &dicts) {
226 Err(CodecZstdDictError::MalformedHash(v)) => assert_eq!(v, "md5:abc"),
227 other => panic!("expected MalformedHash, got {other:?}"),
228 }
229 }
230}