1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C
178 | LangId::Cpp
179 | LangId::Zig
180 | LangId::CSharp
181 | LangId::Bash
182 | LangId::Solidity => ImportBlock::empty(),
183 LangId::Html | LangId::Markdown => ImportBlock::empty(),
184 }
185}
186
187pub fn is_duplicate(
193 block: &ImportBlock,
194 module_path: &str,
195 names: &[String],
196 default_import: Option<&str>,
197 type_only: bool,
198) -> bool {
199 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
200}
201
202pub fn is_duplicate_with_namespace(
204 block: &ImportBlock,
205 module_path: &str,
206 names: &[String],
207 default_import: Option<&str>,
208 namespace_import: Option<&str>,
209 type_only: bool,
210) -> bool {
211 let target_kind = if type_only {
212 ImportKind::Type
213 } else {
214 ImportKind::Value
215 };
216
217 for imp in &block.imports {
218 if imp.module_path != module_path {
219 continue;
220 }
221
222 if names.is_empty()
228 && default_import.is_none()
229 && namespace_import.is_none()
230 && imp.names.is_empty()
231 && imp.default_import.is_none()
232 && imp.namespace_import.is_none()
233 {
234 return true;
235 }
236
237 if names.is_empty()
239 && default_import.is_none()
240 && namespace_import.is_none()
241 && imp.kind == ImportKind::SideEffect
242 {
243 return true;
244 }
245
246 if names.is_empty()
247 && default_import.is_none()
248 && namespace_import.is_some()
249 && imp.names.is_empty()
250 && imp.default_import.is_none()
251 && imp.namespace_import.as_deref() == namespace_import
252 {
253 return true;
254 }
255
256 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
258 continue;
259 }
260
261 if let Some(def) = default_import {
263 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
264 return true;
265 }
266 }
267
268 if !names.is_empty()
274 && names
275 .iter()
276 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
277 {
278 return true;
279 }
280 }
281
282 false
283}
284
285fn sort_named_specifiers(names: &mut [String]) {
286 names.sort_by(|a, b| {
287 specifier_imported_name(a)
288 .cmp(specifier_imported_name(b))
289 .then_with(|| a.cmp(b))
290 });
291}
292
293pub fn find_insertion_point(
304 source: &str,
305 block: &ImportBlock,
306 group: ImportGroup,
307 module_path: &str,
308 type_only: bool,
309) -> (usize, bool, bool) {
310 if block.imports.is_empty() {
311 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
313 }
314
315 let target_kind = if type_only {
316 ImportKind::Type
317 } else {
318 ImportKind::Value
319 };
320
321 let group_imports: Vec<&ImportStatement> =
323 block.imports.iter().filter(|i| i.group == group).collect();
324
325 if group_imports.is_empty() {
326 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
329
330 if let Some(last) = preceding_last {
331 let end = last.byte_range.end;
332 let insert_at = skip_newline(source, end);
333 return (insert_at, true, true);
334 }
335
336 let following_first = block.imports.iter().find(|i| i.group > group);
338
339 if let Some(first) = following_first {
340 return (first.byte_range.start, false, true);
341 }
342
343 let first_byte = import_byte_range(&block.imports)
345 .map(|range| range.start)
346 .unwrap_or(0);
347 return (first_byte, false, true);
348 }
349
350 for imp in &group_imports {
352 let cmp = module_path.cmp(&imp.module_path);
353 match cmp {
354 std::cmp::Ordering::Less => {
355 return (imp.byte_range.start, false, false);
357 }
358 std::cmp::Ordering::Equal => {
359 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
361 let end = imp.byte_range.end;
363 let insert_at = skip_newline(source, end);
364 return (insert_at, false, false);
365 }
366 return (imp.byte_range.start, false, false);
368 }
369 std::cmp::Ordering::Greater => continue,
370 }
371 }
372
373 let Some(last) = group_imports.last() else {
375 return (
376 import_byte_range(&block.imports)
377 .map(|range| range.end)
378 .unwrap_or(0),
379 false,
380 false,
381 );
382 };
383 let end = last.byte_range.end;
384 let insert_at = skip_newline(source, end);
385 (insert_at, false, false)
386}
387
388pub fn generate_import_line(
390 lang: LangId,
391 module_path: &str,
392 names: &[String],
393 default_import: Option<&str>,
394 type_only: bool,
395) -> String {
396 match lang {
397 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
398 generate_ts_import_line(module_path, names, default_import, type_only)
399 }
400 LangId::Python => generate_py_import_line(module_path, names, default_import),
401 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
402 LangId::Go => generate_go_import_line(module_path, default_import, false),
403 LangId::C
404 | LangId::Cpp
405 | LangId::Zig
406 | LangId::CSharp
407 | LangId::Bash
408 | LangId::Solidity => String::new(),
409 LangId::Html | LangId::Markdown => String::new(),
410 }
411}
412
413pub fn is_supported(lang: LangId) -> bool {
415 matches!(
416 lang,
417 LangId::TypeScript
418 | LangId::Tsx
419 | LangId::JavaScript
420 | LangId::Python
421 | LangId::Rust
422 | LangId::Go
423 )
424}
425
426pub fn classify_group_ts(module_path: &str) -> ImportGroup {
428 if module_path.starts_with('.') {
429 ImportGroup::Internal
430 } else {
431 ImportGroup::External
432 }
433}
434
435pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
437 match lang {
438 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
439 LangId::Python => classify_group_py(module_path),
440 LangId::Rust => classify_group_rs(module_path),
441 LangId::Go => classify_group_go(module_path),
442 LangId::C
443 | LangId::Cpp
444 | LangId::Zig
445 | LangId::CSharp
446 | LangId::Bash
447 | LangId::Solidity => ImportGroup::External,
448 LangId::Html | LangId::Markdown => ImportGroup::External,
449 }
450}
451
452pub fn parse_file_imports(
455 path: &std::path::Path,
456 lang: LangId,
457) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
458 let source =
459 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
460 path: format!("{}: {}", path.display(), e),
461 })?;
462
463 let grammar = grammar_for(lang);
464 let mut parser = Parser::new();
465 parser
466 .set_language(&grammar)
467 .map_err(|e| crate::error::AftError::ParseError {
468 message: format!("grammar init failed for {:?}: {}", lang, e),
469 })?;
470
471 let tree = parser
472 .parse(&source, None)
473 .ok_or_else(|| crate::error::AftError::ParseError {
474 message: format!("tree-sitter parse returned None for {}", path.display()),
475 })?;
476
477 let block = parse_imports(&source, &tree, lang);
478 Ok((source, tree, block))
479}
480
481fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
489 let root = tree.root_node();
490 let mut imports = Vec::new();
491
492 let mut cursor = root.walk();
493 if !cursor.goto_first_child() {
494 return ImportBlock::empty();
495 }
496
497 loop {
498 let node = cursor.node();
499 if node.kind() == "import_statement" {
500 if let Some(imp) = parse_single_ts_import(source, &node) {
501 imports.push(imp);
502 }
503 }
504 if !cursor.goto_next_sibling() {
505 break;
506 }
507 }
508
509 let byte_range = import_byte_range(&imports);
510
511 ImportBlock {
512 imports,
513 byte_range,
514 }
515}
516
517fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
519 let raw_text = source[node.byte_range()].to_string();
520 let byte_range = node.byte_range();
521
522 let module_path = extract_module_path(source, node)?;
524
525 let is_type_only = has_type_keyword(node);
527
528 let mut names = Vec::new();
530 let mut default_import = None;
531 let mut namespace_import = None;
532
533 let mut child_cursor = node.walk();
534 if child_cursor.goto_first_child() {
535 loop {
536 let child = child_cursor.node();
537 match child.kind() {
538 "import_clause" => {
539 extract_import_clause(
540 source,
541 &child,
542 &mut names,
543 &mut default_import,
544 &mut namespace_import,
545 );
546 }
547 "identifier" => {
549 let text = &source[child.byte_range()];
550 if text != "import" && text != "from" && text != "type" {
551 default_import = Some(text.to_string());
552 }
553 }
554 _ => {}
555 }
556 if !child_cursor.goto_next_sibling() {
557 break;
558 }
559 }
560 }
561
562 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
564 ImportKind::SideEffect
565 } else if is_type_only {
566 ImportKind::Type
567 } else {
568 ImportKind::Value
569 };
570
571 let group = classify_group_ts(&module_path);
572
573 Some(ImportStatement {
574 module_path,
575 names,
576 default_import,
577 namespace_import,
578 kind,
579 group,
580 byte_range,
581 raw_text,
582 })
583}
584
585fn extract_module_path(source: &str, node: &Node) -> Option<String> {
589 let mut cursor = node.walk();
590 if !cursor.goto_first_child() {
591 return None;
592 }
593
594 loop {
595 let child = cursor.node();
596 if child.kind() == "string" {
597 let text = &source[child.byte_range()];
599 let stripped = text
600 .trim_start_matches(|c| c == '\'' || c == '"')
601 .trim_end_matches(|c| c == '\'' || c == '"');
602 return Some(stripped.to_string());
603 }
604 if !cursor.goto_next_sibling() {
605 break;
606 }
607 }
608 None
609}
610
611fn has_type_keyword(node: &Node) -> bool {
616 let mut cursor = node.walk();
617 if !cursor.goto_first_child() {
618 return false;
619 }
620
621 loop {
622 let child = cursor.node();
623 if child.kind() == "type" {
624 return true;
625 }
626 if !cursor.goto_next_sibling() {
627 break;
628 }
629 }
630
631 false
632}
633
634fn extract_import_clause(
636 source: &str,
637 node: &Node,
638 names: &mut Vec<String>,
639 default_import: &mut Option<String>,
640 namespace_import: &mut Option<String>,
641) {
642 let mut cursor = node.walk();
643 if !cursor.goto_first_child() {
644 return;
645 }
646
647 loop {
648 let child = cursor.node();
649 match child.kind() {
650 "identifier" => {
651 let text = &source[child.byte_range()];
653 if text != "type" {
654 *default_import = Some(text.to_string());
655 }
656 }
657 "named_imports" => {
658 extract_named_imports(source, &child, names);
660 }
661 "namespace_import" => {
662 extract_namespace_import(source, &child, namespace_import);
664 }
665 _ => {}
666 }
667 if !cursor.goto_next_sibling() {
668 break;
669 }
670 }
671}
672
673fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
692 let mut cursor = node.walk();
693 if !cursor.goto_first_child() {
694 return;
695 }
696
697 loop {
698 let child = cursor.node();
699 if child.kind() == "import_specifier" {
700 let raw = source[child.byte_range()].trim().to_string();
705 if !raw.is_empty() {
706 names.push(raw);
707 } else if let Some(name_node) = child.child_by_field_name("name") {
708 names.push(source[name_node.byte_range()].to_string());
709 }
710 }
711 if !cursor.goto_next_sibling() {
712 break;
713 }
714 }
715}
716
717fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
719 let mut cursor = node.walk();
720 if !cursor.goto_first_child() {
721 return;
722 }
723
724 loop {
725 let child = cursor.node();
726 if child.kind() == "identifier" {
727 *namespace_import = Some(source[child.byte_range()].to_string());
728 return;
729 }
730 if !cursor.goto_next_sibling() {
731 break;
732 }
733 }
734}
735
736fn generate_ts_import_line(
738 module_path: &str,
739 names: &[String],
740 default_import: Option<&str>,
741 type_only: bool,
742) -> String {
743 let type_prefix = if type_only { "type " } else { "" };
744
745 if names.is_empty() && default_import.is_none() {
747 return format!("import '{module_path}';");
748 }
749
750 if names.is_empty() {
752 if let Some(def) = default_import {
753 return format!("import {type_prefix}{def} from '{module_path}';");
754 }
755 }
756
757 if default_import.is_none() {
759 let mut sorted_names = names.to_vec();
760 sort_named_specifiers(&mut sorted_names);
761 let names_str = sorted_names.join(", ");
762 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
763 }
764
765 if let Some(def) = default_import {
767 let mut sorted_names = names.to_vec();
768 sort_named_specifiers(&mut sorted_names);
769 let names_str = sorted_names.join(", ");
770 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
771 }
772
773 format!("import '{module_path}';")
775}
776
777const PYTHON_STDLIB: &[&str] = &[
785 "__future__",
786 "_thread",
787 "abc",
788 "aifc",
789 "argparse",
790 "array",
791 "ast",
792 "asynchat",
793 "asyncio",
794 "asyncore",
795 "atexit",
796 "audioop",
797 "base64",
798 "bdb",
799 "binascii",
800 "bisect",
801 "builtins",
802 "bz2",
803 "calendar",
804 "cgi",
805 "cgitb",
806 "chunk",
807 "cmath",
808 "cmd",
809 "code",
810 "codecs",
811 "codeop",
812 "collections",
813 "colorsys",
814 "compileall",
815 "concurrent",
816 "configparser",
817 "contextlib",
818 "contextvars",
819 "copy",
820 "copyreg",
821 "cProfile",
822 "crypt",
823 "csv",
824 "ctypes",
825 "curses",
826 "dataclasses",
827 "datetime",
828 "dbm",
829 "decimal",
830 "difflib",
831 "dis",
832 "distutils",
833 "doctest",
834 "email",
835 "encodings",
836 "enum",
837 "errno",
838 "faulthandler",
839 "fcntl",
840 "filecmp",
841 "fileinput",
842 "fnmatch",
843 "fractions",
844 "ftplib",
845 "functools",
846 "gc",
847 "getopt",
848 "getpass",
849 "gettext",
850 "glob",
851 "grp",
852 "gzip",
853 "hashlib",
854 "heapq",
855 "hmac",
856 "html",
857 "http",
858 "idlelib",
859 "imaplib",
860 "imghdr",
861 "importlib",
862 "inspect",
863 "io",
864 "ipaddress",
865 "itertools",
866 "json",
867 "keyword",
868 "lib2to3",
869 "linecache",
870 "locale",
871 "logging",
872 "lzma",
873 "mailbox",
874 "mailcap",
875 "marshal",
876 "math",
877 "mimetypes",
878 "mmap",
879 "modulefinder",
880 "multiprocessing",
881 "netrc",
882 "numbers",
883 "operator",
884 "optparse",
885 "os",
886 "pathlib",
887 "pdb",
888 "pickle",
889 "pickletools",
890 "pipes",
891 "pkgutil",
892 "platform",
893 "plistlib",
894 "poplib",
895 "posixpath",
896 "pprint",
897 "profile",
898 "pstats",
899 "pty",
900 "pwd",
901 "py_compile",
902 "pyclbr",
903 "pydoc",
904 "queue",
905 "quopri",
906 "random",
907 "re",
908 "readline",
909 "reprlib",
910 "resource",
911 "rlcompleter",
912 "runpy",
913 "sched",
914 "secrets",
915 "select",
916 "selectors",
917 "shelve",
918 "shlex",
919 "shutil",
920 "signal",
921 "site",
922 "smtplib",
923 "sndhdr",
924 "socket",
925 "socketserver",
926 "sqlite3",
927 "ssl",
928 "stat",
929 "statistics",
930 "string",
931 "stringprep",
932 "struct",
933 "subprocess",
934 "symtable",
935 "sys",
936 "sysconfig",
937 "syslog",
938 "tabnanny",
939 "tarfile",
940 "tempfile",
941 "termios",
942 "textwrap",
943 "threading",
944 "time",
945 "timeit",
946 "tkinter",
947 "token",
948 "tokenize",
949 "tomllib",
950 "trace",
951 "traceback",
952 "tracemalloc",
953 "tty",
954 "turtle",
955 "types",
956 "typing",
957 "unicodedata",
958 "unittest",
959 "urllib",
960 "uuid",
961 "venv",
962 "warnings",
963 "wave",
964 "weakref",
965 "webbrowser",
966 "wsgiref",
967 "xml",
968 "xmlrpc",
969 "zipapp",
970 "zipfile",
971 "zipimport",
972 "zlib",
973];
974
975pub fn classify_group_py(module_path: &str) -> ImportGroup {
977 if module_path.starts_with('.') {
979 return ImportGroup::Internal;
980 }
981 let top_module = module_path.split('.').next().unwrap_or(module_path);
983 if PYTHON_STDLIB.contains(&top_module) {
984 ImportGroup::Stdlib
985 } else {
986 ImportGroup::External
987 }
988}
989
990fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
992 let root = tree.root_node();
993 let mut imports = Vec::new();
994
995 let mut cursor = root.walk();
996 if !cursor.goto_first_child() {
997 return ImportBlock::empty();
998 }
999
1000 loop {
1001 let node = cursor.node();
1002 match node.kind() {
1003 "import_statement" => {
1004 if let Some(imp) = parse_py_import_statement(source, &node) {
1005 imports.push(imp);
1006 }
1007 }
1008 "import_from_statement" => {
1009 if let Some(imp) = parse_py_import_from_statement(source, &node) {
1010 imports.push(imp);
1011 }
1012 }
1013 _ => {}
1014 }
1015 if !cursor.goto_next_sibling() {
1016 break;
1017 }
1018 }
1019
1020 let byte_range = import_byte_range(&imports);
1021
1022 ImportBlock {
1023 imports,
1024 byte_range,
1025 }
1026}
1027
1028fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1030 let raw_text = source[node.byte_range()].to_string();
1031 let byte_range = node.byte_range();
1032
1033 let mut module_path = String::new();
1035 let mut c = node.walk();
1036 if c.goto_first_child() {
1037 loop {
1038 if c.node().kind() == "dotted_name" {
1039 module_path = source[c.node().byte_range()].to_string();
1040 break;
1041 }
1042 if !c.goto_next_sibling() {
1043 break;
1044 }
1045 }
1046 }
1047 if module_path.is_empty() {
1048 return None;
1049 }
1050
1051 let group = classify_group_py(&module_path);
1052
1053 Some(ImportStatement {
1054 module_path,
1055 names: Vec::new(),
1056 default_import: None,
1057 namespace_import: None,
1058 kind: ImportKind::Value,
1059 group,
1060 byte_range,
1061 raw_text,
1062 })
1063}
1064
1065fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1067 let raw_text = source[node.byte_range()].to_string();
1068 let byte_range = node.byte_range();
1069
1070 let mut module_path = String::new();
1071 let mut names = Vec::new();
1072
1073 let mut c = node.walk();
1074 if c.goto_first_child() {
1075 loop {
1076 let child = c.node();
1077 match child.kind() {
1078 "dotted_name" => {
1079 if module_path.is_empty()
1084 && !has_seen_import_keyword(source, node, child.start_byte())
1085 {
1086 module_path = source[child.byte_range()].to_string();
1087 } else {
1088 names.push(source[child.byte_range()].to_string());
1090 }
1091 }
1092 "relative_import" => {
1093 module_path = source[child.byte_range()].to_string();
1095 }
1096 _ => {}
1097 }
1098 if !c.goto_next_sibling() {
1099 break;
1100 }
1101 }
1102 }
1103
1104 if module_path.is_empty() {
1106 return None;
1107 }
1108
1109 let group = classify_group_py(&module_path);
1110
1111 Some(ImportStatement {
1112 module_path,
1113 names,
1114 default_import: None,
1115 namespace_import: None,
1116 kind: ImportKind::Value,
1117 group,
1118 byte_range,
1119 raw_text,
1120 })
1121}
1122
1123fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1125 let mut c = parent.walk();
1126 if c.goto_first_child() {
1127 loop {
1128 let child = c.node();
1129 if child.kind() == "import" && child.start_byte() < before_byte {
1130 return true;
1131 }
1132 if child.start_byte() >= before_byte {
1133 return false;
1134 }
1135 if !c.goto_next_sibling() {
1136 break;
1137 }
1138 }
1139 }
1140 false
1141}
1142
1143fn generate_py_import_line(
1145 module_path: &str,
1146 names: &[String],
1147 _default_import: Option<&str>,
1148) -> String {
1149 if names.is_empty() {
1150 format!("import {module_path}")
1152 } else {
1153 let mut sorted = names.to_vec();
1155 sorted.sort();
1156 let names_str = sorted.join(", ");
1157 format!("from {module_path} import {names_str}")
1158 }
1159}
1160
1161pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1167 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1169 match first_seg {
1170 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1171 "crate" | "self" | "super" => ImportGroup::Internal,
1172 _ => ImportGroup::External,
1173 }
1174}
1175
1176fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1178 let root = tree.root_node();
1179 let mut imports = Vec::new();
1180
1181 let mut cursor = root.walk();
1182 if !cursor.goto_first_child() {
1183 return ImportBlock::empty();
1184 }
1185
1186 loop {
1187 let node = cursor.node();
1188 if node.kind() == "use_declaration" {
1189 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1190 imports.push(imp);
1191 }
1192 }
1193 if !cursor.goto_next_sibling() {
1194 break;
1195 }
1196 }
1197
1198 let byte_range = import_byte_range(&imports);
1199
1200 ImportBlock {
1201 imports,
1202 byte_range,
1203 }
1204}
1205
1206fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1208 let raw_text = source[node.byte_range()].to_string();
1209 let byte_range = node.byte_range();
1210
1211 let mut has_pub = false;
1213 let mut use_path = String::new();
1214 let mut names = Vec::new();
1215
1216 let mut c = node.walk();
1217 if c.goto_first_child() {
1218 loop {
1219 let child = c.node();
1220 match child.kind() {
1221 "visibility_modifier" => {
1222 has_pub = true;
1223 }
1224 "scoped_identifier" | "identifier" | "use_as_clause" => {
1225 use_path = source[child.byte_range()].to_string();
1227 }
1228 "scoped_use_list" => {
1229 use_path = source[child.byte_range()].to_string();
1231 extract_rs_use_list_names(source, &child, &mut names);
1233 }
1234 _ => {}
1235 }
1236 if !c.goto_next_sibling() {
1237 break;
1238 }
1239 }
1240 }
1241
1242 if use_path.is_empty() {
1243 return None;
1244 }
1245
1246 let group = classify_group_rs(&use_path);
1247
1248 Some(ImportStatement {
1249 module_path: use_path,
1250 names,
1251 default_import: if has_pub {
1252 Some("pub".to_string())
1253 } else {
1254 None
1255 },
1256 namespace_import: None,
1257 kind: ImportKind::Value,
1258 group,
1259 byte_range,
1260 raw_text,
1261 })
1262}
1263
1264fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1266 let mut c = node.walk();
1267 if c.goto_first_child() {
1268 loop {
1269 let child = c.node();
1270 if child.kind() == "use_list" {
1271 let mut lc = child.walk();
1273 if lc.goto_first_child() {
1274 loop {
1275 let lchild = lc.node();
1276 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1277 names.push(source[lchild.byte_range()].to_string());
1278 }
1279 if !lc.goto_next_sibling() {
1280 break;
1281 }
1282 }
1283 }
1284 }
1285 if !c.goto_next_sibling() {
1286 break;
1287 }
1288 }
1289 }
1290}
1291
1292fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1294 if names.is_empty() {
1295 format!("use {module_path};")
1296 } else {
1297 format!("use {module_path};")
1301 }
1302}
1303
1304pub fn classify_group_go(module_path: &str) -> ImportGroup {
1310 if module_path.contains('.') {
1313 ImportGroup::External
1314 } else {
1315 ImportGroup::Stdlib
1316 }
1317}
1318
1319fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1321 let root = tree.root_node();
1322 let mut imports = Vec::new();
1323
1324 let mut cursor = root.walk();
1325 if !cursor.goto_first_child() {
1326 return ImportBlock::empty();
1327 }
1328
1329 loop {
1330 let node = cursor.node();
1331 if node.kind() == "import_declaration" {
1332 parse_go_import_declaration(source, &node, &mut imports);
1333 }
1334 if !cursor.goto_next_sibling() {
1335 break;
1336 }
1337 }
1338
1339 let byte_range = import_byte_range(&imports);
1340
1341 ImportBlock {
1342 imports,
1343 byte_range,
1344 }
1345}
1346
1347fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1349 let mut c = node.walk();
1350 if c.goto_first_child() {
1351 loop {
1352 let child = c.node();
1353 match child.kind() {
1354 "import_spec" => {
1355 if let Some(imp) = parse_go_import_spec(source, &child) {
1356 imports.push(imp);
1357 }
1358 }
1359 "import_spec_list" => {
1360 let mut lc = child.walk();
1362 if lc.goto_first_child() {
1363 loop {
1364 if lc.node().kind() == "import_spec" {
1365 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1366 imports.push(imp);
1367 }
1368 }
1369 if !lc.goto_next_sibling() {
1370 break;
1371 }
1372 }
1373 }
1374 }
1375 _ => {}
1376 }
1377 if !c.goto_next_sibling() {
1378 break;
1379 }
1380 }
1381 }
1382}
1383
1384fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1386 let raw_text = source[node.byte_range()].to_string();
1387 let byte_range = node.byte_range();
1388
1389 let mut import_path = String::new();
1390 let mut alias = None;
1391
1392 let mut c = node.walk();
1393 if c.goto_first_child() {
1394 loop {
1395 let child = c.node();
1396 match child.kind() {
1397 "interpreted_string_literal" => {
1398 let text = source[child.byte_range()].to_string();
1400 import_path = text.trim_matches('"').to_string();
1401 }
1402 "identifier" | "blank_identifier" | "dot" => {
1403 alias = Some(source[child.byte_range()].to_string());
1405 }
1406 _ => {}
1407 }
1408 if !c.goto_next_sibling() {
1409 break;
1410 }
1411 }
1412 }
1413
1414 if import_path.is_empty() {
1415 return None;
1416 }
1417
1418 let group = classify_group_go(&import_path);
1419
1420 Some(ImportStatement {
1421 module_path: import_path,
1422 names: Vec::new(),
1423 default_import: alias,
1424 namespace_import: None,
1425 kind: ImportKind::Value,
1426 group,
1427 byte_range,
1428 raw_text,
1429 })
1430}
1431
1432pub fn generate_go_import_line_pub(
1434 module_path: &str,
1435 alias: Option<&str>,
1436 in_group: bool,
1437) -> String {
1438 generate_go_import_line(module_path, alias, in_group)
1439}
1440
1441fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1446 if in_group {
1447 match alias {
1449 Some(a) => format!("\t{a} \"{module_path}\""),
1450 None => format!("\t\"{module_path}\""),
1451 }
1452 } else {
1453 match alias {
1455 Some(a) => format!("import {a} \"{module_path}\""),
1456 None => format!("import \"{module_path}\""),
1457 }
1458 }
1459}
1460
1461pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1464 let root = tree.root_node();
1465 let mut cursor = root.walk();
1466 if !cursor.goto_first_child() {
1467 return None;
1468 }
1469
1470 loop {
1471 let node = cursor.node();
1472 if node.kind() == "import_declaration" {
1473 let mut c = node.walk();
1474 if c.goto_first_child() {
1475 loop {
1476 if c.node().kind() == "import_spec_list" {
1477 return Some(c.node().byte_range());
1478 }
1479 if !c.goto_next_sibling() {
1480 break;
1481 }
1482 }
1483 }
1484 }
1485 if !cursor.goto_next_sibling() {
1486 break;
1487 }
1488 }
1489 None
1490}
1491
1492fn skip_newline(source: &str, pos: usize) -> usize {
1494 if pos < source.len() {
1495 let bytes = source.as_bytes();
1496 if bytes[pos] == b'\n' {
1497 return pos + 1;
1498 }
1499 if bytes[pos] == b'\r' {
1500 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1501 return pos + 2;
1502 }
1503 return pos + 1;
1504 }
1505 }
1506 pos
1507}
1508
1509#[cfg(test)]
1514mod tests {
1515 use super::*;
1516
1517 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1518 let grammar = grammar_for(LangId::TypeScript);
1519 let mut parser = Parser::new();
1520 parser.set_language(&grammar).unwrap();
1521 let tree = parser.parse(source, None).unwrap();
1522 let block = parse_imports(source, &tree, LangId::TypeScript);
1523 (tree, block)
1524 }
1525
1526 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1527 let grammar = grammar_for(LangId::JavaScript);
1528 let mut parser = Parser::new();
1529 parser.set_language(&grammar).unwrap();
1530 let tree = parser.parse(source, None).unwrap();
1531 let block = parse_imports(source, &tree, LangId::JavaScript);
1532 (tree, block)
1533 }
1534
1535 #[test]
1538 fn parse_ts_named_imports() {
1539 let source = "import { useState, useEffect } from 'react';\n";
1540 let (_, block) = parse_ts(source);
1541 assert_eq!(block.imports.len(), 1);
1542 let imp = &block.imports[0];
1543 assert_eq!(imp.module_path, "react");
1544 assert!(imp.names.contains(&"useState".to_string()));
1545 assert!(imp.names.contains(&"useEffect".to_string()));
1546 assert_eq!(imp.kind, ImportKind::Value);
1547 assert_eq!(imp.group, ImportGroup::External);
1548 }
1549
1550 #[test]
1551 fn parse_ts_default_import() {
1552 let source = "import React from 'react';\n";
1553 let (_, block) = parse_ts(source);
1554 assert_eq!(block.imports.len(), 1);
1555 let imp = &block.imports[0];
1556 assert_eq!(imp.default_import.as_deref(), Some("React"));
1557 assert_eq!(imp.kind, ImportKind::Value);
1558 }
1559
1560 #[test]
1561 fn parse_ts_side_effect_import() {
1562 let source = "import './styles.css';\n";
1563 let (_, block) = parse_ts(source);
1564 assert_eq!(block.imports.len(), 1);
1565 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1566 assert_eq!(block.imports[0].module_path, "./styles.css");
1567 }
1568
1569 #[test]
1570 fn parse_ts_relative_import() {
1571 let source = "import { helper } from './utils';\n";
1572 let (_, block) = parse_ts(source);
1573 assert_eq!(block.imports.len(), 1);
1574 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1575 }
1576
1577 #[test]
1578 fn parse_ts_multiple_groups() {
1579 let source = "\
1580import React from 'react';
1581import { useState } from 'react';
1582import { helper } from './utils';
1583import { Config } from '../config';
1584";
1585 let (_, block) = parse_ts(source);
1586 assert_eq!(block.imports.len(), 4);
1587
1588 let external: Vec<_> = block
1589 .imports
1590 .iter()
1591 .filter(|i| i.group == ImportGroup::External)
1592 .collect();
1593 let relative: Vec<_> = block
1594 .imports
1595 .iter()
1596 .filter(|i| i.group == ImportGroup::Internal)
1597 .collect();
1598 assert_eq!(external.len(), 2);
1599 assert_eq!(relative.len(), 2);
1600 }
1601
1602 #[test]
1603 fn parse_ts_namespace_import() {
1604 let source = "import * as path from 'path';\n";
1605 let (_, block) = parse_ts(source);
1606 assert_eq!(block.imports.len(), 1);
1607 let imp = &block.imports[0];
1608 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1609 assert_eq!(imp.kind, ImportKind::Value);
1610 }
1611
1612 #[test]
1613 fn parse_js_imports() {
1614 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1615 let (_, block) = parse_js(source);
1616 assert_eq!(block.imports.len(), 2);
1617 assert_eq!(block.imports[0].group, ImportGroup::External);
1618 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1619 }
1620
1621 #[test]
1624 fn classify_external() {
1625 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1626 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1627 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1628 }
1629
1630 #[test]
1631 fn classify_relative() {
1632 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1633 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1634 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1635 }
1636
1637 #[test]
1640 fn dedup_detects_same_named_import() {
1641 let source = "import { useState } from 'react';\n";
1642 let (_, block) = parse_ts(source);
1643 assert!(is_duplicate(
1644 &block,
1645 "react",
1646 &["useState".to_string()],
1647 None,
1648 false
1649 ));
1650 }
1651
1652 #[test]
1653 fn dedup_misses_different_name() {
1654 let source = "import { useState } from 'react';\n";
1655 let (_, block) = parse_ts(source);
1656 assert!(!is_duplicate(
1657 &block,
1658 "react",
1659 &["useEffect".to_string()],
1660 None,
1661 false
1662 ));
1663 }
1664
1665 #[test]
1666 fn dedup_detects_default_import() {
1667 let source = "import React from 'react';\n";
1668 let (_, block) = parse_ts(source);
1669 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1670 }
1671
1672 #[test]
1673 fn dedup_side_effect() {
1674 let source = "import './styles.css';\n";
1675 let (_, block) = parse_ts(source);
1676 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1677 }
1678
1679 #[test]
1680 fn dedup_namespace_import_distinct_from_side_effect_import() {
1681 let side_effect_source = "import 'fs';\n";
1682 let (_, side_effect_block) = parse_ts(side_effect_source);
1683 assert!(!is_duplicate_with_namespace(
1684 &side_effect_block,
1685 "fs",
1686 &[],
1687 None,
1688 Some("fs"),
1689 false
1690 ));
1691
1692 let namespace_source = "import * as fs from 'fs';\n";
1693 let (_, namespace_block) = parse_ts(namespace_source);
1694 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1695 assert!(is_duplicate_with_namespace(
1696 &namespace_block,
1697 "fs",
1698 &[],
1699 None,
1700 Some("fs"),
1701 false
1702 ));
1703 assert!(!is_duplicate_with_namespace(
1704 &namespace_block,
1705 "fs",
1706 &[],
1707 None,
1708 Some("other"),
1709 false
1710 ));
1711 }
1712
1713 #[test]
1714 fn dedup_type_vs_value() {
1715 let source = "import { FC } from 'react';\n";
1716 let (_, block) = parse_ts(source);
1717 assert!(!is_duplicate(
1719 &block,
1720 "react",
1721 &["FC".to_string()],
1722 None,
1723 true
1724 ));
1725 }
1726
1727 #[test]
1730 fn generate_named_import() {
1731 let line = generate_import_line(
1732 LangId::TypeScript,
1733 "react",
1734 &["useState".to_string(), "useEffect".to_string()],
1735 None,
1736 false,
1737 );
1738 assert_eq!(line, "import { useEffect, useState } from 'react';");
1739 }
1740
1741 #[test]
1742 fn generate_named_import_sorts_by_imported_name() {
1743 let line = generate_import_line(
1744 LangId::TypeScript,
1745 "x",
1746 &[
1747 "useState".to_string(),
1748 "type Foo".to_string(),
1749 "stdin as input".to_string(),
1750 "type Bar".to_string(),
1751 ],
1752 None,
1753 false,
1754 );
1755 assert_eq!(
1756 line,
1757 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1758 );
1759 }
1760
1761 #[test]
1762 fn generate_default_import() {
1763 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1764 assert_eq!(line, "import React from 'react';");
1765 }
1766
1767 #[test]
1768 fn generate_type_import() {
1769 let line =
1770 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1771 assert_eq!(line, "import type { FC } from 'react';");
1772 }
1773
1774 #[test]
1775 fn generate_side_effect_import() {
1776 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1777 assert_eq!(line, "import './styles.css';");
1778 }
1779
1780 #[test]
1781 fn generate_default_and_named() {
1782 let line = generate_import_line(
1783 LangId::TypeScript,
1784 "react",
1785 &["useState".to_string()],
1786 Some("React"),
1787 false,
1788 );
1789 assert_eq!(line, "import React, { useState } from 'react';");
1790 }
1791
1792 #[test]
1793 fn parse_ts_type_import() {
1794 let source = "import type { FC } from 'react';\n";
1795 let (_, block) = parse_ts(source);
1796 assert_eq!(block.imports.len(), 1);
1797 let imp = &block.imports[0];
1798 assert_eq!(imp.kind, ImportKind::Type);
1799 assert!(imp.names.contains(&"FC".to_string()));
1800 assert_eq!(imp.group, ImportGroup::External);
1801 }
1802
1803 #[test]
1806 fn insertion_empty_file() {
1807 let source = "";
1808 let (_, block) = parse_ts(source);
1809 let (offset, _, _) =
1810 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1811 assert_eq!(offset, 0);
1812 }
1813
1814 #[test]
1815 fn insertion_alphabetical_within_group() {
1816 let source = "\
1817import { a } from 'alpha';
1818import { c } from 'charlie';
1819";
1820 let (_, block) = parse_ts(source);
1821 let (offset, _, _) =
1822 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1823 let before_charlie = source.find("import { c }").unwrap();
1825 assert_eq!(offset, before_charlie);
1826 }
1827
1828 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1831 let grammar = grammar_for(LangId::Python);
1832 let mut parser = Parser::new();
1833 parser.set_language(&grammar).unwrap();
1834 let tree = parser.parse(source, None).unwrap();
1835 let block = parse_imports(source, &tree, LangId::Python);
1836 (tree, block)
1837 }
1838
1839 #[test]
1840 fn parse_py_import_statement() {
1841 let source = "import os\nimport sys\n";
1842 let (_, block) = parse_py(source);
1843 assert_eq!(block.imports.len(), 2);
1844 assert_eq!(block.imports[0].module_path, "os");
1845 assert_eq!(block.imports[1].module_path, "sys");
1846 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1847 }
1848
1849 #[test]
1850 fn parse_py_from_import() {
1851 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1852 let (_, block) = parse_py(source);
1853 assert_eq!(block.imports.len(), 2);
1854 assert_eq!(block.imports[0].module_path, "collections");
1855 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1856 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1857 assert_eq!(block.imports[1].module_path, "typing");
1858 assert!(block.imports[1].names.contains(&"List".to_string()));
1859 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1860 }
1861
1862 #[test]
1863 fn parse_py_relative_import() {
1864 let source = "from . import utils\nfrom ..config import Settings\n";
1865 let (_, block) = parse_py(source);
1866 assert_eq!(block.imports.len(), 2);
1867 assert_eq!(block.imports[0].module_path, ".");
1868 assert!(block.imports[0].names.contains(&"utils".to_string()));
1869 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1870 assert_eq!(block.imports[1].module_path, "..config");
1871 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1872 }
1873
1874 #[test]
1875 fn classify_py_groups() {
1876 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1877 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1878 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1879 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1880 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1881 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1882 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1883 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1884 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1885 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1886 }
1887
1888 #[test]
1889 fn parse_py_three_groups() {
1890 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1891 let (_, block) = parse_py(source);
1892 let stdlib: Vec<_> = block
1893 .imports
1894 .iter()
1895 .filter(|i| i.group == ImportGroup::Stdlib)
1896 .collect();
1897 let external: Vec<_> = block
1898 .imports
1899 .iter()
1900 .filter(|i| i.group == ImportGroup::External)
1901 .collect();
1902 let internal: Vec<_> = block
1903 .imports
1904 .iter()
1905 .filter(|i| i.group == ImportGroup::Internal)
1906 .collect();
1907 assert_eq!(stdlib.len(), 2);
1908 assert_eq!(external.len(), 1);
1909 assert_eq!(internal.len(), 1);
1910 }
1911
1912 #[test]
1913 fn generate_py_import() {
1914 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1915 assert_eq!(line, "import os");
1916 }
1917
1918 #[test]
1919 fn generate_py_from_import() {
1920 let line = generate_import_line(
1921 LangId::Python,
1922 "collections",
1923 &["OrderedDict".to_string()],
1924 None,
1925 false,
1926 );
1927 assert_eq!(line, "from collections import OrderedDict");
1928 }
1929
1930 #[test]
1931 fn generate_py_from_import_multiple() {
1932 let line = generate_import_line(
1933 LangId::Python,
1934 "typing",
1935 &["Optional".to_string(), "List".to_string()],
1936 None,
1937 false,
1938 );
1939 assert_eq!(line, "from typing import List, Optional");
1940 }
1941
1942 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1945 let grammar = grammar_for(LangId::Rust);
1946 let mut parser = Parser::new();
1947 parser.set_language(&grammar).unwrap();
1948 let tree = parser.parse(source, None).unwrap();
1949 let block = parse_imports(source, &tree, LangId::Rust);
1950 (tree, block)
1951 }
1952
1953 #[test]
1954 fn parse_rs_use_std() {
1955 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1956 let (_, block) = parse_rust(source);
1957 assert_eq!(block.imports.len(), 2);
1958 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1959 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1960 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1961 }
1962
1963 #[test]
1964 fn parse_rs_use_external() {
1965 let source = "use serde::{Deserialize, Serialize};\n";
1966 let (_, block) = parse_rust(source);
1967 assert_eq!(block.imports.len(), 1);
1968 assert_eq!(block.imports[0].group, ImportGroup::External);
1969 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1970 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1971 }
1972
1973 #[test]
1974 fn parse_rs_use_crate() {
1975 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1976 let (_, block) = parse_rust(source);
1977 assert_eq!(block.imports.len(), 2);
1978 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1979 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1980 }
1981
1982 #[test]
1983 fn parse_rs_pub_use() {
1984 let source = "pub use super::parent::Thing;\n";
1985 let (_, block) = parse_rust(source);
1986 assert_eq!(block.imports.len(), 1);
1987 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1989 }
1990
1991 #[test]
1992 fn classify_rs_groups() {
1993 assert_eq!(
1994 classify_group_rs("std::collections::HashMap"),
1995 ImportGroup::Stdlib
1996 );
1997 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
1998 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
1999 assert_eq!(
2000 classify_group_rs("serde::Deserialize"),
2001 ImportGroup::External
2002 );
2003 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
2004 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
2005 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
2006 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
2007 }
2008
2009 #[test]
2010 fn generate_rs_use() {
2011 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2012 assert_eq!(line, "use std::fmt::Display;");
2013 }
2014
2015 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2018 let grammar = grammar_for(LangId::Go);
2019 let mut parser = Parser::new();
2020 parser.set_language(&grammar).unwrap();
2021 let tree = parser.parse(source, None).unwrap();
2022 let block = parse_imports(source, &tree, LangId::Go);
2023 (tree, block)
2024 }
2025
2026 #[test]
2027 fn parse_go_single_import() {
2028 let source = "package main\n\nimport \"fmt\"\n";
2029 let (_, block) = parse_go(source);
2030 assert_eq!(block.imports.len(), 1);
2031 assert_eq!(block.imports[0].module_path, "fmt");
2032 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2033 }
2034
2035 #[test]
2036 fn parse_go_grouped_import() {
2037 let source =
2038 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2039 let (_, block) = parse_go(source);
2040 assert_eq!(block.imports.len(), 3);
2041 assert_eq!(block.imports[0].module_path, "fmt");
2042 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2043 assert_eq!(block.imports[1].module_path, "os");
2044 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2045 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2046 assert_eq!(block.imports[2].group, ImportGroup::External);
2047 }
2048
2049 #[test]
2050 fn parse_go_mixed_imports() {
2051 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2053 let (_, block) = parse_go(source);
2054 assert_eq!(block.imports.len(), 3);
2055 }
2056
2057 #[test]
2058 fn classify_go_groups() {
2059 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2060 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2061 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2062 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2063 assert_eq!(
2064 classify_group_go("github.com/pkg/errors"),
2065 ImportGroup::External
2066 );
2067 assert_eq!(
2068 classify_group_go("golang.org/x/tools"),
2069 ImportGroup::External
2070 );
2071 }
2072
2073 #[test]
2074 fn generate_go_standalone() {
2075 let line = generate_go_import_line("fmt", None, false);
2076 assert_eq!(line, "import \"fmt\"");
2077 }
2078
2079 #[test]
2080 fn generate_go_grouped_spec() {
2081 let line = generate_go_import_line("fmt", None, true);
2082 assert_eq!(line, "\t\"fmt\"");
2083 }
2084
2085 #[test]
2086 fn generate_go_with_alias() {
2087 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2088 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2089 }
2090}