use std::sync::OnceLock;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator, Tree};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RelationshipKind {
Calls,
MethodCall,
Extends,
Implements,
InterfaceExtends,
TypeReference,
}
#[derive(Debug, Clone)]
pub struct RelationshipInfo {
pub from_name: Option<String>,
pub to_name: String,
pub kind: RelationshipKind,
pub line: usize,
}
const CALLS_QUERY: &str = r#"
; Direct call: foo(...)
(call_expression
function: (identifier) @callee_name
arguments: (arguments))
; Method call: obj.method(...)
(call_expression
function: (member_expression
property: (property_identifier) @method_name)
arguments: (arguments))
"#;
const INHERITANCE_QUERY: &str = r#"
; class Foo extends Bar
(class_declaration
name: (type_identifier) @class_name
(class_heritage
(extends_clause
value: (identifier) @extends_name)))
; class Foo implements IBar
(class_declaration
name: (type_identifier) @class_name
(class_heritage
(implements_clause
(type_identifier) @implements_name)))
; interface IFoo extends IBar
(interface_declaration
name: (type_identifier) @iface_name
(extends_type_clause
(type_identifier) @parent_iface_name))
"#;
const TYPE_REF_QUERY: &str = r#"
; Type annotation: const x: SomeType, param: SomeType
(type_annotation
(type_identifier) @type_ref)
"#;
static TS_CALLS_QUERY: OnceLock<Query> = OnceLock::new();
static TS_INHERITANCE_QUERY: OnceLock<Query> = OnceLock::new();
static TS_TYPE_REF_QUERY: OnceLock<Query> = OnceLock::new();
static TSX_CALLS_QUERY: OnceLock<Query> = OnceLock::new();
static TSX_INHERITANCE_QUERY: OnceLock<Query> = OnceLock::new();
static TSX_TYPE_REF_QUERY: OnceLock<Query> = OnceLock::new();
static JS_CALLS_QUERY: OnceLock<Query> = OnceLock::new();
static JS_INHERITANCE_QUERY: OnceLock<Query> = OnceLock::new();
enum LangGroup {
TypeScript,
Tsx,
JavaScript,
}
fn lang_group(language: &Language, is_tsx: bool) -> LangGroup {
match language.name().unwrap_or("") {
"javascript" => LangGroup::JavaScript,
_ => {
if is_tsx {
LangGroup::Tsx
} else {
LangGroup::TypeScript
}
}
}
}
fn calls_query(language: &Language, is_tsx: bool) -> &'static Query {
match lang_group(language, is_tsx) {
LangGroup::TypeScript => TS_CALLS_QUERY
.get_or_init(|| Query::new(language, CALLS_QUERY).expect("invalid TS calls query")),
LangGroup::Tsx => TSX_CALLS_QUERY
.get_or_init(|| Query::new(language, CALLS_QUERY).expect("invalid TSX calls query")),
LangGroup::JavaScript => JS_CALLS_QUERY
.get_or_init(|| Query::new(language, CALLS_QUERY).expect("invalid JS calls query")),
}
}
fn inheritance_query(language: &Language, is_tsx: bool) -> Option<&'static Query> {
match lang_group(language, is_tsx) {
LangGroup::TypeScript => Some(TS_INHERITANCE_QUERY.get_or_init(|| {
Query::new(language, INHERITANCE_QUERY).expect("invalid TS inheritance query")
})),
LangGroup::Tsx => Some(TSX_INHERITANCE_QUERY.get_or_init(|| {
Query::new(language, INHERITANCE_QUERY).expect("invalid TSX inheritance query")
})),
LangGroup::JavaScript => {
Some(JS_INHERITANCE_QUERY.get_or_init(|| {
const JS_INHERITANCE_QUERY: &str = r#"
; class Foo extends Bar (JS class_heritage layout differs from TS)
(class_declaration
name: (identifier) @class_name
(class_heritage
(identifier) @extends_name))
"#;
Query::new(language, JS_INHERITANCE_QUERY).expect("invalid JS inheritance query")
}))
}
}
}
fn type_ref_query(language: &Language, is_tsx: bool) -> Option<&'static Query> {
match lang_group(language, is_tsx) {
LangGroup::TypeScript => Some(TS_TYPE_REF_QUERY.get_or_init(|| {
Query::new(language, TYPE_REF_QUERY).expect("invalid TS type_ref query")
})),
LangGroup::Tsx => Some(TSX_TYPE_REF_QUERY.get_or_init(|| {
Query::new(language, TYPE_REF_QUERY).expect("invalid TSX type_ref query")
})),
LangGroup::JavaScript => None, }
}
fn node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
pub fn extract_relationships(
tree: &Tree,
source: &[u8],
language: &Language,
is_tsx: bool,
) -> Vec<RelationshipInfo> {
let mut results: Vec<RelationshipInfo> = Vec::new();
let mut seen: std::collections::HashSet<(String, usize, String)> =
std::collections::HashSet::new();
macro_rules! push_rel {
($info:expr) => {{
let info: RelationshipInfo = $info;
let key = (info.to_name.clone(), info.line, format!("{:?}", info.kind));
if seen.insert(key) {
results.push(info);
}
}};
}
{
let query = calls_query(language, is_tsx);
let callee_idx = query
.capture_index_for_name("callee_name")
.expect("calls query must have @callee_name");
let method_idx = query
.capture_index_for_name("method_name")
.expect("calls query must have @method_name");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source);
while let Some(m) = matches.next() {
for capture in m.captures {
let text = node_text(capture.node, source);
let line = capture.node.start_position().row + 1;
if capture.index == callee_idx {
push_rel!(RelationshipInfo {
from_name: None,
to_name: text.to_owned(),
kind: RelationshipKind::Calls,
line,
});
} else if capture.index == method_idx {
push_rel!(RelationshipInfo {
from_name: None,
to_name: text.to_owned(),
kind: RelationshipKind::MethodCall,
line,
});
}
}
}
}
if let Some(query) = inheritance_query(language, is_tsx) {
let class_name_idx = query.capture_index_for_name("class_name");
let extends_idx = query.capture_index_for_name("extends_name");
let implements_idx = query.capture_index_for_name("implements_name");
let iface_name_idx = query.capture_index_for_name("iface_name");
let parent_iface_idx = query.capture_index_for_name("parent_iface_name");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source);
while let Some(m) = matches.next() {
let mut class_name: Option<String> = None;
let mut extends_name: Option<(String, usize)> = None;
let mut implements_name: Option<(String, usize)> = None;
let mut iface_name: Option<String> = None;
let mut parent_iface: Option<(String, usize)> = None;
for capture in m.captures {
let text = node_text(capture.node, source).to_owned();
let line = capture.node.start_position().row + 1;
if class_name_idx == Some(capture.index) {
class_name = Some(text);
} else if extends_idx == Some(capture.index) {
extends_name = Some((text, line));
} else if implements_idx == Some(capture.index) {
implements_name = Some((text, line));
} else if iface_name_idx == Some(capture.index) {
iface_name = Some(text);
} else if parent_iface_idx == Some(capture.index) {
parent_iface = Some((text, line));
}
}
if let (Some(from), Some((to, line))) = (&class_name, &extends_name) {
push_rel!(RelationshipInfo {
from_name: Some(from.clone()),
to_name: to.clone(),
kind: RelationshipKind::Extends,
line: *line,
});
}
if let (Some(from), Some((to, line))) = (&class_name, &implements_name) {
push_rel!(RelationshipInfo {
from_name: Some(from.clone()),
to_name: to.clone(),
kind: RelationshipKind::Implements,
line: *line,
});
}
if let (Some(from), Some((to, line))) = (&iface_name, &parent_iface) {
push_rel!(RelationshipInfo {
from_name: Some(from.clone()),
to_name: to.clone(),
kind: RelationshipKind::InterfaceExtends,
line: *line,
});
}
}
}
if let Some(query) = type_ref_query(language, is_tsx) {
let type_ref_idx = query
.capture_index_for_name("type_ref")
.expect("type_ref query must have @type_ref");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source);
while let Some(m) = matches.next() {
for capture in m.captures {
if capture.index == type_ref_idx {
let text = node_text(capture.node, source);
let line = capture.node.start_position().row + 1;
push_rel!(RelationshipInfo {
from_name: None,
to_name: text.to_owned(),
kind: RelationshipKind::TypeReference,
line,
});
}
}
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::languages::language_for_extension;
fn parse_ts(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("ts").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
fn parse_js(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("js").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
fn parse_tsx(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("tsx").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
#[test]
fn test_direct_call_extraction() {
let src = "foo(); bar();";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let calls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Calls)
.collect();
assert_eq!(
calls.len(),
2,
"expected 2 Calls relationships, got {}",
calls.len()
);
let names: Vec<&str> = calls.iter().map(|r| r.to_name.as_str()).collect();
assert!(names.contains(&"foo"), "missing 'foo' call");
assert!(names.contains(&"bar"), "missing 'bar' call");
assert!(
calls.iter().all(|r| r.from_name.is_none()),
"from_name should be None for context-free extraction"
);
}
#[test]
fn test_method_call_extraction() {
let src = "obj.method(); this.render();";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let method_calls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::MethodCall)
.collect();
assert_eq!(method_calls.len(), 2, "expected 2 MethodCall relationships");
let names: Vec<&str> = method_calls.iter().map(|r| r.to_name.as_str()).collect();
assert!(names.contains(&"method"), "missing 'method' call");
assert!(names.contains(&"render"), "missing 'render' call");
}
#[test]
fn test_class_extends_extraction() {
let src = "class Dog extends Animal {}";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Extends)
.collect();
assert_eq!(extends.len(), 1, "expected 1 Extends relationship");
let rel = &extends[0];
assert_eq!(
rel.from_name.as_deref(),
Some("Dog"),
"from_name should be 'Dog'"
);
assert_eq!(rel.to_name, "Animal", "to_name should be 'Animal'");
}
#[test]
fn test_class_implements_extraction() {
let src = "class UserService implements IService {}";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let impls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Implements)
.collect();
assert_eq!(impls.len(), 1, "expected 1 Implements relationship");
let rel = &impls[0];
assert_eq!(
rel.from_name.as_deref(),
Some("UserService"),
"from_name should be 'UserService'"
);
assert_eq!(rel.to_name, "IService", "to_name should be 'IService'");
}
#[test]
fn test_interface_extends_extraction() {
let src = "interface Admin extends User {}";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let iface_extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::InterfaceExtends)
.collect();
assert_eq!(
iface_extends.len(),
1,
"expected 1 InterfaceExtends relationship"
);
let rel = &iface_extends[0];
assert_eq!(
rel.from_name.as_deref(),
Some("Admin"),
"from_name should be 'Admin'"
);
assert_eq!(rel.to_name, "User", "to_name should be 'User'");
}
#[test]
fn test_type_reference_extraction() {
let src = "const x: MyType = {};";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let type_refs: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::TypeReference)
.collect();
assert_eq!(type_refs.len(), 1, "expected 1 TypeReference relationship");
assert_eq!(type_refs[0].to_name, "MyType", "to_name should be 'MyType'");
assert!(type_refs[0].from_name.is_none(), "from_name should be None");
}
#[test]
fn test_combined_relationship_extraction() {
let src = r#"
class Dog extends Animal implements IPet {
bark() {
console.log("Woof");
this.move();
}
}
interface IPet extends IAnimal {}
const owner: Person = {};
"#;
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let calls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Calls)
.collect();
let method_calls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::MethodCall)
.collect();
let extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Extends)
.collect();
let impls: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Implements)
.collect();
let iface_extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::InterfaceExtends)
.collect();
let type_refs: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::TypeReference)
.collect();
assert!(
!calls.is_empty() || !method_calls.is_empty(),
"should find some calls"
);
assert_eq!(extends.len(), 1, "should find class extends Animal");
assert_eq!(impls.len(), 1, "should find class implements IPet");
assert_eq!(
iface_extends.len(),
1,
"should find interface extends IAnimal"
);
assert!(
!type_refs.is_empty(),
"should find type reference to Person"
);
let extends_rel = &extends[0];
assert_eq!(extends_rel.from_name.as_deref(), Some("Dog"));
assert_eq!(extends_rel.to_name, "Animal");
let impl_rel = &impls[0];
assert_eq!(impl_rel.from_name.as_deref(), Some("Dog"));
assert_eq!(impl_rel.to_name, "IPet");
let iface_rel = &iface_extends[0];
assert_eq!(iface_rel.from_name.as_deref(), Some("IPet"));
assert_eq!(iface_rel.to_name, "IAnimal");
}
#[test]
fn test_empty_file_no_relationships() {
let src = "";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
assert!(
rels.is_empty(),
"empty file should produce no relationships"
);
}
#[test]
fn test_no_relationships_in_plain_file() {
let src = "const x = 42;\nconst y = 'hello';";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let significant: Vec<_> = rels
.iter()
.filter(|r| {
matches!(
r.kind,
RelationshipKind::Extends
| RelationshipKind::Implements
| RelationshipKind::InterfaceExtends
)
})
.collect();
assert!(
significant.is_empty(),
"plain variable declarations should not produce inheritance relationships"
);
}
#[test]
fn test_deduplication() {
let src = "foo();";
let (tree, lang) = parse_ts(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let foo_calls: Vec<_> = rels
.iter()
.filter(|r| r.to_name == "foo" && r.kind == RelationshipKind::Calls)
.collect();
assert_eq!(
foo_calls.len(),
1,
"foo() on one line should produce exactly 1 Calls entry"
);
}
#[test]
fn test_tsx_relationships() {
let src = "class Button extends Component { render() { this.setState(); } }";
let (tree, lang) = parse_tsx(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, true);
let extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Extends)
.collect();
assert!(
!extends.is_empty(),
"TSX should find class extends relationship"
);
assert_eq!(extends[0].to_name, "Component");
}
#[test]
fn test_js_class_extends() {
let src = "class Foo extends Bar {}";
let (tree, lang) = parse_js(src);
let rels = extract_relationships(&tree, src.as_bytes(), &lang, false);
let extends: Vec<_> = rels
.iter()
.filter(|r| r.kind == RelationshipKind::Extends)
.collect();
assert_eq!(extends.len(), 1, "JS class extends should be extracted");
assert_eq!(extends[0].from_name.as_deref(), Some("Foo"));
assert_eq!(extends[0].to_name, "Bar");
}
}