1use std::ops::Range;
10
11use tree_sitter::{Node, Parser, Tree};
12
13use crate::parser::{grammar_for, LangId};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ImportKind {
22 Value,
24 Type,
26 SideEffect,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
41pub enum ImportGroup {
42 Stdlib,
45 External,
47 Internal,
49}
50
51impl ImportGroup {
52 pub fn label(&self) -> &'static str {
54 match self {
55 ImportGroup::Stdlib => "stdlib",
56 ImportGroup::External => "external",
57 ImportGroup::Internal => "internal",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ImportStatement {
65 pub module_path: String,
67 pub names: Vec<String>,
69 pub default_import: Option<String>,
71 pub namespace_import: Option<String>,
73 pub kind: ImportKind,
75 pub group: ImportGroup,
77 pub byte_range: Range<usize>,
79 pub raw_text: String,
81}
82
83#[derive(Debug, Clone)]
85pub struct ImportBlock {
86 pub imports: Vec<ImportStatement>,
88 pub byte_range: Option<Range<usize>>,
91}
92
93impl ImportBlock {
94 pub fn empty() -> Self {
95 ImportBlock {
96 imports: Vec::new(),
97 byte_range: None,
98 }
99 }
100}
101
102fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
103 imports.first().zip(imports.last()).map(|(first, last)| {
104 let start = first.byte_range.start;
105 let end = last.byte_range.end;
106 start..end
107 })
108}
109
110pub fn specifier_local_name(spec: &str) -> &str {
126 let trimmed = spec.trim();
127 let after_type = trimmed
128 .strip_prefix("type ")
129 .unwrap_or(trimmed)
130 .trim_start();
131 if let Some(idx) = after_type.find(" as ") {
132 after_type[idx + 4..].trim()
133 } else {
134 after_type
135 }
136}
137
138pub fn specifier_imported_name(spec: &str) -> &str {
147 let trimmed = spec.trim();
148 let after_type = trimmed
149 .strip_prefix("type ")
150 .unwrap_or(trimmed)
151 .trim_start();
152 after_type
153 .find(" as ")
154 .map(|idx| after_type[..idx].trim())
155 .unwrap_or(after_type)
156}
157
158pub fn specifier_matches(spec: &str, target: &str) -> bool {
163 specifier_imported_name(spec) == target || specifier_local_name(spec) == target
164}
165
166pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
172 match lang {
173 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
174 LangId::Python => parse_py_imports(source, tree),
175 LangId::Rust => parse_rs_imports(source, tree),
176 LangId::Go => parse_go_imports(source, tree),
177 LangId::C
178 | LangId::Cpp
179 | LangId::Zig
180 | LangId::CSharp
181 | LangId::Bash
182 | LangId::Solidity
183 | LangId::Vue
184 | LangId::Json
185 | LangId::Scala
186 | LangId::Java
187 | LangId::Ruby
188 | LangId::Kotlin
189 | LangId::Swift
190 | LangId::Php
191 | LangId::Lua
192 | LangId::Perl => ImportBlock::empty(),
193 LangId::Html | LangId::Markdown => ImportBlock::empty(),
194 }
195}
196
197pub fn is_duplicate(
203 block: &ImportBlock,
204 module_path: &str,
205 names: &[String],
206 default_import: Option<&str>,
207 type_only: bool,
208) -> bool {
209 is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
210}
211
212pub fn is_duplicate_with_namespace(
214 block: &ImportBlock,
215 module_path: &str,
216 names: &[String],
217 default_import: Option<&str>,
218 namespace_import: Option<&str>,
219 type_only: bool,
220) -> bool {
221 let target_kind = if type_only {
222 ImportKind::Type
223 } else {
224 ImportKind::Value
225 };
226
227 for imp in &block.imports {
228 if imp.module_path != module_path {
229 continue;
230 }
231
232 if names.is_empty()
238 && default_import.is_none()
239 && namespace_import.is_none()
240 && imp.names.is_empty()
241 && imp.default_import.is_none()
242 && imp.namespace_import.is_none()
243 {
244 return true;
245 }
246
247 if names.is_empty()
249 && default_import.is_none()
250 && namespace_import.is_none()
251 && imp.kind == ImportKind::SideEffect
252 {
253 return true;
254 }
255
256 if names.is_empty()
257 && default_import.is_none()
258 && namespace_import.is_some()
259 && imp.names.is_empty()
260 && imp.default_import.is_none()
261 && imp.namespace_import.as_deref() == namespace_import
262 {
263 return true;
264 }
265
266 if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
268 continue;
269 }
270
271 if let Some(def) = default_import {
273 if imp.default_import.as_deref() == Some(def) && imp.namespace_import.is_none() {
274 return true;
275 }
276 }
277
278 if !names.is_empty()
284 && names
285 .iter()
286 .all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
287 {
288 return true;
289 }
290 }
291
292 false
293}
294
295fn sort_named_specifiers(names: &mut [String]) {
296 names.sort_by(|a, b| {
297 specifier_imported_name(a)
298 .cmp(specifier_imported_name(b))
299 .then_with(|| a.cmp(b))
300 });
301}
302
303pub fn find_insertion_point(
314 source: &str,
315 block: &ImportBlock,
316 group: ImportGroup,
317 module_path: &str,
318 type_only: bool,
319) -> (usize, bool, bool) {
320 if block.imports.is_empty() {
321 return (0, false, source.is_empty().then_some(false).unwrap_or(true));
323 }
324
325 let target_kind = if type_only {
326 ImportKind::Type
327 } else {
328 ImportKind::Value
329 };
330
331 let group_imports: Vec<&ImportStatement> =
333 block.imports.iter().filter(|i| i.group == group).collect();
334
335 if group_imports.is_empty() {
336 let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
339
340 if let Some(last) = preceding_last {
341 let end = last.byte_range.end;
342 let insert_at = skip_newline(source, end);
343 return (insert_at, true, true);
344 }
345
346 let following_first = block.imports.iter().find(|i| i.group > group);
348
349 if let Some(first) = following_first {
350 return (first.byte_range.start, false, true);
351 }
352
353 let first_byte = import_byte_range(&block.imports)
355 .map(|range| range.start)
356 .unwrap_or(0);
357 return (first_byte, false, true);
358 }
359
360 for imp in &group_imports {
362 let cmp = module_path.cmp(&imp.module_path);
363 match cmp {
364 std::cmp::Ordering::Less => {
365 return (imp.byte_range.start, false, false);
367 }
368 std::cmp::Ordering::Equal => {
369 if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
371 let end = imp.byte_range.end;
373 let insert_at = skip_newline(source, end);
374 return (insert_at, false, false);
375 }
376 return (imp.byte_range.start, false, false);
378 }
379 std::cmp::Ordering::Greater => continue,
380 }
381 }
382
383 let Some(last) = group_imports.last() else {
385 return (
386 import_byte_range(&block.imports)
387 .map(|range| range.end)
388 .unwrap_or(0),
389 false,
390 false,
391 );
392 };
393 let end = last.byte_range.end;
394 let insert_at = skip_newline(source, end);
395 (insert_at, false, false)
396}
397
398pub fn generate_import_line(
400 lang: LangId,
401 module_path: &str,
402 names: &[String],
403 default_import: Option<&str>,
404 type_only: bool,
405) -> String {
406 generate_import_line_with_namespace(lang, module_path, names, default_import, None, type_only)
407}
408
409pub fn generate_import_line_with_namespace(
412 lang: LangId,
413 module_path: &str,
414 names: &[String],
415 default_import: Option<&str>,
416 namespace_import: Option<&str>,
417 type_only: bool,
418) -> String {
419 match lang {
420 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => generate_ts_import_line(
421 module_path,
422 names,
423 default_import,
424 namespace_import,
425 type_only,
426 ),
427 LangId::Python => generate_py_import_line(module_path, names, default_import),
428 LangId::Rust => generate_rs_import_line(module_path, names, type_only),
429 LangId::Go => generate_go_import_line(module_path, default_import, false),
430 LangId::C
431 | LangId::Cpp
432 | LangId::Zig
433 | LangId::CSharp
434 | LangId::Bash
435 | LangId::Solidity
436 | LangId::Vue
437 | LangId::Json
438 | LangId::Scala
439 | LangId::Java
440 | LangId::Ruby
441 | LangId::Kotlin
442 | LangId::Swift
443 | LangId::Php
444 | LangId::Lua
445 | LangId::Perl => String::new(),
446 LangId::Html | LangId::Markdown => String::new(),
447 }
448}
449
450pub fn is_supported(lang: LangId) -> bool {
452 matches!(
453 lang,
454 LangId::TypeScript
455 | LangId::Tsx
456 | LangId::JavaScript
457 | LangId::Python
458 | LangId::Rust
459 | LangId::Go
460 )
461}
462
463pub fn classify_group_ts(module_path: &str) -> ImportGroup {
465 if module_path.starts_with('.') {
466 ImportGroup::Internal
467 } else {
468 ImportGroup::External
469 }
470}
471
472pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
474 match lang {
475 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
476 LangId::Python => classify_group_py(module_path),
477 LangId::Rust => classify_group_rs(module_path),
478 LangId::Go => classify_group_go(module_path),
479 LangId::C
480 | LangId::Cpp
481 | LangId::Zig
482 | LangId::CSharp
483 | LangId::Bash
484 | LangId::Solidity
485 | LangId::Vue
486 | LangId::Json
487 | LangId::Scala
488 | LangId::Java
489 | LangId::Ruby
490 | LangId::Kotlin
491 | LangId::Swift
492 | LangId::Php
493 | LangId::Lua
494 | LangId::Perl => ImportGroup::External,
495 LangId::Html | LangId::Markdown => ImportGroup::External,
496 }
497}
498
499pub fn parse_file_imports(
502 path: &std::path::Path,
503 lang: LangId,
504) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
505 let source =
506 std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
507 path: format!("{}: {}", path.display(), e),
508 })?;
509
510 let grammar = grammar_for(lang);
511 let mut parser = Parser::new();
512 parser
513 .set_language(&grammar)
514 .map_err(|e| crate::error::AftError::ParseError {
515 message: format!("grammar init failed for {:?}: {}", lang, e),
516 })?;
517
518 let tree = parser
519 .parse(&source, None)
520 .ok_or_else(|| crate::error::AftError::ParseError {
521 message: format!("tree-sitter parse returned None for {}", path.display()),
522 })?;
523
524 let block = parse_imports(&source, &tree, lang);
525 Ok((source, tree, block))
526}
527
528fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
536 let root = tree.root_node();
537 let mut imports = Vec::new();
538
539 let mut cursor = root.walk();
540 if !cursor.goto_first_child() {
541 return ImportBlock::empty();
542 }
543
544 loop {
545 let node = cursor.node();
546 if node.kind() == "import_statement" {
547 if let Some(imp) = parse_single_ts_import(source, &node) {
548 imports.push(imp);
549 }
550 }
551 if !cursor.goto_next_sibling() {
552 break;
553 }
554 }
555
556 let byte_range = import_byte_range(&imports);
557
558 ImportBlock {
559 imports,
560 byte_range,
561 }
562}
563
564fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
566 let raw_text = source[node.byte_range()].to_string();
567 let byte_range = node.byte_range();
568
569 let module_path = extract_module_path(source, node)?;
571
572 let is_type_only = has_type_keyword(node);
574
575 let mut names = Vec::new();
577 let mut default_import = None;
578 let mut namespace_import = None;
579
580 let mut child_cursor = node.walk();
581 if child_cursor.goto_first_child() {
582 loop {
583 let child = child_cursor.node();
584 match child.kind() {
585 "import_clause" => {
586 extract_import_clause(
587 source,
588 &child,
589 &mut names,
590 &mut default_import,
591 &mut namespace_import,
592 );
593 }
594 "identifier" => {
596 let text = &source[child.byte_range()];
597 if text != "import" && text != "from" && text != "type" {
598 default_import = Some(text.to_string());
599 }
600 }
601 _ => {}
602 }
603 if !child_cursor.goto_next_sibling() {
604 break;
605 }
606 }
607 }
608
609 let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
611 ImportKind::SideEffect
612 } else if is_type_only {
613 ImportKind::Type
614 } else {
615 ImportKind::Value
616 };
617
618 let group = classify_group_ts(&module_path);
619
620 Some(ImportStatement {
621 module_path,
622 names,
623 default_import,
624 namespace_import,
625 kind,
626 group,
627 byte_range,
628 raw_text,
629 })
630}
631
632fn extract_module_path(source: &str, node: &Node) -> Option<String> {
636 let mut cursor = node.walk();
637 if !cursor.goto_first_child() {
638 return None;
639 }
640
641 loop {
642 let child = cursor.node();
643 if child.kind() == "string" {
644 let text = &source[child.byte_range()];
646 let stripped = text
647 .trim_start_matches(|c| c == '\'' || c == '"')
648 .trim_end_matches(|c| c == '\'' || c == '"');
649 return Some(stripped.to_string());
650 }
651 if !cursor.goto_next_sibling() {
652 break;
653 }
654 }
655 None
656}
657
658fn has_type_keyword(node: &Node) -> bool {
663 let mut cursor = node.walk();
664 if !cursor.goto_first_child() {
665 return false;
666 }
667
668 loop {
669 let child = cursor.node();
670 if child.kind() == "type" {
671 return true;
672 }
673 if !cursor.goto_next_sibling() {
674 break;
675 }
676 }
677
678 false
679}
680
681fn extract_import_clause(
683 source: &str,
684 node: &Node,
685 names: &mut Vec<String>,
686 default_import: &mut Option<String>,
687 namespace_import: &mut Option<String>,
688) {
689 let mut cursor = node.walk();
690 if !cursor.goto_first_child() {
691 return;
692 }
693
694 loop {
695 let child = cursor.node();
696 match child.kind() {
697 "identifier" => {
698 let text = &source[child.byte_range()];
700 if text != "type" {
701 *default_import = Some(text.to_string());
702 }
703 }
704 "named_imports" => {
705 extract_named_imports(source, &child, names);
707 }
708 "namespace_import" => {
709 extract_namespace_import(source, &child, namespace_import);
711 }
712 _ => {}
713 }
714 if !cursor.goto_next_sibling() {
715 break;
716 }
717 }
718}
719
720fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
739 let mut cursor = node.walk();
740 if !cursor.goto_first_child() {
741 return;
742 }
743
744 loop {
745 let child = cursor.node();
746 if child.kind() == "import_specifier" {
747 let raw = source[child.byte_range()].trim().to_string();
752 if !raw.is_empty() {
753 names.push(raw);
754 } else if let Some(name_node) = child.child_by_field_name("name") {
755 names.push(source[name_node.byte_range()].to_string());
756 }
757 }
758 if !cursor.goto_next_sibling() {
759 break;
760 }
761 }
762}
763
764fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
766 let mut cursor = node.walk();
767 if !cursor.goto_first_child() {
768 return;
769 }
770
771 loop {
772 let child = cursor.node();
773 if child.kind() == "identifier" {
774 *namespace_import = Some(source[child.byte_range()].to_string());
775 return;
776 }
777 if !cursor.goto_next_sibling() {
778 break;
779 }
780 }
781}
782
783fn generate_ts_import_line(
785 module_path: &str,
786 names: &[String],
787 default_import: Option<&str>,
788 namespace_import: Option<&str>,
789 type_only: bool,
790) -> String {
791 let type_prefix = if type_only { "type " } else { "" };
792
793 if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
795 return format!("import '{module_path}';");
796 }
797
798 if names.is_empty() && default_import.is_none() {
800 if let Some(namespace) = namespace_import {
801 return format!("import {type_prefix}* as {namespace} from '{module_path}';");
802 }
803 }
804
805 if names.is_empty() {
807 if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
808 return format!("import {type_prefix}{def}, * as {namespace} from '{module_path}';");
809 }
810 }
811
812 if names.is_empty() && namespace_import.is_none() {
814 if let Some(def) = default_import {
815 return format!("import {type_prefix}{def} from '{module_path}';");
816 }
817 }
818
819 if default_import.is_none() && namespace_import.is_none() {
821 let mut sorted_names = names.to_vec();
822 sort_named_specifiers(&mut sorted_names);
823 let names_str = sorted_names.join(", ");
824 return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
825 }
826
827 if default_import.is_none() {
829 if let Some(namespace) = namespace_import {
830 let mut sorted_names = names.to_vec();
831 sort_named_specifiers(&mut sorted_names);
832 let names_str = sorted_names.join(", ");
833 return format!(
834 "import {type_prefix}{{ {names_str} }}, * as {namespace} from '{module_path}';"
835 );
836 }
837 }
838
839 if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
841 let mut sorted_names = names.to_vec();
842 sort_named_specifiers(&mut sorted_names);
843 let names_str = sorted_names.join(", ");
844 return format!(
845 "import {type_prefix}{def}, {{ {names_str} }}, * as {namespace} from '{module_path}';"
846 );
847 }
848
849 if let Some(def) = default_import {
851 let mut sorted_names = names.to_vec();
852 sort_named_specifiers(&mut sorted_names);
853 let names_str = sorted_names.join(", ");
854 return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
855 }
856
857 format!("import '{module_path}';")
859}
860
861const PYTHON_STDLIB: &[&str] = &[
869 "__future__",
870 "_thread",
871 "abc",
872 "aifc",
873 "argparse",
874 "array",
875 "ast",
876 "asynchat",
877 "asyncio",
878 "asyncore",
879 "atexit",
880 "audioop",
881 "base64",
882 "bdb",
883 "binascii",
884 "bisect",
885 "builtins",
886 "bz2",
887 "calendar",
888 "cgi",
889 "cgitb",
890 "chunk",
891 "cmath",
892 "cmd",
893 "code",
894 "codecs",
895 "codeop",
896 "collections",
897 "colorsys",
898 "compileall",
899 "concurrent",
900 "configparser",
901 "contextlib",
902 "contextvars",
903 "copy",
904 "copyreg",
905 "cProfile",
906 "crypt",
907 "csv",
908 "ctypes",
909 "curses",
910 "dataclasses",
911 "datetime",
912 "dbm",
913 "decimal",
914 "difflib",
915 "dis",
916 "distutils",
917 "doctest",
918 "email",
919 "encodings",
920 "enum",
921 "errno",
922 "faulthandler",
923 "fcntl",
924 "filecmp",
925 "fileinput",
926 "fnmatch",
927 "fractions",
928 "ftplib",
929 "functools",
930 "gc",
931 "getopt",
932 "getpass",
933 "gettext",
934 "glob",
935 "grp",
936 "gzip",
937 "hashlib",
938 "heapq",
939 "hmac",
940 "html",
941 "http",
942 "idlelib",
943 "imaplib",
944 "imghdr",
945 "importlib",
946 "inspect",
947 "io",
948 "ipaddress",
949 "itertools",
950 "json",
951 "keyword",
952 "lib2to3",
953 "linecache",
954 "locale",
955 "logging",
956 "lzma",
957 "mailbox",
958 "mailcap",
959 "marshal",
960 "math",
961 "mimetypes",
962 "mmap",
963 "modulefinder",
964 "multiprocessing",
965 "netrc",
966 "numbers",
967 "operator",
968 "optparse",
969 "os",
970 "pathlib",
971 "pdb",
972 "pickle",
973 "pickletools",
974 "pipes",
975 "pkgutil",
976 "platform",
977 "plistlib",
978 "poplib",
979 "posixpath",
980 "pprint",
981 "profile",
982 "pstats",
983 "pty",
984 "pwd",
985 "py_compile",
986 "pyclbr",
987 "pydoc",
988 "queue",
989 "quopri",
990 "random",
991 "re",
992 "readline",
993 "reprlib",
994 "resource",
995 "rlcompleter",
996 "runpy",
997 "sched",
998 "secrets",
999 "select",
1000 "selectors",
1001 "shelve",
1002 "shlex",
1003 "shutil",
1004 "signal",
1005 "site",
1006 "smtplib",
1007 "sndhdr",
1008 "socket",
1009 "socketserver",
1010 "sqlite3",
1011 "ssl",
1012 "stat",
1013 "statistics",
1014 "string",
1015 "stringprep",
1016 "struct",
1017 "subprocess",
1018 "symtable",
1019 "sys",
1020 "sysconfig",
1021 "syslog",
1022 "tabnanny",
1023 "tarfile",
1024 "tempfile",
1025 "termios",
1026 "textwrap",
1027 "threading",
1028 "time",
1029 "timeit",
1030 "tkinter",
1031 "token",
1032 "tokenize",
1033 "tomllib",
1034 "trace",
1035 "traceback",
1036 "tracemalloc",
1037 "tty",
1038 "turtle",
1039 "types",
1040 "typing",
1041 "unicodedata",
1042 "unittest",
1043 "urllib",
1044 "uuid",
1045 "venv",
1046 "warnings",
1047 "wave",
1048 "weakref",
1049 "webbrowser",
1050 "wsgiref",
1051 "xml",
1052 "xmlrpc",
1053 "zipapp",
1054 "zipfile",
1055 "zipimport",
1056 "zlib",
1057];
1058
1059pub fn classify_group_py(module_path: &str) -> ImportGroup {
1061 if module_path.starts_with('.') {
1063 return ImportGroup::Internal;
1064 }
1065 let top_module = module_path.split('.').next().unwrap_or(module_path);
1067 if PYTHON_STDLIB.contains(&top_module) {
1068 ImportGroup::Stdlib
1069 } else {
1070 ImportGroup::External
1071 }
1072}
1073
1074fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
1076 let root = tree.root_node();
1077 let mut imports = Vec::new();
1078
1079 let mut cursor = root.walk();
1080 if !cursor.goto_first_child() {
1081 return ImportBlock::empty();
1082 }
1083
1084 loop {
1085 let node = cursor.node();
1086 match node.kind() {
1087 "import_statement" => {
1088 if let Some(imp) = parse_py_import_statement(source, &node) {
1089 imports.push(imp);
1090 }
1091 }
1092 "import_from_statement" => {
1093 if let Some(imp) = parse_py_import_from_statement(source, &node) {
1094 imports.push(imp);
1095 }
1096 }
1097 _ => {}
1098 }
1099 if !cursor.goto_next_sibling() {
1100 break;
1101 }
1102 }
1103
1104 let byte_range = import_byte_range(&imports);
1105
1106 ImportBlock {
1107 imports,
1108 byte_range,
1109 }
1110}
1111
1112fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1114 let raw_text = source[node.byte_range()].to_string();
1115 let byte_range = node.byte_range();
1116
1117 let mut module_path = String::new();
1119 let mut c = node.walk();
1120 if c.goto_first_child() {
1121 loop {
1122 if c.node().kind() == "dotted_name" {
1123 module_path = source[c.node().byte_range()].to_string();
1124 break;
1125 }
1126 if !c.goto_next_sibling() {
1127 break;
1128 }
1129 }
1130 }
1131 if module_path.is_empty() {
1132 return None;
1133 }
1134
1135 let group = classify_group_py(&module_path);
1136
1137 Some(ImportStatement {
1138 module_path,
1139 names: Vec::new(),
1140 default_import: None,
1141 namespace_import: None,
1142 kind: ImportKind::Value,
1143 group,
1144 byte_range,
1145 raw_text,
1146 })
1147}
1148
1149fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
1151 let raw_text = source[node.byte_range()].to_string();
1152 let byte_range = node.byte_range();
1153
1154 let mut module_path = String::new();
1155 let mut names = Vec::new();
1156
1157 let mut c = node.walk();
1158 if c.goto_first_child() {
1159 loop {
1160 let child = c.node();
1161 match child.kind() {
1162 "dotted_name" => {
1163 if module_path.is_empty()
1168 && !has_seen_import_keyword(source, node, child.start_byte())
1169 {
1170 module_path = source[child.byte_range()].to_string();
1171 } else {
1172 names.push(source[child.byte_range()].to_string());
1174 }
1175 }
1176 "relative_import" => {
1177 module_path = source[child.byte_range()].to_string();
1179 }
1180 _ => {}
1181 }
1182 if !c.goto_next_sibling() {
1183 break;
1184 }
1185 }
1186 }
1187
1188 if module_path.is_empty() {
1190 return None;
1191 }
1192
1193 let group = classify_group_py(&module_path);
1194
1195 Some(ImportStatement {
1196 module_path,
1197 names,
1198 default_import: None,
1199 namespace_import: None,
1200 kind: ImportKind::Value,
1201 group,
1202 byte_range,
1203 raw_text,
1204 })
1205}
1206
1207fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
1209 let mut c = parent.walk();
1210 if c.goto_first_child() {
1211 loop {
1212 let child = c.node();
1213 if child.kind() == "import" && child.start_byte() < before_byte {
1214 return true;
1215 }
1216 if child.start_byte() >= before_byte {
1217 return false;
1218 }
1219 if !c.goto_next_sibling() {
1220 break;
1221 }
1222 }
1223 }
1224 false
1225}
1226
1227fn generate_py_import_line(
1229 module_path: &str,
1230 names: &[String],
1231 _default_import: Option<&str>,
1232) -> String {
1233 if names.is_empty() {
1234 format!("import {module_path}")
1236 } else {
1237 let mut sorted = names.to_vec();
1239 sorted.sort();
1240 let names_str = sorted.join(", ");
1241 format!("from {module_path} import {names_str}")
1242 }
1243}
1244
1245pub fn classify_group_rs(module_path: &str) -> ImportGroup {
1251 let first_seg = module_path.split("::").next().unwrap_or(module_path);
1253 match first_seg {
1254 "std" | "core" | "alloc" => ImportGroup::Stdlib,
1255 "crate" | "self" | "super" => ImportGroup::Internal,
1256 _ => ImportGroup::External,
1257 }
1258}
1259
1260fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
1262 let root = tree.root_node();
1263 let mut imports = Vec::new();
1264
1265 let mut cursor = root.walk();
1266 if !cursor.goto_first_child() {
1267 return ImportBlock::empty();
1268 }
1269
1270 loop {
1271 let node = cursor.node();
1272 if node.kind() == "use_declaration" {
1273 if let Some(imp) = parse_rs_use_declaration(source, &node) {
1274 imports.push(imp);
1275 }
1276 }
1277 if !cursor.goto_next_sibling() {
1278 break;
1279 }
1280 }
1281
1282 let byte_range = import_byte_range(&imports);
1283
1284 ImportBlock {
1285 imports,
1286 byte_range,
1287 }
1288}
1289
1290fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
1292 let raw_text = source[node.byte_range()].to_string();
1293 let byte_range = node.byte_range();
1294
1295 let mut has_pub = false;
1297 let mut use_path = String::new();
1298 let mut names = Vec::new();
1299
1300 let mut c = node.walk();
1301 if c.goto_first_child() {
1302 loop {
1303 let child = c.node();
1304 match child.kind() {
1305 "visibility_modifier" => {
1306 has_pub = true;
1307 }
1308 "scoped_identifier" | "identifier" | "use_as_clause" => {
1309 use_path = source[child.byte_range()].to_string();
1311 }
1312 "scoped_use_list" => {
1313 use_path = source[child.byte_range()].to_string();
1315 extract_rs_use_list_names(source, &child, &mut names);
1317 }
1318 _ => {}
1319 }
1320 if !c.goto_next_sibling() {
1321 break;
1322 }
1323 }
1324 }
1325
1326 if use_path.is_empty() {
1327 return None;
1328 }
1329
1330 let group = classify_group_rs(&use_path);
1331
1332 Some(ImportStatement {
1333 module_path: use_path,
1334 names,
1335 default_import: if has_pub {
1336 Some("pub".to_string())
1337 } else {
1338 None
1339 },
1340 namespace_import: None,
1341 kind: ImportKind::Value,
1342 group,
1343 byte_range,
1344 raw_text,
1345 })
1346}
1347
1348fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
1350 let mut c = node.walk();
1351 if c.goto_first_child() {
1352 loop {
1353 let child = c.node();
1354 if child.kind() == "use_list" {
1355 let mut lc = child.walk();
1357 if lc.goto_first_child() {
1358 loop {
1359 let lchild = lc.node();
1360 if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
1361 names.push(source[lchild.byte_range()].to_string());
1362 }
1363 if !lc.goto_next_sibling() {
1364 break;
1365 }
1366 }
1367 }
1368 }
1369 if !c.goto_next_sibling() {
1370 break;
1371 }
1372 }
1373 }
1374}
1375
1376fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
1378 if names.is_empty() {
1379 format!("use {module_path};")
1380 } else {
1381 format!("use {module_path};")
1385 }
1386}
1387
1388pub fn classify_group_go(module_path: &str) -> ImportGroup {
1394 if module_path.contains('.') {
1397 ImportGroup::External
1398 } else {
1399 ImportGroup::Stdlib
1400 }
1401}
1402
1403fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
1405 let root = tree.root_node();
1406 let mut imports = Vec::new();
1407
1408 let mut cursor = root.walk();
1409 if !cursor.goto_first_child() {
1410 return ImportBlock::empty();
1411 }
1412
1413 loop {
1414 let node = cursor.node();
1415 if node.kind() == "import_declaration" {
1416 parse_go_import_declaration(source, &node, &mut imports);
1417 }
1418 if !cursor.goto_next_sibling() {
1419 break;
1420 }
1421 }
1422
1423 let byte_range = import_byte_range(&imports);
1424
1425 ImportBlock {
1426 imports,
1427 byte_range,
1428 }
1429}
1430
1431fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
1433 let mut c = node.walk();
1434 if c.goto_first_child() {
1435 loop {
1436 let child = c.node();
1437 match child.kind() {
1438 "import_spec" => {
1439 if let Some(imp) = parse_go_import_spec(source, &child) {
1440 imports.push(imp);
1441 }
1442 }
1443 "import_spec_list" => {
1444 let mut lc = child.walk();
1446 if lc.goto_first_child() {
1447 loop {
1448 if lc.node().kind() == "import_spec" {
1449 if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
1450 imports.push(imp);
1451 }
1452 }
1453 if !lc.goto_next_sibling() {
1454 break;
1455 }
1456 }
1457 }
1458 }
1459 _ => {}
1460 }
1461 if !c.goto_next_sibling() {
1462 break;
1463 }
1464 }
1465 }
1466}
1467
1468fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
1470 let raw_text = source[node.byte_range()].to_string();
1471 let byte_range = node.byte_range();
1472
1473 let mut import_path = String::new();
1474 let mut alias = None;
1475
1476 let mut c = node.walk();
1477 if c.goto_first_child() {
1478 loop {
1479 let child = c.node();
1480 match child.kind() {
1481 "interpreted_string_literal" => {
1482 let text = source[child.byte_range()].to_string();
1484 import_path = text.trim_matches('"').to_string();
1485 }
1486 "identifier" | "blank_identifier" | "dot" => {
1487 alias = Some(source[child.byte_range()].to_string());
1489 }
1490 _ => {}
1491 }
1492 if !c.goto_next_sibling() {
1493 break;
1494 }
1495 }
1496 }
1497
1498 if import_path.is_empty() {
1499 return None;
1500 }
1501
1502 let group = classify_group_go(&import_path);
1503
1504 Some(ImportStatement {
1505 module_path: import_path,
1506 names: Vec::new(),
1507 default_import: alias,
1508 namespace_import: None,
1509 kind: ImportKind::Value,
1510 group,
1511 byte_range,
1512 raw_text,
1513 })
1514}
1515
1516pub fn generate_go_import_line_pub(
1518 module_path: &str,
1519 alias: Option<&str>,
1520 in_group: bool,
1521) -> String {
1522 generate_go_import_line(module_path, alias, in_group)
1523}
1524
1525fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
1530 if in_group {
1531 match alias {
1533 Some(a) => format!("\t{a} \"{module_path}\""),
1534 None => format!("\t\"{module_path}\""),
1535 }
1536 } else {
1537 match alias {
1539 Some(a) => format!("import {a} \"{module_path}\""),
1540 None => format!("import \"{module_path}\""),
1541 }
1542 }
1543}
1544
1545pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
1548 let root = tree.root_node();
1549 let mut cursor = root.walk();
1550 if !cursor.goto_first_child() {
1551 return None;
1552 }
1553
1554 loop {
1555 let node = cursor.node();
1556 if node.kind() == "import_declaration" {
1557 let mut c = node.walk();
1558 if c.goto_first_child() {
1559 loop {
1560 if c.node().kind() == "import_spec_list" {
1561 return Some(node.byte_range());
1562 }
1563 if !c.goto_next_sibling() {
1564 break;
1565 }
1566 }
1567 }
1568 }
1569 if !cursor.goto_next_sibling() {
1570 break;
1571 }
1572 }
1573 None
1574}
1575
1576fn skip_newline(source: &str, pos: usize) -> usize {
1578 if pos < source.len() {
1579 let bytes = source.as_bytes();
1580 if bytes[pos] == b'\n' {
1581 return pos + 1;
1582 }
1583 if bytes[pos] == b'\r' {
1584 if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
1585 return pos + 2;
1586 }
1587 return pos + 1;
1588 }
1589 }
1590 pos
1591}
1592
1593#[cfg(test)]
1598mod tests {
1599 use super::*;
1600
1601 fn parse_ts(source: &str) -> (Tree, ImportBlock) {
1602 let grammar = grammar_for(LangId::TypeScript);
1603 let mut parser = Parser::new();
1604 parser.set_language(&grammar).unwrap();
1605 let tree = parser.parse(source, None).unwrap();
1606 let block = parse_imports(source, &tree, LangId::TypeScript);
1607 (tree, block)
1608 }
1609
1610 fn parse_js(source: &str) -> (Tree, ImportBlock) {
1611 let grammar = grammar_for(LangId::JavaScript);
1612 let mut parser = Parser::new();
1613 parser.set_language(&grammar).unwrap();
1614 let tree = parser.parse(source, None).unwrap();
1615 let block = parse_imports(source, &tree, LangId::JavaScript);
1616 (tree, block)
1617 }
1618
1619 #[test]
1622 fn parse_ts_named_imports() {
1623 let source = "import { useState, useEffect } from 'react';\n";
1624 let (_, block) = parse_ts(source);
1625 assert_eq!(block.imports.len(), 1);
1626 let imp = &block.imports[0];
1627 assert_eq!(imp.module_path, "react");
1628 assert!(imp.names.contains(&"useState".to_string()));
1629 assert!(imp.names.contains(&"useEffect".to_string()));
1630 assert_eq!(imp.kind, ImportKind::Value);
1631 assert_eq!(imp.group, ImportGroup::External);
1632 }
1633
1634 #[test]
1635 fn parse_ts_default_import() {
1636 let source = "import React from 'react';\n";
1637 let (_, block) = parse_ts(source);
1638 assert_eq!(block.imports.len(), 1);
1639 let imp = &block.imports[0];
1640 assert_eq!(imp.default_import.as_deref(), Some("React"));
1641 assert_eq!(imp.kind, ImportKind::Value);
1642 }
1643
1644 #[test]
1645 fn parse_ts_side_effect_import() {
1646 let source = "import './styles.css';\n";
1647 let (_, block) = parse_ts(source);
1648 assert_eq!(block.imports.len(), 1);
1649 assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
1650 assert_eq!(block.imports[0].module_path, "./styles.css");
1651 }
1652
1653 #[test]
1654 fn parse_ts_relative_import() {
1655 let source = "import { helper } from './utils';\n";
1656 let (_, block) = parse_ts(source);
1657 assert_eq!(block.imports.len(), 1);
1658 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1659 }
1660
1661 #[test]
1662 fn parse_ts_multiple_groups() {
1663 let source = "\
1664import React from 'react';
1665import { useState } from 'react';
1666import { helper } from './utils';
1667import { Config } from '../config';
1668";
1669 let (_, block) = parse_ts(source);
1670 assert_eq!(block.imports.len(), 4);
1671
1672 let external: Vec<_> = block
1673 .imports
1674 .iter()
1675 .filter(|i| i.group == ImportGroup::External)
1676 .collect();
1677 let relative: Vec<_> = block
1678 .imports
1679 .iter()
1680 .filter(|i| i.group == ImportGroup::Internal)
1681 .collect();
1682 assert_eq!(external.len(), 2);
1683 assert_eq!(relative.len(), 2);
1684 }
1685
1686 #[test]
1687 fn parse_ts_namespace_import() {
1688 let source = "import * as path from 'path';\n";
1689 let (_, block) = parse_ts(source);
1690 assert_eq!(block.imports.len(), 1);
1691 let imp = &block.imports[0];
1692 assert_eq!(imp.namespace_import.as_deref(), Some("path"));
1693 assert_eq!(imp.kind, ImportKind::Value);
1694 }
1695
1696 #[test]
1697 fn parse_js_imports() {
1698 let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
1699 let (_, block) = parse_js(source);
1700 assert_eq!(block.imports.len(), 2);
1701 assert_eq!(block.imports[0].group, ImportGroup::External);
1702 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1703 }
1704
1705 #[test]
1708 fn classify_external() {
1709 assert_eq!(classify_group_ts("react"), ImportGroup::External);
1710 assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
1711 assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
1712 }
1713
1714 #[test]
1715 fn classify_relative() {
1716 assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
1717 assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
1718 assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
1719 }
1720
1721 #[test]
1724 fn dedup_detects_same_named_import() {
1725 let source = "import { useState } from 'react';\n";
1726 let (_, block) = parse_ts(source);
1727 assert!(is_duplicate(
1728 &block,
1729 "react",
1730 &["useState".to_string()],
1731 None,
1732 false
1733 ));
1734 }
1735
1736 #[test]
1737 fn dedup_misses_different_name() {
1738 let source = "import { useState } from 'react';\n";
1739 let (_, block) = parse_ts(source);
1740 assert!(!is_duplicate(
1741 &block,
1742 "react",
1743 &["useEffect".to_string()],
1744 None,
1745 false
1746 ));
1747 }
1748
1749 #[test]
1750 fn dedup_detects_default_import() {
1751 let source = "import React from 'react';\n";
1752 let (_, block) = parse_ts(source);
1753 assert!(is_duplicate(&block, "react", &[], Some("React"), false));
1754 }
1755
1756 #[test]
1757 fn dedup_side_effect() {
1758 let source = "import './styles.css';\n";
1759 let (_, block) = parse_ts(source);
1760 assert!(is_duplicate(&block, "./styles.css", &[], None, false));
1761 }
1762
1763 #[test]
1764 fn dedup_namespace_import_distinct_from_side_effect_import() {
1765 let side_effect_source = "import 'fs';\n";
1766 let (_, side_effect_block) = parse_ts(side_effect_source);
1767 assert!(!is_duplicate_with_namespace(
1768 &side_effect_block,
1769 "fs",
1770 &[],
1771 None,
1772 Some("fs"),
1773 false
1774 ));
1775
1776 let namespace_source = "import * as fs from 'fs';\n";
1777 let (_, namespace_block) = parse_ts(namespace_source);
1778 assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
1779 assert!(is_duplicate_with_namespace(
1780 &namespace_block,
1781 "fs",
1782 &[],
1783 None,
1784 Some("fs"),
1785 false
1786 ));
1787 assert!(!is_duplicate_with_namespace(
1788 &namespace_block,
1789 "fs",
1790 &[],
1791 None,
1792 Some("other"),
1793 false
1794 ));
1795 }
1796
1797 #[test]
1798 fn dedup_type_vs_value() {
1799 let source = "import { FC } from 'react';\n";
1800 let (_, block) = parse_ts(source);
1801 assert!(!is_duplicate(
1803 &block,
1804 "react",
1805 &["FC".to_string()],
1806 None,
1807 true
1808 ));
1809 }
1810
1811 #[test]
1814 fn generate_named_import() {
1815 let line = generate_import_line(
1816 LangId::TypeScript,
1817 "react",
1818 &["useState".to_string(), "useEffect".to_string()],
1819 None,
1820 false,
1821 );
1822 assert_eq!(line, "import { useEffect, useState } from 'react';");
1823 }
1824
1825 #[test]
1826 fn generate_named_import_sorts_by_imported_name() {
1827 let line = generate_import_line(
1828 LangId::TypeScript,
1829 "x",
1830 &[
1831 "useState".to_string(),
1832 "type Foo".to_string(),
1833 "stdin as input".to_string(),
1834 "type Bar".to_string(),
1835 ],
1836 None,
1837 false,
1838 );
1839 assert_eq!(
1840 line,
1841 "import { type Bar, type Foo, stdin as input, useState } from 'x';"
1842 );
1843 }
1844
1845 #[test]
1846 fn generate_default_import() {
1847 let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
1848 assert_eq!(line, "import React from 'react';");
1849 }
1850
1851 #[test]
1852 fn generate_type_import() {
1853 let line =
1854 generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
1855 assert_eq!(line, "import type { FC } from 'react';");
1856 }
1857
1858 #[test]
1859 fn generate_side_effect_import() {
1860 let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
1861 assert_eq!(line, "import './styles.css';");
1862 }
1863
1864 #[test]
1865 fn generate_default_and_named() {
1866 let line = generate_import_line(
1867 LangId::TypeScript,
1868 "react",
1869 &["useState".to_string()],
1870 Some("React"),
1871 false,
1872 );
1873 assert_eq!(line, "import React, { useState } from 'react';");
1874 }
1875
1876 #[test]
1877 fn parse_ts_type_import() {
1878 let source = "import type { FC } from 'react';\n";
1879 let (_, block) = parse_ts(source);
1880 assert_eq!(block.imports.len(), 1);
1881 let imp = &block.imports[0];
1882 assert_eq!(imp.kind, ImportKind::Type);
1883 assert!(imp.names.contains(&"FC".to_string()));
1884 assert_eq!(imp.group, ImportGroup::External);
1885 }
1886
1887 #[test]
1890 fn insertion_empty_file() {
1891 let source = "";
1892 let (_, block) = parse_ts(source);
1893 let (offset, _, _) =
1894 find_insertion_point(source, &block, ImportGroup::External, "react", false);
1895 assert_eq!(offset, 0);
1896 }
1897
1898 #[test]
1899 fn insertion_alphabetical_within_group() {
1900 let source = "\
1901import { a } from 'alpha';
1902import { c } from 'charlie';
1903";
1904 let (_, block) = parse_ts(source);
1905 let (offset, _, _) =
1906 find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
1907 let before_charlie = source.find("import { c }").unwrap();
1909 assert_eq!(offset, before_charlie);
1910 }
1911
1912 fn parse_py(source: &str) -> (Tree, ImportBlock) {
1915 let grammar = grammar_for(LangId::Python);
1916 let mut parser = Parser::new();
1917 parser.set_language(&grammar).unwrap();
1918 let tree = parser.parse(source, None).unwrap();
1919 let block = parse_imports(source, &tree, LangId::Python);
1920 (tree, block)
1921 }
1922
1923 #[test]
1924 fn parse_py_import_statement() {
1925 let source = "import os\nimport sys\n";
1926 let (_, block) = parse_py(source);
1927 assert_eq!(block.imports.len(), 2);
1928 assert_eq!(block.imports[0].module_path, "os");
1929 assert_eq!(block.imports[1].module_path, "sys");
1930 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1931 }
1932
1933 #[test]
1934 fn parse_py_from_import() {
1935 let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
1936 let (_, block) = parse_py(source);
1937 assert_eq!(block.imports.len(), 2);
1938 assert_eq!(block.imports[0].module_path, "collections");
1939 assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
1940 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
1941 assert_eq!(block.imports[1].module_path, "typing");
1942 assert!(block.imports[1].names.contains(&"List".to_string()));
1943 assert!(block.imports[1].names.contains(&"Optional".to_string()));
1944 }
1945
1946 #[test]
1947 fn parse_py_relative_import() {
1948 let source = "from . import utils\nfrom ..config import Settings\n";
1949 let (_, block) = parse_py(source);
1950 assert_eq!(block.imports.len(), 2);
1951 assert_eq!(block.imports[0].module_path, ".");
1952 assert!(block.imports[0].names.contains(&"utils".to_string()));
1953 assert_eq!(block.imports[0].group, ImportGroup::Internal);
1954 assert_eq!(block.imports[1].module_path, "..config");
1955 assert_eq!(block.imports[1].group, ImportGroup::Internal);
1956 }
1957
1958 #[test]
1959 fn classify_py_groups() {
1960 assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
1961 assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
1962 assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
1963 assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
1964 assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
1965 assert_eq!(classify_group_py("requests"), ImportGroup::External);
1966 assert_eq!(classify_group_py("flask"), ImportGroup::External);
1967 assert_eq!(classify_group_py("."), ImportGroup::Internal);
1968 assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
1969 assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
1970 }
1971
1972 #[test]
1973 fn parse_py_three_groups() {
1974 let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
1975 let (_, block) = parse_py(source);
1976 let stdlib: Vec<_> = block
1977 .imports
1978 .iter()
1979 .filter(|i| i.group == ImportGroup::Stdlib)
1980 .collect();
1981 let external: Vec<_> = block
1982 .imports
1983 .iter()
1984 .filter(|i| i.group == ImportGroup::External)
1985 .collect();
1986 let internal: Vec<_> = block
1987 .imports
1988 .iter()
1989 .filter(|i| i.group == ImportGroup::Internal)
1990 .collect();
1991 assert_eq!(stdlib.len(), 2);
1992 assert_eq!(external.len(), 1);
1993 assert_eq!(internal.len(), 1);
1994 }
1995
1996 #[test]
1997 fn generate_py_import() {
1998 let line = generate_import_line(LangId::Python, "os", &[], None, false);
1999 assert_eq!(line, "import os");
2000 }
2001
2002 #[test]
2003 fn generate_py_from_import() {
2004 let line = generate_import_line(
2005 LangId::Python,
2006 "collections",
2007 &["OrderedDict".to_string()],
2008 None,
2009 false,
2010 );
2011 assert_eq!(line, "from collections import OrderedDict");
2012 }
2013
2014 #[test]
2015 fn generate_py_from_import_multiple() {
2016 let line = generate_import_line(
2017 LangId::Python,
2018 "typing",
2019 &["Optional".to_string(), "List".to_string()],
2020 None,
2021 false,
2022 );
2023 assert_eq!(line, "from typing import List, Optional");
2024 }
2025
2026 fn parse_rust(source: &str) -> (Tree, ImportBlock) {
2029 let grammar = grammar_for(LangId::Rust);
2030 let mut parser = Parser::new();
2031 parser.set_language(&grammar).unwrap();
2032 let tree = parser.parse(source, None).unwrap();
2033 let block = parse_imports(source, &tree, LangId::Rust);
2034 (tree, block)
2035 }
2036
2037 #[test]
2038 fn parse_rs_use_std() {
2039 let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
2040 let (_, block) = parse_rust(source);
2041 assert_eq!(block.imports.len(), 2);
2042 assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
2043 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2044 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2045 }
2046
2047 #[test]
2048 fn parse_rs_use_external() {
2049 let source = "use serde::{Deserialize, Serialize};\n";
2050 let (_, block) = parse_rust(source);
2051 assert_eq!(block.imports.len(), 1);
2052 assert_eq!(block.imports[0].group, ImportGroup::External);
2053 assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
2054 assert!(block.imports[0].names.contains(&"Serialize".to_string()));
2055 }
2056
2057 #[test]
2058 fn parse_rs_use_crate() {
2059 let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
2060 let (_, block) = parse_rust(source);
2061 assert_eq!(block.imports.len(), 2);
2062 assert_eq!(block.imports[0].group, ImportGroup::Internal);
2063 assert_eq!(block.imports[1].group, ImportGroup::Internal);
2064 }
2065
2066 #[test]
2067 fn parse_rs_pub_use() {
2068 let source = "pub use super::parent::Thing;\n";
2069 let (_, block) = parse_rust(source);
2070 assert_eq!(block.imports.len(), 1);
2071 assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
2073 }
2074
2075 #[test]
2076 fn classify_rs_groups() {
2077 assert_eq!(
2078 classify_group_rs("std::collections::HashMap"),
2079 ImportGroup::Stdlib
2080 );
2081 assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
2082 assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
2083 assert_eq!(
2084 classify_group_rs("serde::Deserialize"),
2085 ImportGroup::External
2086 );
2087 assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
2088 assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
2089 assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
2090 assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
2091 }
2092
2093 #[test]
2094 fn generate_rs_use() {
2095 let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
2096 assert_eq!(line, "use std::fmt::Display;");
2097 }
2098
2099 fn parse_go(source: &str) -> (Tree, ImportBlock) {
2102 let grammar = grammar_for(LangId::Go);
2103 let mut parser = Parser::new();
2104 parser.set_language(&grammar).unwrap();
2105 let tree = parser.parse(source, None).unwrap();
2106 let block = parse_imports(source, &tree, LangId::Go);
2107 (tree, block)
2108 }
2109
2110 #[test]
2111 fn parse_go_single_import() {
2112 let source = "package main\n\nimport \"fmt\"\n";
2113 let (_, block) = parse_go(source);
2114 assert_eq!(block.imports.len(), 1);
2115 assert_eq!(block.imports[0].module_path, "fmt");
2116 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2117 }
2118
2119 #[test]
2120 fn parse_go_grouped_import() {
2121 let source =
2122 "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
2123 let (_, block) = parse_go(source);
2124 assert_eq!(block.imports.len(), 3);
2125 assert_eq!(block.imports[0].module_path, "fmt");
2126 assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
2127 assert_eq!(block.imports[1].module_path, "os");
2128 assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
2129 assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
2130 assert_eq!(block.imports[2].group, ImportGroup::External);
2131 }
2132
2133 #[test]
2134 fn parse_go_mixed_imports() {
2135 let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
2137 let (_, block) = parse_go(source);
2138 assert_eq!(block.imports.len(), 3);
2139 }
2140
2141 #[test]
2142 fn classify_go_groups() {
2143 assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
2144 assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
2145 assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
2146 assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
2147 assert_eq!(
2148 classify_group_go("github.com/pkg/errors"),
2149 ImportGroup::External
2150 );
2151 assert_eq!(
2152 classify_group_go("golang.org/x/tools"),
2153 ImportGroup::External
2154 );
2155 }
2156
2157 #[test]
2158 fn generate_go_standalone() {
2159 let line = generate_go_import_line("fmt", None, false);
2160 assert_eq!(line, "import \"fmt\"");
2161 }
2162
2163 #[test]
2164 fn generate_go_grouped_spec() {
2165 let line = generate_go_import_line("fmt", None, true);
2166 assert_eq!(line, "\t\"fmt\"");
2167 }
2168
2169 #[test]
2170 fn generate_go_with_alias() {
2171 let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
2172 assert_eq!(line, "import errs \"github.com/pkg/errors\"");
2173 }
2174}