use std::collections::HashMap;
use std::fmt;
use sha2::{Digest, Sha256};
pub fn hash_zstd_dict(dict_bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(dict_bytes);
let digest = hasher.finalize();
format!("sha256:{}", hex::encode(digest))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CodecZstdDictError {
MissingHeader,
MalformedHash(String),
UnknownHash(String),
}
impl fmt::Display for CodecZstdDictError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CodecZstdDictError::MissingHeader => write!(
f,
"Response is Content-Encoding: zstd but no Codec-Zstd-Dict \
header was present. Per spec/PROTOCOL.md the server MUST \
name the dict it used. Refusing to guess."
),
CodecZstdDictError::MalformedHash(value) => write!(
f,
"Malformed Codec-Zstd-Dict value: {value:?}. Expected \
'sha256:<64 hex chars>'."
),
CodecZstdDictError::UnknownHash(hash) => write!(
f,
"Server used zstd dict {hash} but it isn't loaded \
locally. Fetch it from the tokenizer map's \
zstd_dictionaries[] entry (the entry whose hash \
matches), or send Accept-Encoding: gzip to downgrade."
),
}
}
}
impl std::error::Error for CodecZstdDictError {}
pub fn select_zstd_dict_for_response<'a>(
response_headers: &HashMap<String, String>,
loaded_dicts: &'a HashMap<String, Vec<u8>>,
) -> Result<Option<&'a [u8]>, CodecZstdDictError> {
let enc = header(response_headers, "content-encoding");
match enc.map(|v| v.trim().to_ascii_lowercase()) {
Some(ref v) if v == "zstd" => {}
_ => return Ok(None), }
let declared = match header(response_headers, "codec-zstd-dict") {
Some(v) => v.trim().to_string(),
None => return Err(CodecZstdDictError::MissingHeader),
};
if declared.is_empty() {
return Err(CodecZstdDictError::MissingHeader);
}
if !is_canonical_sha256(&declared) {
return Err(CodecZstdDictError::MalformedHash(declared));
}
match loaded_dicts.get(&declared) {
Some(bytes) => Ok(Some(bytes.as_slice())),
None => Err(CodecZstdDictError::UnknownHash(declared)),
}
}
fn header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
if let Some(v) = headers.get(name) {
return Some(v.as_str());
}
let lower = name.to_ascii_lowercase();
for (k, v) in headers.iter() {
if k.to_ascii_lowercase() == lower {
return Some(v.as_str());
}
}
None
}
fn is_canonical_sha256(value: &str) -> bool {
const PREFIX: &str = "sha256:";
if !value.starts_with(PREFIX) {
return false;
}
let hex = &value[PREFIX.len()..];
hex.len() == 64 && hex.bytes().all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_zstd_dict_matches_python_reference() {
let got = hash_zstd_dict(b"hello world");
assert_eq!(
got,
"sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
);
}
#[test]
fn select_returns_none_when_not_zstd() {
let mut headers = HashMap::new();
headers.insert("content-encoding".into(), "gzip".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts).unwrap(),
None
);
}
#[test]
fn select_returns_none_when_no_encoding() {
let headers: HashMap<String, String> = HashMap::new();
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts).unwrap(),
None
);
}
#[test]
fn select_missing_header_is_error() {
let mut headers = HashMap::new();
headers.insert("Content-Encoding".into(), "zstd".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
assert_eq!(
select_zstd_dict_for_response(&headers, &dicts),
Err(CodecZstdDictError::MissingHeader)
);
}
#[test]
fn select_malformed_hash_is_error() {
let mut headers = HashMap::new();
headers.insert("content-encoding".into(), "zstd".into());
headers.insert("codec-zstd-dict".into(), "md5:abc".into());
let dicts: HashMap<String, Vec<u8>> = HashMap::new();
match select_zstd_dict_for_response(&headers, &dicts) {
Err(CodecZstdDictError::MalformedHash(v)) => assert_eq!(v, "md5:abc"),
other => panic!("expected MalformedHash, got {other:?}"),
}
}
}