use tree_sitter::{Node, Tree};
pub use trusty_symgraph::contracts::EdgeKind;
pub use trusty_symgraph::{fact_hash_str, EntityType, RawEntity};
pub mod tables {
pub use trusty_symgraph::contracts::tables::*;
}
fn node_text(node: Node<'_>, src: &[u8]) -> String {
std::str::from_utf8(&src[node.start_byte()..node.end_byte()])
.unwrap_or("")
.to_string()
}
pub struct EntityExtractor;
impl EntityExtractor {
pub fn extract(tree: &Tree, src: &[u8], file: &str, lang: &str) -> Vec<RawEntity> {
extract_entities(tree, src, file, lang)
}
}
pub fn extract_entities(tree: &Tree, src: &[u8], file: &str, lang: &str) -> Vec<RawEntity> {
match lang {
"rust" => extract_rust(tree, src, file),
"python" | "javascript" | "typescript" | "go" | "java" | "c" | "cpp" => {
tracing::debug!("entity extraction not fully implemented for {lang}");
extract_universal(tree, src, file)
}
_ => Vec::new(),
}
}
fn extract_universal(tree: &Tree, src: &[u8], file: &str) -> Vec<RawEntity> {
let mut out = Vec::new();
let mut stack: Vec<Node> = vec![tree.root_node()];
while let Some(node) = stack.pop() {
let kind = node.kind();
if kind == "type_identifier" || kind == "type" {
let text = node_text(node, src);
if !text.is_empty() {
out.push(RawEntity::new(
EntityType::NamedType,
text,
(node.start_byte(), node.end_byte()),
file,
node.start_position().row + 1,
));
}
} else if kind == "scoped_identifier" || kind == "qualified_identifier" {
let text = node_text(node, src);
if text.contains("::") || text.contains('.') {
out.push(RawEntity::new(
EntityType::ModulePath,
text,
(node.start_byte(), node.end_byte()),
file,
node.start_position().row + 1,
));
}
}
let mut walker = node.walk();
for child in node.children(&mut walker) {
stack.push(child);
}
}
out
}
fn extract_rust(tree: &Tree, src: &[u8], file: &str) -> Vec<RawEntity> {
let mut out = Vec::new();
let root = tree.root_node();
let mut top_cursor = root.walk();
for child in root.children(&mut top_cursor) {
if child.kind() == "use_declaration" {
let text = node_text(child, src);
let trimmed = text.trim_start_matches("use ").trim_end_matches(';').trim();
let first = trimmed
.split(|c: char| c == ':' || c.is_whitespace() || c == '{' || c == ',')
.find(|s| !s.is_empty())
.unwrap_or("");
let line = child.start_position().row + 1;
let span = (child.start_byte(), child.end_byte());
if !first.is_empty()
&& !matches!(first, "crate" | "super" | "self" | "std" | "core" | "alloc")
{
out.push(RawEntity::new(
EntityType::ExternalCrate,
first.to_string(),
span,
file,
line,
));
}
out.push(RawEntity::new(
EntityType::ModulePath,
trimmed.to_string(),
span,
file,
line,
));
}
}
walk_rust(root, src, file, false, &mut out);
out
}
fn walk_rust(node: Node<'_>, src: &[u8], file: &str, in_test_fn: bool, out: &mut Vec<RawEntity>) {
let kind = node.kind();
let line = node.start_position().row + 1;
let span = (node.start_byte(), node.end_byte());
match kind {
"type_identifier" => {
let t = node_text(node, src);
if !t.is_empty() {
out.push(RawEntity::new(EntityType::NamedType, t, span, file, line));
}
}
"trait_bounds" => {
let t = node_text(node, src);
out.push(RawEntity::new(EntityType::TraitBound, t, span, file, line));
}
"scoped_identifier" => {
let t = node_text(node, src);
if t.contains("::") {
out.push(RawEntity::new(EntityType::ModulePath, t, span, file, line));
}
}
"macro_invocation" => {
if let Some(name_node) = node.child_by_field_name("macro") {
let name = node_text(name_node, src);
let last = name.rsplit("::").next().unwrap_or(&name).trim();
if matches!(last, "bail" | "anyhow" | "panic" | "unwrap" | "expect") {
let t = node_text(node, src);
out.push(RawEntity::new(
EntityType::ErrorVariant,
t,
span,
file,
line,
));
}
}
}
"call_expression" => {
if let Some(func) = node.child_by_field_name("function") {
let txt = node_text(func, src);
let last = txt.rsplit('.').next().unwrap_or(&txt);
if matches!(last, "unwrap" | "expect") {
let t = node_text(node, src);
out.push(RawEntity::new(
EntityType::ErrorVariant,
t,
span,
file,
line,
));
}
}
}
"attribute_item" | "inner_attribute_item" => {
let t = node_text(node, src);
out.push(RawEntity::new(EntityType::Annotation, t, span, file, line));
}
"string_literal" => {
let t = node_text(node, src);
let inner = t.trim_matches('"');
if inner.len() > 10 {
out.push(RawEntity::new(
EntityType::LiteralString,
t,
span,
file,
line,
));
}
}
"type_item" => {
let t = node_text(node, src);
out.push(RawEntity::new(EntityType::TypeAlias, t, span, file, line));
}
"identifier" if in_test_fn => {
let t = node_text(node, src);
if !t.is_empty() {
out.push(RawEntity::new(
EntityType::TestRelation,
t,
span,
file,
line,
));
}
}
_ => {}
}
let entering_test_fn = kind == "function_item" && function_has_test_attr(node, src);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_rust(child, src, file, in_test_fn || entering_test_fn, out);
}
}
fn function_has_test_attr(node: Node<'_>, src: &[u8]) -> bool {
let mut prev = node.prev_sibling();
while let Some(p) = prev {
let k = p.kind();
if k == "attribute_item" || k == "inner_attribute_item" {
let t = node_text(p, src);
if t.contains("test") {
return true;
}
prev = p.prev_sibling();
} else if k == "line_comment" || k == "block_comment" {
prev = p.prev_sibling();
} else {
break;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_rust(src: &str) -> tree_sitter::Tree {
let mut p = Parser::new();
let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
p.set_language(&lang).expect("set rust language");
p.parse(src, None).expect("parse")
}
#[test]
fn test_extractor_emits_named_type_modulepath_and_test_relation() {
let src = "use std::sync::Arc;\n\
struct MyType { v: u32 }\n\
#[test]\n\
fn it_works() { let _ = Arc::new(MyType { v: 1 }); }\n";
let tree = parse_rust(src);
let ents = EntityExtractor::extract(&tree, src.as_bytes(), "x.rs", "rust");
assert!(
ents.iter()
.any(|e| matches!(e.entity_type, EntityType::NamedType) && e.text == "MyType"),
"expected NamedType=MyType in {ents:?}"
);
assert!(
ents.iter()
.any(|e| matches!(e.entity_type, EntityType::ModulePath)),
"expected at least one ModulePath in {ents:?}"
);
assert!(
ents.iter()
.any(|e| matches!(e.entity_type, EntityType::TestRelation)),
"expected TestRelation identifiers from #[test] fn body in {ents:?}"
);
}
}