1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C
178 | LangId::Cpp
179 | LangId::Zig
180 | LangId::CSharp
181 | LangId::Bash
182 | LangId::Solidity
183 | LangId::Vue
184 | LangId::Json
185 | LangId::Scala => ImportBlock::empty(),
186 LangId::Html | LangId::Markdown => ImportBlock::empty(),
187 }
188}
189
190pub fn is_duplicate(
196 block: &ImportBlock,
197 module_path: &str,
198 names: &[String],
199 default_import: Option<&str>,
200 type_only: bool,
201) -> bool {
202 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
203}
204
205pub fn is_duplicate_with_namespace(
207 block: &ImportBlock,
208 module_path: &str,
209 names: &[String],
210 default_import: Option<&str>,
211 namespace_import: Option<&str>,
212 type_only: bool,
213) -> bool {
214 let target_kind = if type_only {
215 ImportKind::Type
216 } else {
217 ImportKind::Value
218 };
219
220 for imp in &block.imports {
221 if imp.module_path != module_path {
222 continue;
223 }
224
225 if names.is_empty()
231 && default_import.is_none()
232 && namespace_import.is_none()
233 && imp.names.is_empty()
234 && imp.default_import.is_none()
235 && imp.namespace_import.is_none()
236 {
237 return true;
238 }
239
240 if names.is_empty()
242 && default_import.is_none()
243 && namespace_import.is_none()
244 && imp.kind == ImportKind::SideEffect
245 {
246 return true;
247 }
248
249 if names.is_empty()
250 && default_import.is_none()
251 && namespace_import.is_some()
252 && imp.names.is_empty()
253 && imp.default_import.is_none()
254 && imp.namespace_import.as_deref() == namespace_import
255 {
256 return true;
257 }
258
259 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
261 continue;
262 }
263
264 if let Some(def) = default_import {
266 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
267 return true;
268 }
269 }
270
271 if !names.is_empty()
277 && names
278 .iter()
279 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
280 {
281 return true;
282 }
283 }
284
285 false
286}
287
288fn sort_named_specifiers(names: &mut [String]) {
289 names.sort_by(|a, b| {
290 specifier_imported_name(a)
291 .cmp(specifier_imported_name(b))
292 .then_with(|| a.cmp(b))
293 });
294}
295
296pub fn find_insertion_point(
307 source: &str,
308 block: &ImportBlock,
309 group: ImportGroup,
310 module_path: &str,
311 type_only: bool,
312) -> (usize, bool, bool) {
313 if block.imports.is_empty() {
314 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
316 }
317
318 let target_kind = if type_only {
319 ImportKind::Type
320 } else {
321 ImportKind::Value
322 };
323
324 let group_imports: Vec<&ImportStatement> =
326 block.imports.iter().filter(|i| i.group == group).collect();
327
328 if group_imports.is_empty() {
329 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
332
333 if let Some(last) = preceding_last {
334 let end = last.byte_range.end;
335 let insert_at = skip_newline(source, end);
336 return (insert_at, true, true);
337 }
338
339 let following_first = block.imports.iter().find(|i| i.group > group);
341
342 if let Some(first) = following_first {
343 return (first.byte_range.start, false, true);
344 }
345
346 let first_byte = import_byte_range(&block.imports)
348 .map(|range| range.start)
349 .unwrap_or(0);
350 return (first_byte, false, true);
351 }
352
353 for imp in &group_imports {
355 let cmp = module_path.cmp(&imp.module_path);
356 match cmp {
357 std::cmp::Ordering::Less => {
358 return (imp.byte_range.start, false, false);
360 }
361 std::cmp::Ordering::Equal => {
362 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
364 let end = imp.byte_range.end;
366 let insert_at = skip_newline(source, end);
367 return (insert_at, false, false);
368 }
369 return (imp.byte_range.start, false, false);
371 }
372 std::cmp::Ordering::Greater => continue,
373 }
374 }
375
376 let Some(last) = group_imports.last() else {
378 return (
379 import_byte_range(&block.imports)
380 .map(|range| range.end)
381 .unwrap_or(0),
382 false,
383 false,
384 );
385 };
386 let end = last.byte_range.end;
387 let insert_at = skip_newline(source, end);
388 (insert_at, false, false)
389}
390
391pub fn generate_import_line(
393 lang: LangId,
394 module_path: &str,
395 names: &[String],
396 default_import: Option<&str>,
397 type_only: bool,
398) -> String {
399 match lang {
400 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
401 generate_ts_import_line(module_path, names, default_import, type_only)
402 }
403 LangId::Python => generate_py_import_line(module_path, names, default_import),
404 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
405 LangId::Go => generate_go_import_line(module_path, default_import, false),
406 LangId::C
407 | LangId::Cpp
408 | LangId::Zig
409 | LangId::CSharp
410 | LangId::Bash
411 | LangId::Solidity
412 | LangId::Vue
413 | LangId::Json
414 | LangId::Scala => String::new(),
415 LangId::Html | LangId::Markdown => String::new(),
416 }
417}
418
419pub fn is_supported(lang: LangId) -> bool {
421 matches!(
422 lang,
423 LangId::TypeScript
424 | LangId::Tsx
425 | LangId::JavaScript
426 | LangId::Python
427 | LangId::Rust
428 | LangId::Go
429 )
430}
431
432pub fn classify_group_ts(module_path: &str) -> ImportGroup {
434 if module_path.starts_with('.') {
435 ImportGroup::Internal
436 } else {
437 ImportGroup::External
438 }
439}
440
441pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
443 match lang {
444 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
445 LangId::Python => classify_group_py(module_path),
446 LangId::Rust => classify_group_rs(module_path),
447 LangId::Go => classify_group_go(module_path),
448 LangId::C
449 | LangId::Cpp
450 | LangId::Zig
451 | LangId::CSharp
452 | LangId::Bash
453 | LangId::Solidity
454 | LangId::Vue
455 | LangId::Json
456 | LangId::Scala => ImportGroup::External,
457 LangId::Html | LangId::Markdown => ImportGroup::External,
458 }
459}
460
461pub fn parse_file_imports(
464 path: &std::path::Path,
465 lang: LangId,
466) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
467 let source =
468 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
469 path: format!("{}: {}", path.display(), e),
470 })?;
471
472 let grammar = grammar_for(lang);
473 let mut parser = Parser::new();
474 parser
475 .set_language(&grammar)
476 .map_err(|e| crate::error::AftError::ParseError {
477 message: format!("grammar init failed for {:?}: {}", lang, e),
478 })?;
479
480 let tree = parser
481 .parse(&source, None)
482 .ok_or_else(|| crate::error::AftError::ParseError {
483 message: format!("tree-sitter parse returned None for {}", path.display()),
484 })?;
485
486 let block = parse_imports(&source, &tree, lang);
487 Ok((source, tree, block))
488}
489
490fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
498 let root = tree.root_node();
499 let mut imports = Vec::new();
500
501 let mut cursor = root.walk();
502 if !cursor.goto_first_child() {
503 return ImportBlock::empty();
504 }
505
506 loop {
507 let node = cursor.node();
508 if node.kind() == "import_statement" {
509 if let Some(imp) = parse_single_ts_import(source, &node) {
510 imports.push(imp);
511 }
512 }
513 if !cursor.goto_next_sibling() {
514 break;
515 }
516 }
517
518 let byte_range = import_byte_range(&imports);
519
520 ImportBlock {
521 imports,
522 byte_range,
523 }
524}
525
526fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
528 let raw_text = source[node.byte_range()].to_string();
529 let byte_range = node.byte_range();
530
531 let module_path = extract_module_path(source, node)?;
533
534 let is_type_only = has_type_keyword(node);
536
537 let mut names = Vec::new();
539 let mut default_import = None;
540 let mut namespace_import = None;
541
542 let mut child_cursor = node.walk();
543 if child_cursor.goto_first_child() {
544 loop {
545 let child = child_cursor.node();
546 match child.kind() {
547 "import_clause" => {
548 extract_import_clause(
549 source,
550 &child,
551 &mut names,
552 &mut default_import,
553 &mut namespace_import,
554 );
555 }
556 "identifier" => {
558 let text = &source[child.byte_range()];
559 if text != "import" && text != "from" && text != "type" {
560 default_import = Some(text.to_string());
561 }
562 }
563 _ => {}
564 }
565 if !child_cursor.goto_next_sibling() {
566 break;
567 }
568 }
569 }
570
571 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
573 ImportKind::SideEffect
574 } else if is_type_only {
575 ImportKind::Type
576 } else {
577 ImportKind::Value
578 };
579
580 let group = classify_group_ts(&module_path);
581
582 Some(ImportStatement {
583 module_path,
584 names,
585 default_import,
586 namespace_import,
587 kind,
588 group,
589 byte_range,
590 raw_text,
591 })
592}
593
594fn extract_module_path(source: &str, node: &Node) -> Option<String> {
598 let mut cursor = node.walk();
599 if !cursor.goto_first_child() {
600 return None;
601 }
602
603 loop {
604 let child = cursor.node();
605 if child.kind() == "string" {
606 let text = &source[child.byte_range()];
608 let stripped = text
609 .trim_start_matches(|c| c == '\'' || c == '"')
610 .trim_end_matches(|c| c == '\'' || c == '"');
611 return Some(stripped.to_string());
612 }
613 if !cursor.goto_next_sibling() {
614 break;
615 }
616 }
617 None
618}
619
620fn has_type_keyword(node: &Node) -> bool {
625 let mut cursor = node.walk();
626 if !cursor.goto_first_child() {
627 return false;
628 }
629
630 loop {
631 let child = cursor.node();
632 if child.kind() == "type" {
633 return true;
634 }
635 if !cursor.goto_next_sibling() {
636 break;
637 }
638 }
639
640 false
641}
642
643fn extract_import_clause(
645 source: &str,
646 node: &Node,
647 names: &mut Vec<String>,
648 default_import: &mut Option<String>,
649 namespace_import: &mut Option<String>,
650) {
651 let mut cursor = node.walk();
652 if !cursor.goto_first_child() {
653 return;
654 }
655
656 loop {
657 let child = cursor.node();
658 match child.kind() {
659 "identifier" => {
660 let text = &source[child.byte_range()];
662 if text != "type" {
663 *default_import = Some(text.to_string());
664 }
665 }
666 "named_imports" => {
667 extract_named_imports(source, &child, names);
669 }
670 "namespace_import" => {
671 extract_namespace_import(source, &child, namespace_import);
673 }
674 _ => {}
675 }
676 if !cursor.goto_next_sibling() {
677 break;
678 }
679 }
680}
681
682fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
701 let mut cursor = node.walk();
702 if !cursor.goto_first_child() {
703 return;
704 }
705
706 loop {
707 let child = cursor.node();
708 if child.kind() == "import_specifier" {
709 let raw = source[child.byte_range()].trim().to_string();
714 if !raw.is_empty() {
715 names.push(raw);
716 } else if let Some(name_node) = child.child_by_field_name("name") {
717 names.push(source[name_node.byte_range()].to_string());
718 }
719 }
720 if !cursor.goto_next_sibling() {
721 break;
722 }
723 }
724}
725
726fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
728 let mut cursor = node.walk();
729 if !cursor.goto_first_child() {
730 return;
731 }
732
733 loop {
734 let child = cursor.node();
735 if child.kind() == "identifier" {
736 *namespace_import = Some(source[child.byte_range()].to_string());
737 return;
738 }
739 if !cursor.goto_next_sibling() {
740 break;
741 }
742 }
743}
744
745fn generate_ts_import_line(
747 module_path: &str,
748 names: &[String],
749 default_import: Option<&str>,
750 type_only: bool,
751) -> String {
752 let type_prefix = if type_only { "type " } else { "" };
753
754 if names.is_empty() && default_import.is_none() {
756 return format!("import '{module_path}';");
757 }
758
759 if names.is_empty() {
761 if let Some(def) = default_import {
762 return format!("import {type_prefix}{def} from '{module_path}';");
763 }
764 }
765
766 if default_import.is_none() {
768 let mut sorted_names = names.to_vec();
769 sort_named_specifiers(&mut sorted_names);
770 let names_str = sorted_names.join(", ");
771 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
772 }
773
774 if let Some(def) = default_import {
776 let mut sorted_names = names.to_vec();
777 sort_named_specifiers(&mut sorted_names);
778 let names_str = sorted_names.join(", ");
779 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
780 }
781
782 format!("import '{module_path}';")
784}
785
786const PYTHON_STDLIB: &[&str] = &[
794 "__future__",
795 "_thread",
796 "abc",
797 "aifc",
798 "argparse",
799 "array",
800 "ast",
801 "asynchat",
802 "asyncio",
803 "asyncore",
804 "atexit",
805 "audioop",
806 "base64",
807 "bdb",
808 "binascii",
809 "bisect",
810 "builtins",
811 "bz2",
812 "calendar",
813 "cgi",
814 "cgitb",
815 "chunk",
816 "cmath",
817 "cmd",
818 "code",
819 "codecs",
820 "codeop",
821 "collections",
822 "colorsys",
823 "compileall",
824 "concurrent",
825 "configparser",
826 "contextlib",
827 "contextvars",
828 "copy",
829 "copyreg",
830 "cProfile",
831 "crypt",
832 "csv",
833 "ctypes",
834 "curses",
835 "dataclasses",
836 "datetime",
837 "dbm",
838 "decimal",
839 "difflib",
840 "dis",
841 "distutils",
842 "doctest",
843 "email",
844 "encodings",
845 "enum",
846 "errno",
847 "faulthandler",
848 "fcntl",
849 "filecmp",
850 "fileinput",
851 "fnmatch",
852 "fractions",
853 "ftplib",
854 "functools",
855 "gc",
856 "getopt",
857 "getpass",
858 "gettext",
859 "glob",
860 "grp",
861 "gzip",
862 "hashlib",
863 "heapq",
864 "hmac",
865 "html",
866 "http",
867 "idlelib",
868 "imaplib",
869 "imghdr",
870 "importlib",
871 "inspect",
872 "io",
873 "ipaddress",
874 "itertools",
875 "json",
876 "keyword",
877 "lib2to3",
878 "linecache",
879 "locale",
880 "logging",
881 "lzma",
882 "mailbox",
883 "mailcap",
884 "marshal",
885 "math",
886 "mimetypes",
887 "mmap",
888 "modulefinder",
889 "multiprocessing",
890 "netrc",
891 "numbers",
892 "operator",
893 "optparse",
894 "os",
895 "pathlib",
896 "pdb",
897 "pickle",
898 "pickletools",
899 "pipes",
900 "pkgutil",
901 "platform",
902 "plistlib",
903 "poplib",
904 "posixpath",
905 "pprint",
906 "profile",
907 "pstats",
908 "pty",
909 "pwd",
910 "py_compile",
911 "pyclbr",
912 "pydoc",
913 "queue",
914 "quopri",
915 "random",
916 "re",
917 "readline",
918 "reprlib",
919 "resource",
920 "rlcompleter",
921 "runpy",
922 "sched",
923 "secrets",
924 "select",
925 "selectors",
926 "shelve",
927 "shlex",
928 "shutil",
929 "signal",
930 "site",
931 "smtplib",
932 "sndhdr",
933 "socket",
934 "socketserver",
935 "sqlite3",
936 "ssl",
937 "stat",
938 "statistics",
939 "string",
940 "stringprep",
941 "struct",
942 "subprocess",
943 "symtable",
944 "sys",
945 "sysconfig",
946 "syslog",
947 "tabnanny",
948 "tarfile",
949 "tempfile",
950 "termios",
951 "textwrap",
952 "threading",
953 "time",
954 "timeit",
955 "tkinter",
956 "token",
957 "tokenize",
958 "tomllib",
959 "trace",
960 "traceback",
961 "tracemalloc",
962 "tty",
963 "turtle",
964 "types",
965 "typing",
966 "unicodedata",
967 "unittest",
968 "urllib",
969 "uuid",
970 "venv",
971 "warnings",
972 "wave",
973 "weakref",
974 "webbrowser",
975 "wsgiref",
976 "xml",
977 "xmlrpc",
978 "zipapp",
979 "zipfile",
980 "zipimport",
981 "zlib",
982];
983
984pub fn classify_group_py(module_path: &str) -> ImportGroup {
986 if module_path.starts_with('.') {
988 return ImportGroup::Internal;
989 }
990 let top_module = module_path.split('.').next().unwrap_or(module_path);
992 if PYTHON_STDLIB.contains(&top_module) {
993 ImportGroup::Stdlib
994 } else {
995 ImportGroup::External
996 }
997}
998
999fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
1001 let root = tree.root_node();
1002 let mut imports = Vec::new();
1003
1004 let mut cursor = root.walk();
1005 if !cursor.goto_first_child() {
1006 return ImportBlock::empty();
1007 }
1008
1009 loop {
1010 let node = cursor.node();
1011 match node.kind() {
1012 "import_statement" => {
1013 if let Some(imp) = parse_py_import_statement(source, &node) {
1014 imports.push(imp);
1015 }
1016 }
1017 "import_from_statement" => {
1018 if let Some(imp) = parse_py_import_from_statement(source, &node) {
1019 imports.push(imp);
1020 }
1021 }
1022 _ => {}
1023 }
1024 if !cursor.goto_next_sibling() {
1025 break;
1026 }
1027 }
1028
1029 let byte_range = import_byte_range(&imports);
1030
1031 ImportBlock {
1032 imports,
1033 byte_range,
1034 }
1035}
1036
1037fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1039 let raw_text = source[node.byte_range()].to_string();
1040 let byte_range = node.byte_range();
1041
1042 let mut module_path = String::new();
1044 let mut c = node.walk();
1045 if c.goto_first_child() {
1046 loop {
1047 if c.node().kind() == "dotted_name" {
1048 module_path = source[c.node().byte_range()].to_string();
1049 break;
1050 }
1051 if !c.goto_next_sibling() {
1052 break;
1053 }
1054 }
1055 }
1056 if module_path.is_empty() {
1057 return None;
1058 }
1059
1060 let group = classify_group_py(&module_path);
1061
1062 Some(ImportStatement {
1063 module_path,
1064 names: Vec::new(),
1065 default_import: None,
1066 namespace_import: None,
1067 kind: ImportKind::Value,
1068 group,
1069 byte_range,
1070 raw_text,
1071 })
1072}
1073
1074fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1076 let raw_text = source[node.byte_range()].to_string();
1077 let byte_range = node.byte_range();
1078
1079 let mut module_path = String::new();
1080 let mut names = Vec::new();
1081
1082 let mut c = node.walk();
1083 if c.goto_first_child() {
1084 loop {
1085 let child = c.node();
1086 match child.kind() {
1087 "dotted_name" => {
1088 if module_path.is_empty()
1093 && !has_seen_import_keyword(source, node, child.start_byte())
1094 {
1095 module_path = source[child.byte_range()].to_string();
1096 } else {
1097 names.push(source[child.byte_range()].to_string());
1099 }
1100 }
1101 "relative_import" => {
1102 module_path = source[child.byte_range()].to_string();
1104 }
1105 _ => {}
1106 }
1107 if !c.goto_next_sibling() {
1108 break;
1109 }
1110 }
1111 }
1112
1113 if module_path.is_empty() {
1115 return None;
1116 }
1117
1118 let group = classify_group_py(&module_path);
1119
1120 Some(ImportStatement {
1121 module_path,
1122 names,
1123 default_import: None,
1124 namespace_import: None,
1125 kind: ImportKind::Value,
1126 group,
1127 byte_range,
1128 raw_text,
1129 })
1130}
1131
1132fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1134 let mut c = parent.walk();
1135 if c.goto_first_child() {
1136 loop {
1137 let child = c.node();
1138 if child.kind() == "import" && child.start_byte() < before_byte {
1139 return true;
1140 }
1141 if child.start_byte() >= before_byte {
1142 return false;
1143 }
1144 if !c.goto_next_sibling() {
1145 break;
1146 }
1147 }
1148 }
1149 false
1150}
1151
1152fn generate_py_import_line(
1154 module_path: &str,
1155 names: &[String],
1156 _default_import: Option<&str>,
1157) -> String {
1158 if names.is_empty() {
1159 format!("import {module_path}")
1161 } else {
1162 let mut sorted = names.to_vec();
1164 sorted.sort();
1165 let names_str = sorted.join(", ");
1166 format!("from {module_path} import {names_str}")
1167 }
1168}
1169
1170pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1176 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1178 match first_seg {
1179 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1180 "crate" | "self" | "super" => ImportGroup::Internal,
1181 _ => ImportGroup::External,
1182 }
1183}
1184
1185fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1187 let root = tree.root_node();
1188 let mut imports = Vec::new();
1189
1190 let mut cursor = root.walk();
1191 if !cursor.goto_first_child() {
1192 return ImportBlock::empty();
1193 }
1194
1195 loop {
1196 let node = cursor.node();
1197 if node.kind() == "use_declaration" {
1198 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1199 imports.push(imp);
1200 }
1201 }
1202 if !cursor.goto_next_sibling() {
1203 break;
1204 }
1205 }
1206
1207 let byte_range = import_byte_range(&imports);
1208
1209 ImportBlock {
1210 imports,
1211 byte_range,
1212 }
1213}
1214
1215fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1217 let raw_text = source[node.byte_range()].to_string();
1218 let byte_range = node.byte_range();
1219
1220 let mut has_pub = false;
1222 let mut use_path = String::new();
1223 let mut names = Vec::new();
1224
1225 let mut c = node.walk();
1226 if c.goto_first_child() {
1227 loop {
1228 let child = c.node();
1229 match child.kind() {
1230 "visibility_modifier" => {
1231 has_pub = true;
1232 }
1233 "scoped_identifier" | "identifier" | "use_as_clause" => {
1234 use_path = source[child.byte_range()].to_string();
1236 }
1237 "scoped_use_list" => {
1238 use_path = source[child.byte_range()].to_string();
1240 extract_rs_use_list_names(source, &child, &mut names);
1242 }
1243 _ => {}
1244 }
1245 if !c.goto_next_sibling() {
1246 break;
1247 }
1248 }
1249 }
1250
1251 if use_path.is_empty() {
1252 return None;
1253 }
1254
1255 let group = classify_group_rs(&use_path);
1256
1257 Some(ImportStatement {
1258 module_path: use_path,
1259 names,
1260 default_import: if has_pub {
1261 Some("pub".to_string())
1262 } else {
1263 None
1264 },
1265 namespace_import: None,
1266 kind: ImportKind::Value,
1267 group,
1268 byte_range,
1269 raw_text,
1270 })
1271}
1272
1273fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1275 let mut c = node.walk();
1276 if c.goto_first_child() {
1277 loop {
1278 let child = c.node();
1279 if child.kind() == "use_list" {
1280 let mut lc = child.walk();
1282 if lc.goto_first_child() {
1283 loop {
1284 let lchild = lc.node();
1285 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1286 names.push(source[lchild.byte_range()].to_string());
1287 }
1288 if !lc.goto_next_sibling() {
1289 break;
1290 }
1291 }
1292 }
1293 }
1294 if !c.goto_next_sibling() {
1295 break;
1296 }
1297 }
1298 }
1299}
1300
1301fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1303 if names.is_empty() {
1304 format!("use {module_path};")
1305 } else {
1306 format!("use {module_path};")
1310 }
1311}
1312
1313pub fn classify_group_go(module_path: &str) -> ImportGroup {
1319 if module_path.contains('.') {
1322 ImportGroup::External
1323 } else {
1324 ImportGroup::Stdlib
1325 }
1326}
1327
1328fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1330 let root = tree.root_node();
1331 let mut imports = Vec::new();
1332
1333 let mut cursor = root.walk();
1334 if !cursor.goto_first_child() {
1335 return ImportBlock::empty();
1336 }
1337
1338 loop {
1339 let node = cursor.node();
1340 if node.kind() == "import_declaration" {
1341 parse_go_import_declaration(source, &node, &mut imports);
1342 }
1343 if !cursor.goto_next_sibling() {
1344 break;
1345 }
1346 }
1347
1348 let byte_range = import_byte_range(&imports);
1349
1350 ImportBlock {
1351 imports,
1352 byte_range,
1353 }
1354}
1355
1356fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1358 let mut c = node.walk();
1359 if c.goto_first_child() {
1360 loop {
1361 let child = c.node();
1362 match child.kind() {
1363 "import_spec" => {
1364 if let Some(imp) = parse_go_import_spec(source, &child) {
1365 imports.push(imp);
1366 }
1367 }
1368 "import_spec_list" => {
1369 let mut lc = child.walk();
1371 if lc.goto_first_child() {
1372 loop {
1373 if lc.node().kind() == "import_spec" {
1374 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1375 imports.push(imp);
1376 }
1377 }
1378 if !lc.goto_next_sibling() {
1379 break;
1380 }
1381 }
1382 }
1383 }
1384 _ => {}
1385 }
1386 if !c.goto_next_sibling() {
1387 break;
1388 }
1389 }
1390 }
1391}
1392
1393fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1395 let raw_text = source[node.byte_range()].to_string();
1396 let byte_range = node.byte_range();
1397
1398 let mut import_path = String::new();
1399 let mut alias = None;
1400
1401 let mut c = node.walk();
1402 if c.goto_first_child() {
1403 loop {
1404 let child = c.node();
1405 match child.kind() {
1406 "interpreted_string_literal" => {
1407 let text = source[child.byte_range()].to_string();
1409 import_path = text.trim_matches('"').to_string();
1410 }
1411 "identifier" | "blank_identifier" | "dot" => {
1412 alias = Some(source[child.byte_range()].to_string());
1414 }
1415 _ => {}
1416 }
1417 if !c.goto_next_sibling() {
1418 break;
1419 }
1420 }
1421 }
1422
1423 if import_path.is_empty() {
1424 return None;
1425 }
1426
1427 let group = classify_group_go(&import_path);
1428
1429 Some(ImportStatement {
1430 module_path: import_path,
1431 names: Vec::new(),
1432 default_import: alias,
1433 namespace_import: None,
1434 kind: ImportKind::Value,
1435 group,
1436 byte_range,
1437 raw_text,
1438 })
1439}
1440
1441pub fn generate_go_import_line_pub(
1443 module_path: &str,
1444 alias: Option<&str>,
1445 in_group: bool,
1446) -> String {
1447 generate_go_import_line(module_path, alias, in_group)
1448}
1449
1450fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1455 if in_group {
1456 match alias {
1458 Some(a) => format!("\t{a} \"{module_path}\""),
1459 None => format!("\t\"{module_path}\""),
1460 }
1461 } else {
1462 match alias {
1464 Some(a) => format!("import {a} \"{module_path}\""),
1465 None => format!("import \"{module_path}\""),
1466 }
1467 }
1468}
1469
1470pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1473 let root = tree.root_node();
1474 let mut cursor = root.walk();
1475 if !cursor.goto_first_child() {
1476 return None;
1477 }
1478
1479 loop {
1480 let node = cursor.node();
1481 if node.kind() == "import_declaration" {
1482 let mut c = node.walk();
1483 if c.goto_first_child() {
1484 loop {
1485 if c.node().kind() == "import_spec_list" {
1486 return Some(node.byte_range());
1487 }
1488 if !c.goto_next_sibling() {
1489 break;
1490 }
1491 }
1492 }
1493 }
1494 if !cursor.goto_next_sibling() {
1495 break;
1496 }
1497 }
1498 None
1499}
1500
1501fn skip_newline(source: &str, pos: usize) -> usize {
1503 if pos < source.len() {
1504 let bytes = source.as_bytes();
1505 if bytes[pos] == b'\n' {
1506 return pos + 1;
1507 }
1508 if bytes[pos] == b'\r' {
1509 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1510 return pos + 2;
1511 }
1512 return pos + 1;
1513 }
1514 }
1515 pos
1516}
1517
1518#[cfg(test)]
1523mod tests {
1524 use super::*;
1525
1526 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1527 let grammar = grammar_for(LangId::TypeScript);
1528 let mut parser = Parser::new();
1529 parser.set_language(&grammar).unwrap();
1530 let tree = parser.parse(source, None).unwrap();
1531 let block = parse_imports(source, &tree, LangId::TypeScript);
1532 (tree, block)
1533 }
1534
1535 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1536 let grammar = grammar_for(LangId::JavaScript);
1537 let mut parser = Parser::new();
1538 parser.set_language(&grammar).unwrap();
1539 let tree = parser.parse(source, None).unwrap();
1540 let block = parse_imports(source, &tree, LangId::JavaScript);
1541 (tree, block)
1542 }
1543
1544 #[test]
1547 fn parse_ts_named_imports() {
1548 let source = "import { useState, useEffect } from 'react';\n";
1549 let (_, block) = parse_ts(source);
1550 assert_eq!(block.imports.len(), 1);
1551 let imp = &block.imports[0];
1552 assert_eq!(imp.module_path, "react");
1553 assert!(imp.names.contains(&"useState".to_string()));
1554 assert!(imp.names.contains(&"useEffect".to_string()));
1555 assert_eq!(imp.kind, ImportKind::Value);
1556 assert_eq!(imp.group, ImportGroup::External);
1557 }
1558
1559 #[test]
1560 fn parse_ts_default_import() {
1561 let source = "import React from 'react';\n";
1562 let (_, block) = parse_ts(source);
1563 assert_eq!(block.imports.len(), 1);
1564 let imp = &block.imports[0];
1565 assert_eq!(imp.default_import.as_deref(), Some("React"));
1566 assert_eq!(imp.kind, ImportKind::Value);
1567 }
1568
1569 #[test]
1570 fn parse_ts_side_effect_import() {
1571 let source = "import './styles.css';\n";
1572 let (_, block) = parse_ts(source);
1573 assert_eq!(block.imports.len(), 1);
1574 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1575 assert_eq!(block.imports[0].module_path, "./styles.css");
1576 }
1577
1578 #[test]
1579 fn parse_ts_relative_import() {
1580 let source = "import { helper } from './utils';\n";
1581 let (_, block) = parse_ts(source);
1582 assert_eq!(block.imports.len(), 1);
1583 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1584 }
1585
1586 #[test]
1587 fn parse_ts_multiple_groups() {
1588 let source = "\
1589import React from 'react';
1590import { useState } from 'react';
1591import { helper } from './utils';
1592import { Config } from '../config';
1593";
1594 let (_, block) = parse_ts(source);
1595 assert_eq!(block.imports.len(), 4);
1596
1597 let external: Vec<_> = block
1598 .imports
1599 .iter()
1600 .filter(|i| i.group == ImportGroup::External)
1601 .collect();
1602 let relative: Vec<_> = block
1603 .imports
1604 .iter()
1605 .filter(|i| i.group == ImportGroup::Internal)
1606 .collect();
1607 assert_eq!(external.len(), 2);
1608 assert_eq!(relative.len(), 2);
1609 }
1610
1611 #[test]
1612 fn parse_ts_namespace_import() {
1613 let source = "import * as path from 'path';\n";
1614 let (_, block) = parse_ts(source);
1615 assert_eq!(block.imports.len(), 1);
1616 let imp = &block.imports[0];
1617 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1618 assert_eq!(imp.kind, ImportKind::Value);
1619 }
1620
1621 #[test]
1622 fn parse_js_imports() {
1623 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1624 let (_, block) = parse_js(source);
1625 assert_eq!(block.imports.len(), 2);
1626 assert_eq!(block.imports[0].group, ImportGroup::External);
1627 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1628 }
1629
1630 #[test]
1633 fn classify_external() {
1634 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1635 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1636 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1637 }
1638
1639 #[test]
1640 fn classify_relative() {
1641 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1642 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1643 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1644 }
1645
1646 #[test]
1649 fn dedup_detects_same_named_import() {
1650 let source = "import { useState } from 'react';\n";
1651 let (_, block) = parse_ts(source);
1652 assert!(is_duplicate(
1653 &block,
1654 "react",
1655 &["useState".to_string()],
1656 None,
1657 false
1658 ));
1659 }
1660
1661 #[test]
1662 fn dedup_misses_different_name() {
1663 let source = "import { useState } from 'react';\n";
1664 let (_, block) = parse_ts(source);
1665 assert!(!is_duplicate(
1666 &block,
1667 "react",
1668 &["useEffect".to_string()],
1669 None,
1670 false
1671 ));
1672 }
1673
1674 #[test]
1675 fn dedup_detects_default_import() {
1676 let source = "import React from 'react';\n";
1677 let (_, block) = parse_ts(source);
1678 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1679 }
1680
1681 #[test]
1682 fn dedup_side_effect() {
1683 let source = "import './styles.css';\n";
1684 let (_, block) = parse_ts(source);
1685 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1686 }
1687
1688 #[test]
1689 fn dedup_namespace_import_distinct_from_side_effect_import() {
1690 let side_effect_source = "import 'fs';\n";
1691 let (_, side_effect_block) = parse_ts(side_effect_source);
1692 assert!(!is_duplicate_with_namespace(
1693 &side_effect_block,
1694 "fs",
1695 &[],
1696 None,
1697 Some("fs"),
1698 false
1699 ));
1700
1701 let namespace_source = "import * as fs from 'fs';\n";
1702 let (_, namespace_block) = parse_ts(namespace_source);
1703 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1704 assert!(is_duplicate_with_namespace(
1705 &namespace_block,
1706 "fs",
1707 &[],
1708 None,
1709 Some("fs"),
1710 false
1711 ));
1712 assert!(!is_duplicate_with_namespace(
1713 &namespace_block,
1714 "fs",
1715 &[],
1716 None,
1717 Some("other"),
1718 false
1719 ));
1720 }
1721
1722 #[test]
1723 fn dedup_type_vs_value() {
1724 let source = "import { FC } from 'react';\n";
1725 let (_, block) = parse_ts(source);
1726 assert!(!is_duplicate(
1728 &block,
1729 "react",
1730 &["FC".to_string()],
1731 None,
1732 true
1733 ));
1734 }
1735
1736 #[test]
1739 fn generate_named_import() {
1740 let line = generate_import_line(
1741 LangId::TypeScript,
1742 "react",
1743 &["useState".to_string(), "useEffect".to_string()],
1744 None,
1745 false,
1746 );
1747 assert_eq!(line, "import { useEffect, useState } from 'react';");
1748 }
1749
1750 #[test]
1751 fn generate_named_import_sorts_by_imported_name() {
1752 let line = generate_import_line(
1753 LangId::TypeScript,
1754 "x",
1755 &[
1756 "useState".to_string(),
1757 "type Foo".to_string(),
1758 "stdin as input".to_string(),
1759 "type Bar".to_string(),
1760 ],
1761 None,
1762 false,
1763 );
1764 assert_eq!(
1765 line,
1766 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1767 );
1768 }
1769
1770 #[test]
1771 fn generate_default_import() {
1772 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1773 assert_eq!(line, "import React from 'react';");
1774 }
1775
1776 #[test]
1777 fn generate_type_import() {
1778 let line =
1779 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1780 assert_eq!(line, "import type { FC } from 'react';");
1781 }
1782
1783 #[test]
1784 fn generate_side_effect_import() {
1785 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1786 assert_eq!(line, "import './styles.css';");
1787 }
1788
1789 #[test]
1790 fn generate_default_and_named() {
1791 let line = generate_import_line(
1792 LangId::TypeScript,
1793 "react",
1794 &["useState".to_string()],
1795 Some("React"),
1796 false,
1797 );
1798 assert_eq!(line, "import React, { useState } from 'react';");
1799 }
1800
1801 #[test]
1802 fn parse_ts_type_import() {
1803 let source = "import type { FC } from 'react';\n";
1804 let (_, block) = parse_ts(source);
1805 assert_eq!(block.imports.len(), 1);
1806 let imp = &block.imports[0];
1807 assert_eq!(imp.kind, ImportKind::Type);
1808 assert!(imp.names.contains(&"FC".to_string()));
1809 assert_eq!(imp.group, ImportGroup::External);
1810 }
1811
1812 #[test]
1815 fn insertion_empty_file() {
1816 let source = "";
1817 let (_, block) = parse_ts(source);
1818 let (offset, _, _) =
1819 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1820 assert_eq!(offset, 0);
1821 }
1822
1823 #[test]
1824 fn insertion_alphabetical_within_group() {
1825 let source = "\
1826import { a } from 'alpha';
1827import { c } from 'charlie';
1828";
1829 let (_, block) = parse_ts(source);
1830 let (offset, _, _) =
1831 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1832 let before_charlie = source.find("import { c }").unwrap();
1834 assert_eq!(offset, before_charlie);
1835 }
1836
1837 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1840 let grammar = grammar_for(LangId::Python);
1841 let mut parser = Parser::new();
1842 parser.set_language(&grammar).unwrap();
1843 let tree = parser.parse(source, None).unwrap();
1844 let block = parse_imports(source, &tree, LangId::Python);
1845 (tree, block)
1846 }
1847
1848 #[test]
1849 fn parse_py_import_statement() {
1850 let source = "import os\nimport sys\n";
1851 let (_, block) = parse_py(source);
1852 assert_eq!(block.imports.len(), 2);
1853 assert_eq!(block.imports[0].module_path, "os");
1854 assert_eq!(block.imports[1].module_path, "sys");
1855 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1856 }
1857
1858 #[test]
1859 fn parse_py_from_import() {
1860 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1861 let (_, block) = parse_py(source);
1862 assert_eq!(block.imports.len(), 2);
1863 assert_eq!(block.imports[0].module_path, "collections");
1864 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1865 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1866 assert_eq!(block.imports[1].module_path, "typing");
1867 assert!(block.imports[1].names.contains(&"List".to_string()));
1868 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1869 }
1870
1871 #[test]
1872 fn parse_py_relative_import() {
1873 let source = "from . import utils\nfrom ..config import Settings\n";
1874 let (_, block) = parse_py(source);
1875 assert_eq!(block.imports.len(), 2);
1876 assert_eq!(block.imports[0].module_path, ".");
1877 assert!(block.imports[0].names.contains(&"utils".to_string()));
1878 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1879 assert_eq!(block.imports[1].module_path, "..config");
1880 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1881 }
1882
1883 #[test]
1884 fn classify_py_groups() {
1885 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1886 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1887 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1888 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1889 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1890 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1891 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1892 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1893 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1894 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1895 }
1896
1897 #[test]
1898 fn parse_py_three_groups() {
1899 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1900 let (_, block) = parse_py(source);
1901 let stdlib: Vec<_> = block
1902 .imports
1903 .iter()
1904 .filter(|i| i.group == ImportGroup::Stdlib)
1905 .collect();
1906 let external: Vec<_> = block
1907 .imports
1908 .iter()
1909 .filter(|i| i.group == ImportGroup::External)
1910 .collect();
1911 let internal: Vec<_> = block
1912 .imports
1913 .iter()
1914 .filter(|i| i.group == ImportGroup::Internal)
1915 .collect();
1916 assert_eq!(stdlib.len(), 2);
1917 assert_eq!(external.len(), 1);
1918 assert_eq!(internal.len(), 1);
1919 }
1920
1921 #[test]
1922 fn generate_py_import() {
1923 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1924 assert_eq!(line, "import os");
1925 }
1926
1927 #[test]
1928 fn generate_py_from_import() {
1929 let line = generate_import_line(
1930 LangId::Python,
1931 "collections",
1932 &["OrderedDict".to_string()],
1933 None,
1934 false,
1935 );
1936 assert_eq!(line, "from collections import OrderedDict");
1937 }
1938
1939 #[test]
1940 fn generate_py_from_import_multiple() {
1941 let line = generate_import_line(
1942 LangId::Python,
1943 "typing",
1944 &["Optional".to_string(), "List".to_string()],
1945 None,
1946 false,
1947 );
1948 assert_eq!(line, "from typing import List, Optional");
1949 }
1950
1951 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
1954 let grammar = grammar_for(LangId::Rust);
1955 let mut parser = Parser::new();
1956 parser.set_language(&grammar).unwrap();
1957 let tree = parser.parse(source, None).unwrap();
1958 let block = parse_imports(source, &tree, LangId::Rust);
1959 (tree, block)
1960 }
1961
1962 #[test]
1963 fn parse_rs_use_std() {
1964 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
1965 let (_, block) = parse_rust(source);
1966 assert_eq!(block.imports.len(), 2);
1967 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
1968 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1969 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
1970 }
1971
1972 #[test]
1973 fn parse_rs_use_external() {
1974 let source = "use serde::{Deserialize, Serialize};\n";
1975 let (_, block) = parse_rust(source);
1976 assert_eq!(block.imports.len(), 1);
1977 assert_eq!(block.imports[0].group, ImportGroup::External);
1978 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
1979 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
1980 }
1981
1982 #[test]
1983 fn parse_rs_use_crate() {
1984 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
1985 let (_, block) = parse_rust(source);
1986 assert_eq!(block.imports.len(), 2);
1987 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1988 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1989 }
1990
1991 #[test]
1992 fn parse_rs_pub_use() {
1993 let source = "pub use super::parent::Thing;\n";
1994 let (_, block) = parse_rust(source);
1995 assert_eq!(block.imports.len(), 1);
1996 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
1998 }
1999
2000 #[test]
2001 fn classify_rs_groups() {
2002 assert_eq!(
2003 classify_group_rs("std::collections::HashMap"),
2004 ImportGroup::Stdlib
2005 );
2006 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
2007 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
2008 assert_eq!(
2009 classify_group_rs("serde::Deserialize"),
2010 ImportGroup::External
2011 );
2012 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
2013 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
2014 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
2015 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
2016 }
2017
2018 #[test]
2019 fn generate_rs_use() {
2020 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2021 assert_eq!(line, "use std::fmt::Display;");
2022 }
2023
2024 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2027 let grammar = grammar_for(LangId::Go);
2028 let mut parser = Parser::new();
2029 parser.set_language(&grammar).unwrap();
2030 let tree = parser.parse(source, None).unwrap();
2031 let block = parse_imports(source, &tree, LangId::Go);
2032 (tree, block)
2033 }
2034
2035 #[test]
2036 fn parse_go_single_import() {
2037 let source = "package main\n\nimport \"fmt\"\n";
2038 let (_, block) = parse_go(source);
2039 assert_eq!(block.imports.len(), 1);
2040 assert_eq!(block.imports[0].module_path, "fmt");
2041 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2042 }
2043
2044 #[test]
2045 fn parse_go_grouped_import() {
2046 let source =
2047 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2048 let (_, block) = parse_go(source);
2049 assert_eq!(block.imports.len(), 3);
2050 assert_eq!(block.imports[0].module_path, "fmt");
2051 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2052 assert_eq!(block.imports[1].module_path, "os");
2053 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2054 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2055 assert_eq!(block.imports[2].group, ImportGroup::External);
2056 }
2057
2058 #[test]
2059 fn parse_go_mixed_imports() {
2060 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2062 let (_, block) = parse_go(source);
2063 assert_eq!(block.imports.len(), 3);
2064 }
2065
2066 #[test]
2067 fn classify_go_groups() {
2068 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2069 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2070 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2071 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2072 assert_eq!(
2073 classify_group_go("github.com/pkg/errors"),
2074 ImportGroup::External
2075 );
2076 assert_eq!(
2077 classify_group_go("golang.org/x/tools"),
2078 ImportGroup::External
2079 );
2080 }
2081
2082 #[test]
2083 fn generate_go_standalone() {
2084 let line = generate_go_import_line("fmt", None, false);
2085 assert_eq!(line, "import \"fmt\"");
2086 }
2087
2088 #[test]
2089 fn generate_go_grouped_spec() {
2090 let line = generate_go_import_line("fmt", None, true);
2091 assert_eq!(line, "\t\"fmt\"");
2092 }
2093
2094 #[test]
2095 fn generate_go_with_alias() {
2096 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2097 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2098 }
2099}