1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C
178 | LangId::Cpp
179 | LangId::Zig
180 | LangId::CSharp
181 | LangId::Bash
182 | LangId::Solidity
183 | LangId::Vue => ImportBlock::empty(),
184 LangId::Html | LangId::Markdown => ImportBlock::empty(),
185 }
186}
187
188pub fn is_duplicate(
194 block: &ImportBlock,
195 module_path: &str,
196 names: &[String],
197 default_import: Option<&str>,
198 type_only: bool,
199) -> bool {
200 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
201}
202
203pub fn is_duplicate_with_namespace(
205 block: &ImportBlock,
206 module_path: &str,
207 names: &[String],
208 default_import: Option<&str>,
209 namespace_import: Option<&str>,
210 type_only: bool,
211) -> bool {
212 let target_kind = if type_only {
213 ImportKind::Type
214 } else {
215 ImportKind::Value
216 };
217
218 for imp in &block.imports {
219 if imp.module_path != module_path {
220 continue;
221 }
222
223 if names.is_empty()
229 && default_import.is_none()
230 && namespace_import.is_none()
231 && imp.names.is_empty()
232 && imp.default_import.is_none()
233 && imp.namespace_import.is_none()
234 {
235 return true;
236 }
237
238 if names.is_empty()
240 && default_import.is_none()
241 && namespace_import.is_none()
242 && imp.kind == ImportKind::SideEffect
243 {
244 return true;
245 }
246
247 if names.is_empty()
248 && default_import.is_none()
249 && namespace_import.is_some()
250 && imp.names.is_empty()
251 && imp.default_import.is_none()
252 && imp.namespace_import.as_deref() == namespace_import
253 {
254 return true;
255 }
256
257 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
259 continue;
260 }
261
262 if let Some(def) = default_import {
264 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
265 return true;
266 }
267 }
268
269 if !names.is_empty()
275 && names
276 .iter()
277 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
278 {
279 return true;
280 }
281 }
282
283 false
284}
285
286fn sort_named_specifiers(names: &mut [String]) {
287 names.sort_by(|a, b| {
288 specifier_imported_name(a)
289 .cmp(specifier_imported_name(b))
290 .then_with(|| a.cmp(b))
291 });
292}
293
294pub fn find_insertion_point(
305 source: &str,
306 block: &ImportBlock,
307 group: ImportGroup,
308 module_path: &str,
309 type_only: bool,
310) -> (usize, bool, bool) {
311 if block.imports.is_empty() {
312 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
314 }
315
316 let target_kind = if type_only {
317 ImportKind::Type
318 } else {
319 ImportKind::Value
320 };
321
322 let group_imports: Vec<&ImportStatement> =
324 block.imports.iter().filter(|i| i.group == group).collect();
325
326 if group_imports.is_empty() {
327 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
330
331 if let Some(last) = preceding_last {
332 let end = last.byte_range.end;
333 let insert_at = skip_newline(source, end);
334 return (insert_at, true, true);
335 }
336
337 let following_first = block.imports.iter().find(|i| i.group > group);
339
340 if let Some(first) = following_first {
341 return (first.byte_range.start, false, true);
342 }
343
344 let first_byte = import_byte_range(&block.imports)
346 .map(|range| range.start)
347 .unwrap_or(0);
348 return (first_byte, false, true);
349 }
350
351 for imp in &group_imports {
353 let cmp = module_path.cmp(&imp.module_path);
354 match cmp {
355 std::cmp::Ordering::Less => {
356 return (imp.byte_range.start, false, false);
358 }
359 std::cmp::Ordering::Equal => {
360 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
362 let end = imp.byte_range.end;
364 let insert_at = skip_newline(source, end);
365 return (insert_at, false, false);
366 }
367 return (imp.byte_range.start, false, false);
369 }
370 std::cmp::Ordering::Greater => continue,
371 }
372 }
373
374 let Some(last) = group_imports.last() else {
376 return (
377 import_byte_range(&block.imports)
378 .map(|range| range.end)
379 .unwrap_or(0),
380 false,
381 false,
382 );
383 };
384 let end = last.byte_range.end;
385 let insert_at = skip_newline(source, end);
386 (insert_at, false, false)
387}
388
389pub fn generate_import_line(
391 lang: LangId,
392 module_path: &str,
393 names: &[String],
394 default_import: Option<&str>,
395 type_only: bool,
396) -> String {
397 match lang {
398 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
399 generate_ts_import_line(module_path, names, default_import, type_only)
400 }
401 LangId::Python => generate_py_import_line(module_path, names, default_import),
402 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
403 LangId::Go => generate_go_import_line(module_path, default_import, false),
404 LangId::C
405 | LangId::Cpp
406 | LangId::Zig
407 | LangId::CSharp
408 | LangId::Bash
409 | LangId::Solidity
410 | LangId::Vue => String::new(),
411 LangId::Html | LangId::Markdown => String::new(),
412 }
413}
414
415pub fn is_supported(lang: LangId) -> bool {
417 matches!(
418 lang,
419 LangId::TypeScript
420 | LangId::Tsx
421 | LangId::JavaScript
422 | LangId::Python
423 | LangId::Rust
424 | LangId::Go
425 )
426}
427
428pub fn classify_group_ts(module_path: &str) -> ImportGroup {
430 if module_path.starts_with('.') {
431 ImportGroup::Internal
432 } else {
433 ImportGroup::External
434 }
435}
436
437pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
439 match lang {
440 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
441 LangId::Python => classify_group_py(module_path),
442 LangId::Rust => classify_group_rs(module_path),
443 LangId::Go => classify_group_go(module_path),
444 LangId::C
445 | LangId::Cpp
446 | LangId::Zig
447 | LangId::CSharp
448 | LangId::Bash
449 | LangId::Solidity
450 | LangId::Vue => ImportGroup::External,
451 LangId::Html | LangId::Markdown => ImportGroup::External,
452 }
453}
454
455pub fn parse_file_imports(
458 path: &std::path::Path,
459 lang: LangId,
460) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
461 let source =
462 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
463 path: format!("{}: {}", path.display(), e),
464 })?;
465
466 let grammar = grammar_for(lang);
467 let mut parser = Parser::new();
468 parser
469 .set_language(&grammar)
470 .map_err(|e| crate::error::AftError::ParseError {
471 message: format!("grammar init failed for {:?}: {}", lang, e),
472 })?;
473
474 let tree = parser
475 .parse(&source, None)
476 .ok_or_else(|| crate::error::AftError::ParseError {
477 message: format!("tree-sitter parse returned None for {}", path.display()),
478 })?;
479
480 let block = parse_imports(&source, &tree, lang);
481 Ok((source, tree, block))
482}
483
484fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
492 let root = tree.root_node();
493 let mut imports = Vec::new();
494
495 let mut cursor = root.walk();
496 if !cursor.goto_first_child() {
497 return ImportBlock::empty();
498 }
499
500 loop {
501 let node = cursor.node();
502 if node.kind() == "import_statement" {
503 if let Some(imp) = parse_single_ts_import(source, &node) {
504 imports.push(imp);
505 }
506 }
507 if !cursor.goto_next_sibling() {
508 break;
509 }
510 }
511
512 let byte_range = import_byte_range(&imports);
513
514 ImportBlock {
515 imports,
516 byte_range,
517 }
518}
519
520fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
522 let raw_text = source[node.byte_range()].to_string();
523 let byte_range = node.byte_range();
524
525 let module_path = extract_module_path(source, node)?;
527
528 let is_type_only = has_type_keyword(node);
530
531 let mut names = Vec::new();
533 let mut default_import = None;
534 let mut namespace_import = None;
535
536 let mut child_cursor = node.walk();
537 if child_cursor.goto_first_child() {
538 loop {
539 let child = child_cursor.node();
540 match child.kind() {
541 "import_clause" => {
542 extract_import_clause(
543 source,
544 &child,
545 &mut names,
546 &mut default_import,
547 &mut namespace_import,
548 );
549 }
550 "identifier" => {
552 let text = &source[child.byte_range()];
553 if text != "import" && text != "from" && text != "type" {
554 default_import = Some(text.to_string());
555 }
556 }
557 _ => {}
558 }
559 if !child_cursor.goto_next_sibling() {
560 break;
561 }
562 }
563 }
564
565 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
567 ImportKind::SideEffect
568 } else if is_type_only {
569 ImportKind::Type
570 } else {
571 ImportKind::Value
572 };
573
574 let group = classify_group_ts(&module_path);
575
576 Some(ImportStatement {
577 module_path,
578 names,
579 default_import,
580 namespace_import,
581 kind,
582 group,
583 byte_range,
584 raw_text,
585 })
586}
587
588fn extract_module_path(source: &str, node: &Node) -> Option<String> {
592 let mut cursor = node.walk();
593 if !cursor.goto_first_child() {
594 return None;
595 }
596
597 loop {
598 let child = cursor.node();
599 if child.kind() == "string" {
600 let text = &source[child.byte_range()];
602 let stripped = text
603 .trim_start_matches(|c| c == '\'' || c == '"')
604 .trim_end_matches(|c| c == '\'' || c == '"');
605 return Some(stripped.to_string());
606 }
607 if !cursor.goto_next_sibling() {
608 break;
609 }
610 }
611 None
612}
613
614fn has_type_keyword(node: &Node) -> bool {
619 let mut cursor = node.walk();
620 if !cursor.goto_first_child() {
621 return false;
622 }
623
624 loop {
625 let child = cursor.node();
626 if child.kind() == "type" {
627 return true;
628 }
629 if !cursor.goto_next_sibling() {
630 break;
631 }
632 }
633
634 false
635}
636
637fn extract_import_clause(
639 source: &str,
640 node: &Node,
641 names: &mut Vec<String>,
642 default_import: &mut Option<String>,
643 namespace_import: &mut Option<String>,
644) {
645 let mut cursor = node.walk();
646 if !cursor.goto_first_child() {
647 return;
648 }
649
650 loop {
651 let child = cursor.node();
652 match child.kind() {
653 "identifier" => {
654 let text = &source[child.byte_range()];
656 if text != "type" {
657 *default_import = Some(text.to_string());
658 }
659 }
660 "named_imports" => {
661 extract_named_imports(source, &child, names);
663 }
664 "namespace_import" => {
665 extract_namespace_import(source, &child, namespace_import);
667 }
668 _ => {}
669 }
670 if !cursor.goto_next_sibling() {
671 break;
672 }
673 }
674}
675
676fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
695 let mut cursor = node.walk();
696 if !cursor.goto_first_child() {
697 return;
698 }
699
700 loop {
701 let child = cursor.node();
702 if child.kind() == "import_specifier" {
703 let raw = source[child.byte_range()].trim().to_string();
708 if !raw.is_empty() {
709 names.push(raw);
710 } else if let Some(name_node) = child.child_by_field_name("name") {
711 names.push(source[name_node.byte_range()].to_string());
712 }
713 }
714 if !cursor.goto_next_sibling() {
715 break;
716 }
717 }
718}
719
720fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
722 let mut cursor = node.walk();
723 if !cursor.goto_first_child() {
724 return;
725 }
726
727 loop {
728 let child = cursor.node();
729 if child.kind() == "identifier" {
730 *namespace_import = Some(source[child.byte_range()].to_string());
731 return;
732 }
733 if !cursor.goto_next_sibling() {
734 break;
735 }
736 }
737}
738
739fn generate_ts_import_line(
741 module_path: &str,
742 names: &[String],
743 default_import: Option<&str>,
744 type_only: bool,
745) -> String {
746 let type_prefix = if type_only { "type " } else { "" };
747
748 if names.is_empty() && default_import.is_none() {
750 return format!("import '{module_path}';");
751 }
752
753 if names.is_empty() {
755 if let Some(def) = default_import {
756 return format!("import {type_prefix}{def} from '{module_path}';");
757 }
758 }
759
760 if default_import.is_none() {
762 let mut sorted_names = names.to_vec();
763 sort_named_specifiers(&mut sorted_names);
764 let names_str = sorted_names.join(", ");
765 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
766 }
767
768 if let Some(def) = default_import {
770 let mut sorted_names = names.to_vec();
771 sort_named_specifiers(&mut sorted_names);
772 let names_str = sorted_names.join(", ");
773 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
774 }
775
776 format!("import '{module_path}';")
778}
779
780const PYTHON_STDLIB: &[&str] = &[
788 "__future__",
789 "_thread",
790 "abc",
791 "aifc",
792 "argparse",
793 "array",
794 "ast",
795 "asynchat",
796 "asyncio",
797 "asyncore",
798 "atexit",
799 "audioop",
800 "base64",
801 "bdb",
802 "binascii",
803 "bisect",
804 "builtins",
805 "bz2",
806 "calendar",
807 "cgi",
808 "cgitb",
809 "chunk",
810 "cmath",
811 "cmd",
812 "code",
813 "codecs",
814 "codeop",
815 "collections",
816 "colorsys",
817 "compileall",
818 "concurrent",
819 "configparser",
820 "contextlib",
821 "contextvars",
822 "copy",
823 "copyreg",
824 "cProfile",
825 "crypt",
826 "csv",
827 "ctypes",
828 "curses",
829 "dataclasses",
830 "datetime",
831 "dbm",
832 "decimal",
833 "difflib",
834 "dis",
835 "distutils",
836 "doctest",
837 "email",
838 "encodings",
839 "enum",
840 "errno",
841 "faulthandler",
842 "fcntl",
843 "filecmp",
844 "fileinput",
845 "fnmatch",
846 "fractions",
847 "ftplib",
848 "functools",
849 "gc",
850 "getopt",
851 "getpass",
852 "gettext",
853 "glob",
854 "grp",
855 "gzip",
856 "hashlib",
857 "heapq",
858 "hmac",
859 "html",
860 "http",
861 "idlelib",
862 "imaplib",
863 "imghdr",
864 "importlib",
865 "inspect",
866 "io",
867 "ipaddress",
868 "itertools",
869 "json",
870 "keyword",
871 "lib2to3",
872 "linecache",
873 "locale",
874 "logging",
875 "lzma",
876 "mailbox",
877 "mailcap",
878 "marshal",
879 "math",
880 "mimetypes",
881 "mmap",
882 "modulefinder",
883 "multiprocessing",
884 "netrc",
885 "numbers",
886 "operator",
887 "optparse",
888 "os",
889 "pathlib",
890 "pdb",
891 "pickle",
892 "pickletools",
893 "pipes",
894 "pkgutil",
895 "platform",
896 "plistlib",
897 "poplib",
898 "posixpath",
899 "pprint",
900 "profile",
901 "pstats",
902 "pty",
903 "pwd",
904 "py_compile",
905 "pyclbr",
906 "pydoc",
907 "queue",
908 "quopri",
909 "random",
910 "re",
911 "readline",
912 "reprlib",
913 "resource",
914 "rlcompleter",
915 "runpy",
916 "sched",
917 "secrets",
918 "select",
919 "selectors",
920 "shelve",
921 "shlex",
922 "shutil",
923 "signal",
924 "site",
925 "smtplib",
926 "sndhdr",
927 "socket",
928 "socketserver",
929 "sqlite3",
930 "ssl",
931 "stat",
932 "statistics",
933 "string",
934 "stringprep",
935 "struct",
936 "subprocess",
937 "symtable",
938 "sys",
939 "sysconfig",
940 "syslog",
941 "tabnanny",
942 "tarfile",
943 "tempfile",
944 "termios",
945 "textwrap",
946 "threading",
947 "time",
948 "timeit",
949 "tkinter",
950 "token",
951 "tokenize",
952 "tomllib",
953 "trace",
954 "traceback",
955 "tracemalloc",
956 "tty",
957 "turtle",
958 "types",
959 "typing",
960 "unicodedata",
961 "unittest",
962 "urllib",
963 "uuid",
964 "venv",
965 "warnings",
966 "wave",
967 "weakref",
968 "webbrowser",
969 "wsgiref",
970 "xml",
971 "xmlrpc",
972 "zipapp",
973 "zipfile",
974 "zipimport",
975 "zlib",
976];
977
978pub fn classify_group_py(module_path: &str) -> ImportGroup {
980 if module_path.starts_with('.') {
982 return ImportGroup::Internal;
983 }
984 let top_module = module_path.split('.').next().unwrap_or(module_path);
986 if PYTHON_STDLIB.contains(&top_module) {
987 ImportGroup::Stdlib
988 } else {
989 ImportGroup::External
990 }
991}
992
993fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
995 let root = tree.root_node();
996 let mut imports = Vec::new();
997
998 let mut cursor = root.walk();
999 if !cursor.goto_first_child() {
1000 return ImportBlock::empty();
1001 }
1002
1003 loop {
1004 let node = cursor.node();
1005 match node.kind() {
1006 "import_statement" => {
1007 if let Some(imp) = parse_py_import_statement(source, &node) {
1008 imports.push(imp);
1009 }
1010 }
1011 "import_from_statement" => {
1012 if let Some(imp) = parse_py_import_from_statement(source, &node) {
1013 imports.push(imp);
1014 }
1015 }
1016 _ => {}
1017 }
1018 if !cursor.goto_next_sibling() {
1019 break;
1020 }
1021 }
1022
1023 let byte_range = import_byte_range(&imports);
1024
1025 ImportBlock {
1026 imports,
1027 byte_range,
1028 }
1029}
1030
1031fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1033 let raw_text = source[node.byte_range()].to_string();
1034 let byte_range = node.byte_range();
1035
1036 let mut module_path = String::new();
1038 let mut c = node.walk();
1039 if c.goto_first_child() {
1040 loop {
1041 if c.node().kind() == "dotted_name" {
1042 module_path = source[c.node().byte_range()].to_string();
1043 break;
1044 }
1045 if !c.goto_next_sibling() {
1046 break;
1047 }
1048 }
1049 }
1050 if module_path.is_empty() {
1051 return None;
1052 }
1053
1054 let group = classify_group_py(&module_path);
1055
1056 Some(ImportStatement {
1057 module_path,
1058 names: Vec::new(),
1059 default_import: None,
1060 namespace_import: None,
1061 kind: ImportKind::Value,
1062 group,
1063 byte_range,
1064 raw_text,
1065 })
1066}
1067
1068fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1070 let raw_text = source[node.byte_range()].to_string();
1071 let byte_range = node.byte_range();
1072
1073 let mut module_path = String::new();
1074 let mut names = Vec::new();
1075
1076 let mut c = node.walk();
1077 if c.goto_first_child() {
1078 loop {
1079 let child = c.node();
1080 match child.kind() {
1081 "dotted_name" => {
1082 if module_path.is_empty()
1087 && !has_seen_import_keyword(source, node, child.start_byte())
1088 {
1089 module_path = source[child.byte_range()].to_string();
1090 } else {
1091 names.push(source[child.byte_range()].to_string());
1093 }
1094 }
1095 "relative_import" => {
1096 module_path = source[child.byte_range()].to_string();
1098 }
1099 _ => {}
1100 }
1101 if !c.goto_next_sibling() {
1102 break;
1103 }
1104 }
1105 }
1106
1107 if module_path.is_empty() {
1109 return None;
1110 }
1111
1112 let group = classify_group_py(&module_path);
1113
1114 Some(ImportStatement {
1115 module_path,
1116 names,
1117 default_import: None,
1118 namespace_import: None,
1119 kind: ImportKind::Value,
1120 group,
1121 byte_range,
1122 raw_text,
1123 })
1124}
1125
1126fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1128 let mut c = parent.walk();
1129 if c.goto_first_child() {
1130 loop {
1131 let child = c.node();
1132 if child.kind() == "import" && child.start_byte() < before_byte {
1133 return true;
1134 }
1135 if child.start_byte() >= before_byte {
1136 return false;
1137 }
1138 if !c.goto_next_sibling() {
1139 break;
1140 }
1141 }
1142 }
1143 false
1144}
1145
1146fn generate_py_import_line(
1148 module_path: &str,
1149 names: &[String],
1150 _default_import: Option<&str>,
1151) -> String {
1152 if names.is_empty() {
1153 format!("import {module_path}")
1155 } else {
1156 let mut sorted = names.to_vec();
1158 sorted.sort();
1159 let names_str = sorted.join(", ");
1160 format!("from {module_path} import {names_str}")
1161 }
1162}
1163
1164pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1170 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1172 match first_seg {
1173 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1174 "crate" | "self" | "super" => ImportGroup::Internal,
1175 _ => ImportGroup::External,
1176 }
1177}
1178
1179fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1181 let root = tree.root_node();
1182 let mut imports = Vec::new();
1183
1184 let mut cursor = root.walk();
1185 if !cursor.goto_first_child() {
1186 return ImportBlock::empty();
1187 }
1188
1189 loop {
1190 let node = cursor.node();
1191 if node.kind() == "use_declaration" {
1192 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1193 imports.push(imp);
1194 }
1195 }
1196 if !cursor.goto_next_sibling() {
1197 break;
1198 }
1199 }
1200
1201 let byte_range = import_byte_range(&imports);
1202
1203 ImportBlock {
1204 imports,
1205 byte_range,
1206 }
1207}
1208
1209fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1211 let raw_text = source[node.byte_range()].to_string();
1212 let byte_range = node.byte_range();
1213
1214 let mut has_pub = false;
1216 let mut use_path = String::new();
1217 let mut names = Vec::new();
1218
1219 let mut c = node.walk();
1220 if c.goto_first_child() {
1221 loop {
1222 let child = c.node();
1223 match child.kind() {
1224 "visibility_modifier" => {
1225 has_pub = true;
1226 }
1227 "scoped_identifier" | "identifier" | "use_as_clause" => {
1228 use_path = source[child.byte_range()].to_string();
1230 }
1231 "scoped_use_list" => {
1232 use_path = source[child.byte_range()].to_string();
1234 extract_rs_use_list_names(source, &child, &mut names);
1236 }
1237 _ => {}
1238 }
1239 if !c.goto_next_sibling() {
1240 break;
1241 }
1242 }
1243 }
1244
1245 if use_path.is_empty() {
1246 return None;
1247 }
1248
1249 let group = classify_group_rs(&use_path);
1250
1251 Some(ImportStatement {
1252 module_path: use_path,
1253 names,
1254 default_import: if has_pub {
1255 Some("pub".to_string())
1256 } else {
1257 None
1258 },
1259 namespace_import: None,
1260 kind: ImportKind::Value,
1261 group,
1262 byte_range,
1263 raw_text,
1264 })
1265}
1266
1267fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1269 let mut c = node.walk();
1270 if c.goto_first_child() {
1271 loop {
1272 let child = c.node();
1273 if child.kind() == "use_list" {
1274 let mut lc = child.walk();
1276 if lc.goto_first_child() {
1277 loop {
1278 let lchild = lc.node();
1279 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1280 names.push(source[lchild.byte_range()].to_string());
1281 }
1282 if !lc.goto_next_sibling() {
1283 break;
1284 }
1285 }
1286 }
1287 }
1288 if !c.goto_next_sibling() {
1289 break;
1290 }
1291 }
1292 }
1293}
1294
1295fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1297 if names.is_empty() {
1298 format!("use {module_path};")
1299 } else {
1300 format!("use {module_path};")
1304 }
1305}
1306
1307pub fn classify_group_go(module_path: &str) -> ImportGroup {
1313 if module_path.contains('.') {
1316 ImportGroup::External
1317 } else {
1318 ImportGroup::Stdlib
1319 }
1320}
1321
1322fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1324 let root = tree.root_node();
1325 let mut imports = Vec::new();
1326
1327 let mut cursor = root.walk();
1328 if !cursor.goto_first_child() {
1329 return ImportBlock::empty();
1330 }
1331
1332 loop {
1333 let node = cursor.node();
1334 if node.kind() == "import_declaration" {
1335 parse_go_import_declaration(source, &node, &mut imports);
1336 }
1337 if !cursor.goto_next_sibling() {
1338 break;
1339 }
1340 }
1341
1342 let byte_range = import_byte_range(&imports);
1343
1344 ImportBlock {
1345 imports,
1346 byte_range,
1347 }
1348}
1349
1350fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1352 let mut c = node.walk();
1353 if c.goto_first_child() {
1354 loop {
1355 let child = c.node();
1356 match child.kind() {
1357 "import_spec" => {
1358 if let Some(imp) = parse_go_import_spec(source, &child) {
1359 imports.push(imp);
1360 }
1361 }
1362 "import_spec_list" => {
1363 let mut lc = child.walk();
1365 if lc.goto_first_child() {
1366 loop {
1367 if lc.node().kind() == "import_spec" {
1368 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1369 imports.push(imp);
1370 }
1371 }
1372 if !lc.goto_next_sibling() {
1373 break;
1374 }
1375 }
1376 }
1377 }
1378 _ => {}
1379 }
1380 if !c.goto_next_sibling() {
1381 break;
1382 }
1383 }
1384 }
1385}
1386
1387fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1389 let raw_text = source[node.byte_range()].to_string();
1390 let byte_range = node.byte_range();
1391
1392 let mut import_path = String::new();
1393 let mut alias = None;
1394
1395 let mut c = node.walk();
1396 if c.goto_first_child() {
1397 loop {
1398 let child = c.node();
1399 match child.kind() {
1400 "interpreted_string_literal" => {
1401 let text = source[child.byte_range()].to_string();
1403 import_path = text.trim_matches('"').to_string();
1404 }
1405 "identifier" | "blank_identifier" | "dot" => {
1406 alias = Some(source[child.byte_range()].to_string());
1408 }
1409 _ => {}
1410 }
1411 if !c.goto_next_sibling() {
1412 break;
1413 }
1414 }
1415 }
1416
1417 if import_path.is_empty() {
1418 return None;
1419 }
1420
1421 let group = classify_group_go(&import_path);
1422
1423 Some(ImportStatement {
1424 module_path: import_path,
1425 names: Vec::new(),
1426 default_import: alias,
1427 namespace_import: None,
1428 kind: ImportKind::Value,
1429 group,
1430 byte_range,
1431 raw_text,
1432 })
1433}
1434
1435pub fn generate_go_import_line_pub(
1437 module_path: &str,
1438 alias: Option<&str>,
1439 in_group: bool,
1440) -> String {
1441 generate_go_import_line(module_path, alias, in_group)
1442}
1443
1444fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1449 if in_group {
1450 match alias {
1452 Some(a) => format!("\t{a} \"{module_path}\""),
1453 None => format!("\t\"{module_path}\""),
1454 }
1455 } else {
1456 match alias {
1458 Some(a) => format!("import {a} \"{module_path}\""),
1459 None => format!("import \"{module_path}\""),
1460 }
1461 }
1462}
1463
1464pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1467 let root = tree.root_node();
1468 let mut cursor = root.walk();
1469 if !cursor.goto_first_child() {
1470 return None;
1471 }
1472
1473 loop {
1474 let node = cursor.node();
1475 if node.kind() == "import_declaration" {
1476 let mut c = node.walk();
1477 if c.goto_first_child() {
1478 loop {
1479 if c.node().kind() == "import_spec_list" {
1480 return Some(node.byte_range());
1481 }
1482 if !c.goto_next_sibling() {
1483 break;
1484 }
1485 }
1486 }
1487 }
1488 if !cursor.goto_next_sibling() {
1489 break;
1490 }
1491 }
1492 None
1493}
1494
1495fn skip_newline(source: &str, pos: usize) -> usize {
1497 if pos < source.len() {
1498 let bytes = source.as_bytes();
1499 if bytes[pos] == b'\n' {
1500 return pos + 1;
1501 }
1502 if bytes[pos] == b'\r' {
1503 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1504 return pos + 2;
1505 }
1506 return pos + 1;
1507 }
1508 }
1509 pos
1510}
1511
1512#[cfg(test)]
1517mod tests {
1518 use super::*;
1519
1520 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1521 let grammar = grammar_for(LangId::TypeScript);
1522 let mut parser = Parser::new();
1523 parser.set_language(&grammar).unwrap();
1524 let tree = parser.parse(source, None).unwrap();
1525 let block = parse_imports(source, &tree, LangId::TypeScript);
1526 (tree, block)
1527 }
1528
1529 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1530 let grammar = grammar_for(LangId::JavaScript);
1531 let mut parser = Parser::new();
1532 parser.set_language(&grammar).unwrap();
1533 let tree = parser.parse(source, None).unwrap();
1534 let block = parse_imports(source, &tree, LangId::JavaScript);
1535 (tree, block)
1536 }
1537
1538 #[test]
1541 fn parse_ts_named_imports() {
1542 let source = "import { useState, useEffect } from 'react';\n";
1543 let (_, block) = parse_ts(source);
1544 assert_eq!(block.imports.len(), 1);
1545 let imp = &block.imports[0];
1546 assert_eq!(imp.module_path, "react");
1547 assert!(imp.names.contains(&"useState".to_string()));
1548 assert!(imp.names.contains(&"useEffect".to_string()));
1549 assert_eq!(imp.kind, ImportKind::Value);
1550 assert_eq!(imp.group, ImportGroup::External);
1551 }
1552
1553 #[test]
1554 fn parse_ts_default_import() {
1555 let source = "import React from 'react';\n";
1556 let (_, block) = parse_ts(source);
1557 assert_eq!(block.imports.len(), 1);
1558 let imp = &block.imports[0];
1559 assert_eq!(imp.default_import.as_deref(), Some("React"));
1560 assert_eq!(imp.kind, ImportKind::Value);
1561 }
1562
1563 #[test]
1564 fn parse_ts_side_effect_import() {
1565 let source = "import './styles.css';\n";
1566 let (_, block) = parse_ts(source);
1567 assert_eq!(block.imports.len(), 1);
1568 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1569 assert_eq!(block.imports[0].module_path, "./styles.css");
1570 }
1571
1572 #[test]
1573 fn parse_ts_relative_import() {
1574 let source = "import { helper } from './utils';\n";
1575 let (_, block) = parse_ts(source);
1576 assert_eq!(block.imports.len(), 1);
1577 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1578 }
1579
1580 #[test]
1581 fn parse_ts_multiple_groups() {
1582 let source = "\
1583import React from 'react';
1584import { useState } from 'react';
1585import { helper } from './utils';
1586import { Config } from '../config';
1587";
1588 let (_, block) = parse_ts(source);
1589 assert_eq!(block.imports.len(), 4);
1590
1591 let external: Vec<_> = block
1592 .imports
1593 .iter()
1594 .filter(|i| i.group == ImportGroup::External)
1595 .collect();
1596 let relative: Vec<_> = block
1597 .imports
1598 .iter()
1599 .filter(|i| i.group == ImportGroup::Internal)
1600 .collect();
1601 assert_eq!(external.len(), 2);
1602 assert_eq!(relative.len(), 2);
1603 }
1604
1605 #[test]
1606 fn parse_ts_namespace_import() {
1607 let source = "import * as path from 'path';\n";
1608 let (_, block) = parse_ts(source);
1609 assert_eq!(block.imports.len(), 1);
1610 let imp = &block.imports[0];
1611 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1612 assert_eq!(imp.kind, ImportKind::Value);
1613 }
1614
1615 #[test]
1616 fn parse_js_imports() {
1617 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1618 let (_, block) = parse_js(source);
1619 assert_eq!(block.imports.len(), 2);
1620 assert_eq!(block.imports[0].group, ImportGroup::External);
1621 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1622 }
1623
1624 #[test]
1627 fn classify_external() {
1628 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1629 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1630 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1631 }
1632
1633 #[test]
1634 fn classify_relative() {
1635 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1636 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1637 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1638 }
1639
1640 #[test]
1643 fn dedup_detects_same_named_import() {
1644 let source = "import { useState } from 'react';\n";
1645 let (_, block) = parse_ts(source);
1646 assert!(is_duplicate(
1647 &block,
1648 "react",
1649 &["useState".to_string()],
1650 None,
1651 false
1652 ));
1653 }
1654
1655 #[test]
1656 fn dedup_misses_different_name() {
1657 let source = "import { useState } from 'react';\n";
1658 let (_, block) = parse_ts(source);
1659 assert!(!is_duplicate(
1660 &block,
1661 "react",
1662 &["useEffect".to_string()],
1663 None,
1664 false
1665 ));
1666 }
1667
1668 #[test]
1669 fn dedup_detects_default_import() {
1670 let source = "import React from 'react';\n";
1671 let (_, block) = parse_ts(source);
1672 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1673 }
1674
1675 #[test]
1676 fn dedup_side_effect() {
1677 let source = "import './styles.css';\n";
1678 let (_, block) = parse_ts(source);
1679 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1680 }
1681
1682 #[test]
1683 fn dedup_namespace_import_distinct_from_side_effect_import() {
1684 let side_effect_source = "import 'fs';\n";
1685 let (_, side_effect_block) = parse_ts(side_effect_source);
1686 assert!(!is_duplicate_with_namespace(
1687 &side_effect_block,
1688 "fs",
1689 &[],
1690 None,
1691 Some("fs"),
1692 false
1693 ));
1694
1695 let namespace_source = "import * as fs from 'fs';\n";
1696 let (_, namespace_block) = parse_ts(namespace_source);
1697 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1698 assert!(is_duplicate_with_namespace(
1699 &namespace_block,
1700 "fs",
1701 &[],
1702 None,
1703 Some("fs"),
1704 false
1705 ));
1706 assert!(!is_duplicate_with_namespace(
1707 &namespace_block,
1708 "fs",
1709 &[],
1710 None,
1711 Some("other"),
1712 false
1713 ));
1714 }
1715
1716 #[test]
1717 fn dedup_type_vs_value() {
1718 let source = "import { FC } from 'react';\n";
1719 let (_, block) = parse_ts(source);
1720 assert!(!is_duplicate(
1722 &block,
1723 "react",
1724 &["FC".to_string()],
1725 None,
1726 true
1727 ));
1728 }
1729
1730 #[test]
1733 fn generate_named_import() {
1734 let line = generate_import_line(
1735 LangId::TypeScript,
1736 "react",
1737 &["useState".to_string(), "useEffect".to_string()],
1738 None,
1739 false,
1740 );
1741 assert_eq!(line, "import { useEffect, useState } from 'react';");
1742 }
1743
1744 #[test]
1745 fn generate_named_import_sorts_by_imported_name() {
1746 let line = generate_import_line(
1747 LangId::TypeScript,
1748 "x",
1749 &[
1750 "useState".to_string(),
1751 "type Foo".to_string(),
1752 "stdin as input".to_string(),
1753 "type Bar".to_string(),
1754 ],
1755 None,
1756 false,
1757 );
1758 assert_eq!(
1759 line,
1760 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1761 );
1762 }
1763
1764 #[test]
1765 fn generate_default_import() {
1766 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1767 assert_eq!(line, "import React from 'react';");
1768 }
1769
1770 #[test]
1771 fn generate_type_import() {
1772 let line =
1773 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1774 assert_eq!(line, "import type { FC } from 'react';");
1775 }
1776
1777 #[test]
1778 fn generate_side_effect_import() {
1779 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1780 assert_eq!(line, "import './styles.css';");
1781 }
1782
1783 #[test]
1784 fn generate_default_and_named() {
1785 let line = generate_import_line(
1786 LangId::TypeScript,
1787 "react",
1788 &["useState".to_string()],
1789 Some("React"),
1790 false,
1791 );
1792 assert_eq!(line, "import React, { useState } from 'react';");
1793 }
1794
1795 #[test]
1796 fn parse_ts_type_import() {
1797 let source = "import type { FC } from 'react';\n";
1798 let (_, block) = parse_ts(source);
1799 assert_eq!(block.imports.len(), 1);
1800 let imp = &block.imports[0];
1801 assert_eq!(imp.kind, ImportKind::Type);
1802 assert!(imp.names.contains(&"FC".to_string()));
1803 assert_eq!(imp.group, ImportGroup::External);
1804 }
1805
1806 #[test]
1809 fn insertion_empty_file() {
1810 let source = "";
1811 let (_, block) = parse_ts(source);
1812 let (offset, _, _) =
1813 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1814 assert_eq!(offset, 0);
1815 }
1816
1817 #[test]
1818 fn insertion_alphabetical_within_group() {
1819 let source = "\
1820import { a } from 'alpha';
1821import { c } from 'charlie';
1822";
1823 let (_, block) = parse_ts(source);
1824 let (offset, _, _) =
1825 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1826 let before_charlie = source.find("import { c }").unwrap();
1828 assert_eq!(offset, before_charlie);
1829 }
1830
1831 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1834 let grammar = grammar_for(LangId::Python);
1835 let mut parser = Parser::new();
1836 parser.set_language(&grammar).unwrap();
1837 let tree = parser.parse(source, None).unwrap();
1838 let block = parse_imports(source, &tree, LangId::Python);
1839 (tree, block)
1840 }
1841
1842 #[test]
1843 fn parse_py_import_statement() {
1844 let source = "import os\nimport sys\n";
1845 let (_, block) = parse_py(source);
1846 assert_eq!(block.imports.len(), 2);
1847 assert_eq!(block.imports[0].module_path, "os");
1848 assert_eq!(block.imports[1].module_path, "sys");
1849 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1850 }
1851
1852 #[test]
1853 fn parse_py_from_import() {
1854 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1855 let (_, block) = parse_py(source);
1856 assert_eq!(block.imports.len(), 2);
1857 assert_eq!(block.imports[0].module_path, "collections");
1858 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1859 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1860 assert_eq!(block.imports[1].module_path, "typing");
1861 assert!(block.imports[1].names.contains(&"List".to_string()));
1862 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1863 }
1864
1865 #[test]
1866 fn parse_py_relative_import() {
1867 let source = "from . import utils\nfrom ..config import Settings\n";
1868 let (_, block) = parse_py(source);
1869 assert_eq!(block.imports.len(), 2);
1870 assert_eq!(block.imports[0].module_path, ".");
1871 assert!(block.imports[0].names.contains(&"utils".to_string()));
1872 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1873 assert_eq!(block.imports[1].module_path, "..config");
1874 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1875 }
1876
1877 #[test]
1878 fn classify_py_groups() {
1879 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1880 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1881 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1882 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1883 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1884 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1885 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1886 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1887 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1888 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1889 }
1890
1891 #[test]
1892 fn parse_py_three_groups() {
1893 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1894 let (_, block) = parse_py(source);
1895 let stdlib: Vec<_> = block
1896 .imports
1897 .iter()
1898 .filter(|i| i.group == ImportGroup::Stdlib)
1899 .collect();
1900 let external: Vec<_> = block
1901 .imports
1902 .iter()
1903 .filter(|i| i.group == ImportGroup::External)
1904 .collect();
1905 let internal: Vec<_> = block
1906 .imports
1907 .iter()
1908 .filter(|i| i.group == ImportGroup::Internal)
1909 .collect();
1910 assert_eq!(stdlib.len(), 2);
1911 assert_eq!(external.len(), 1);
1912 assert_eq!(internal.len(), 1);
1913 }
1914
1915 #[test]
1916 fn generate_py_import() {
1917 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1918 assert_eq!(line, "import os");
1919 }
1920
1921 #[test]
1922 fn generate_py_from_import() {
1923 let line = generate_import_line(
1924 LangId::Python,
1925 "collections",
1926 &["OrderedDict".to_string()],
1927 None,
1928 false,
1929 );
1930 assert_eq!(line, "from collections import OrderedDict");
1931 }
1932
1933 #[test]
1934 fn generate_py_from_import_multiple() {
1935 let line = generate_import_line(
1936 LangId::Python,
1937 "typing",
1938 &["Optional".to_string(), "List".to_string()],
1939 None,
1940 false,
1941 );
1942 assert_eq!(line, "from typing import List, Optional");
1943 }
1944
1945 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1948 let grammar = grammar_for(LangId::Rust);
1949 let mut parser = Parser::new();
1950 parser.set_language(&grammar).unwrap();
1951 let tree = parser.parse(source, None).unwrap();
1952 let block = parse_imports(source, &tree, LangId::Rust);
1953 (tree, block)
1954 }
1955
1956 #[test]
1957 fn parse_rs_use_std() {
1958 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1959 let (_, block) = parse_rust(source);
1960 assert_eq!(block.imports.len(), 2);
1961 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1962 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1963 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1964 }
1965
1966 #[test]
1967 fn parse_rs_use_external() {
1968 let source = "use serde::{Deserialize, Serialize};\n";
1969 let (_, block) = parse_rust(source);
1970 assert_eq!(block.imports.len(), 1);
1971 assert_eq!(block.imports[0].group, ImportGroup::External);
1972 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1973 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1974 }
1975
1976 #[test]
1977 fn parse_rs_use_crate() {
1978 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1979 let (_, block) = parse_rust(source);
1980 assert_eq!(block.imports.len(), 2);
1981 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1982 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1983 }
1984
1985 #[test]
1986 fn parse_rs_pub_use() {
1987 let source = "pub use super::parent::Thing;\n";
1988 let (_, block) = parse_rust(source);
1989 assert_eq!(block.imports.len(), 1);
1990 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1992 }
1993
1994 #[test]
1995 fn classify_rs_groups() {
1996 assert_eq!(
1997 classify_group_rs("std::collections::HashMap"),
1998 ImportGroup::Stdlib
1999 );
2000 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
2001 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
2002 assert_eq!(
2003 classify_group_rs("serde::Deserialize"),
2004 ImportGroup::External
2005 );
2006 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
2007 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
2008 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
2009 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
2010 }
2011
2012 #[test]
2013 fn generate_rs_use() {
2014 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2015 assert_eq!(line, "use std::fmt::Display;");
2016 }
2017
2018 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2021 let grammar = grammar_for(LangId::Go);
2022 let mut parser = Parser::new();
2023 parser.set_language(&grammar).unwrap();
2024 let tree = parser.parse(source, None).unwrap();
2025 let block = parse_imports(source, &tree, LangId::Go);
2026 (tree, block)
2027 }
2028
2029 #[test]
2030 fn parse_go_single_import() {
2031 let source = "package main\n\nimport \"fmt\"\n";
2032 let (_, block) = parse_go(source);
2033 assert_eq!(block.imports.len(), 1);
2034 assert_eq!(block.imports[0].module_path, "fmt");
2035 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2036 }
2037
2038 #[test]
2039 fn parse_go_grouped_import() {
2040 let source =
2041 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2042 let (_, block) = parse_go(source);
2043 assert_eq!(block.imports.len(), 3);
2044 assert_eq!(block.imports[0].module_path, "fmt");
2045 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2046 assert_eq!(block.imports[1].module_path, "os");
2047 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2048 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2049 assert_eq!(block.imports[2].group, ImportGroup::External);
2050 }
2051
2052 #[test]
2053 fn parse_go_mixed_imports() {
2054 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2056 let (_, block) = parse_go(source);
2057 assert_eq!(block.imports.len(), 3);
2058 }
2059
2060 #[test]
2061 fn classify_go_groups() {
2062 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2063 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2064 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2065 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2066 assert_eq!(
2067 classify_group_go("github.com/pkg/errors"),
2068 ImportGroup::External
2069 );
2070 assert_eq!(
2071 classify_group_go("golang.org/x/tools"),
2072 ImportGroup::External
2073 );
2074 }
2075
2076 #[test]
2077 fn generate_go_standalone() {
2078 let line = generate_go_import_line("fmt", None, false);
2079 assert_eq!(line, "import \"fmt\"");
2080 }
2081
2082 #[test]
2083 fn generate_go_grouped_spec() {
2084 let line = generate_go_import_line("fmt", None, true);
2085 assert_eq!(line, "\t\"fmt\"");
2086 }
2087
2088 #[test]
2089 fn generate_go_with_alias() {
2090 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2091 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2092 }
2093}