Skip to main content

aft/
imports.rs

1//! Import analysis engine: parsing, grouping, deduplication, and insertion.
2//!
3//! Provides per-language import handling dispatched by `LangId`. Each language
4//! implementation extracts imports from tree-sitter ASTs, classifies them into
5//! groups, and generates import text.
6//!
7//! Currently supports: TypeScript, TSX, JavaScript.
8
9use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15// ---------------------------------------------------------------------------
16// Shared types
17// ---------------------------------------------------------------------------
18
19/// What kind of import this is.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22    /// `import { X } from 'y'` or `import X from 'y'`
23    Value,
24    /// `import type { X } from 'y'`
25    Type,
26    /// `import './side-effect'`
27    SideEffect,
28}
29
30/// Which logical group an import belongs to (language-specific).
31///
32/// Ordering matches conventional import group sorting:
33///   Stdlib (first) < External < Internal (last)
34///
35/// Language mapping:
36///   - TS/JS/TSX: External (no `.` prefix), Internal (`.`/`..` prefix)
37///   - Python:    Stdlib, External (third-party), Internal (relative `.`/`..`)
38///   - Rust:      Stdlib (std/core/alloc), External (crates), Internal (crate/self/super)
39///   - Go:        Stdlib (no dots in path), External (dots in path)
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42    /// Standard library (Python stdlib, Rust std/core/alloc, Go stdlib).
43    /// TS/JS don't use this group.
44    Stdlib,
45    /// External/third-party packages.
46    External,
47    /// Internal/relative imports (TS relative, Python local, Rust crate/self/super).
48    Internal,
49}
50
51impl ImportGroup {
52    /// Human-readable label for the group.
53    pub fn label(&self) -> &'static str {
54        match self {
55            ImportGroup::Stdlib => "stdlib",
56            ImportGroup::External => "external",
57            ImportGroup::Internal => "internal",
58        }
59    }
60}
61
62/// A single parsed import statement.
63#[derive(Debug, Clone)]
64pub struct ImportStatement {
65    /// The module path (e.g., `react`, `./utils`, `../config`).
66    pub module_path: String,
67    /// Named imports (e.g., `["useState", "useEffect"]`).
68    pub names: Vec<String>,
69    /// Default import name (e.g., `React` from `import React from 'react'`).
70    pub default_import: Option<String>,
71    /// Namespace import name (e.g., `path` from `import * as path from 'path'`).
72    pub namespace_import: Option<String>,
73    /// What kind: value, type, or side-effect.
74    pub kind: ImportKind,
75    /// Which group this import belongs to.
76    pub group: ImportGroup,
77    /// Byte range in the original source.
78    pub byte_range: Range<usize>,
79    /// Raw text of the import statement.
80    pub raw_text: String,
81}
82
83/// A block of parsed imports from a file.
84#[derive(Debug, Clone)]
85pub struct ImportBlock {
86    /// All parsed import statements, in source order.
87    pub imports: Vec<ImportStatement>,
88    /// Overall byte range covering all import statements (start of first to end of last).
89    /// `None` if no imports found.
90    pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94    pub fn empty() -> Self {
95        ImportBlock {
96            imports: Vec::new(),
97            byte_range: None,
98        }
99    }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103    imports.first().zip(imports.last()).map(|(first, last)| {
104        let start = first.byte_range.start;
105        let end = last.byte_range.end;
106        start..end
107    })
108}
109
110// ---------------------------------------------------------------------------
111// Core API
112// ---------------------------------------------------------------------------
113
114/// Parse imports from source using the provided tree-sitter tree.
115pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
116    match lang {
117        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
118        LangId::Python => parse_py_imports(source, tree),
119        LangId::Rust => parse_rs_imports(source, tree),
120        LangId::Go => parse_go_imports(source, tree),
121        LangId::Markdown => ImportBlock::empty(),
122    }
123}
124
125/// Check if an import with the given module + name combination already exists.
126///
127/// For dedup: same module path AND (same named import OR same default import).
128/// Side-effect imports match on module path alone.
129pub fn is_duplicate(
130    block: &ImportBlock,
131    module_path: &str,
132    names: &[String],
133    default_import: Option<&str>,
134    type_only: bool,
135) -> bool {
136    let target_kind = if type_only {
137        ImportKind::Type
138    } else {
139        ImportKind::Value
140    };
141
142    for imp in &block.imports {
143        if imp.module_path != module_path {
144            continue;
145        }
146
147        // For side-effect imports or whole-module imports (no names, no default):
148        // module path match alone is sufficient.
149        if names.is_empty()
150            && default_import.is_none()
151            && imp.names.is_empty()
152            && imp.default_import.is_none()
153        {
154            return true;
155        }
156
157        // For side-effect imports specifically (TS/JS): module match is enough
158        if names.is_empty() && default_import.is_none() && imp.kind == ImportKind::SideEffect {
159            return true;
160        }
161
162        // Kind must match for dedup (value imports don't dedup against type imports)
163        if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
164            continue;
165        }
166
167        // Check default import match
168        if let Some(def) = default_import {
169            if imp.default_import.as_deref() == Some(def) {
170                return true;
171            }
172        }
173
174        // Check named imports — if ALL requested names already exist
175        if !names.is_empty() && names.iter().all(|n| imp.names.contains(n)) {
176            return true;
177        }
178    }
179
180    false
181}
182
183/// Find the byte offset where a new import should be inserted.
184///
185/// Strategy:
186/// - Find all existing imports in the same group.
187/// - Within that group, find the alphabetical position by module path.
188/// - Type imports sort after value imports within the same group and module-sort position.
189/// - If no imports exist in the target group, insert after the last import of the
190///   nearest preceding group (or before the first import of the nearest following
191///   group, or at file start if no groups exist).
192/// - Returns (byte_offset, needs_newline_before, needs_newline_after)
193pub fn find_insertion_point(
194    source: &str,
195    block: &ImportBlock,
196    group: ImportGroup,
197    module_path: &str,
198    type_only: bool,
199) -> (usize, bool, bool) {
200    if block.imports.is_empty() {
201        // No imports at all — insert at start of file
202        return (0, false, source.is_empty().then_some(false).unwrap_or(true));
203    }
204
205    let target_kind = if type_only {
206        ImportKind::Type
207    } else {
208        ImportKind::Value
209    };
210
211    // Collect imports in the target group
212    let group_imports: Vec<&ImportStatement> =
213        block.imports.iter().filter(|i| i.group == group).collect();
214
215    if group_imports.is_empty() {
216        // No imports in this group yet — find nearest neighbor group
217        // Try preceding groups (lower ordinal) first
218        let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
219
220        if let Some(last) = preceding_last {
221            let end = last.byte_range.end;
222            let insert_at = skip_newline(source, end);
223            return (insert_at, true, true);
224        }
225
226        // No preceding group — try following groups (higher ordinal)
227        let following_first = block.imports.iter().find(|i| i.group > group);
228
229        if let Some(first) = following_first {
230            return (first.byte_range.start, false, true);
231        }
232
233        // Shouldn't reach here if block is non-empty, but handle gracefully
234        let first_byte = import_byte_range(&block.imports)
235            .map(|range| range.start)
236            .unwrap_or(0);
237        return (first_byte, false, true);
238    }
239
240    // Find position within the group (alphabetical by module path, type after value)
241    for imp in &group_imports {
242        let cmp = module_path.cmp(&imp.module_path);
243        match cmp {
244            std::cmp::Ordering::Less => {
245                // Insert before this import
246                return (imp.byte_range.start, false, false);
247            }
248            std::cmp::Ordering::Equal => {
249                // Same module — type imports go after value imports
250                if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
251                    // Insert after this value import
252                    let end = imp.byte_range.end;
253                    let insert_at = skip_newline(source, end);
254                    return (insert_at, false, false);
255                }
256                // Insert before (or it's a duplicate, caller should have checked)
257                return (imp.byte_range.start, false, false);
258            }
259            std::cmp::Ordering::Greater => continue,
260        }
261    }
262
263    // Module path sorts after all existing imports in this group — insert at end
264    let Some(last) = group_imports.last() else {
265        return (
266            import_byte_range(&block.imports)
267                .map(|range| range.end)
268                .unwrap_or(0),
269            false,
270            false,
271        );
272    };
273    let end = last.byte_range.end;
274    let insert_at = skip_newline(source, end);
275    (insert_at, false, false)
276}
277
278/// Generate an import line for the given language.
279pub fn generate_import_line(
280    lang: LangId,
281    module_path: &str,
282    names: &[String],
283    default_import: Option<&str>,
284    type_only: bool,
285) -> String {
286    match lang {
287        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
288            generate_ts_import_line(module_path, names, default_import, type_only)
289        }
290        LangId::Python => generate_py_import_line(module_path, names, default_import),
291        LangId::Rust => generate_rs_import_line(module_path, names, type_only),
292        LangId::Go => generate_go_import_line(module_path, default_import, false),
293        LangId::Markdown => String::new(),
294    }
295}
296
297/// Check if the given language is supported by the import engine.
298pub fn is_supported(lang: LangId) -> bool {
299    matches!(
300        lang,
301        LangId::TypeScript
302            | LangId::Tsx
303            | LangId::JavaScript
304            | LangId::Python
305            | LangId::Rust
306            | LangId::Go
307    )
308}
309
310/// Classify a module path into a group for TS/JS/TSX.
311pub fn classify_group_ts(module_path: &str) -> ImportGroup {
312    if module_path.starts_with('.') {
313        ImportGroup::Internal
314    } else {
315        ImportGroup::External
316    }
317}
318
319/// Classify a module path into a group for the given language.
320pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
321    match lang {
322        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
323        LangId::Python => classify_group_py(module_path),
324        LangId::Rust => classify_group_rs(module_path),
325        LangId::Go => classify_group_go(module_path),
326        LangId::Markdown => ImportGroup::External,
327    }
328}
329
330/// Parse a file from disk and return its import block.
331/// Convenience wrapper that handles parsing.
332pub fn parse_file_imports(
333    path: &std::path::Path,
334    lang: LangId,
335) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
336    let source =
337        std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
338            path: format!("{}: {}", path.display(), e),
339        })?;
340
341    let grammar = grammar_for(lang);
342    let mut parser = Parser::new();
343    parser
344        .set_language(&grammar)
345        .map_err(|e| crate::error::AftError::ParseError {
346            message: format!("grammar init failed for {:?}: {}", lang, e),
347        })?;
348
349    let tree = parser
350        .parse(&source, None)
351        .ok_or_else(|| crate::error::AftError::ParseError {
352            message: format!("tree-sitter parse returned None for {}", path.display()),
353        })?;
354
355    let block = parse_imports(&source, &tree, lang);
356    Ok((source, tree, block))
357}
358
359// ---------------------------------------------------------------------------
360// TS/JS/TSX implementation
361// ---------------------------------------------------------------------------
362
363/// Parse imports from a TS/JS/TSX file.
364///
365/// Walks the AST root's direct children looking for `import_statement` nodes (D041).
366fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
367    let root = tree.root_node();
368    let mut imports = Vec::new();
369
370    let mut cursor = root.walk();
371    if !cursor.goto_first_child() {
372        return ImportBlock::empty();
373    }
374
375    loop {
376        let node = cursor.node();
377        if node.kind() == "import_statement" {
378            if let Some(imp) = parse_single_ts_import(source, &node) {
379                imports.push(imp);
380            }
381        }
382        if !cursor.goto_next_sibling() {
383            break;
384        }
385    }
386
387    let byte_range = import_byte_range(&imports);
388
389    ImportBlock {
390        imports,
391        byte_range,
392    }
393}
394
395/// Parse a single `import_statement` node into an `ImportStatement`.
396fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
397    let raw_text = source[node.byte_range()].to_string();
398    let byte_range = node.byte_range();
399
400    // Find the source module (string/string_fragment child of the import)
401    let module_path = extract_module_path(source, node)?;
402
403    // Determine if this is a type-only import: `import type ...`
404    let is_type_only = has_type_keyword(node);
405
406    // Extract import clause details
407    let mut names = Vec::new();
408    let mut default_import = None;
409    let mut namespace_import = None;
410
411    let mut child_cursor = node.walk();
412    if child_cursor.goto_first_child() {
413        loop {
414            let child = child_cursor.node();
415            match child.kind() {
416                "import_clause" => {
417                    extract_import_clause(
418                        source,
419                        &child,
420                        &mut names,
421                        &mut default_import,
422                        &mut namespace_import,
423                    );
424                }
425                // In some grammars, the default import is a direct identifier child
426                "identifier" => {
427                    let text = &source[child.byte_range()];
428                    if text != "import" && text != "from" && text != "type" {
429                        default_import = Some(text.to_string());
430                    }
431                }
432                _ => {}
433            }
434            if !child_cursor.goto_next_sibling() {
435                break;
436            }
437        }
438    }
439
440    // Classify kind
441    let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
442        ImportKind::SideEffect
443    } else if is_type_only {
444        ImportKind::Type
445    } else {
446        ImportKind::Value
447    };
448
449    let group = classify_group_ts(&module_path);
450
451    Some(ImportStatement {
452        module_path,
453        names,
454        default_import,
455        namespace_import,
456        kind,
457        group,
458        byte_range,
459        raw_text,
460    })
461}
462
463/// Extract the module path string from an import_statement node.
464///
465/// Looks for a `string` child node and extracts the content without quotes.
466fn extract_module_path(source: &str, node: &Node) -> Option<String> {
467    let mut cursor = node.walk();
468    if !cursor.goto_first_child() {
469        return None;
470    }
471
472    loop {
473        let child = cursor.node();
474        if child.kind() == "string" {
475            // Get the text and strip quotes
476            let text = &source[child.byte_range()];
477            let stripped = text
478                .trim_start_matches(|c| c == '\'' || c == '"')
479                .trim_end_matches(|c| c == '\'' || c == '"');
480            return Some(stripped.to_string());
481        }
482        if !cursor.goto_next_sibling() {
483            break;
484        }
485    }
486    None
487}
488
489/// Check if the import_statement has a `type` keyword (import type ...).
490///
491/// In tree-sitter-typescript, `import type { X } from 'y'` produces a `type`
492/// node as a direct child of `import_statement`, between `import` and `import_clause`.
493fn has_type_keyword(node: &Node) -> bool {
494    let mut cursor = node.walk();
495    if !cursor.goto_first_child() {
496        return false;
497    }
498
499    loop {
500        let child = cursor.node();
501        if child.kind() == "type" {
502            return true;
503        }
504        if !cursor.goto_next_sibling() {
505            break;
506        }
507    }
508
509    false
510}
511
512/// Extract named imports, default import, and namespace import from an import_clause.
513fn extract_import_clause(
514    source: &str,
515    node: &Node,
516    names: &mut Vec<String>,
517    default_import: &mut Option<String>,
518    namespace_import: &mut Option<String>,
519) {
520    let mut cursor = node.walk();
521    if !cursor.goto_first_child() {
522        return;
523    }
524
525    loop {
526        let child = cursor.node();
527        match child.kind() {
528            "identifier" => {
529                // This is a default import: `import Foo from 'bar'`
530                let text = &source[child.byte_range()];
531                if text != "type" {
532                    *default_import = Some(text.to_string());
533                }
534            }
535            "named_imports" => {
536                // `{ name1, name2 }`
537                extract_named_imports(source, &child, names);
538            }
539            "namespace_import" => {
540                // `* as name`
541                extract_namespace_import(source, &child, namespace_import);
542            }
543            _ => {}
544        }
545        if !cursor.goto_next_sibling() {
546            break;
547        }
548    }
549}
550
551/// Extract individual names from a named_imports node (`{ a, b, c }`).
552fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
553    let mut cursor = node.walk();
554    if !cursor.goto_first_child() {
555        return;
556    }
557
558    loop {
559        let child = cursor.node();
560        if child.kind() == "import_specifier" {
561            // import_specifier can have `name` (the imported name) and optional `alias`
562            if let Some(name_node) = child.child_by_field_name("name") {
563                names.push(source[name_node.byte_range()].to_string());
564            } else {
565                // Fallback: first identifier child
566                let mut spec_cursor = child.walk();
567                if spec_cursor.goto_first_child() {
568                    loop {
569                        if spec_cursor.node().kind() == "identifier"
570                            || spec_cursor.node().kind() == "type_identifier"
571                        {
572                            names.push(source[spec_cursor.node().byte_range()].to_string());
573                            break;
574                        }
575                        if !spec_cursor.goto_next_sibling() {
576                            break;
577                        }
578                    }
579                }
580            }
581        }
582        if !cursor.goto_next_sibling() {
583            break;
584        }
585    }
586}
587
588/// Extract the alias name from a namespace_import node (`* as name`).
589fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
590    let mut cursor = node.walk();
591    if !cursor.goto_first_child() {
592        return;
593    }
594
595    loop {
596        let child = cursor.node();
597        if child.kind() == "identifier" {
598            *namespace_import = Some(source[child.byte_range()].to_string());
599            return;
600        }
601        if !cursor.goto_next_sibling() {
602            break;
603        }
604    }
605}
606
607/// Generate an import line for TS/JS/TSX.
608fn generate_ts_import_line(
609    module_path: &str,
610    names: &[String],
611    default_import: Option<&str>,
612    type_only: bool,
613) -> String {
614    let type_prefix = if type_only { "type " } else { "" };
615
616    // Side-effect import
617    if names.is_empty() && default_import.is_none() {
618        return format!("import '{module_path}';");
619    }
620
621    // Default import only
622    if names.is_empty() {
623        if let Some(def) = default_import {
624            return format!("import {type_prefix}{def} from '{module_path}';");
625        }
626    }
627
628    // Named imports only
629    if default_import.is_none() {
630        let mut sorted_names = names.to_vec();
631        sorted_names.sort();
632        let names_str = sorted_names.join(", ");
633        return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
634    }
635
636    // Both default and named imports
637    if let Some(def) = default_import {
638        let mut sorted_names = names.to_vec();
639        sorted_names.sort();
640        let names_str = sorted_names.join(", ");
641        return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
642    }
643
644    // Shouldn't reach here, but handle gracefully
645    format!("import '{module_path}';")
646}
647
648// ---------------------------------------------------------------------------
649// Python implementation
650// ---------------------------------------------------------------------------
651
652/// Python 3.x standard library module names (top-level modules).
653/// Used for import group classification. Covers the commonly-used modules;
654/// unknown modules are assumed third-party.
655const PYTHON_STDLIB: &[&str] = &[
656    "__future__",
657    "_thread",
658    "abc",
659    "aifc",
660    "argparse",
661    "array",
662    "ast",
663    "asynchat",
664    "asyncio",
665    "asyncore",
666    "atexit",
667    "audioop",
668    "base64",
669    "bdb",
670    "binascii",
671    "bisect",
672    "builtins",
673    "bz2",
674    "calendar",
675    "cgi",
676    "cgitb",
677    "chunk",
678    "cmath",
679    "cmd",
680    "code",
681    "codecs",
682    "codeop",
683    "collections",
684    "colorsys",
685    "compileall",
686    "concurrent",
687    "configparser",
688    "contextlib",
689    "contextvars",
690    "copy",
691    "copyreg",
692    "cProfile",
693    "crypt",
694    "csv",
695    "ctypes",
696    "curses",
697    "dataclasses",
698    "datetime",
699    "dbm",
700    "decimal",
701    "difflib",
702    "dis",
703    "distutils",
704    "doctest",
705    "email",
706    "encodings",
707    "enum",
708    "errno",
709    "faulthandler",
710    "fcntl",
711    "filecmp",
712    "fileinput",
713    "fnmatch",
714    "fractions",
715    "ftplib",
716    "functools",
717    "gc",
718    "getopt",
719    "getpass",
720    "gettext",
721    "glob",
722    "grp",
723    "gzip",
724    "hashlib",
725    "heapq",
726    "hmac",
727    "html",
728    "http",
729    "idlelib",
730    "imaplib",
731    "imghdr",
732    "importlib",
733    "inspect",
734    "io",
735    "ipaddress",
736    "itertools",
737    "json",
738    "keyword",
739    "lib2to3",
740    "linecache",
741    "locale",
742    "logging",
743    "lzma",
744    "mailbox",
745    "mailcap",
746    "marshal",
747    "math",
748    "mimetypes",
749    "mmap",
750    "modulefinder",
751    "multiprocessing",
752    "netrc",
753    "numbers",
754    "operator",
755    "optparse",
756    "os",
757    "pathlib",
758    "pdb",
759    "pickle",
760    "pickletools",
761    "pipes",
762    "pkgutil",
763    "platform",
764    "plistlib",
765    "poplib",
766    "posixpath",
767    "pprint",
768    "profile",
769    "pstats",
770    "pty",
771    "pwd",
772    "py_compile",
773    "pyclbr",
774    "pydoc",
775    "queue",
776    "quopri",
777    "random",
778    "re",
779    "readline",
780    "reprlib",
781    "resource",
782    "rlcompleter",
783    "runpy",
784    "sched",
785    "secrets",
786    "select",
787    "selectors",
788    "shelve",
789    "shlex",
790    "shutil",
791    "signal",
792    "site",
793    "smtplib",
794    "sndhdr",
795    "socket",
796    "socketserver",
797    "sqlite3",
798    "ssl",
799    "stat",
800    "statistics",
801    "string",
802    "stringprep",
803    "struct",
804    "subprocess",
805    "symtable",
806    "sys",
807    "sysconfig",
808    "syslog",
809    "tabnanny",
810    "tarfile",
811    "tempfile",
812    "termios",
813    "textwrap",
814    "threading",
815    "time",
816    "timeit",
817    "tkinter",
818    "token",
819    "tokenize",
820    "tomllib",
821    "trace",
822    "traceback",
823    "tracemalloc",
824    "tty",
825    "turtle",
826    "types",
827    "typing",
828    "unicodedata",
829    "unittest",
830    "urllib",
831    "uuid",
832    "venv",
833    "warnings",
834    "wave",
835    "weakref",
836    "webbrowser",
837    "wsgiref",
838    "xml",
839    "xmlrpc",
840    "zipapp",
841    "zipfile",
842    "zipimport",
843    "zlib",
844];
845
846/// Classify a Python import into a group.
847pub fn classify_group_py(module_path: &str) -> ImportGroup {
848    // Relative imports start with '.'
849    if module_path.starts_with('.') {
850        return ImportGroup::Internal;
851    }
852    // Check stdlib: use the top-level module name (before first '.')
853    let top_module = module_path.split('.').next().unwrap_or(module_path);
854    if PYTHON_STDLIB.contains(&top_module) {
855        ImportGroup::Stdlib
856    } else {
857        ImportGroup::External
858    }
859}
860
861/// Parse imports from a Python file.
862fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
863    let root = tree.root_node();
864    let mut imports = Vec::new();
865
866    let mut cursor = root.walk();
867    if !cursor.goto_first_child() {
868        return ImportBlock::empty();
869    }
870
871    loop {
872        let node = cursor.node();
873        match node.kind() {
874            "import_statement" => {
875                if let Some(imp) = parse_py_import_statement(source, &node) {
876                    imports.push(imp);
877                }
878            }
879            "import_from_statement" => {
880                if let Some(imp) = parse_py_import_from_statement(source, &node) {
881                    imports.push(imp);
882                }
883            }
884            _ => {}
885        }
886        if !cursor.goto_next_sibling() {
887            break;
888        }
889    }
890
891    let byte_range = import_byte_range(&imports);
892
893    ImportBlock {
894        imports,
895        byte_range,
896    }
897}
898
899/// Parse `import X` or `import X.Y` Python statements.
900fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
901    let raw_text = source[node.byte_range()].to_string();
902    let byte_range = node.byte_range();
903
904    // Find the dotted_name child (the module name)
905    let mut module_path = String::new();
906    let mut c = node.walk();
907    if c.goto_first_child() {
908        loop {
909            if c.node().kind() == "dotted_name" {
910                module_path = source[c.node().byte_range()].to_string();
911                break;
912            }
913            if !c.goto_next_sibling() {
914                break;
915            }
916        }
917    }
918    if module_path.is_empty() {
919        return None;
920    }
921
922    let group = classify_group_py(&module_path);
923
924    Some(ImportStatement {
925        module_path,
926        names: Vec::new(),
927        default_import: None,
928        namespace_import: None,
929        kind: ImportKind::Value,
930        group,
931        byte_range,
932        raw_text,
933    })
934}
935
936/// Parse `from X import Y, Z` or `from . import Y` Python statements.
937fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
938    let raw_text = source[node.byte_range()].to_string();
939    let byte_range = node.byte_range();
940
941    let mut module_path = String::new();
942    let mut names = Vec::new();
943
944    let mut c = node.walk();
945    if c.goto_first_child() {
946        loop {
947            let child = c.node();
948            match child.kind() {
949                "dotted_name" => {
950                    // Could be the module name or an imported name
951                    // The module name comes right after `from`, imported names come after `import`
952                    // Use position: if we haven't set module_path yet and this comes
953                    // before the `import` keyword, it's the module.
954                    if module_path.is_empty()
955                        && !has_seen_import_keyword(source, node, child.start_byte())
956                    {
957                        module_path = source[child.byte_range()].to_string();
958                    } else {
959                        // It's an imported name
960                        names.push(source[child.byte_range()].to_string());
961                    }
962                }
963                "relative_import" => {
964                    // from . import X or from ..module import X
965                    module_path = source[child.byte_range()].to_string();
966                }
967                _ => {}
968            }
969            if !c.goto_next_sibling() {
970                break;
971            }
972        }
973    }
974
975    // module_path must be non-empty for a valid import
976    if module_path.is_empty() {
977        return None;
978    }
979
980    let group = classify_group_py(&module_path);
981
982    Some(ImportStatement {
983        module_path,
984        names,
985        default_import: None,
986        namespace_import: None,
987        kind: ImportKind::Value,
988        group,
989        byte_range,
990        raw_text,
991    })
992}
993
994/// Check if the `import` keyword appears before the given byte position in a from...import node.
995fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
996    let mut c = parent.walk();
997    if c.goto_first_child() {
998        loop {
999            let child = c.node();
1000            if child.kind() == "import" && child.start_byte() < before_byte {
1001                return true;
1002            }
1003            if child.start_byte() >= before_byte {
1004                return false;
1005            }
1006            if !c.goto_next_sibling() {
1007                break;
1008            }
1009        }
1010    }
1011    false
1012}
1013
1014/// Generate a Python import line.
1015fn generate_py_import_line(
1016    module_path: &str,
1017    names: &[String],
1018    _default_import: Option<&str>,
1019) -> String {
1020    if names.is_empty() {
1021        // `import module`
1022        format!("import {module_path}")
1023    } else {
1024        // `from module import name1, name2`
1025        let mut sorted = names.to_vec();
1026        sorted.sort();
1027        let names_str = sorted.join(", ");
1028        format!("from {module_path} import {names_str}")
1029    }
1030}
1031
1032// ---------------------------------------------------------------------------
1033// Rust implementation
1034// ---------------------------------------------------------------------------
1035
1036/// Classify a Rust use path into a group.
1037pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1038    // Extract the first path segment (before ::)
1039    let first_seg = module_path.split("::").next().unwrap_or(module_path);
1040    match first_seg {
1041        "std" | "core" | "alloc" => ImportGroup::Stdlib,
1042        "crate" | "self" | "super" => ImportGroup::Internal,
1043        _ => ImportGroup::External,
1044    }
1045}
1046
1047/// Parse imports from a Rust file.
1048fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1049    let root = tree.root_node();
1050    let mut imports = Vec::new();
1051
1052    let mut cursor = root.walk();
1053    if !cursor.goto_first_child() {
1054        return ImportBlock::empty();
1055    }
1056
1057    loop {
1058        let node = cursor.node();
1059        if node.kind() == "use_declaration" {
1060            if let Some(imp) = parse_rs_use_declaration(source, &node) {
1061                imports.push(imp);
1062            }
1063        }
1064        if !cursor.goto_next_sibling() {
1065            break;
1066        }
1067    }
1068
1069    let byte_range = import_byte_range(&imports);
1070
1071    ImportBlock {
1072        imports,
1073        byte_range,
1074    }
1075}
1076
1077/// Parse a single `use` declaration from Rust.
1078fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1079    let raw_text = source[node.byte_range()].to_string();
1080    let byte_range = node.byte_range();
1081
1082    // Check for `pub` visibility modifier
1083    let mut has_pub = false;
1084    let mut use_path = String::new();
1085    let mut names = Vec::new();
1086
1087    let mut c = node.walk();
1088    if c.goto_first_child() {
1089        loop {
1090            let child = c.node();
1091            match child.kind() {
1092                "visibility_modifier" => {
1093                    has_pub = true;
1094                }
1095                "scoped_identifier" | "identifier" | "use_as_clause" => {
1096                    // Full path like `std::collections::HashMap` or just `serde`
1097                    use_path = source[child.byte_range()].to_string();
1098                }
1099                "scoped_use_list" => {
1100                    // e.g. `serde::{Deserialize, Serialize}`
1101                    use_path = source[child.byte_range()].to_string();
1102                    // Also extract the individual names from the use_list
1103                    extract_rs_use_list_names(source, &child, &mut names);
1104                }
1105                _ => {}
1106            }
1107            if !c.goto_next_sibling() {
1108                break;
1109            }
1110        }
1111    }
1112
1113    if use_path.is_empty() {
1114        return None;
1115    }
1116
1117    let group = classify_group_rs(&use_path);
1118
1119    Some(ImportStatement {
1120        module_path: use_path,
1121        names,
1122        default_import: if has_pub {
1123            Some("pub".to_string())
1124        } else {
1125            None
1126        },
1127        namespace_import: None,
1128        kind: ImportKind::Value,
1129        group,
1130        byte_range,
1131        raw_text,
1132    })
1133}
1134
1135/// Extract individual names from a Rust `scoped_use_list` node.
1136fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1137    let mut c = node.walk();
1138    if c.goto_first_child() {
1139        loop {
1140            let child = c.node();
1141            if child.kind() == "use_list" {
1142                // Walk into the use_list to find identifiers
1143                let mut lc = child.walk();
1144                if lc.goto_first_child() {
1145                    loop {
1146                        let lchild = lc.node();
1147                        if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1148                            names.push(source[lchild.byte_range()].to_string());
1149                        }
1150                        if !lc.goto_next_sibling() {
1151                            break;
1152                        }
1153                    }
1154                }
1155            }
1156            if !c.goto_next_sibling() {
1157                break;
1158            }
1159        }
1160    }
1161}
1162
1163/// Generate a Rust import line.
1164fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1165    if names.is_empty() {
1166        format!("use {module_path};")
1167    } else {
1168        // If names are provided, generate `use prefix::{names};`
1169        // But the caller may pass module_path as the full path including the item,
1170        // e.g., "serde::Deserialize". For simple cases, just use the module_path directly.
1171        format!("use {module_path};")
1172    }
1173}
1174
1175// ---------------------------------------------------------------------------
1176// Go implementation
1177// ---------------------------------------------------------------------------
1178
1179/// Classify a Go import path into a group.
1180pub fn classify_group_go(module_path: &str) -> ImportGroup {
1181    // stdlib paths don't contain dots (e.g., "fmt", "os", "net/http")
1182    // external paths contain dots (e.g., "github.com/pkg/errors")
1183    if module_path.contains('.') {
1184        ImportGroup::External
1185    } else {
1186        ImportGroup::Stdlib
1187    }
1188}
1189
1190/// Parse imports from a Go file.
1191fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1192    let root = tree.root_node();
1193    let mut imports = Vec::new();
1194
1195    let mut cursor = root.walk();
1196    if !cursor.goto_first_child() {
1197        return ImportBlock::empty();
1198    }
1199
1200    loop {
1201        let node = cursor.node();
1202        if node.kind() == "import_declaration" {
1203            parse_go_import_declaration(source, &node, &mut imports);
1204        }
1205        if !cursor.goto_next_sibling() {
1206            break;
1207        }
1208    }
1209
1210    let byte_range = import_byte_range(&imports);
1211
1212    ImportBlock {
1213        imports,
1214        byte_range,
1215    }
1216}
1217
1218/// Parse a single Go import_declaration (may contain one or multiple specs).
1219fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1220    let mut c = node.walk();
1221    if c.goto_first_child() {
1222        loop {
1223            let child = c.node();
1224            match child.kind() {
1225                "import_spec" => {
1226                    if let Some(imp) = parse_go_import_spec(source, &child) {
1227                        imports.push(imp);
1228                    }
1229                }
1230                "import_spec_list" => {
1231                    // Grouped imports: walk into the list
1232                    let mut lc = child.walk();
1233                    if lc.goto_first_child() {
1234                        loop {
1235                            if lc.node().kind() == "import_spec" {
1236                                if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1237                                    imports.push(imp);
1238                                }
1239                            }
1240                            if !lc.goto_next_sibling() {
1241                                break;
1242                            }
1243                        }
1244                    }
1245                }
1246                _ => {}
1247            }
1248            if !c.goto_next_sibling() {
1249                break;
1250            }
1251        }
1252    }
1253}
1254
1255/// Parse a single Go import_spec node.
1256fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1257    let raw_text = source[node.byte_range()].to_string();
1258    let byte_range = node.byte_range();
1259
1260    let mut import_path = String::new();
1261    let mut alias = None;
1262
1263    let mut c = node.walk();
1264    if c.goto_first_child() {
1265        loop {
1266            let child = c.node();
1267            match child.kind() {
1268                "interpreted_string_literal" => {
1269                    // Extract the path without quotes
1270                    let text = source[child.byte_range()].to_string();
1271                    import_path = text.trim_matches('"').to_string();
1272                }
1273                "identifier" | "blank_identifier" | "dot" => {
1274                    // This is an alias (e.g., `alias "path"` or `. "path"` or `_ "path"`)
1275                    alias = Some(source[child.byte_range()].to_string());
1276                }
1277                _ => {}
1278            }
1279            if !c.goto_next_sibling() {
1280                break;
1281            }
1282        }
1283    }
1284
1285    if import_path.is_empty() {
1286        return None;
1287    }
1288
1289    let group = classify_group_go(&import_path);
1290
1291    Some(ImportStatement {
1292        module_path: import_path,
1293        names: Vec::new(),
1294        default_import: alias,
1295        namespace_import: None,
1296        kind: ImportKind::Value,
1297        group,
1298        byte_range,
1299        raw_text,
1300    })
1301}
1302
1303/// Public API for Go import line generation (used by add_import handler).
1304pub fn generate_go_import_line_pub(
1305    module_path: &str,
1306    alias: Option<&str>,
1307    in_group: bool,
1308) -> String {
1309    generate_go_import_line(module_path, alias, in_group)
1310}
1311
1312/// Generate a Go import line (public API for command handler).
1313///
1314/// `in_group` controls whether to generate a spec for insertion into an
1315/// existing grouped import (`\t"path"`) or a standalone import (`import "path"`).
1316fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1317    if in_group {
1318        // Spec for grouped import block
1319        match alias {
1320            Some(a) => format!("\t{a} \"{module_path}\""),
1321            None => format!("\t\"{module_path}\""),
1322        }
1323    } else {
1324        // Standalone import
1325        match alias {
1326            Some(a) => format!("import {a} \"{module_path}\""),
1327            None => format!("import \"{module_path}\""),
1328        }
1329    }
1330}
1331
1332/// Check if a Go import block has a grouped import declaration.
1333/// Returns the byte range of the import_spec_list if found.
1334pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1335    let root = tree.root_node();
1336    let mut cursor = root.walk();
1337    if !cursor.goto_first_child() {
1338        return None;
1339    }
1340
1341    loop {
1342        let node = cursor.node();
1343        if node.kind() == "import_declaration" {
1344            let mut c = node.walk();
1345            if c.goto_first_child() {
1346                loop {
1347                    if c.node().kind() == "import_spec_list" {
1348                        return Some(c.node().byte_range());
1349                    }
1350                    if !c.goto_next_sibling() {
1351                        break;
1352                    }
1353                }
1354            }
1355        }
1356        if !cursor.goto_next_sibling() {
1357            break;
1358        }
1359    }
1360    None
1361}
1362
1363/// Skip past a newline character at the given position.
1364fn skip_newline(source: &str, pos: usize) -> usize {
1365    if pos < source.len() {
1366        let bytes = source.as_bytes();
1367        if bytes[pos] == b'\n' {
1368            return pos + 1;
1369        }
1370        if bytes[pos] == b'\r' {
1371            if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1372                return pos + 2;
1373            }
1374            return pos + 1;
1375        }
1376    }
1377    pos
1378}
1379
1380// ---------------------------------------------------------------------------
1381// Unit tests
1382// ---------------------------------------------------------------------------
1383
1384#[cfg(test)]
1385mod tests {
1386    use super::*;
1387
1388    fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1389        let grammar = grammar_for(LangId::TypeScript);
1390        let mut parser = Parser::new();
1391        parser.set_language(&grammar).unwrap();
1392        let tree = parser.parse(source, None).unwrap();
1393        let block = parse_imports(source, &tree, LangId::TypeScript);
1394        (tree, block)
1395    }
1396
1397    fn parse_js(source: &str) -> (Tree, ImportBlock) {
1398        let grammar = grammar_for(LangId::JavaScript);
1399        let mut parser = Parser::new();
1400        parser.set_language(&grammar).unwrap();
1401        let tree = parser.parse(source, None).unwrap();
1402        let block = parse_imports(source, &tree, LangId::JavaScript);
1403        (tree, block)
1404    }
1405
1406    // --- Basic parsing ---
1407
1408    #[test]
1409    fn parse_ts_named_imports() {
1410        let source = "import { useState, useEffect } from 'react';\n";
1411        let (_, block) = parse_ts(source);
1412        assert_eq!(block.imports.len(), 1);
1413        let imp = &block.imports[0];
1414        assert_eq!(imp.module_path, "react");
1415        assert!(imp.names.contains(&"useState".to_string()));
1416        assert!(imp.names.contains(&"useEffect".to_string()));
1417        assert_eq!(imp.kind, ImportKind::Value);
1418        assert_eq!(imp.group, ImportGroup::External);
1419    }
1420
1421    #[test]
1422    fn parse_ts_default_import() {
1423        let source = "import React from 'react';\n";
1424        let (_, block) = parse_ts(source);
1425        assert_eq!(block.imports.len(), 1);
1426        let imp = &block.imports[0];
1427        assert_eq!(imp.default_import.as_deref(), Some("React"));
1428        assert_eq!(imp.kind, ImportKind::Value);
1429    }
1430
1431    #[test]
1432    fn parse_ts_side_effect_import() {
1433        let source = "import './styles.css';\n";
1434        let (_, block) = parse_ts(source);
1435        assert_eq!(block.imports.len(), 1);
1436        assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1437        assert_eq!(block.imports[0].module_path, "./styles.css");
1438    }
1439
1440    #[test]
1441    fn parse_ts_relative_import() {
1442        let source = "import { helper } from './utils';\n";
1443        let (_, block) = parse_ts(source);
1444        assert_eq!(block.imports.len(), 1);
1445        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1446    }
1447
1448    #[test]
1449    fn parse_ts_multiple_groups() {
1450        let source = "\
1451import React from 'react';
1452import { useState } from 'react';
1453import { helper } from './utils';
1454import { Config } from '../config';
1455";
1456        let (_, block) = parse_ts(source);
1457        assert_eq!(block.imports.len(), 4);
1458
1459        let external: Vec<_> = block
1460            .imports
1461            .iter()
1462            .filter(|i| i.group == ImportGroup::External)
1463            .collect();
1464        let relative: Vec<_> = block
1465            .imports
1466            .iter()
1467            .filter(|i| i.group == ImportGroup::Internal)
1468            .collect();
1469        assert_eq!(external.len(), 2);
1470        assert_eq!(relative.len(), 2);
1471    }
1472
1473    #[test]
1474    fn parse_ts_namespace_import() {
1475        let source = "import * as path from 'path';\n";
1476        let (_, block) = parse_ts(source);
1477        assert_eq!(block.imports.len(), 1);
1478        let imp = &block.imports[0];
1479        assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1480        assert_eq!(imp.kind, ImportKind::Value);
1481    }
1482
1483    #[test]
1484    fn parse_js_imports() {
1485        let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1486        let (_, block) = parse_js(source);
1487        assert_eq!(block.imports.len(), 2);
1488        assert_eq!(block.imports[0].group, ImportGroup::External);
1489        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1490    }
1491
1492    // --- Group classification ---
1493
1494    #[test]
1495    fn classify_external() {
1496        assert_eq!(classify_group_ts("react"), ImportGroup::External);
1497        assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1498        assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1499    }
1500
1501    #[test]
1502    fn classify_relative() {
1503        assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1504        assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1505        assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1506    }
1507
1508    // --- Dedup ---
1509
1510    #[test]
1511    fn dedup_detects_same_named_import() {
1512        let source = "import { useState } from 'react';\n";
1513        let (_, block) = parse_ts(source);
1514        assert!(is_duplicate(
1515            &block,
1516            "react",
1517            &["useState".to_string()],
1518            None,
1519            false
1520        ));
1521    }
1522
1523    #[test]
1524    fn dedup_misses_different_name() {
1525        let source = "import { useState } from 'react';\n";
1526        let (_, block) = parse_ts(source);
1527        assert!(!is_duplicate(
1528            &block,
1529            "react",
1530            &["useEffect".to_string()],
1531            None,
1532            false
1533        ));
1534    }
1535
1536    #[test]
1537    fn dedup_detects_default_import() {
1538        let source = "import React from 'react';\n";
1539        let (_, block) = parse_ts(source);
1540        assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1541    }
1542
1543    #[test]
1544    fn dedup_side_effect() {
1545        let source = "import './styles.css';\n";
1546        let (_, block) = parse_ts(source);
1547        assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1548    }
1549
1550    #[test]
1551    fn dedup_type_vs_value() {
1552        let source = "import { FC } from 'react';\n";
1553        let (_, block) = parse_ts(source);
1554        // Type import should NOT match a value import of the same name
1555        assert!(!is_duplicate(
1556            &block,
1557            "react",
1558            &["FC".to_string()],
1559            None,
1560            true
1561        ));
1562    }
1563
1564    // --- Generation ---
1565
1566    #[test]
1567    fn generate_named_import() {
1568        let line = generate_import_line(
1569            LangId::TypeScript,
1570            "react",
1571            &["useState".to_string(), "useEffect".to_string()],
1572            None,
1573            false,
1574        );
1575        assert_eq!(line, "import { useEffect, useState } from 'react';");
1576    }
1577
1578    #[test]
1579    fn generate_default_import() {
1580        let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1581        assert_eq!(line, "import React from 'react';");
1582    }
1583
1584    #[test]
1585    fn generate_type_import() {
1586        let line =
1587            generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1588        assert_eq!(line, "import type { FC } from 'react';");
1589    }
1590
1591    #[test]
1592    fn generate_side_effect_import() {
1593        let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1594        assert_eq!(line, "import './styles.css';");
1595    }
1596
1597    #[test]
1598    fn generate_default_and_named() {
1599        let line = generate_import_line(
1600            LangId::TypeScript,
1601            "react",
1602            &["useState".to_string()],
1603            Some("React"),
1604            false,
1605        );
1606        assert_eq!(line, "import React, { useState } from 'react';");
1607    }
1608
1609    #[test]
1610    fn parse_ts_type_import() {
1611        let source = "import type { FC } from 'react';\n";
1612        let (_, block) = parse_ts(source);
1613        assert_eq!(block.imports.len(), 1);
1614        let imp = &block.imports[0];
1615        assert_eq!(imp.kind, ImportKind::Type);
1616        assert!(imp.names.contains(&"FC".to_string()));
1617        assert_eq!(imp.group, ImportGroup::External);
1618    }
1619
1620    // --- Insertion point ---
1621
1622    #[test]
1623    fn insertion_empty_file() {
1624        let source = "";
1625        let (_, block) = parse_ts(source);
1626        let (offset, _, _) =
1627            find_insertion_point(source, &block, ImportGroup::External, "react", false);
1628        assert_eq!(offset, 0);
1629    }
1630
1631    #[test]
1632    fn insertion_alphabetical_within_group() {
1633        let source = "\
1634import { a } from 'alpha';
1635import { c } from 'charlie';
1636";
1637        let (_, block) = parse_ts(source);
1638        let (offset, _, _) =
1639            find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1640        // Should insert before 'charlie' (which starts at line 2)
1641        let before_charlie = source.find("import { c }").unwrap();
1642        assert_eq!(offset, before_charlie);
1643    }
1644
1645    // --- Python parsing ---
1646
1647    fn parse_py(source: &str) -> (Tree, ImportBlock) {
1648        let grammar = grammar_for(LangId::Python);
1649        let mut parser = Parser::new();
1650        parser.set_language(&grammar).unwrap();
1651        let tree = parser.parse(source, None).unwrap();
1652        let block = parse_imports(source, &tree, LangId::Python);
1653        (tree, block)
1654    }
1655
1656    #[test]
1657    fn parse_py_import_statement() {
1658        let source = "import os\nimport sys\n";
1659        let (_, block) = parse_py(source);
1660        assert_eq!(block.imports.len(), 2);
1661        assert_eq!(block.imports[0].module_path, "os");
1662        assert_eq!(block.imports[1].module_path, "sys");
1663        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1664    }
1665
1666    #[test]
1667    fn parse_py_from_import() {
1668        let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1669        let (_, block) = parse_py(source);
1670        assert_eq!(block.imports.len(), 2);
1671        assert_eq!(block.imports[0].module_path, "collections");
1672        assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1673        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1674        assert_eq!(block.imports[1].module_path, "typing");
1675        assert!(block.imports[1].names.contains(&"List".to_string()));
1676        assert!(block.imports[1].names.contains(&"Optional".to_string()));
1677    }
1678
1679    #[test]
1680    fn parse_py_relative_import() {
1681        let source = "from . import utils\nfrom ..config import Settings\n";
1682        let (_, block) = parse_py(source);
1683        assert_eq!(block.imports.len(), 2);
1684        assert_eq!(block.imports[0].module_path, ".");
1685        assert!(block.imports[0].names.contains(&"utils".to_string()));
1686        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1687        assert_eq!(block.imports[1].module_path, "..config");
1688        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1689    }
1690
1691    #[test]
1692    fn classify_py_groups() {
1693        assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1694        assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1695        assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1696        assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1697        assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1698        assert_eq!(classify_group_py("requests"), ImportGroup::External);
1699        assert_eq!(classify_group_py("flask"), ImportGroup::External);
1700        assert_eq!(classify_group_py("."), ImportGroup::Internal);
1701        assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1702        assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1703    }
1704
1705    #[test]
1706    fn parse_py_three_groups() {
1707        let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1708        let (_, block) = parse_py(source);
1709        let stdlib: Vec<_> = block
1710            .imports
1711            .iter()
1712            .filter(|i| i.group == ImportGroup::Stdlib)
1713            .collect();
1714        let external: Vec<_> = block
1715            .imports
1716            .iter()
1717            .filter(|i| i.group == ImportGroup::External)
1718            .collect();
1719        let internal: Vec<_> = block
1720            .imports
1721            .iter()
1722            .filter(|i| i.group == ImportGroup::Internal)
1723            .collect();
1724        assert_eq!(stdlib.len(), 2);
1725        assert_eq!(external.len(), 1);
1726        assert_eq!(internal.len(), 1);
1727    }
1728
1729    #[test]
1730    fn generate_py_import() {
1731        let line = generate_import_line(LangId::Python, "os", &[], None, false);
1732        assert_eq!(line, "import os");
1733    }
1734
1735    #[test]
1736    fn generate_py_from_import() {
1737        let line = generate_import_line(
1738            LangId::Python,
1739            "collections",
1740            &["OrderedDict".to_string()],
1741            None,
1742            false,
1743        );
1744        assert_eq!(line, "from collections import OrderedDict");
1745    }
1746
1747    #[test]
1748    fn generate_py_from_import_multiple() {
1749        let line = generate_import_line(
1750            LangId::Python,
1751            "typing",
1752            &["Optional".to_string(), "List".to_string()],
1753            None,
1754            false,
1755        );
1756        assert_eq!(line, "from typing import List, Optional");
1757    }
1758
1759    // --- Rust parsing ---
1760
1761    fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1762        let grammar = grammar_for(LangId::Rust);
1763        let mut parser = Parser::new();
1764        parser.set_language(&grammar).unwrap();
1765        let tree = parser.parse(source, None).unwrap();
1766        let block = parse_imports(source, &tree, LangId::Rust);
1767        (tree, block)
1768    }
1769
1770    #[test]
1771    fn parse_rs_use_std() {
1772        let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1773        let (_, block) = parse_rust(source);
1774        assert_eq!(block.imports.len(), 2);
1775        assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1776        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1777        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1778    }
1779
1780    #[test]
1781    fn parse_rs_use_external() {
1782        let source = "use serde::{Deserialize, Serialize};\n";
1783        let (_, block) = parse_rust(source);
1784        assert_eq!(block.imports.len(), 1);
1785        assert_eq!(block.imports[0].group, ImportGroup::External);
1786        assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1787        assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1788    }
1789
1790    #[test]
1791    fn parse_rs_use_crate() {
1792        let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1793        let (_, block) = parse_rust(source);
1794        assert_eq!(block.imports.len(), 2);
1795        assert_eq!(block.imports[0].group, ImportGroup::Internal);
1796        assert_eq!(block.imports[1].group, ImportGroup::Internal);
1797    }
1798
1799    #[test]
1800    fn parse_rs_pub_use() {
1801        let source = "pub use super::parent::Thing;\n";
1802        let (_, block) = parse_rust(source);
1803        assert_eq!(block.imports.len(), 1);
1804        // `pub` is stored in default_import as a marker
1805        assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1806    }
1807
1808    #[test]
1809    fn classify_rs_groups() {
1810        assert_eq!(
1811            classify_group_rs("std::collections::HashMap"),
1812            ImportGroup::Stdlib
1813        );
1814        assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
1815        assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
1816        assert_eq!(
1817            classify_group_rs("serde::Deserialize"),
1818            ImportGroup::External
1819        );
1820        assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
1821        assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
1822        assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
1823        assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
1824    }
1825
1826    #[test]
1827    fn generate_rs_use() {
1828        let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
1829        assert_eq!(line, "use std::fmt::Display;");
1830    }
1831
1832    // --- Go parsing ---
1833
1834    fn parse_go(source: &str) -> (Tree, ImportBlock) {
1835        let grammar = grammar_for(LangId::Go);
1836        let mut parser = Parser::new();
1837        parser.set_language(&grammar).unwrap();
1838        let tree = parser.parse(source, None).unwrap();
1839        let block = parse_imports(source, &tree, LangId::Go);
1840        (tree, block)
1841    }
1842
1843    #[test]
1844    fn parse_go_single_import() {
1845        let source = "package main\n\nimport \"fmt\"\n";
1846        let (_, block) = parse_go(source);
1847        assert_eq!(block.imports.len(), 1);
1848        assert_eq!(block.imports[0].module_path, "fmt");
1849        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1850    }
1851
1852    #[test]
1853    fn parse_go_grouped_import() {
1854        let source =
1855            "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
1856        let (_, block) = parse_go(source);
1857        assert_eq!(block.imports.len(), 3);
1858        assert_eq!(block.imports[0].module_path, "fmt");
1859        assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1860        assert_eq!(block.imports[1].module_path, "os");
1861        assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1862        assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
1863        assert_eq!(block.imports[2].group, ImportGroup::External);
1864    }
1865
1866    #[test]
1867    fn parse_go_mixed_imports() {
1868        // Single + grouped
1869        let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
1870        let (_, block) = parse_go(source);
1871        assert_eq!(block.imports.len(), 3);
1872    }
1873
1874    #[test]
1875    fn classify_go_groups() {
1876        assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
1877        assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
1878        assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
1879        assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
1880        assert_eq!(
1881            classify_group_go("github.com/pkg/errors"),
1882            ImportGroup::External
1883        );
1884        assert_eq!(
1885            classify_group_go("golang.org/x/tools"),
1886            ImportGroup::External
1887        );
1888    }
1889
1890    #[test]
1891    fn generate_go_standalone() {
1892        let line = generate_go_import_line("fmt", None, false);
1893        assert_eq!(line, "import \"fmt\"");
1894    }
1895
1896    #[test]
1897    fn generate_go_grouped_spec() {
1898        let line = generate_go_import_line("fmt", None, true);
1899        assert_eq!(line, "\t\"fmt\"");
1900    }
1901
1902    #[test]
1903    fn generate_go_with_alias() {
1904        let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
1905        assert_eq!(line, "import errs \"github.com/pkg/errors\"");
1906    }
1907}