use serde::{Deserialize, Serialize};
pub const SHAPE_SCHEMA_VERSION: u16 = 1;
pub const CF_BUCKET_COUNT: usize = 15;
pub const MINHASH_LANES: usize = 64;
pub const MIN_HASHABLE_TOKENS: u16 = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct ShapeHash128 {
pub high: u64,
pub low: u64,
}
impl ShapeHash128 {
#[must_use]
pub const fn as_u128(self) -> u128 {
((self.high as u128) << 64) | (self.low as u128)
}
#[must_use]
pub const fn from_u128(value: u128) -> Self {
Self {
high: (value >> 64) as u64,
low: (value & 0xFFFF_FFFF_FFFF_FFFF) as u64,
}
}
#[must_use]
pub const fn is_zero(self) -> bool {
self.high == 0 && self.low == 0
}
}
impl std::fmt::Display for ShapeHash128 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:016x}{:016x}", self.high, self.low)
}
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct SignatureShape {
pub arity_positional: u16,
pub arity_keyword_only: u16,
pub has_defaults: bool,
pub has_varargs: bool,
pub has_kwargs: bool,
pub has_return_annotation: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum CalleeShape {
#[default]
Unresolved,
Resolved {
count: u16,
degree_buckets: [u16; 4],
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct ShapeFlags(pub u8);
impl ShapeFlags {
pub const UNHASHABLE: u8 = 0b0000_0001;
pub const TRUNCATED: u8 = 0b0000_0010;
#[must_use]
pub const fn empty() -> Self {
Self(0)
}
#[must_use]
pub const fn is_unhashable(self) -> bool {
self.0 & Self::UNHASHABLE != 0
}
#[must_use]
pub const fn is_truncated(self) -> bool {
self.0 & Self::TRUNCATED != 0
}
pub const fn set_unhashable(&mut self) {
self.0 |= Self::UNHASHABLE;
}
pub const fn set_truncated(&mut self) {
self.0 |= Self::TRUNCATED;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShapeDescriptor {
pub cf_histogram: [u16; CF_BUCKET_COUNT],
pub signature_shape: SignatureShape,
pub callee_shape: CalleeShape,
pub shape_hash: ShapeHash128,
#[serde(with = "minhash_serde")]
pub minhash: [u32; MINHASH_LANES],
pub flags: ShapeFlags,
}
impl Default for ShapeDescriptor {
fn default() -> Self {
Self {
cf_histogram: [0; CF_BUCKET_COUNT],
signature_shape: SignatureShape::default(),
callee_shape: CalleeShape::default(),
shape_hash: ShapeHash128::default(),
minhash: [0; MINHASH_LANES],
flags: ShapeFlags::empty(),
}
}
}
impl ShapeDescriptor {
#[must_use]
pub fn unhashable(signature_shape: SignatureShape) -> Self {
let mut flags = ShapeFlags::empty();
flags.set_unhashable();
Self {
signature_shape,
flags,
..Self::default()
}
}
#[must_use]
pub const fn is_unhashable(&self) -> bool {
self.flags.is_unhashable()
}
#[must_use]
pub const fn is_truncated(&self) -> bool {
self.flags.is_truncated()
}
}
mod minhash_serde {
use super::MINHASH_LANES;
use serde::de::{self, Deserializer, SeqAccess, Visitor};
use serde::ser::{SerializeTuple, Serializer};
use std::fmt;
pub fn serialize<S>(value: &[u32; MINHASH_LANES], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(MINHASH_LANES)?;
for lane in value {
tup.serialize_element(lane)?;
}
tup.end()
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<[u32; MINHASH_LANES], D::Error>
where
D: Deserializer<'de>,
{
struct LanesVisitor;
impl<'de> Visitor<'de> for LanesVisitor {
type Value = [u32; MINHASH_LANES];
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "an array of {MINHASH_LANES} u32 `MinHash` lanes")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut lanes = [0u32; MINHASH_LANES];
for (i, slot) in lanes.iter_mut().enumerate() {
*slot = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
}
Ok(lanes)
}
}
deserializer.deserialize_tuple(MINHASH_LANES, LanesVisitor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shape_hash_u128_roundtrip() {
let h = ShapeHash128 {
high: 0x1234_5678_9abc_def0,
low: 0x0fed_cba9_8765_4321,
};
assert_eq!(ShapeHash128::from_u128(h.as_u128()), h);
}
#[test]
fn shape_hash_display_is_32_hex() {
let h = ShapeHash128 {
high: 0x1234_5678_90ab_cdef,
low: 0xfedc_ba09_8765_4321,
};
let s = format!("{h}");
assert_eq!(s, "1234567890abcdeffedcba0987654321");
assert_eq!(s.len(), 32);
}
#[test]
fn shape_hash_zero_sentinel() {
assert!(ShapeHash128::default().is_zero());
assert!(!ShapeHash128 { high: 0, low: 1 }.is_zero());
}
#[test]
fn callee_shape_defaults_unresolved() {
assert_eq!(CalleeShape::default(), CalleeShape::Unresolved);
}
#[test]
fn flags_set_and_query() {
let mut f = ShapeFlags::empty();
assert!(!f.is_unhashable() && !f.is_truncated());
f.set_unhashable();
assert!(f.is_unhashable() && !f.is_truncated());
f.set_truncated();
assert!(f.is_unhashable() && f.is_truncated());
}
#[test]
fn default_descriptor_is_empty_but_present() {
let d = ShapeDescriptor::default();
assert_eq!(d.cf_histogram, [0; CF_BUCKET_COUNT]);
assert_eq!(d.minhash, [0; MINHASH_LANES]);
assert_eq!(d.callee_shape, CalleeShape::Unresolved);
assert!(d.shape_hash.is_zero());
assert!(!d.is_unhashable());
assert!(!d.is_truncated());
}
#[test]
fn unhashable_constructor_sets_marker_and_keeps_signature() {
let sig = SignatureShape {
arity_positional: 2,
has_return_annotation: true,
..SignatureShape::default()
};
let d = ShapeDescriptor::unhashable(sig);
assert!(d.is_unhashable());
assert_eq!(d.signature_shape, sig);
assert!(d.shape_hash.is_zero());
assert_eq!(d.cf_histogram, [0; CF_BUCKET_COUNT]);
}
#[test]
fn descriptor_postcard_roundtrip_with_full_minhash() {
let mut d = ShapeDescriptor::default();
for (i, slot) in d.minhash.iter_mut().enumerate() {
*slot = (i as u32).wrapping_mul(2_654_435_761);
}
d.cf_histogram[0] = 7;
d.cf_histogram[CF_BUCKET_COUNT - 1] = 3;
d.signature_shape.arity_positional = 4;
d.signature_shape.has_kwargs = true;
d.callee_shape = CalleeShape::Resolved {
count: 9,
degree_buckets: [1, 2, 3, 4],
};
d.shape_hash = ShapeHash128 { high: 11, low: 22 };
d.flags.set_truncated();
let bytes = postcard::to_allocvec(&d).expect("serialize");
let back: ShapeDescriptor = postcard::from_bytes(&bytes).expect("deserialize");
assert_eq!(d, back);
}
#[test]
fn descriptor_json_roundtrip() {
let d = ShapeDescriptor::default();
let json = serde_json::to_string(&d).expect("to json");
let back: ShapeDescriptor = serde_json::from_str(&json).expect("from json");
assert_eq!(d, back);
}
#[test]
fn constants_are_frozen() {
assert_eq!(CF_BUCKET_COUNT, 15);
assert_eq!(MINHASH_LANES, 64);
assert_eq!(SHAPE_SCHEMA_VERSION, 1);
assert_eq!(MIN_HASHABLE_TOKENS, 4);
}
}