use std::path::{Path, PathBuf};
use streaming_iterator::StreamingIterator;
use tree_sitter::{Node, Parser, Query, QueryCursor};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EntryPointKind {
Main,
LibraryExport,
Test,
Ffi,
ProcMacro,
Init,
BuildScript,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EntryPoint {
pub name: String,
pub kind: EntryPointKind,
pub file_path: PathBuf,
pub line: u32,
}
pub trait EntryPointDetector {
fn detect(&self, source: &str, file_path: &Path) -> Vec<EntryPoint>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct RustEntryDetector;
impl EntryPointDetector for RustEntryDetector {
fn detect(&self, source: &str, file_path: &Path) -> Vec<EntryPoint> {
let mut entries = Vec::new();
let Some(tree) = parse_with(source, &tree_sitter_rust::LANGUAGE.into()) else {
return entries;
};
let root = tree.root_node();
let bytes = source.as_bytes();
if file_path.file_name().and_then(|s| s.to_str()) == Some("build.rs") {
entries.push(EntryPoint {
name: "build.rs".to_string(),
kind: EntryPointKind::BuildScript,
file_path: file_path.to_path_buf(),
line: 1,
});
}
let is_lib_or_mod_rs = matches!(
file_path.file_name().and_then(|s| s.to_str()),
Some("lib.rs" | "mod.rs")
);
visit_rust_node(&root, bytes, file_path, is_lib_or_mod_rs, &mut entries);
entries
}
}
fn visit_rust_node(
node: &Node<'_>,
bytes: &[u8],
file_path: &Path,
is_lib_or_mod_rs: bool,
out: &mut Vec<EntryPoint>,
) {
if node.kind() == "function_item" {
rust_classify_function(node, bytes, file_path, is_lib_or_mod_rs, out);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
visit_rust_node(&child, bytes, file_path, is_lib_or_mod_rs, out);
}
}
fn rust_classify_function(
node: &Node<'_>,
bytes: &[u8],
file_path: &Path,
is_lib_or_mod_rs: bool,
out: &mut Vec<EntryPoint>,
) {
let name_node = node.child_by_field_name("name");
let Some(name_node) = name_node else { return };
let Ok(name) = std::str::from_utf8(&bytes[name_node.start_byte()..name_node.end_byte()]) else {
return;
};
let line = u32::try_from(node.start_position().row + 1).unwrap_or(u32::MAX);
let attrs = collect_preceding_rust_attrs(node, bytes);
if attrs.iter().any(|a| {
a.starts_with("proc_macro_derive")
|| a.starts_with("proc_macro_attribute")
|| a == "proc_macro"
|| a.starts_with("proc_macro(")
}) {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::ProcMacro,
file_path: file_path.to_path_buf(),
line,
});
}
if attrs.iter().any(|a| a == "test" || a == "bench") {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Test,
file_path: file_path.to_path_buf(),
line,
});
}
let function_text =
std::str::from_utf8(&bytes[node.start_byte()..node.end_byte()]).unwrap_or("");
let has_extern_c =
rust_function_has_extern_c(node, bytes) || function_text.contains("extern \"C\"");
if attrs.iter().any(|a| a == "no_mangle") || has_extern_c {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Ffi,
file_path: file_path.to_path_buf(),
line,
});
}
if name == "main" {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Main,
file_path: file_path.to_path_buf(),
line,
});
}
if is_lib_or_mod_rs && rust_function_is_pub(node, bytes) {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::LibraryExport,
file_path: file_path.to_path_buf(),
line,
});
}
}
fn collect_preceding_rust_attrs(node: &Node<'_>, bytes: &[u8]) -> Vec<String> {
let mut attrs = Vec::new();
let mut prev = node.prev_sibling();
while let Some(p) = prev {
if p.kind() == "attribute_item" || p.kind() == "inner_attribute_item" {
let mut cursor = p.walk();
let mut attr_text: Option<String> = None;
for child in p.children(&mut cursor) {
if child.kind() == "attribute"
&& let Ok(text) =
std::str::from_utf8(&bytes[child.start_byte()..child.end_byte()])
{
attr_text = Some(text.to_string());
}
}
if let Some(t) = attr_text {
attrs.push(t);
}
prev = p.prev_sibling();
} else if p.kind().starts_with("line_comment") || p.kind().starts_with("block_comment") {
prev = p.prev_sibling();
} else {
break;
}
}
attrs
}
fn rust_function_is_pub(node: &Node<'_>, bytes: &[u8]) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier"
&& let Ok(text) = std::str::from_utf8(&bytes[child.start_byte()..child.end_byte()])
{
return text.starts_with("pub");
}
}
false
}
fn rust_function_has_extern_c(node: &Node<'_>, bytes: &[u8]) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() != "function_modifiers" {
continue;
}
let mut inner = child.walk();
for grandchild in child.children(&mut inner) {
if grandchild.kind() == "extern_modifier"
&& let Ok(text) =
std::str::from_utf8(&bytes[grandchild.start_byte()..grandchild.end_byte()])
&& text.contains("\"C\"")
{
return true;
}
}
}
false
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PythonEntryDetector;
impl EntryPointDetector for PythonEntryDetector {
fn detect(&self, source: &str, file_path: &Path) -> Vec<EntryPoint> {
let mut entries = Vec::new();
let Some(tree) = parse_with(source, &tree_sitter_python::LANGUAGE.into()) else {
return entries;
};
let root = tree.root_node();
let bytes = source.as_bytes();
let is_test_file = python_is_test_file(file_path);
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
match child.kind() {
"if_statement" if python_is_dunder_main_block(&child, bytes) => {
let line = u32::try_from(child.start_position().row + 1).unwrap_or(u32::MAX);
entries.push(EntryPoint {
name: "__main__".to_string(),
kind: EntryPointKind::Main,
file_path: file_path.to_path_buf(),
line,
});
}
"expression_statement" => {
if let Some(names) = python_extract_dunder_all(&child, bytes) {
let line =
u32::try_from(child.start_position().row + 1).unwrap_or(u32::MAX);
for n in names {
entries.push(EntryPoint {
name: n,
kind: EntryPointKind::LibraryExport,
file_path: file_path.to_path_buf(),
line,
});
}
}
}
"function_definition" | "decorated_definition" => {
let fn_node = if child.kind() == "decorated_definition" {
child.child_by_field_name("definition")
} else {
Some(child)
};
if let Some(fn_node) = fn_node
&& fn_node.kind() == "function_definition"
&& let Some(name_node) = fn_node.child_by_field_name("name")
&& let Ok(name) = std::str::from_utf8(
&bytes[name_node.start_byte()..name_node.end_byte()],
)
&& is_test_file
&& name.starts_with("test_")
{
let line =
u32::try_from(fn_node.start_position().row + 1).unwrap_or(u32::MAX);
entries.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Test,
file_path: file_path.to_path_buf(),
line,
});
}
}
_ => {}
}
}
entries
}
}
fn python_is_test_file(file_path: &Path) -> bool {
let Some(file_name) = file_path.file_name().and_then(|s| s.to_str()) else {
return false;
};
let is_py = Path::new(file_name)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("py"));
if !is_py {
return false;
}
let stem = Path::new(file_name)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("");
if stem.starts_with("test_") || stem.ends_with("_test") {
return true;
}
file_path
.components()
.any(|c| c.as_os_str() == std::ffi::OsStr::new("tests"))
}
fn python_is_dunder_main_block(node: &Node<'_>, bytes: &[u8]) -> bool {
let cond = node.child_by_field_name("condition");
let Some(cond) = cond else { return false };
let Ok(text) = std::str::from_utf8(&bytes[cond.start_byte()..cond.end_byte()]) else {
return false;
};
let normalized = text.replace(' ', "");
normalized.contains("__name__==\"__main__\"")
|| normalized.contains("__name__=='__main__'")
|| normalized.contains("\"__main__\"==__name__")
|| normalized.contains("'__main__'==__name__")
}
fn python_extract_dunder_all(node: &Node<'_>, bytes: &[u8]) -> Option<Vec<String>> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "assignment" {
let left = child.child_by_field_name("left")?;
let right = child.child_by_field_name("right")?;
let left_text = std::str::from_utf8(&bytes[left.start_byte()..left.end_byte()]).ok()?;
if left_text.trim() != "__all__" {
return None;
}
let mut names = Vec::new();
let mut inner = right.walk();
for grandchild in right.children(&mut inner) {
if grandchild.kind() != "string" {
continue;
}
let mut sc = grandchild.walk();
let mut content_text: Option<String> = None;
for sg in grandchild.children(&mut sc) {
if sg.kind() == "string_content"
&& let Ok(t) = std::str::from_utf8(&bytes[sg.start_byte()..sg.end_byte()])
{
content_text = Some(t.to_string());
}
}
if let Some(t) = content_text {
names.push(t);
} else if let Ok(raw) =
std::str::from_utf8(&bytes[grandchild.start_byte()..grandchild.end_byte()])
{
let trimmed = raw.trim_matches(|c| c == '"' || c == '\'');
names.push(trimmed.to_string());
}
}
return Some(names);
}
}
None
}
#[derive(Debug, Default, Clone, Copy)]
pub struct GoEntryDetector;
impl EntryPointDetector for GoEntryDetector {
fn detect(&self, source: &str, file_path: &Path) -> Vec<EntryPoint> {
let mut entries = Vec::new();
let Some(tree) = parse_with(source, &tree_sitter_go::LANGUAGE.into()) else {
return entries;
};
let root = tree.root_node();
let bytes = source.as_bytes();
let package_name = go_package_name(&root, bytes).unwrap_or_default();
let is_main_package = package_name == "main";
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
match child.kind() {
"function_declaration" => {
if let Some(name_node) = child.child_by_field_name("name")
&& let Ok(name) = std::str::from_utf8(
&bytes[name_node.start_byte()..name_node.end_byte()],
)
{
let line =
u32::try_from(child.start_position().row + 1).unwrap_or(u32::MAX);
go_classify(name, line, is_main_package, file_path, &mut entries);
}
}
"method_declaration" => {
if let Some(name_node) = child.child_by_field_name("name")
&& let Ok(name) = std::str::from_utf8(
&bytes[name_node.start_byte()..name_node.end_byte()],
)
&& !is_main_package
&& go_is_exported(name)
{
let line =
u32::try_from(child.start_position().row + 1).unwrap_or(u32::MAX);
entries.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::LibraryExport,
file_path: file_path.to_path_buf(),
line,
});
}
}
_ => {}
}
}
entries
}
}
fn go_package_name(root: &Node<'_>, bytes: &[u8]) -> Option<String> {
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() != "package_clause" {
continue;
}
let mut inner = child.walk();
for grandchild in child.children(&mut inner) {
if grandchild.kind() == "package_identifier"
&& let Ok(text) =
std::str::from_utf8(&bytes[grandchild.start_byte()..grandchild.end_byte()])
{
return Some(text.to_string());
}
}
}
None
}
fn go_classify(
name: &str,
line: u32,
is_main_package: bool,
file_path: &Path,
out: &mut Vec<EntryPoint>,
) {
if name == "main" && is_main_package {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Main,
file_path: file_path.to_path_buf(),
line,
});
return;
}
if name == "init" {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Init,
file_path: file_path.to_path_buf(),
line,
});
return;
}
if name.starts_with("Test")
|| name.starts_with("Benchmark")
|| name.starts_with("Example")
|| name.starts_with("Fuzz")
{
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::Test,
file_path: file_path.to_path_buf(),
line,
});
return;
}
if !is_main_package && go_is_exported(name) {
out.push(EntryPoint {
name: name.to_string(),
kind: EntryPointKind::LibraryExport,
file_path: file_path.to_path_buf(),
line,
});
}
}
fn go_is_exported(name: &str) -> bool {
name.chars().next().is_some_and(|c| c.is_ascii_uppercase())
}
#[must_use]
pub fn detector_for(language: &str) -> Option<Box<dyn EntryPointDetector>> {
match language {
"rust" | "rs" => Some(Box::new(RustEntryDetector)),
"python" | "py" | "pyi" => Some(Box::new(PythonEntryDetector)),
"go" => Some(Box::new(GoEntryDetector)),
_ => None,
}
}
fn parse_with(source: &str, language: &tree_sitter::Language) -> Option<tree_sitter::Tree> {
let mut parser = Parser::new();
parser.set_language(language).ok()?;
parser.parse(source, None)
}
#[allow(dead_code)]
pub(crate) fn query_match_lines(
source: &str,
language: &tree_sitter::Language,
query: &Query,
) -> Vec<u32> {
let mut lines = Vec::new();
let Some(tree) = parse_with(source, language) else {
return lines;
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source.as_bytes());
while let Some(m) = matches.next() {
for cap in m.captures {
let line = u32::try_from(cap.node.start_position().row + 1).unwrap_or(u32::MAX);
lines.push(line);
}
}
lines
}