use tree_sitter::{Node, Parser};
use crate::error::{CodegraphError, Result};
use crate::graph::types::{
Binding, BindingKind, ByteSpan, FileFacts, RefRole, Reference, Scope, ScopeId, ScopeKind,
Symbol, SymbolKind, Visibility,
};
use crate::lang::Language;
use crate::symbol::Descriptor;
use super::{
ExtractCtx, Extractor, MIN_REF_LEN, attach_reference_scopes, collect_call_references,
definition_bindings, field_text, import_bindings, innermost_scope, make_symbol, node_span,
node_text, one_line_signature, push_binding, push_import_ref, push_ref, push_scope,
};
const CALL_QUERY: &str = r#"
[
(function_call name: (identifier) @callee)
(function_call name: (dot_index_expression table: (identifier) @qualifier field: (identifier) @callee))
(function_call name: (method_index_expression table: (identifier) @qualifier method: (identifier) @callee))
]
"#;
pub struct LuaExtractor;
impl Extractor for LuaExtractor {
fn lang(&self) -> Language {
Language::Lua
}
fn extract(&self, source: &str, file: &str) -> Result<FileFacts> {
extract_lua_family(source, file, Language::Lua, crate::grammar::lua())
}
}
pub(crate) fn extract_lua_family(
source: &str,
file: &str,
lang: Language,
grammar: tree_sitter::Language,
) -> Result<FileFacts> {
let mut parser = Parser::new();
parser
.set_language(&grammar)
.map_err(|_| CodegraphError::Parse {
path: file.to_owned(),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| CodegraphError::Parse {
path: file.to_owned(),
})?;
let root = tree.root_node();
let bytes = source.as_bytes();
let namespaces = lua_namespaces(file);
let ctx = ExtractCtx { bytes, file, lang };
let defs = collect_symbols(&root, &ctx, &namespaces);
let def_bindings = definition_bindings(&defs);
let mut symbols = defs;
let mod_sym = super::module_symbol(lang, &namespaces, file, source.len());
let module_id = mod_sym.id.to_scip_string();
symbols.push(mod_sym);
let mut references = collect_call_references(&root, &grammar, CALL_QUERY, lang, bytes, file)?;
references.retain(|r| r.name != "require");
collect_require_imports(&root, bytes, file, &mut references, &module_id);
collect_read_references(&root, bytes, file, &mut references);
collect_write_references(&root, bytes, file, &mut references);
let scopes = collect_scopes(&root, source.len());
attach_reference_scopes(&mut references, &scopes);
let mut bindings = collect_bindings(&root, bytes, &scopes);
bindings.extend(def_bindings);
bindings.extend(import_bindings(&references, &scopes));
Ok(FileFacts {
file: file.to_owned(),
lang: lang.as_str().to_owned(),
symbols,
references,
scopes,
bindings,
ffi_exports: Vec::new(),
})
}
fn lua_namespaces(file: &str) -> Vec<String> {
let p = file
.strip_suffix(".luau")
.or_else(|| file.strip_suffix(".lua"))
.unwrap_or(file);
let p = p
.strip_prefix("luau/")
.or_else(|| p.strip_prefix("lua/"))
.or_else(|| p.strip_prefix("src/"))
.unwrap_or(p);
p.split('/')
.filter(|s| !s.is_empty())
.map(str::to_owned)
.collect()
}
fn collect_symbols(root: &Node, ctx: &ExtractCtx, namespaces: &[String]) -> Vec<Symbol> {
let ns_descriptors: Vec<Descriptor> = namespaces
.iter()
.cloned()
.map(Descriptor::Namespace)
.collect();
let mut out = Vec::new();
collect_chunk(root, ctx, &ns_descriptors, &mut out);
out
}
fn collect_chunk(node: &Node, ctx: &ExtractCtx, prefix: &[Descriptor], out: &mut Vec<Symbol>) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"function_declaration" => {
collect_function_declaration(&child, ctx, prefix, out);
}
"variable_declaration" => {
collect_variable_declaration(&child, ctx, prefix, out);
}
"assignment_statement" => {
collect_assignment(&child, ctx, prefix, out);
}
"type_definition" => {
collect_type_definition(&child, ctx, prefix, out);
}
_ => {}
}
}
}
fn collect_function_declaration(
node: &Node,
ctx: &ExtractCtx,
prefix: &[Descriptor],
out: &mut Vec<Symbol>,
) {
let name_node = match node.child_by_field_name("name") {
Some(n) => n,
None => return,
};
match name_node.kind() {
"identifier" => {
let name = node_text(&name_node, ctx.bytes).to_owned();
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Method {
name: name.clone(),
disambiguator: String::new(),
});
let sig = one_line_signature(node_text(node, ctx.bytes), &['{', '(']);
out.push(make_symbol(
ctx,
node,
name,
SymbolKind::Function,
Visibility::Unknown,
descriptors,
sig,
));
}
"dot_index_expression" | "method_index_expression" => {
let table_field = if name_node.kind() == "dot_index_expression" {
("table", "field")
} else {
("table", "method")
};
let table = match field_text(&name_node, table_field.0, ctx.bytes) {
Some(t) => t,
None => return,
};
let method = match field_text(&name_node, table_field.1, ctx.bytes) {
Some(m) => m,
None => return,
};
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Type(table));
descriptors.push(Descriptor::Method {
name: method.clone(),
disambiguator: String::new(),
});
let sig = one_line_signature(node_text(node, ctx.bytes), &['{', '(']);
out.push(make_symbol(
ctx,
node,
method,
SymbolKind::Method,
Visibility::Unknown,
descriptors,
sig,
));
}
_ => {}
}
}
fn collect_variable_declaration(
node: &Node,
ctx: &ExtractCtx,
prefix: &[Descriptor],
out: &mut Vec<Symbol>,
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "assignment_statement" {
collect_assignment(&child, ctx, prefix, out);
}
}
}
fn collect_assignment(node: &Node, ctx: &ExtractCtx, prefix: &[Descriptor], out: &mut Vec<Symbol>) {
let names: Vec<Node> = node
.children(&mut node.walk())
.find(|c| c.kind() == "variable_list")
.map(|vl| {
vl.children(&mut vl.walk())
.filter(|c| c.kind() == "identifier")
.collect()
})
.unwrap_or_default();
let values: Vec<Node> = node
.children(&mut node.walk())
.find(|c| c.kind() == "expression_list")
.map(|el| {
el.children(&mut el.walk())
.filter(|c| !matches!(c.kind(), "," | " "))
.collect()
})
.unwrap_or_default();
for (i, name_node) in names.iter().enumerate() {
let name = node_text(name_node, ctx.bytes).to_owned();
let value_opt = values.get(i);
emit_local_symbol(name, value_opt, node, ctx, prefix, out);
}
}
fn emit_local_symbol(
name: String,
value_opt: Option<&Node>,
decl_node: &Node,
ctx: &ExtractCtx,
prefix: &[Descriptor],
out: &mut Vec<Symbol>,
) {
let (kind, descriptor) = match value_opt.map(|v| v.kind()) {
Some("function_definition") => (
SymbolKind::Function,
Descriptor::Method {
name: name.clone(),
disambiguator: String::new(),
},
),
Some("table_constructor") => (SymbolKind::Module, Descriptor::Type(name.clone())),
_ => (SymbolKind::Static, Descriptor::Term(name.clone())),
};
let mut descriptors = prefix.to_vec();
descriptors.push(descriptor);
let sig = one_line_signature(node_text(decl_node, ctx.bytes), &['{', '=']);
out.push(make_symbol(
ctx,
decl_node,
name,
kind,
Visibility::Unknown,
descriptors.clone(),
sig,
));
if let Some(val) = value_opt {
if val.kind() == "table_constructor" {
collect_table_fields(val, &descriptors, ctx, out);
}
}
}
fn collect_table_fields(
node: &Node,
type_prefix: &[Descriptor],
ctx: &ExtractCtx,
out: &mut Vec<Symbol>,
) {
for child in node.children(&mut node.walk()) {
if child.kind() != "field" {
continue;
}
let Some(fname) = field_text(&child, "name", ctx.bytes) else {
continue;
};
let value_kind = child.child_by_field_name("value").map_or("", |v| v.kind());
let (kind, descriptor) = if value_kind == "function_definition" {
(
SymbolKind::Method,
Descriptor::Method {
name: fname.clone(),
disambiguator: String::new(),
},
)
} else {
(SymbolKind::Static, Descriptor::Term(fname.clone()))
};
let mut descriptors = type_prefix.to_vec();
descriptors.push(descriptor);
let sig = one_line_signature(node_text(&child, ctx.bytes), &['{', '=']);
out.push(make_symbol(
ctx,
&child,
fname,
kind,
Visibility::Unknown,
descriptors,
sig,
));
}
}
fn collect_type_definition(
node: &Node,
ctx: &ExtractCtx,
prefix: &[Descriptor],
out: &mut Vec<Symbol>,
) {
let Some(name_node) = node.child_by_field_name("name") else {
return;
};
let name = node_text(&name_node, ctx.bytes).to_owned();
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Type(name.clone()));
let sig = one_line_signature(node_text(node, ctx.bytes), &['=']);
out.push(make_symbol(
ctx,
node,
name,
SymbolKind::TypeAlias,
Visibility::Unknown,
descriptors,
sig,
));
}
fn collect_require_imports(
node: &Node,
bytes: &[u8],
file: &str,
out: &mut Vec<Reference>,
module_id: &str,
) {
if node.kind() == "function_call" {
if let Some(name_node) = node.child_by_field_name("name") {
if name_node.kind() == "identifier" && node_text(&name_node, bytes) == "require" {
if let Some(args) = node.child_by_field_name("arguments") {
if let Some(from_path) = extract_string_arg(&args, bytes) {
let leaf = from_path.rsplit('.').next().unwrap_or(&from_path);
if leaf.len() >= MIN_REF_LEN {
push_import_ref(out, leaf, &name_node, file, module_id, &from_path);
}
}
else if let Some(from_path) = extract_dot_expr_arg(&args, bytes) {
let leaf = from_path.rsplit('.').next().unwrap_or(&from_path);
if leaf.len() >= MIN_REF_LEN {
push_import_ref(out, leaf, &name_node, file, module_id, &from_path);
}
}
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_require_imports(&child, bytes, file, out, module_id);
}
}
fn extract_dot_expr_arg(args: &Node, bytes: &[u8]) -> Option<String> {
for child in args.children(&mut args.walk()) {
if child.kind() == "dot_index_expression" {
return Some(node_text(&child, bytes).to_owned());
}
}
None
}
fn extract_string_arg(args: &Node, bytes: &[u8]) -> Option<String> {
for child in args.children(&mut args.walk()) {
if child.kind() == "string" {
for inner in child.children(&mut child.walk()) {
if inner.kind() == "string_content" {
return Some(node_text(&inner, bytes).to_owned());
}
}
let raw = node_text(&child, bytes);
let stripped = raw
.strip_prefix('\'')
.and_then(|s| s.strip_suffix('\''))
.or_else(|| raw.strip_prefix('"').and_then(|s| s.strip_suffix('"')))
.unwrap_or(raw);
if !stripped.is_empty() {
return Some(stripped.to_owned());
}
}
}
None
}
fn is_non_read_position(node: &Node) -> bool {
let parent = match node.parent() {
Some(p) => p,
None => return true,
};
match parent.kind() {
"function_call" => parent.child_by_field_name("name").as_ref() == Some(node),
"dot_index_expression" => parent.child_by_field_name("field").as_ref() == Some(node),
"method_index_expression" => parent.child_by_field_name("method").as_ref() == Some(node),
"variable_list" => true,
"parameters" => true,
"function_declaration" => parent.child_by_field_name("name").as_ref() == Some(node),
"field" => parent.child_by_field_name("name").as_ref() == Some(node),
_ => false,
}
}
fn collect_read_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "identifier" {
let name = node_text(node, bytes);
if name.len() >= MIN_REF_LEN && !is_non_read_position(node) {
push_ref(out, name, node, file, RefRole::Read);
}
return;
}
for child in node.children(&mut node.walk()) {
collect_read_references(&child, bytes, file, out);
}
}
fn collect_write_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "assignment_statement"
&& !matches!(
node.parent().map(|p| p.kind()),
Some("variable_declaration") | Some("local_declaration")
)
{
if let Some(vl) = node
.children(&mut node.walk())
.find(|c| c.kind() == "variable_list")
{
for target in vl.children(&mut vl.walk()) {
if target.kind() == "identifier" {
let name = node_text(&target, bytes);
if name.len() >= MIN_REF_LEN {
push_ref(out, name, &target, file, RefRole::Write);
}
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_write_references(&child, bytes, file, out);
}
}
fn collect_scopes(root: &Node, source_len: usize) -> Vec<Scope> {
let mut scopes = Vec::new();
push_scope(
&mut scopes,
None,
ByteSpan {
start: 0,
end: source_len,
},
ScopeKind::Module,
);
for child in root.children(&mut root.walk()) {
scope_dfs(&child, 0, &mut scopes);
}
scopes
}
fn scope_dfs(node: &Node, parent_id: ScopeId, scopes: &mut Vec<Scope>) {
match node.kind() {
"function_declaration" | "function_definition" => {
let fn_id = push_scope(
scopes,
Some(parent_id),
node_span(node),
ScopeKind::Function,
);
if let Some(body) = node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, fn_id, scopes);
}
}
}
"block" => {
let block_id = push_scope(scopes, Some(parent_id), node_span(node), ScopeKind::Block);
for child in node.children(&mut node.walk()) {
scope_dfs(&child, block_id, scopes);
}
}
_ => {
for child in node.children(&mut node.walk()) {
scope_dfs(&child, parent_id, scopes);
}
}
}
}
fn collect_bindings(root: &Node, bytes: &[u8], scopes: &[Scope]) -> Vec<Binding> {
let mut out = Vec::new();
collect_bindings_dfs(root, bytes, scopes, &mut out);
out
}
fn collect_bindings_dfs(node: &Node, bytes: &[u8], scopes: &[Scope], out: &mut Vec<Binding>) {
match node.kind() {
"function_declaration" | "function_definition" => {
if let Some(params) = node.child_by_field_name("parameters") {
for child in params.named_children(&mut params.walk()) {
if child.kind() == "identifier" {
let name = node_text(&child, bytes).to_owned();
let intro = child.start_byte();
if name.len() >= MIN_REF_LEN && innermost_scope(intro, scopes) != Some(0) {
push_binding(out, name, intro, BindingKind::Param, scopes);
}
}
}
}
}
"local_declaration" => {
let mut cur1 = node.walk();
for inner in node.children(&mut cur1) {
if inner.kind() == "variable_declaration" {
let mut cur2 = inner.walk();
let vl_opt = inner
.children(&mut cur2)
.find(|c| c.kind() == "variable_list");
if let Some(vl) = vl_opt {
let mut cur3 = vl.walk();
for id in vl.children(&mut cur3) {
if id.kind() == "identifier" {
let name = node_text(&id, bytes).to_owned();
let intro = id.start_byte();
if name.len() >= MIN_REF_LEN
&& innermost_scope(intro, scopes) != Some(0)
{
push_binding(out, name, intro, BindingKind::Local, scopes);
}
}
}
}
}
}
}
_ => {}
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::types::RefRole;
fn extract(src: &str, file: &str) -> FileFacts {
LuaExtractor.extract(src, file).unwrap()
}
fn by_name(facts: &FileFacts, name: &str) -> Option<Symbol> {
facts.symbols.iter().find(|s| s.name == name).cloned()
}
#[test]
fn global_function_is_extracted() {
let src = "function foo() end";
let facts = extract(src, "src/util.lua");
let foo = by_name(&facts, "foo").unwrap();
assert_eq!(foo.kind, SymbolKind::Function);
let scip = foo.id.to_scip_string();
assert!(
scip.contains("util") && scip.contains("foo"),
"unexpected SCIP string: {scip}"
);
assert_eq!(facts.lang, "lua");
}
#[test]
fn table_dot_method_is_extracted_as_method_under_type() {
let src = "function M.baz(x) end";
let facts = extract(src, "src/util.lua");
let baz = by_name(&facts, "baz").unwrap();
assert_eq!(baz.kind, SymbolKind::Method);
let scip = baz.id.to_scip_string();
assert!(
scip.contains("M#") && scip.contains("baz"),
"unexpected SCIP string: {scip}"
);
}
#[test]
fn table_colon_method_is_extracted_as_method_under_type() {
let src = "function M:qux() end";
let facts = extract(src, "src/util.lua");
let qux = by_name(&facts, "qux").unwrap();
assert_eq!(qux.kind, SymbolKind::Method);
let scip = qux.id.to_scip_string();
assert!(
scip.contains("M#") && scip.contains("qux"),
"unexpected SCIP string: {scip}"
);
}
#[test]
fn local_function_is_extracted_as_function() {
let src = "local function bar() end";
let facts = extract(src, "src/util.lua");
let bar = by_name(&facts, "bar").unwrap();
assert_eq!(bar.kind, SymbolKind::Function);
}
#[test]
fn local_table_is_extracted_as_module() {
let src = "local M = {}";
let facts = extract(src, "src/util.lua");
let m = by_name(&facts, "M").unwrap();
assert_eq!(m.kind, SymbolKind::Module);
}
#[test]
fn free_call_is_captured_as_call_ref() {
let src = "function run() foo() end";
let facts = extract(src, "src/util.lua");
let call_ref = facts.references.iter().find(|r| r.name == "foo").unwrap();
assert_eq!(call_ref.role, RefRole::Call);
}
#[test]
fn member_call_captures_qualifier() {
let src = "function run() a.bar() end";
let facts = extract(src, "src/util.lua");
let bar_ref = facts
.references
.iter()
.find(|r| r.name == "bar")
.expect("expected Call ref for 'bar'");
assert_eq!(bar_ref.role, RefRole::Call);
assert_eq!(
bar_ref.qualifier.as_deref(),
Some("a"),
"expected qualifier 'a' on the bar call ref"
);
}
#[test]
fn require_produces_import_reference() {
let src = "local sub = require('pkg.sub')";
let facts = extract(src, "src/util.lua");
let import_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::Import)
.expect("expected an Import ref from require");
assert_eq!(import_ref.name, "sub");
assert!(
import_ref
.from_path
.as_deref()
.is_some_and(|p| p.contains("pkg.sub")),
"from_path should contain 'pkg.sub', got {:?}",
import_ref.from_path
);
}
#[test]
fn require_is_not_emitted_as_plain_call() {
let src = "local sub = require('pkg.sub')";
let facts = extract(src, "src/util.lua");
let require_calls: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Call && r.name == "require")
.collect();
assert!(
require_calls.is_empty(),
"require should not appear as a Call ref"
);
}
fn has_ref(facts: &FileFacts, role: RefRole, name: &str) -> bool {
facts
.references
.iter()
.any(|r| r.role == role && r.name == name)
}
#[test]
fn reassignment_emits_write_and_reads() {
let src = "function run() local count = 0 count = count + bonus end";
let facts = extract(src, "src/util.lua");
assert!(
has_ref(&facts, RefRole::Write, "count"),
"expected a Write ref for the assignment target 'count'"
);
assert!(
has_ref(&facts, RefRole::Read, "bonus"),
"expected a Read ref for 'bonus' on the RHS"
);
assert!(
has_ref(&facts, RefRole::Read, "count"),
"expected a Read ref for the RHS use of 'count'"
);
}
#[test]
fn local_declaration_is_not_a_write() {
let src = "function run() local total = compute() end";
let facts = extract(src, "src/util.lua");
assert!(
!has_ref(&facts, RefRole::Write, "total"),
"a `local` declaration must not emit a Write ref"
);
assert!(
!has_ref(&facts, RefRole::Read, "total"),
"the declared name must not be emitted as a Read"
);
}
#[test]
fn read_of_global_in_call_arg() {
let src = "function run() print(config) end";
let facts = extract(src, "src/util.lua");
assert!(
has_ref(&facts, RefRole::Read, "config"),
"expected a Read ref for the argument 'config'"
);
assert!(
!has_ref(&facts, RefRole::Read, "print"),
"a call callee must not also be a Read ref"
);
}
#[test]
fn member_access_base_is_read_leaf_is_not() {
let src = "function run() value = source.field end";
let facts = extract(src, "src/util.lua");
assert!(
has_ref(&facts, RefRole::Read, "source"),
"the base of a member access should be a Read ref"
);
assert!(
!has_ref(&facts, RefRole::Read, "field"),
"the leaf of a member access must not be a Read ref"
);
}
}