use tree_sitter::{Node, Query, QueryCursor, StreamingIterator};
use crate::lang::CustomEdgeDef;
use crate::model::{Relation, RelationKind, Span};
pub fn extract_relations(file: &str, source: &[u8], root: Node, query: &Query) -> Vec<Relation> {
extract_relations_with_custom_edges(file, source, root, query, &[])
}
pub fn extract_relations_with_custom_edges(
file: &str,
source: &[u8],
root: Node,
query: &Query,
custom_edges: &[CustomEdgeDef],
) -> Vec<Relation> {
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, root, source);
let capture_names = query.capture_names();
let mut relations = Vec::new();
while let Some(m) = matches.next() {
let mut rel_kind = None;
let mut source_name = None;
let mut target_name = None;
let mut site_node = None;
let mut receiver_text = None;
let mut custom_source: Option<(String, String)> = None; let mut custom_target: Option<(String, String)> = None; let mut custom_site_node: Option<Node> = None;
let mut custom_edge_name: Option<String> = None;
for capture in m.captures {
let idx = capture.index as usize;
let cap_name = capture_names[idx];
let node = capture.node;
let text = node_text(node, source);
match cap_name {
"call.func" => {
target_name = Some(text);
rel_kind = Some(RelationKind::Calls);
}
"call.site" => {
site_node = Some(node);
}
"call.caller" => {
source_name = Some(text);
}
"call.receiver" => {
receiver_text = Some(text);
}
"import.module" => {
target_name = Some(text);
rel_kind = Some(RelationKind::Imports);
source_name = Some(file.to_string());
}
"import.name" => {
target_name = Some(text);
rel_kind = Some(RelationKind::Imports);
source_name = Some(file.to_string());
}
"inherit.child" => {
source_name = Some(text);
if rel_kind.is_none() {
rel_kind = Some(RelationKind::Inherits);
}
}
"inherit.parent" => {
target_name = Some(text);
rel_kind = Some(RelationKind::Inherits);
}
other => {
if let Some((prefix, suffix)) = other.split_once('.') {
if let Some(edge_def) = custom_edges.iter().find(|e| e.capture == prefix) {
custom_edge_name = Some(edge_def.name.clone());
match suffix {
"source" => {
custom_source = Some((edge_def.name.clone(), text));
custom_site_node = Some(node);
}
"target" => {
custom_target = Some((edge_def.name.clone(), text));
}
"site" => {
custom_site_node = Some(node);
}
_ => {}
}
}
}
}
}
}
if let Some((_, tgt_text)) = custom_target {
let edge_name = if let Some((name, _)) = &custom_source {
name.clone()
} else {
custom_edge_name.unwrap_or_default()
};
if edge_name.is_empty() {
} else {
let src_text = if let Some((_, src)) = custom_source {
src
} else if let Some(site) = custom_site_node {
find_enclosing_function(site, source).unwrap_or_else(|| file.to_string())
} else {
file.to_string()
};
let span = custom_site_node.map(|n| Span {
file: file.to_string(),
start_line: n.start_position().row as u32 + 1,
start_col: n.start_position().column as u32,
end_line: n.end_position().row as u32 + 1,
end_col: n.end_position().column as u32,
});
let source_id = format!("{}::{}", file, src_text);
let target_id = format!("{}::{}", file, tgt_text);
relations.push(Relation {
source_id,
target_id,
kind: RelationKind::Custom(edge_name),
span,
receiver: None,
});
continue;
}
}
if rel_kind == Some(RelationKind::Calls) && source_name.is_none() {
if let Some(site) = site_node {
source_name =
find_enclosing_function(site, source).or_else(|| Some(file.to_string()));
}
}
if rel_kind == Some(RelationKind::Calls) {
if let Some(ref recv) = receiver_text {
if recv == "self" || recv == "this" || recv == "@" {
if let Some(site) = site_node {
if let Some(cls) = find_enclosing_class(site, source) {
receiver_text = Some(cls);
}
}
}
}
}
if let (Some(kind), Some(src), Some(tgt)) = (rel_kind, source_name, target_name) {
let span = site_node.map(|n| Span {
file: file.to_string(),
start_line: n.start_position().row as u32 + 1,
start_col: n.start_position().column as u32,
end_line: n.end_position().row as u32 + 1,
end_col: n.end_position().column as u32,
});
let source_id = if kind == RelationKind::Imports {
src
} else {
format!("{}::{}", file, src)
};
let target_id = format!("{}::{}", file, tgt);
relations.push(Relation {
source_id,
target_id,
kind,
span,
receiver: receiver_text.clone(),
});
}
}
relations
}
fn find_enclosing_function(node: Node, source: &[u8]) -> Option<String> {
let func_kinds = [
"function_definition", "function_item", "function_declaration", "method_declaration", "method_definition", "func_literal", "sub_definition", "property_definition", ];
let sql_container_kinds = [
"create_table", "insert", ];
let mut current = node.parent();
while let Some(n) = current {
if func_kinds.contains(&n.kind()) {
if let Some(name_node) = n.child_by_field_name("name") {
return Some(node_text(name_node, source));
}
}
if n.kind() == "defProc" {
if let Some(header) = n.child_by_field_name("header") {
if let Some(name_node) = header.child_by_field_name("name") {
if name_node.kind() == "genericDot" {
if let Some(rhs) = name_node.child_by_field_name("rhs") {
return Some(node_text(rhs, source));
}
}
return Some(node_text(name_node, source));
}
}
}
if sql_container_kinds.contains(&n.kind()) {
if let Some(obj_ref) = n.child_by_field_name("name") {
return Some(node_text(obj_ref, source));
}
let mut i = 0;
while let Some(child) = n.child(i) {
if child.kind() == "object_reference" {
if let Some(id) = child.child_by_field_name("name") {
return Some(node_text(id, source));
}
}
i += 1;
}
}
if n.kind() == "cte" {
if let Some(id) = n.child(0) {
if id.kind() == "identifier" {
return Some(node_text(id, source));
}
}
}
current = n.parent();
}
None
}
fn find_enclosing_class(node: Node, source: &[u8]) -> Option<String> {
let class_kinds = [
"class_definition", "class_declaration", "class", "class_specifier", "impl_item", "struct_item", "defmodule", ];
let mut current = node.parent();
while let Some(n) = current {
if class_kinds.contains(&n.kind()) {
if let Some(name_node) = n.child_by_field_name("name") {
return Some(node_text(name_node, source));
}
}
if n.kind() == "declClass" || n.kind() == "declIntf" {
if let Some(parent) = n.parent() {
if parent.kind() == "declType" {
if let Some(name_node) = parent.child_by_field_name("name") {
return Some(node_text(name_node, source));
}
}
}
}
current = n.parent();
}
None
}
fn node_text(node: Node, source: &[u8]) -> String {
node.utf8_text(source).unwrap_or("").to_string()
}