splice 2.6.4

Span-safe refactoring kernel for 7 languages with Magellan code graph integration
Documentation
//! JavaScript/TypeScript import statement extraction.
//!
//! Uses tree-sitter-javascript to parse and extract ES6 `import` statements and CommonJS `require()` calls.

use crate::error::{Result, SpliceError};
use crate::ingest::imports::ImportKind;
use std::path::Path;

/// Strip surrounding quotes from a string.
/// Character-safe for UTF-8: works with char indices, not byte offsets.
fn strip_quotes(text: &str) -> String {
    let chars: Vec<char> = text.chars().collect();
    if chars.len() >= 2 {
        match (chars.first(), chars.last()) {
            (Some('"'), Some('"')) | (Some('\''), Some('\'')) => {
                chars[1..chars.len() - 1].iter().collect()
            }
            _ => text.to_string(),
        }
    } else {
        text.to_string()
    }
}

/// Extract import statements from a JavaScript/TypeScript source file.
///
/// Uses tree-sitter-javascript to parse the file and extract:
/// - ES6 `import` statements (named, default, namespace, side-effect)
/// - CommonJS `require()` calls
///
/// # Examples
///
/// ```
/// # use splice::ingest::imports::{extract_javascript_imports, ImportKind};
/// # use std::path::Path;
/// let source = b"import { foo } from 'bar';\n";
/// let imports = extract_javascript_imports(Path::new("test.js"), source)?;
/// assert_eq!(imports[0].import_kind, ImportKind::JsImport);
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn extract_javascript_imports(path: &Path, source: &[u8]) -> Result<Vec<super::ImportFact>> {
    // Create tree-sitter parser for JavaScript
    let mut parser = tree_sitter::Parser::new();
    parser
        .set_language(&tree_sitter_javascript::language())
        .map_err(|e| SpliceError::Parse {
            file: path.to_path_buf(),
            message: format!("Failed to set JavaScript language: {:?}", e),
        })?;

    // Parse the source code
    let tree = parser
        .parse(source, None)
        .ok_or_else(|| SpliceError::Parse {
            file: path.to_path_buf(),
            message: "Parse failed - no tree returned".to_string(),
        })?;

    // Extract imports from the AST
    let mut imports = Vec::new();
    extract_import_statements(tree.root_node(), source, &mut imports);

    Ok(imports)
}

/// Extract import statements from AST nodes.
fn extract_import_statements(
    node: tree_sitter::Node,
    source: &[u8],
    imports: &mut Vec<super::ImportFact>,
) {
    let kind = node.kind();

    // Check for ES6 import statements
    if kind == "import_statement" {
        if let Some(import) = extract_import_statement(node, source) {
            imports.push(import);
        }
        return; // Don't recurse into import_statement
    }

    // Check for CommonJS require calls
    if kind == "variable_declarator" {
        if let Some(import) = extract_require_call(node, source) {
            imports.push(import);
        }
        return; // Don't recurse into processed variable_declarator
    }

    // Recursively process children
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        extract_import_statements(child, source, imports);
    }
}

