use std::ops::Range;
use tree_sitter::{Node, Parser, Tree};
use crate::parser::{grammar_for, LangId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImportKind {
Value,
Type,
SideEffect,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ImportGroup {
Stdlib,
External,
Internal,
}
impl ImportGroup {
pub fn label(&self) -> &'static str {
match self {
ImportGroup::Stdlib => "stdlib",
ImportGroup::External => "external",
ImportGroup::Internal => "internal",
}
}
}
#[derive(Debug, Clone)]
pub struct ImportStatement {
pub module_path: String,
pub names: Vec<String>,
pub default_import: Option<String>,
pub namespace_import: Option<String>,
pub kind: ImportKind,
pub group: ImportGroup,
pub byte_range: Range<usize>,
pub raw_text: String,
}
#[derive(Debug, Clone)]
pub struct ImportBlock {
pub imports: Vec<ImportStatement>,
pub byte_range: Option<Range<usize>>,
}
impl ImportBlock {
pub fn empty() -> Self {
ImportBlock {
imports: Vec::new(),
byte_range: None,
}
}
}
fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
imports.first().zip(imports.last()).map(|(first, last)| {
let start = first.byte_range.start;
let end = last.byte_range.end;
start..end
})
}
pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => parse_ts_imports(source, tree),
LangId::Python => parse_py_imports(source, tree),
LangId::Rust => parse_rs_imports(source, tree),
LangId::Go => parse_go_imports(source, tree),
LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => ImportBlock::empty(),
LangId::Html | LangId::Markdown => ImportBlock::empty(),
}
}
pub fn is_duplicate(
block: &ImportBlock,
module_path: &str,
names: &[String],
default_import: Option<&str>,
type_only: bool,
) -> bool {
let target_kind = if type_only {
ImportKind::Type
} else {
ImportKind::Value
};
for imp in &block.imports {
if imp.module_path != module_path {
continue;
}
if names.is_empty()
&& default_import.is_none()
&& imp.names.is_empty()
&& imp.default_import.is_none()
{
return true;
}
if names.is_empty() && default_import.is_none() && imp.kind == ImportKind::SideEffect {
return true;
}
if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
continue;
}
if let Some(def) = default_import {
if imp.default_import.as_deref() == Some(def) {
return true;
}
}
if !names.is_empty() && names.iter().all(|n| imp.names.contains(n)) {
return true;
}
}
false
}
pub fn find_insertion_point(
source: &str,
block: &ImportBlock,
group: ImportGroup,
module_path: &str,
type_only: bool,
) -> (usize, bool, bool) {
if block.imports.is_empty() {
return (0, false, source.is_empty().then_some(false).unwrap_or(true));
}
let target_kind = if type_only {
ImportKind::Type
} else {
ImportKind::Value
};
let group_imports: Vec<&ImportStatement> =
block.imports.iter().filter(|i| i.group == group).collect();
if group_imports.is_empty() {
let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
if let Some(last) = preceding_last {
let end = last.byte_range.end;
let insert_at = skip_newline(source, end);
return (insert_at, true, true);
}
let following_first = block.imports.iter().find(|i| i.group > group);
if let Some(first) = following_first {
return (first.byte_range.start, false, true);
}
let first_byte = import_byte_range(&block.imports)
.map(|range| range.start)
.unwrap_or(0);
return (first_byte, false, true);
}
for imp in &group_imports {
let cmp = module_path.cmp(&imp.module_path);
match cmp {
std::cmp::Ordering::Less => {
return (imp.byte_range.start, false, false);
}
std::cmp::Ordering::Equal => {
if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
let end = imp.byte_range.end;
let insert_at = skip_newline(source, end);
return (insert_at, false, false);
}
return (imp.byte_range.start, false, false);
}
std::cmp::Ordering::Greater => continue,
}
}
let Some(last) = group_imports.last() else {
return (
import_byte_range(&block.imports)
.map(|range| range.end)
.unwrap_or(0),
false,
false,
);
};
let end = last.byte_range.end;
let insert_at = skip_newline(source, end);
(insert_at, false, false)
}
pub fn generate_import_line(
lang: LangId,
module_path: &str,
names: &[String],
default_import: Option<&str>,
type_only: bool,
) -> String {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
generate_ts_import_line(module_path, names, default_import, type_only)
}
LangId::Python => generate_py_import_line(module_path, names, default_import),
LangId::Rust => generate_rs_import_line(module_path, names, type_only),
LangId::Go => generate_go_import_line(module_path, default_import, false),
LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => String::new(),
LangId::Html | LangId::Markdown => String::new(),
}
}
pub fn is_supported(lang: LangId) -> bool {
matches!(
lang,
LangId::TypeScript
| LangId::Tsx
| LangId::JavaScript
| LangId::Python
| LangId::Rust
| LangId::Go
)
}
pub fn classify_group_ts(module_path: &str) -> ImportGroup {
if module_path.starts_with('.') {
ImportGroup::Internal
} else {
ImportGroup::External
}
}
pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => classify_group_ts(module_path),
LangId::Python => classify_group_py(module_path),
LangId::Rust => classify_group_rs(module_path),
LangId::Go => classify_group_go(module_path),
LangId::C | LangId::Cpp | LangId::Zig | LangId::CSharp => ImportGroup::External,
LangId::Html | LangId::Markdown => ImportGroup::External,
}
}
pub fn parse_file_imports(
path: &std::path::Path,
lang: LangId,
) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
let source =
std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
path: format!("{}: {}", path.display(), e),
})?;
let grammar = grammar_for(lang);
let mut parser = Parser::new();
parser
.set_language(&grammar)
.map_err(|e| crate::error::AftError::ParseError {
message: format!("grammar init failed for {:?}: {}", lang, e),
})?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| crate::error::AftError::ParseError {
message: format!("tree-sitter parse returned None for {}", path.display()),
})?;
let block = parse_imports(&source, &tree, lang);
Ok((source, tree, block))
}
fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "import_statement" {
if let Some(imp) = parse_single_ts_import(source, &node) {
imports.push(imp);
}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let module_path = extract_module_path(source, node)?;
let is_type_only = has_type_keyword(node);
let mut names = Vec::new();
let mut default_import = None;
let mut namespace_import = None;
let mut child_cursor = node.walk();
if child_cursor.goto_first_child() {
loop {
let child = child_cursor.node();
match child.kind() {
"import_clause" => {
extract_import_clause(
source,
&child,
&mut names,
&mut default_import,
&mut namespace_import,
);
}
"identifier" => {
let text = &source[child.byte_range()];
if text != "import" && text != "from" && text != "type" {
default_import = Some(text.to_string());
}
}
_ => {}
}
if !child_cursor.goto_next_sibling() {
break;
}
}
}
let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
ImportKind::SideEffect
} else if is_type_only {
ImportKind::Type
} else {
ImportKind::Value
};
let group = classify_group_ts(&module_path);
Some(ImportStatement {
module_path,
names,
default_import,
namespace_import,
kind,
group,
byte_range,
raw_text,
})
}
fn extract_module_path(source: &str, node: &Node) -> Option<String> {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return None;
}
loop {
let child = cursor.node();
if child.kind() == "string" {
let text = &source[child.byte_range()];
let stripped = text
.trim_start_matches(|c| c == '\'' || c == '"')
.trim_end_matches(|c| c == '\'' || c == '"');
return Some(stripped.to_string());
}
if !cursor.goto_next_sibling() {
break;
}
}
None
}
fn has_type_keyword(node: &Node) -> bool {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return false;
}
loop {
let child = cursor.node();
if child.kind() == "type" {
return true;
}
if !cursor.goto_next_sibling() {
break;
}
}
false
}
fn extract_import_clause(
source: &str,
node: &Node,
names: &mut Vec<String>,
default_import: &mut Option<String>,
namespace_import: &mut Option<String>,
) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
match child.kind() {
"identifier" => {
let text = &source[child.byte_range()];
if text != "type" {
*default_import = Some(text.to_string());
}
}
"named_imports" => {
extract_named_imports(source, &child, names);
}
"namespace_import" => {
extract_namespace_import(source, &child, namespace_import);
}
_ => {}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
if child.kind() == "import_specifier" {
if let Some(name_node) = child.child_by_field_name("name") {
names.push(source[name_node.byte_range()].to_string());
} else {
let mut spec_cursor = child.walk();
if spec_cursor.goto_first_child() {
loop {
if spec_cursor.node().kind() == "identifier"
|| spec_cursor.node().kind() == "type_identifier"
{
names.push(source[spec_cursor.node().byte_range()].to_string());
break;
}
if !spec_cursor.goto_next_sibling() {
break;
}
}
}
}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
if child.kind() == "identifier" {
*namespace_import = Some(source[child.byte_range()].to_string());
return;
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn generate_ts_import_line(
module_path: &str,
names: &[String],
default_import: Option<&str>,
type_only: bool,
) -> String {
let type_prefix = if type_only { "type " } else { "" };
if names.is_empty() && default_import.is_none() {
return format!("import '{module_path}';");
}
if names.is_empty() {
if let Some(def) = default_import {
return format!("import {type_prefix}{def} from '{module_path}';");
}
}
if default_import.is_none() {
let mut sorted_names = names.to_vec();
sorted_names.sort();
let names_str = sorted_names.join(", ");
return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
}
if let Some(def) = default_import {
let mut sorted_names = names.to_vec();
sorted_names.sort();
let names_str = sorted_names.join(", ");
return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
}
format!("import '{module_path}';")
}
const PYTHON_STDLIB: &[&str] = &[
"__future__",
"_thread",
"abc",
"aifc",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"bisect",
"builtins",
"bz2",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"cProfile",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"idlelib",
"imaplib",
"imghdr",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"multiprocessing",
"netrc",
"numbers",
"operator",
"optparse",
"os",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"sqlite3",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"tempfile",
"termios",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"tomllib",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"wsgiref",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
];
pub fn classify_group_py(module_path: &str) -> ImportGroup {
if module_path.starts_with('.') {
return ImportGroup::Internal;
}
let top_module = module_path.split('.').next().unwrap_or(module_path);
if PYTHON_STDLIB.contains(&top_module) {
ImportGroup::Stdlib
} else {
ImportGroup::External
}
}
fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
match node.kind() {
"import_statement" => {
if let Some(imp) = parse_py_import_statement(source, &node) {
imports.push(imp);
}
}
"import_from_statement" => {
if let Some(imp) = parse_py_import_from_statement(source, &node) {
imports.push(imp);
}
}
_ => {}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut module_path = String::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
if c.node().kind() == "dotted_name" {
module_path = source[c.node().byte_range()].to_string();
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
if module_path.is_empty() {
return None;
}
let group = classify_group_py(&module_path);
Some(ImportStatement {
module_path,
names: Vec::new(),
default_import: None,
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
})
}
fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut module_path = String::new();
let mut names = Vec::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"dotted_name" => {
if module_path.is_empty()
&& !has_seen_import_keyword(source, node, child.start_byte())
{
module_path = source[child.byte_range()].to_string();
} else {
names.push(source[child.byte_range()].to_string());
}
}
"relative_import" => {
module_path = source[child.byte_range()].to_string();
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if module_path.is_empty() {
return None;
}
let group = classify_group_py(&module_path);
Some(ImportStatement {
module_path,
names,
default_import: None,
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
})
}
fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
let mut c = parent.walk();
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "import" && child.start_byte() < before_byte {
return true;
}
if child.start_byte() >= before_byte {
return false;
}
if !c.goto_next_sibling() {
break;
}
}
}
false
}
fn generate_py_import_line(
module_path: &str,
names: &[String],
_default_import: Option<&str>,
) -> String {
if names.is_empty() {
format!("import {module_path}")
} else {
let mut sorted = names.to_vec();
sorted.sort();
let names_str = sorted.join(", ");
format!("from {module_path} import {names_str}")
}
}
pub fn classify_group_rs(module_path: &str) -> ImportGroup {
let first_seg = module_path.split("::").next().unwrap_or(module_path);
match first_seg {
"std" | "core" | "alloc" => ImportGroup::Stdlib,
"crate" | "self" | "super" => ImportGroup::Internal,
_ => ImportGroup::External,
}
}
fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "use_declaration" {
if let Some(imp) = parse_rs_use_declaration(source, &node) {
imports.push(imp);
}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut has_pub = false;
let mut use_path = String::new();
let mut names = Vec::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"visibility_modifier" => {
has_pub = true;
}
"scoped_identifier" | "identifier" | "use_as_clause" => {
use_path = source[child.byte_range()].to_string();
}
"scoped_use_list" => {
use_path = source[child.byte_range()].to_string();
extract_rs_use_list_names(source, &child, &mut names);
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if use_path.is_empty() {
return None;
}
let group = classify_group_rs(&use_path);
Some(ImportStatement {
module_path: use_path,
names,
default_import: if has_pub {
Some("pub".to_string())
} else {
None
},
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
})
}
fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "use_list" {
let mut lc = child.walk();
if lc.goto_first_child() {
loop {
let lchild = lc.node();
if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
names.push(source[lchild.byte_range()].to_string());
}
if !lc.goto_next_sibling() {
break;
}
}
}
}
if !c.goto_next_sibling() {
break;
}
}
}
}
fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
if names.is_empty() {
format!("use {module_path};")
} else {
format!("use {module_path};")
}
}
pub fn classify_group_go(module_path: &str) -> ImportGroup {
if module_path.contains('.') {
ImportGroup::External
} else {
ImportGroup::Stdlib
}
}
fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration" {
parse_go_import_declaration(source, &node, &mut imports);
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"import_spec" => {
if let Some(imp) = parse_go_import_spec(source, &child) {
imports.push(imp);
}
}
"import_spec_list" => {
let mut lc = child.walk();
if lc.goto_first_child() {
loop {
if lc.node().kind() == "import_spec" {
if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
imports.push(imp);
}
}
if !lc.goto_next_sibling() {
break;
}
}
}
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
}
fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut import_path = String::new();
let mut alias = None;
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"interpreted_string_literal" => {
let text = source[child.byte_range()].to_string();
import_path = text.trim_matches('"').to_string();
}
"identifier" | "blank_identifier" | "dot" => {
alias = Some(source[child.byte_range()].to_string());
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if import_path.is_empty() {
return None;
}
let group = classify_group_go(&import_path);
Some(ImportStatement {
module_path: import_path,
names: Vec::new(),
default_import: alias,
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
})
}
pub fn generate_go_import_line_pub(
module_path: &str,
alias: Option<&str>,
in_group: bool,
) -> String {
generate_go_import_line(module_path, alias, in_group)
}
fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
if in_group {
match alias {
Some(a) => format!("\t{a} \"{module_path}\""),
None => format!("\t\"{module_path}\""),
}
} else {
match alias {
Some(a) => format!("import {a} \"{module_path}\""),
None => format!("import \"{module_path}\""),
}
}
}
pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
let root = tree.root_node();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return None;
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration" {
let mut c = node.walk();
if c.goto_first_child() {
loop {
if c.node().kind() == "import_spec_list" {
return Some(c.node().byte_range());
}
if !c.goto_next_sibling() {
break;
}
}
}
}
if !cursor.goto_next_sibling() {
break;
}
}
None
}
fn skip_newline(source: &str, pos: usize) -> usize {
if pos < source.len() {
let bytes = source.as_bytes();
if bytes[pos] == b'\n' {
return pos + 1;
}
if bytes[pos] == b'\r' {
if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
return pos + 2;
}
return pos + 1;
}
}
pos
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_ts(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::TypeScript);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::TypeScript);
(tree, block)
}
fn parse_js(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::JavaScript);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::JavaScript);
(tree, block)
}
#[test]
fn parse_ts_named_imports() {
let source = "import { useState, useEffect } from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.module_path, "react");
assert!(imp.names.contains(&"useState".to_string()));
assert!(imp.names.contains(&"useEffect".to_string()));
assert_eq!(imp.kind, ImportKind::Value);
assert_eq!(imp.group, ImportGroup::External);
}
#[test]
fn parse_ts_default_import() {
let source = "import React from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.default_import.as_deref(), Some("React"));
assert_eq!(imp.kind, ImportKind::Value);
}
#[test]
fn parse_ts_side_effect_import() {
let source = "import './styles.css';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
assert_eq!(block.imports[0].module_path, "./styles.css");
}
#[test]
fn parse_ts_relative_import() {
let source = "import { helper } from './utils';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].group, ImportGroup::Internal);
}
#[test]
fn parse_ts_multiple_groups() {
let source = "\
import React from 'react';
import { useState } from 'react';
import { helper } from './utils';
import { Config } from '../config';
";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 4);
let external: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::External)
.collect();
let relative: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Internal)
.collect();
assert_eq!(external.len(), 2);
assert_eq!(relative.len(), 2);
}
#[test]
fn parse_ts_namespace_import() {
let source = "import * as path from 'path';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.namespace_import.as_deref(), Some("path"));
assert_eq!(imp.kind, ImportKind::Value);
}
#[test]
fn parse_js_imports() {
let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
let (_, block) = parse_js(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].group, ImportGroup::External);
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn classify_external() {
assert_eq!(classify_group_ts("react"), ImportGroup::External);
assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
}
#[test]
fn classify_relative() {
assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
}
#[test]
fn dedup_detects_same_named_import() {
let source = "import { useState } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(
&block,
"react",
&["useState".to_string()],
None,
false
));
}
#[test]
fn dedup_misses_different_name() {
let source = "import { useState } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(!is_duplicate(
&block,
"react",
&["useEffect".to_string()],
None,
false
));
}
#[test]
fn dedup_detects_default_import() {
let source = "import React from 'react';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(&block, "react", &[], Some("React"), false));
}
#[test]
fn dedup_side_effect() {
let source = "import './styles.css';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(&block, "./styles.css", &[], None, false));
}
#[test]
fn dedup_type_vs_value() {
let source = "import { FC } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(!is_duplicate(
&block,
"react",
&["FC".to_string()],
None,
true
));
}
#[test]
fn generate_named_import() {
let line = generate_import_line(
LangId::TypeScript,
"react",
&["useState".to_string(), "useEffect".to_string()],
None,
false,
);
assert_eq!(line, "import { useEffect, useState } from 'react';");
}
#[test]
fn generate_default_import() {
let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
assert_eq!(line, "import React from 'react';");
}
#[test]
fn generate_type_import() {
let line =
generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
assert_eq!(line, "import type { FC } from 'react';");
}
#[test]
fn generate_side_effect_import() {
let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
assert_eq!(line, "import './styles.css';");
}
#[test]
fn generate_default_and_named() {
let line = generate_import_line(
LangId::TypeScript,
"react",
&["useState".to_string()],
Some("React"),
false,
);
assert_eq!(line, "import React, { useState } from 'react';");
}
#[test]
fn parse_ts_type_import() {
let source = "import type { FC } from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.kind, ImportKind::Type);
assert!(imp.names.contains(&"FC".to_string()));
assert_eq!(imp.group, ImportGroup::External);
}
#[test]
fn insertion_empty_file() {
let source = "";
let (_, block) = parse_ts(source);
let (offset, _, _) =
find_insertion_point(source, &block, ImportGroup::External, "react", false);
assert_eq!(offset, 0);
}
#[test]
fn insertion_alphabetical_within_group() {
let source = "\
import { a } from 'alpha';
import { c } from 'charlie';
";
let (_, block) = parse_ts(source);
let (offset, _, _) =
find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
let before_charlie = source.find("import { c }").unwrap();
assert_eq!(offset, before_charlie);
}
fn parse_py(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Python);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Python);
(tree, block)
}
#[test]
fn parse_py_import_statement() {
let source = "import os\nimport sys\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "os");
assert_eq!(block.imports[1].module_path, "sys");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
}
#[test]
fn parse_py_from_import() {
let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "collections");
assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].module_path, "typing");
assert!(block.imports[1].names.contains(&"List".to_string()));
assert!(block.imports[1].names.contains(&"Optional".to_string()));
}
#[test]
fn parse_py_relative_import() {
let source = "from . import utils\nfrom ..config import Settings\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, ".");
assert!(block.imports[0].names.contains(&"utils".to_string()));
assert_eq!(block.imports[0].group, ImportGroup::Internal);
assert_eq!(block.imports[1].module_path, "..config");
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn classify_py_groups() {
assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("requests"), ImportGroup::External);
assert_eq!(classify_group_py("flask"), ImportGroup::External);
assert_eq!(classify_group_py("."), ImportGroup::Internal);
assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
}
#[test]
fn parse_py_three_groups() {
let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
let (_, block) = parse_py(source);
let stdlib: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Stdlib)
.collect();
let external: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::External)
.collect();
let internal: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Internal)
.collect();
assert_eq!(stdlib.len(), 2);
assert_eq!(external.len(), 1);
assert_eq!(internal.len(), 1);
}
#[test]
fn generate_py_import() {
let line = generate_import_line(LangId::Python, "os", &[], None, false);
assert_eq!(line, "import os");
}
#[test]
fn generate_py_from_import() {
let line = generate_import_line(
LangId::Python,
"collections",
&["OrderedDict".to_string()],
None,
false,
);
assert_eq!(line, "from collections import OrderedDict");
}
#[test]
fn generate_py_from_import_multiple() {
let line = generate_import_line(
LangId::Python,
"typing",
&["Optional".to_string(), "List".to_string()],
None,
false,
);
assert_eq!(line, "from typing import List, Optional");
}
fn parse_rust(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Rust);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Rust);
(tree, block)
}
#[test]
fn parse_rs_use_std() {
let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
}
#[test]
fn parse_rs_use_external() {
let source = "use serde::{Deserialize, Serialize};\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].group, ImportGroup::External);
assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
assert!(block.imports[0].names.contains(&"Serialize".to_string()));
}
#[test]
fn parse_rs_use_crate() {
let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].group, ImportGroup::Internal);
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn parse_rs_pub_use() {
let source = "pub use super::parent::Thing;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
}
#[test]
fn classify_rs_groups() {
assert_eq!(
classify_group_rs("std::collections::HashMap"),
ImportGroup::Stdlib
);
assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
assert_eq!(
classify_group_rs("serde::Deserialize"),
ImportGroup::External
);
assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
}
#[test]
fn generate_rs_use() {
let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
assert_eq!(line, "use std::fmt::Display;");
}
fn parse_go(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Go);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Go);
(tree, block)
}
#[test]
fn parse_go_single_import() {
let source = "package main\n\nimport \"fmt\"\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].module_path, "fmt");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
}
#[test]
fn parse_go_grouped_import() {
let source =
"package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 3);
assert_eq!(block.imports[0].module_path, "fmt");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].module_path, "os");
assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
assert_eq!(block.imports[2].group, ImportGroup::External);
}
#[test]
fn parse_go_mixed_imports() {
let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 3);
}
#[test]
fn classify_go_groups() {
assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
assert_eq!(
classify_group_go("github.com/pkg/errors"),
ImportGroup::External
);
assert_eq!(
classify_group_go("golang.org/x/tools"),
ImportGroup::External
);
}
#[test]
fn generate_go_standalone() {
let line = generate_go_import_line("fmt", None, false);
assert_eq!(line, "import \"fmt\"");
}
#[test]
fn generate_go_grouped_spec() {
let line = generate_go_import_line("fmt", None, true);
assert_eq!(line, "\t\"fmt\"");
}
#[test]
fn generate_go_with_alias() {
let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
assert_eq!(line, "import errs \"github.com/pkg/errors\"");
}
}