Skip to main content

aft/
imports.rs

1//! Import analysis engine: parsing, grouping, deduplication, and insertion.
2//!
3//! Provides per-language import handling dispatched by `LangId`. Each language
4//! implementation extracts imports from tree-sitter ASTs, classifies them into
5//! groups, and generates import text.
6//!
7//! Currently supports: TypeScript, TSX, JavaScript.
8
9use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15// ---------------------------------------------------------------------------
16// Shared types
17// ---------------------------------------------------------------------------
18
19/// What kind of import this is.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22    /// `import { X } from 'y'` or `import X from 'y'`
23    Value,
24    /// `import type { X } from 'y'`
25    Type,
26    /// `import './side-effect'`
27    SideEffect,
28}
29
30/// Which logical group an import belongs to (language-specific).
31///
32/// Ordering matches conventional import group sorting:
33///   Stdlib (first) < External < Internal (last)
34///
35/// Language mapping:
36///   - TS/JS/TSX: External (no `.` prefix), Internal (`.`/`..` prefix)
37///   - Python:    Stdlib, External (third-party), Internal (relative `.`/`..`)
38///   - Rust:      Stdlib (std/core/alloc), External (crates), Internal (crate/self/super)
39///   - Go:        Stdlib (no dots in path), External (dots in path)
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42    /// Standard library (Python stdlib, Rust std/core/alloc, Go stdlib).
43    /// TS/JS don't use this group.
44    Stdlib,
45    /// External/third-party packages.
46    External,
47    /// Internal/relative imports (TS relative, Python local, Rust crate/self/super).
48    Internal,
49}
50
51impl ImportGroup {
52    /// Human-readable label for the group.
53    pub fn label(&self) -> &'static str {
54        match self {
55            ImportGroup::Stdlib => "stdlib",
56            ImportGroup::External => "external",
57            ImportGroup::Internal => "internal",
58        }
59    }
60}
61
62/// A single parsed import statement.
63#[derive(Debug, Clone)]
64pub struct ImportStatement {
65    /// The module path (e.g., `react`, `./utils`, `../config`).
66    pub module_path: String,
67    /// Named imports (e.g., `["useState", "useEffect"]`).
68    pub names: Vec<String>,
69    /// Default import name (e.g., `React` from `import React from 'react'`).
70    pub default_import: Option<String>,
71    /// Namespace import name (e.g., `path` from `import * as path from 'path'`).
72    pub namespace_import: Option<String>,
73    /// What kind: value, type, or side-effect.
74    pub kind: ImportKind,
75    /// Which group this import belongs to.
76    pub group: ImportGroup,
77    /// Byte range in the original source.
78    pub byte_range: Range<usize>,
79    /// Raw text of the import statement.
80    pub raw_text: String,
81}
82
83/// A block of parsed imports from a file.
84#[derive(Debug, Clone)]
85pub struct ImportBlock {
86    /// All parsed import statements, in source order.
87    pub imports: Vec<ImportStatement>,
88    /// Overall byte range covering all import statements (start of first to end of last).
89    /// `None` if no imports found.
90    pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94    pub fn empty() -> Self {
95        ImportBlock {
96            imports: Vec::new(),
97            byte_range: None,
98        }
99    }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103    imports.first().zip(imports.last()).map(|(first, last)| {
104        let start = first.byte_range.start;
105        let end = last.byte_range.end;
106        start..end
107    })
108}
109
110// ---------------------------------------------------------------------------
111// Core API
112// ---------------------------------------------------------------------------
113
114/// Parse imports from source using the provided tree-sitter tree.
115pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
116    match lang {
117        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
118        LangId::Python => parse_py_imports(source, tree),
119        LangId::Rust => parse_rs_imports(source, tree),
120        LangId::Go => parse_go_imports(source, tree),
121        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => {
122            ImportBlock::empty()
123        }
124        LangId::Html | LangId::Markdown => ImportBlock::empty(),
125    }
126}
127
128/// Check if an import with the given module + name combination already exists.
129///
130/// For dedup: same module path AND (same named import OR same default import).
131/// Side-effect imports match on module path alone.
132pub fn is_duplicate(
133    block: &ImportBlock,
134    module_path: &str,
135    names: &[String],
136    default_import: Option<&str>,
137    type_only: bool,
138) -> bool {
139    let target_kind = if type_only {
140        ImportKind::Type
141    } else {
142        ImportKind::Value
143    };
144
145    for imp in &block.imports {
146        if imp.module_path != module_path {
147            continue;
148        }
149
150        // For side-effect imports or whole-module imports (no names, no default):
151        // module path match alone is sufficient.
152        if names.is_empty()
153            && default_import.is_none()
154            && imp.names.is_empty()
155            && imp.default_import.is_none()
156        {
157            return true;
158        }
159
160        // For side-effect imports specifically (TS/JS): module match is enough
161        if names.is_empty() && default_import.is_none() && imp.kind == ImportKind::SideEffect {
162            return true;
163        }
164
165        // Kind must match for dedup (value imports don't dedup against type imports)
166        if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
167            continue;
168        }
169
170        // Check default import match
171        if let Some(def) = default_import {
172            if imp.default_import.as_deref() == Some(def) {
173                return true;
174            }
175        }
176
177        // Check named imports — if ALL requested names already exist
178        if !names.is_empty() && names.iter().all(|n| imp.names.contains(n)) {
179            return true;
180        }
181    }
182
183    false
184}
185
186/// Find the byte offset where a new import should be inserted.
187///
188/// Strategy:
189/// - Find all existing imports in the same group.
190/// - Within that group, find the alphabetical position by module path.
191/// - Type imports sort after value imports within the same group and module-sort position.
192/// - If no imports exist in the target group, insert after the last import of the
193///   nearest preceding group (or before the first import of the nearest following
194///   group, or at file start if no groups exist).
195/// - Returns (byte_offset, needs_newline_before, needs_newline_after)
196pub fn find_insertion_point(
197    source: &str,
198    block: &ImportBlock,
199    group: ImportGroup,
200    module_path: &str,
201    type_only: bool,
202) -> (usize, bool, bool) {
203    if block.imports.is_empty() {
204        // No imports at all — insert at start of file
205        return (0, false, source.is_empty().then_some(false).unwrap_or(true));
206    }
207
208    let target_kind = if type_only {
209        ImportKind::Type
210    } else {
211        ImportKind::Value
212    };
213
214    // Collect imports in the target group
215    let group_imports: Vec<&ImportStatement> =
216        block.imports.iter().filter(|i| i.group == group).collect();
217
218    if group_imports.is_empty() {
219        // No imports in this group yet — find nearest neighbor group
220        // Try preceding groups (lower ordinal) first
221        let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
222
223        if let Some(last) = preceding_last {
224            let end = last.byte_range.end;
225            let insert_at = skip_newline(source, end);
226            return (insert_at, true, true);
227        }
228
229        // No preceding group — try following groups (higher ordinal)
230        let following_first = block.imports.iter().find(|i| i.group > group);
231
232        if let Some(first) = following_first {
233            return (first.byte_range.start, false, true);
234        }
235
236        // Shouldn't reach here if block is non-empty, but handle gracefully
237        let first_byte = import_byte_range(&block.imports)
238            .map(|range| range.start)
239            .unwrap_or(0);
240        return (first_byte, false, true);
241    }
242
243    // Find position within the group (alphabetical by module path, type after value)
244    for imp in &group_imports {
245        let cmp = module_path.cmp(&imp.module_path);
246        match cmp {
247            std::cmp::Ordering::Less => {
248                // Insert before this import
249                return (imp.byte_range.start, false, false);
250            }
251            std::cmp::Ordering::Equal => {
252                // Same module — type imports go after value imports
253                if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
254                    // Insert after this value import
255                    let end = imp.byte_range.end;
256                    let insert_at = skip_newline(source, end);
257                    return (insert_at, false, false);
258                }
259                // Insert before (or it's a duplicate, caller should have checked)
260                return (imp.byte_range.start, false, false);
261            }
262            std::cmp::Ordering::Greater => continue,
263        }
264    }
265
266    // Module path sorts after all existing imports in this group — insert at end
267    let Some(last) = group_imports.last() else {
268        return (
269            import_byte_range(&block.imports)
270                .map(|range| range.end)
271                .unwrap_or(0),
272            false,
273            false,
274        );
275    };
276    let end = last.byte_range.end;
277    let insert_at = skip_newline(source, end);
278    (insert_at, false, false)
279}
280
281/// Generate an import line for the given language.
282pub fn generate_import_line(
283    lang: LangId,
284    module_path: &str,
285    names: &[String],
286    default_import: Option<&str>,
287    type_only: bool,
288) -> String {
289    match lang {
290        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
291            generate_ts_import_line(module_path, names, default_import, type_only)
292        }
293        LangId::Python => generate_py_import_line(module_path, names, default_import),
294        LangId::Rust => generate_rs_import_line(module_path, names, type_only),
295        LangId::Go => generate_go_import_line(module_path, default_import, false),
296        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => String::new(),
297        LangId::Html | LangId::Markdown => String::new(),
298    }
299}
300
301/// Check if the given language is supported by the import engine.
302pub fn is_supported(lang: LangId) -> bool {
303    matches!(
304        lang,
305        LangId::TypeScript
306            | LangId::Tsx
307            | LangId::JavaScript
308            | LangId::Python
309            | LangId::Rust
310            | LangId::Go
311    )
312}
313
314/// Classify a module path into a group for TS/JS/TSX.
315pub fn classify_group_ts(module_path: &str) -> ImportGroup {
316    if module_path.starts_with('.') {
317        ImportGroup::Internal
318    } else {
319        ImportGroup::External
320    }
321}
322
323/// Classify a module path into a group for the given language.
324pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
325    match lang {
326        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
327        LangId::Python => classify_group_py(module_path),
328        LangId::Rust => classify_group_rs(module_path),
329        LangId::Go => classify_group_go(module_path),
330        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => {
331            ImportGroup::External
332        }
333        LangId::Html | LangId::Markdown => ImportGroup::External,
334    }
335}
336
337/// Parse a file from disk and return its import block.
338/// Convenience wrapper that handles parsing.
339pub fn parse_file_imports(
340    path: &std::path::Path,
341    lang: LangId,
342) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
343    let source =
344        std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
345            path: format!("{}: {}", path.display(), e),
346        })?;
347
348    let grammar = grammar_for(lang);
349    let mut parser = Parser::new();
350    parser
351        .set_language(&grammar)
352        .map_err(|e| crate::error::AftError::ParseError {
353            message: format!("grammar init failed for {:?}: {}", lang, e),
354        })?;
355
356    let tree = parser
357        .parse(&source, None)
358        .ok_or_else(|| crate::error::AftError::ParseError {
359            message: format!("tree-sitter parse returned None for {}", path.display()),
360        })?;
361
362    let block = parse_imports(&source, &tree, lang);
363    Ok((source, tree, block))
364}
365
366// ---------------------------------------------------------------------------
367// TS/JS/TSX implementation
368// ---------------------------------------------------------------------------
369
370/// Parse imports from a TS/JS/TSX file.
371///
372/// Walks the AST root's direct children looking for `import_statement` nodes (D041).
373fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
374    let root = tree.root_node();
375    let mut imports = Vec::new();
376
377    let mut cursor = root.walk();
378    if !cursor.goto_first_child() {
379        return ImportBlock::empty();
380    }
381
382    loop {
383        let node = cursor.node();
384        if node.kind() == "import_statement" {
385            if let Some(imp) = parse_single_ts_import(source, &node) {
386                imports.push(imp);
387            }
388        }
389        if !cursor.goto_next_sibling() {
390            break;
391        }
392    }
393
394    let byte_range = import_byte_range(&imports);
395
396    ImportBlock {
397        imports,
398        byte_range,
399    }
400}
401
402/// Parse a single `import_statement` node into an `ImportStatement`.
403fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
404    let raw_text = source[node.byte_range()].to_string();
405    let byte_range = node.byte_range();
406
407    // Find the source module (string/string_fragment child of the import)
408    let module_path = extract_module_path(source, node)?;
409
410    // Determine if this is a type-only import: `import type ...`
411    let is_type_only = has_type_keyword(node);
412
413    // Extract import clause details
414    let mut names = Vec::new();
415    let mut default_import = None;
416    let mut namespace_import = None;
417
418    let mut child_cursor = node.walk();
419    if child_cursor.goto_first_child() {
420        loop {
421            let child = child_cursor.node();
422            match child.kind() {
423                "import_clause" => {
424                    extract_import_clause(
425                        source,
426                        &child,
427                        &mut names,
428                        &mut default_import,
429                        &mut namespace_import,
430                    );
431                }
432                // In some grammars, the default import is a direct identifier child
433                "identifier" => {
434                    let text = &source[child.byte_range()];
435                    if text != "import" && text != "from" && text != "type" {
436                        default_import = Some(text.to_string());
437                    }
438                }
439                _ => {}
440            }
441            if !child_cursor.goto_next_sibling() {
442                break;
443            }
444        }
445    }
446
447    // Classify kind
448    let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
449        ImportKind::SideEffect
450    } else if is_type_only {
451        ImportKind::Type
452    } else {
453        ImportKind::Value
454    };
455
456    let group = classify_group_ts(&module_path);
457
458    Some(ImportStatement {
459        module_path,
460        names,
461        default_import,
462        namespace_import,
463        kind,
464        group,
465        byte_range,
466        raw_text,
467    })
468}
469
470/// Extract the module path string from an import_statement node.
471///
472/// Looks for a `string` child node and extracts the content without quotes.
473fn extract_module_path(source: &str, node: &Node) -> Option<String> {
474    let mut cursor = node.walk();
475    if !cursor.goto_first_child() {
476        return None;
477    }
478
479    loop {
480        let child = cursor.node();
481        if child.kind() == "string" {
482            // Get the text and strip quotes
483            let text = &source[child.byte_range()];
484            let stripped = text
485                .trim_start_matches(|c| c == '\'' || c == '"')
486                .trim_end_matches(|c| c == '\'' || c == '"');
487            return Some(stripped.to_string());
488        }
489        if !cursor.goto_next_sibling() {
490            break;
491        }
492    }
493    None
494}
495
496/// Check if the import_statement has a `type` keyword (import type ...).
497///
498/// In tree-sitter-typescript, `import type { X } from 'y'` produces a `type`
499/// node as a direct child of `import_statement`, between `import` and `import_clause`.
500fn has_type_keyword(node: &Node) -> bool {
501    let mut cursor = node.walk();
502    if !cursor.goto_first_child() {
503        return false;
504    }
505
506    loop {
507        let child = cursor.node();
508        if child.kind() == "type" {
509            return true;
510        }
511        if !cursor.goto_next_sibling() {
512            break;
513        }
514    }
515
516    false
517}
518
519/// Extract named imports, default import, and namespace import from an import_clause.
520fn extract_import_clause(
521    source: &str,
522    node: &Node,
523    names: &mut Vec<String>,
524    default_import: &mut Option<String>,
525    namespace_import: &mut Option<String>,
526) {
527    let mut cursor = node.walk();
528    if !cursor.goto_first_child() {
529        return;
530    }
531
532    loop {
533        let child = cursor.node();
534        match child.kind() {
535            "identifier" => {
536                // This is a default import: `import Foo from 'bar'`
537                let text = &source[child.byte_range()];
538                if text != "type" {
539                    *default_import = Some(text.to_string());
540                }
541            }
542            "named_imports" => {
543                // `{ name1, name2 }`
544                extract_named_imports(source, &child, names);
545            }
546            "namespace_import" => {
547                // `* as name`
548                extract_namespace_import(source, &child, namespace_import);
549            }
550            _ => {}
551        }
552        if !cursor.goto_next_sibling() {
553            break;
554        }
555    }
556}
557
558/// Extract individual names from a named_imports node (`{ a, b, c }`).
559fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
560    let mut cursor = node.walk();
561    if !cursor.goto_first_child() {
562        return;
563    }
564
565    loop {
566        let child = cursor.node();
567        if child.kind() == "import_specifier" {
568            // import_specifier can have `name` (the imported name) and optional `alias`
569            if let Some(name_node) = child.child_by_field_name("name") {
570                names.push(source[name_node.byte_range()].to_string());
571            } else {
572                // Fallback: first identifier child
573                let mut spec_cursor = child.walk();
574                if spec_cursor.goto_first_child() {
575                    loop {
576                        if spec_cursor.node().kind() == "identifier"
577                            || spec_cursor.node().kind() == "type_identifier"
578                        {
579                            names.push(source[spec_cursor.node().byte_range()].to_string());
580                            break;
581                        }
582                        if !spec_cursor.goto_next_sibling() {
583                            break;
584                        }
585                    }
586                }
587            }
588        }
589        if !cursor.goto_next_sibling() {
590            break;
591        }
592    }
593}
594
595/// Extract the alias name from a namespace_import node (`* as name`).
596fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
597    let mut cursor = node.walk();
598    if !cursor.goto_first_child() {
599        return;
600    }
601
602    loop {
603        let child = cursor.node();
604        if child.kind() == "identifier" {
605            *namespace_import = Some(source[child.byte_range()].to_string());
606            return;
607        }
608        if !cursor.goto_next_sibling() {
609            break;
610        }
611    }
612}
613
614/// Generate an import line for TS/JS/TSX.
615fn generate_ts_import_line(
616    module_path: &str,
617    names: &[String],
618    default_import: Option<&str>,
619    type_only: bool,
620) -> String {
621    let type_prefix = if type_only { "type " } else { "" };
622
623    // Side-effect import
624    if names.is_empty() && default_import.is_none() {
625        return format!("import '{module_path}';");
626    }
627
628    // Default import only
629    if names.is_empty() {
630        if let Some(def) = default_import {
631            return format!("import {type_prefix}{def} from '{module_path}';");
632        }
633    }
634
635    // Named imports only
636    if default_import.is_none() {
637        let mut sorted_names = names.to_vec();
638        sorted_names.sort();
639        let names_str = sorted_names.join(", ");
640        return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
641    }
642
643    // Both default and named imports
644    if let Some(def) = default_import {
645        let mut sorted_names = names.to_vec();
646        sorted_names.sort();
647        let names_str = sorted_names.join(", ");
648        return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
649    }
650
651    // Shouldn't reach here, but handle gracefully
652    format!("import '{module_path}';")
653}
654
655// ---------------------------------------------------------------------------
656// Python implementation
657// ---------------------------------------------------------------------------
658
659/// Python 3.x standard library module names (top-level modules).
660/// Used for import group classification. Covers the commonly-used modules;
661/// unknown modules are assumed third-party.
662const PYTHON_STDLIB: &[&str] = &[
663    "__future__",
664    "_thread",
665    "abc",
666    "aifc",
667    "argparse",
668    "array",
669    "ast",
670    "asynchat",
671    "asyncio",
672    "asyncore",
673    "atexit",
674    "audioop",
675    "base64",
676    "bdb",
677    "binascii",
678    "bisect",
679    "builtins",
680    "bz2",
681    "calendar",
682    "cgi",
683    "cgitb",
684    "chunk",
685    "cmath",
686    "cmd",
687    "code",
688    "codecs",
689    "codeop",
690    "collections",
691    "colorsys",
692    "compileall",
693    "concurrent",
694    "configparser",
695    "contextlib",
696    "contextvars",
697    "copy",
698    "copyreg",
699    "cProfile",
700    "crypt",
701    "csv",
702    "ctypes",
703    "curses",
704    "dataclasses",
705    "datetime",
706    "dbm",
707    "decimal",
708    "difflib",
709    "dis",
710    "distutils",
711    "doctest",
712    "email",
713    "encodings",
714    "enum",
715    "errno",
716    "faulthandler",
717    "fcntl",
718    "filecmp",
719    "fileinput",
720    "fnmatch",
721    "fractions",
722    "ftplib",
723    "functools",
724    "gc",
725    "getopt",
726    "getpass",
727    "gettext",
728    "glob",
729    "grp",
730    "gzip",
731    "hashlib",
732    "heapq",
733    "hmac",
734    "html",
735    "http",
736    "idlelib",
737    "imaplib",
738    "imghdr",
739    "importlib",
740    "inspect",
741    "io",
742    "ipaddress",
743    "itertools",
744    "json",
745    "keyword",
746    "lib2to3",
747    "linecache",
748    "locale",
749    "logging",
750    "lzma",
751    "mailbox",
752    "mailcap",
753    "marshal",
754    "math",
755    "mimetypes",
756    "mmap",
757    "modulefinder",
758    "multiprocessing",
759    "netrc",
760    "numbers",
761    "operator",
762    "optparse",
763    "os",
764    "pathlib",
765    "pdb",
766    "pickle",
767    "pickletools",
768    "pipes",
769    "pkgutil",
770    "platform",
771    "plistlib",
772    "poplib",
773    "posixpath",
774    "pprint",
775    "profile",
776    "pstats",
777    "pty",
778    "pwd",
779    "py_compile",
780    "pyclbr",
781    "pydoc",
782    "queue",
783    "quopri",
784    "random",
785    "re",
786    "readline",
787    "reprlib",
788    "resource",
789    "rlcompleter",
790    "runpy",
791    "sched",
792    "secrets",
793    "select",
794    "selectors",
795    "shelve",
796    "shlex",
797    "shutil",
798    "signal",
799    "site",
800    "smtplib",
801    "sndhdr",
802    "socket",
803    "socketserver",
804    "sqlite3",
805    "ssl",
806    "stat",
807    "statistics",
808    "string",
809    "stringprep",
810    "struct",
811    "subprocess",
812    "symtable",
813    "sys",
814    "sysconfig",
815    "syslog",
816    "tabnanny",
817    "tarfile",
818    "tempfile",
819    "termios",
820    "textwrap",
821    "threading",
822    "time",
823    "timeit",
824    "tkinter",
825    "token",
826    "tokenize",
827    "tomllib",
828    "trace",
829    "traceback",
830    "tracemalloc",
831    "tty",
832    "turtle",
833    "types",
834    "typing",
835    "unicodedata",
836    "unittest",
837    "urllib",
838    "uuid",
839    "venv",
840    "warnings",
841    "wave",
842    "weakref",
843    "webbrowser",
844    "wsgiref",
845    "xml",
846    "xmlrpc",
847    "zipapp",
848    "zipfile",
849    "zipimport",
850    "zlib",
851];
852
853/// Classify a Python import into a group.
854pub fn classify_group_py(module_path: &str) -> ImportGroup {
855    // Relative imports start with '.'
856    if module_path.starts_with('.') {
857        return ImportGroup::Internal;
858    }
859    // Check stdlib: use the top-level module name (before first '.')
860    let top_module = module_path.split('.').next().unwrap_or(module_path);
861    if PYTHON_STDLIB.contains(&top_module) {
862        ImportGroup::Stdlib
863    } else {
864        ImportGroup::External
865    }
866}
867
868/// Parse imports from a Python file.
869fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
870    let root = tree.root_node();
871    let mut imports = Vec::new();
872
873    let mut cursor = root.walk();
874    if !cursor.goto_first_child() {
875        return ImportBlock::empty();
876    }
877
878    loop {
879        let node = cursor.node();
880        match node.kind() {
881            "import_statement" => {
882                if let Some(imp) = parse_py_import_statement(source, &node) {
883                    imports.push(imp);
884                }
885            }
886            "import_from_statement" => {
887                if let Some(imp) = parse_py_import_from_statement(source, &node) {
888                    imports.push(imp);
889                }
890            }
891            _ => {}
892        }
893        if !cursor.goto_next_sibling() {
894            break;
895        }
896    }
897
898    let byte_range = import_byte_range(&imports);
899
900    ImportBlock {
901        imports,
902        byte_range,
903    }
904}
905
906/// Parse `import X` or `import X.Y` Python statements.
907fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
908    let raw_text = source[node.byte_range()].to_string();
909    let byte_range = node.byte_range();
910
911    // Find the dotted_name child (the module name)
912    let mut module_path = String::new();
913    let mut c = node.walk();
914    if c.goto_first_child() {
915        loop {
916            if c.node().kind() == "dotted_name" {
917                module_path = source[c.node().byte_range()].to_string();
918                break;
919            }
920            if !c.goto_next_sibling() {
921                break;
922            }
923        }
924    }
925    if module_path.is_empty() {
926        return None;
927    }
928
929    let group = classify_group_py(&module_path);
930
931    Some(ImportStatement {
932        module_path,
933        names: Vec::new(),
934        default_import: None,
935        namespace_import: None,
936        kind: ImportKind::Value,
937        group,
938        byte_range,
939        raw_text,
940    })
941}
942
943/// Parse `from X import Y, Z` or `from . import Y` Python statements.
944fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
945    let raw_text = source[node.byte_range()].to_string();
946    let byte_range = node.byte_range();
947
948    let mut module_path = String::new();
949    let mut names = Vec::new();
950
951    let mut c = node.walk();
952    if c.goto_first_child() {
953        loop {
954            let child = c.node();
955            match child.kind() {
956                "dotted_name" => {
957                    // Could be the module name or an imported name
958                    // The module name comes right after `from`, imported names come after `import`
959                    // Use position: if we haven't set module_path yet and this comes
960                    // before the `import` keyword, it's the module.
961                    if module_path.is_empty()
962                        && !has_seen_import_keyword(source, node, child.start_byte())
963                    {
964                        module_path = source[child.byte_range()].to_string();
965                    } else {
966                        // It's an imported name
967                        names.push(source[child.byte_range()].to_string());
968                    }
969                }
970                "relative_import" => {
971                    // from . import X or from ..module import X
972                    module_path = source[child.byte_range()].to_string();
973                }
974                _ => {}
975            }
976            if !c.goto_next_sibling() {
977                break;
978            }
979        }
980    }
981
982    // module_path must be non-empty for a valid import
983    if module_path.is_empty() {
984        return None;
985    }
986
987    let group = classify_group_py(&module_path);
988
989    Some(ImportStatement {
990        module_path,
991        names,
992        default_import: None,
993        namespace_import: None,
994        kind: ImportKind::Value,
995        group,
996        byte_range,
997        raw_text,
998    })
999}
1000
1001/// Check if the `import` keyword appears before the given byte position in a from...import node.
1002fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1003    let mut c = parent.walk();
1004    if c.goto_first_child() {
1005        loop {
1006            let child = c.node();
1007            if child.kind() == "import" && child.start_byte() < before_byte {
1008                return true;
1009            }
1010            if child.start_byte() >= before_byte {
1011                return false;
1012            }
1013            if !c.goto_next_sibling() {
1014                break;
1015            }
1016        }
1017    }
1018    false
1019}
1020
1021/// Generate a Python import line.
1022fn generate_py_import_line(
1023    module_path: &str,
1024    names: &[String],
1025    _default_import: Option<&str>,
1026) -> String {
1027    if names.is_empty() {
1028        // `import module`
1029        format!("import {module_path}")
1030    } else {
1031        // `from module import name1, name2`
1032        let mut sorted = names.to_vec();
1033        sorted.sort();
1034        let names_str = sorted.join(", ");
1035        format!("from {module_path} import {names_str}")
1036    }
1037}
1038
1039// ---------------------------------------------------------------------------
1040// Rust implementation
1041// ---------------------------------------------------------------------------
1042
1043/// Classify a Rust use path into a group.
1044pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1045    // Extract the first path segment (before ::)
1046    let first_seg = module_path.split("::").next().unwrap_or(module_path);
1047    match first_seg {
1048        "std" | "core" | "alloc" => ImportGroup::Stdlib,
1049        "crate" | "self" | "super" => ImportGroup::Internal,
1050        _ => ImportGroup::External,
1051    }
1052}
1053
1054/// Parse imports from a Rust file.
1055fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1056    let root = tree.root_node();
1057    let mut imports = Vec::new();
1058
1059    let mut cursor = root.walk();
1060    if !cursor.goto_first_child() {
1061        return ImportBlock::empty();
1062    }
1063
1064    loop {
1065        let node = cursor.node();
1066        if node.kind() == "use_declaration" {
1067            if let Some(imp) = parse_rs_use_declaration(source, &node) {
1068                imports.push(imp);
1069            }
1070        }
1071        if !cursor.goto_next_sibling() {
1072            break;
1073        }
1074    }
1075
1076    let byte_range = import_byte_range(&imports);
1077
1078    ImportBlock {
1079        imports,
1080        byte_range,
1081    }
1082}
1083
1084/// Parse a single `use` declaration from Rust.
1085fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1086    let raw_text = source[node.byte_range()].to_string();
1087    let byte_range = node.byte_range();
1088
1089    // Check for `pub` visibility modifier
1090    let mut has_pub = false;
1091    let mut use_path = String::new();
1092    let mut names = Vec::new();
1093
1094    let mut c = node.walk();
1095    if c.goto_first_child() {
1096        loop {
1097            let child = c.node();
1098            match child.kind() {
1099                "visibility_modifier" => {
1100                    has_pub = true;
1101                }
1102                "scoped_identifier" | "identifier" | "use_as_clause" => {
1103                    // Full path like `std::collections::HashMap` or just `serde`
1104                    use_path = source[child.byte_range()].to_string();
1105                }
1106                "scoped_use_list" => {
1107                    // e.g. `serde::{Deserialize, Serialize}`
1108                    use_path = source[child.byte_range()].to_string();
1109                    // Also extract the individual names from the use_list
1110                    extract_rs_use_list_names(source, &child, &mut names);
1111                }
1112                _ => {}
1113            }
1114            if !c.goto_next_sibling() {
1115                break;
1116            }
1117        }
1118    }
1119
1120    if use_path.is_empty() {
1121        return None;
1122    }
1123
1124    let group = classify_group_rs(&use_path);
1125
1126    Some(ImportStatement {
1127        module_path: use_path,
1128        names,
1129        default_import: if has_pub {
1130            Some("pub".to_string())
1131        } else {
1132            None
1133        },
1134        namespace_import: None,
1135        kind: ImportKind::Value,
1136        group,
1137        byte_range,
1138        raw_text,
1139    })
1140}
1141
1142/// Extract individual names from a Rust `scoped_use_list` node.
1143fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1144    let mut c = node.walk();
1145    if c.goto_first_child() {
1146        loop {
1147            let child = c.node();
1148            if child.kind() == "use_list" {
1149                // Walk into the use_list to find identifiers
1150                let mut lc = child.walk();
1151                if lc.goto_first_child() {
1152                    loop {
1153                        let lchild = lc.node();
1154                        if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1155                            names.push(source[lchild.byte_range()].to_string());
1156                        }
1157                        if !lc.goto_next_sibling() {
1158                            break;
1159                        }
1160                    }
1161                }
1162            }
1163            if !c.goto_next_sibling() {
1164                break;
1165            }
1166        }
1167    }
1168}
1169
1170/// Generate a Rust import line.
1171fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1172    if names.is_empty() {
1173        format!("use {module_path};")
1174    } else {
1175        // If names are provided, generate `use prefix::{names};`
1176        // But the caller may pass module_path as the full path including the item,
1177        // e.g., "serde::Deserialize". For simple cases, just use the module_path directly.
1178        format!("use {module_path};")
1179    }
1180}
1181
1182// ---------------------------------------------------------------------------
1183// Go implementation
1184// ---------------------------------------------------------------------------
1185
1186/// Classify a Go import path into a group.
1187pub fn classify_group_go(module_path: &str) -> ImportGroup {
1188    // stdlib paths don't contain dots (e.g., "fmt", "os", "net/http")
1189    // external paths contain dots (e.g., "github.com/pkg/errors")
1190    if module_path.contains('.') {
1191        ImportGroup::External
1192    } else {
1193        ImportGroup::Stdlib
1194    }
1195}
1196
1197/// Parse imports from a Go file.
1198fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1199    let root = tree.root_node();
1200    let mut imports = Vec::new();
1201
1202    let mut cursor = root.walk();
1203    if !cursor.goto_first_child() {
1204        return ImportBlock::empty();
1205    }
1206
1207    loop {
1208        let node = cursor.node();
1209        if node.kind() == "import_declaration" {
1210            parse_go_import_declaration(source, &node, &mut imports);
1211        }
1212        if !cursor.goto_next_sibling() {
1213            break;
1214        }
1215    }
1216
1217    let byte_range = import_byte_range(&imports);
1218
1219    ImportBlock {
1220        imports,
1221        byte_range,
1222    }
1223}
1224
1225/// Parse a single Go import_declaration (may contain one or multiple specs).
1226fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1227    let mut c = node.walk();
1228    if c.goto_first_child() {
1229        loop {
1230            let child = c.node();
1231            match child.kind() {
1232                "import_spec" => {
1233                    if let Some(imp) = parse_go_import_spec(source, &child) {
1234                        imports.push(imp);
1235                    }
1236                }
1237                "import_spec_list" => {
1238                    // Grouped imports: walk into the list
1239                    let mut lc = child.walk();
1240                    if lc.goto_first_child() {
1241                        loop {
1242                            if lc.node().kind() == "import_spec" {
1243                                if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1244                                    imports.push(imp);
1245                                }
1246                            }
1247                            if !lc.goto_next_sibling() {
1248                                break;
1249                            }
1250                        }
1251                    }
1252                }
1253                _ => {}
1254            }
1255            if !c.goto_next_sibling() {
1256                break;
1257            }
1258        }
1259    }
1260}
1261
1262/// Parse a single Go import_spec node.
1263fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1264    let raw_text = source[node.byte_range()].to_string();
1265    let byte_range = node.byte_range();
1266
1267    let mut import_path = String::new();
1268    let mut alias = None;
1269
1270    let mut c = node.walk();
1271    if c.goto_first_child() {
1272        loop {
1273            let child = c.node();
1274            match child.kind() {
1275                "interpreted_string_literal" => {
1276                    // Extract the path without quotes
1277                    let text = source[child.byte_range()].to_string();
1278                    import_path = text.trim_matches('"').to_string();
1279                }
1280                "identifier" | "blank_identifier" | "dot" => {
1281                    // This is an alias (e.g., `alias "path"` or `. "path"` or `_ "path"`)
1282                    alias = Some(source[child.byte_range()].to_string());
1283                }
1284                _ => {}
1285            }
1286            if !c.goto_next_sibling() {
1287                break;
1288            }
1289        }
1290    }
1291
1292    if import_path.is_empty() {
1293        return None;
1294    }
1295
1296    let group = classify_group_go(&import_path);
1297
1298    Some(ImportStatement {
1299        module_path: import_path,
1300        names: Vec::new(),
1301        default_import: alias,
1302        namespace_import: None,
1303        kind: ImportKind::Value,
1304        group,
1305        byte_range,
1306        raw_text,
1307    })
1308}
1309
1310/// Public API for Go import line generation (used by add_import handler).
1311pub fn generate_go_import_line_pub(
1312    module_path: &str,
1313    alias: Option<&str>,
1314    in_group: bool,
1315) -> String {
1316    generate_go_import_line(module_path, alias, in_group)
1317}
1318
1319/// Generate a Go import line (public API for command handler).
1320///
1321/// `in_group` controls whether to generate a spec for insertion into an
1322/// existing grouped import (`\t"path"`) or a standalone import (`import "path"`).
1323fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1324    if in_group {
1325        // Spec for grouped import block
1326        match alias {
1327            Some(a) => format!("\t{a} \"{module_path}\""),
1328            None => format!("\t\"{module_path}\""),
1329        }
1330    } else {
1331        // Standalone import
1332        match alias {
1333            Some(a) => format!("import {a} \"{module_path}\""),
1334            None => format!("import \"{module_path}\""),
1335        }
1336    }
1337}
1338
1339/// Check if a Go import block has a grouped import declaration.
1340/// Returns the byte range of the import_spec_list if found.
1341pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1342    let root = tree.root_node();
1343    let mut cursor = root.walk();
1344    if !cursor.goto_first_child() {
1345        return None;
1346    }
1347
1348    loop {
1349        let node = cursor.node();
1350        if node.kind() == "import_declaration" {
1351            let mut c = node.walk();
1352            if c.goto_first_child() {
1353                loop {
1354                    if c.node().kind() == "import_spec_list" {
1355                        return Some(c.node().byte_range());
1356                    }
1357                    if !c.goto_next_sibling() {
1358                        break;
1359                    }
1360                }
1361            }
1362        }
1363        if !cursor.goto_next_sibling() {
1364            break;
1365        }
1366    }
1367    None
1368}
1369
1370/// Skip past a newline character at the given position.
1371fn skip_newline(source: &str, pos: usize) -> usize {
1372    if pos < source.len() {
1373        let bytes = source.as_bytes();
1374        if bytes[pos] == b'\n' {
1375            return pos + 1;
1376        }
1377        if bytes[pos] == b'\r' {
1378            if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1379                return pos + 2;
1380            }
1381            return pos + 1;
1382        }
1383    }
1384    pos
1385}
1386
1387// ---------------------------------------------------------------------------
1388// Unit tests
1389// ---------------------------------------------------------------------------
1390
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394
1395    fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1396        let grammar = grammar_for(LangId::TypeScript);
1397        let mut parser = Parser::new();
1398        parser.set_language(&grammar).unwrap();
1399        let tree = parser.parse(source, None).unwrap();
1400        let block = parse_imports(source, &tree, LangId::TypeScript);
1401        (tree, block)
1402    }
1403
1404    fn parse_js(source: &str) -> (Tree, ImportBlock) {
1405        let grammar = grammar_for(LangId::JavaScript);
1406        let mut parser = Parser::new();
1407        parser.set_language(&grammar).unwrap();
1408        let tree = parser.parse(source, None).unwrap();
1409        let block = parse_imports(source, &tree, LangId::JavaScript);
1410        (tree, block)
1411    }
1412
1413    // --- Basic parsing ---
1414
1415    #[test]
1416    fn parse_ts_named_imports() {
1417        let source = "import { useState, useEffect } from 'react';\n";
1418        let (_, block) = parse_ts(source);
1419        assert_eq!(block.imports.len(), 1);
1420        let imp = &block.imports[0];
1421        assert_eq!(imp.module_path, "react");
1422        assert!(imp.names.contains(&"useState".to_string()));
1423        assert!(imp.names.contains(&"useEffect".to_string()));
1424        assert_eq!(imp.kind, ImportKind::Value);
1425        assert_eq!(imp.group, ImportGroup::External);
1426    }
1427
1428    #[test]
1429    fn parse_ts_default_import() {
1430        let source = "import React from 'react';\n";
1431        let (_, block) = parse_ts(source);
1432        assert_eq!(block.imports.len(), 1);
1433        let imp = &block.imports[0];
1434        assert_eq!(imp.default_import.as_deref(), Some("React"));
1435        assert_eq!(imp.kind, ImportKind::Value);
1436    }
1437
1438    #[test]
1439    fn parse_ts_side_effect_import() {
1440        let source = "import './styles.css';\n";
1441        let (_, block) = parse_ts(source);
1442        assert_eq!(block.imports.len(), 1);
1443        assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1444        assert_eq!(block.imports[0].module_path, "./styles.css");
1445    }
1446
1447    #[test]
1448    fn parse_ts_relative_import() {
1449        let source = "import { helper } from './utils';\n";
1450        let (_, block) = parse_ts(source);
1451        assert_eq!(block.imports.len(), 1);
1452        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1453    }
1454
1455    #[test]
1456    fn parse_ts_multiple_groups() {
1457        let source = "\
1458import React from 'react';
1459import { useState } from 'react';
1460import { helper } from './utils';
1461import { Config } from '../config';
1462";
1463        let (_, block) = parse_ts(source);
1464        assert_eq!(block.imports.len(), 4);
1465
1466        let external: Vec<_> = block
1467            .imports
1468            .iter()
1469            .filter(|i| i.group == ImportGroup::External)
1470            .collect();
1471        let relative: Vec<_> = block
1472            .imports
1473            .iter()
1474            .filter(|i| i.group == ImportGroup::Internal)
1475            .collect();
1476        assert_eq!(external.len(), 2);
1477        assert_eq!(relative.len(), 2);
1478    }
1479
1480    #[test]
1481    fn parse_ts_namespace_import() {
1482        let source = "import * as path from 'path';\n";
1483        let (_, block) = parse_ts(source);
1484        assert_eq!(block.imports.len(), 1);
1485        let imp = &block.imports[0];
1486        assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1487        assert_eq!(imp.kind, ImportKind::Value);
1488    }
1489
1490    #[test]
1491    fn parse_js_imports() {
1492        let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1493        let (_, block) = parse_js(source);
1494        assert_eq!(block.imports.len(), 2);
1495        assert_eq!(block.imports[0].group, ImportGroup::External);
1496        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1497    }
1498
1499    // --- Group classification ---
1500
1501    #[test]
1502    fn classify_external() {
1503        assert_eq!(classify_group_ts("react"), ImportGroup::External);
1504        assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1505        assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1506    }
1507
1508    #[test]
1509    fn classify_relative() {
1510        assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1511        assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1512        assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1513    }
1514
1515    // --- Dedup ---
1516
1517    #[test]
1518    fn dedup_detects_same_named_import() {
1519        let source = "import { useState } from 'react';\n";
1520        let (_, block) = parse_ts(source);
1521        assert!(is_duplicate(
1522            &block,
1523            "react",
1524            &["useState".to_string()],
1525            None,
1526            false
1527        ));
1528    }
1529
1530    #[test]
1531    fn dedup_misses_different_name() {
1532        let source = "import { useState } from 'react';\n";
1533        let (_, block) = parse_ts(source);
1534        assert!(!is_duplicate(
1535            &block,
1536            "react",
1537            &["useEffect".to_string()],
1538            None,
1539            false
1540        ));
1541    }
1542
1543    #[test]
1544    fn dedup_detects_default_import() {
1545        let source = "import React from 'react';\n";
1546        let (_, block) = parse_ts(source);
1547        assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1548    }
1549
1550    #[test]
1551    fn dedup_side_effect() {
1552        let source = "import './styles.css';\n";
1553        let (_, block) = parse_ts(source);
1554        assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1555    }
1556
1557    #[test]
1558    fn dedup_type_vs_value() {
1559        let source = "import { FC } from 'react';\n";
1560        let (_, block) = parse_ts(source);
1561        // Type import should NOT match a value import of the same name
1562        assert!(!is_duplicate(
1563            &block,
1564            "react",
1565            &["FC".to_string()],
1566            None,
1567            true
1568        ));
1569    }
1570
1571    // --- Generation ---
1572
1573    #[test]
1574    fn generate_named_import() {
1575        let line = generate_import_line(
1576            LangId::TypeScript,
1577            "react",
1578            &["useState".to_string(), "useEffect".to_string()],
1579            None,
1580            false,
1581        );
1582        assert_eq!(line, "import { useEffect, useState } from 'react';");
1583    }
1584
1585    #[test]
1586    fn generate_default_import() {
1587        let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1588        assert_eq!(line, "import React from 'react';");
1589    }
1590
1591    #[test]
1592    fn generate_type_import() {
1593        let line =
1594            generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1595        assert_eq!(line, "import type { FC } from 'react';");
1596    }
1597
1598    #[test]
1599    fn generate_side_effect_import() {
1600        let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1601        assert_eq!(line, "import './styles.css';");
1602    }
1603
1604    #[test]
1605    fn generate_default_and_named() {
1606        let line = generate_import_line(
1607            LangId::TypeScript,
1608            "react",
1609            &["useState".to_string()],
1610            Some("React"),
1611            false,
1612        );
1613        assert_eq!(line, "import React, { useState } from 'react';");
1614    }
1615
1616    #[test]
1617    fn parse_ts_type_import() {
1618        let source = "import type { FC } from 'react';\n";
1619        let (_, block) = parse_ts(source);
1620        assert_eq!(block.imports.len(), 1);
1621        let imp = &block.imports[0];
1622        assert_eq!(imp.kind, ImportKind::Type);
1623        assert!(imp.names.contains(&"FC".to_string()));
1624        assert_eq!(imp.group, ImportGroup::External);
1625    }
1626
1627    // --- Insertion point ---
1628
1629    #[test]
1630    fn insertion_empty_file() {
1631        let source = "";
1632        let (_, block) = parse_ts(source);
1633        let (offset, _, _) =
1634            find_insertion_point(source, &block, ImportGroup::External, "react", false);
1635        assert_eq!(offset, 0);
1636    }
1637
1638    #[test]
1639    fn insertion_alphabetical_within_group() {
1640        let source = "\
1641import { a } from 'alpha';
1642import { c } from 'charlie';
1643";
1644        let (_, block) = parse_ts(source);
1645        let (offset, _, _) =
1646            find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1647        // Should insert before 'charlie' (which starts at line 2)
1648        let before_charlie = source.find("import { c }").unwrap();
1649        assert_eq!(offset, before_charlie);
1650    }
1651
1652    // --- Python parsing ---
1653
1654    fn parse_py(source: &str) -> (Tree, ImportBlock) {
1655        let grammar = grammar_for(LangId::Python);
1656        let mut parser = Parser::new();
1657        parser.set_language(&grammar).unwrap();
1658        let tree = parser.parse(source, None).unwrap();
1659        let block = parse_imports(source, &tree, LangId::Python);
1660        (tree, block)
1661    }
1662
1663    #[test]
1664    fn parse_py_import_statement() {
1665        let source = "import os\nimport sys\n";
1666        let (_, block) = parse_py(source);
1667        assert_eq!(block.imports.len(), 2);
1668        assert_eq!(block.imports[0].module_path, "os");
1669        assert_eq!(block.imports[1].module_path, "sys");
1670        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1671    }
1672
1673    #[test]
1674    fn parse_py_from_import() {
1675        let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1676        let (_, block) = parse_py(source);
1677        assert_eq!(block.imports.len(), 2);
1678        assert_eq!(block.imports[0].module_path, "collections");
1679        assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1680        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1681        assert_eq!(block.imports[1].module_path, "typing");
1682        assert!(block.imports[1].names.contains(&"List".to_string()));
1683        assert!(block.imports[1].names.contains(&"Optional".to_string()));
1684    }
1685
1686    #[test]
1687    fn parse_py_relative_import() {
1688        let source = "from . import utils\nfrom ..config import Settings\n";
1689        let (_, block) = parse_py(source);
1690        assert_eq!(block.imports.len(), 2);
1691        assert_eq!(block.imports[0].module_path, ".");
1692        assert!(block.imports[0].names.contains(&"utils".to_string()));
1693        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1694        assert_eq!(block.imports[1].module_path, "..config");
1695        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1696    }
1697
1698    #[test]
1699    fn classify_py_groups() {
1700        assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1701        assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1702        assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1703        assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1704        assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1705        assert_eq!(classify_group_py("requests"), ImportGroup::External);
1706        assert_eq!(classify_group_py("flask"), ImportGroup::External);
1707        assert_eq!(classify_group_py("."), ImportGroup::Internal);
1708        assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1709        assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1710    }
1711
1712    #[test]
1713    fn parse_py_three_groups() {
1714        let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1715        let (_, block) = parse_py(source);
1716        let stdlib: Vec<_> = block
1717            .imports
1718            .iter()
1719            .filter(|i| i.group == ImportGroup::Stdlib)
1720            .collect();
1721        let external: Vec<_> = block
1722            .imports
1723            .iter()
1724            .filter(|i| i.group == ImportGroup::External)
1725            .collect();
1726        let internal: Vec<_> = block
1727            .imports
1728            .iter()
1729            .filter(|i| i.group == ImportGroup::Internal)
1730            .collect();
1731        assert_eq!(stdlib.len(), 2);
1732        assert_eq!(external.len(), 1);
1733        assert_eq!(internal.len(), 1);
1734    }
1735
1736    #[test]
1737    fn generate_py_import() {
1738        let line = generate_import_line(LangId::Python, "os", &[], None, false);
1739        assert_eq!(line, "import os");
1740    }
1741
1742    #[test]
1743    fn generate_py_from_import() {
1744        let line = generate_import_line(
1745            LangId::Python,
1746            "collections",
1747            &["OrderedDict".to_string()],
1748            None,
1749            false,
1750        );
1751        assert_eq!(line, "from collections import OrderedDict");
1752    }
1753
1754    #[test]
1755    fn generate_py_from_import_multiple() {
1756        let line = generate_import_line(
1757            LangId::Python,
1758            "typing",
1759            &["Optional".to_string(), "List".to_string()],
1760            None,
1761            false,
1762        );
1763        assert_eq!(line, "from typing import List, Optional");
1764    }
1765
1766    // --- Rust parsing ---
1767
1768    fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1769        let grammar = grammar_for(LangId::Rust);
1770        let mut parser = Parser::new();
1771        parser.set_language(&grammar).unwrap();
1772        let tree = parser.parse(source, None).unwrap();
1773        let block = parse_imports(source, &tree, LangId::Rust);
1774        (tree, block)
1775    }
1776
1777    #[test]
1778    fn parse_rs_use_std() {
1779        let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1780        let (_, block) = parse_rust(source);
1781        assert_eq!(block.imports.len(), 2);
1782        assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1783        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1784        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1785    }
1786
1787    #[test]
1788    fn parse_rs_use_external() {
1789        let source = "use serde::{Deserialize, Serialize};\n";
1790        let (_, block) = parse_rust(source);
1791        assert_eq!(block.imports.len(), 1);
1792        assert_eq!(block.imports[0].group, ImportGroup::External);
1793        assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1794        assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1795    }
1796
1797    #[test]
1798    fn parse_rs_use_crate() {
1799        let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1800        let (_, block) = parse_rust(source);
1801        assert_eq!(block.imports.len(), 2);
1802        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1803        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1804    }
1805
1806    #[test]
1807    fn parse_rs_pub_use() {
1808        let source = "pub use super::parent::Thing;\n";
1809        let (_, block) = parse_rust(source);
1810        assert_eq!(block.imports.len(), 1);
1811        // `pub` is stored in default_import as a marker
1812        assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1813    }
1814
1815    #[test]
1816    fn classify_rs_groups() {
1817        assert_eq!(
1818            classify_group_rs("std::collections::HashMap"),
1819            ImportGroup::Stdlib
1820        );
1821        assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
1822        assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
1823        assert_eq!(
1824            classify_group_rs("serde::Deserialize"),
1825            ImportGroup::External
1826        );
1827        assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
1828        assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
1829        assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
1830        assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
1831    }
1832
1833    #[test]
1834    fn generate_rs_use() {
1835        let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
1836        assert_eq!(line, "use std::fmt::Display;");
1837    }
1838
1839    // --- Go parsing ---
1840
1841    fn parse_go(source: &str) -> (Tree, ImportBlock) {
1842        let grammar = grammar_for(LangId::Go);
1843        let mut parser = Parser::new();
1844        parser.set_language(&grammar).unwrap();
1845        let tree = parser.parse(source, None).unwrap();
1846        let block = parse_imports(source, &tree, LangId::Go);
1847        (tree, block)
1848    }
1849
1850    #[test]
1851    fn parse_go_single_import() {
1852        let source = "package main\n\nimport \"fmt\"\n";
1853        let (_, block) = parse_go(source);
1854        assert_eq!(block.imports.len(), 1);
1855        assert_eq!(block.imports[0].module_path, "fmt");
1856        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1857    }
1858
1859    #[test]
1860    fn parse_go_grouped_import() {
1861        let source =
1862            "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
1863        let (_, block) = parse_go(source);
1864        assert_eq!(block.imports.len(), 3);
1865        assert_eq!(block.imports[0].module_path, "fmt");
1866        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1867        assert_eq!(block.imports[1].module_path, "os");
1868        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1869        assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
1870        assert_eq!(block.imports[2].group, ImportGroup::External);
1871    }
1872
1873    #[test]
1874    fn parse_go_mixed_imports() {
1875        // Single + grouped
1876        let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
1877        let (_, block) = parse_go(source);
1878        assert_eq!(block.imports.len(), 3);
1879    }
1880
1881    #[test]
1882    fn classify_go_groups() {
1883        assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
1884        assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
1885        assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
1886        assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
1887        assert_eq!(
1888            classify_group_go("github.com/pkg/errors"),
1889            ImportGroup::External
1890        );
1891        assert_eq!(
1892            classify_group_go("golang.org/x/tools"),
1893            ImportGroup::External
1894        );
1895    }
1896
1897    #[test]
1898    fn generate_go_standalone() {
1899        let line = generate_go_import_line("fmt", None, false);
1900        assert_eq!(line, "import \"fmt\"");
1901    }
1902
1903    #[test]
1904    fn generate_go_grouped_spec() {
1905        let line = generate_go_import_line("fmt", None, true);
1906        assert_eq!(line, "\t\"fmt\"");
1907    }
1908
1909    #[test]
1910    fn generate_go_with_alias() {
1911        let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
1912        assert_eq!(line, "import errs \"github.com/pkg/errors\"");
1913    }
1914}