Skip to main content

aft/
imports.rs

1//! Import analysis engine: parsing, grouping, deduplication, and insertion.
2//!
3//! Provides per-language import handling dispatched by `LangId`. Each language
4//! implementation extracts imports from tree-sitter ASTs, classifies them into
5//! groups, and generates import text.
6//!
7//! Currently supports: TypeScript, TSX, JavaScript.
8
9use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15// ---------------------------------------------------------------------------
16// Shared types
17// ---------------------------------------------------------------------------
18
19/// What kind of import this is.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22    /// `import { X } from 'y'` or `import X from 'y'`
23    Value,
24    /// `import type { X } from 'y'`
25    Type,
26    /// `import './side-effect'`
27    SideEffect,
28}
29
30/// Which logical group an import belongs to (language-specific).
31///
32/// Ordering matches conventional import group sorting:
33///   Stdlib (first) < External < Internal (last)
34///
35/// Language mapping:
36///   - TS/JS/TSX: External (no `.` prefix), Internal (`.`/`..` prefix)
37///   - Python:    Stdlib, External (third-party), Internal (relative `.`/`..`)
38///   - Rust:      Stdlib (std/core/alloc), External (crates), Internal (crate/self/super)
39///   - Go:        Stdlib (no dots in path), External (dots in path)
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42    /// Standard library (Python stdlib, Rust std/core/alloc, Go stdlib).
43    /// TS/JS don't use this group.
44    Stdlib,
45    /// External/third-party packages.
46    External,
47    /// Internal/relative imports (TS relative, Python local, Rust crate/self/super).
48    Internal,
49}
50
51impl ImportGroup {
52    /// Human-readable label for the group.
53    pub fn label(&self) -> &'static str {
54        match self {
55            ImportGroup::Stdlib => "stdlib",
56            ImportGroup::External => "external",
57            ImportGroup::Internal => "internal",
58        }
59    }
60}
61
62/// A single parsed import statement.
63#[derive(Debug, Clone)]
64pub struct ImportStatement {
65    /// The module path (e.g., `react`, `./utils`, `../config`).
66    pub module_path: String,
67    /// Named imports (e.g., `["useState", "useEffect"]`).
68    pub names: Vec<String>,
69    /// Default import name (e.g., `React` from `import React from 'react'`).
70    pub default_import: Option<String>,
71    /// Namespace import name (e.g., `path` from `import * as path from 'path'`).
72    pub namespace_import: Option<String>,
73    /// What kind: value, type, or side-effect.
74    pub kind: ImportKind,
75    /// Which group this import belongs to.
76    pub group: ImportGroup,
77    /// Byte range in the original source.
78    pub byte_range: Range<usize>,
79    /// Raw text of the import statement.
80    pub raw_text: String,
81}
82
83/// A block of parsed imports from a file.
84#[derive(Debug, Clone)]
85pub struct ImportBlock {
86    /// All parsed import statements, in source order.
87    pub imports: Vec<ImportStatement>,
88    /// Overall byte range covering all import statements (start of first to end of last).
89    /// `None` if no imports found.
90    pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94    pub fn empty() -> Self {
95        ImportBlock {
96            imports: Vec::new(),
97            byte_range: None,
98        }
99    }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103    imports.first().zip(imports.last()).map(|(first, last)| {
104        let start = first.byte_range.start;
105        let end = last.byte_range.end;
106        start..end
107    })
108}
109
110// ---------------------------------------------------------------------------
111// Core API
112// ---------------------------------------------------------------------------
113
114/// Parse imports from source using the provided tree-sitter tree.
115pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
116    match lang {
117        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
118        LangId::Python => parse_py_imports(source, tree),
119        LangId::Rust => parse_rs_imports(source, tree),
120        LangId::Go => parse_go_imports(source, tree),
121        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => ImportBlock::empty(),
122        LangId::Markdown => ImportBlock::empty(),
123    }
124}
125
126/// Check if an import with the given module + name combination already exists.
127///
128/// For dedup: same module path AND (same named import OR same default import).
129/// Side-effect imports match on module path alone.
130pub fn is_duplicate(
131    block: &ImportBlock,
132    module_path: &str,
133    names: &[String],
134    default_import: Option<&str>,
135    type_only: bool,
136) -> bool {
137    let target_kind = if type_only {
138        ImportKind::Type
139    } else {
140        ImportKind::Value
141    };
142
143    for imp in &block.imports {
144        if imp.module_path != module_path {
145            continue;
146        }
147
148        // For side-effect imports or whole-module imports (no names, no default):
149        // module path match alone is sufficient.
150        if names.is_empty()
151            && default_import.is_none()
152            && imp.names.is_empty()
153            && imp.default_import.is_none()
154        {
155            return true;
156        }
157
158        // For side-effect imports specifically (TS/JS): module match is enough
159        if names.is_empty() && default_import.is_none() && imp.kind == ImportKind::SideEffect {
160            return true;
161        }
162
163        // Kind must match for dedup (value imports don't dedup against type imports)
164        if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
165            continue;
166        }
167
168        // Check default import match
169        if let Some(def) = default_import {
170            if imp.default_import.as_deref() == Some(def) {
171                return true;
172            }
173        }
174
175        // Check named imports — if ALL requested names already exist
176        if !names.is_empty() && names.iter().all(|n| imp.names.contains(n)) {
177            return true;
178        }
179    }
180
181    false
182}
183
184/// Find the byte offset where a new import should be inserted.
185///
186/// Strategy:
187/// - Find all existing imports in the same group.
188/// - Within that group, find the alphabetical position by module path.
189/// - Type imports sort after value imports within the same group and module-sort position.
190/// - If no imports exist in the target group, insert after the last import of the
191///   nearest preceding group (or before the first import of the nearest following
192///   group, or at file start if no groups exist).
193/// - Returns (byte_offset, needs_newline_before, needs_newline_after)
194pub fn find_insertion_point(
195    source: &str,
196    block: &ImportBlock,
197    group: ImportGroup,
198    module_path: &str,
199    type_only: bool,
200) -> (usize, bool, bool) {
201    if block.imports.is_empty() {
202        // No imports at all — insert at start of file
203        return (0, false, source.is_empty().then_some(false).unwrap_or(true));
204    }
205
206    let target_kind = if type_only {
207        ImportKind::Type
208    } else {
209        ImportKind::Value
210    };
211
212    // Collect imports in the target group
213    let group_imports: Vec<&ImportStatement> =
214        block.imports.iter().filter(|i| i.group == group).collect();
215
216    if group_imports.is_empty() {
217        // No imports in this group yet — find nearest neighbor group
218        // Try preceding groups (lower ordinal) first
219        let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
220
221        if let Some(last) = preceding_last {
222            let end = last.byte_range.end;
223            let insert_at = skip_newline(source, end);
224            return (insert_at, true, true);
225        }
226
227        // No preceding group — try following groups (higher ordinal)
228        let following_first = block.imports.iter().find(|i| i.group > group);
229
230        if let Some(first) = following_first {
231            return (first.byte_range.start, false, true);
232        }
233
234        // Shouldn't reach here if block is non-empty, but handle gracefully
235        let first_byte = import_byte_range(&block.imports)
236            .map(|range| range.start)
237            .unwrap_or(0);
238        return (first_byte, false, true);
239    }
240
241    // Find position within the group (alphabetical by module path, type after value)
242    for imp in &group_imports {
243        let cmp = module_path.cmp(&imp.module_path);
244        match cmp {
245            std::cmp::Ordering::Less => {
246                // Insert before this import
247                return (imp.byte_range.start, false, false);
248            }
249            std::cmp::Ordering::Equal => {
250                // Same module — type imports go after value imports
251                if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
252                    // Insert after this value import
253                    let end = imp.byte_range.end;
254                    let insert_at = skip_newline(source, end);
255                    return (insert_at, false, false);
256                }
257                // Insert before (or it's a duplicate, caller should have checked)
258                return (imp.byte_range.start, false, false);
259            }
260            std::cmp::Ordering::Greater => continue,
261        }
262    }
263
264    // Module path sorts after all existing imports in this group — insert at end
265    let Some(last) = group_imports.last() else {
266        return (
267            import_byte_range(&block.imports)
268                .map(|range| range.end)
269                .unwrap_or(0),
270            false,
271            false,
272        );
273    };
274    let end = last.byte_range.end;
275    let insert_at = skip_newline(source, end);
276    (insert_at, false, false)
277}
278
279/// Generate an import line for the given language.
280pub fn generate_import_line(
281    lang: LangId,
282    module_path: &str,
283    names: &[String],
284    default_import: Option<&str>,
285    type_only: bool,
286) -> String {
287    match lang {
288        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
289            generate_ts_import_line(module_path, names, default_import, type_only)
290        }
291        LangId::Python => generate_py_import_line(module_path, names, default_import),
292        LangId::Rust => generate_rs_import_line(module_path, names, type_only),
293        LangId::Go => generate_go_import_line(module_path, default_import, false),
294        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => String::new(),
295        LangId::Markdown => String::new(),
296    }
297}
298
299/// Check if the given language is supported by the import engine.
300pub fn is_supported(lang: LangId) -> bool {
301    matches!(
302        lang,
303        LangId::TypeScript
304            | LangId::Tsx
305            | LangId::JavaScript
306            | LangId::Python
307            | LangId::Rust
308            | LangId::Go
309    )
310}
311
312/// Classify a module path into a group for TS/JS/TSX.
313pub fn classify_group_ts(module_path: &str) -> ImportGroup {
314    if module_path.starts_with('.') {
315        ImportGroup::Internal
316    } else {
317        ImportGroup::External
318    }
319}
320
321/// Classify a module path into a group for the given language.
322pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
323    match lang {
324        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
325        LangId::Python => classify_group_py(module_path),
326        LangId::Rust => classify_group_rs(module_path),
327        LangId::Go => classify_group_go(module_path),
328        LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => ImportGroup::External,
329        LangId::Markdown => ImportGroup::External,
330    }
331}
332
333/// Parse a file from disk and return its import block.
334/// Convenience wrapper that handles parsing.
335pub fn parse_file_imports(
336    path: &std::path::Path,
337    lang: LangId,
338) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
339    let source =
340        std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
341            path: format!("{}: {}", path.display(), e),
342        })?;
343
344    let grammar = grammar_for(lang);
345    let mut parser = Parser::new();
346    parser
347        .set_language(&grammar)
348        .map_err(|e| crate::error::AftError::ParseError {
349            message: format!("grammar init failed for {:?}: {}", lang, e),
350        })?;
351
352    let tree = parser
353        .parse(&source, None)
354        .ok_or_else(|| crate::error::AftError::ParseError {
355            message: format!("tree-sitter parse returned None for {}", path.display()),
356        })?;
357
358    let block = parse_imports(&source, &tree, lang);
359    Ok((source, tree, block))
360}
361
362// ---------------------------------------------------------------------------
363// TS/JS/TSX implementation
364// ---------------------------------------------------------------------------
365
366/// Parse imports from a TS/JS/TSX file.
367///
368/// Walks the AST root's direct children looking for `import_statement` nodes (D041).
369fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
370    let root = tree.root_node();
371    let mut imports = Vec::new();
372
373    let mut cursor = root.walk();
374    if !cursor.goto_first_child() {
375        return ImportBlock::empty();
376    }
377
378    loop {
379        let node = cursor.node();
380        if node.kind() == "import_statement" {
381            if let Some(imp) = parse_single_ts_import(source, &node) {
382                imports.push(imp);
383            }
384        }
385        if !cursor.goto_next_sibling() {
386            break;
387        }
388    }
389
390    let byte_range = import_byte_range(&imports);
391
392    ImportBlock {
393        imports,
394        byte_range,
395    }
396}
397
398/// Parse a single `import_statement` node into an `ImportStatement`.
399fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
400    let raw_text = source[node.byte_range()].to_string();
401    let byte_range = node.byte_range();
402
403    // Find the source module (string/string_fragment child of the import)
404    let module_path = extract_module_path(source, node)?;
405
406    // Determine if this is a type-only import: `import type ...`
407    let is_type_only = has_type_keyword(node);
408
409    // Extract import clause details
410    let mut names = Vec::new();
411    let mut default_import = None;
412    let mut namespace_import = None;
413
414    let mut child_cursor = node.walk();
415    if child_cursor.goto_first_child() {
416        loop {
417            let child = child_cursor.node();
418            match child.kind() {
419                "import_clause" => {
420                    extract_import_clause(
421                        source,
422                        &child,
423                        &mut names,
424                        &mut default_import,
425                        &mut namespace_import,
426                    );
427                }
428                // In some grammars, the default import is a direct identifier child
429                "identifier" => {
430                    let text = &source[child.byte_range()];
431                    if text != "import" && text != "from" && text != "type" {
432                        default_import = Some(text.to_string());
433                    }
434                }
435                _ => {}
436            }
437            if !child_cursor.goto_next_sibling() {
438                break;
439            }
440        }
441    }
442
443    // Classify kind
444    let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
445        ImportKind::SideEffect
446    } else if is_type_only {
447        ImportKind::Type
448    } else {
449        ImportKind::Value
450    };
451
452    let group = classify_group_ts(&module_path);
453
454    Some(ImportStatement {
455        module_path,
456        names,
457        default_import,
458        namespace_import,
459        kind,
460        group,
461        byte_range,
462        raw_text,
463    })
464}
465
466/// Extract the module path string from an import_statement node.
467///
468/// Looks for a `string` child node and extracts the content without quotes.
469fn extract_module_path(source: &str, node: &Node) -> Option<String> {
470    let mut cursor = node.walk();
471    if !cursor.goto_first_child() {
472        return None;
473    }
474
475    loop {
476        let child = cursor.node();
477        if child.kind() == "string" {
478            // Get the text and strip quotes
479            let text = &source[child.byte_range()];
480            let stripped = text
481                .trim_start_matches(|c| c == '\'' || c == '"')
482                .trim_end_matches(|c| c == '\'' || c == '"');
483            return Some(stripped.to_string());
484        }
485        if !cursor.goto_next_sibling() {
486            break;
487        }
488    }
489    None
490}
491
492/// Check if the import_statement has a `type` keyword (import type ...).
493///
494/// In tree-sitter-typescript, `import type { X } from 'y'` produces a `type`
495/// node as a direct child of `import_statement`, between `import` and `import_clause`.
496fn has_type_keyword(node: &Node) -> bool {
497    let mut cursor = node.walk();
498    if !cursor.goto_first_child() {
499        return false;
500    }
501
502    loop {
503        let child = cursor.node();
504        if child.kind() == "type" {
505            return true;
506        }
507        if !cursor.goto_next_sibling() {
508            break;
509        }
510    }
511
512    false
513}
514
515/// Extract named imports, default import, and namespace import from an import_clause.
516fn extract_import_clause(
517    source: &str,
518    node: &Node,
519    names: &mut Vec<String>,
520    default_import: &mut Option<String>,
521    namespace_import: &mut Option<String>,
522) {
523    let mut cursor = node.walk();
524    if !cursor.goto_first_child() {
525        return;
526    }
527
528    loop {
529        let child = cursor.node();
530        match child.kind() {
531            "identifier" => {
532                // This is a default import: `import Foo from 'bar'`
533                let text = &source[child.byte_range()];
534                if text != "type" {
535                    *default_import = Some(text.to_string());
536                }
537            }
538            "named_imports" => {
539                // `{ name1, name2 }`
540                extract_named_imports(source, &child, names);
541            }
542            "namespace_import" => {
543                // `* as name`
544                extract_namespace_import(source, &child, namespace_import);
545            }
546            _ => {}
547        }
548        if !cursor.goto_next_sibling() {
549            break;
550        }
551    }
552}
553
554/// Extract individual names from a named_imports node (`{ a, b, c }`).
555fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
556    let mut cursor = node.walk();
557    if !cursor.goto_first_child() {
558        return;
559    }
560
561    loop {
562        let child = cursor.node();
563        if child.kind() == "import_specifier" {
564            // import_specifier can have `name` (the imported name) and optional `alias`
565            if let Some(name_node) = child.child_by_field_name("name") {
566                names.push(source[name_node.byte_range()].to_string());
567            } else {
568                // Fallback: first identifier child
569                let mut spec_cursor = child.walk();
570                if spec_cursor.goto_first_child() {
571                    loop {
572                        if spec_cursor.node().kind() == "identifier"
573                            || spec_cursor.node().kind() == "type_identifier"
574                        {
575                            names.push(source[spec_cursor.node().byte_range()].to_string());
576                            break;
577                        }
578                        if !spec_cursor.goto_next_sibling() {
579                            break;
580                        }
581                    }
582                }
583            }
584        }
585        if !cursor.goto_next_sibling() {
586            break;
587        }
588    }
589}
590
591/// Extract the alias name from a namespace_import node (`* as name`).
592fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
593    let mut cursor = node.walk();
594    if !cursor.goto_first_child() {
595        return;
596    }
597
598    loop {
599        let child = cursor.node();
600        if child.kind() == "identifier" {
601            *namespace_import = Some(source[child.byte_range()].to_string());
602            return;
603        }
604        if !cursor.goto_next_sibling() {
605            break;
606        }
607    }
608}
609
610/// Generate an import line for TS/JS/TSX.
611fn generate_ts_import_line(
612    module_path: &str,
613    names: &[String],
614    default_import: Option<&str>,
615    type_only: bool,
616) -> String {
617    let type_prefix = if type_only { "type " } else { "" };
618
619    // Side-effect import
620    if names.is_empty() && default_import.is_none() {
621        return format!("import '{module_path}';");
622    }
623
624    // Default import only
625    if names.is_empty() {
626        if let Some(def) = default_import {
627            return format!("import {type_prefix}{def} from '{module_path}';");
628        }
629    }
630
631    // Named imports only
632    if default_import.is_none() {
633        let mut sorted_names = names.to_vec();
634        sorted_names.sort();
635        let names_str = sorted_names.join(", ");
636        return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
637    }
638
639    // Both default and named imports
640    if let Some(def) = default_import {
641        let mut sorted_names = names.to_vec();
642        sorted_names.sort();
643        let names_str = sorted_names.join(", ");
644        return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
645    }
646
647    // Shouldn't reach here, but handle gracefully
648    format!("import '{module_path}';")
649}
650
651// ---------------------------------------------------------------------------
652// Python implementation
653// ---------------------------------------------------------------------------
654
655/// Python 3.x standard library module names (top-level modules).
656/// Used for import group classification. Covers the commonly-used modules;
657/// unknown modules are assumed third-party.
658const PYTHON_STDLIB: &[&str] = &[
659    "__future__",
660    "_thread",
661    "abc",
662    "aifc",
663    "argparse",
664    "array",
665    "ast",
666    "asynchat",
667    "asyncio",
668    "asyncore",
669    "atexit",
670    "audioop",
671    "base64",
672    "bdb",
673    "binascii",
674    "bisect",
675    "builtins",
676    "bz2",
677    "calendar",
678    "cgi",
679    "cgitb",
680    "chunk",
681    "cmath",
682    "cmd",
683    "code",
684    "codecs",
685    "codeop",
686    "collections",
687    "colorsys",
688    "compileall",
689    "concurrent",
690    "configparser",
691    "contextlib",
692    "contextvars",
693    "copy",
694    "copyreg",
695    "cProfile",
696    "crypt",
697    "csv",
698    "ctypes",
699    "curses",
700    "dataclasses",
701    "datetime",
702    "dbm",
703    "decimal",
704    "difflib",
705    "dis",
706    "distutils",
707    "doctest",
708    "email",
709    "encodings",
710    "enum",
711    "errno",
712    "faulthandler",
713    "fcntl",
714    "filecmp",
715    "fileinput",
716    "fnmatch",
717    "fractions",
718    "ftplib",
719    "functools",
720    "gc",
721    "getopt",
722    "getpass",
723    "gettext",
724    "glob",
725    "grp",
726    "gzip",
727    "hashlib",
728    "heapq",
729    "hmac",
730    "html",
731    "http",
732    "idlelib",
733    "imaplib",
734    "imghdr",
735    "importlib",
736    "inspect",
737    "io",
738    "ipaddress",
739    "itertools",
740    "json",
741    "keyword",
742    "lib2to3",
743    "linecache",
744    "locale",
745    "logging",
746    "lzma",
747    "mailbox",
748    "mailcap",
749    "marshal",
750    "math",
751    "mimetypes",
752    "mmap",
753    "modulefinder",
754    "multiprocessing",
755    "netrc",
756    "numbers",
757    "operator",
758    "optparse",
759    "os",
760    "pathlib",
761    "pdb",
762    "pickle",
763    "pickletools",
764    "pipes",
765    "pkgutil",
766    "platform",
767    "plistlib",
768    "poplib",
769    "posixpath",
770    "pprint",
771    "profile",
772    "pstats",
773    "pty",
774    "pwd",
775    "py_compile",
776    "pyclbr",
777    "pydoc",
778    "queue",
779    "quopri",
780    "random",
781    "re",
782    "readline",
783    "reprlib",
784    "resource",
785    "rlcompleter",
786    "runpy",
787    "sched",
788    "secrets",
789    "select",
790    "selectors",
791    "shelve",
792    "shlex",
793    "shutil",
794    "signal",
795    "site",
796    "smtplib",
797    "sndhdr",
798    "socket",
799    "socketserver",
800    "sqlite3",
801    "ssl",
802    "stat",
803    "statistics",
804    "string",
805    "stringprep",
806    "struct",
807    "subprocess",
808    "symtable",
809    "sys",
810    "sysconfig",
811    "syslog",
812    "tabnanny",
813    "tarfile",
814    "tempfile",
815    "termios",
816    "textwrap",
817    "threading",
818    "time",
819    "timeit",
820    "tkinter",
821    "token",
822    "tokenize",
823    "tomllib",
824    "trace",
825    "traceback",
826    "tracemalloc",
827    "tty",
828    "turtle",
829    "types",
830    "typing",
831    "unicodedata",
832    "unittest",
833    "urllib",
834    "uuid",
835    "venv",
836    "warnings",
837    "wave",
838    "weakref",
839    "webbrowser",
840    "wsgiref",
841    "xml",
842    "xmlrpc",
843    "zipapp",
844    "zipfile",
845    "zipimport",
846    "zlib",
847];
848
849/// Classify a Python import into a group.
850pub fn classify_group_py(module_path: &str) -> ImportGroup {
851    // Relative imports start with '.'
852    if module_path.starts_with('.') {
853        return ImportGroup::Internal;
854    }
855    // Check stdlib: use the top-level module name (before first '.')
856    let top_module = module_path.split('.').next().unwrap_or(module_path);
857    if PYTHON_STDLIB.contains(&top_module) {
858        ImportGroup::Stdlib
859    } else {
860        ImportGroup::External
861    }
862}
863
864/// Parse imports from a Python file.
865fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
866    let root = tree.root_node();
867    let mut imports = Vec::new();
868
869    let mut cursor = root.walk();
870    if !cursor.goto_first_child() {
871        return ImportBlock::empty();
872    }
873
874    loop {
875        let node = cursor.node();
876        match node.kind() {
877            "import_statement" => {
878                if let Some(imp) = parse_py_import_statement(source, &node) {
879                    imports.push(imp);
880                }
881            }
882            "import_from_statement" => {
883                if let Some(imp) = parse_py_import_from_statement(source, &node) {
884                    imports.push(imp);
885                }
886            }
887            _ => {}
888        }
889        if !cursor.goto_next_sibling() {
890            break;
891        }
892    }
893
894    let byte_range = import_byte_range(&imports);
895
896    ImportBlock {
897        imports,
898        byte_range,
899    }
900}
901
902/// Parse `import X` or `import X.Y` Python statements.
903fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
904    let raw_text = source[node.byte_range()].to_string();
905    let byte_range = node.byte_range();
906
907    // Find the dotted_name child (the module name)
908    let mut module_path = String::new();
909    let mut c = node.walk();
910    if c.goto_first_child() {
911        loop {
912            if c.node().kind() == "dotted_name" {
913                module_path = source[c.node().byte_range()].to_string();
914                break;
915            }
916            if !c.goto_next_sibling() {
917                break;
918            }
919        }
920    }
921    if module_path.is_empty() {
922        return None;
923    }
924
925    let group = classify_group_py(&module_path);
926
927    Some(ImportStatement {
928        module_path,
929        names: Vec::new(),
930        default_import: None,
931        namespace_import: None,
932        kind: ImportKind::Value,
933        group,
934        byte_range,
935        raw_text,
936    })
937}
938
939/// Parse `from X import Y, Z` or `from . import Y` Python statements.
940fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
941    let raw_text = source[node.byte_range()].to_string();
942    let byte_range = node.byte_range();
943
944    let mut module_path = String::new();
945    let mut names = Vec::new();
946
947    let mut c = node.walk();
948    if c.goto_first_child() {
949        loop {
950            let child = c.node();
951            match child.kind() {
952                "dotted_name" => {
953                    // Could be the module name or an imported name
954                    // The module name comes right after `from`, imported names come after `import`
955                    // Use position: if we haven't set module_path yet and this comes
956                    // before the `import` keyword, it's the module.
957                    if module_path.is_empty()
958                        && !has_seen_import_keyword(source, node, child.start_byte())
959                    {
960                        module_path = source[child.byte_range()].to_string();
961                    } else {
962                        // It's an imported name
963                        names.push(source[child.byte_range()].to_string());
964                    }
965                }
966                "relative_import" => {
967                    // from . import X or from ..module import X
968                    module_path = source[child.byte_range()].to_string();
969                }
970                _ => {}
971            }
972            if !c.goto_next_sibling() {
973                break;
974            }
975        }
976    }
977
978    // module_path must be non-empty for a valid import
979    if module_path.is_empty() {
980        return None;
981    }
982
983    let group = classify_group_py(&module_path);
984
985    Some(ImportStatement {
986        module_path,
987        names,
988        default_import: None,
989        namespace_import: None,
990        kind: ImportKind::Value,
991        group,
992        byte_range,
993        raw_text,
994    })
995}
996
997/// Check if the `import` keyword appears before the given byte position in a from...import node.
998fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
999    let mut c = parent.walk();
1000    if c.goto_first_child() {
1001        loop {
1002            let child = c.node();
1003            if child.kind() == "import" && child.start_byte() < before_byte {
1004                return true;
1005            }
1006            if child.start_byte() >= before_byte {
1007                return false;
1008            }
1009            if !c.goto_next_sibling() {
1010                break;
1011            }
1012        }
1013    }
1014    false
1015}
1016
1017/// Generate a Python import line.
1018fn generate_py_import_line(
1019    module_path: &str,
1020    names: &[String],
1021    _default_import: Option<&str>,
1022) -> String {
1023    if names.is_empty() {
1024        // `import module`
1025        format!("import {module_path}")
1026    } else {
1027        // `from module import name1, name2`
1028        let mut sorted = names.to_vec();
1029        sorted.sort();
1030        let names_str = sorted.join(", ");
1031        format!("from {module_path} import {names_str}")
1032    }
1033}
1034
1035// ---------------------------------------------------------------------------
1036// Rust implementation
1037// ---------------------------------------------------------------------------
1038
1039/// Classify a Rust use path into a group.
1040pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1041    // Extract the first path segment (before ::)
1042    let first_seg = module_path.split("::").next().unwrap_or(module_path);
1043    match first_seg {
1044        "std" | "core" | "alloc" => ImportGroup::Stdlib,
1045        "crate" | "self" | "super" => ImportGroup::Internal,
1046        _ => ImportGroup::External,
1047    }
1048}
1049
1050/// Parse imports from a Rust file.
1051fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1052    let root = tree.root_node();
1053    let mut imports = Vec::new();
1054
1055    let mut cursor = root.walk();
1056    if !cursor.goto_first_child() {
1057        return ImportBlock::empty();
1058    }
1059
1060    loop {
1061        let node = cursor.node();
1062        if node.kind() == "use_declaration" {
1063            if let Some(imp) = parse_rs_use_declaration(source, &node) {
1064                imports.push(imp);
1065            }
1066        }
1067        if !cursor.goto_next_sibling() {
1068            break;
1069        }
1070    }
1071
1072    let byte_range = import_byte_range(&imports);
1073
1074    ImportBlock {
1075        imports,
1076        byte_range,
1077    }
1078}
1079
1080/// Parse a single `use` declaration from Rust.
1081fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1082    let raw_text = source[node.byte_range()].to_string();
1083    let byte_range = node.byte_range();
1084
1085    // Check for `pub` visibility modifier
1086    let mut has_pub = false;
1087    let mut use_path = String::new();
1088    let mut names = Vec::new();
1089
1090    let mut c = node.walk();
1091    if c.goto_first_child() {
1092        loop {
1093            let child = c.node();
1094            match child.kind() {
1095                "visibility_modifier" => {
1096                    has_pub = true;
1097                }
1098                "scoped_identifier" | "identifier" | "use_as_clause" => {
1099                    // Full path like `std::collections::HashMap` or just `serde`
1100                    use_path = source[child.byte_range()].to_string();
1101                }
1102                "scoped_use_list" => {
1103                    // e.g. `serde::{Deserialize, Serialize}`
1104                    use_path = source[child.byte_range()].to_string();
1105                    // Also extract the individual names from the use_list
1106                    extract_rs_use_list_names(source, &child, &mut names);
1107                }
1108                _ => {}
1109            }
1110            if !c.goto_next_sibling() {
1111                break;
1112            }
1113        }
1114    }
1115
1116    if use_path.is_empty() {
1117        return None;
1118    }
1119
1120    let group = classify_group_rs(&use_path);
1121
1122    Some(ImportStatement {
1123        module_path: use_path,
1124        names,
1125        default_import: if has_pub {
1126            Some("pub".to_string())
1127        } else {
1128            None
1129        },
1130        namespace_import: None,
1131        kind: ImportKind::Value,
1132        group,
1133        byte_range,
1134        raw_text,
1135    })
1136}
1137
1138/// Extract individual names from a Rust `scoped_use_list` node.
1139fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1140    let mut c = node.walk();
1141    if c.goto_first_child() {
1142        loop {
1143            let child = c.node();
1144            if child.kind() == "use_list" {
1145                // Walk into the use_list to find identifiers
1146                let mut lc = child.walk();
1147                if lc.goto_first_child() {
1148                    loop {
1149                        let lchild = lc.node();
1150                        if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1151                            names.push(source[lchild.byte_range()].to_string());
1152                        }
1153                        if !lc.goto_next_sibling() {
1154                            break;
1155                        }
1156                    }
1157                }
1158            }
1159            if !c.goto_next_sibling() {
1160                break;
1161            }
1162        }
1163    }
1164}
1165
1166/// Generate a Rust import line.
1167fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1168    if names.is_empty() {
1169        format!("use {module_path};")
1170    } else {
1171        // If names are provided, generate `use prefix::{names};`
1172        // But the caller may pass module_path as the full path including the item,
1173        // e.g., "serde::Deserialize". For simple cases, just use the module_path directly.
1174        format!("use {module_path};")
1175    }
1176}
1177
1178// ---------------------------------------------------------------------------
1179// Go implementation
1180// ---------------------------------------------------------------------------
1181
1182/// Classify a Go import path into a group.
1183pub fn classify_group_go(module_path: &str) -> ImportGroup {
1184    // stdlib paths don't contain dots (e.g., "fmt", "os", "net/http")
1185    // external paths contain dots (e.g., "github.com/pkg/errors")
1186    if module_path.contains('.') {
1187        ImportGroup::External
1188    } else {
1189        ImportGroup::Stdlib
1190    }
1191}
1192
1193/// Parse imports from a Go file.
1194fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1195    let root = tree.root_node();
1196    let mut imports = Vec::new();
1197
1198    let mut cursor = root.walk();
1199    if !cursor.goto_first_child() {
1200        return ImportBlock::empty();
1201    }
1202
1203    loop {
1204        let node = cursor.node();
1205        if node.kind() == "import_declaration" {
1206            parse_go_import_declaration(source, &node, &mut imports);
1207        }
1208        if !cursor.goto_next_sibling() {
1209            break;
1210        }
1211    }
1212
1213    let byte_range = import_byte_range(&imports);
1214
1215    ImportBlock {
1216        imports,
1217        byte_range,
1218    }
1219}
1220
1221/// Parse a single Go import_declaration (may contain one or multiple specs).
1222fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1223    let mut c = node.walk();
1224    if c.goto_first_child() {
1225        loop {
1226            let child = c.node();
1227            match child.kind() {
1228                "import_spec" => {
1229                    if let Some(imp) = parse_go_import_spec(source, &child) {
1230                        imports.push(imp);
1231                    }
1232                }
1233                "import_spec_list" => {
1234                    // Grouped imports: walk into the list
1235                    let mut lc = child.walk();
1236                    if lc.goto_first_child() {
1237                        loop {
1238                            if lc.node().kind() == "import_spec" {
1239                                if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1240                                    imports.push(imp);
1241                                }
1242                            }
1243                            if !lc.goto_next_sibling() {
1244                                break;
1245                            }
1246                        }
1247                    }
1248                }
1249                _ => {}
1250            }
1251            if !c.goto_next_sibling() {
1252                break;
1253            }
1254        }
1255    }
1256}
1257
1258/// Parse a single Go import_spec node.
1259fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1260    let raw_text = source[node.byte_range()].to_string();
1261    let byte_range = node.byte_range();
1262
1263    let mut import_path = String::new();
1264    let mut alias = None;
1265
1266    let mut c = node.walk();
1267    if c.goto_first_child() {
1268        loop {
1269            let child = c.node();
1270            match child.kind() {
1271                "interpreted_string_literal" => {
1272                    // Extract the path without quotes
1273                    let text = source[child.byte_range()].to_string();
1274                    import_path = text.trim_matches('"').to_string();
1275                }
1276                "identifier" | "blank_identifier" | "dot" => {
1277                    // This is an alias (e.g., `alias "path"` or `. "path"` or `_ "path"`)
1278                    alias = Some(source[child.byte_range()].to_string());
1279                }
1280                _ => {}
1281            }
1282            if !c.goto_next_sibling() {
1283                break;
1284            }
1285        }
1286    }
1287
1288    if import_path.is_empty() {
1289        return None;
1290    }
1291
1292    let group = classify_group_go(&import_path);
1293
1294    Some(ImportStatement {
1295        module_path: import_path,
1296        names: Vec::new(),
1297        default_import: alias,
1298        namespace_import: None,
1299        kind: ImportKind::Value,
1300        group,
1301        byte_range,
1302        raw_text,
1303    })
1304}
1305
1306/// Public API for Go import line generation (used by add_import handler).
1307pub fn generate_go_import_line_pub(
1308    module_path: &str,
1309    alias: Option<&str>,
1310    in_group: bool,
1311) -> String {
1312    generate_go_import_line(module_path, alias, in_group)
1313}
1314
1315/// Generate a Go import line (public API for command handler).
1316///
1317/// `in_group` controls whether to generate a spec for insertion into an
1318/// existing grouped import (`\t"path"`) or a standalone import (`import "path"`).
1319fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1320    if in_group {
1321        // Spec for grouped import block
1322        match alias {
1323            Some(a) => format!("\t{a} \"{module_path}\""),
1324            None => format!("\t\"{module_path}\""),
1325        }
1326    } else {
1327        // Standalone import
1328        match alias {
1329            Some(a) => format!("import {a} \"{module_path}\""),
1330            None => format!("import \"{module_path}\""),
1331        }
1332    }
1333}
1334
1335/// Check if a Go import block has a grouped import declaration.
1336/// Returns the byte range of the import_spec_list if found.
1337pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1338    let root = tree.root_node();
1339    let mut cursor = root.walk();
1340    if !cursor.goto_first_child() {
1341        return None;
1342    }
1343
1344    loop {
1345        let node = cursor.node();
1346        if node.kind() == "import_declaration" {
1347            let mut c = node.walk();
1348            if c.goto_first_child() {
1349                loop {
1350                    if c.node().kind() == "import_spec_list" {
1351                        return Some(c.node().byte_range());
1352                    }
1353                    if !c.goto_next_sibling() {
1354                        break;
1355                    }
1356                }
1357            }
1358        }
1359        if !cursor.goto_next_sibling() {
1360            break;
1361        }
1362    }
1363    None
1364}
1365
1366/// Skip past a newline character at the given position.
1367fn skip_newline(source: &str, pos: usize) -> usize {
1368    if pos < source.len() {
1369        let bytes = source.as_bytes();
1370        if bytes[pos] == b'\n' {
1371            return pos + 1;
1372        }
1373        if bytes[pos] == b'\r' {
1374            if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1375                return pos + 2;
1376            }
1377            return pos + 1;
1378        }
1379    }
1380    pos
1381}
1382
1383// ---------------------------------------------------------------------------
1384// Unit tests
1385// ---------------------------------------------------------------------------
1386
1387#[cfg(test)]
1388mod tests {
1389    use super::*;
1390
1391    fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1392        let grammar = grammar_for(LangId::TypeScript);
1393        let mut parser = Parser::new();
1394        parser.set_language(&grammar).unwrap();
1395        let tree = parser.parse(source, None).unwrap();
1396        let block = parse_imports(source, &tree, LangId::TypeScript);
1397        (tree, block)
1398    }
1399
1400    fn parse_js(source: &str) -> (Tree, ImportBlock) {
1401        let grammar = grammar_for(LangId::JavaScript);
1402        let mut parser = Parser::new();
1403        parser.set_language(&grammar).unwrap();
1404        let tree = parser.parse(source, None).unwrap();
1405        let block = parse_imports(source, &tree, LangId::JavaScript);
1406        (tree, block)
1407    }
1408
1409    // --- Basic parsing ---
1410
1411    #[test]
1412    fn parse_ts_named_imports() {
1413        let source = "import { useState, useEffect } from 'react';\n";
1414        let (_, block) = parse_ts(source);
1415        assert_eq!(block.imports.len(), 1);
1416        let imp = &block.imports[0];
1417        assert_eq!(imp.module_path, "react");
1418        assert!(imp.names.contains(&"useState".to_string()));
1419        assert!(imp.names.contains(&"useEffect".to_string()));
1420        assert_eq!(imp.kind, ImportKind::Value);
1421        assert_eq!(imp.group, ImportGroup::External);
1422    }
1423
1424    #[test]
1425    fn parse_ts_default_import() {
1426        let source = "import React from 'react';\n";
1427        let (_, block) = parse_ts(source);
1428        assert_eq!(block.imports.len(), 1);
1429        let imp = &block.imports[0];
1430        assert_eq!(imp.default_import.as_deref(), Some("React"));
1431        assert_eq!(imp.kind, ImportKind::Value);
1432    }
1433
1434    #[test]
1435    fn parse_ts_side_effect_import() {
1436        let source = "import './styles.css';\n";
1437        let (_, block) = parse_ts(source);
1438        assert_eq!(block.imports.len(), 1);
1439        assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1440        assert_eq!(block.imports[0].module_path, "./styles.css");
1441    }
1442
1443    #[test]
1444    fn parse_ts_relative_import() {
1445        let source = "import { helper } from './utils';\n";
1446        let (_, block) = parse_ts(source);
1447        assert_eq!(block.imports.len(), 1);
1448        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1449    }
1450
1451    #[test]
1452    fn parse_ts_multiple_groups() {
1453        let source = "\
1454import React from 'react';
1455import { useState } from 'react';
1456import { helper } from './utils';
1457import { Config } from '../config';
1458";
1459        let (_, block) = parse_ts(source);
1460        assert_eq!(block.imports.len(), 4);
1461
1462        let external: Vec<_> = block
1463            .imports
1464            .iter()
1465            .filter(|i| i.group == ImportGroup::External)
1466            .collect();
1467        let relative: Vec<_> = block
1468            .imports
1469            .iter()
1470            .filter(|i| i.group == ImportGroup::Internal)
1471            .collect();
1472        assert_eq!(external.len(), 2);
1473        assert_eq!(relative.len(), 2);
1474    }
1475
1476    #[test]
1477    fn parse_ts_namespace_import() {
1478        let source = "import * as path from 'path';\n";
1479        let (_, block) = parse_ts(source);
1480        assert_eq!(block.imports.len(), 1);
1481        let imp = &block.imports[0];
1482        assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1483        assert_eq!(imp.kind, ImportKind::Value);
1484    }
1485
1486    #[test]
1487    fn parse_js_imports() {
1488        let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1489        let (_, block) = parse_js(source);
1490        assert_eq!(block.imports.len(), 2);
1491        assert_eq!(block.imports[0].group, ImportGroup::External);
1492        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1493    }
1494
1495    // --- Group classification ---
1496
1497    #[test]
1498    fn classify_external() {
1499        assert_eq!(classify_group_ts("react"), ImportGroup::External);
1500        assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1501        assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1502    }
1503
1504    #[test]
1505    fn classify_relative() {
1506        assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1507        assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1508        assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1509    }
1510
1511    // --- Dedup ---
1512
1513    #[test]
1514    fn dedup_detects_same_named_import() {
1515        let source = "import { useState } from 'react';\n";
1516        let (_, block) = parse_ts(source);
1517        assert!(is_duplicate(
1518            &block,
1519            "react",
1520            &["useState".to_string()],
1521            None,
1522            false
1523        ));
1524    }
1525
1526    #[test]
1527    fn dedup_misses_different_name() {
1528        let source = "import { useState } from 'react';\n";
1529        let (_, block) = parse_ts(source);
1530        assert!(!is_duplicate(
1531            &block,
1532            "react",
1533            &["useEffect".to_string()],
1534            None,
1535            false
1536        ));
1537    }
1538
1539    #[test]
1540    fn dedup_detects_default_import() {
1541        let source = "import React from 'react';\n";
1542        let (_, block) = parse_ts(source);
1543        assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1544    }
1545
1546    #[test]
1547    fn dedup_side_effect() {
1548        let source = "import './styles.css';\n";
1549        let (_, block) = parse_ts(source);
1550        assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1551    }
1552
1553    #[test]
1554    fn dedup_type_vs_value() {
1555        let source = "import { FC } from 'react';\n";
1556        let (_, block) = parse_ts(source);
1557        // Type import should NOT match a value import of the same name
1558        assert!(!is_duplicate(
1559            &block,
1560            "react",
1561            &["FC".to_string()],
1562            None,
1563            true
1564        ));
1565    }
1566
1567    // --- Generation ---
1568
1569    #[test]
1570    fn generate_named_import() {
1571        let line = generate_import_line(
1572            LangId::TypeScript,
1573            "react",
1574            &["useState".to_string(), "useEffect".to_string()],
1575            None,
1576            false,
1577        );
1578        assert_eq!(line, "import { useEffect, useState } from 'react';");
1579    }
1580
1581    #[test]
1582    fn generate_default_import() {
1583        let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1584        assert_eq!(line, "import React from 'react';");
1585    }
1586
1587    #[test]
1588    fn generate_type_import() {
1589        let line =
1590            generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1591        assert_eq!(line, "import type { FC } from 'react';");
1592    }
1593
1594    #[test]
1595    fn generate_side_effect_import() {
1596        let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1597        assert_eq!(line, "import './styles.css';");
1598    }
1599
1600    #[test]
1601    fn generate_default_and_named() {
1602        let line = generate_import_line(
1603            LangId::TypeScript,
1604            "react",
1605            &["useState".to_string()],
1606            Some("React"),
1607            false,
1608        );
1609        assert_eq!(line, "import React, { useState } from 'react';");
1610    }
1611
1612    #[test]
1613    fn parse_ts_type_import() {
1614        let source = "import type { FC } from 'react';\n";
1615        let (_, block) = parse_ts(source);
1616        assert_eq!(block.imports.len(), 1);
1617        let imp = &block.imports[0];
1618        assert_eq!(imp.kind, ImportKind::Type);
1619        assert!(imp.names.contains(&"FC".to_string()));
1620        assert_eq!(imp.group, ImportGroup::External);
1621    }
1622
1623    // --- Insertion point ---
1624
1625    #[test]
1626    fn insertion_empty_file() {
1627        let source = "";
1628        let (_, block) = parse_ts(source);
1629        let (offset, _, _) =
1630            find_insertion_point(source, &block, ImportGroup::External, "react", false);
1631        assert_eq!(offset, 0);
1632    }
1633
1634    #[test]
1635    fn insertion_alphabetical_within_group() {
1636        let source = "\
1637import { a } from 'alpha';
1638import { c } from 'charlie';
1639";
1640        let (_, block) = parse_ts(source);
1641        let (offset, _, _) =
1642            find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1643        // Should insert before 'charlie' (which starts at line 2)
1644        let before_charlie = source.find("import { c }").unwrap();
1645        assert_eq!(offset, before_charlie);
1646    }
1647
1648    // --- Python parsing ---
1649
1650    fn parse_py(source: &str) -> (Tree, ImportBlock) {
1651        let grammar = grammar_for(LangId::Python);
1652        let mut parser = Parser::new();
1653        parser.set_language(&grammar).unwrap();
1654        let tree = parser.parse(source, None).unwrap();
1655        let block = parse_imports(source, &tree, LangId::Python);
1656        (tree, block)
1657    }
1658
1659    #[test]
1660    fn parse_py_import_statement() {
1661        let source = "import os\nimport sys\n";
1662        let (_, block) = parse_py(source);
1663        assert_eq!(block.imports.len(), 2);
1664        assert_eq!(block.imports[0].module_path, "os");
1665        assert_eq!(block.imports[1].module_path, "sys");
1666        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1667    }
1668
1669    #[test]
1670    fn parse_py_from_import() {
1671        let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1672        let (_, block) = parse_py(source);
1673        assert_eq!(block.imports.len(), 2);
1674        assert_eq!(block.imports[0].module_path, "collections");
1675        assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1676        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1677        assert_eq!(block.imports[1].module_path, "typing");
1678        assert!(block.imports[1].names.contains(&"List".to_string()));
1679        assert!(block.imports[1].names.contains(&"Optional".to_string()));
1680    }
1681
1682    #[test]
1683    fn parse_py_relative_import() {
1684        let source = "from . import utils\nfrom ..config import Settings\n";
1685        let (_, block) = parse_py(source);
1686        assert_eq!(block.imports.len(), 2);
1687        assert_eq!(block.imports[0].module_path, ".");
1688        assert!(block.imports[0].names.contains(&"utils".to_string()));
1689        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1690        assert_eq!(block.imports[1].module_path, "..config");
1691        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1692    }
1693
1694    #[test]
1695    fn classify_py_groups() {
1696        assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1697        assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1698        assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1699        assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1700        assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1701        assert_eq!(classify_group_py("requests"), ImportGroup::External);
1702        assert_eq!(classify_group_py("flask"), ImportGroup::External);
1703        assert_eq!(classify_group_py("."), ImportGroup::Internal);
1704        assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1705        assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1706    }
1707
1708    #[test]
1709    fn parse_py_three_groups() {
1710        let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1711        let (_, block) = parse_py(source);
1712        let stdlib: Vec<_> = block
1713            .imports
1714            .iter()
1715            .filter(|i| i.group == ImportGroup::Stdlib)
1716            .collect();
1717        let external: Vec<_> = block
1718            .imports
1719            .iter()
1720            .filter(|i| i.group == ImportGroup::External)
1721            .collect();
1722        let internal: Vec<_> = block
1723            .imports
1724            .iter()
1725            .filter(|i| i.group == ImportGroup::Internal)
1726            .collect();
1727        assert_eq!(stdlib.len(), 2);
1728        assert_eq!(external.len(), 1);
1729        assert_eq!(internal.len(), 1);
1730    }
1731
1732    #[test]
1733    fn generate_py_import() {
1734        let line = generate_import_line(LangId::Python, "os", &[], None, false);
1735        assert_eq!(line, "import os");
1736    }
1737
1738    #[test]
1739    fn generate_py_from_import() {
1740        let line = generate_import_line(
1741            LangId::Python,
1742            "collections",
1743            &["OrderedDict".to_string()],
1744            None,
1745            false,
1746        );
1747        assert_eq!(line, "from collections import OrderedDict");
1748    }
1749
1750    #[test]
1751    fn generate_py_from_import_multiple() {
1752        let line = generate_import_line(
1753            LangId::Python,
1754            "typing",
1755            &["Optional".to_string(), "List".to_string()],
1756            None,
1757            false,
1758        );
1759        assert_eq!(line, "from typing import List, Optional");
1760    }
1761
1762    // --- Rust parsing ---
1763
1764    fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1765        let grammar = grammar_for(LangId::Rust);
1766        let mut parser = Parser::new();
1767        parser.set_language(&grammar).unwrap();
1768        let tree = parser.parse(source, None).unwrap();
1769        let block = parse_imports(source, &tree, LangId::Rust);
1770        (tree, block)
1771    }
1772
1773    #[test]
1774    fn parse_rs_use_std() {
1775        let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1776        let (_, block) = parse_rust(source);
1777        assert_eq!(block.imports.len(), 2);
1778        assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1779        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1780        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1781    }
1782
1783    #[test]
1784    fn parse_rs_use_external() {
1785        let source = "use serde::{Deserialize, Serialize};\n";
1786        let (_, block) = parse_rust(source);
1787        assert_eq!(block.imports.len(), 1);
1788        assert_eq!(block.imports[0].group, ImportGroup::External);
1789        assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1790        assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1791    }
1792
1793    #[test]
1794    fn parse_rs_use_crate() {
1795        let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1796        let (_, block) = parse_rust(source);
1797        assert_eq!(block.imports.len(), 2);
1798        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1799        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1800    }
1801
1802    #[test]
1803    fn parse_rs_pub_use() {
1804        let source = "pub use super::parent::Thing;\n";
1805        let (_, block) = parse_rust(source);
1806        assert_eq!(block.imports.len(), 1);
1807        // `pub` is stored in default_import as a marker
1808        assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1809    }
1810
1811    #[test]
1812    fn classify_rs_groups() {
1813        assert_eq!(
1814            classify_group_rs("std::collections::HashMap"),
1815            ImportGroup::Stdlib
1816        );
1817        assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
1818        assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
1819        assert_eq!(
1820            classify_group_rs("serde::Deserialize"),
1821            ImportGroup::External
1822        );
1823        assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
1824        assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
1825        assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
1826        assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
1827    }
1828
1829    #[test]
1830    fn generate_rs_use() {
1831        let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
1832        assert_eq!(line, "use std::fmt::Display;");
1833    }
1834
1835    // --- Go parsing ---
1836
1837    fn parse_go(source: &str) -> (Tree, ImportBlock) {
1838        let grammar = grammar_for(LangId::Go);
1839        let mut parser = Parser::new();
1840        parser.set_language(&grammar).unwrap();
1841        let tree = parser.parse(source, None).unwrap();
1842        let block = parse_imports(source, &tree, LangId::Go);
1843        (tree, block)
1844    }
1845
1846    #[test]
1847    fn parse_go_single_import() {
1848        let source = "package main\n\nimport \"fmt\"\n";
1849        let (_, block) = parse_go(source);
1850        assert_eq!(block.imports.len(), 1);
1851        assert_eq!(block.imports[0].module_path, "fmt");
1852        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1853    }
1854
1855    #[test]
1856    fn parse_go_grouped_import() {
1857        let source =
1858            "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
1859        let (_, block) = parse_go(source);
1860        assert_eq!(block.imports.len(), 3);
1861        assert_eq!(block.imports[0].module_path, "fmt");
1862        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1863        assert_eq!(block.imports[1].module_path, "os");
1864        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1865        assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
1866        assert_eq!(block.imports[2].group, ImportGroup::External);
1867    }
1868
1869    #[test]
1870    fn parse_go_mixed_imports() {
1871        // Single + grouped
1872        let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
1873        let (_, block) = parse_go(source);
1874        assert_eq!(block.imports.len(), 3);
1875    }
1876
1877    #[test]
1878    fn classify_go_groups() {
1879        assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
1880        assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
1881        assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
1882        assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
1883        assert_eq!(
1884            classify_group_go("github.com/pkg/errors"),
1885            ImportGroup::External
1886        );
1887        assert_eq!(
1888            classify_group_go("golang.org/x/tools"),
1889            ImportGroup::External
1890        );
1891    }
1892
1893    #[test]
1894    fn generate_go_standalone() {
1895        let line = generate_go_import_line("fmt", None, false);
1896        assert_eq!(line, "import \"fmt\"");
1897    }
1898
1899    #[test]
1900    fn generate_go_grouped_spec() {
1901        let line = generate_go_import_line("fmt", None, true);
1902        assert_eq!(line, "\t\"fmt\"");
1903    }
1904
1905    #[test]
1906    fn generate_go_with_alias() {
1907        let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
1908        assert_eq!(line, "import errs \"github.com/pkg/errors\"");
1909    }
1910}