/// Extract a single import_statement from a tree-sitter node.
fn extract_import_statement(node: tree_sitter::Node, source: &[u8]) -> Option<super::ImportFact> {
    let byte_start = node.start_byte();
    let byte_end = node.end_byte();

    let mut source_path = String::new();
    let mut imported_names = Vec::new();
    let mut import_kind = ImportKind::JsImport;
    let mut is_glob = false;

    // Get the source string (it's a named field)
    if let Some(source_node) = node.child_by_field_name("source") {
        // The string node contains string_fragment
        for sub_child in source_node.children(&mut source_node.walk()) {
            if sub_child.kind() == "string_fragment" {
                if let Ok(text) = sub_child.utf8_text(source) {
                    source_path = text.to_string();
                }
            }
        }
    }

    // Get the import_clause (it's a direct child, not a named field)
    for child in node.children(&mut node.walk()) {
        if child.kind() == "import_clause" {
            // Determine the type of import and extract names
            for sub_child in child.children(&mut child.walk()) {
                match sub_child.kind() {
                    "identifier" => {
                        // Default import: `import foo from 'bar'`
                        if let Ok(name) = sub_child.utf8_text(source) {
                            imported_names.push(name.to_string());
                            import_kind = ImportKind::JsDefaultImport;
                        }
                    }
                    "named_imports" => {
                        // Named imports: `import { foo, bar } from 'baz'`
                        import_kind = ImportKind::JsImport;
                        for named in sub_child.children(&mut sub_child.walk()) {
                            if named.kind() == "import_specifier" {
                                // Get the local name (identifier after "as" if present)
                                if let Some(local_name_node) = named.child_by_field_name("alias") {
                                    if let Ok(name) = local_name_node.utf8_text(source) {
                                        imported_names.push(name.to_string());
                                    }
                                } else if let Some(name_node) = named.child_by_field_name("name") {
                                    if let Ok(name) = name_node.utf8_text(source) {
                                        imported_names.push(name.to_string());
                                    }
                                } else {
                                    // Fallback: iterate children
                                    for name_node in named.children(&mut named.walk()) {
                                        if name_node.kind() == "identifier"
                                            || name_node.kind() == "property_identifier"
                                        {
                                            if let Ok(name) = name_node.utf8_text(source) {
                                                imported_names.push(name.to_string());
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                    "namespace_import" => {
                        // Namespace import: `import * as foo from 'bar'`
                        import_kind = ImportKind::JsNamespaceImport;
                        is_glob = true;
                        for name_node in sub_child.children(&mut sub_child.walk()) {
                            if name_node.kind() == "identifier" {
                                if let Ok(name) = name_node.utf8_text(source) {
                                    imported_names.push(name.to_string());
                                }
                            }
                        }
                    }
                    _ => {}
                }
            }
        }
    }

    // Handle side-effect imports: `import 'bar'`
    if imported_names.is_empty() && !source_path.is_empty() {
        import_kind = ImportKind::JsSideEffectImport;
    }

    if source_path.is_empty() {
        return None;
    }

    // Parse the source path to extract segments
    let path_parts: Vec<String> = source_path.split('/').map(|s| s.to_string()).collect();

    Some(super::ImportFact {
        file_path: std::path::PathBuf::new(),
        import_kind,
        path: path_parts,
        imported_names,
        is_glob,
        is_reexport: false,
        byte_span: (byte_start, byte_end),
    })
}

/// Extract CommonJS require() calls from a variable_declarator node.
fn extract_require_call(node: tree_sitter::Node, source: &[u8]) -> Option<super::ImportFact> {
    let byte_start = node.start_byte();
    let byte_end = node.end_byte();

    let mut source_path = String::new();
    let mut variable_name = String::new();

    // Look for pattern: const x = require('module')
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "identifier" {
            if let Ok(name) = child.utf8_text(source) {
                variable_name = name.to_string();
            }
        }
        if child.kind() == "call_expression" {
            // Check if this is a require() call
            for sub_child in child.children(&mut child.walk()) {
                if sub_child.kind() == "identifier" {
                    if let Ok(name) = sub_child.utf8_text(source) {
                        if name != "require" {
                            return None;
                        }
                    }
                }
                if sub_child.kind() == "arguments" {
                    for arg in sub_child.children(&mut sub_child.walk()) {
                        if arg.kind() == "string" {
                            if let Ok(text) = arg.utf8_text(source) {
                                source_path = strip_quotes(text);
                            }
                        }
                    }
                }
            }
        }
    }

    if source_path.is_empty() {
        return None;
    }

    let path_parts: Vec<String> = source_path.split('/').map(|s| s.to_string()).collect();

    Some(super::ImportFact {
        file_path: std::path::PathBuf::new(),
        import_kind: ImportKind::JsRequire,
        path: path_parts,
        imported_names: vec![variable_name],
        is_glob: false,
        is_reexport: false,
        byte_span: (byte_start, byte_end),
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_named_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import { foo } from 'bar';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].import_kind, ImportKind::JsImport);
        assert_eq!(imports[0].path, vec!["bar"]);
        assert_eq!(imports[0].imported_names, vec!["foo"]);
        Ok(())
    }

    #[test]
    fn test_extract_default_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import baz from 'module';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].import_kind, ImportKind::JsDefaultImport);
        assert_eq!(imports[0].imported_names, vec!["baz"]);
        Ok(())
    }

    #[test]
    fn test_extract_namespace_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import * as utils from './utils';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].import_kind, ImportKind::JsNamespaceImport);
        assert!(imports[0].is_glob);
        assert_eq!(imports[0].imported_names, vec!["utils"]);
        Ok(())
    }

    #[test]
    fn test_extract_side_effect_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import 'polyfills';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].import_kind, ImportKind::JsSideEffectImport);
        assert_eq!(imports[0].path, vec!["polyfills"]);
        Ok(())
    }

    #[test]
    fn test_extract_require_call() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"const fs = require('fs');\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].import_kind, ImportKind::JsRequire);
        assert_eq!(imports[0].path, vec!["fs"]);
        assert_eq!(imports[0].imported_names, vec!["fs"]);
        Ok(())
    }

    #[test]
    fn test_extract_multiple_imports() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import { foo, bar } from 'baz';\nimport qux from 'module';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 2);
        Ok(())
    }

    #[test]
    fn test_extract_nested_path_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
        let source = b"import { Component } from '@react/core';\n";
        let path = Path::new("test.js");
        let result = extract_javascript_imports(path, source);
        assert!(result.is_ok());
        let imports = result?;
        assert_eq!(imports.len(), 1);
        assert_eq!(imports[0].path, vec!["@react", "core"]);
        Ok(())
    }
}