use crate::graph::body_hash::{SHAPE_SEED_0, SHAPE_SEED_1};
use crate::graph::unified::storage::shape::{
CF_BUCKET_COUNT, CalleeShape, MIN_HASHABLE_TOKENS, MINHASH_LANES, ShapeDescriptor, ShapeFlags,
ShapeHash128, SignatureShape,
};
use xxhash_rust::xxh64::xxh64;
const LEAF_TOKEN: u32 = u32::MAX;
pub const DEFAULT_SHAPE_NODE_BUDGET: u32 = 4096;
const MINHASH_LANE_MIX: u64 = 0x9E37_79B9_7F4A_7C15;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CfBucket {
Branch = 0,
Loop = 1,
Match = 2,
Try = 3,
Catch = 4,
Throw = 5,
Resource = 6,
Return = 7,
Yield = 8,
Await = 9,
BreakContinue = 10,
Call = 11,
Assign = 12,
Closure = 13,
Comprehension = 14,
}
impl CfBucket {
pub const COUNT: usize = CF_BUCKET_COUNT;
pub const ALL: [CfBucket; CF_BUCKET_COUNT] = [
CfBucket::Branch,
CfBucket::Loop,
CfBucket::Match,
CfBucket::Try,
CfBucket::Catch,
CfBucket::Throw,
CfBucket::Resource,
CfBucket::Return,
CfBucket::Yield,
CfBucket::Await,
CfBucket::BreakContinue,
CfBucket::Call,
CfBucket::Assign,
CfBucket::Closure,
CfBucket::Comprehension,
];
#[must_use]
pub const fn index(self) -> usize {
self as usize
}
}
pub trait ShapeMapping: Send + Sync {
fn cf_bucket(&self, ts_node_kind_id: u16) -> Option<CfBucket>;
fn signature_shape(&self, fn_node: tree_sitter::Node, src: &[u8]) -> SignatureShape;
}
#[derive(Debug, Clone, Copy)]
pub struct ShapeBudget {
pub max_visited_nodes: u32,
}
impl Default for ShapeBudget {
fn default() -> Self {
Self {
max_visited_nodes: DEFAULT_SHAPE_NODE_BUDGET,
}
}
}
fn canonical_token(node: tree_sitter::Node) -> u32 {
if node.is_named() && node.child_count() == 0 {
LEAF_TOKEN
} else {
u32::from(node.kind_id())
}
}
fn lane_seed(i: usize) -> u64 {
SHAPE_SEED_0 ^ (i as u64).wrapping_mul(MINHASH_LANE_MIX)
}
fn wl_label(own: u32, child_tokens: &mut [u32]) -> u64 {
child_tokens.sort_unstable();
let mut buf = Vec::with_capacity(4 + child_tokens.len() * 4);
buf.extend_from_slice(&own.to_le_bytes());
for t in child_tokens.iter() {
buf.extend_from_slice(&t.to_le_bytes());
}
xxh64(&buf, SHAPE_SEED_0)
}
fn body_subtree(fn_node: tree_sitter::Node) -> tree_sitter::Node {
fn_node.child_by_field_name("body").unwrap_or(fn_node)
}
struct WalkState {
token_count: u32,
histogram: [u16; CF_BUCKET_COUNT],
shingles: Vec<(u32, u32)>,
wl_labels: Vec<u64>,
truncated: bool,
}
fn walk_body(
body: tree_sitter::Node,
mapping: &dyn ShapeMapping,
budget: &ShapeBudget,
) -> WalkState {
let mut state = WalkState {
token_count: 0,
histogram: [0; CF_BUCKET_COUNT],
shingles: Vec::new(),
wl_labels: Vec::new(),
truncated: false,
};
let mut stack = vec![body];
let mut visited: u32 = 0;
while let Some(node) = stack.pop() {
if node.is_extra() {
continue;
}
visited += 1;
if visited > budget.max_visited_nodes {
state.truncated = true;
break;
}
if let Some(bucket) = mapping.cf_bucket(node.kind_id()) {
let slot = &mut state.histogram[bucket.index()];
*slot = slot.saturating_add(1);
}
let own = canonical_token(node);
let child_count = node.child_count();
if child_count == 0 {
state.token_count = state.token_count.saturating_add(1);
}
let mut child_tokens: Vec<u32> = Vec::new();
for i in 0..u32::try_from(child_count).unwrap_or(u32::MAX) {
let Some(child) = node.child(i) else { continue };
if child.is_extra() {
continue;
}
let child_token = canonical_token(child);
child_tokens.push(child_token);
state.shingles.push((own, child_token));
stack.push(child);
}
state.wl_labels.push(wl_label(own, &mut child_tokens));
}
state
}
fn shape_hash_of(shingles: &mut [(u32, u32)]) -> ShapeHash128 {
shingles.sort_unstable();
let mut buf = Vec::with_capacity(shingles.len() * 8);
for (a, b) in shingles.iter() {
buf.extend_from_slice(&a.to_le_bytes());
buf.extend_from_slice(&b.to_le_bytes());
}
ShapeHash128 {
high: xxh64(&buf, SHAPE_SEED_0),
low: xxh64(&buf, SHAPE_SEED_1),
}
}
fn minhash_of(wl_labels: &[u64]) -> [u32; MINHASH_LANES] {
let mut sketch = [u32::MAX; MINHASH_LANES];
for &label in wl_labels {
let bytes = label.to_le_bytes();
for (i, lane) in sketch.iter_mut().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let h = xxh64(&bytes, lane_seed(i)) as u32;
if h < *lane {
*lane = h;
}
}
}
sketch
}
#[must_use]
pub fn compute_shape_descriptor(
fn_node: tree_sitter::Node,
src: &[u8],
mapping: &dyn ShapeMapping,
budget: &ShapeBudget,
) -> ShapeDescriptor {
let signature_shape = mapping.signature_shape(fn_node, src);
let body = body_subtree(fn_node);
let mut state = walk_body(body, mapping, budget);
if !state.truncated && state.token_count < u32::from(MIN_HASHABLE_TOKENS) {
return ShapeDescriptor::unhashable(signature_shape);
}
let shape_hash = shape_hash_of(&mut state.shingles);
let minhash = minhash_of(&state.wl_labels);
let mut flags = ShapeFlags::empty();
if state.truncated {
flags.set_truncated();
}
ShapeDescriptor {
cf_histogram: state.histogram,
signature_shape,
callee_shape: CalleeShape::Unresolved,
shape_hash,
minhash,
flags,
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestRustMapping;
impl ShapeMapping for TestRustMapping {
fn cf_bucket(&self, ts_node_kind_id: u16) -> Option<CfBucket> {
let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
let name = lang.node_kind_for_id(ts_node_kind_id)?;
let bucket = match name {
"if_expression" => CfBucket::Branch,
"for_expression" | "while_expression" | "loop_expression" => CfBucket::Loop,
"match_expression" => CfBucket::Match,
"return_expression" => CfBucket::Return,
"await_expression" => CfBucket::Await,
"break_expression" | "continue_expression" => CfBucket::BreakContinue,
"call_expression" | "macro_invocation" => CfBucket::Call,
"let_declaration" | "assignment_expression" => CfBucket::Assign,
"closure_expression" => CfBucket::Closure,
_ => return None,
};
Some(bucket)
}
fn signature_shape(&self, fn_node: tree_sitter::Node, _src: &[u8]) -> SignatureShape {
let mut shape = SignatureShape::default();
if let Some(params) = fn_node.child_by_field_name("parameters") {
let mut cursor = params.walk();
for child in params.named_children(&mut cursor) {
if child.kind() == "parameter" || child.kind() == "self_parameter" {
shape.arity_positional = shape.arity_positional.saturating_add(1);
}
}
}
shape.has_return_annotation = fn_node.child_by_field_name("return_type").is_some();
shape
}
}
fn parse(src: &str) -> tree_sitter::Tree {
let mut parser = tree_sitter::Parser::new();
let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
parser.set_language(&lang).expect("load rust grammar");
parser.parse(src, None).expect("parse")
}
fn first_function<'t>(tree: &'t tree_sitter::Tree) -> tree_sitter::Node<'t> {
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.named_children(&mut cursor) {
if child.kind() == "function_item" {
return child;
}
}
panic!("no function_item in source");
}
fn descriptor_of(src: &str) -> ShapeDescriptor {
let tree = parse(src);
let func = first_function(&tree);
compute_shape_descriptor(
func,
src.as_bytes(),
&TestRustMapping,
&ShapeBudget::default(),
)
}
#[test]
fn cf_bucket_discriminants_are_frozen() {
assert_eq!(CfBucket::COUNT, CF_BUCKET_COUNT);
assert_eq!(CfBucket::ALL.len(), CF_BUCKET_COUNT);
for (i, bucket) in CfBucket::ALL.iter().enumerate() {
assert_eq!(bucket.index(), i, "discriminant must equal histogram index");
}
assert_eq!(CfBucket::Branch.index(), 0);
assert_eq!(CfBucket::Comprehension.index(), CF_BUCKET_COUNT - 1);
}
#[test]
fn ac2_rename_invariance_identifiers_and_literals() {
let a = r#"
fn original(input: i32) -> i32 {
let total = input + 42;
if total > 100 {
return total;
}
helper(total)
}
"#;
let b = r#"
fn renamed(arg: i32) -> i32 {
let sum = arg + 7;
if sum > 999 {
return sum;
}
other(sum)
}
"#;
let da = descriptor_of(a);
let db = descriptor_of(b);
assert_eq!(
da.cf_histogram, db.cf_histogram,
"histogram must be rename-invariant"
);
assert_eq!(
da.shape_hash, db.shape_hash,
"shape_hash must be rename-invariant"
);
assert_eq!(da.minhash, db.minhash, "minhash must be rename-invariant");
assert!(!da.is_unhashable());
}
#[test]
fn ac3_whitespace_and_comment_invariance() {
let plain = "fn f(x: i32) -> i32 { let y = x + 1; if y > 0 { return y; } y }";
let formatted = r#"
fn f(x: i32) -> i32 {
// leading comment
let y = x + 1; // trailing comment
if y > 0 {
/* block comment */
return y;
}
y
}
"#;
let dp = descriptor_of(plain);
let df = descriptor_of(formatted);
assert_eq!(dp.cf_histogram, df.cf_histogram);
assert_eq!(
dp.shape_hash, df.shape_hash,
"comments/whitespace must not change shape_hash"
);
assert_eq!(dp.minhash, df.minhash);
}
#[test]
fn different_structure_changes_shape_hash() {
let loops = descriptor_of("fn a(n: i32) { for i in 0..n { sink(i); } }");
let branch = descriptor_of("fn a(n: i32) { if n > 0 { sink(n); } }");
assert_ne!(loops.shape_hash, branch.shape_hash);
assert_ne!(loops.cf_histogram, branch.cf_histogram);
}
#[test]
fn histogram_counts_control_flow_kinds() {
let d = descriptor_of(
"fn a(n: i32) -> i32 { if n > 0 { return n; } for i in 0..n { f(i); } n }",
);
assert_eq!(d.cf_histogram[CfBucket::Branch.index()], 1);
assert_eq!(d.cf_histogram[CfBucket::Loop.index()], 1);
assert_eq!(d.cf_histogram[CfBucket::Return.index()], 1);
assert!(d.cf_histogram[CfBucket::Call.index()] >= 1);
}
#[test]
fn signature_shape_reads_parameters_and_return() {
let d = descriptor_of("fn a(x: i32, y: i32) -> i32 { x + y + 1 }");
assert_eq!(d.signature_shape.arity_positional, 2);
assert!(d.signature_shape.has_return_annotation);
}
#[test]
fn sub_four_token_body_is_unhashable_not_none() {
let d = descriptor_of("fn tiny() {}");
assert!(
d.is_unhashable(),
"tiny body must carry the honest unhashable marker"
);
assert!(d.shape_hash.is_zero());
assert_eq!(d.cf_histogram, [0; CF_BUCKET_COUNT]);
}
#[test]
fn determinism_two_computes_match() {
let src = "fn a(n: i32) -> i32 { let mut s = 0; for i in 0..n { s += g(i); } s }";
assert_eq!(descriptor_of(src), descriptor_of(src));
}
#[test]
fn over_budget_body_is_truncated() {
let src = "fn a(n: i32) -> i32 { let y = n + 1; if y > 0 { return y; } y }";
let tree = parse(src);
let func = first_function(&tree);
let tight = ShapeBudget {
max_visited_nodes: 3,
};
let d = compute_shape_descriptor(func, src.as_bytes(), &TestRustMapping, &tight);
assert!(
d.is_truncated(),
"a body past the node budget must be marked truncated"
);
}
#[test]
fn callee_shape_is_unresolved_this_effort() {
let d = descriptor_of("fn a(n: i32) -> i32 { if n > 0 { return n; } n }");
assert_eq!(d.callee_shape, CalleeShape::Unresolved);
}
}