use once_cell::sync::Lazy;
use rayon::prelude::*;
use scribe_core::Result;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tree_sitter::{Language, Node, Parser, Tree, TreeCursor};
#[derive(Debug, Clone)]
pub struct SimpleImport {
pub module: String,
pub line_number: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ImportLanguage {
Python,
JavaScript,
TypeScript,
Go,
Rust,
}
impl ImportLanguage {
pub fn tree_sitter_language(&self) -> Language {
match self {
ImportLanguage::Python => tree_sitter_python::language(),
ImportLanguage::JavaScript => tree_sitter_javascript::language(),
ImportLanguage::TypeScript => tree_sitter_typescript::language_typescript(),
ImportLanguage::Go => tree_sitter_go::language(),
ImportLanguage::Rust => tree_sitter_rust::language(),
}
}
pub fn from_extension(ext: &str) -> Option<Self> {
match ext.to_lowercase().as_str() {
"py" | "pyi" | "pyw" => Some(ImportLanguage::Python),
"js" | "mjs" | "cjs" => Some(ImportLanguage::JavaScript),
"ts" | "mts" | "cts" => Some(ImportLanguage::TypeScript),
"go" => Some(ImportLanguage::Go),
"rs" => Some(ImportLanguage::Rust),
_ => None,
}
}
}
static PARSER_POOL: Lazy<Arc<Mutex<HashMap<ImportLanguage, Vec<Parser>>>>> =
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
const IMPORT_NODE_TYPES: &[&str] = &[
"import_statement",
"import_from_statement",
"use_declaration",
"import_declaration",
"import_spec",
"source_file",
"module",
];
pub struct SimpleAstParser {
}
impl std::fmt::Debug for SimpleAstParser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimpleAstParser")
.field("parsers", &"[reusable pool]")
.finish()
}
}
impl SimpleAstParser {
pub fn new() -> Result<Self> {
Self::ensure_parser_pool_initialized()?;
Ok(Self {})
}
fn ensure_parser_pool_initialized() -> Result<()> {
let mut pool = PARSER_POOL.lock().unwrap();
for language in [
ImportLanguage::Python,
ImportLanguage::JavaScript,
ImportLanguage::TypeScript,
ImportLanguage::Go,
ImportLanguage::Rust,
] {
if !pool.contains_key(&language) {
let mut parser = Parser::new();
parser
.set_language(language.tree_sitter_language())
.map_err(|e| {
scribe_core::ScribeError::parse(format!(
"Failed to set tree-sitter language: {}",
e
))
})?;
pool.insert(language, vec![parser]);
}
}
Ok(())
}
fn get_parser(&self, language: ImportLanguage) -> Result<Parser> {
let mut pool = PARSER_POOL.lock().unwrap();
if let Some(parsers) = pool.get_mut(&language) {
if let Some(parser) = parsers.pop() {
return Ok(parser);
}
}
let mut parser = Parser::new();
parser
.set_language(language.tree_sitter_language())
.map_err(|e| {
scribe_core::ScribeError::parse(format!(
"Failed to set tree-sitter language: {}",
e
))
})?;
Ok(parser)
}
fn return_parser(&self, language: ImportLanguage, parser: Parser) {
let mut pool = PARSER_POOL.lock().unwrap();
pool.entry(language).or_insert_with(Vec::new).push(parser);
}
pub fn extract_imports(
&self,
content: &str,
language: ImportLanguage,
) -> Result<Vec<SimpleImport>> {
let mut parser = self.get_parser(language)?;
let tree = parser
.parse(content, None)
.ok_or_else(|| scribe_core::ScribeError::parse("Failed to parse content"))?;
let mut imports = Vec::new();
let mut cursor = tree.walk();
self.extract_imports_with_cursor(&mut cursor, content, language, &mut imports)?;
self.return_parser(language, parser);
Ok(imports)
}
fn extract_imports_with_cursor(
&self,
cursor: &mut TreeCursor,
content: &str,
language: ImportLanguage,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
let node = cursor.node();
if !self.node_can_contain_imports(node.kind()) {
return Ok(());
}
if self.is_import_node(node.kind()) {
self.extract_import_from_node(node, content, language, imports)?;
}
if cursor.goto_first_child() {
loop {
self.extract_imports_with_cursor(cursor, content, language, imports)?;
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
Ok(())
}
fn node_can_contain_imports(&self, kind: &str) -> bool {
IMPORT_NODE_TYPES.contains(&kind)
|| kind.contains("import")
|| kind.contains("use")
|| kind == "program"
|| kind == "translation_unit"
|| kind == "block"
|| kind == "statement_block"
}
fn is_import_node(&self, kind: &str) -> bool {
matches!(
kind,
"import_statement"
| "import_from_statement"
| "use_declaration"
| "import_declaration"
| "import_spec"
)
}
fn extract_import_from_node(
&self,
node: Node,
content: &str,
language: ImportLanguage,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
match language {
ImportLanguage::Python => {
self.extract_python_import_node(node, content, imports)?;
}
ImportLanguage::JavaScript | ImportLanguage::TypeScript => {
self.extract_js_ts_import_node(node, content, imports)?;
}
ImportLanguage::Go => {
self.extract_go_import_node(node, content, imports)?;
}
ImportLanguage::Rust => {
self.extract_rust_import_node(node, content, imports)?;
}
}
Ok(())
}
fn extract_python_import_node(
&self,
node: Node,
content: &str,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
if node.kind() == "import_statement" {
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "dotted_name" || child.kind() == "identifier" {
let module = self.node_text(child, content);
let line_number = child.start_position().row + 1;
imports.push(SimpleImport {
module,
line_number,
});
}
if !cursor.goto_next_sibling() {
break;
}
}
}
} else if node.kind() == "import_from_statement" {
if let Some(module_node) = node.child_by_field_name("module_name") {
let module = self.node_text(module_node, content);
let line_number = node.start_position().row + 1;
imports.push(SimpleImport {
module,
line_number,
});
}
}
Ok(())
}
fn extract_js_ts_import_node(
&self,
node: Node,
content: &str,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
if node.kind() == "import_statement" {
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "string" {
let mut module = self.node_text(child, content);
module = module.trim_matches('"').trim_matches('\'').to_string();
let line_number = node.start_position().row + 1;
imports.push(SimpleImport {
module,
line_number,
});
break;
}
if !cursor.goto_next_sibling() {
break;
}
}
}
}
Ok(())
}
fn extract_go_import_node(
&self,
node: Node,
content: &str,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
if node.kind() == "import_spec" {
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "interpreted_string_literal" {
let module = self.node_text(child, content);
let module = module.trim_matches('"').to_string();
let line_number = child.start_position().row + 1;
imports.push(SimpleImport {
module,
line_number,
});
}
if !cursor.goto_next_sibling() {
break;
}
}
}
}
Ok(())
}
fn extract_rust_import_node(
&self,
node: Node,
content: &str,
imports: &mut Vec<SimpleImport>,
) -> Result<()> {
if node.kind() == "use_declaration" {
if let Some(use_tree) = node.child_by_field_name("argument") {
let module = self.node_text(use_tree, content);
let line_number = node.start_position().row + 1;
imports.push(SimpleImport {
module,
line_number,
});
}
}
Ok(())
}
fn node_text(&self, node: Node, content: &str) -> String {
content[node.start_byte()..node.end_byte()].to_string()
}
pub fn extract_imports_parallel(
&self,
files: &[(String, String, ImportLanguage)], ) -> Result<Vec<(String, Vec<SimpleImport>)>> {
files
.par_iter()
.map(|(path, content, language)| {
let imports = self.extract_imports(content, *language)?;
Ok((path.clone(), imports))
})
.collect()
}
pub fn extract_imports_batch(
&self,
contents: &[&str],
language: ImportLanguage,
) -> Result<Vec<Vec<SimpleImport>>> {
contents
.par_iter()
.map(|content| self.extract_imports(content, language))
.collect()
}
}
impl Default for SimpleAstParser {
fn default() -> Self {
Self::new().expect("Failed to create SimpleAstParser")
}
}