use tree_sitter::{Node, Parser};
use crate::error::{CodegraphError, Result};
use crate::graph::types::{
Binding, BindingKind, ByteSpan, EntryPoint, FileFacts, RefRole, Reference, Scope, ScopeId,
ScopeKind, Symbol, SymbolKind, TypeRefContext, Visibility,
};
use crate::lang::Language;
use crate::symbol::Descriptor;
use super::{
ExtractCtx, Extractor, MIN_REF_LEN, attach_reference_scopes, collect_call_references,
definition_bindings, field_text, import_bindings, make_symbol, node_span, node_text,
one_line_signature, push_binding, push_ref, push_scope, push_type_ref,
};
const CALL_QUERY: &str = r#"
(call
function: [
(identifier) @callee
(attribute attribute: (identifier) @callee)
]
)
"#;
pub struct PythonExtractor;
impl Extractor for PythonExtractor {
fn lang(&self) -> Language {
Language::Python
}
fn extract(&self, source: &str, file: &str) -> Result<FileFacts> {
let ts_language = crate::grammar::python();
let mut parser = Parser::new();
parser
.set_language(&ts_language)
.map_err(|_| CodegraphError::Parse {
path: file.to_owned(),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| CodegraphError::Parse {
path: file.to_owned(),
})?;
let root = tree.root_node();
let bytes = source.as_bytes();
let namespaces = python_namespaces(file);
let ctx = ExtractCtx {
bytes,
file,
lang: Language::Python,
};
let defs = collect_symbols(&root, &ctx, &namespaces);
let def_bindings = definition_bindings(&defs);
let mut symbols = defs;
let mut mod_sym = super::module_symbol(Language::Python, &namespaces, file, source.len());
let module_id = mod_sym.id.to_scip_string();
if module_is_main_entry(&root, bytes) {
mod_sym.entry_points.push(EntryPoint::Main);
}
symbols.push(mod_sym);
let mut references = collect_call_references(
&root,
&ts_language,
CALL_QUERY,
Language::Python,
bytes,
file,
)?;
collect_inheritance(&root, bytes, file, &mut references);
collect_imports(&root, bytes, file, &mut references, &module_id);
collect_type_references(&root, bytes, file, &mut references);
collect_read_references(&root, bytes, file, &mut references);
collect_write_references(&root, bytes, file, &mut references);
let scopes = collect_scopes(&root, source.len());
attach_reference_scopes(&mut references, &scopes);
let mut bindings = collect_bindings(&root, bytes, &scopes);
bindings.extend(def_bindings);
bindings.extend(import_bindings(&references, &scopes));
Ok(FileFacts {
file: file.to_owned(),
lang: Language::Python.as_str().to_owned(),
symbols,
references,
scopes,
bindings,
ffi_exports: Vec::new(),
})
}
}
fn python_namespaces(file: &str) -> Vec<String> {
let p = file.strip_prefix("src/").unwrap_or(file);
let mut parts: Vec<String> = p
.split('/')
.filter(|s| !s.is_empty())
.map(str::to_owned)
.collect();
if let Some(last) = parts.pop() {
let stem = last
.strip_suffix(".pyi")
.or_else(|| last.strip_suffix(".py"))
.unwrap_or(&last);
if stem != "__init__" {
parts.push(stem.to_owned());
}
}
parts
}
const PY_ROUTE_VERBS: &[&str] = &[
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
"trace",
"route",
"websocket",
"ws",
];
fn entry_points_for(fn_name: &str, outer_node: &Node, bytes: &[u8]) -> Vec<EntryPoint> {
let mut markers: Vec<EntryPoint> = Vec::new();
if fn_name == "main" {
markers.push(EntryPoint::Main);
}
if outer_node.kind() != "decorated_definition" {
return markers;
}
for child in outer_node.children(&mut outer_node.walk()) {
if child.kind() != "decorator" {
continue;
}
let Some(call_node) = child
.children(&mut child.walk())
.find(|c| c.kind() == "call")
else {
continue;
};
let Some(func_node) = call_node.child_by_field_name("function") else {
continue;
};
let (terminal, callee_text) = match func_node.kind() {
"attribute" => {
let terminal = func_node
.child_by_field_name("attribute")
.map(|n| node_text(&n, bytes))
.unwrap_or("");
let callee = node_text(&func_node, bytes);
(terminal, callee)
}
"identifier" => {
let text = node_text(&func_node, bytes);
(text, text)
}
_ => continue,
};
if PY_ROUTE_VERBS.contains(&terminal) {
markers.push(EntryPoint::HttpRoute(callee_text.to_owned()));
}
}
markers
}
fn module_is_main_entry(root: &tree_sitter::Node, bytes: &[u8]) -> bool {
root.children(&mut root.walk())
.filter(|n| n.kind() == "if_statement")
.filter_map(|n| n.child_by_field_name("condition"))
.filter(|cond| cond.kind() == "comparison_operator")
.any(|cond| is_name_eq_main(&cond, bytes))
}
fn is_name_eq_main(cond: &Node, bytes: &[u8]) -> bool {
let eq_tokens = cond
.children(&mut cond.walk())
.filter(|c| !c.is_named() && c.kind() == "==")
.count();
if eq_tokens != 1 {
return false;
}
let operands: Vec<Node> = cond.named_children(&mut cond.walk()).collect();
if operands.len() != 2 {
return false;
}
let (a, b) = (operands[0], operands[1]);
is_dunder_name_pair(&a, &b, bytes) || is_dunder_name_pair(&b, &a, bytes)
}
fn is_dunder_name_pair(name_node: &Node, str_node: &Node, bytes: &[u8]) -> bool {
name_node.kind() == "identifier"
&& node_text(name_node, bytes) == "__name__"
&& str_node.kind() == "string"
&& string_content(str_node, bytes) == "__main__"
}
fn string_content<'a>(string_node: &Node, bytes: &'a [u8]) -> &'a str {
if let Some(content) = string_node
.children(&mut string_node.walk())
.find(|c| c.kind() == "string_content")
{
return node_text(&content, bytes);
}
node_text(string_node, bytes)
.trim_matches('"')
.trim_matches('\'')
}
fn collect_symbols(root: &Node, ctx: &ExtractCtx, namespaces: &[String]) -> Vec<Symbol> {
let mut out = Vec::new();
for child in root.children(&mut root.walk()) {
let parsed = match child.kind() {
"function_definition" => def_of(&child, &child, ctx.bytes, true),
"class_definition" => def_of(&child, &child, ctx.bytes, false),
"decorated_definition" => {
let Some(inner) = child
.children(&mut child.walk())
.find(|c| matches!(c.kind(), "function_definition" | "class_definition"))
else {
continue;
};
let is_fn = inner.kind() == "function_definition";
def_of(&child, &inner, ctx.bytes, is_fn)
}
"expression_statement" | "assignment" => const_of(&child, ctx.bytes),
_ => None,
};
let Some((span_node, sig_node, name, kind, leaf)) = parsed else {
continue;
};
let mut descriptors: Vec<Descriptor> = namespaces
.iter()
.cloned()
.map(Descriptor::Namespace)
.collect();
descriptors.push(leaf);
let signature = one_line_signature(node_text(&sig_node, ctx.bytes), &[':']);
let mut sym = make_symbol(
ctx,
&span_node,
name,
kind,
Visibility::Public,
descriptors,
signature,
);
if sym.kind == SymbolKind::Function {
sym.entry_points = entry_points_for(&sym.name, &span_node, ctx.bytes);
}
out.push(sym);
}
out
}
fn def_of<'a>(
span_node: &Node<'a>,
sig_node: &Node<'a>,
bytes: &[u8],
is_fn: bool,
) -> Option<(Node<'a>, Node<'a>, String, SymbolKind, Descriptor)> {
let name = sig_node
.children(&mut sig_node.walk())
.find(|c| c.kind() == "identifier")
.map(|c| node_text(&c, bytes).to_owned())?;
if name.chars().all(|c| c == '_') {
return None;
}
let (kind, leaf) = if is_fn {
(
SymbolKind::Function,
Descriptor::Method {
name: name.clone(),
disambiguator: String::new(),
},
)
} else {
(SymbolKind::Class, Descriptor::Type(name.clone()))
};
Some((*span_node, *sig_node, name, kind, leaf))
}
fn const_of<'a>(
node: &Node<'a>,
bytes: &[u8],
) -> Option<(Node<'a>, Node<'a>, String, SymbolKind, Descriptor)> {
let assign = if node.kind() == "assignment" {
*node
} else {
node.children(&mut node.walk())
.find(|c| c.kind() == "assignment")?
};
let lhs = assign
.children(&mut assign.walk())
.find(|c| c.kind() == "identifier")?;
let name = node_text(&lhs, bytes).to_owned();
if name.len() < 3
|| !name
.chars()
.all(|c| c.is_uppercase() || c == '_' || c.is_numeric())
{
return None;
}
Some((
*node,
*node,
name.clone(),
SymbolKind::Const,
Descriptor::Term(name),
))
}
fn collect_imports(
node: &Node,
bytes: &[u8],
file: &str,
out: &mut Vec<Reference>,
module_id: &str,
) {
match node.kind() {
"import_from_statement" => {
let module_name = node.child_by_field_name("module_name");
let from_path = module_name.map_or("", |n| node_text(&n, bytes));
if let Some(mn) = module_name {
emit_module_path_refs(&mn, false, bytes, file, out);
}
for child in node.children_by_field_name("name", &mut node.walk()) {
match child.kind() {
"dotted_name" => {
let text = node_text(&child, bytes);
let leaf = super::simple_type_name(text, ".");
super::push_import_ref(out, leaf, &child, file, module_id, from_path);
}
"aliased_import" => {
if let Some(name_node) = child.child_by_field_name("name") {
let text = node_text(&name_node, bytes);
let leaf = super::simple_type_name(text, ".");
super::push_import_ref(
out, leaf, &name_node, file, module_id, from_path,
);
}
}
_ => {}
}
}
return;
}
"import_statement" => {
for child in node.children_by_field_name("name", &mut node.walk()) {
match child.kind() {
"dotted_name" => {
let text = node_text(&child, bytes);
let leaf = super::simple_type_name(text, ".");
emit_module_path_refs(&child, true, bytes, file, out);
super::push_import_ref(out, leaf, &child, file, module_id, text);
}
"aliased_import" => {
if let Some(name_node) = child.child_by_field_name("name") {
let text = node_text(&name_node, bytes);
let leaf = super::simple_type_name(text, ".");
super::push_import_ref(out, leaf, &name_node, file, module_id, text);
}
}
_ => {}
}
}
return;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
collect_imports(&child, bytes, file, out, module_id);
}
}
fn emit_module_path_refs(
dotted: &Node,
skip_last: bool,
bytes: &[u8],
file: &str,
out: &mut Vec<Reference>,
) {
if skip_last {
let idents: Vec<Node> = dotted
.children(&mut dotted.walk())
.filter(|c| c.kind() == "identifier")
.collect();
let module_count = idents.len().saturating_sub(1);
for id in idents.iter().take(module_count) {
push_ref(out, node_text(id, bytes), id, file, RefRole::ModuleRef);
}
} else {
for id in dotted
.children(&mut dotted.walk())
.filter(|c| c.kind() == "identifier")
{
push_ref(out, node_text(&id, bytes), &id, file, RefRole::ModuleRef);
}
}
}
fn collect_inheritance(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "class_definition" {
if let Some(superclasses) = node.child_by_field_name("superclasses") {
for child in superclasses.children(&mut superclasses.walk()) {
if !child.is_named() {
continue;
}
match child.kind() {
"identifier" => {
super::push_ref(
out,
node_text(&child, bytes),
&child,
file,
RefRole::IsImplementation,
);
}
"attribute" => {
if let Some(name) = field_text(&child, "attribute", bytes) {
super::push_ref(out, &name, &child, file, RefRole::IsImplementation);
}
}
_ => {} }
}
}
}
for child in node.children(&mut node.walk()) {
collect_inheritance(&child, bytes, file, out);
}
}
fn emit_type_node(
node: &Node,
bytes: &[u8],
file: &str,
ctx: TypeRefContext,
out: &mut Vec<Reference>,
) {
match node.kind() {
"type" => {
for child in node.named_children(&mut node.walk()) {
emit_type_node(&child, bytes, file, ctx, out);
}
}
"identifier" => {
let name = node_text(node, bytes);
push_type_ref(out, name, node, file, ctx);
}
"generic_type" => {
if let Some(head) = node.named_children(&mut node.walk()).next() {
if head.kind() == "identifier" {
push_type_ref(out, node_text(&head, bytes), &head, file, ctx);
}
}
if let Some(tp) = node
.named_children(&mut node.walk())
.find(|c| c.kind() == "type_parameter")
{
for child in tp.named_children(&mut tp.walk()) {
emit_type_node(&child, bytes, file, TypeRefContext::GenericArg, out);
}
}
}
"union_type" => {
for child in node.named_children(&mut node.walk()) {
emit_type_node(&child, bytes, file, ctx, out);
}
}
"member_type" => {
if let Some(id) = node
.named_children(&mut node.walk())
.filter(|c| c.kind() == "identifier")
.last()
{
push_type_ref(out, node_text(&id, bytes), &id, file, ctx);
}
}
_ => {}
}
}
fn collect_type_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
match node.kind() {
"typed_parameter" | "typed_default_parameter" => {
if let Some(typ) = node.child_by_field_name("type") {
emit_type_node(&typ, bytes, file, TypeRefContext::ParameterType, out);
}
}
"function_definition" => {
if let Some(ret) = node.child_by_field_name("return_type") {
emit_type_node(&ret, bytes, file, TypeRefContext::ReturnType, out);
}
for child in node.children(&mut node.walk()) {
collect_type_references(&child, bytes, file, out);
}
return; }
"assignment" => {
if let Some(typ) = node.child_by_field_name("type") {
emit_type_node(&typ, bytes, file, TypeRefContext::Field, out);
}
}
_ => {}
}
for child in node.children(&mut node.walk()) {
collect_type_references(&child, bytes, file, out);
}
}
fn is_non_read_position(node: &Node) -> bool {
let parent = match node.parent() {
Some(p) => p,
None => return true,
};
match parent.kind() {
"call" => parent.child_by_field_name("function").as_ref() == Some(node),
"function_definition" | "class_definition" => {
parent.child_by_field_name("name").as_ref() == Some(node)
}
"parameters" => true,
"typed_parameter" | "list_splat_pattern" | "dictionary_splat_pattern" => {
parent.child_by_field_name("type").as_ref() != Some(node) && node.kind() == "identifier"
}
"default_parameter" => parent.child_by_field_name("name").as_ref() == Some(node),
"typed_default_parameter" => parent.child_by_field_name("name").as_ref() == Some(node),
"import_statement" | "import_from_statement" | "dotted_name" | "aliased_import" => true,
"assignment" => parent.child_by_field_name("left").as_ref() == Some(node),
"attribute" => parent.child_by_field_name("attribute").as_ref() == Some(node),
"type" => true,
"generic_type" => true,
"union_type" | "member_type" => true,
_ => false,
}
}
fn collect_read_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "identifier" {
let name = node_text(node, bytes);
if name.len() >= MIN_REF_LEN && !is_non_read_position(node) {
push_ref(out, name, node, file, RefRole::Read);
}
return;
}
for child in node.children(&mut node.walk()) {
collect_read_references(&child, bytes, file, out);
}
}
fn collect_write_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "assignment" {
if let Some(lhs) = node.child_by_field_name("left") {
if lhs.kind() == "identifier" {
let name = node_text(&lhs, bytes);
if name.len() >= MIN_REF_LEN {
push_ref(out, name, &lhs, file, RefRole::Write);
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_write_references(&child, bytes, file, out);
}
}
fn collect_scopes(root: &Node, source_len: usize) -> Vec<Scope> {
let mut scopes = Vec::new();
push_scope(
&mut scopes,
None,
ByteSpan {
start: 0,
end: source_len,
},
ScopeKind::Module,
);
for child in root.children(&mut root.walk()) {
scope_dfs(&child, 0, &mut scopes);
}
scopes
}
fn scope_dfs(node: &Node, parent_id: ScopeId, scopes: &mut Vec<Scope>) {
if node.kind() == "function_definition" {
let fn_id = push_scope(
scopes,
Some(parent_id),
node_span(node),
ScopeKind::Function,
);
if let Some(body) = node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, fn_id, scopes);
}
}
} else {
for child in node.children(&mut node.walk()) {
scope_dfs(&child, parent_id, scopes);
}
}
}
fn collect_bindings(root: &Node, bytes: &[u8], scopes: &[Scope]) -> Vec<Binding> {
let mut out = Vec::new();
collect_bindings_dfs(root, bytes, scopes, &mut out);
out
}
fn collect_bindings_dfs(node: &Node, bytes: &[u8], scopes: &[Scope], out: &mut Vec<Binding>) {
match node.kind() {
"function_definition" => {
if let Some(params) = node.child_by_field_name("parameters") {
collect_params(¶ms, bytes, scopes, out);
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
"assignment" => {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "identifier" {
let intro = left.start_byte();
let name = node_text(&left, bytes).to_owned();
push_binding(out, name, intro, BindingKind::Local, scopes);
}
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
_ => {
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
}
}
fn collect_params(params: &Node, bytes: &[u8], scopes: &[Scope], out: &mut Vec<Binding>) {
for child in params.named_children(&mut params.walk()) {
let ident = match child.kind() {
"identifier" => Some(child),
"default_parameter" | "typed_default_parameter" => child.child_by_field_name("name"),
"typed_parameter" | "list_splat_pattern" | "dictionary_splat_pattern" => child
.named_children(&mut child.walk())
.find(|c| c.kind() == "identifier"),
_ => None,
};
if let Some(id) = ident {
if id.kind() == "identifier" {
let intro = id.start_byte();
let name = node_text(&id, bytes).to_owned();
push_binding(out, name, intro, BindingKind::Param, scopes);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_defs_with_dotted_module() {
let src = "\
def validate_token(tok):
return helper()
class Config:
pass
async def fetch_data():
pass
MAX_RETRIES = 3
";
let facts = PythonExtractor.extract(src, "src/auth/jwt.py").unwrap();
let by_name = |n: &str| facts.symbols.iter().find(|s| s.name == n).cloned();
let vt = by_name("validate_token").unwrap();
assert_eq!(
vt.id.to_scip_string(),
"codegraph . . . auth/jwt/validate_token()."
);
assert_eq!(vt.kind, SymbolKind::Function);
assert_eq!(by_name("Config").unwrap().kind, SymbolKind::Class);
assert!(by_name("fetch_data").is_some());
assert_eq!(by_name("MAX_RETRIES").unwrap().kind, SymbolKind::Const);
}
#[test]
fn init_collapses_to_package() {
let facts = PythonExtractor
.extract("def helper(): pass", "src/auth/__init__.py")
.unwrap();
assert_eq!(
facts.symbols[0].id.to_scip_string(),
"codegraph . . . auth/helper()."
);
}
#[test]
fn emits_function_scope_and_bindings() {
let src = "def run(arg):\n local = 1\n helper(arg)\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
assert_eq!(facts.scopes.len(), 2, "expected module + function scope");
assert_eq!(facts.scopes[0].kind, ScopeKind::Module);
assert_eq!(facts.scopes[1].kind, ScopeKind::Function);
let has = |name: &str, kind: BindingKind| {
facts
.bindings
.iter()
.any(|b| b.name == name && b.kind == kind)
};
assert!(has("arg", BindingKind::Param), "param binding missing");
assert!(has("local", BindingKind::Local), "local binding missing");
assert!(has("run", BindingKind::Definition), "def binding missing");
}
#[test]
fn class_body_opens_no_scope_legb() {
let src = "class Foo:\n def method(self):\n pass\n";
let facts = PythonExtractor.extract(src, "src/m.py").unwrap();
let fn_scopes: Vec<_> = facts
.scopes
.iter()
.filter(|s| s.kind == ScopeKind::Function)
.collect();
assert_eq!(fn_scopes.len(), 1, "only the method opens a scope");
assert!(
!facts.scopes.iter().any(|s| s.kind == ScopeKind::Type),
"class body must not open a Type scope in Python"
);
assert_eq!(
fn_scopes[0].parent,
Some(0),
"method's enclosing scope is the module (class skipped)"
);
}
#[test]
fn extracts_call_references() {
let facts = PythonExtractor
.extract(
"def main():\n validate_token('t')\n helper()\n",
"src/main.py",
)
.unwrap();
let names: Vec<&str> = facts.references.iter().map(|r| r.name.as_str()).collect();
assert!(names.contains(&"validate_token"));
assert!(names.contains(&"helper"));
}
#[test]
fn extracts_single_base_class_inherit_reference() {
let src = "class Sub(Base):\n pass\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let inherit_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert_eq!(
inherit_names,
vec!["Base"],
"expected ['Base'] in {inherit_names:?}"
);
}
#[test]
fn extracts_multiple_base_classes_inherit_references() {
let src = "class Multi(A, B):\n pass\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let inherit_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert!(
inherit_names.contains(&"A"),
"expected 'A' in {inherit_names:?}"
);
assert!(
inherit_names.contains(&"B"),
"expected 'B' in {inherit_names:?}"
);
}
#[test]
fn extracts_dotted_base_class_leaf_segment() {
let src = "class Dotted(mod.Base):\n pass\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let inherit_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert_eq!(
inherit_names,
vec!["Base"],
"expected ['Base'] in {inherit_names:?}"
);
}
#[test]
fn import_from_statement_emits_leaf_name() {
let src = "from pkg.models import Config\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
let import_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.map(|r| r.name.as_str())
.collect();
assert_eq!(
import_names,
vec!["Config"],
"expected ['Config'] in {import_names:?}"
);
}
#[test]
fn import_statement_emits_module_leaf() {
let src = "import os\nimport foo.bar\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let import_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.map(|r| r.name.as_str())
.collect();
assert!(
import_names.contains(&"os"),
"expected 'os' in {import_names:?}"
);
assert!(
import_names.contains(&"bar"),
"expected 'bar' in {import_names:?}"
);
}
#[test]
fn import_from_statement_multiple_names() {
let src = "from x import A, B\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let import_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.map(|r| r.name.as_str())
.collect();
assert!(
import_names.contains(&"A"),
"expected 'A' in {import_names:?}"
);
assert!(
import_names.contains(&"B"),
"expected 'B' in {import_names:?}"
);
}
#[test]
fn import_alias_emits_real_name_not_alias() {
let src = "from pkg import Thing as T\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let import_names: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.map(|r| r.name.as_str())
.collect();
assert!(
import_names.contains(&"Thing"),
"expected 'Thing' in {import_names:?}"
);
assert!(
!import_names.contains(&"T"),
"alias 'T' must NOT appear in {import_names:?}"
);
}
#[test]
fn wildcard_import_emits_nothing() {
let src = "from x import *\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let import_refs: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.map(|r| r.name.as_str())
.collect();
assert!(
import_refs.is_empty(),
"expected no Import refs for wildcard, got {import_refs:?}"
);
}
#[test]
fn import_refs_carry_source_module() {
let src = "from pkg.models import Config\n";
let file = "src/app.py";
let facts = PythonExtractor.extract(src, file).unwrap();
let namespaces = python_namespaces(file);
let expected_module_id =
crate::extract::module_symbol(Language::Python, &namespaces, file, src.len())
.id
.to_scip_string();
let import_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Import)
.collect();
assert!(!import_refs.is_empty(), "expected at least one Import ref");
for r in &import_refs {
assert_eq!(
r.source_module,
Some(expected_module_id.clone()),
"Import ref '{}' should carry source_module = {:?}",
r.name,
expected_module_id
);
}
}
#[test]
fn call_refs_have_no_source_module() {
let src = "def main():\n helper()\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let call_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Call)
.collect();
assert!(!call_refs.is_empty(), "expected at least one Call ref");
for r in &call_refs {
assert_eq!(
r.source_module, None,
"Call ref '{}' must have source_module = None",
r.name
);
}
}
#[test]
fn py_param_type_ref_emitted() {
let src = "def f(c: Config): pass\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let r = facts
.references
.iter()
.find(|r| r.role == RefRole::TypeRef && r.name == "Config")
.expect("expected TypeRef ref for 'Config'");
assert_eq!(
r.type_ref_ctx,
Some(TypeRefContext::ParameterType),
"expected ParameterType ctx, got {:?}",
r.type_ref_ctx
);
}
#[test]
fn py_return_type_ref_emitted() {
let src = "def f() -> Config: pass\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let r = facts
.references
.iter()
.find(|r| r.role == RefRole::TypeRef && r.name == "Config")
.expect("expected TypeRef ref for 'Config'");
assert_eq!(
r.type_ref_ctx,
Some(TypeRefContext::ReturnType),
"expected ReturnType ctx, got {:?}",
r.type_ref_ctx
);
}
#[test]
fn py_annotated_field_type_ref_emitted() {
let src = "class C:\n name: Config\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let r = facts
.references
.iter()
.find(|r| r.role == RefRole::TypeRef && r.name == "Config")
.expect("expected TypeRef ref for 'Config'");
assert_eq!(
r.type_ref_ctx,
Some(TypeRefContext::Field),
"expected Field ctx, got {:?}",
r.type_ref_ctx
);
}
#[test]
fn py_read_ref_emitted_for_use_not_declaration() {
let src = "def f():\n base = 1\n return base\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let read_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "base")
.collect();
assert!(
!read_refs.is_empty(),
"expected at least one Read ref for 'base', got none"
);
let use_ref = read_refs
.iter()
.find(|r| r.occ.byte > 20)
.expect("expected Read ref for 'base' in the return statement (byte > 20)");
assert!(
use_ref.occ.byte > 20,
"Read ref should be at the use site, not the declaration"
);
}
#[test]
fn py_write_ref_emitted_for_assignment() {
let src = "def f():\n xxx = 5\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let write_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Write && r.name == "xxx")
.collect();
assert!(
!write_refs.is_empty(),
"expected at least one Write ref for 'xxx', got none — all refs: {:?}",
facts
.references
.iter()
.map(|r| (&r.name, r.role))
.collect::<Vec<_>>()
);
}
#[test]
fn py_call_not_also_read() {
let src = "def run():\n helper()\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let call_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Call && r.name == "helper")
.collect();
assert!(!call_refs.is_empty(), "expected a Call ref for 'helper'");
let read_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "helper")
.collect();
assert!(
read_refs.is_empty(),
"helper() must NOT produce a Read ref; got: {read_refs:?}"
);
}
#[test]
fn py_attribute_not_a_read_of_property() {
let src = "def run():\n return obj.foo\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let foo_reads: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "foo")
.collect();
assert!(
foo_reads.is_empty(),
"attribute 'foo' must NOT be a Read ref; got: {foo_reads:?}"
);
}
fn module_ref_names(facts: &FileFacts) -> Vec<String> {
let mut names: Vec<String> = facts
.references
.iter()
.filter(|r| r.role == RefRole::ModuleRef)
.map(|r| r.name.clone())
.collect();
names.sort();
names
}
#[test]
fn from_import_emits_module_refs_for_path_segments() {
let src = "from a.b import c\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
assert_eq!(
module_ref_names(&facts),
vec!["a".to_owned(), "b".to_owned()],
);
assert!(
facts
.references
.iter()
.any(|r| r.role == RefRole::Import && r.name == "c"),
"expected Import ref 'c'"
);
}
#[test]
fn from_single_module_import() {
let src = "from a import c\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
assert_eq!(module_ref_names(&facts), vec!["a".to_owned()]);
assert!(
facts
.references
.iter()
.any(|r| r.role == RefRole::Import && r.name == "c"),
"expected Import ref 'c'"
);
}
#[test]
fn import_dotted_emits_module_refs_except_leaf() {
let src = "import a.b.c\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
assert_eq!(
module_ref_names(&facts),
vec!["a".to_owned(), "b".to_owned()],
);
assert!(
facts
.references
.iter()
.any(|r| r.role == RefRole::Import && r.name == "c"),
"expected Import ref 'c' (the leaf segment) to still be present"
);
}
#[test]
fn relative_import_no_crash() {
let src = "from . import x\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
let module_refs = module_ref_names(&facts);
assert!(
module_refs.is_empty(),
"relative `from . import x` must emit no ModuleRef, got {:?}",
module_refs
);
assert!(
facts.references.iter().all(|r| !r.name.is_empty()),
"no reference should have an empty name"
);
}
#[test]
fn import_from_statement_carries_from_path() {
let src = "from pkg.models import Config\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
let r = facts
.references
.iter()
.find(|r| r.role == RefRole::Import && r.name == "Config")
.expect("expected Import ref for 'Config'");
assert_eq!(
r.from_path,
Some("pkg.models".to_owned()),
"from_path should be 'pkg.models', got {:?}",
r.from_path
);
}
#[test]
fn plain_import_statement_carries_from_path() {
let src = "import os\nimport foo.bar\n";
let facts = PythonExtractor.extract(src, "src/mod.py").unwrap();
let os_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::Import && r.name == "os")
.expect("expected Import ref for 'os'");
assert_eq!(
os_ref.from_path,
Some("os".to_owned()),
"from_path for 'import os' should be 'os', got {:?}",
os_ref.from_path
);
let bar_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::Import && r.name == "bar")
.expect("expected Import ref for 'bar'");
assert_eq!(
bar_ref.from_path,
Some("foo.bar".to_owned()),
"from_path for 'import foo.bar' should be 'foo.bar', got {:?}",
bar_ref.from_path
);
}
fn sym_by_name(facts: &FileFacts, name: &str) -> Symbol {
facts
.symbols
.iter()
.find(|s| s.name == name)
.unwrap_or_else(|| {
panic!("symbol '{name}' not found; symbols: {:?}", {
let names: Vec<&str> = facts.symbols.iter().map(|s| s.name.as_str()).collect();
names
})
})
.clone()
}
fn ep_str(eps: &[EntryPoint]) -> String {
eps.iter()
.map(|ep| match ep {
EntryPoint::Main => "Main".to_owned(),
EntryPoint::HttpRoute(m) => format!("HttpRoute({m})"),
})
.collect::<Vec<_>>()
.join(", ")
}
#[test]
fn entry_point_app_get_route() {
let src = "@app.get(\"/users\")\ndef list_users():\n pass\n";
let facts = PythonExtractor.extract(src, "src/routes.py").unwrap();
let sym = sym_by_name(&facts, "list_users");
assert_eq!(
sym.entry_points.len(),
1,
"expected exactly 1 entry point, got [{}]",
ep_str(&sym.entry_points)
);
assert!(
matches!(&sym.entry_points[0], EntryPoint::HttpRoute(m) if m == "app.get"),
"expected HttpRoute(\"app.get\"), got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_app_route_with_methods_arg() {
let src = "@app.route(\"/x\", methods=[\"POST\"])\ndef handler():\n pass\n";
let facts = PythonExtractor.extract(src, "src/routes.py").unwrap();
let sym = sym_by_name(&facts, "handler");
assert_eq!(
sym.entry_points.len(),
1,
"expected exactly 1 entry point, got [{}]",
ep_str(&sym.entry_points)
);
assert!(
matches!(&sym.entry_points[0], EntryPoint::HttpRoute(m) if m == "app.route"),
"expected HttpRoute(\"app.route\"), got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_non_route_decorator_ignored() {
let src =
"import functools\n@functools.lru_cache(maxsize=128)\ndef compute(x):\n pass\n";
let facts = PythonExtractor.extract(src, "src/util.py").unwrap();
let sym2 = sym_by_name(&facts, "compute");
assert!(
sym2.entry_points.is_empty(),
"non-route call decorator must not produce entry points; got [{}]",
ep_str(&sym2.entry_points)
);
}
#[test]
fn entry_point_main_function() {
let src = "def main():\n pass\n";
let facts = PythonExtractor.extract(src, "src/main.py").unwrap();
let sym = sym_by_name(&facts, "main");
assert_eq!(
sym.entry_points.len(),
1,
"expected exactly 1 entry point, got [{}]",
ep_str(&sym.entry_points)
);
assert!(
matches!(&sym.entry_points[0], EntryPoint::Main),
"expected Main, got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_plain_function_empty() {
let src = "def process(data):\n return data\n";
let facts = PythonExtractor.extract(src, "src/util.py").unwrap();
let sym = sym_by_name(&facts, "process");
assert!(
sym.entry_points.is_empty(),
"plain function must have no entry points; got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_fastapi_router_post() {
let src = "@router.post(\"/items\")\ndef create_item():\n pass\n";
let facts = PythonExtractor.extract(src, "src/items.py").unwrap();
let sym = sym_by_name(&facts, "create_item");
assert_eq!(
sym.entry_points.len(),
1,
"expected exactly 1 entry point, got [{}]",
ep_str(&sym.entry_points)
);
assert!(
matches!(&sym.entry_points[0], EntryPoint::HttpRoute(m) if m == "router.post"),
"expected HttpRoute(\"router.post\"), got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_websocket_route() {
let src = "@bp.websocket(\"/ws\")\ndef ws_handler():\n pass\n";
let facts = PythonExtractor.extract(src, "src/ws.py").unwrap();
let sym = sym_by_name(&facts, "ws_handler");
assert_eq!(
sym.entry_points.len(),
1,
"expected exactly 1 entry point, got [{}]",
ep_str(&sym.entry_points)
);
assert!(
matches!(&sym.entry_points[0], EntryPoint::HttpRoute(m) if m == "bp.websocket"),
"expected HttpRoute(\"bp.websocket\"), got [{}]",
ep_str(&sym.entry_points)
);
}
#[test]
fn entry_point_main_guard_marks_module() {
let src = "if __name__ == \"__main__\":\n pass\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
let module = facts
.symbols
.iter()
.find(|s| s.kind == SymbolKind::Module)
.expect("expected a Module symbol");
assert!(
module
.entry_points
.iter()
.any(|ep| matches!(ep, EntryPoint::Main)),
"module guard must mark the module Main; got [{}]",
ep_str(&module.entry_points)
);
}
#[test]
fn entry_point_non_main_guard_ignored() {
let src = "if __name__ == \"__other__\":\n pass\n";
let facts = PythonExtractor.extract(src, "src/app.py").unwrap();
let module = facts
.symbols
.iter()
.find(|s| s.kind == SymbolKind::Module)
.expect("expected a Module symbol");
assert!(
module.entry_points.is_empty(),
"non-main guard must not mark the module; got [{}]",
ep_str(&module.entry_points)
);
}
}