1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C
178 | LangId::Cpp
179 | LangId::Zig
180 | LangId::CSharp
181 | LangId::Bash
182 | LangId::Solidity
183 | LangId::Vue
184 | LangId::Json
185 | LangId::Scala => ImportBlock::empty(),
186 LangId::Html | LangId::Markdown => ImportBlock::empty(),
187 }
188}
189
190pub fn is_duplicate(
196 block: &ImportBlock,
197 module_path: &str,
198 names: &[String],
199 default_import: Option<&str>,
200 type_only: bool,
201) -> bool {
202 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
203}
204
205pub fn is_duplicate_with_namespace(
207 block: &ImportBlock,
208 module_path: &str,
209 names: &[String],
210 default_import: Option<&str>,
211 namespace_import: Option<&str>,
212 type_only: bool,
213) -> bool {
214 let target_kind = if type_only {
215 ImportKind::Type
216 } else {
217 ImportKind::Value
218 };
219
220 for imp in &block.imports {
221 if imp.module_path != module_path {
222 continue;
223 }
224
225 if names.is_empty()
231 && default_import.is_none()
232 && namespace_import.is_none()
233 && imp.names.is_empty()
234 && imp.default_import.is_none()
235 && imp.namespace_import.is_none()
236 {
237 return true;
238 }
239
240 if names.is_empty()
242 && default_import.is_none()
243 && namespace_import.is_none()
244 && imp.kind == ImportKind::SideEffect
245 {
246 return true;
247 }
248
249 if names.is_empty()
250 && default_import.is_none()
251 && namespace_import.is_some()
252 && imp.names.is_empty()
253 && imp.default_import.is_none()
254 && imp.namespace_import.as_deref() == namespace_import
255 {
256 return true;
257 }
258
259 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
261 continue;
262 }
263
264 if let Some(def) = default_import {
266 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
267 return true;
268 }
269 }
270
271 if !names.is_empty()
277 && names
278 .iter()
279 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
280 {
281 return true;
282 }
283 }
284
285 false
286}
287
288fn sort_named_specifiers(names: &mut [String]) {
289 names.sort_by(|a, b| {
290 specifier_imported_name(a)
291 .cmp(specifier_imported_name(b))
292 .then_with(|| a.cmp(b))
293 });
294}
295
296pub fn find_insertion_point(
307 source: &str,
308 block: &ImportBlock,
309 group: ImportGroup,
310 module_path: &str,
311 type_only: bool,
312) -> (usize, bool, bool) {
313 if block.imports.is_empty() {
314 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
316 }
317
318 let target_kind = if type_only {
319 ImportKind::Type
320 } else {
321 ImportKind::Value
322 };
323
324 let group_imports: Vec<&ImportStatement> =
326 block.imports.iter().filter(|i| i.group == group).collect();
327
328 if group_imports.is_empty() {
329 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
332
333 if let Some(last) = preceding_last {
334 let end = last.byte_range.end;
335 let insert_at = skip_newline(source, end);
336 return (insert_at, true, true);
337 }
338
339 let following_first = block.imports.iter().find(|i| i.group > group);
341
342 if let Some(first) = following_first {
343 return (first.byte_range.start, false, true);
344 }
345
346 let first_byte = import_byte_range(&block.imports)
348 .map(|range| range.start)
349 .unwrap_or(0);
350 return (first_byte, false, true);
351 }
352
353 for imp in &group_imports {
355 let cmp = module_path.cmp(&imp.module_path);
356 match cmp {
357 std::cmp::Ordering::Less => {
358 return (imp.byte_range.start, false, false);
360 }
361 std::cmp::Ordering::Equal => {
362 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
364 let end = imp.byte_range.end;
366 let insert_at = skip_newline(source, end);
367 return (insert_at, false, false);
368 }
369 return (imp.byte_range.start, false, false);
371 }
372 std::cmp::Ordering::Greater => continue,
373 }
374 }
375
376 let Some(last) = group_imports.last() else {
378 return (
379 import_byte_range(&block.imports)
380 .map(|range| range.end)
381 .unwrap_or(0),
382 false,
383 false,
384 );
385 };
386 let end = last.byte_range.end;
387 let insert_at = skip_newline(source, end);
388 (insert_at, false, false)
389}
390
391pub fn generate_import_line(
393 lang: LangId,
394 module_path: &str,
395 names: &[String],
396 default_import: Option<&str>,
397 type_only: bool,
398) -> String {
399 generate_import_line_with_namespace(lang, module_path, names, default_import, None, type_only)
400}
401
402pub fn generate_import_line_with_namespace(
405 lang: LangId,
406 module_path: &str,
407 names: &[String],
408 default_import: Option<&str>,
409 namespace_import: Option<&str>,
410 type_only: bool,
411) -> String {
412 match lang {
413 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => generate_ts_import_line(
414 module_path,
415 names,
416 default_import,
417 namespace_import,
418 type_only,
419 ),
420 LangId::Python => generate_py_import_line(module_path, names, default_import),
421 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
422 LangId::Go => generate_go_import_line(module_path, default_import, false),
423 LangId::C
424 | LangId::Cpp
425 | LangId::Zig
426 | LangId::CSharp
427 | LangId::Bash
428 | LangId::Solidity
429 | LangId::Vue
430 | LangId::Json
431 | LangId::Scala => String::new(),
432 LangId::Html | LangId::Markdown => String::new(),
433 }
434}
435
436pub fn is_supported(lang: LangId) -> bool {
438 matches!(
439 lang,
440 LangId::TypeScript
441 | LangId::Tsx
442 | LangId::JavaScript
443 | LangId::Python
444 | LangId::Rust
445 | LangId::Go
446 )
447}
448
449pub fn classify_group_ts(module_path: &str) -> ImportGroup {
451 if module_path.starts_with('.') {
452 ImportGroup::Internal
453 } else {
454 ImportGroup::External
455 }
456}
457
458pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
460 match lang {
461 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
462 LangId::Python => classify_group_py(module_path),
463 LangId::Rust => classify_group_rs(module_path),
464 LangId::Go => classify_group_go(module_path),
465 LangId::C
466 | LangId::Cpp
467 | LangId::Zig
468 | LangId::CSharp
469 | LangId::Bash
470 | LangId::Solidity
471 | LangId::Vue
472 | LangId::Json
473 | LangId::Scala => ImportGroup::External,
474 LangId::Html | LangId::Markdown => ImportGroup::External,
475 }
476}
477
478pub fn parse_file_imports(
481 path: &std::path::Path,
482 lang: LangId,
483) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
484 let source =
485 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
486 path: format!("{}: {}", path.display(), e),
487 })?;
488
489 let grammar = grammar_for(lang);
490 let mut parser = Parser::new();
491 parser
492 .set_language(&grammar)
493 .map_err(|e| crate::error::AftError::ParseError {
494 message: format!("grammar init failed for {:?}: {}", lang, e),
495 })?;
496
497 let tree = parser
498 .parse(&source, None)
499 .ok_or_else(|| crate::error::AftError::ParseError {
500 message: format!("tree-sitter parse returned None for {}", path.display()),
501 })?;
502
503 let block = parse_imports(&source, &tree, lang);
504 Ok((source, tree, block))
505}
506
507fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
515 let root = tree.root_node();
516 let mut imports = Vec::new();
517
518 let mut cursor = root.walk();
519 if !cursor.goto_first_child() {
520 return ImportBlock::empty();
521 }
522
523 loop {
524 let node = cursor.node();
525 if node.kind() == "import_statement" {
526 if let Some(imp) = parse_single_ts_import(source, &node) {
527 imports.push(imp);
528 }
529 }
530 if !cursor.goto_next_sibling() {
531 break;
532 }
533 }
534
535 let byte_range = import_byte_range(&imports);
536
537 ImportBlock {
538 imports,
539 byte_range,
540 }
541}
542
543fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
545 let raw_text = source[node.byte_range()].to_string();
546 let byte_range = node.byte_range();
547
548 let module_path = extract_module_path(source, node)?;
550
551 let is_type_only = has_type_keyword(node);
553
554 let mut names = Vec::new();
556 let mut default_import = None;
557 let mut namespace_import = None;
558
559 let mut child_cursor = node.walk();
560 if child_cursor.goto_first_child() {
561 loop {
562 let child = child_cursor.node();
563 match child.kind() {
564 "import_clause" => {
565 extract_import_clause(
566 source,
567 &child,
568 &mut names,
569 &mut default_import,
570 &mut namespace_import,
571 );
572 }
573 "identifier" => {
575 let text = &source[child.byte_range()];
576 if text != "import" && text != "from" && text != "type" {
577 default_import = Some(text.to_string());
578 }
579 }
580 _ => {}
581 }
582 if !child_cursor.goto_next_sibling() {
583 break;
584 }
585 }
586 }
587
588 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
590 ImportKind::SideEffect
591 } else if is_type_only {
592 ImportKind::Type
593 } else {
594 ImportKind::Value
595 };
596
597 let group = classify_group_ts(&module_path);
598
599 Some(ImportStatement {
600 module_path,
601 names,
602 default_import,
603 namespace_import,
604 kind,
605 group,
606 byte_range,
607 raw_text,
608 })
609}
610
611fn extract_module_path(source: &str, node: &Node) -> Option<String> {
615 let mut cursor = node.walk();
616 if !cursor.goto_first_child() {
617 return None;
618 }
619
620 loop {
621 let child = cursor.node();
622 if child.kind() == "string" {
623 let text = &source[child.byte_range()];
625 let stripped = text
626 .trim_start_matches(|c| c == '\'' || c == '"')
627 .trim_end_matches(|c| c == '\'' || c == '"');
628 return Some(stripped.to_string());
629 }
630 if !cursor.goto_next_sibling() {
631 break;
632 }
633 }
634 None
635}
636
637fn has_type_keyword(node: &Node) -> bool {
642 let mut cursor = node.walk();
643 if !cursor.goto_first_child() {
644 return false;
645 }
646
647 loop {
648 let child = cursor.node();
649 if child.kind() == "type" {
650 return true;
651 }
652 if !cursor.goto_next_sibling() {
653 break;
654 }
655 }
656
657 false
658}
659
660fn extract_import_clause(
662 source: &str,
663 node: &Node,
664 names: &mut Vec<String>,
665 default_import: &mut Option<String>,
666 namespace_import: &mut Option<String>,
667) {
668 let mut cursor = node.walk();
669 if !cursor.goto_first_child() {
670 return;
671 }
672
673 loop {
674 let child = cursor.node();
675 match child.kind() {
676 "identifier" => {
677 let text = &source[child.byte_range()];
679 if text != "type" {
680 *default_import = Some(text.to_string());
681 }
682 }
683 "named_imports" => {
684 extract_named_imports(source, &child, names);
686 }
687 "namespace_import" => {
688 extract_namespace_import(source, &child, namespace_import);
690 }
691 _ => {}
692 }
693 if !cursor.goto_next_sibling() {
694 break;
695 }
696 }
697}
698
699fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
718 let mut cursor = node.walk();
719 if !cursor.goto_first_child() {
720 return;
721 }
722
723 loop {
724 let child = cursor.node();
725 if child.kind() == "import_specifier" {
726 let raw = source[child.byte_range()].trim().to_string();
731 if !raw.is_empty() {
732 names.push(raw);
733 } else if let Some(name_node) = child.child_by_field_name("name") {
734 names.push(source[name_node.byte_range()].to_string());
735 }
736 }
737 if !cursor.goto_next_sibling() {
738 break;
739 }
740 }
741}
742
743fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
745 let mut cursor = node.walk();
746 if !cursor.goto_first_child() {
747 return;
748 }
749
750 loop {
751 let child = cursor.node();
752 if child.kind() == "identifier" {
753 *namespace_import = Some(source[child.byte_range()].to_string());
754 return;
755 }
756 if !cursor.goto_next_sibling() {
757 break;
758 }
759 }
760}
761
762fn generate_ts_import_line(
764 module_path: &str,
765 names: &[String],
766 default_import: Option<&str>,
767 namespace_import: Option<&str>,
768 type_only: bool,
769) -> String {
770 let type_prefix = if type_only { "type " } else { "" };
771
772 if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
774 return format!("import '{module_path}';");
775 }
776
777 if names.is_empty() && default_import.is_none() {
779 if let Some(namespace) = namespace_import {
780 return format!("import {type_prefix}* as {namespace} from '{module_path}';");
781 }
782 }
783
784 if names.is_empty() {
786 if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
787 return format!("import {type_prefix}{def}, * as {namespace} from '{module_path}';");
788 }
789 }
790
791 if names.is_empty() && namespace_import.is_none() {
793 if let Some(def) = default_import {
794 return format!("import {type_prefix}{def} from '{module_path}';");
795 }
796 }
797
798 if default_import.is_none() && namespace_import.is_none() {
800 let mut sorted_names = names.to_vec();
801 sort_named_specifiers(&mut sorted_names);
802 let names_str = sorted_names.join(", ");
803 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
804 }
805
806 if default_import.is_none() {
808 if let Some(namespace) = namespace_import {
809 let mut sorted_names = names.to_vec();
810 sort_named_specifiers(&mut sorted_names);
811 let names_str = sorted_names.join(", ");
812 return format!(
813 "import {type_prefix}{{ {names_str} }}, * as {namespace} from '{module_path}';"
814 );
815 }
816 }
817
818 if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
820 let mut sorted_names = names.to_vec();
821 sort_named_specifiers(&mut sorted_names);
822 let names_str = sorted_names.join(", ");
823 return format!(
824 "import {type_prefix}{def}, {{ {names_str} }}, * as {namespace} from '{module_path}';"
825 );
826 }
827
828 if let Some(def) = default_import {
830 let mut sorted_names = names.to_vec();
831 sort_named_specifiers(&mut sorted_names);
832 let names_str = sorted_names.join(", ");
833 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
834 }
835
836 format!("import '{module_path}';")
838}
839
840const PYTHON_STDLIB: &[&str] = &[
848 "__future__",
849 "_thread",
850 "abc",
851 "aifc",
852 "argparse",
853 "array",
854 "ast",
855 "asynchat",
856 "asyncio",
857 "asyncore",
858 "atexit",
859 "audioop",
860 "base64",
861 "bdb",
862 "binascii",
863 "bisect",
864 "builtins",
865 "bz2",
866 "calendar",
867 "cgi",
868 "cgitb",
869 "chunk",
870 "cmath",
871 "cmd",
872 "code",
873 "codecs",
874 "codeop",
875 "collections",
876 "colorsys",
877 "compileall",
878 "concurrent",
879 "configparser",
880 "contextlib",
881 "contextvars",
882 "copy",
883 "copyreg",
884 "cProfile",
885 "crypt",
886 "csv",
887 "ctypes",
888 "curses",
889 "dataclasses",
890 "datetime",
891 "dbm",
892 "decimal",
893 "difflib",
894 "dis",
895 "distutils",
896 "doctest",
897 "email",
898 "encodings",
899 "enum",
900 "errno",
901 "faulthandler",
902 "fcntl",
903 "filecmp",
904 "fileinput",
905 "fnmatch",
906 "fractions",
907 "ftplib",
908 "functools",
909 "gc",
910 "getopt",
911 "getpass",
912 "gettext",
913 "glob",
914 "grp",
915 "gzip",
916 "hashlib",
917 "heapq",
918 "hmac",
919 "html",
920 "http",
921 "idlelib",
922 "imaplib",
923 "imghdr",
924 "importlib",
925 "inspect",
926 "io",
927 "ipaddress",
928 "itertools",
929 "json",
930 "keyword",
931 "lib2to3",
932 "linecache",
933 "locale",
934 "logging",
935 "lzma",
936 "mailbox",
937 "mailcap",
938 "marshal",
939 "math",
940 "mimetypes",
941 "mmap",
942 "modulefinder",
943 "multiprocessing",
944 "netrc",
945 "numbers",
946 "operator",
947 "optparse",
948 "os",
949 "pathlib",
950 "pdb",
951 "pickle",
952 "pickletools",
953 "pipes",
954 "pkgutil",
955 "platform",
956 "plistlib",
957 "poplib",
958 "posixpath",
959 "pprint",
960 "profile",
961 "pstats",
962 "pty",
963 "pwd",
964 "py_compile",
965 "pyclbr",
966 "pydoc",
967 "queue",
968 "quopri",
969 "random",
970 "re",
971 "readline",
972 "reprlib",
973 "resource",
974 "rlcompleter",
975 "runpy",
976 "sched",
977 "secrets",
978 "select",
979 "selectors",
980 "shelve",
981 "shlex",
982 "shutil",
983 "signal",
984 "site",
985 "smtplib",
986 "sndhdr",
987 "socket",
988 "socketserver",
989 "sqlite3",
990 "ssl",
991 "stat",
992 "statistics",
993 "string",
994 "stringprep",
995 "struct",
996 "subprocess",
997 "symtable",
998 "sys",
999 "sysconfig",
1000 "syslog",
1001 "tabnanny",
1002 "tarfile",
1003 "tempfile",
1004 "termios",
1005 "textwrap",
1006 "threading",
1007 "time",
1008 "timeit",
1009 "tkinter",
1010 "token",
1011 "tokenize",
1012 "tomllib",
1013 "trace",
1014 "traceback",
1015 "tracemalloc",
1016 "tty",
1017 "turtle",
1018 "types",
1019 "typing",
1020 "unicodedata",
1021 "unittest",
1022 "urllib",
1023 "uuid",
1024 "venv",
1025 "warnings",
1026 "wave",
1027 "weakref",
1028 "webbrowser",
1029 "wsgiref",
1030 "xml",
1031 "xmlrpc",
1032 "zipapp",
1033 "zipfile",
1034 "zipimport",
1035 "zlib",
1036];
1037
1038pub fn classify_group_py(module_path: &str) -> ImportGroup {
1040 if module_path.starts_with('.') {
1042 return ImportGroup::Internal;
1043 }
1044 let top_module = module_path.split('.').next().unwrap_or(module_path);
1046 if PYTHON_STDLIB.contains(&top_module) {
1047 ImportGroup::Stdlib
1048 } else {
1049 ImportGroup::External
1050 }
1051}
1052
1053fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
1055 let root = tree.root_node();
1056 let mut imports = Vec::new();
1057
1058 let mut cursor = root.walk();
1059 if !cursor.goto_first_child() {
1060 return ImportBlock::empty();
1061 }
1062
1063 loop {
1064 let node = cursor.node();
1065 match node.kind() {
1066 "import_statement" => {
1067 if let Some(imp) = parse_py_import_statement(source, &node) {
1068 imports.push(imp);
1069 }
1070 }
1071 "import_from_statement" => {
1072 if let Some(imp) = parse_py_import_from_statement(source, &node) {
1073 imports.push(imp);
1074 }
1075 }
1076 _ => {}
1077 }
1078 if !cursor.goto_next_sibling() {
1079 break;
1080 }
1081 }
1082
1083 let byte_range = import_byte_range(&imports);
1084
1085 ImportBlock {
1086 imports,
1087 byte_range,
1088 }
1089}
1090
1091fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1093 let raw_text = source[node.byte_range()].to_string();
1094 let byte_range = node.byte_range();
1095
1096 let mut module_path = String::new();
1098 let mut c = node.walk();
1099 if c.goto_first_child() {
1100 loop {
1101 if c.node().kind() == "dotted_name" {
1102 module_path = source[c.node().byte_range()].to_string();
1103 break;
1104 }
1105 if !c.goto_next_sibling() {
1106 break;
1107 }
1108 }
1109 }
1110 if module_path.is_empty() {
1111 return None;
1112 }
1113
1114 let group = classify_group_py(&module_path);
1115
1116 Some(ImportStatement {
1117 module_path,
1118 names: Vec::new(),
1119 default_import: None,
1120 namespace_import: None,
1121 kind: ImportKind::Value,
1122 group,
1123 byte_range,
1124 raw_text,
1125 })
1126}
1127
1128fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1130 let raw_text = source[node.byte_range()].to_string();
1131 let byte_range = node.byte_range();
1132
1133 let mut module_path = String::new();
1134 let mut names = Vec::new();
1135
1136 let mut c = node.walk();
1137 if c.goto_first_child() {
1138 loop {
1139 let child = c.node();
1140 match child.kind() {
1141 "dotted_name" => {
1142 if module_path.is_empty()
1147 && !has_seen_import_keyword(source, node, child.start_byte())
1148 {
1149 module_path = source[child.byte_range()].to_string();
1150 } else {
1151 names.push(source[child.byte_range()].to_string());
1153 }
1154 }
1155 "relative_import" => {
1156 module_path = source[child.byte_range()].to_string();
1158 }
1159 _ => {}
1160 }
1161 if !c.goto_next_sibling() {
1162 break;
1163 }
1164 }
1165 }
1166
1167 if module_path.is_empty() {
1169 return None;
1170 }
1171
1172 let group = classify_group_py(&module_path);
1173
1174 Some(ImportStatement {
1175 module_path,
1176 names,
1177 default_import: None,
1178 namespace_import: None,
1179 kind: ImportKind::Value,
1180 group,
1181 byte_range,
1182 raw_text,
1183 })
1184}
1185
1186fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1188 let mut c = parent.walk();
1189 if c.goto_first_child() {
1190 loop {
1191 let child = c.node();
1192 if child.kind() == "import" && child.start_byte() < before_byte {
1193 return true;
1194 }
1195 if child.start_byte() >= before_byte {
1196 return false;
1197 }
1198 if !c.goto_next_sibling() {
1199 break;
1200 }
1201 }
1202 }
1203 false
1204}
1205
1206fn generate_py_import_line(
1208 module_path: &str,
1209 names: &[String],
1210 _default_import: Option<&str>,
1211) -> String {
1212 if names.is_empty() {
1213 format!("import {module_path}")
1215 } else {
1216 let mut sorted = names.to_vec();
1218 sorted.sort();
1219 let names_str = sorted.join(", ");
1220 format!("from {module_path} import {names_str}")
1221 }
1222}
1223
1224pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1230 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1232 match first_seg {
1233 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1234 "crate" | "self" | "super" => ImportGroup::Internal,
1235 _ => ImportGroup::External,
1236 }
1237}
1238
1239fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1241 let root = tree.root_node();
1242 let mut imports = Vec::new();
1243
1244 let mut cursor = root.walk();
1245 if !cursor.goto_first_child() {
1246 return ImportBlock::empty();
1247 }
1248
1249 loop {
1250 let node = cursor.node();
1251 if node.kind() == "use_declaration" {
1252 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1253 imports.push(imp);
1254 }
1255 }
1256 if !cursor.goto_next_sibling() {
1257 break;
1258 }
1259 }
1260
1261 let byte_range = import_byte_range(&imports);
1262
1263 ImportBlock {
1264 imports,
1265 byte_range,
1266 }
1267}
1268
1269fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1271 let raw_text = source[node.byte_range()].to_string();
1272 let byte_range = node.byte_range();
1273
1274 let mut has_pub = false;
1276 let mut use_path = String::new();
1277 let mut names = Vec::new();
1278
1279 let mut c = node.walk();
1280 if c.goto_first_child() {
1281 loop {
1282 let child = c.node();
1283 match child.kind() {
1284 "visibility_modifier" => {
1285 has_pub = true;
1286 }
1287 "scoped_identifier" | "identifier" | "use_as_clause" => {
1288 use_path = source[child.byte_range()].to_string();
1290 }
1291 "scoped_use_list" => {
1292 use_path = source[child.byte_range()].to_string();
1294 extract_rs_use_list_names(source, &child, &mut names);
1296 }
1297 _ => {}
1298 }
1299 if !c.goto_next_sibling() {
1300 break;
1301 }
1302 }
1303 }
1304
1305 if use_path.is_empty() {
1306 return None;
1307 }
1308
1309 let group = classify_group_rs(&use_path);
1310
1311 Some(ImportStatement {
1312 module_path: use_path,
1313 names,
1314 default_import: if has_pub {
1315 Some("pub".to_string())
1316 } else {
1317 None
1318 },
1319 namespace_import: None,
1320 kind: ImportKind::Value,
1321 group,
1322 byte_range,
1323 raw_text,
1324 })
1325}
1326
1327fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1329 let mut c = node.walk();
1330 if c.goto_first_child() {
1331 loop {
1332 let child = c.node();
1333 if child.kind() == "use_list" {
1334 let mut lc = child.walk();
1336 if lc.goto_first_child() {
1337 loop {
1338 let lchild = lc.node();
1339 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1340 names.push(source[lchild.byte_range()].to_string());
1341 }
1342 if !lc.goto_next_sibling() {
1343 break;
1344 }
1345 }
1346 }
1347 }
1348 if !c.goto_next_sibling() {
1349 break;
1350 }
1351 }
1352 }
1353}
1354
1355fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1357 if names.is_empty() {
1358 format!("use {module_path};")
1359 } else {
1360 format!("use {module_path};")
1364 }
1365}
1366
1367pub fn classify_group_go(module_path: &str) -> ImportGroup {
1373 if module_path.contains('.') {
1376 ImportGroup::External
1377 } else {
1378 ImportGroup::Stdlib
1379 }
1380}
1381
1382fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1384 let root = tree.root_node();
1385 let mut imports = Vec::new();
1386
1387 let mut cursor = root.walk();
1388 if !cursor.goto_first_child() {
1389 return ImportBlock::empty();
1390 }
1391
1392 loop {
1393 let node = cursor.node();
1394 if node.kind() == "import_declaration" {
1395 parse_go_import_declaration(source, &node, &mut imports);
1396 }
1397 if !cursor.goto_next_sibling() {
1398 break;
1399 }
1400 }
1401
1402 let byte_range = import_byte_range(&imports);
1403
1404 ImportBlock {
1405 imports,
1406 byte_range,
1407 }
1408}
1409
1410fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1412 let mut c = node.walk();
1413 if c.goto_first_child() {
1414 loop {
1415 let child = c.node();
1416 match child.kind() {
1417 "import_spec" => {
1418 if let Some(imp) = parse_go_import_spec(source, &child) {
1419 imports.push(imp);
1420 }
1421 }
1422 "import_spec_list" => {
1423 let mut lc = child.walk();
1425 if lc.goto_first_child() {
1426 loop {
1427 if lc.node().kind() == "import_spec" {
1428 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1429 imports.push(imp);
1430 }
1431 }
1432 if !lc.goto_next_sibling() {
1433 break;
1434 }
1435 }
1436 }
1437 }
1438 _ => {}
1439 }
1440 if !c.goto_next_sibling() {
1441 break;
1442 }
1443 }
1444 }
1445}
1446
1447fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1449 let raw_text = source[node.byte_range()].to_string();
1450 let byte_range = node.byte_range();
1451
1452 let mut import_path = String::new();
1453 let mut alias = None;
1454
1455 let mut c = node.walk();
1456 if c.goto_first_child() {
1457 loop {
1458 let child = c.node();
1459 match child.kind() {
1460 "interpreted_string_literal" => {
1461 let text = source[child.byte_range()].to_string();
1463 import_path = text.trim_matches('"').to_string();
1464 }
1465 "identifier" | "blank_identifier" | "dot" => {
1466 alias = Some(source[child.byte_range()].to_string());
1468 }
1469 _ => {}
1470 }
1471 if !c.goto_next_sibling() {
1472 break;
1473 }
1474 }
1475 }
1476
1477 if import_path.is_empty() {
1478 return None;
1479 }
1480
1481 let group = classify_group_go(&import_path);
1482
1483 Some(ImportStatement {
1484 module_path: import_path,
1485 names: Vec::new(),
1486 default_import: alias,
1487 namespace_import: None,
1488 kind: ImportKind::Value,
1489 group,
1490 byte_range,
1491 raw_text,
1492 })
1493}
1494
1495pub fn generate_go_import_line_pub(
1497 module_path: &str,
1498 alias: Option<&str>,
1499 in_group: bool,
1500) -> String {
1501 generate_go_import_line(module_path, alias, in_group)
1502}
1503
1504fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1509 if in_group {
1510 match alias {
1512 Some(a) => format!("\t{a} \"{module_path}\""),
1513 None => format!("\t\"{module_path}\""),
1514 }
1515 } else {
1516 match alias {
1518 Some(a) => format!("import {a} \"{module_path}\""),
1519 None => format!("import \"{module_path}\""),
1520 }
1521 }
1522}
1523
1524pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1527 let root = tree.root_node();
1528 let mut cursor = root.walk();
1529 if !cursor.goto_first_child() {
1530 return None;
1531 }
1532
1533 loop {
1534 let node = cursor.node();
1535 if node.kind() == "import_declaration" {
1536 let mut c = node.walk();
1537 if c.goto_first_child() {
1538 loop {
1539 if c.node().kind() == "import_spec_list" {
1540 return Some(node.byte_range());
1541 }
1542 if !c.goto_next_sibling() {
1543 break;
1544 }
1545 }
1546 }
1547 }
1548 if !cursor.goto_next_sibling() {
1549 break;
1550 }
1551 }
1552 None
1553}
1554
1555fn skip_newline(source: &str, pos: usize) -> usize {
1557 if pos < source.len() {
1558 let bytes = source.as_bytes();
1559 if bytes[pos] == b'\n' {
1560 return pos + 1;
1561 }
1562 if bytes[pos] == b'\r' {
1563 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1564 return pos + 2;
1565 }
1566 return pos + 1;
1567 }
1568 }
1569 pos
1570}
1571
1572#[cfg(test)]
1577mod tests {
1578 use super::*;
1579
1580 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1581 let grammar = grammar_for(LangId::TypeScript);
1582 let mut parser = Parser::new();
1583 parser.set_language(&grammar).unwrap();
1584 let tree = parser.parse(source, None).unwrap();
1585 let block = parse_imports(source, &tree, LangId::TypeScript);
1586 (tree, block)
1587 }
1588
1589 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1590 let grammar = grammar_for(LangId::JavaScript);
1591 let mut parser = Parser::new();
1592 parser.set_language(&grammar).unwrap();
1593 let tree = parser.parse(source, None).unwrap();
1594 let block = parse_imports(source, &tree, LangId::JavaScript);
1595 (tree, block)
1596 }
1597
1598 #[test]
1601 fn parse_ts_named_imports() {
1602 let source = "import { useState, useEffect } from 'react';\n";
1603 let (_, block) = parse_ts(source);
1604 assert_eq!(block.imports.len(), 1);
1605 let imp = &block.imports[0];
1606 assert_eq!(imp.module_path, "react");
1607 assert!(imp.names.contains(&"useState".to_string()));
1608 assert!(imp.names.contains(&"useEffect".to_string()));
1609 assert_eq!(imp.kind, ImportKind::Value);
1610 assert_eq!(imp.group, ImportGroup::External);
1611 }
1612
1613 #[test]
1614 fn parse_ts_default_import() {
1615 let source = "import React from 'react';\n";
1616 let (_, block) = parse_ts(source);
1617 assert_eq!(block.imports.len(), 1);
1618 let imp = &block.imports[0];
1619 assert_eq!(imp.default_import.as_deref(), Some("React"));
1620 assert_eq!(imp.kind, ImportKind::Value);
1621 }
1622
1623 #[test]
1624 fn parse_ts_side_effect_import() {
1625 let source = "import './styles.css';\n";
1626 let (_, block) = parse_ts(source);
1627 assert_eq!(block.imports.len(), 1);
1628 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1629 assert_eq!(block.imports[0].module_path, "./styles.css");
1630 }
1631
1632 #[test]
1633 fn parse_ts_relative_import() {
1634 let source = "import { helper } from './utils';\n";
1635 let (_, block) = parse_ts(source);
1636 assert_eq!(block.imports.len(), 1);
1637 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1638 }
1639
1640 #[test]
1641 fn parse_ts_multiple_groups() {
1642 let source = "\
1643import React from 'react';
1644import { useState } from 'react';
1645import { helper } from './utils';
1646import { Config } from '../config';
1647";
1648 let (_, block) = parse_ts(source);
1649 assert_eq!(block.imports.len(), 4);
1650
1651 let external: Vec<_> = block
1652 .imports
1653 .iter()
1654 .filter(|i| i.group == ImportGroup::External)
1655 .collect();
1656 let relative: Vec<_> = block
1657 .imports
1658 .iter()
1659 .filter(|i| i.group == ImportGroup::Internal)
1660 .collect();
1661 assert_eq!(external.len(), 2);
1662 assert_eq!(relative.len(), 2);
1663 }
1664
1665 #[test]
1666 fn parse_ts_namespace_import() {
1667 let source = "import * as path from 'path';\n";
1668 let (_, block) = parse_ts(source);
1669 assert_eq!(block.imports.len(), 1);
1670 let imp = &block.imports[0];
1671 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1672 assert_eq!(imp.kind, ImportKind::Value);
1673 }
1674
1675 #[test]
1676 fn parse_js_imports() {
1677 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1678 let (_, block) = parse_js(source);
1679 assert_eq!(block.imports.len(), 2);
1680 assert_eq!(block.imports[0].group, ImportGroup::External);
1681 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1682 }
1683
1684 #[test]
1687 fn classify_external() {
1688 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1689 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1690 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1691 }
1692
1693 #[test]
1694 fn classify_relative() {
1695 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1696 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1697 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1698 }
1699
1700 #[test]
1703 fn dedup_detects_same_named_import() {
1704 let source = "import { useState } from 'react';\n";
1705 let (_, block) = parse_ts(source);
1706 assert!(is_duplicate(
1707 &block,
1708 "react",
1709 &["useState".to_string()],
1710 None,
1711 false
1712 ));
1713 }
1714
1715 #[test]
1716 fn dedup_misses_different_name() {
1717 let source = "import { useState } from 'react';\n";
1718 let (_, block) = parse_ts(source);
1719 assert!(!is_duplicate(
1720 &block,
1721 "react",
1722 &["useEffect".to_string()],
1723 None,
1724 false
1725 ));
1726 }
1727
1728 #[test]
1729 fn dedup_detects_default_import() {
1730 let source = "import React from 'react';\n";
1731 let (_, block) = parse_ts(source);
1732 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1733 }
1734
1735 #[test]
1736 fn dedup_side_effect() {
1737 let source = "import './styles.css';\n";
1738 let (_, block) = parse_ts(source);
1739 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1740 }
1741
1742 #[test]
1743 fn dedup_namespace_import_distinct_from_side_effect_import() {
1744 let side_effect_source = "import 'fs';\n";
1745 let (_, side_effect_block) = parse_ts(side_effect_source);
1746 assert!(!is_duplicate_with_namespace(
1747 &side_effect_block,
1748 "fs",
1749 &[],
1750 None,
1751 Some("fs"),
1752 false
1753 ));
1754
1755 let namespace_source = "import * as fs from 'fs';\n";
1756 let (_, namespace_block) = parse_ts(namespace_source);
1757 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1758 assert!(is_duplicate_with_namespace(
1759 &namespace_block,
1760 "fs",
1761 &[],
1762 None,
1763 Some("fs"),
1764 false
1765 ));
1766 assert!(!is_duplicate_with_namespace(
1767 &namespace_block,
1768 "fs",
1769 &[],
1770 None,
1771 Some("other"),
1772 false
1773 ));
1774 }
1775
1776 #[test]
1777 fn dedup_type_vs_value() {
1778 let source = "import { FC } from 'react';\n";
1779 let (_, block) = parse_ts(source);
1780 assert!(!is_duplicate(
1782 &block,
1783 "react",
1784 &["FC".to_string()],
1785 None,
1786 true
1787 ));
1788 }
1789
1790 #[test]
1793 fn generate_named_import() {
1794 let line = generate_import_line(
1795 LangId::TypeScript,
1796 "react",
1797 &["useState".to_string(), "useEffect".to_string()],
1798 None,
1799 false,
1800 );
1801 assert_eq!(line, "import { useEffect, useState } from 'react';");
1802 }
1803
1804 #[test]
1805 fn generate_named_import_sorts_by_imported_name() {
1806 let line = generate_import_line(
1807 LangId::TypeScript,
1808 "x",
1809 &[
1810 "useState".to_string(),
1811 "type Foo".to_string(),
1812 "stdin as input".to_string(),
1813 "type Bar".to_string(),
1814 ],
1815 None,
1816 false,
1817 );
1818 assert_eq!(
1819 line,
1820 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1821 );
1822 }
1823
1824 #[test]
1825 fn generate_default_import() {
1826 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1827 assert_eq!(line, "import React from 'react';");
1828 }
1829
1830 #[test]
1831 fn generate_type_import() {
1832 let line =
1833 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1834 assert_eq!(line, "import type { FC } from 'react';");
1835 }
1836
1837 #[test]
1838 fn generate_side_effect_import() {
1839 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1840 assert_eq!(line, "import './styles.css';");
1841 }
1842
1843 #[test]
1844 fn generate_default_and_named() {
1845 let line = generate_import_line(
1846 LangId::TypeScript,
1847 "react",
1848 &["useState".to_string()],
1849 Some("React"),
1850 false,
1851 );
1852 assert_eq!(line, "import React, { useState } from 'react';");
1853 }
1854
1855 #[test]
1856 fn parse_ts_type_import() {
1857 let source = "import type { FC } from 'react';\n";
1858 let (_, block) = parse_ts(source);
1859 assert_eq!(block.imports.len(), 1);
1860 let imp = &block.imports[0];
1861 assert_eq!(imp.kind, ImportKind::Type);
1862 assert!(imp.names.contains(&"FC".to_string()));
1863 assert_eq!(imp.group, ImportGroup::External);
1864 }
1865
1866 #[test]
1869 fn insertion_empty_file() {
1870 let source = "";
1871 let (_, block) = parse_ts(source);
1872 let (offset, _, _) =
1873 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1874 assert_eq!(offset, 0);
1875 }
1876
1877 #[test]
1878 fn insertion_alphabetical_within_group() {
1879 let source = "\
1880import { a } from 'alpha';
1881import { c } from 'charlie';
1882";
1883 let (_, block) = parse_ts(source);
1884 let (offset, _, _) =
1885 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1886 let before_charlie = source.find("import { c }").unwrap();
1888 assert_eq!(offset, before_charlie);
1889 }
1890
1891 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1894 let grammar = grammar_for(LangId::Python);
1895 let mut parser = Parser::new();
1896 parser.set_language(&grammar).unwrap();
1897 let tree = parser.parse(source, None).unwrap();
1898 let block = parse_imports(source, &tree, LangId::Python);
1899 (tree, block)
1900 }
1901
1902 #[test]
1903 fn parse_py_import_statement() {
1904 let source = "import os\nimport sys\n";
1905 let (_, block) = parse_py(source);
1906 assert_eq!(block.imports.len(), 2);
1907 assert_eq!(block.imports[0].module_path, "os");
1908 assert_eq!(block.imports[1].module_path, "sys");
1909 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1910 }
1911
1912 #[test]
1913 fn parse_py_from_import() {
1914 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1915 let (_, block) = parse_py(source);
1916 assert_eq!(block.imports.len(), 2);
1917 assert_eq!(block.imports[0].module_path, "collections");
1918 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1919 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1920 assert_eq!(block.imports[1].module_path, "typing");
1921 assert!(block.imports[1].names.contains(&"List".to_string()));
1922 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1923 }
1924
1925 #[test]
1926 fn parse_py_relative_import() {
1927 let source = "from . import utils\nfrom ..config import Settings\n";
1928 let (_, block) = parse_py(source);
1929 assert_eq!(block.imports.len(), 2);
1930 assert_eq!(block.imports[0].module_path, ".");
1931 assert!(block.imports[0].names.contains(&"utils".to_string()));
1932 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1933 assert_eq!(block.imports[1].module_path, "..config");
1934 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1935 }
1936
1937 #[test]
1938 fn classify_py_groups() {
1939 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1940 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1941 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1942 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1943 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1944 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1945 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1946 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1947 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1948 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1949 }
1950
1951 #[test]
1952 fn parse_py_three_groups() {
1953 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1954 let (_, block) = parse_py(source);
1955 let stdlib: Vec<_> = block
1956 .imports
1957 .iter()
1958 .filter(|i| i.group == ImportGroup::Stdlib)
1959 .collect();
1960 let external: Vec<_> = block
1961 .imports
1962 .iter()
1963 .filter(|i| i.group == ImportGroup::External)
1964 .collect();
1965 let internal: Vec<_> = block
1966 .imports
1967 .iter()
1968 .filter(|i| i.group == ImportGroup::Internal)
1969 .collect();
1970 assert_eq!(stdlib.len(), 2);
1971 assert_eq!(external.len(), 1);
1972 assert_eq!(internal.len(), 1);
1973 }
1974
1975 #[test]
1976 fn generate_py_import() {
1977 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1978 assert_eq!(line, "import os");
1979 }
1980
1981 #[test]
1982 fn generate_py_from_import() {
1983 let line = generate_import_line(
1984 LangId::Python,
1985 "collections",
1986 &["OrderedDict".to_string()],
1987 None,
1988 false,
1989 );
1990 assert_eq!(line, "from collections import OrderedDict");
1991 }
1992
1993 #[test]
1994 fn generate_py_from_import_multiple() {
1995 let line = generate_import_line(
1996 LangId::Python,
1997 "typing",
1998 &["Optional".to_string(), "List".to_string()],
1999 None,
2000 false,
2001 );
2002 assert_eq!(line, "from typing import List, Optional");
2003 }
2004
2005 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
2008 let grammar = grammar_for(LangId::Rust);
2009 let mut parser = Parser::new();
2010 parser.set_language(&grammar).unwrap();
2011 let tree = parser.parse(source, None).unwrap();
2012 let block = parse_imports(source, &tree, LangId::Rust);
2013 (tree, block)
2014 }
2015
2016 #[test]
2017 fn parse_rs_use_std() {
2018 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
2019 let (_, block) = parse_rust(source);
2020 assert_eq!(block.imports.len(), 2);
2021 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
2022 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2023 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2024 }
2025
2026 #[test]
2027 fn parse_rs_use_external() {
2028 let source = "use serde::{Deserialize, Serialize};\n";
2029 let (_, block) = parse_rust(source);
2030 assert_eq!(block.imports.len(), 1);
2031 assert_eq!(block.imports[0].group, ImportGroup::External);
2032 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
2033 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
2034 }
2035
2036 #[test]
2037 fn parse_rs_use_crate() {
2038 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
2039 let (_, block) = parse_rust(source);
2040 assert_eq!(block.imports.len(), 2);
2041 assert_eq!(block.imports[0].group, ImportGroup::Internal);
2042 assert_eq!(block.imports[1].group, ImportGroup::Internal);
2043 }
2044
2045 #[test]
2046 fn parse_rs_pub_use() {
2047 let source = "pub use super::parent::Thing;\n";
2048 let (_, block) = parse_rust(source);
2049 assert_eq!(block.imports.len(), 1);
2050 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
2052 }
2053
2054 #[test]
2055 fn classify_rs_groups() {
2056 assert_eq!(
2057 classify_group_rs("std::collections::HashMap"),
2058 ImportGroup::Stdlib
2059 );
2060 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
2061 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
2062 assert_eq!(
2063 classify_group_rs("serde::Deserialize"),
2064 ImportGroup::External
2065 );
2066 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
2067 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
2068 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
2069 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
2070 }
2071
2072 #[test]
2073 fn generate_rs_use() {
2074 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2075 assert_eq!(line, "use std::fmt::Display;");
2076 }
2077
2078 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2081 let grammar = grammar_for(LangId::Go);
2082 let mut parser = Parser::new();
2083 parser.set_language(&grammar).unwrap();
2084 let tree = parser.parse(source, None).unwrap();
2085 let block = parse_imports(source, &tree, LangId::Go);
2086 (tree, block)
2087 }
2088
2089 #[test]
2090 fn parse_go_single_import() {
2091 let source = "package main\n\nimport \"fmt\"\n";
2092 let (_, block) = parse_go(source);
2093 assert_eq!(block.imports.len(), 1);
2094 assert_eq!(block.imports[0].module_path, "fmt");
2095 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2096 }
2097
2098 #[test]
2099 fn parse_go_grouped_import() {
2100 let source =
2101 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2102 let (_, block) = parse_go(source);
2103 assert_eq!(block.imports.len(), 3);
2104 assert_eq!(block.imports[0].module_path, "fmt");
2105 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2106 assert_eq!(block.imports[1].module_path, "os");
2107 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2108 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2109 assert_eq!(block.imports[2].group, ImportGroup::External);
2110 }
2111
2112 #[test]
2113 fn parse_go_mixed_imports() {
2114 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2116 let (_, block) = parse_go(source);
2117 assert_eq!(block.imports.len(), 3);
2118 }
2119
2120 #[test]
2121 fn classify_go_groups() {
2122 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2123 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2124 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2125 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2126 assert_eq!(
2127 classify_group_go("github.com/pkg/errors"),
2128 ImportGroup::External
2129 );
2130 assert_eq!(
2131 classify_group_go("golang.org/x/tools"),
2132 ImportGroup::External
2133 );
2134 }
2135
2136 #[test]
2137 fn generate_go_standalone() {
2138 let line = generate_go_import_line("fmt", None, false);
2139 assert_eq!(line, "import \"fmt\"");
2140 }
2141
2142 #[test]
2143 fn generate_go_grouped_spec() {
2144 let line = generate_go_import_line("fmt", None, true);
2145 assert_eq!(line, "\t\"fmt\"");
2146 }
2147
2148 #[test]
2149 fn generate_go_with_alias() {
2150 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2151 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2152 }
2153}