use crate::error::Result;
use crate::ingest::imports::{ImportExtractor, ImportKind};
use crate::symbol::Language;
use std::path::Path;
pub struct PythonExtractor;
impl ImportExtractor for PythonExtractor {
fn language() -> tree_sitter::Language {
tree_sitter_python::language()
}
fn language_enum() -> Language {
Language::Python
}
fn extract_from_node(
node: tree_sitter::Node,
source: &[u8],
imports: &mut Vec<super::ImportFact>,
) {
extract_import_statements(node, source, imports);
}
}
pub fn extract_python_imports(path: &Path, source: &[u8]) -> Result<Vec<super::ImportFact>> {
PythonExtractor::extract(path, source)
}
fn extract_import_statements(
node: tree_sitter::Node,
source: &[u8],
imports: &mut Vec<super::ImportFact>,
) {
match node.kind() {
"import_statement" => {
if let Some(mut stmt_imports) = extract_import_statement(node, source) {
imports.append(&mut stmt_imports);
}
}
"import_from_statement" => {
if let Some(import) = extract_import_from_statement(node, source) {
imports.push(import);
}
}
_ => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_import_statements(child, source, imports);
}
}
}
}
fn extract_import_statement(
node: tree_sitter::Node,
source: &[u8],
) -> Option<Vec<super::ImportFact>> {
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let mut result = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"dotted_name" => {
let path = extract_dotted_name_path(child, source);
if !path.is_empty() {
let imported_name = path
.first()
.cloned()
.expect("path confirmed non-empty by is_empty() check");
result.push(super::ImportFact {
file_path: std::path::PathBuf::new(),
import_kind: ImportKind::PythonImport,
path: path.clone(),
imported_names: vec![imported_name],
is_glob: false,
is_reexport: false,
byte_span: (byte_start, byte_end),
});
}
}
"aliased_import" => {
if let Some(import) = extract_aliased_import(
child,
source,
byte_start,
byte_end,
ImportKind::PythonImport,
) {
result.push(import);
}
}
_ => {}
}
}
if result.is_empty() {
None
} else {
Some(result)
}
}
fn extract_import_from_statement(
node: tree_sitter::Node,
source: &[u8],
) -> Option<super::ImportFact> {
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let mut cursor = node.walk();
let mut path = Vec::new();
let mut import_kind = ImportKind::PythonFrom;
let mut imported_names = Vec::new();
let mut is_glob = false;
let mut stage = 0;
for child in node.children(&mut cursor) {
match child.kind() {
"from" => {
stage = 1;
continue;
}
"import" => {
stage = 2;
continue;
}
"relative_import" => {
let text = child.utf8_text(source).ok()?;
let relative_level = text.matches('.').count();
import_kind = match relative_level {
1 => ImportKind::PythonFromRelative,
2 => ImportKind::PythonFromParent,
_ => ImportKind::PythonFromAncestor,
};
let mut dot_cursor = child.walk();
let mut has_module_name = false;
for sub_child in child.children(&mut dot_cursor) {
if sub_child.kind() == "dotted_name" {
let module_path = extract_dotted_name_path(sub_child, source);
if !module_path.is_empty() {
path.push(".".repeat(relative_level));
path.extend(module_path);
has_module_name = true;
}
}
}
if !has_module_name {
path.push(".".repeat(relative_level));
}
}
"dotted_name" => {
let name_path = extract_dotted_name_path(child, source);
if name_path.is_empty() {
continue;
}
if stage == 1 {
path.extend(name_path);
} else if stage == 2 {
if let Some(name) = name_path.last().cloned() {
imported_names.push(name);
}
}
}
"aliased_import" => {
if stage == 2 {
if let Some(alias) = extract_alias_name(child, source) {
imported_names.push(alias);
}
}
}
"wildcard_import" => {
imported_names.push("*".to_string());
is_glob = true;
}
_ => {}
}
}
Some(super::ImportFact {
file_path: std::path::PathBuf::new(),
import_kind,
path,
imported_names,
is_glob,
is_reexport: false,
byte_span: (byte_start, byte_end),
})
}
fn extract_dotted_name_path(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
let mut path = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
if let Ok(text) = child.utf8_text(source) {
path.push(text.to_string());
}
}
}
path
}
fn extract_alias_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
let mut cursor = node.walk();
let children: Vec<_> = node.children(&mut cursor).collect();
if children.len() >= 3 {
if let Ok(alias) = children[2].utf8_text(source) {
return Some(alias.to_string());
}
}
None
}
fn extract_aliased_import(
node: tree_sitter::Node,
source: &[u8],
byte_start: usize,
byte_end: usize,
kind: ImportKind,
) -> Option<super::ImportFact> {
let mut cursor = node.walk();
let children: Vec<_> = node.children(&mut cursor).collect();
let mut path = Vec::new();
if let Some(dotted_name) = children.first() {
path = extract_dotted_name_path(*dotted_name, source);
}
let imported_name = if children.len() >= 3 {
children[2].utf8_text(source).ok()?.to_string()
} else {
path.first().cloned().unwrap_or_default()
};
Some(super::ImportFact {
file_path: std::path::PathBuf::new(),
import_kind: kind,
path,
imported_names: vec![imported_name],
is_glob: false,
is_reexport: false,
byte_span: (byte_start, byte_end),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_simple_import_basic() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import os\n";
let path = Path::new("test.py");
let result = extract_python_imports(path, source)?;
assert_eq!(result.len(), 1);
assert_eq!(result[0].import_kind, ImportKind::PythonImport);
Ok(())
}
#[test]
fn test_empty_path_handling() {
let source = b"import\n"; let path = Path::new("test.py");
let result = extract_python_imports(path, source);
match result {
Ok(imports) => {
for imp in imports {
if imp.path.is_empty() {
assert!(
imp.imported_names.is_empty()
|| imp.imported_names.iter().all(|n| n.is_empty())
);
}
}
}
Err(_) => {
}
}
}
}