use anyhow::Result;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Language, Node, Parser, QueryCursor};
use crate::types::{symbol_id, Edge, EdgeKind, Symbol, SymbolKind, Visibility};
use super::queries::{is_inside_nested_scope, CachedQuery};
use super::{node_text, ExtractionResult, Extractor};
pub struct PythonExtractor {
parser: Parser,
call_query: CachedQuery,
raise_query: CachedQuery,
except_query: CachedQuery,
type_ref_query: CachedQuery,
}
impl PythonExtractor {
pub fn new() -> Self {
let lang = Language::new(tree_sitter_python::LANGUAGE);
let mut parser = Parser::new();
parser
.set_language(&lang)
.expect("Python grammar should always load");
let call_query =
CachedQuery::new(&lang, "(call function: [(identifier) (attribute)] @callee)");
let raise_query = CachedQuery::new(
&lang,
r#"(raise_statement
[(call function: [(identifier) (attribute)] @exception)
(identifier) @exception
(attribute) @exception])"#,
);
let except_query = CachedQuery::new(
&lang,
r#"(except_clause
[(identifier) @exception_type
(tuple (identifier) @exception_type)
(attribute) @exception_type])"#,
);
let type_ref_query = CachedQuery::new(
&lang,
r#"[(identifier) @type_ref
(attribute) @type_ref]"#,
);
Self {
parser,
call_query,
raise_query,
except_query,
type_ref_query,
}
}
}
impl Default for PythonExtractor {
fn default() -> Self {
Self::new()
}
}
struct Queries<'a> {
call: &'a CachedQuery,
raise: &'a CachedQuery,
except: &'a CachedQuery,
type_ref: &'a CachedQuery,
}
impl Extractor for PythonExtractor {
fn extract(&mut self, source: &str, file_path: &str) -> Result<ExtractionResult> {
let tree = self
.parser
.parse(source, None)
.ok_or_else(|| anyhow::anyhow!("Failed to parse {file_path}"))?;
let mut symbols = Vec::new();
let mut edges = Vec::new();
let queries = Queries {
call: &self.call_query,
raise: &self.raise_query,
except: &self.except_query,
type_ref: &self.type_ref_query,
};
let root = tree.root_node();
extract_node(
&queries,
root,
source,
file_path,
None, None, &mut symbols,
&mut edges,
);
Ok(ExtractionResult { symbols, edges })
}
}
#[allow(clippy::too_many_arguments)]
fn extract_node(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
parent_id: Option<&str>,
parent_qname: Option<&str>,
symbols: &mut Vec<Symbol>,
edges: &mut Vec<Edge>,
) {
match node.kind() {
"function_definition" => {
extract_function(
queries,
node,
source,
file_path,
parent_id,
parent_qname,
symbols,
edges,
);
}
"class_definition" => {
extract_class(
queries,
node,
source,
file_path,
parent_id,
parent_qname,
symbols,
edges,
);
}
"decorated_definition" => {
let mut def_sym_id = None;
for child in node.named_children(&mut node.walk()) {
if child.kind() == "function_definition" || child.kind() == "class_definition" {
if let Some(name_node) = child.child_by_field_name("name") {
let name = node_text(name_node, source);
let kind = if child.kind() == "class_definition" {
"class"
} else if parent_id.is_some() {
"method"
} else {
"function"
};
def_sym_id = Some(symbol_id(file_path, kind, name, parent_qname));
}
}
}
for child in node.named_children(&mut node.walk()) {
if child.kind() == "decorator" {
extract_decorator_ref(child, source, file_path, def_sym_id.as_deref(), edges);
} else if child.kind() == "function_definition"
|| child.kind() == "class_definition"
{
extract_node(
queries,
child,
source,
file_path,
parent_id,
parent_qname,
symbols,
edges,
);
}
}
}
"import_statement" | "import_from_statement" => {
extract_import(
node,
source,
file_path,
parent_id,
parent_qname,
symbols,
edges,
);
}
"expression_statement" => {
for child in node.named_children(&mut node.walk()) {
if child.kind() == "assignment" {
extract_assignment(child, source, file_path, parent_id, parent_qname, symbols);
}
}
walk_for_calls_and_raises_q(queries, node, source, file_path, parent_id, edges);
}
_ => {
for child in node.named_children(&mut node.walk()) {
extract_node(
queries,
child,
source,
file_path,
parent_id,
parent_qname,
symbols,
edges,
);
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn extract_function(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
parent_id: Option<&str>,
parent_qname: Option<&str>,
symbols: &mut Vec<Symbol>,
edges: &mut Vec<Edge>,
) {
let name_ref = match node.child_by_field_name("name") {
Some(n) => node_text(n, source),
None => return,
};
let start_line = node.start_position().row as u32 + 1;
let end_line = node.end_position().row as u32 + 1;
let is_method = parent_id.is_some();
let kind = if is_method {
SymbolKind::Method
} else {
SymbolKind::Function
};
let visibility = python_visibility(name_ref);
let is_async = node
.prev_named_sibling()
.is_some_and(|s| s.kind() == "async")
|| node_text_slice(
source,
node.start_byte(),
node.start_byte().saturating_add(6),
)
.starts_with("async ");
let is_async = is_async
|| node.parent().is_some_and(|_p| {
node_text_slice(
source,
node.start_byte().saturating_sub(6),
node.start_byte(),
)
.contains("async")
});
let name = name_ref.to_string();
let signature = extract_signature(node, source);
let docstring = extract_docstring(node, source);
let sym_id = symbol_id(file_path, kind.as_str(), &name, parent_qname);
let mut sym = Symbol::new(
&name,
kind,
file_path,
start_line,
end_line,
node.start_byte() as u32,
node.end_byte() as u32,
parent_qname,
)
.with_parent(parent_id)
.with_signature(signature);
if visibility != Visibility::Public {
sym = sym.with_visibility(visibility);
}
if is_async {
sym = sym.with_async(true);
}
sym = sym.with_docstring(docstring);
symbols.push(sym);
extract_fn_type_refs(queries, node, source, file_path, &sym_id, edges);
if let Some(body) = node.child_by_field_name("body") {
walk_for_calls_and_raises_q(queries, body, source, file_path, Some(&sym_id), edges);
let child_qname = match parent_qname {
Some(pq) => format!("{pq}.{name}"),
None => name.clone(),
};
for child in body.named_children(&mut body.walk()) {
match child.kind() {
"function_definition" | "class_definition" | "decorated_definition" => {
extract_node(
queries,
child,
source,
file_path,
Some(&sym_id),
Some(&child_qname),
symbols,
edges,
);
}
_ => {}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn extract_class(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
parent_id: Option<&str>,
parent_qname: Option<&str>,
symbols: &mut Vec<Symbol>,
edges: &mut Vec<Edge>,
) {
let name_ref = match node.child_by_field_name("name") {
Some(n) => node_text(n, source),
None => return,
};
let start_line = node.start_position().row as u32 + 1;
let end_line = node.end_position().row as u32 + 1;
let visibility = python_visibility(name_ref);
let docstring = extract_docstring(node, source);
let name = name_ref.to_string();
let sym_id = symbol_id(file_path, SymbolKind::Class.as_str(), &name, parent_qname);
let mut sym = Symbol::new(
&name,
SymbolKind::Class,
file_path,
start_line,
end_line,
node.start_byte() as u32,
node.end_byte() as u32,
parent_qname,
)
.with_parent(parent_id)
.with_docstring(docstring);
if visibility != Visibility::Public {
sym = sym.with_visibility(visibility);
}
symbols.push(sym);
if let Some(args) = node.child_by_field_name("superclasses") {
for child in args.named_children(&mut args.walk()) {
let base_name = node_text(child, source);
if !base_name.is_empty() {
edges.push(Edge::new(
sym_id.clone(),
base_name,
EdgeKind::Inherits,
file_path,
child.start_position().row as u32 + 1,
));
}
}
}
if let Some(body) = node.child_by_field_name("body") {
let child_qname = match parent_qname {
Some(pq) => format!("{pq}.{name}"),
None => name.clone(),
};
for child in body.named_children(&mut body.walk()) {
extract_node(
queries,
child,
source,
file_path,
Some(&sym_id),
Some(&child_qname),
symbols,
edges,
);
}
}
}
fn extract_import(
node: Node,
source: &str,
file_path: &str,
parent_id: Option<&str>,
parent_qname: Option<&str>,
symbols: &mut Vec<Symbol>,
edges: &mut Vec<Edge>,
) {
let line = node.start_position().row as u32 + 1;
let import_text = node_text(node, source).to_string();
let module_name = extract_import_module(node, source);
if module_name.is_empty() {
return;
}
let sym_id = symbol_id(
file_path,
SymbolKind::Import.as_str(),
&module_name,
parent_qname,
);
symbols.push(
Symbol::new(
&module_name,
SymbolKind::Import,
file_path,
line,
line,
node.start_byte() as u32,
node.end_byte() as u32,
parent_qname,
)
.with_parent(parent_id)
.with_signature(Some(import_text)),
);
let names = extract_imported_names(node, source);
for imported in names {
edges.push(Edge::new(
sym_id.clone(),
imported,
EdgeKind::Imports,
file_path,
line,
));
}
}
fn extract_assignment(
node: Node,
source: &str,
file_path: &str,
parent_id: Option<&str>,
parent_qname: Option<&str>,
symbols: &mut Vec<Symbol>,
) {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "identifier" {
let name_ref = node_text(left, source);
let line = node.start_position().row as u32 + 1;
let visibility = python_visibility(name_ref);
let name = name_ref.to_string();
let mut sym = Symbol::new(
&name,
SymbolKind::Variable,
file_path,
line,
node.end_position().row as u32 + 1,
node.start_byte() as u32,
node.end_byte() as u32,
parent_qname,
)
.with_parent(parent_id);
if visibility != Visibility::Public {
sym = sym.with_visibility(visibility);
}
symbols.push(sym);
}
}
}
fn walk_for_calls_and_raises_q(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
context_id: Option<&str>,
edges: &mut Vec<Edge>,
) {
let Some(ctx) = context_id else { return };
let callee_idx = queries.call.capture_index("callee");
let exception_idx = queries.raise.capture_index("exception");
let except_type_idx = queries.except.capture_index("exception_type");
{
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&queries.call.query, node, source.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
if capture.index == callee_idx {
if is_inside_nested_scope(capture.node, node, PY_SCOPE_KINDS) {
continue;
}
let callee_name = node_text(capture.node, source);
if !callee_name.is_empty() {
edges.push(Edge::new(
ctx,
callee_name,
EdgeKind::Calls,
file_path,
capture.node.start_position().row as u32 + 1,
));
}
}
}
}
}
{
let mut cursor = QueryCursor::new();
let mut seen_raises = std::collections::HashSet::new();
let mut matches = cursor.matches(&queries.raise.query, node, source.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
if capture.index == exception_idx {
if is_inside_nested_scope(capture.node, node, PY_SCOPE_KINDS) {
continue;
}
let line = capture.node.start_position().row as u32 + 1;
let exc_name = node_text(capture.node, source);
if !exc_name.is_empty() && seen_raises.insert((exc_name.to_string(), line)) {
edges.push(Edge::new(ctx, exc_name, EdgeKind::Raises, file_path, line));
}
}
}
}
}
{
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&queries.except.query, node, source.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
if capture.index == except_type_idx {
if is_inside_nested_scope(capture.node, node, PY_SCOPE_KINDS) {
continue;
}
let type_name = node_text(capture.node, source);
if !type_name.is_empty()
&& type_name.chars().next().is_some_and(|c| c.is_uppercase())
{
edges.push(Edge::new(
ctx,
type_name,
EdgeKind::References,
file_path,
capture.node.start_position().row as u32 + 1,
));
}
}
}
}
}
}
const PY_SCOPE_KINDS: &[&str] = &["function_definition", "class_definition"];
fn extract_fn_type_refs(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
sym_id: &str,
edges: &mut Vec<Edge>,
) {
if let Some(params) = node.child_by_field_name("parameters") {
for param in params.named_children(&mut params.walk()) {
if let Some(type_node) = param.child_by_field_name("type") {
collect_type_refs(queries, type_node, source, file_path, sym_id, edges);
}
}
}
if let Some(ret) = node.child_by_field_name("return_type") {
collect_type_refs(queries, ret, source, file_path, sym_id, edges);
}
}
fn collect_type_refs(
queries: &Queries,
node: Node,
source: &str,
file_path: &str,
sym_id: &str,
edges: &mut Vec<Edge>,
) {
let type_ref_idx = queries.type_ref.capture_index("type_ref");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&queries.type_ref.query, node, source.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
if capture.index == type_ref_idx {
let name = node_text(capture.node, source);
if capture.node.kind() == "identifier" {
if !name.is_empty() && name.chars().next().is_some_and(|c| c.is_uppercase()) {
edges.push(Edge::new(
sym_id,
name,
EdgeKind::References,
file_path,
capture.node.start_position().row as u32 + 1,
));
}
} else if capture.node.kind() == "attribute" {
if !name.is_empty() {
edges.push(Edge::new(
sym_id,
name,
EdgeKind::References,
file_path,
capture.node.start_position().row as u32 + 1,
));
}
}
}
}
}
}
fn extract_decorator_ref(
node: Node,
source: &str,
file_path: &str,
context_id: Option<&str>,
edges: &mut Vec<Edge>,
) {
let Some(ctx) = context_id else { return };
for child in node.named_children(&mut node.walk()) {
let name = match child.kind() {
"identifier" | "attribute" => node_text(child, source).to_string(),
"call" => child
.child_by_field_name("function")
.map(|f| node_text(f, source).to_string())
.unwrap_or_default(),
_ => continue,
};
if !name.is_empty() {
edges.push(Edge::new(
ctx,
name,
EdgeKind::References,
file_path,
node.start_position().row as u32 + 1,
));
}
}
}
fn node_text_slice(source: &str, start: usize, end: usize) -> &str {
source
.get(start.min(source.len())..end.min(source.len()))
.unwrap_or("")
}
fn python_visibility(name: &str) -> Visibility {
if name.starts_with("__") && name.ends_with("__") {
Visibility::Public
} else if name.starts_with("__") {
Visibility::Private
} else if name.starts_with('_') {
Visibility::Protected
} else {
Visibility::Public
}
}
fn extract_signature(node: Node, source: &str) -> Option<String> {
let params_text = node_text(node.child_by_field_name("parameters")?, source);
let return_type = node
.child_by_field_name("return_type")
.map(|r| format!(" -> {}", node_text(r, source)));
Some(format!("{params_text}{}", return_type.unwrap_or_default()))
}
fn extract_docstring(node: Node, source: &str) -> Option<String> {
let body = node.child_by_field_name("body")?;
let first = body.named_child(0)?;
if first.kind() != "expression_statement" {
return None;
}
let expr = first.named_child(0)?;
if expr.kind() != "string" {
return None;
}
let text = node_text(expr, source);
let inner = text
.strip_prefix("\"\"\"")
.and_then(|s| s.strip_suffix("\"\"\""))
.or_else(|| text.strip_prefix("'''").and_then(|s| s.strip_suffix("'''")))?;
let trimmed = inner.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn extract_import_module(node: Node, source: &str) -> String {
match node.kind() {
"import_statement" => {
for child in node.named_children(&mut node.walk()) {
if child.kind() == "dotted_name" {
return node_text(child, source).to_string();
}
if child.kind() == "aliased_import" {
if let Some(name) = child.child_by_field_name("name") {
return node_text(name, source).to_string();
}
}
}
String::new()
}
"import_from_statement" => {
if let Some(module) = node.child_by_field_name("module_name") {
node_text(module, source).to_string()
} else {
for child in node.named_children(&mut node.walk()) {
if child.kind() == "dotted_name" || child.kind() == "relative_import" {
return node_text(child, source).to_string();
}
}
String::new()
}
}
_ => String::new(),
}
}
fn extract_imported_names(node: Node, source: &str) -> Vec<String> {
let mut names = Vec::new();
match node.kind() {
"import_statement" => {
for child in node.named_children(&mut node.walk()) {
if child.kind() == "dotted_name" {
names.push(node_text(child, source).to_string());
} else if child.kind() == "aliased_import" {
if let Some(name) = child.child_by_field_name("name") {
names.push(node_text(name, source).to_string());
}
}
}
}
"import_from_statement" => {
for child in node.named_children(&mut node.walk()) {
match child.kind() {
"dotted_name" => {
}
"aliased_import" => {
if let Some(name) = child.child_by_field_name("name") {
names.push(node_text(name, source).to_string());
}
}
_ => {}
}
}
if names.is_empty() {
let mut past_import = false;
for i in 0..node.child_count() {
if let Some(child) = node.child(i as u32) {
if child.kind() == "import" {
past_import = true;
continue;
}
if past_import
&& (child.kind() == "dotted_name" || child.kind() == "identifier")
{
names.push(node_text(child, source).to_string());
}
}
}
}
}
_ => {}
}
names
}
#[cfg(test)]
mod tests {
use super::*;
fn extract(source: &str) -> ExtractionResult {
let mut ext = PythonExtractor::new();
ext.extract(source, "test.py").unwrap()
}
#[test]
fn test_simple_function() {
let result = extract(
r#"
def hello(name: str) -> str:
"""Greet someone."""
return f"Hello, {name}!"
"#,
);
assert_eq!(result.symbols.len(), 1);
assert_eq!(result.symbols[0].name, "hello");
assert_eq!(result.symbols[0].kind, SymbolKind::Function);
assert_eq!(
result.symbols[0].signature.as_deref(),
Some("(name: str) -> str")
);
assert_eq!(
result.symbols[0].docstring.as_deref(),
Some("Greet someone.")
);
}
#[test]
fn test_class_with_methods() {
let result = extract(
r#"
class UserService:
"""Manages users."""
def __init__(self, db):
self.db = db
def get_user(self, user_id: int) -> User:
return self.db.find(user_id)
def _internal_method(self):
pass
"#,
);
let class = result.symbols.iter().find(|s| s.name == "UserService");
assert!(class.is_some());
assert_eq!(class.unwrap().kind, SymbolKind::Class);
let init = result.symbols.iter().find(|s| s.name == "__init__");
assert!(init.is_some());
assert_eq!(init.unwrap().kind, SymbolKind::Method);
assert!(init.unwrap().parent_id.is_some());
let internal = result.symbols.iter().find(|s| s.name == "_internal_method");
assert!(internal.is_some());
assert_eq!(internal.unwrap().visibility, Visibility::Protected);
}
#[test]
fn test_inheritance() {
let result = extract(
r#"
class AdminService(UserService, BaseService):
pass
"#,
);
let inherits: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Inherits)
.collect();
assert_eq!(inherits.len(), 2);
let targets: Vec<&str> = inherits.iter().map(|e| e.target_name.as_str()).collect();
assert!(targets.contains(&"UserService"));
assert!(targets.contains(&"BaseService"));
}
#[test]
fn test_function_calls() {
let result = extract(
r#"
def process():
data = fetch_data()
result = transform(data)
save(result)
"#,
);
let calls: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Calls)
.collect();
assert_eq!(calls.len(), 3);
let targets: Vec<&str> = calls.iter().map(|e| e.target_name.as_str()).collect();
assert!(targets.contains(&"fetch_data"));
assert!(targets.contains(&"transform"));
assert!(targets.contains(&"save"));
}
#[test]
fn test_imports() {
let result = extract(
r#"
import os
from pathlib import Path
from typing import Optional, List
"#,
);
let imports: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Import)
.collect();
assert!(!imports.is_empty());
let import_edges: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Imports)
.collect();
assert!(!import_edges.is_empty());
}
#[test]
fn test_raises() {
let result = extract(
r#"
def validate(x):
if x < 0:
raise ValueError("negative")
if x > 100:
raise RuntimeError("too large")
"#,
);
let raises: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Raises)
.collect();
assert_eq!(raises.len(), 2);
}
#[test]
fn test_private_naming() {
let result = extract(
r#"
class Foo:
def public_method(self): pass
def _protected_method(self): pass
def __private_method(self): pass
def __dunder__(self): pass
"#,
);
let public = result.symbols.iter().find(|s| s.name == "public_method");
assert_eq!(public.unwrap().visibility, Visibility::Public);
let protected = result
.symbols
.iter()
.find(|s| s.name == "_protected_method");
assert_eq!(protected.unwrap().visibility, Visibility::Protected);
let private = result.symbols.iter().find(|s| s.name == "__private_method");
assert_eq!(private.unwrap().visibility, Visibility::Private);
let dunder = result.symbols.iter().find(|s| s.name == "__dunder__");
assert_eq!(dunder.unwrap().visibility, Visibility::Public);
}
#[test]
fn test_async_function() {
let result = extract(
r#"
async def fetch(url: str) -> Response:
return await http.get(url)
class Service:
async def process(self, data):
pass
"#,
);
let fetch = result.symbols.iter().find(|s| s.name == "fetch").unwrap();
assert_eq!(fetch.kind, SymbolKind::Function);
assert!(fetch.is_async);
let process = result.symbols.iter().find(|s| s.name == "process").unwrap();
assert_eq!(process.kind, SymbolKind::Method);
assert!(process.is_async);
}
#[test]
fn test_module_level_assignment() {
let result = extract(
r#"
MAX_RETRIES = 3
_internal_cache = {}
__private_lock = None
"#,
);
let vars: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Variable)
.collect();
assert_eq!(vars.len(), 3);
let max = vars.iter().find(|s| s.name == "MAX_RETRIES").unwrap();
assert_eq!(max.visibility, Visibility::Public);
let internal = vars.iter().find(|s| s.name == "_internal_cache").unwrap();
assert_eq!(internal.visibility, Visibility::Protected);
let private = vars.iter().find(|s| s.name == "__private_lock").unwrap();
assert_eq!(private.visibility, Visibility::Private);
}
#[test]
fn test_aliased_import() {
let result = extract(
r#"
import numpy as np
from collections import OrderedDict as ODict
"#,
);
let imports: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Import)
.collect();
assert_eq!(imports.len(), 2);
let names: Vec<&str> = imports.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"numpy"));
assert!(names.contains(&"collections"));
}
#[test]
fn test_empty_file() {
let result = extract("");
assert!(result.symbols.is_empty());
assert!(result.edges.is_empty());
}
#[test]
fn test_syntax_error_partial_parse() {
let result = extract("def broken(:\n pass");
let _ = result.symbols.len();
}
#[test]
fn test_type_annotation_refs() {
let result = extract(
r#"
def process(user: User, count: int) -> Response:
pass
"#,
);
let refs: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::References)
.collect();
let targets: Vec<&str> = refs.iter().map(|e| e.target_name.as_str()).collect();
assert!(targets.contains(&"User"));
assert!(targets.contains(&"Response"));
assert!(!targets.contains(&"int"));
}
#[test]
fn test_decorator_refs() {
let result = extract(
r#"
@login_required
def protected():
pass
@app.route("/api")
def endpoint():
pass
"#,
);
let refs: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::References)
.collect();
let targets: Vec<&str> = refs.iter().map(|e| e.target_name.as_str()).collect();
assert!(targets.contains(&"login_required"));
assert!(targets.contains(&"app.route"));
}
#[test]
fn test_except_clause_refs() {
let result = extract(
r#"
def risky():
try:
pass
except ValueError:
pass
except (TypeError, KeyError):
pass
"#,
);
let refs: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::References)
.collect();
let targets: Vec<&str> = refs.iter().map(|e| e.target_name.as_str()).collect();
assert!(targets.contains(&"ValueError"));
assert!(targets.contains(&"TypeError"));
assert!(targets.contains(&"KeyError"));
}
#[test]
fn test_imports_specific_names() {
let result = extract(
r#"
from typing import Optional, List
"#,
);
let import_edges: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Imports)
.collect();
let targets: Vec<&str> = import_edges
.iter()
.map(|e| e.target_name.as_str())
.collect();
assert!(targets.contains(&"Optional"));
assert!(targets.contains(&"List"));
}
#[test]
fn stable_id_survives_line_movement() {
let source_v1 = r#"
def foo():
pass
def bar():
pass
"#;
let source_v2 = r#"
# added comment
def foo():
pass
# another comment
def bar():
pass
"#;
let r1 = extract(source_v1);
let r2 = extract(source_v2);
let foo_v1 = r1.symbols.iter().find(|s| s.name == "foo").unwrap();
let foo_v2 = r2.symbols.iter().find(|s| s.name == "foo").unwrap();
assert_eq!(
foo_v1.id, foo_v2.id,
"ID should be stable across line moves"
);
assert_ne!(foo_v1.start_line, foo_v2.start_line, "lines should differ");
let bar_v1 = r1.symbols.iter().find(|s| s.name == "bar").unwrap();
let bar_v2 = r2.symbols.iter().find(|s| s.name == "bar").unwrap();
assert_eq!(bar_v1.id, bar_v2.id);
}
#[test]
fn stable_id_method_includes_class_name() {
let result = extract(
r#"
class MyService:
def handle(self):
pass
"#,
);
let method = result.symbols.iter().find(|s| s.name == "handle").unwrap();
assert_eq!(method.id, "test.py:method:MyService.handle");
assert_eq!(method.kind, SymbolKind::Method);
}
}