1use std::collections::HashMap;
23use std::fmt;
24
25use sha2::{Digest, Sha256};
26
27pub fn hash_zstd_dict(dict_bytes: &[u8]) -> String {
33 let mut hasher = Sha256::new();
34 hasher.update(dict_bytes);
35 let digest = hasher.finalize();
36 format!("sha256:{}", hex::encode(digest))
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum CodecZstdDictError {
46 MissingHeader,
50 MalformedHash(String),
54 UnknownHash(String),
60}
61
62impl fmt::Display for CodecZstdDictError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 match self {
65 CodecZstdDictError::MissingHeader => write!(
66 f,
67 "Response is Content-Encoding: zstd but no Codec-Zstd-Dict \
68 header was present. Per spec/PROTOCOL.md the server MUST \
69 name the dict it used. Refusing to guess."
70 ),
71 CodecZstdDictError::MalformedHash(value) => write!(
72 f,
73 "Malformed Codec-Zstd-Dict value: {value:?}. Expected \
74 'sha256:<64 hex chars>'."
75 ),
76 CodecZstdDictError::UnknownHash(hash) => write!(
77 f,
78 "Server used zstd dict {hash} but it isn't loaded \
79 locally. Fetch it from the tokenizer map's \
80 zstd_dictionaries[] entry (the entry whose hash \
81 matches), or send Accept-Encoding: gzip to downgrade."
82 ),
83 }
84 }
85}
86
87impl std::error::Error for CodecZstdDictError {}
88
89pub fn select_zstd_dict_for_response<'a>(
117 response_headers: &HashMap<String, String>,
118 loaded_dicts: &'a HashMap<String, Vec<u8>>,
119) -> Result<Option<&'a [u8]>, CodecZstdDictError> {
120 let enc = header(response_headers, "content-encoding");
121 match enc.map(|v| v.trim().to_ascii_lowercase()) {
122 Some(ref v) if v == "zstd" => {}
123 _ => return Ok(None), }
125
126 let declared = match header(response_headers, "codec-zstd-dict") {
127 Some(v) => v.trim().to_string(),
128 None => return Err(CodecZstdDictError::MissingHeader),
129 };
130 if declared.is_empty() {
131 return Err(CodecZstdDictError::MissingHeader);
132 }
133
134 if !is_canonical_sha256(&declared) {
135 return Err(CodecZstdDictError::MalformedHash(declared));
136 }
137
138 match loaded_dicts.get(&declared) {
139 Some(bytes) => Ok(Some(bytes.as_slice())),
140 None => Err(CodecZstdDictError::UnknownHash(declared)),
141 }
142}
143
144fn header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
150 if let Some(v) = headers.get(name) {
151 return Some(v.as_str());
152 }
153 let lower = name.to_ascii_lowercase();
154 for (k, v) in headers.iter() {
155 if k.to_ascii_lowercase() == lower {
156 return Some(v.as_str());
157 }
158 }
159 None
160}
161
162fn is_canonical_sha256(value: &str) -> bool {
163 const PREFIX: &str = "sha256:";
164 if !value.starts_with(PREFIX) {
165 return false;
166 }
167 let hex = &value[PREFIX.len()..];
168 hex.len() == 64 && hex.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
169}
170
171#[derive(Debug, thiserror::Error)]
180pub enum ZstdDictDiscoveryError {
181 #[error("Invalid dict hash {hash:?}: expected 'sha256:<64 hex>' or '<64 hex>'")]
183 InvalidHash { hash: String },
184 #[error("No zstd dict at {url} (HTTP 404)")]
186 NotFound { url: String },
187 #[error("Zstd dict hash mismatch at {url}\n expected: {expected}\n actual: {actual}")]
190 HashMismatch {
191 url: String,
192 expected: String,
193 actual: String,
194 },
195 #[cfg(feature = "http")]
199 #[error("HTTP error fetching {url}: {source}")]
200 Http {
201 url: String,
202 #[source]
203 source: reqwest::Error,
204 },
205}
206
207fn parse_dict_hash(hash: &str) -> Result<String, ZstdDictDiscoveryError> {
212 let s = hash.trim();
213 let stripped = s.strip_prefix("sha256:").unwrap_or(s);
214 let lower = stripped.to_ascii_lowercase();
215 if lower.len() != 64 || !lower.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f')) {
216 return Err(ZstdDictDiscoveryError::InvalidHash {
217 hash: hash.to_string(),
218 });
219 }
220 Ok(lower)
221}
222
223pub fn well_known_dict_url(origin: &str, hash: &str) -> Result<String, ZstdDictDiscoveryError> {
233 let hex = parse_dict_hash(hash)?;
234 let origin = origin.strip_suffix('/').unwrap_or(origin);
235 Ok(format!("{origin}/.well-known/codec/dicts/{hex}.zstd"))
236}
237
238#[cfg(feature = "http")]
239fn sha256_hex_bytes(bytes: &[u8]) -> String {
240 let mut hasher = Sha256::new();
241 hasher.update(bytes);
242 hex::encode(hasher.finalize())
243}
244
245#[cfg(feature = "http")]
272pub fn discover_zstd_dict_blocking(
273 origin: &str,
274 hash: &str,
275) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
276 let expected = parse_dict_hash(hash)?;
277 let url = well_known_dict_url(origin, hash)?;
278
279 let client = reqwest::blocking::Client::builder()
280 .user_agent("codec-rs/0.4")
281 .build()
282 .map_err(|e| ZstdDictDiscoveryError::Http {
283 url: url.clone(),
284 source: e,
285 })?;
286
287 let resp = client
288 .get(&url)
289 .send()
290 .map_err(|e| ZstdDictDiscoveryError::Http {
291 url: url.clone(),
292 source: e,
293 })?;
294 if resp.status() == reqwest::StatusCode::NOT_FOUND {
295 return Err(ZstdDictDiscoveryError::NotFound { url });
296 }
297 let resp = resp
298 .error_for_status()
299 .map_err(|e| ZstdDictDiscoveryError::Http {
300 url: url.clone(),
301 source: e,
302 })?;
303 let bytes = resp.bytes().map_err(|e| ZstdDictDiscoveryError::Http {
304 url: url.clone(),
305 source: e,
306 })?;
307 let actual = sha256_hex_bytes(&bytes);
308 if actual != expected {
309 return Err(ZstdDictDiscoveryError::HashMismatch {
310 url,
311 expected,
312 actual,
313 });
314 }
315 Ok(bytes.to_vec())
316}
317
318#[cfg(feature = "http")]
320pub async fn discover_zstd_dict(
321 origin: &str,
322 hash: &str,
323) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
324 let expected = parse_dict_hash(hash)?;
325 let url = well_known_dict_url(origin, hash)?;
326
327 let client = reqwest::Client::builder()
328 .user_agent("codec-rs/0.4")
329 .build()
330 .map_err(|e| ZstdDictDiscoveryError::Http {
331 url: url.clone(),
332 source: e,
333 })?;
334
335 let resp =
336 client
337 .get(&url)
338 .send()
339 .await
340 .map_err(|e| ZstdDictDiscoveryError::Http {
341 url: url.clone(),
342 source: e,
343 })?;
344 if resp.status() == reqwest::StatusCode::NOT_FOUND {
345 return Err(ZstdDictDiscoveryError::NotFound { url });
346 }
347 let resp = resp
348 .error_for_status()
349 .map_err(|e| ZstdDictDiscoveryError::Http {
350 url: url.clone(),
351 source: e,
352 })?;
353 let bytes = resp
354 .bytes()
355 .await
356 .map_err(|e| ZstdDictDiscoveryError::Http {
357 url: url.clone(),
358 source: e,
359 })?;
360 let actual = sha256_hex_bytes(&bytes);
361 if actual != expected {
362 return Err(ZstdDictDiscoveryError::HashMismatch {
363 url,
364 expected,
365 actual,
366 });
367 }
368 Ok(bytes.to_vec())
369}
370
371#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn hash_zstd_dict_matches_python_reference() {
379 let got = hash_zstd_dict(b"hello world");
381 assert_eq!(
382 got,
383 "sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
384 );
385 }
386
387 #[test]
388 fn select_returns_none_when_not_zstd() {
389 let mut headers = HashMap::new();
390 headers.insert("content-encoding".into(), "gzip".into());
391 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
392 assert_eq!(
393 select_zstd_dict_for_response(&headers, &dicts).unwrap(),
394 None
395 );
396 }
397
398 #[test]
399 fn select_returns_none_when_no_encoding() {
400 let headers: HashMap<String, String> = HashMap::new();
401 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
402 assert_eq!(
403 select_zstd_dict_for_response(&headers, &dicts).unwrap(),
404 None
405 );
406 }
407
408 #[test]
409 fn select_missing_header_is_error() {
410 let mut headers = HashMap::new();
411 headers.insert("Content-Encoding".into(), "zstd".into());
412 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
413 assert_eq!(
414 select_zstd_dict_for_response(&headers, &dicts),
415 Err(CodecZstdDictError::MissingHeader)
416 );
417 }
418
419 #[test]
420 fn select_malformed_hash_is_error() {
421 let mut headers = HashMap::new();
422 headers.insert("content-encoding".into(), "zstd".into());
423 headers.insert("codec-zstd-dict".into(), "md5:abc".into());
424 let dicts: HashMap<String, Vec<u8>> = HashMap::new();
425 match select_zstd_dict_for_response(&headers, &dicts) {
426 Err(CodecZstdDictError::MalformedHash(v)) => assert_eq!(v, "md5:abc"),
427 other => panic!("expected MalformedHash, got {other:?}"),
428 }
429 }
430
431 #[test]
434 fn well_known_dict_url_strips_sha256_prefix() {
435 let h = "a".repeat(64);
436 assert_eq!(
437 well_known_dict_url("https://codec.example", &format!("sha256:{h}")).unwrap(),
438 format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
439 );
440 }
441
442 #[test]
443 fn well_known_dict_url_accepts_bare_hex() {
444 let h = "b".repeat(64);
445 assert_eq!(
446 well_known_dict_url("https://codec.example", &h).unwrap(),
447 format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
448 );
449 }
450
451 #[test]
452 fn well_known_dict_url_strips_trailing_slash() {
453 let h = "c".repeat(64);
454 assert_eq!(
455 well_known_dict_url("https://codec.example/", &h).unwrap(),
456 format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
457 );
458 }
459
460 #[test]
461 fn well_known_dict_url_normalises_uppercase_hex() {
462 let upper = "D".repeat(64);
463 let expected = "d".repeat(64);
464 assert_eq!(
465 well_known_dict_url("https://codec.example", &upper).unwrap(),
466 format!("https://codec.example/.well-known/codec/dicts/{expected}.zstd"),
467 );
468 }
469
470 #[test]
471 fn well_known_dict_url_rejects_short_hash() {
472 let err = well_known_dict_url("https://codec.example", "deadbeef").unwrap_err();
473 assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
474 }
475
476 #[test]
477 fn well_known_dict_url_rejects_wrong_algorithm() {
478 let err = well_known_dict_url(
479 "https://codec.example",
480 &format!("md5:{}", "a".repeat(32)),
481 )
482 .unwrap_err();
483 assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
484 }
485
486 #[test]
487 fn well_known_dict_url_rejects_nonhex_chars() {
488 let err = well_known_dict_url("https://codec.example", &"z".repeat(64)).unwrap_err();
489 assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
490 }
491}