use std::collections::HashMap;
use std::fmt;
use sha2::{Digest, Sha256};
pub fn hash_zstd_dict(dict_bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(dict_bytes);
let digest = hasher.finalize();
format!("sha256:{}", hex::encode(digest))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CodecZstdDictError {
MissingHeader,
MalformedHash(String),
UnknownHash(String),
}
impl fmt::Display for CodecZstdDictError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CodecZstdDictError::MissingHeader => write!(
f,
"Response is Content-Encoding: zstd but no Codec-Zstd-Dict \
header was present. Per spec/PROTOCOL.md the server MUST \
name the dict it used. Refusing to guess."
),
CodecZstdDictError::MalformedHash(value) => write!(
f,
"Malformed Codec-Zstd-Dict value: {value:?}. Expected \
'sha256:<64 hex chars>'."
),
CodecZstdDictError::UnknownHash(hash) => write!(
f,
"Server used zstd dict {hash} but it isn't loaded \
locally. Fetch it from the tokenizer map's \
zstd_dictionaries[] entry (the entry whose hash \
matches), or send Accept-Encoding: gzip to downgrade."
),
}
}
}
impl std::error::Error for CodecZstdDictError {}
pub fn select_zstd_dict_for_response<'a>(
response_headers: &HashMap<String, String>,
loaded_dicts: &'a HashMap<String, Vec<u8>>,
) -> Result<Option<&'a [u8]>, CodecZstdDictError> {
let enc = header(response_headers, "content-encoding");
match enc.map(|v| v.trim().to_ascii_lowercase()) {
Some(ref v) if v == "zstd" => {}
_ => return Ok(None), }
let declared = match header(response_headers, "codec-zstd-dict") {
Some(v) => v.trim().to_string(),
None => return Err(CodecZstdDictError::MissingHeader),
};
if declared.is_empty() {
return Err(CodecZstdDictError::MissingHeader);
}
if !is_canonical_sha256(&declared) {
return Err(CodecZstdDictError::MalformedHash(declared));
}
match loaded_dicts.get(&declared) {
Some(bytes) => Ok(Some(bytes.as_slice())),
None => Err(CodecZstdDictError::UnknownHash(declared)),
}
}
fn header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
if let Some(v) = headers.get(name) {
return Some(v.as_str());
}
let lower = name.to_ascii_lowercase();
for (k, v) in headers.iter() {
if k.to_ascii_lowercase() == lower {
return Some(v.as_str());
}
}
None
}
fn is_canonical_sha256(value: &str) -> bool {
const PREFIX: &str = "sha256:";
if !value.starts_with(PREFIX) {
return false;
}
let hex = &value[PREFIX.len()..];
hex.len() == 64 && hex.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
}
#[derive(Debug, thiserror::Error)]
pub enum ZstdDictDiscoveryError {
#[error("Invalid dict hash {hash:?}: expected 'sha256:<64 hex>' or '<64 hex>'")]
InvalidHash { hash: String },
#[error("No zstd dict at {url} (HTTP 404)")]
NotFound { url: String },
#[error("Zstd dict hash mismatch at {url}\n expected: {expected}\n actual: {actual}")]
HashMismatch {
url: String,
expected: String,
actual: String,
},
#[cfg(feature = "http")]
#[error("HTTP error fetching {url}: {source}")]
Http {
url: String,
#[source]
source: reqwest::Error,
},
}
fn parse_dict_hash(hash: &str) -> Result<String, ZstdDictDiscoveryError> {
let s = hash.trim();
let stripped = s.strip_prefix("sha256:").unwrap_or(s);
let lower = stripped.to_ascii_lowercase();
if lower.len() != 64 || !lower.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f')) {
return Err(ZstdDictDiscoveryError::InvalidHash {
hash: hash.to_string(),
});
}
Ok(lower)
}
pub fn well_known_dict_url(origin: &str, hash: &str) -> Result<String, ZstdDictDiscoveryError> {
let hex = parse_dict_hash(hash)?;
let origin = origin.strip_suffix('/').unwrap_or(origin);
Ok(format!("{origin}/.well-known/codec/dicts/{hex}.zstd"))
}
#[cfg(feature = "http")]
fn sha256_hex_bytes(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
hex::encode(hasher.finalize())
}
#[cfg(feature = "http")]
pub fn discover_zstd_dict_blocking(
origin: &str,
hash: &str,
) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
let expected = parse_dict_hash(hash)?;
let url = well_known_dict_url(origin, hash)?;
let client = reqwest::blocking::Client::builder()
.user_agent("codec-rs/0.4")
.build()
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let resp = client
.get(&url)
.send()
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return Err(ZstdDictDiscoveryError::NotFound { url });
}
let resp = resp
.error_for_status()
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let bytes = resp.bytes().map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let actual = sha256_hex_bytes(&bytes);
if actual != expected {
return Err(ZstdDictDiscoveryError::HashMismatch {
url,
expected,
actual,
});
}
Ok(bytes.to_vec())
}
#[cfg(feature = "http")]
pub async fn discover_zstd_dict(
origin: &str,
hash: &str,
) -> Result<Vec<u8>, ZstdDictDiscoveryError> {
let expected = parse_dict_hash(hash)?;
let url = well_known_dict_url(origin, hash)?;
let client = reqwest::Client::builder()
.user_agent("codec-rs/0.4")
.build()
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let resp =
client
.get(&url)
.send()
.await
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return Err(ZstdDictDiscoveryError::NotFound { url });
}
let resp = resp
.error_for_status()
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let bytes = resp
.bytes()
.await
.map_err(|e| ZstdDictDiscoveryError::Http {
url: url.clone(),
source: e,
})?;
let actual = sha256_hex_bytes(&bytes);
if actual != expected {
return Err(ZstdDictDiscoveryError::HashMismatch {
url,
expected,
actual,
});
}
Ok(bytes.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_zstd_dict_matches_python_reference() {
let got = hash_zstd_dict(b"hello world");
assert_eq!(
got,
"sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
);
}
#[test]
fn select_returns_none_when_not_zstd() {
let mut headers = HashMap::new();
headers.insert("content-encoding".into(), "gzip".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts).unwrap(),
None
);
}
#[test]
fn select_returns_none_when_no_encoding() {
let headers: HashMap<String, String> = HashMap::new();
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts).unwrap(),
None
);
}
#[test]
fn select_missing_header_is_error() {
let mut headers = HashMap::new();
headers.insert("Content-Encoding".into(), "zstd".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts),
Err(CodecZstdDictError::MissingHeader)
);
}
#[test]
fn select_malformed_hash_is_error() {
let mut headers = HashMap::new();
headers.insert("content-encoding".into(), "zstd".into());
headers.insert("codec-zstd-dict".into(), "md5:abc".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
match select_zstd_dict_for_response(&headers, &dicts) {
Err(CodecZstdDictError::MalformedHash(v)) => assert_eq!(v, "md5:abc"),
other => panic!("expected MalformedHash, got {other:?}"),
}
}
#[test]
fn well_known_dict_url_strips_sha256_prefix() {
let h = "a".repeat(64);
assert_eq!(
well_known_dict_url("https://codec.example", &format!("sha256:{h}")).unwrap(),
format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
);
}
#[test]
fn well_known_dict_url_accepts_bare_hex() {
let h = "b".repeat(64);
assert_eq!(
well_known_dict_url("https://codec.example", &h).unwrap(),
format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
);
}
#[test]
fn well_known_dict_url_strips_trailing_slash() {
let h = "c".repeat(64);
assert_eq!(
well_known_dict_url("https://codec.example/", &h).unwrap(),
format!("https://codec.example/.well-known/codec/dicts/{h}.zstd"),
);
}
#[test]
fn well_known_dict_url_normalises_uppercase_hex() {
let upper = "D".repeat(64);
let expected = "d".repeat(64);
assert_eq!(
well_known_dict_url("https://codec.example", &upper).unwrap(),
format!("https://codec.example/.well-known/codec/dicts/{expected}.zstd"),
);
}
#[test]
fn well_known_dict_url_rejects_short_hash() {
let err = well_known_dict_url("https://codec.example", "deadbeef").unwrap_err();
assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
}
#[test]
fn well_known_dict_url_rejects_wrong_algorithm() {
let err = well_known_dict_url(
"https://codec.example",
&format!("md5:{}", "a".repeat(32)),
)
.unwrap_err();
assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
}
#[test]
fn well_known_dict_url_rejects_nonhex_chars() {
let err = well_known_dict_url("https://codec.example", &"z".repeat(64)).unwrap_err();
assert!(matches!(err, ZstdDictDiscoveryError::InvalidHash { .. }));
}
}