pub const SCHEMA_VERSION: u16 = 1;
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct MinHashSig<const H: usize> {
pub schema: u16,
pub _pad: [u8; 6],
pub hashes: [u64; H],
}
unsafe impl<const H: usize> bytemuck::Zeroable for MinHashSig<H> {}
unsafe impl<const H: usize> bytemuck::Pod for MinHashSig<H> {}
impl<const H: usize> MinHashSig<H> {
#[inline]
#[must_use]
pub const fn empty() -> Self {
Self {
schema: SCHEMA_VERSION,
_pad: [0; 6],
hashes: [u64::MAX; H],
}
}
#[inline]
#[must_use]
pub const fn slot_count(&self) -> usize {
H
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
bytemuck::bytes_of(self)
}
}
#[cfg(feature = "serde")]
const _: () = {
use serde::de::{self, MapAccess, SeqAccess, Visitor};
use serde::ser::{SerializeStruct, SerializeTuple};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
impl<const H: usize> Serialize for MinHashSig<H> {
fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
if ser.is_human_readable() {
let mut s = ser.serialize_struct("MinHashSig", 2)?;
s.serialize_field("schema", &self.schema)?;
s.serialize_field("hashes", &SliceSer(&self.hashes[..]))?;
s.end()
} else {
let mut t = ser.serialize_tuple(1 + H)?;
t.serialize_element(&self.schema)?;
for h in &self.hashes {
t.serialize_element(h)?;
}
t.end()
}
}
}
struct SliceSer<'a>(&'a [u64]);
impl Serialize for SliceSer<'_> {
fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeSeq;
let mut s = ser.serialize_seq(Some(self.0.len()))?;
for h in self.0 {
s.serialize_element(h)?;
}
s.end()
}
}
impl<'de, const H: usize> Deserialize<'de> for MinHashSig<H> {
fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
if de.is_human_readable() {
de.deserialize_struct("MinHashSig", &["schema", "hashes"], StructVisitor::<H>)
} else {
de.deserialize_tuple(1 + H, TupleVisitor::<H>)
}
}
}
struct StructVisitor<const H: usize>;
impl<'de, const H: usize> Visitor<'de> for StructVisitor<H> {
type Value = MinHashSig<H>;
fn expecting(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "a MinHashSig<{H}> struct")
}
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let mut schema: Option<u16> = None;
let mut hashes: Option<[u64; H]> = None;
while let Some(key) = map.next_key::<alloc::string::String>()? {
match key.as_str() {
"schema" => schema = Some(map.next_value()?),
"hashes" => hashes = Some(map.next_value::<HashesArray<H>>()?.0),
other => {
return Err(de::Error::unknown_field(other, &["schema", "hashes"]));
}
}
}
let schema = schema.ok_or_else(|| de::Error::missing_field("schema"))?;
let hashes = hashes.ok_or_else(|| de::Error::missing_field("hashes"))?;
Ok(MinHashSig {
schema,
_pad: [0; 6],
hashes,
})
}
}
struct HashesArray<const H: usize>([u64; H]);
impl<'de, const H: usize> Deserialize<'de> for HashesArray<H> {
fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
de.deserialize_seq(HashesArrayVisitor::<H>)
}
}
struct HashesArrayVisitor<const H: usize>;
impl<'de, const H: usize> Visitor<'de> for HashesArrayVisitor<H> {
type Value = HashesArray<H>;
fn expecting(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "an array of exactly {H} u64 hash slots")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut out = [0_u64; H];
for (i, slot) in out.iter_mut().enumerate() {
*slot = seq
.next_element::<u64>()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
}
if seq.next_element::<u64>()?.is_some() {
return Err(de::Error::invalid_length(H + 1, &self));
}
Ok(HashesArray(out))
}
}
struct TupleVisitor<const H: usize>;
impl<'de, const H: usize> Visitor<'de> for TupleVisitor<H> {
type Value = MinHashSig<H>;
fn expecting(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "a {}-element MinHashSig<{H}> tuple", 1 + H)
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let schema: u16 = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let mut hashes = [0_u64; H];
for (i, slot) in hashes.iter_mut().enumerate() {
*slot = seq
.next_element::<u64>()?
.ok_or_else(|| de::Error::invalid_length(1 + i, &self))?;
}
Ok(MinHashSig {
schema,
_pad: [0; 6],
hashes,
})
}
}
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_signature_has_schema_set() {
let s: MinHashSig<128> = MinHashSig::empty();
assert_eq!(s.schema, SCHEMA_VERSION);
assert_eq!(s._pad, [0; 6]);
assert!(s.hashes.iter().all(|h| *h == u64::MAX));
}
#[test]
fn slot_count_matches_const_generic() {
let s: MinHashSig<64> = MinHashSig::empty();
assert_eq!(s.slot_count(), 64);
}
#[test]
fn pod_roundtrip_through_bytes() {
let s: MinHashSig<8> = MinHashSig {
schema: SCHEMA_VERSION,
_pad: [0; 6],
hashes: [1, 2, 3, 4, 5, 6, 7, 8],
};
let bytes = s.as_bytes();
assert_eq!(bytes.len(), 8 + 8 * 8);
let s2: MinHashSig<8> = *bytemuck::from_bytes(bytes);
assert_eq!(s, s2);
}
#[test]
fn schema_version_is_frozen() {
assert_eq!(SCHEMA_VERSION, 1);
}
#[test]
fn signatures_are_pod_eq_hash() {
fn assert_pod<T: bytemuck::Pod>() {}
fn assert_eq_hash<T: Eq + core::hash::Hash>() {}
assert_pod::<MinHashSig<128>>();
assert_eq_hash::<MinHashSig<128>>();
}
#[cfg(feature = "serde")]
#[test]
fn serde_json_round_trip() {
let s: MinHashSig<8> = MinHashSig {
schema: SCHEMA_VERSION,
_pad: [0; 6],
hashes: [11, 22, 33, 44, 55, 66, 77, 88],
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("\"schema\":1"));
assert!(json.contains("\"hashes\":[11,22,33,44,55,66,77,88]"));
let back: MinHashSig<8> = serde_json::from_str(&json).unwrap();
assert_eq!(s, back);
}
#[cfg(feature = "serde")]
#[test]
fn serde_json_h128_round_trip() {
let s: MinHashSig<128> = MinHashSig::empty();
let json = serde_json::to_string(&s).unwrap();
let back: MinHashSig<128> = serde_json::from_str(&json).unwrap();
assert_eq!(s, back);
}
#[cfg(feature = "serde")]
#[test]
fn serde_rejects_wrong_length_in_human_format() {
let bad = "{\"schema\":1,\"hashes\":[1,2,3,4,5,6,7]}";
let r: Result<MinHashSig<8>, _> = serde_json::from_str(bad);
assert!(r.is_err());
}
}