1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => {
178 ImportBlock::empty()
179 }
180 LangId::Html | LangId::Markdown => ImportBlock::empty(),
181 }
182}
183
184pub fn is_duplicate(
190 block: &ImportBlock,
191 module_path: &str,
192 names: &[String],
193 default_import: Option<&str>,
194 type_only: bool,
195) -> bool {
196 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
197}
198
199pub fn is_duplicate_with_namespace(
201 block: &ImportBlock,
202 module_path: &str,
203 names: &[String],
204 default_import: Option<&str>,
205 namespace_import: Option<&str>,
206 type_only: bool,
207) -> bool {
208 let target_kind = if type_only {
209 ImportKind::Type
210 } else {
211 ImportKind::Value
212 };
213
214 for imp in &block.imports {
215 if imp.module_path != module_path {
216 continue;
217 }
218
219 if names.is_empty()
225 && default_import.is_none()
226 && namespace_import.is_none()
227 && imp.names.is_empty()
228 && imp.default_import.is_none()
229 && imp.namespace_import.is_none()
230 {
231 return true;
232 }
233
234 if names.is_empty()
236 && default_import.is_none()
237 && namespace_import.is_none()
238 && imp.kind == ImportKind::SideEffect
239 {
240 return true;
241 }
242
243 if names.is_empty()
244 && default_import.is_none()
245 && namespace_import.is_some()
246 && imp.names.is_empty()
247 && imp.default_import.is_none()
248 && imp.namespace_import.as_deref() == namespace_import
249 {
250 return true;
251 }
252
253 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
255 continue;
256 }
257
258 if let Some(def) = default_import {
260 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
261 return true;
262 }
263 }
264
265 if !names.is_empty()
271 && names
272 .iter()
273 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
274 {
275 return true;
276 }
277 }
278
279 false
280}
281
282fn sort_named_specifiers(names: &mut [String]) {
283 names.sort_by(|a, b| {
284 specifier_imported_name(a)
285 .cmp(specifier_imported_name(b))
286 .then_with(|| a.cmp(b))
287 });
288}
289
290pub fn find_insertion_point(
301 source: &str,
302 block: &ImportBlock,
303 group: ImportGroup,
304 module_path: &str,
305 type_only: bool,
306) -> (usize, bool, bool) {
307 if block.imports.is_empty() {
308 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
310 }
311
312 let target_kind = if type_only {
313 ImportKind::Type
314 } else {
315 ImportKind::Value
316 };
317
318 let group_imports: Vec<&ImportStatement> =
320 block.imports.iter().filter(|i| i.group == group).collect();
321
322 if group_imports.is_empty() {
323 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
326
327 if let Some(last) = preceding_last {
328 let end = last.byte_range.end;
329 let insert_at = skip_newline(source, end);
330 return (insert_at, true, true);
331 }
332
333 let following_first = block.imports.iter().find(|i| i.group > group);
335
336 if let Some(first) = following_first {
337 return (first.byte_range.start, false, true);
338 }
339
340 let first_byte = import_byte_range(&block.imports)
342 .map(|range| range.start)
343 .unwrap_or(0);
344 return (first_byte, false, true);
345 }
346
347 for imp in &group_imports {
349 let cmp = module_path.cmp(&imp.module_path);
350 match cmp {
351 std::cmp::Ordering::Less => {
352 return (imp.byte_range.start, false, false);
354 }
355 std::cmp::Ordering::Equal => {
356 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
358 let end = imp.byte_range.end;
360 let insert_at = skip_newline(source, end);
361 return (insert_at, false, false);
362 }
363 return (imp.byte_range.start, false, false);
365 }
366 std::cmp::Ordering::Greater => continue,
367 }
368 }
369
370 let Some(last) = group_imports.last() else {
372 return (
373 import_byte_range(&block.imports)
374 .map(|range| range.end)
375 .unwrap_or(0),
376 false,
377 false,
378 );
379 };
380 let end = last.byte_range.end;
381 let insert_at = skip_newline(source, end);
382 (insert_at, false, false)
383}
384
385pub fn generate_import_line(
387 lang: LangId,
388 module_path: &str,
389 names: &[String],
390 default_import: Option<&str>,
391 type_only: bool,
392) -> String {
393 match lang {
394 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
395 generate_ts_import_line(module_path, names, default_import, type_only)
396 }
397 LangId::Python => generate_py_import_line(module_path, names, default_import),
398 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
399 LangId::Go => generate_go_import_line(module_path, default_import, false),
400 LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => String::new(),
401 LangId::Html | LangId::Markdown => String::new(),
402 }
403}
404
405pub fn is_supported(lang: LangId) -> bool {
407 matches!(
408 lang,
409 LangId::TypeScript
410 | LangId::Tsx
411 | LangId::JavaScript
412 | LangId::Python
413 | LangId::Rust
414 | LangId::Go
415 )
416}
417
418pub fn classify_group_ts(module_path: &str) -> ImportGroup {
420 if module_path.starts_with('.') {
421 ImportGroup::Internal
422 } else {
423 ImportGroup::External
424 }
425}
426
427pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
429 match lang {
430 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
431 LangId::Python => classify_group_py(module_path),
432 LangId::Rust => classify_group_rs(module_path),
433 LangId::Go => classify_group_go(module_path),
434 LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp | LangId::Bash => {
435 ImportGroup::External
436 }
437 LangId::Html | LangId::Markdown => ImportGroup::External,
438 }
439}
440
441pub fn parse_file_imports(
444 path: &std::path::Path,
445 lang: LangId,
446) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
447 let source =
448 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
449 path: format!("{}: {}", path.display(), e),
450 })?;
451
452 let grammar = grammar_for(lang);
453 let mut parser = Parser::new();
454 parser
455 .set_language(&grammar)
456 .map_err(|e| crate::error::AftError::ParseError {
457 message: format!("grammar init failed for {:?}: {}", lang, e),
458 })?;
459
460 let tree = parser
461 .parse(&source, None)
462 .ok_or_else(|| crate::error::AftError::ParseError {
463 message: format!("tree-sitter parse returned None for {}", path.display()),
464 })?;
465
466 let block = parse_imports(&source, &tree, lang);
467 Ok((source, tree, block))
468}
469
470fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
478 let root = tree.root_node();
479 let mut imports = Vec::new();
480
481 let mut cursor = root.walk();
482 if !cursor.goto_first_child() {
483 return ImportBlock::empty();
484 }
485
486 loop {
487 let node = cursor.node();
488 if node.kind() == "import_statement" {
489 if let Some(imp) = parse_single_ts_import(source, &node) {
490 imports.push(imp);
491 }
492 }
493 if !cursor.goto_next_sibling() {
494 break;
495 }
496 }
497
498 let byte_range = import_byte_range(&imports);
499
500 ImportBlock {
501 imports,
502 byte_range,
503 }
504}
505
506fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
508 let raw_text = source[node.byte_range()].to_string();
509 let byte_range = node.byte_range();
510
511 let module_path = extract_module_path(source, node)?;
513
514 let is_type_only = has_type_keyword(node);
516
517 let mut names = Vec::new();
519 let mut default_import = None;
520 let mut namespace_import = None;
521
522 let mut child_cursor = node.walk();
523 if child_cursor.goto_first_child() {
524 loop {
525 let child = child_cursor.node();
526 match child.kind() {
527 "import_clause" => {
528 extract_import_clause(
529 source,
530 &child,
531 &mut names,
532 &mut default_import,
533 &mut namespace_import,
534 );
535 }
536 "identifier" => {
538 let text = &source[child.byte_range()];
539 if text != "import" && text != "from" && text != "type" {
540 default_import = Some(text.to_string());
541 }
542 }
543 _ => {}
544 }
545 if !child_cursor.goto_next_sibling() {
546 break;
547 }
548 }
549 }
550
551 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
553 ImportKind::SideEffect
554 } else if is_type_only {
555 ImportKind::Type
556 } else {
557 ImportKind::Value
558 };
559
560 let group = classify_group_ts(&module_path);
561
562 Some(ImportStatement {
563 module_path,
564 names,
565 default_import,
566 namespace_import,
567 kind,
568 group,
569 byte_range,
570 raw_text,
571 })
572}
573
574fn extract_module_path(source: &str, node: &Node) -> Option<String> {
578 let mut cursor = node.walk();
579 if !cursor.goto_first_child() {
580 return None;
581 }
582
583 loop {
584 let child = cursor.node();
585 if child.kind() == "string" {
586 let text = &source[child.byte_range()];
588 let stripped = text
589 .trim_start_matches(|c| c == '\'' || c == '"')
590 .trim_end_matches(|c| c == '\'' || c == '"');
591 return Some(stripped.to_string());
592 }
593 if !cursor.goto_next_sibling() {
594 break;
595 }
596 }
597 None
598}
599
600fn has_type_keyword(node: &Node) -> bool {
605 let mut cursor = node.walk();
606 if !cursor.goto_first_child() {
607 return false;
608 }
609
610 loop {
611 let child = cursor.node();
612 if child.kind() == "type" {
613 return true;
614 }
615 if !cursor.goto_next_sibling() {
616 break;
617 }
618 }
619
620 false
621}
622
623fn extract_import_clause(
625 source: &str,
626 node: &Node,
627 names: &mut Vec<String>,
628 default_import: &mut Option<String>,
629 namespace_import: &mut Option<String>,
630) {
631 let mut cursor = node.walk();
632 if !cursor.goto_first_child() {
633 return;
634 }
635
636 loop {
637 let child = cursor.node();
638 match child.kind() {
639 "identifier" => {
640 let text = &source[child.byte_range()];
642 if text != "type" {
643 *default_import = Some(text.to_string());
644 }
645 }
646 "named_imports" => {
647 extract_named_imports(source, &child, names);
649 }
650 "namespace_import" => {
651 extract_namespace_import(source, &child, namespace_import);
653 }
654 _ => {}
655 }
656 if !cursor.goto_next_sibling() {
657 break;
658 }
659 }
660}
661
662fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
681 let mut cursor = node.walk();
682 if !cursor.goto_first_child() {
683 return;
684 }
685
686 loop {
687 let child = cursor.node();
688 if child.kind() == "import_specifier" {
689 let raw = source[child.byte_range()].trim().to_string();
694 if !raw.is_empty() {
695 names.push(raw);
696 } else if let Some(name_node) = child.child_by_field_name("name") {
697 names.push(source[name_node.byte_range()].to_string());
698 }
699 }
700 if !cursor.goto_next_sibling() {
701 break;
702 }
703 }
704}
705
706fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
708 let mut cursor = node.walk();
709 if !cursor.goto_first_child() {
710 return;
711 }
712
713 loop {
714 let child = cursor.node();
715 if child.kind() == "identifier" {
716 *namespace_import = Some(source[child.byte_range()].to_string());
717 return;
718 }
719 if !cursor.goto_next_sibling() {
720 break;
721 }
722 }
723}
724
725fn generate_ts_import_line(
727 module_path: &str,
728 names: &[String],
729 default_import: Option<&str>,
730 type_only: bool,
731) -> String {
732 let type_prefix = if type_only { "type " } else { "" };
733
734 if names.is_empty() && default_import.is_none() {
736 return format!("import '{module_path}';");
737 }
738
739 if names.is_empty() {
741 if let Some(def) = default_import {
742 return format!("import {type_prefix}{def} from '{module_path}';");
743 }
744 }
745
746 if default_import.is_none() {
748 let mut sorted_names = names.to_vec();
749 sort_named_specifiers(&mut sorted_names);
750 let names_str = sorted_names.join(", ");
751 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
752 }
753
754 if let Some(def) = default_import {
756 let mut sorted_names = names.to_vec();
757 sort_named_specifiers(&mut sorted_names);
758 let names_str = sorted_names.join(", ");
759 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
760 }
761
762 format!("import '{module_path}';")
764}
765
766const PYTHON_STDLIB: &[&str] = &[
774 "__future__",
775 "_thread",
776 "abc",
777 "aifc",
778 "argparse",
779 "array",
780 "ast",
781 "asynchat",
782 "asyncio",
783 "asyncore",
784 "atexit",
785 "audioop",
786 "base64",
787 "bdb",
788 "binascii",
789 "bisect",
790 "builtins",
791 "bz2",
792 "calendar",
793 "cgi",
794 "cgitb",
795 "chunk",
796 "cmath",
797 "cmd",
798 "code",
799 "codecs",
800 "codeop",
801 "collections",
802 "colorsys",
803 "compileall",
804 "concurrent",
805 "configparser",
806 "contextlib",
807 "contextvars",
808 "copy",
809 "copyreg",
810 "cProfile",
811 "crypt",
812 "csv",
813 "ctypes",
814 "curses",
815 "dataclasses",
816 "datetime",
817 "dbm",
818 "decimal",
819 "difflib",
820 "dis",
821 "distutils",
822 "doctest",
823 "email",
824 "encodings",
825 "enum",
826 "errno",
827 "faulthandler",
828 "fcntl",
829 "filecmp",
830 "fileinput",
831 "fnmatch",
832 "fractions",
833 "ftplib",
834 "functools",
835 "gc",
836 "getopt",
837 "getpass",
838 "gettext",
839 "glob",
840 "grp",
841 "gzip",
842 "hashlib",
843 "heapq",
844 "hmac",
845 "html",
846 "http",
847 "idlelib",
848 "imaplib",
849 "imghdr",
850 "importlib",
851 "inspect",
852 "io",
853 "ipaddress",
854 "itertools",
855 "json",
856 "keyword",
857 "lib2to3",
858 "linecache",
859 "locale",
860 "logging",
861 "lzma",
862 "mailbox",
863 "mailcap",
864 "marshal",
865 "math",
866 "mimetypes",
867 "mmap",
868 "modulefinder",
869 "multiprocessing",
870 "netrc",
871 "numbers",
872 "operator",
873 "optparse",
874 "os",
875 "pathlib",
876 "pdb",
877 "pickle",
878 "pickletools",
879 "pipes",
880 "pkgutil",
881 "platform",
882 "plistlib",
883 "poplib",
884 "posixpath",
885 "pprint",
886 "profile",
887 "pstats",
888 "pty",
889 "pwd",
890 "py_compile",
891 "pyclbr",
892 "pydoc",
893 "queue",
894 "quopri",
895 "random",
896 "re",
897 "readline",
898 "reprlib",
899 "resource",
900 "rlcompleter",
901 "runpy",
902 "sched",
903 "secrets",
904 "select",
905 "selectors",
906 "shelve",
907 "shlex",
908 "shutil",
909 "signal",
910 "site",
911 "smtplib",
912 "sndhdr",
913 "socket",
914 "socketserver",
915 "sqlite3",
916 "ssl",
917 "stat",
918 "statistics",
919 "string",
920 "stringprep",
921 "struct",
922 "subprocess",
923 "symtable",
924 "sys",
925 "sysconfig",
926 "syslog",
927 "tabnanny",
928 "tarfile",
929 "tempfile",
930 "termios",
931 "textwrap",
932 "threading",
933 "time",
934 "timeit",
935 "tkinter",
936 "token",
937 "tokenize",
938 "tomllib",
939 "trace",
940 "traceback",
941 "tracemalloc",
942 "tty",
943 "turtle",
944 "types",
945 "typing",
946 "unicodedata",
947 "unittest",
948 "urllib",
949 "uuid",
950 "venv",
951 "warnings",
952 "wave",
953 "weakref",
954 "webbrowser",
955 "wsgiref",
956 "xml",
957 "xmlrpc",
958 "zipapp",
959 "zipfile",
960 "zipimport",
961 "zlib",
962];
963
964pub fn classify_group_py(module_path: &str) -> ImportGroup {
966 if module_path.starts_with('.') {
968 return ImportGroup::Internal;
969 }
970 let top_module = module_path.split('.').next().unwrap_or(module_path);
972 if PYTHON_STDLIB.contains(&top_module) {
973 ImportGroup::Stdlib
974 } else {
975 ImportGroup::External
976 }
977}
978
979fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
981 let root = tree.root_node();
982 let mut imports = Vec::new();
983
984 let mut cursor = root.walk();
985 if !cursor.goto_first_child() {
986 return ImportBlock::empty();
987 }
988
989 loop {
990 let node = cursor.node();
991 match node.kind() {
992 "import_statement" => {
993 if let Some(imp) = parse_py_import_statement(source, &node) {
994 imports.push(imp);
995 }
996 }
997 "import_from_statement" => {
998 if let Some(imp) = parse_py_import_from_statement(source, &node) {
999 imports.push(imp);
1000 }
1001 }
1002 _ => {}
1003 }
1004 if !cursor.goto_next_sibling() {
1005 break;
1006 }
1007 }
1008
1009 let byte_range = import_byte_range(&imports);
1010
1011 ImportBlock {
1012 imports,
1013 byte_range,
1014 }
1015}
1016
1017fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1019 let raw_text = source[node.byte_range()].to_string();
1020 let byte_range = node.byte_range();
1021
1022 let mut module_path = String::new();
1024 let mut c = node.walk();
1025 if c.goto_first_child() {
1026 loop {
1027 if c.node().kind() == "dotted_name" {
1028 module_path = source[c.node().byte_range()].to_string();
1029 break;
1030 }
1031 if !c.goto_next_sibling() {
1032 break;
1033 }
1034 }
1035 }
1036 if module_path.is_empty() {
1037 return None;
1038 }
1039
1040 let group = classify_group_py(&module_path);
1041
1042 Some(ImportStatement {
1043 module_path,
1044 names: Vec::new(),
1045 default_import: None,
1046 namespace_import: None,
1047 kind: ImportKind::Value,
1048 group,
1049 byte_range,
1050 raw_text,
1051 })
1052}
1053
1054fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1056 let raw_text = source[node.byte_range()].to_string();
1057 let byte_range = node.byte_range();
1058
1059 let mut module_path = String::new();
1060 let mut names = Vec::new();
1061
1062 let mut c = node.walk();
1063 if c.goto_first_child() {
1064 loop {
1065 let child = c.node();
1066 match child.kind() {
1067 "dotted_name" => {
1068 if module_path.is_empty()
1073 && !has_seen_import_keyword(source, node, child.start_byte())
1074 {
1075 module_path = source[child.byte_range()].to_string();
1076 } else {
1077 names.push(source[child.byte_range()].to_string());
1079 }
1080 }
1081 "relative_import" => {
1082 module_path = source[child.byte_range()].to_string();
1084 }
1085 _ => {}
1086 }
1087 if !c.goto_next_sibling() {
1088 break;
1089 }
1090 }
1091 }
1092
1093 if module_path.is_empty() {
1095 return None;
1096 }
1097
1098 let group = classify_group_py(&module_path);
1099
1100 Some(ImportStatement {
1101 module_path,
1102 names,
1103 default_import: None,
1104 namespace_import: None,
1105 kind: ImportKind::Value,
1106 group,
1107 byte_range,
1108 raw_text,
1109 })
1110}
1111
1112fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1114 let mut c = parent.walk();
1115 if c.goto_first_child() {
1116 loop {
1117 let child = c.node();
1118 if child.kind() == "import" && child.start_byte() < before_byte {
1119 return true;
1120 }
1121 if child.start_byte() >= before_byte {
1122 return false;
1123 }
1124 if !c.goto_next_sibling() {
1125 break;
1126 }
1127 }
1128 }
1129 false
1130}
1131
1132fn generate_py_import_line(
1134 module_path: &str,
1135 names: &[String],
1136 _default_import: Option<&str>,
1137) -> String {
1138 if names.is_empty() {
1139 format!("import {module_path}")
1141 } else {
1142 let mut sorted = names.to_vec();
1144 sorted.sort();
1145 let names_str = sorted.join(", ");
1146 format!("from {module_path} import {names_str}")
1147 }
1148}
1149
1150pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1156 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1158 match first_seg {
1159 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1160 "crate" | "self" | "super" => ImportGroup::Internal,
1161 _ => ImportGroup::External,
1162 }
1163}
1164
1165fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1167 let root = tree.root_node();
1168 let mut imports = Vec::new();
1169
1170 let mut cursor = root.walk();
1171 if !cursor.goto_first_child() {
1172 return ImportBlock::empty();
1173 }
1174
1175 loop {
1176 let node = cursor.node();
1177 if node.kind() == "use_declaration" {
1178 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1179 imports.push(imp);
1180 }
1181 }
1182 if !cursor.goto_next_sibling() {
1183 break;
1184 }
1185 }
1186
1187 let byte_range = import_byte_range(&imports);
1188
1189 ImportBlock {
1190 imports,
1191 byte_range,
1192 }
1193}
1194
1195fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1197 let raw_text = source[node.byte_range()].to_string();
1198 let byte_range = node.byte_range();
1199
1200 let mut has_pub = false;
1202 let mut use_path = String::new();
1203 let mut names = Vec::new();
1204
1205 let mut c = node.walk();
1206 if c.goto_first_child() {
1207 loop {
1208 let child = c.node();
1209 match child.kind() {
1210 "visibility_modifier" => {
1211 has_pub = true;
1212 }
1213 "scoped_identifier" | "identifier" | "use_as_clause" => {
1214 use_path = source[child.byte_range()].to_string();
1216 }
1217 "scoped_use_list" => {
1218 use_path = source[child.byte_range()].to_string();
1220 extract_rs_use_list_names(source, &child, &mut names);
1222 }
1223 _ => {}
1224 }
1225 if !c.goto_next_sibling() {
1226 break;
1227 }
1228 }
1229 }
1230
1231 if use_path.is_empty() {
1232 return None;
1233 }
1234
1235 let group = classify_group_rs(&use_path);
1236
1237 Some(ImportStatement {
1238 module_path: use_path,
1239 names,
1240 default_import: if has_pub {
1241 Some("pub".to_string())
1242 } else {
1243 None
1244 },
1245 namespace_import: None,
1246 kind: ImportKind::Value,
1247 group,
1248 byte_range,
1249 raw_text,
1250 })
1251}
1252
1253fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1255 let mut c = node.walk();
1256 if c.goto_first_child() {
1257 loop {
1258 let child = c.node();
1259 if child.kind() == "use_list" {
1260 let mut lc = child.walk();
1262 if lc.goto_first_child() {
1263 loop {
1264 let lchild = lc.node();
1265 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1266 names.push(source[lchild.byte_range()].to_string());
1267 }
1268 if !lc.goto_next_sibling() {
1269 break;
1270 }
1271 }
1272 }
1273 }
1274 if !c.goto_next_sibling() {
1275 break;
1276 }
1277 }
1278 }
1279}
1280
1281fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1283 if names.is_empty() {
1284 format!("use {module_path};")
1285 } else {
1286 format!("use {module_path};")
1290 }
1291}
1292
1293pub fn classify_group_go(module_path: &str) -> ImportGroup {
1299 if module_path.contains('.') {
1302 ImportGroup::External
1303 } else {
1304 ImportGroup::Stdlib
1305 }
1306}
1307
1308fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1310 let root = tree.root_node();
1311 let mut imports = Vec::new();
1312
1313 let mut cursor = root.walk();
1314 if !cursor.goto_first_child() {
1315 return ImportBlock::empty();
1316 }
1317
1318 loop {
1319 let node = cursor.node();
1320 if node.kind() == "import_declaration" {
1321 parse_go_import_declaration(source, &node, &mut imports);
1322 }
1323 if !cursor.goto_next_sibling() {
1324 break;
1325 }
1326 }
1327
1328 let byte_range = import_byte_range(&imports);
1329
1330 ImportBlock {
1331 imports,
1332 byte_range,
1333 }
1334}
1335
1336fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1338 let mut c = node.walk();
1339 if c.goto_first_child() {
1340 loop {
1341 let child = c.node();
1342 match child.kind() {
1343 "import_spec" => {
1344 if let Some(imp) = parse_go_import_spec(source, &child) {
1345 imports.push(imp);
1346 }
1347 }
1348 "import_spec_list" => {
1349 let mut lc = child.walk();
1351 if lc.goto_first_child() {
1352 loop {
1353 if lc.node().kind() == "import_spec" {
1354 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1355 imports.push(imp);
1356 }
1357 }
1358 if !lc.goto_next_sibling() {
1359 break;
1360 }
1361 }
1362 }
1363 }
1364 _ => {}
1365 }
1366 if !c.goto_next_sibling() {
1367 break;
1368 }
1369 }
1370 }
1371}
1372
1373fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1375 let raw_text = source[node.byte_range()].to_string();
1376 let byte_range = node.byte_range();
1377
1378 let mut import_path = String::new();
1379 let mut alias = None;
1380
1381 let mut c = node.walk();
1382 if c.goto_first_child() {
1383 loop {
1384 let child = c.node();
1385 match child.kind() {
1386 "interpreted_string_literal" => {
1387 let text = source[child.byte_range()].to_string();
1389 import_path = text.trim_matches('"').to_string();
1390 }
1391 "identifier" | "blank_identifier" | "dot" => {
1392 alias = Some(source[child.byte_range()].to_string());
1394 }
1395 _ => {}
1396 }
1397 if !c.goto_next_sibling() {
1398 break;
1399 }
1400 }
1401 }
1402
1403 if import_path.is_empty() {
1404 return None;
1405 }
1406
1407 let group = classify_group_go(&import_path);
1408
1409 Some(ImportStatement {
1410 module_path: import_path,
1411 names: Vec::new(),
1412 default_import: alias,
1413 namespace_import: None,
1414 kind: ImportKind::Value,
1415 group,
1416 byte_range,
1417 raw_text,
1418 })
1419}
1420
1421pub fn generate_go_import_line_pub(
1423 module_path: &str,
1424 alias: Option<&str>,
1425 in_group: bool,
1426) -> String {
1427 generate_go_import_line(module_path, alias, in_group)
1428}
1429
1430fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1435 if in_group {
1436 match alias {
1438 Some(a) => format!("\t{a} \"{module_path}\""),
1439 None => format!("\t\"{module_path}\""),
1440 }
1441 } else {
1442 match alias {
1444 Some(a) => format!("import {a} \"{module_path}\""),
1445 None => format!("import \"{module_path}\""),
1446 }
1447 }
1448}
1449
1450pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1453 let root = tree.root_node();
1454 let mut cursor = root.walk();
1455 if !cursor.goto_first_child() {
1456 return None;
1457 }
1458
1459 loop {
1460 let node = cursor.node();
1461 if node.kind() == "import_declaration" {
1462 let mut c = node.walk();
1463 if c.goto_first_child() {
1464 loop {
1465 if c.node().kind() == "import_spec_list" {
1466 return Some(c.node().byte_range());
1467 }
1468 if !c.goto_next_sibling() {
1469 break;
1470 }
1471 }
1472 }
1473 }
1474 if !cursor.goto_next_sibling() {
1475 break;
1476 }
1477 }
1478 None
1479}
1480
1481fn skip_newline(source: &str, pos: usize) -> usize {
1483 if pos < source.len() {
1484 let bytes = source.as_bytes();
1485 if bytes[pos] == b'\n' {
1486 return pos + 1;
1487 }
1488 if bytes[pos] == b'\r' {
1489 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1490 return pos + 2;
1491 }
1492 return pos + 1;
1493 }
1494 }
1495 pos
1496}
1497
1498#[cfg(test)]
1503mod tests {
1504 use super::*;
1505
1506 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1507 let grammar = grammar_for(LangId::TypeScript);
1508 let mut parser = Parser::new();
1509 parser.set_language(&grammar).unwrap();
1510 let tree = parser.parse(source, None).unwrap();
1511 let block = parse_imports(source, &tree, LangId::TypeScript);
1512 (tree, block)
1513 }
1514
1515 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1516 let grammar = grammar_for(LangId::JavaScript);
1517 let mut parser = Parser::new();
1518 parser.set_language(&grammar).unwrap();
1519 let tree = parser.parse(source, None).unwrap();
1520 let block = parse_imports(source, &tree, LangId::JavaScript);
1521 (tree, block)
1522 }
1523
1524 #[test]
1527 fn parse_ts_named_imports() {
1528 let source = "import { useState, useEffect } from 'react';\n";
1529 let (_, block) = parse_ts(source);
1530 assert_eq!(block.imports.len(), 1);
1531 let imp = &block.imports[0];
1532 assert_eq!(imp.module_path, "react");
1533 assert!(imp.names.contains(&"useState".to_string()));
1534 assert!(imp.names.contains(&"useEffect".to_string()));
1535 assert_eq!(imp.kind, ImportKind::Value);
1536 assert_eq!(imp.group, ImportGroup::External);
1537 }
1538
1539 #[test]
1540 fn parse_ts_default_import() {
1541 let source = "import React from 'react';\n";
1542 let (_, block) = parse_ts(source);
1543 assert_eq!(block.imports.len(), 1);
1544 let imp = &block.imports[0];
1545 assert_eq!(imp.default_import.as_deref(), Some("React"));
1546 assert_eq!(imp.kind, ImportKind::Value);
1547 }
1548
1549 #[test]
1550 fn parse_ts_side_effect_import() {
1551 let source = "import './styles.css';\n";
1552 let (_, block) = parse_ts(source);
1553 assert_eq!(block.imports.len(), 1);
1554 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1555 assert_eq!(block.imports[0].module_path, "./styles.css");
1556 }
1557
1558 #[test]
1559 fn parse_ts_relative_import() {
1560 let source = "import { helper } from './utils';\n";
1561 let (_, block) = parse_ts(source);
1562 assert_eq!(block.imports.len(), 1);
1563 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1564 }
1565
1566 #[test]
1567 fn parse_ts_multiple_groups() {
1568 let source = "\
1569import React from 'react';
1570import { useState } from 'react';
1571import { helper } from './utils';
1572import { Config } from '../config';
1573";
1574 let (_, block) = parse_ts(source);
1575 assert_eq!(block.imports.len(), 4);
1576
1577 let external: Vec<_> = block
1578 .imports
1579 .iter()
1580 .filter(|i| i.group == ImportGroup::External)
1581 .collect();
1582 let relative: Vec<_> = block
1583 .imports
1584 .iter()
1585 .filter(|i| i.group == ImportGroup::Internal)
1586 .collect();
1587 assert_eq!(external.len(), 2);
1588 assert_eq!(relative.len(), 2);
1589 }
1590
1591 #[test]
1592 fn parse_ts_namespace_import() {
1593 let source = "import * as path from 'path';\n";
1594 let (_, block) = parse_ts(source);
1595 assert_eq!(block.imports.len(), 1);
1596 let imp = &block.imports[0];
1597 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1598 assert_eq!(imp.kind, ImportKind::Value);
1599 }
1600
1601 #[test]
1602 fn parse_js_imports() {
1603 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1604 let (_, block) = parse_js(source);
1605 assert_eq!(block.imports.len(), 2);
1606 assert_eq!(block.imports[0].group, ImportGroup::External);
1607 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1608 }
1609
1610 #[test]
1613 fn classify_external() {
1614 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1615 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1616 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1617 }
1618
1619 #[test]
1620 fn classify_relative() {
1621 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1622 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1623 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1624 }
1625
1626 #[test]
1629 fn dedup_detects_same_named_import() {
1630 let source = "import { useState } from 'react';\n";
1631 let (_, block) = parse_ts(source);
1632 assert!(is_duplicate(
1633 &block,
1634 "react",
1635 &["useState".to_string()],
1636 None,
1637 false
1638 ));
1639 }
1640
1641 #[test]
1642 fn dedup_misses_different_name() {
1643 let source = "import { useState } from 'react';\n";
1644 let (_, block) = parse_ts(source);
1645 assert!(!is_duplicate(
1646 &block,
1647 "react",
1648 &["useEffect".to_string()],
1649 None,
1650 false
1651 ));
1652 }
1653
1654 #[test]
1655 fn dedup_detects_default_import() {
1656 let source = "import React from 'react';\n";
1657 let (_, block) = parse_ts(source);
1658 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1659 }
1660
1661 #[test]
1662 fn dedup_side_effect() {
1663 let source = "import './styles.css';\n";
1664 let (_, block) = parse_ts(source);
1665 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1666 }
1667
1668 #[test]
1669 fn dedup_namespace_import_distinct_from_side_effect_import() {
1670 let side_effect_source = "import 'fs';\n";
1671 let (_, side_effect_block) = parse_ts(side_effect_source);
1672 assert!(!is_duplicate_with_namespace(
1673 &side_effect_block,
1674 "fs",
1675 &[],
1676 None,
1677 Some("fs"),
1678 false
1679 ));
1680
1681 let namespace_source = "import * as fs from 'fs';\n";
1682 let (_, namespace_block) = parse_ts(namespace_source);
1683 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1684 assert!(is_duplicate_with_namespace(
1685 &namespace_block,
1686 "fs",
1687 &[],
1688 None,
1689 Some("fs"),
1690 false
1691 ));
1692 assert!(!is_duplicate_with_namespace(
1693 &namespace_block,
1694 "fs",
1695 &[],
1696 None,
1697 Some("other"),
1698 false
1699 ));
1700 }
1701
1702 #[test]
1703 fn dedup_type_vs_value() {
1704 let source = "import { FC } from 'react';\n";
1705 let (_, block) = parse_ts(source);
1706 assert!(!is_duplicate(
1708 &block,
1709 "react",
1710 &["FC".to_string()],
1711 None,
1712 true
1713 ));
1714 }
1715
1716 #[test]
1719 fn generate_named_import() {
1720 let line = generate_import_line(
1721 LangId::TypeScript,
1722 "react",
1723 &["useState".to_string(), "useEffect".to_string()],
1724 None,
1725 false,
1726 );
1727 assert_eq!(line, "import { useEffect, useState } from 'react';");
1728 }
1729
1730 #[test]
1731 fn generate_named_import_sorts_by_imported_name() {
1732 let line = generate_import_line(
1733 LangId::TypeScript,
1734 "x",
1735 &[
1736 "useState".to_string(),
1737 "type Foo".to_string(),
1738 "stdin as input".to_string(),
1739 "type Bar".to_string(),
1740 ],
1741 None,
1742 false,
1743 );
1744 assert_eq!(
1745 line,
1746 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1747 );
1748 }
1749
1750 #[test]
1751 fn generate_default_import() {
1752 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1753 assert_eq!(line, "import React from 'react';");
1754 }
1755
1756 #[test]
1757 fn generate_type_import() {
1758 let line =
1759 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1760 assert_eq!(line, "import type { FC } from 'react';");
1761 }
1762
1763 #[test]
1764 fn generate_side_effect_import() {
1765 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1766 assert_eq!(line, "import './styles.css';");
1767 }
1768
1769 #[test]
1770 fn generate_default_and_named() {
1771 let line = generate_import_line(
1772 LangId::TypeScript,
1773 "react",
1774 &["useState".to_string()],
1775 Some("React"),
1776 false,
1777 );
1778 assert_eq!(line, "import React, { useState } from 'react';");
1779 }
1780
1781 #[test]
1782 fn parse_ts_type_import() {
1783 let source = "import type { FC } from 'react';\n";
1784 let (_, block) = parse_ts(source);
1785 assert_eq!(block.imports.len(), 1);
1786 let imp = &block.imports[0];
1787 assert_eq!(imp.kind, ImportKind::Type);
1788 assert!(imp.names.contains(&"FC".to_string()));
1789 assert_eq!(imp.group, ImportGroup::External);
1790 }
1791
1792 #[test]
1795 fn insertion_empty_file() {
1796 let source = "";
1797 let (_, block) = parse_ts(source);
1798 let (offset, _, _) =
1799 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1800 assert_eq!(offset, 0);
1801 }
1802
1803 #[test]
1804 fn insertion_alphabetical_within_group() {
1805 let source = "\
1806import { a } from 'alpha';
1807import { c } from 'charlie';
1808";
1809 let (_, block) = parse_ts(source);
1810 let (offset, _, _) =
1811 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1812 let before_charlie = source.find("import { c }").unwrap();
1814 assert_eq!(offset, before_charlie);
1815 }
1816
1817 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1820 let grammar = grammar_for(LangId::Python);
1821 let mut parser = Parser::new();
1822 parser.set_language(&grammar).unwrap();
1823 let tree = parser.parse(source, None).unwrap();
1824 let block = parse_imports(source, &tree, LangId::Python);
1825 (tree, block)
1826 }
1827
1828 #[test]
1829 fn parse_py_import_statement() {
1830 let source = "import os\nimport sys\n";
1831 let (_, block) = parse_py(source);
1832 assert_eq!(block.imports.len(), 2);
1833 assert_eq!(block.imports[0].module_path, "os");
1834 assert_eq!(block.imports[1].module_path, "sys");
1835 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1836 }
1837
1838 #[test]
1839 fn parse_py_from_import() {
1840 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1841 let (_, block) = parse_py(source);
1842 assert_eq!(block.imports.len(), 2);
1843 assert_eq!(block.imports[0].module_path, "collections");
1844 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1845 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1846 assert_eq!(block.imports[1].module_path, "typing");
1847 assert!(block.imports[1].names.contains(&"List".to_string()));
1848 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1849 }
1850
1851 #[test]
1852 fn parse_py_relative_import() {
1853 let source = "from . import utils\nfrom ..config import Settings\n";
1854 let (_, block) = parse_py(source);
1855 assert_eq!(block.imports.len(), 2);
1856 assert_eq!(block.imports[0].module_path, ".");
1857 assert!(block.imports[0].names.contains(&"utils".to_string()));
1858 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1859 assert_eq!(block.imports[1].module_path, "..config");
1860 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1861 }
1862
1863 #[test]
1864 fn classify_py_groups() {
1865 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1866 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1867 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1868 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1869 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1870 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1871 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1872 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1873 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1874 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1875 }
1876
1877 #[test]
1878 fn parse_py_three_groups() {
1879 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1880 let (_, block) = parse_py(source);
1881 let stdlib: Vec<_> = block
1882 .imports
1883 .iter()
1884 .filter(|i| i.group == ImportGroup::Stdlib)
1885 .collect();
1886 let external: Vec<_> = block
1887 .imports
1888 .iter()
1889 .filter(|i| i.group == ImportGroup::External)
1890 .collect();
1891 let internal: Vec<_> = block
1892 .imports
1893 .iter()
1894 .filter(|i| i.group == ImportGroup::Internal)
1895 .collect();
1896 assert_eq!(stdlib.len(), 2);
1897 assert_eq!(external.len(), 1);
1898 assert_eq!(internal.len(), 1);
1899 }
1900
1901 #[test]
1902 fn generate_py_import() {
1903 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1904 assert_eq!(line, "import os");
1905 }
1906
1907 #[test]
1908 fn generate_py_from_import() {
1909 let line = generate_import_line(
1910 LangId::Python,
1911 "collections",
1912 &["OrderedDict".to_string()],
1913 None,
1914 false,
1915 );
1916 assert_eq!(line, "from collections import OrderedDict");
1917 }
1918
1919 #[test]
1920 fn generate_py_from_import_multiple() {
1921 let line = generate_import_line(
1922 LangId::Python,
1923 "typing",
1924 &["Optional".to_string(), "List".to_string()],
1925 None,
1926 false,
1927 );
1928 assert_eq!(line, "from typing import List, Optional");
1929 }
1930
1931 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1934 let grammar = grammar_for(LangId::Rust);
1935 let mut parser = Parser::new();
1936 parser.set_language(&grammar).unwrap();
1937 let tree = parser.parse(source, None).unwrap();
1938 let block = parse_imports(source, &tree, LangId::Rust);
1939 (tree, block)
1940 }
1941
1942 #[test]
1943 fn parse_rs_use_std() {
1944 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1945 let (_, block) = parse_rust(source);
1946 assert_eq!(block.imports.len(), 2);
1947 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1948 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1949 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1950 }
1951
1952 #[test]
1953 fn parse_rs_use_external() {
1954 let source = "use serde::{Deserialize, Serialize};\n";
1955 let (_, block) = parse_rust(source);
1956 assert_eq!(block.imports.len(), 1);
1957 assert_eq!(block.imports[0].group, ImportGroup::External);
1958 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1959 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1960 }
1961
1962 #[test]
1963 fn parse_rs_use_crate() {
1964 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1965 let (_, block) = parse_rust(source);
1966 assert_eq!(block.imports.len(), 2);
1967 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1968 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1969 }
1970
1971 #[test]
1972 fn parse_rs_pub_use() {
1973 let source = "pub use super::parent::Thing;\n";
1974 let (_, block) = parse_rust(source);
1975 assert_eq!(block.imports.len(), 1);
1976 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1978 }
1979
1980 #[test]
1981 fn classify_rs_groups() {
1982 assert_eq!(
1983 classify_group_rs("std::collections::HashMap"),
1984 ImportGroup::Stdlib
1985 );
1986 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
1987 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
1988 assert_eq!(
1989 classify_group_rs("serde::Deserialize"),
1990 ImportGroup::External
1991 );
1992 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
1993 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
1994 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
1995 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
1996 }
1997
1998 #[test]
1999 fn generate_rs_use() {
2000 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2001 assert_eq!(line, "use std::fmt::Display;");
2002 }
2003
2004 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2007 let grammar = grammar_for(LangId::Go);
2008 let mut parser = Parser::new();
2009 parser.set_language(&grammar).unwrap();
2010 let tree = parser.parse(source, None).unwrap();
2011 let block = parse_imports(source, &tree, LangId::Go);
2012 (tree, block)
2013 }
2014
2015 #[test]
2016 fn parse_go_single_import() {
2017 let source = "package main\n\nimport \"fmt\"\n";
2018 let (_, block) = parse_go(source);
2019 assert_eq!(block.imports.len(), 1);
2020 assert_eq!(block.imports[0].module_path, "fmt");
2021 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2022 }
2023
2024 #[test]
2025 fn parse_go_grouped_import() {
2026 let source =
2027 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2028 let (_, block) = parse_go(source);
2029 assert_eq!(block.imports.len(), 3);
2030 assert_eq!(block.imports[0].module_path, "fmt");
2031 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2032 assert_eq!(block.imports[1].module_path, "os");
2033 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2034 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2035 assert_eq!(block.imports[2].group, ImportGroup::External);
2036 }
2037
2038 #[test]
2039 fn parse_go_mixed_imports() {
2040 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2042 let (_, block) = parse_go(source);
2043 assert_eq!(block.imports.len(), 3);
2044 }
2045
2046 #[test]
2047 fn classify_go_groups() {
2048 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2049 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2050 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2051 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2052 assert_eq!(
2053 classify_group_go("github.com/pkg/errors"),
2054 ImportGroup::External
2055 );
2056 assert_eq!(
2057 classify_group_go("golang.org/x/tools"),
2058 ImportGroup::External
2059 );
2060 }
2061
2062 #[test]
2063 fn generate_go_standalone() {
2064 let line = generate_go_import_line("fmt", None, false);
2065 assert_eq!(line, "import \"fmt\"");
2066 }
2067
2068 #[test]
2069 fn generate_go_grouped_spec() {
2070 let line = generate_go_import_line("fmt", None, true);
2071 assert_eq!(line, "\t\"fmt\"");
2072 }
2073
2074 #[test]
2075 fn generate_go_with_alias() {
2076 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2077 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2078 }
2079}