use std::ops::Range;
use tree_sitter::{Node, Parser, Tree};
use crate::parser::{grammar_for, LangId};
mod c;
pub(crate) use c::{classify_group_c_import_kind, normalize_include_module};
mod csharp;
mod java;
mod kotlin;
mod lua;
mod perl;
mod php;
pub(crate) use php::{php_grouped_use_matches_module, php_grouped_use_shares_prefix};
mod ruby;
mod scala;
pub(crate) use scala::scala_block_uses_scala2_dialect;
mod swift;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImportKind {
Value,
Type,
SideEffect,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ImportGroup {
Stdlib,
External,
Internal,
}
impl ImportGroup {
pub fn label(&self) -> &'static str {
match self {
ImportGroup::Stdlib => "stdlib",
ImportGroup::External => "external",
ImportGroup::Internal => "internal",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ImportForm {
Es {
default_import: Option<String>,
namespace_import: Option<String>,
named: Vec<String>,
type_only: bool,
side_effect: bool,
},
Python {
from_import: bool,
named: Vec<String>,
},
RustUse {
visibility: Option<String>,
named: Vec<String>,
},
Go { alias: Option<String> },
Solidity {
named: Vec<String>,
namespace: Option<String>,
alias: Option<String>,
},
Structured {
named: Vec<String>,
namespace: Option<String>,
alias: Option<String>,
modifiers: Vec<String>,
import_kind: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct ImportRequest<'a> {
pub module_path: &'a str,
pub names: &'a [String],
pub default_import: Option<&'a str>,
pub namespace: Option<&'a str>,
pub alias: Option<&'a str>,
pub type_only: bool,
pub modifiers: &'a [String],
pub import_kind: Option<&'a str>,
}
const NO_MODIFIERS: &[String] = &[];
impl<'a> ImportRequest<'a> {
pub fn legacy(
module_path: &'a str,
names: &'a [String],
default_import: Option<&'a str>,
namespace: Option<&'a str>,
type_only: bool,
) -> Self {
ImportRequest {
module_path,
names,
default_import,
namespace,
alias: None,
type_only,
modifiers: NO_MODIFIERS,
import_kind: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ImportStatement {
pub module_path: String,
pub names: Vec<String>,
pub default_import: Option<String>,
pub namespace_import: Option<String>,
pub kind: ImportKind,
pub group: ImportGroup,
pub byte_range: Range<usize>,
pub raw_text: String,
pub form: ImportForm,
}
#[derive(Debug, Clone)]
pub struct ImportBlock {
pub imports: Vec<ImportStatement>,
pub byte_range: Option<Range<usize>>,
}
impl ImportBlock {
pub fn empty() -> Self {
ImportBlock {
imports: Vec::new(),
byte_range: None,
}
}
}
pub(crate) fn import_byte_range(imports: &[ImportStatement]) -> Option<Range<usize>> {
imports.first().zip(imports.last()).map(|(first, last)| {
let start = first.byte_range.start;
let end = last.byte_range.end;
start..end
})
}
pub fn specifier_local_name(spec: &str) -> &str {
let trimmed = spec.trim();
let after_type = trimmed
.strip_prefix("type ")
.unwrap_or(trimmed)
.trim_start();
if let Some(idx) = after_type.find(" as ") {
after_type[idx + 4..].trim()
} else {
after_type
}
}
pub fn specifier_imported_name(spec: &str) -> &str {
let trimmed = spec.trim();
let after_type = trimmed
.strip_prefix("type ")
.unwrap_or(trimmed)
.trim_start();
after_type
.find(" as ")
.map(|idx| after_type[..idx].trim())
.unwrap_or(after_type)
}
pub fn specifier_matches(spec: &str, target: &str) -> bool {
specifier_imported_name(spec) == target || specifier_local_name(spec) == target
}
pub trait ImportSyntax: Sync {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock;
fn generate_line(&self, req: &ImportRequest) -> String;
fn classify_group(&self, module_path: &str) -> ImportGroup;
}
struct EsSyntax;
impl ImportSyntax for EsSyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_ts_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_ts_import_line(
req.module_path,
req.names,
req.default_import,
req.namespace,
req.type_only,
)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_ts(module_path)
}
}
struct PythonSyntax;
impl ImportSyntax for PythonSyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_py_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_py_import_line(req.module_path, req.names, req.default_import)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_py(module_path)
}
}
struct RustSyntax;
impl ImportSyntax for RustSyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_rs_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_rs_import_line(req.module_path, req.names, req.type_only)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_rs(module_path)
}
}
struct GoSyntax;
impl ImportSyntax for GoSyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_go_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_go_import_line(req.module_path, req.default_import, false)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_go(module_path)
}
}
struct SoliditySyntax;
impl ImportSyntax for SoliditySyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_solidity_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_solidity_import_line(req)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_solidity(module_path)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum VueScriptRangeError {
MissingScript,
MultipleScripts,
}
impl VueScriptRangeError {
pub(crate) fn code(self) -> &'static str {
match self {
VueScriptRangeError::MissingScript => "missing_vue_script",
VueScriptRangeError::MultipleScripts => "ambiguous_vue_script",
}
}
pub(crate) fn message(self, command: &str) -> String {
match self {
VueScriptRangeError::MissingScript => format!(
"{command}: Vue import management requires exactly one <script> block; found none"
),
VueScriptRangeError::MultipleScripts => format!(
"{command}: Vue import management requires exactly one <script> block; found multiple"
),
}
}
}
pub(crate) fn vue_single_script_content_range(
tree: &Tree,
) -> Result<(usize, usize), VueScriptRangeError> {
let root = tree.root_node();
let mut ranges = Vec::new();
let mut cursor = root.walk();
for child in root.named_children(&mut cursor) {
if child.kind() == "script_element" {
ranges.push(vue_script_element_content_range(&child));
}
}
match ranges.len() {
0 => Err(VueScriptRangeError::MissingScript),
1 => Ok(ranges[0]),
_ => Err(VueScriptRangeError::MultipleScripts),
}
}
pub(crate) fn vue_script_content_range(tree: &Tree) -> Option<(usize, usize)> {
vue_single_script_content_range(tree).ok()
}
fn vue_script_element_content_range(child: &Node) -> (usize, usize) {
let mut inner = child.walk();
for sub in child.named_children(&mut inner) {
if sub.kind() == "raw_text" {
return (sub.start_byte(), sub.end_byte());
}
}
let mut inner2 = child.walk();
for sub in child.named_children(&mut inner2) {
if sub.kind() == "start_tag" {
return (sub.end_byte(), sub.end_byte());
}
}
(child.end_byte(), child.end_byte())
}
fn parse_vue_imports(source: &str, tree: &Tree) -> ImportBlock {
let Ok((start, end)) = vue_single_script_content_range(tree) else {
return ImportBlock {
imports: Vec::new(),
byte_range: None,
};
};
let inner = &source[start..end];
let mut parser = Parser::new();
if parser
.set_language(&grammar_for(LangId::TypeScript))
.is_err()
{
return ImportBlock {
imports: Vec::new(),
byte_range: None,
};
}
let Some(inner_tree) = parser.parse(inner, None) else {
return ImportBlock {
imports: Vec::new(),
byte_range: None,
};
};
let mut block = parse_ts_imports(inner, &inner_tree);
for imp in &mut block.imports {
imp.byte_range = (imp.byte_range.start + start)..(imp.byte_range.end + start);
}
block.byte_range = block.byte_range.map(|r| (r.start + start)..(r.end + start));
block
}
struct VueSyntax;
impl ImportSyntax for VueSyntax {
fn parse(&self, source: &str, tree: &Tree) -> ImportBlock {
parse_vue_imports(source, tree)
}
fn generate_line(&self, req: &ImportRequest) -> String {
generate_ts_import_line(
req.module_path,
req.names,
req.default_import,
req.namespace,
req.type_only,
)
}
fn classify_group(&self, module_path: &str) -> ImportGroup {
classify_group_ts(module_path)
}
}
static ES_SYNTAX: EsSyntax = EsSyntax;
static PYTHON_SYNTAX: PythonSyntax = PythonSyntax;
static RUST_SYNTAX: RustSyntax = RustSyntax;
static GO_SYNTAX: GoSyntax = GoSyntax;
static SOLIDITY_SYNTAX: SoliditySyntax = SoliditySyntax;
static VUE_SYNTAX: VueSyntax = VueSyntax;
pub fn syntax_for(lang: LangId) -> Option<&'static dyn ImportSyntax> {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => Some(&ES_SYNTAX),
LangId::Python => Some(&PYTHON_SYNTAX),
LangId::Rust => Some(&RUST_SYNTAX),
LangId::Go => Some(&GO_SYNTAX),
LangId::Solidity => Some(&SOLIDITY_SYNTAX),
LangId::Vue => Some(&VUE_SYNTAX),
LangId::C => Some(&c::C_SYNTAX),
LangId::Cpp => Some(&c::C_SYNTAX),
LangId::Java => Some(&java::JAVA_SYNTAX),
LangId::Kotlin => Some(&kotlin::KOTLIN_SYNTAX),
LangId::Lua => Some(&lua::LUA_SYNTAX),
LangId::CSharp => Some(&csharp::CSHARP_SYNTAX),
LangId::Php => Some(&php::PHP_SYNTAX),
LangId::Perl => Some(&perl::PERL_SYNTAX),
LangId::Ruby => Some(&ruby::RUBY_SYNTAX),
LangId::Scala => Some(&scala::SCALA_SYNTAX),
LangId::Swift => Some(&swift::SWIFT_SYNTAX),
LangId::Zig
| LangId::Bash
| LangId::Scss
| LangId::Json
| LangId::Html
| LangId::Markdown
| LangId::Yaml => None,
}
}
pub fn parse_imports(source: &str, tree: &Tree, lang: LangId) -> ImportBlock {
match syntax_for(lang) {
Some(engine) => engine.parse(source, tree),
None => ImportBlock::empty(),
}
}
pub fn is_duplicate(
block: &ImportBlock,
module_path: &str,
names: &[String],
default_import: Option<&str>,
type_only: bool,
) -> bool {
is_duplicate_with_namespace(block, module_path, names, default_import, None, type_only)
}
pub fn is_duplicate_with_namespace(
block: &ImportBlock,
module_path: &str,
names: &[String],
default_import: Option<&str>,
namespace_import: Option<&str>,
type_only: bool,
) -> bool {
let target_kind = if type_only {
ImportKind::Type
} else {
ImportKind::Value
};
for imp in &block.imports {
if imp.module_path != module_path {
continue;
}
if names.is_empty()
&& default_import.is_none()
&& namespace_import.is_none()
&& imp.names.is_empty()
&& imp.default_import.is_none()
&& imp.namespace_import.is_none()
{
return true;
}
if names.is_empty()
&& default_import.is_none()
&& namespace_import.is_none()
&& imp.kind == ImportKind::SideEffect
{
return true;
}
if imp.kind != target_kind && imp.kind != ImportKind::SideEffect {
continue;
}
if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
if imp.default_import.as_deref() == Some(def)
&& imp.namespace_import.as_deref() == Some(namespace)
&& names
.iter()
.all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
{
return true;
}
continue;
}
if names.is_empty()
&& default_import.is_none()
&& namespace_import.is_some()
&& imp.namespace_import.as_deref() == namespace_import
{
return true;
}
if let Some(def) = default_import {
if namespace_import.is_none() && imp.default_import.as_deref() == Some(def) {
return true;
}
}
if !names.is_empty()
&& names
.iter()
.all(|n| imp.names.iter().any(|stored| specifier_matches(stored, n)))
{
return true;
}
}
false
}
pub(crate) fn is_duplicate_import_request(
lang: LangId,
block: &ImportBlock,
req: &ImportRequest<'_>,
) -> bool {
if !uses_form_aware_dedup(lang) {
return is_duplicate_with_namespace(
block,
req.module_path,
req.names,
req.default_import,
req.namespace,
req.type_only,
);
}
let target = request_dedup_key(lang, req);
block
.imports
.iter()
.map(|imp| statement_dedup_key(lang, imp))
.any(|key| key == target)
}
fn uses_form_aware_dedup(lang: LangId) -> bool {
matches!(
lang,
LangId::Solidity
| LangId::C
| LangId::Cpp
| LangId::Java
| LangId::CSharp
| LangId::Php
| LangId::Kotlin
| LangId::Scala
| LangId::Swift
| LangId::Ruby
| LangId::Lua
| LangId::Perl
)
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ImportDedupKey {
module_path: String,
kind: ImportKind,
form: ImportForm,
}
fn statement_dedup_key(lang: LangId, imp: &ImportStatement) -> ImportDedupKey {
canonical_dedup_key(
lang,
ImportDedupKey {
module_path: imp.module_path.clone(),
kind: imp.kind,
form: imp.form.clone(),
},
)
}
fn request_dedup_key(lang: LangId, req: &ImportRequest<'_>) -> ImportDedupKey {
let key = match lang {
LangId::Solidity => {
let kind = if req.names.is_empty() && req.namespace.is_none() && req.alias.is_none() {
ImportKind::SideEffect
} else {
ImportKind::Value
};
ImportDedupKey {
module_path: req.module_path.to_string(),
kind,
form: ImportForm::Solidity {
named: req.names.to_vec(),
namespace: req.namespace.map(str::to_string),
alias: req.alias.map(str::to_string),
},
}
}
LangId::C | LangId::Cpp => structured_dedup_key(
req.module_path,
ImportKind::SideEffect,
&[],
None,
None,
&[],
Some(req.import_kind.or(req.default_import).unwrap_or("system")),
),
LangId::Java => {
let (mut module_path, modifiers) = wildcard_suffix_request(
req.module_path,
req.modifiers,
req.default_import == Some("*"),
);
let mut names = req.names.to_vec();
normalize_java_static_member_key(&mut module_path, &modifiers, &mut names);
structured_dedup_key(
&module_path,
ImportKind::Value,
&names,
None,
None,
&modifiers,
None,
)
}
LangId::CSharp => structured_dedup_key(
req.module_path,
ImportKind::Value,
&[],
None,
req.alias,
req.modifiers,
None,
),
LangId::Php => structured_dedup_key(
req.module_path,
ImportKind::Value,
&[],
None,
req.alias,
req.modifiers,
req.import_kind,
),
LangId::Kotlin => {
let wildcard = req.default_import == Some("*") || req.module_path.ends_with(".*");
let (module_path, modifiers) =
wildcard_suffix_request(req.module_path, req.modifiers, wildcard);
let alias = req
.alias
.or(req.default_import.filter(|value| *value != "*"));
structured_dedup_key(
&module_path,
ImportKind::Value,
&[],
None,
alias,
&modifiers,
None,
)
}
LangId::Scala => scala_request_dedup_key(req),
LangId::Swift => structured_dedup_key(
req.module_path,
ImportKind::Value,
&[],
None,
None,
req.modifiers,
req.import_kind,
),
LangId::Ruby => {
let mut modifiers = req.modifiers.to_vec();
if !modifiers
.iter()
.any(|modifier| modifier == "quote:single" || modifier == "quote:double")
{
modifiers.push("quote:single".to_string());
}
structured_dedup_key(
req.module_path,
ImportKind::SideEffect,
&[],
None,
None,
&modifiers,
Some(req.import_kind.unwrap_or("require")),
)
}
LangId::Lua => {
let alias = req.default_import.or(req.alias);
let kind = if alias.is_some() {
ImportKind::Value
} else {
ImportKind::SideEffect
};
structured_dedup_key(req.module_path, kind, &[], None, alias, req.modifiers, None)
}
LangId::Perl => structured_dedup_key(
req.module_path,
ImportKind::SideEffect,
&[],
None,
None,
req.modifiers,
Some(req.import_kind.unwrap_or("use")),
),
_ => structured_dedup_key(
req.module_path,
if req.type_only {
ImportKind::Type
} else {
ImportKind::Value
},
req.names,
req.namespace,
req.alias,
req.modifiers,
req.import_kind,
),
};
canonical_dedup_key(lang, key)
}
fn structured_dedup_key(
module_path: &str,
kind: ImportKind,
named: &[String],
namespace: Option<&str>,
alias: Option<&str>,
modifiers: &[String],
import_kind: Option<&str>,
) -> ImportDedupKey {
ImportDedupKey {
module_path: module_path.to_string(),
kind,
form: ImportForm::Structured {
named: named.to_vec(),
namespace: namespace.map(str::to_string),
alias: alias.map(str::to_string),
modifiers: modifiers.to_vec(),
import_kind: import_kind.map(str::to_string),
},
}
}
fn wildcard_suffix_request(
module_path: &str,
modifiers: &[String],
wildcard: bool,
) -> (String, Vec<String>) {
let stripped = module_path.strip_suffix(".*").unwrap_or(module_path);
let mut modifiers = modifiers.to_vec();
if (wildcard || stripped.len() != module_path.len())
&& !modifiers.iter().any(|modifier| modifier == "wildcard")
{
modifiers.push("wildcard".to_string());
}
(stripped.to_string(), modifiers)
}
fn normalize_java_static_member_key(
module_path: &mut String,
modifiers: &[String],
names: &mut Vec<String>,
) {
let is_static = modifiers.iter().any(|modifier| modifier == "static");
let is_wildcard = modifiers.iter().any(|modifier| modifier == "wildcard");
if !is_static || is_wildcard || !names.is_empty() {
return;
}
if let Some((prefix, member)) = module_path.rsplit_once('.') {
if !prefix.is_empty() && !member.is_empty() {
names.push(member.to_string());
*module_path = prefix.to_string();
}
}
}
fn scala_request_dedup_key(req: &ImportRequest<'_>) -> ImportDedupKey {
let mut module_path = req.module_path.to_string();
let mut names: Vec<String> = req
.names
.iter()
.map(|name| normalize_scala_selector_for_dedup(name))
.collect();
let mut modifiers = req.modifiers.to_vec();
let mut import_kind = req.import_kind.map(str::to_string);
if req.default_import == Some("given") || module_path.ends_with(".given") {
import_kind.get_or_insert_with(|| "given".to_string());
if let Some(stripped) = module_path.strip_suffix(".given") {
module_path = stripped.to_string();
}
}
if matches!(req.default_import, Some("*") | Some("_"))
|| matches!(req.namespace, Some("*") | Some("_"))
|| module_path.ends_with(".*")
|| module_path.ends_with("._")
{
if !modifiers.iter().any(|modifier| modifier == "wildcard") {
modifiers.push("wildcard".to_string());
}
module_path = module_path
.strip_suffix(".*")
.or_else(|| module_path.strip_suffix("._"))
.unwrap_or(&module_path)
.to_string();
}
if names.is_empty() {
if let Some(alias) = req.alias.filter(|alias| !alias.is_empty()) {
if let Some((prefix, leaf)) = module_path.rsplit_once('.') {
names.push(format!("{leaf} as {alias}"));
module_path = prefix.to_string();
}
}
}
structured_dedup_key(
&module_path,
ImportKind::Value,
&names,
None,
None,
&modifiers,
import_kind.as_deref(),
)
}
fn normalize_scala_selector_for_dedup(name: &str) -> String {
let trimmed = name.trim();
if let Some((from, to)) = trimmed.split_once("=>") {
format!("{} as {}", from.trim(), to.trim())
} else {
trimmed.to_string()
}
}
fn canonical_dedup_key(lang: LangId, mut key: ImportDedupKey) -> ImportDedupKey {
match &mut key.form {
ImportForm::Structured { named, .. } | ImportForm::Solidity { named, .. } => {
sort_named_specifiers(named);
}
ImportForm::Es { named, .. } | ImportForm::Python { named, .. } => {
sort_named_specifiers(named);
}
ImportForm::RustUse { named, .. } => {
sort_named_specifiers(named);
}
ImportForm::Go { .. } => {}
}
if matches!(lang, LangId::Java | LangId::Kotlin) {
if let Some(stripped) = key.module_path.strip_suffix(".*") {
key.module_path = stripped.to_string();
}
if matches!(lang, LangId::Java) {
if let ImportForm::Structured {
named, modifiers, ..
} = &mut key.form
{
normalize_java_static_member_key(&mut key.module_path, modifiers, named);
}
}
} else if matches!(lang, LangId::Scala) {
key.module_path = key
.module_path
.strip_suffix(".given")
.or_else(|| key.module_path.strip_suffix(".*"))
.or_else(|| key.module_path.strip_suffix("._"))
.unwrap_or(&key.module_path)
.to_string();
}
key
}
fn sort_named_specifiers(names: &mut [String]) {
names.sort_by(|a, b| {
specifier_imported_name(a)
.cmp(specifier_imported_name(b))
.then_with(|| a.cmp(b))
});
}
pub fn find_insertion_point(
source: &str,
block: &ImportBlock,
group: ImportGroup,
module_path: &str,
type_only: bool,
) -> (usize, bool, bool) {
if block.imports.is_empty() {
return (0, false, source.is_empty().then_some(false).unwrap_or(true));
}
let target_kind = if type_only {
ImportKind::Type
} else {
ImportKind::Value
};
let group_imports: Vec<&ImportStatement> =
block.imports.iter().filter(|i| i.group == group).collect();
if group_imports.is_empty() {
let preceding_last = block.imports.iter().filter(|i| i.group < group).last();
if let Some(last) = preceding_last {
let end = last.byte_range.end;
let insert_at = skip_newline(source, end);
return (insert_at, true, true);
}
let following_first = block.imports.iter().find(|i| i.group > group);
if let Some(first) = following_first {
return (first.byte_range.start, false, true);
}
let first_byte = import_byte_range(&block.imports)
.map(|range| range.start)
.unwrap_or(0);
return (first_byte, false, true);
}
for imp in &group_imports {
let cmp = module_path.cmp(&imp.module_path);
match cmp {
std::cmp::Ordering::Less => {
return (imp.byte_range.start, false, false);
}
std::cmp::Ordering::Equal => {
if target_kind == ImportKind::Type && imp.kind == ImportKind::Value {
let end = imp.byte_range.end;
let insert_at = skip_newline(source, end);
return (insert_at, false, false);
}
return (imp.byte_range.start, false, false);
}
std::cmp::Ordering::Greater => continue,
}
}
let Some(last) = group_imports.last() else {
return (
import_byte_range(&block.imports)
.map(|range| range.end)
.unwrap_or(0),
false,
false,
);
};
let end = last.byte_range.end;
let insert_at = skip_newline(source, end);
(insert_at, false, false)
}
pub fn generate_import(lang: LangId, req: &ImportRequest) -> String {
match syntax_for(lang) {
Some(engine) => engine.generate_line(req),
None => String::new(),
}
}
pub fn generate_import_line(
lang: LangId,
module_path: &str,
names: &[String],
default_import: Option<&str>,
type_only: bool,
) -> String {
generate_import(
lang,
&ImportRequest::legacy(module_path, names, default_import, None, type_only),
)
}
pub fn generate_import_line_with_namespace(
lang: LangId,
module_path: &str,
names: &[String],
default_import: Option<&str>,
namespace_import: Option<&str>,
type_only: bool,
) -> String {
generate_import(
lang,
&ImportRequest::legacy(
module_path,
names,
default_import,
namespace_import,
type_only,
),
)
}
pub fn is_supported(lang: LangId) -> bool {
syntax_for(lang).is_some()
}
pub fn classify_group_ts(module_path: &str) -> ImportGroup {
if module_path.starts_with('.') {
ImportGroup::Internal
} else {
ImportGroup::External
}
}
pub fn classify_group(lang: LangId, module_path: &str) -> ImportGroup {
match syntax_for(lang) {
Some(engine) => engine.classify_group(module_path),
None => ImportGroup::External,
}
}
pub fn parse_file_imports(
path: &std::path::Path,
lang: LangId,
) -> Result<(String, Tree, ImportBlock), crate::error::AftError> {
let source =
std::fs::read_to_string(path).map_err(|e| crate::error::AftError::FileNotFound {
path: format!("{}: {}", path.display(), e),
})?;
let grammar = grammar_for(lang);
let mut parser = Parser::new();
parser
.set_language(&grammar)
.map_err(|e| crate::error::AftError::ParseError {
message: format!("grammar init failed for {:?}: {}", lang, e),
})?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| crate::error::AftError::ParseError {
message: format!("tree-sitter parse returned None for {}", path.display()),
})?;
let block = parse_imports(&source, &tree, lang);
Ok((source, tree, block))
}
fn parse_ts_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "import_statement" {
if let Some(imp) = parse_single_ts_import(source, &node) {
imports.push(imp);
}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_single_ts_import(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let module_path = extract_module_path(source, node)?;
let is_type_only = has_type_keyword(node);
let mut names = Vec::new();
let mut default_import = None;
let mut namespace_import = None;
let mut child_cursor = node.walk();
if child_cursor.goto_first_child() {
loop {
let child = child_cursor.node();
match child.kind() {
"import_clause" => {
extract_import_clause(
source,
&child,
&mut names,
&mut default_import,
&mut namespace_import,
);
}
"identifier" => {
let text = &source[child.byte_range()];
if text != "import" && text != "from" && text != "type" {
default_import = Some(text.to_string());
}
}
_ => {}
}
if !child_cursor.goto_next_sibling() {
break;
}
}
}
let kind = if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
ImportKind::SideEffect
} else if is_type_only {
ImportKind::Type
} else {
ImportKind::Value
};
let group = classify_group_ts(&module_path);
let form = ImportForm::Es {
default_import: default_import.clone(),
namespace_import: namespace_import.clone(),
named: names.clone(),
type_only: is_type_only,
side_effect: matches!(kind, ImportKind::SideEffect),
};
Some(ImportStatement {
module_path,
names,
default_import,
namespace_import,
kind,
group,
byte_range,
raw_text,
form,
})
}
fn extract_module_path(source: &str, node: &Node) -> Option<String> {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return None;
}
loop {
let child = cursor.node();
if child.kind() == "string" {
let text = &source[child.byte_range()];
let stripped = text
.trim_start_matches(|c| c == '\'' || c == '"')
.trim_end_matches(|c| c == '\'' || c == '"');
return Some(stripped.to_string());
}
if !cursor.goto_next_sibling() {
break;
}
}
None
}
fn has_type_keyword(node: &Node) -> bool {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return false;
}
loop {
let child = cursor.node();
if child.kind() == "type" {
return true;
}
if !cursor.goto_next_sibling() {
break;
}
}
false
}
fn extract_import_clause(
source: &str,
node: &Node,
names: &mut Vec<String>,
default_import: &mut Option<String>,
namespace_import: &mut Option<String>,
) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
match child.kind() {
"identifier" => {
let text = &source[child.byte_range()];
if text != "type" {
*default_import = Some(text.to_string());
}
}
"named_imports" => {
extract_named_imports(source, &child, names);
}
"namespace_import" => {
extract_namespace_import(source, &child, namespace_import);
}
_ => {}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn extract_named_imports(source: &str, node: &Node, names: &mut Vec<String>) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
if child.kind() == "import_specifier" {
let raw = source[child.byte_range()].trim().to_string();
if !raw.is_empty() {
names.push(raw);
} else if let Some(name_node) = child.child_by_field_name("name") {
names.push(source[name_node.byte_range()].to_string());
}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn extract_namespace_import(source: &str, node: &Node, namespace_import: &mut Option<String>) {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return;
}
loop {
let child = cursor.node();
if child.kind() == "identifier" {
*namespace_import = Some(source[child.byte_range()].to_string());
return;
}
if !cursor.goto_next_sibling() {
break;
}
}
}
fn generate_ts_import_line(
module_path: &str,
names: &[String],
default_import: Option<&str>,
namespace_import: Option<&str>,
type_only: bool,
) -> String {
let type_prefix = if type_only { "type " } else { "" };
if names.is_empty() && default_import.is_none() && namespace_import.is_none() {
return format!("import '{module_path}';");
}
if names.is_empty() && default_import.is_none() {
if let Some(namespace) = namespace_import {
return format!("import {type_prefix}* as {namespace} from '{module_path}';");
}
}
if names.is_empty() {
if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
return format!("import {type_prefix}{def}, * as {namespace} from '{module_path}';");
}
}
if names.is_empty() && namespace_import.is_none() {
if let Some(def) = default_import {
return format!("import {type_prefix}{def} from '{module_path}';");
}
}
if default_import.is_none() && namespace_import.is_none() {
let mut sorted_names = names.to_vec();
sort_named_specifiers(&mut sorted_names);
let names_str = sorted_names.join(", ");
return format!("import {type_prefix}{{ {names_str} }} from '{module_path}';");
}
if default_import.is_none() {
if let Some(namespace) = namespace_import {
let mut sorted_names = names.to_vec();
sort_named_specifiers(&mut sorted_names);
let names_str = sorted_names.join(", ");
return format!(
"import {type_prefix}{{ {names_str} }}, * as {namespace} from '{module_path}';"
);
}
}
if let (Some(def), Some(namespace)) = (default_import, namespace_import) {
let mut sorted_names = names.to_vec();
sort_named_specifiers(&mut sorted_names);
let names_str = sorted_names.join(", ");
return format!(
"import {type_prefix}{def}, {{ {names_str} }}, * as {namespace} from '{module_path}';"
);
}
if let Some(def) = default_import {
let mut sorted_names = names.to_vec();
sort_named_specifiers(&mut sorted_names);
let names_str = sorted_names.join(", ");
return format!("import {type_prefix}{def}, {{ {names_str} }} from '{module_path}';");
}
format!("import '{module_path}';")
}
const PYTHON_STDLIB: &[&str] = &[
"__future__",
"_thread",
"abc",
"aifc",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"bisect",
"builtins",
"bz2",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"cProfile",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"idlelib",
"imaplib",
"imghdr",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"multiprocessing",
"netrc",
"numbers",
"operator",
"optparse",
"os",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"sqlite3",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"tempfile",
"termios",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"tomllib",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"wsgiref",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
];
pub fn classify_group_py(module_path: &str) -> ImportGroup {
if module_path.starts_with('.') {
return ImportGroup::Internal;
}
let top_module = module_path.split('.').next().unwrap_or(module_path);
if PYTHON_STDLIB.contains(&top_module) {
ImportGroup::Stdlib
} else {
ImportGroup::External
}
}
fn parse_py_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
match node.kind() {
"import_statement" => {
if let Some(imp) = parse_py_import_statement(source, &node) {
imports.push(imp);
}
}
"import_from_statement" => {
if let Some(imp) = parse_py_import_from_statement(source, &node) {
imports.push(imp);
}
}
_ => {}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_py_import_statement(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut module_path = String::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
if c.node().kind() == "dotted_name" {
module_path = source[c.node().byte_range()].to_string();
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
if module_path.is_empty() {
return None;
}
let group = classify_group_py(&module_path);
Some(ImportStatement {
module_path,
names: Vec::new(),
default_import: None,
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
form: ImportForm::Python {
from_import: false,
named: Vec::new(),
},
})
}
fn parse_py_import_from_statement(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut module_path = String::new();
let mut names = Vec::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"dotted_name" => {
if module_path.is_empty()
&& !has_seen_import_keyword(source, node, child.start_byte())
{
module_path = source[child.byte_range()].to_string();
} else {
names.push(source[child.byte_range()].to_string());
}
}
"relative_import" => {
module_path = source[child.byte_range()].to_string();
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if module_path.is_empty() {
return None;
}
let group = classify_group_py(&module_path);
Some(ImportStatement {
module_path,
names: names.clone(),
default_import: None,
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
form: ImportForm::Python {
from_import: true,
named: names,
},
})
}
fn has_seen_import_keyword(_source: &str, parent: &Node, before_byte: usize) -> bool {
let mut c = parent.walk();
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "import" && child.start_byte() < before_byte {
return true;
}
if child.start_byte() >= before_byte {
return false;
}
if !c.goto_next_sibling() {
break;
}
}
}
false
}
fn generate_py_import_line(
module_path: &str,
names: &[String],
_default_import: Option<&str>,
) -> String {
if names.is_empty() {
format!("import {module_path}")
} else {
let mut sorted = names.to_vec();
sorted.sort();
let names_str = sorted.join(", ");
format!("from {module_path} import {names_str}")
}
}
pub fn classify_group_rs(module_path: &str) -> ImportGroup {
let first_seg = module_path.split("::").next().unwrap_or(module_path);
match first_seg {
"std" | "core" | "alloc" => ImportGroup::Stdlib,
"crate" | "self" | "super" => ImportGroup::Internal,
_ => ImportGroup::External,
}
}
fn parse_rs_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "use_declaration" {
if let Some(imp) = parse_rs_use_declaration(source, &node) {
imports.push(imp);
}
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_rs_use_declaration(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut has_pub = false;
let mut use_path = String::new();
let mut names = Vec::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"visibility_modifier" => {
has_pub = true;
}
"scoped_identifier" | "identifier" | "use_as_clause" => {
use_path = source[child.byte_range()].to_string();
}
"scoped_use_list" => {
use_path = source[child.byte_range()].to_string();
extract_rs_use_list_names(source, &child, &mut names);
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if use_path.is_empty() {
return None;
}
let group = classify_group_rs(&use_path);
Some(ImportStatement {
module_path: use_path,
names: names.clone(),
default_import: if has_pub {
Some("pub".to_string())
} else {
None
},
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
form: ImportForm::RustUse {
visibility: has_pub.then(|| "pub".to_string()),
named: names,
},
})
}
fn extract_rs_use_list_names(source: &str, node: &Node, names: &mut Vec<String>) {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "use_list" {
let mut lc = child.walk();
if lc.goto_first_child() {
loop {
let lchild = lc.node();
if lchild.kind() == "identifier" || lchild.kind() == "scoped_identifier" {
names.push(source[lchild.byte_range()].to_string());
}
if !lc.goto_next_sibling() {
break;
}
}
}
}
if !c.goto_next_sibling() {
break;
}
}
}
}
fn generate_rs_import_line(module_path: &str, names: &[String], _type_only: bool) -> String {
if names.is_empty() {
format!("use {module_path};")
} else {
let mut sorted_names = names.to_vec();
sort_named_specifiers(&mut sorted_names);
format!("use {module_path}::{{{}}};", sorted_names.join(", "))
}
}
pub fn classify_group_go(module_path: &str) -> ImportGroup {
if module_path.contains('.') {
ImportGroup::External
} else {
ImportGroup::Stdlib
}
}
fn parse_go_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return ImportBlock::empty();
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration" {
parse_go_import_declaration(source, &node, &mut imports);
}
if !cursor.goto_next_sibling() {
break;
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_go_import_declaration(source: &str, node: &Node, imports: &mut Vec<ImportStatement>) {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"import_spec" => {
if let Some(imp) = parse_go_import_spec(source, &child) {
imports.push(imp);
}
}
"import_spec_list" => {
let mut lc = child.walk();
if lc.goto_first_child() {
loop {
if lc.node().kind() == "import_spec" {
if let Some(imp) = parse_go_import_spec(source, &lc.node()) {
imports.push(imp);
}
}
if !lc.goto_next_sibling() {
break;
}
}
}
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
}
fn parse_go_import_spec(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut import_path = String::new();
let mut alias = None;
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
match child.kind() {
"interpreted_string_literal" => {
let text = source[child.byte_range()].to_string();
import_path = text.trim_matches('"').to_string();
}
"identifier" | "blank_identifier" | "dot" => {
alias = Some(source[child.byte_range()].to_string());
}
_ => {}
}
if !c.goto_next_sibling() {
break;
}
}
}
if import_path.is_empty() {
return None;
}
let group = classify_group_go(&import_path);
Some(ImportStatement {
module_path: import_path,
names: Vec::new(),
default_import: alias.clone(),
namespace_import: None,
kind: ImportKind::Value,
group,
byte_range,
raw_text,
form: ImportForm::Go { alias },
})
}
pub fn generate_go_import_line_pub(
module_path: &str,
alias: Option<&str>,
in_group: bool,
) -> String {
generate_go_import_line(module_path, alias, in_group)
}
fn generate_go_import_line(module_path: &str, alias: Option<&str>, in_group: bool) -> String {
if in_group {
match alias {
Some(a) => format!("\t{a} \"{module_path}\""),
None => format!("\t\"{module_path}\""),
}
} else {
match alias {
Some(a) => format!("import {a} \"{module_path}\""),
None => format!("import \"{module_path}\""),
}
}
}
pub fn go_has_grouped_import(_source: &str, tree: &Tree) -> Option<Range<usize>> {
let root = tree.root_node();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return None;
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration" && go_import_declaration_is_grouped(&node) {
return Some(node.byte_range());
}
if !cursor.goto_next_sibling() {
break;
}
}
None
}
pub fn go_import_declarations_range(_source: &str, tree: &Tree) -> Option<Range<usize>> {
let root = tree.root_node();
let mut cursor = root.walk();
let mut range: Option<Range<usize>> = None;
if !cursor.goto_first_child() {
return None;
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration" {
let node_range = node.byte_range();
range = Some(match range {
Some(existing) => {
existing.start.min(node_range.start)..existing.end.max(node_range.end)
}
None => node_range,
});
}
if !cursor.goto_next_sibling() {
break;
}
}
range
}
pub fn go_offset_is_in_grouped_import(_source: &str, tree: &Tree, offset: usize) -> bool {
let root = tree.root_node();
let mut cursor = root.walk();
if !cursor.goto_first_child() {
return false;
}
loop {
let node = cursor.node();
if node.kind() == "import_declaration"
&& node.start_byte() < offset
&& offset < node.end_byte()
&& go_import_declaration_is_grouped(&node)
{
return true;
}
if !cursor.goto_next_sibling() {
break;
}
}
false
}
fn go_import_declaration_is_grouped(node: &Node) -> bool {
let mut c = node.walk();
if c.goto_first_child() {
loop {
if c.node().kind() == "import_spec_list" {
return true;
}
if !c.goto_next_sibling() {
break;
}
}
}
false
}
pub fn classify_group_solidity(module_path: &str) -> ImportGroup {
if module_path.starts_with('.') {
ImportGroup::Internal
} else {
ImportGroup::External
}
}
fn parse_solidity_imports(source: &str, tree: &Tree) -> ImportBlock {
let root = tree.root_node();
let mut imports = Vec::new();
let mut cursor = root.walk();
if cursor.goto_first_child() {
loop {
let node = cursor.node();
if node.kind() == "import_directive" {
if let Some(imp) = parse_solidity_import_directive(source, &node) {
imports.push(imp);
}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
let byte_range = import_byte_range(&imports);
ImportBlock {
imports,
byte_range,
}
}
fn parse_solidity_import_directive(source: &str, node: &Node) -> Option<ImportStatement> {
let raw_text = source[node.byte_range()].to_string();
let byte_range = node.byte_range();
let mut children: Vec<(String, String)> = Vec::new();
let mut c = node.walk();
if c.goto_first_child() {
loop {
let ch = c.node();
children.push((ch.kind().to_string(), source[ch.byte_range()].to_string()));
if !c.goto_next_sibling() {
break;
}
}
}
let module_path = children
.iter()
.find(|(k, _)| k == "string")
.map(|(_, t)| t.trim_matches('"').to_string())?;
if module_path.is_empty() {
return None;
}
let has_brace = children.iter().any(|(k, _)| k == "{");
let has_star = children.iter().any(|(k, _)| k == "*");
let mut named: Vec<String> = Vec::new();
let mut namespace: Option<String> = None;
let mut alias: Option<String> = None;
if has_brace {
named = parse_solidity_named_specifiers(&children);
} else if has_star {
namespace = solidity_identifier_after_as(&children);
} else {
alias = solidity_identifier_after_as(&children);
}
let kind = if named.is_empty() && namespace.is_none() && alias.is_none() {
ImportKind::SideEffect
} else {
ImportKind::Value
};
let group = classify_group_solidity(&module_path);
Some(ImportStatement {
module_path,
names: named.clone(),
default_import: None,
namespace_import: namespace.clone(),
kind,
group,
byte_range,
raw_text,
form: ImportForm::Solidity {
named,
namespace,
alias,
},
})
}
fn solidity_identifier_after_as(children: &[(String, String)]) -> Option<String> {
let as_pos = children.iter().position(|(k, _)| k == "as")?;
children[as_pos + 1..]
.iter()
.find(|(k, _)| k == "identifier")
.map(|(_, t)| t.clone())
}
fn parse_solidity_named_specifiers(children: &[(String, String)]) -> Vec<String> {
let mut names = Vec::new();
let mut in_braces = false;
let mut current: Option<String> = None;
let mut expect_alias = false;
for (k, t) in children {
match k.as_str() {
"{" => in_braces = true,
"}" => {
if let Some(n) = current.take() {
names.push(n);
}
in_braces = false;
}
_ if !in_braces => {}
"identifier" => {
if expect_alias {
if let Some(n) = current.take() {
names.push(format!("{n} as {t}"));
}
expect_alias = false;
} else {
if let Some(n) = current.take() {
names.push(n);
}
current = Some(t.clone());
}
}
"as" => expect_alias = true,
"," => {
if let Some(n) = current.take() {
names.push(n);
}
expect_alias = false;
}
_ => {}
}
}
names
}
fn generate_solidity_import_line(req: &ImportRequest) -> String {
if !req.names.is_empty() {
format!(
"import {{ {} }} from \"{}\";",
req.names.join(", "),
req.module_path
)
} else if let Some(ns) = req.namespace {
format!("import * as {} from \"{}\";", ns, req.module_path)
} else if let Some(al) = req.alias {
format!("import \"{}\" as {};", req.module_path, al)
} else {
format!("import \"{}\";", req.module_path)
}
}
fn skip_newline(source: &str, pos: usize) -> usize {
if pos < source.len() {
let bytes = source.as_bytes();
if bytes[pos] == b'\n' {
return pos + 1;
}
if bytes[pos] == b'\r' {
if pos + 1 < source.len() && bytes[pos + 1] == b'\n' {
return pos + 2;
}
return pos + 1;
}
}
pos
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn form_es_mirrors_flat_fields() {
let (_, block) = parse_ts(
"import Default, { a, b as c } from \"ext\";\nimport type { T } from \"./t\";\nimport \"./side\";\nimport * as ns from \"nspkg\";\n",
);
match &block.imports[0].form {
ImportForm::Es {
default_import,
namespace_import,
named,
type_only,
side_effect,
} => {
assert_eq!(default_import.as_deref(), Some("Default"));
assert_eq!(namespace_import, &None);
assert_eq!(named, &block.imports[0].names);
assert!(!type_only);
assert!(!side_effect);
}
other => panic!("expected Es, got {other:?}"),
}
match &block.imports[1].form {
ImportForm::Es {
type_only, named, ..
} => {
assert!(type_only);
assert_eq!(named, &block.imports[1].names);
}
other => panic!("expected Es type-only, got {other:?}"),
}
match &block.imports[2].form {
ImportForm::Es { side_effect, .. } => assert!(side_effect),
other => panic!("expected Es side-effect, got {other:?}"),
}
match &block.imports[3].form {
ImportForm::Es {
namespace_import, ..
} => assert_eq!(namespace_import.as_deref(), Some("ns")),
other => panic!("expected Es namespace, got {other:?}"),
}
}
#[test]
fn form_python_mirrors_flat_fields() {
let (_, block) = parse_py("import os\nfrom sys import argv, path\n");
match &block.imports[0].form {
ImportForm::Python { from_import, named } => {
assert!(!from_import, "`import os` is not a from-import");
assert!(named.is_empty());
}
other => panic!("expected Python import, got {other:?}"),
}
match &block.imports[1].form {
ImportForm::Python { from_import, named } => {
assert!(from_import, "`from sys import ...` is a from-import");
assert_eq!(named, &block.imports[1].names);
}
other => panic!("expected Python from-import, got {other:?}"),
}
}
#[test]
fn form_rust_de_overloads_pub_from_default_import() {
let (_, block) = parse_rust("pub use crate::a::Exported;\nuse std::fmt::Debug;\n");
match &block.imports[0].form {
ImportForm::RustUse { visibility, named } => {
assert_eq!(visibility.as_deref(), Some("pub"));
assert_eq!(named, &block.imports[0].names);
}
other => panic!("expected RustUse, got {other:?}"),
}
assert_eq!(
block.imports[0].default_import.as_deref(),
Some("pub"),
"flat field unchanged during additive migration"
);
match &block.imports[1].form {
ImportForm::RustUse { visibility, .. } => assert_eq!(visibility, &None),
other => panic!("expected RustUse, got {other:?}"),
}
assert_eq!(block.imports[1].default_import, None);
}
#[test]
fn form_go_de_overloads_alias_from_default_import() {
let (_, block) =
parse_go("package main\n\nimport (\n\t_ \"github.com/x/y\"\n\t\"fmt\"\n)\n");
let blank = block
.imports
.iter()
.find(|i| i.module_path == "github.com/x/y")
.expect("blank import parsed");
match &blank.form {
ImportForm::Go { alias } => assert_eq!(alias.as_deref(), Some("_")),
other => panic!("expected Go blank-aliased, got {other:?}"),
}
assert_eq!(
blank.default_import.as_deref(),
Some("_"),
"form.alias mirrors the flat default_import field exactly"
);
let plain = block
.imports
.iter()
.find(|i| i.module_path == "fmt")
.expect("plain import parsed");
match &plain.form {
ImportForm::Go { alias } => assert_eq!(alias, &None),
other => panic!("expected Go plain, got {other:?}"),
}
assert_eq!(plain.default_import, None);
}
fn parse_ts(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::TypeScript);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::TypeScript);
(tree, block)
}
fn parse_js(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::JavaScript);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::JavaScript);
(tree, block)
}
fn parse_vue(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Vue);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Vue);
(tree, block)
}
#[test]
fn vue_grammar_node_kinds_are_stable() {
let src = "<template>\n <div />\n</template>\n\n<script setup lang=\"ts\">\nimport { ref } from 'vue'\n</script>\n";
let grammar = grammar_for(LangId::Vue);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(src, None).unwrap();
let root = tree.root_node();
let mut cursor = root.walk();
let script = root
.named_children(&mut cursor)
.find(|n| n.kind() == "script_element")
.expect("expected a script_element node");
let mut inner = script.walk();
assert!(
script
.named_children(&mut inner)
.any(|n| n.kind() == "raw_text"),
"expected script body exposed as raw_text"
);
}
#[test]
fn vue_parses_script_imports_with_whole_file_offsets() {
let src = "<template>\n <div />\n</template>\n\n<script setup lang=\"ts\">\nimport { ref } from 'vue'\nimport Foo from './Foo.vue'\nconst x = ref(0)\n</script>\n";
let (_tree, block) = parse_vue(src);
assert_eq!(block.imports.len(), 2, "should find both script imports");
for imp in &block.imports {
assert_eq!(&src[imp.byte_range.clone()], imp.raw_text);
assert!(
imp.byte_range.start > src.find("<script").unwrap(),
"import offset must fall inside the script block"
);
}
assert_eq!(block.imports[0].module_path, "vue");
assert_eq!(block.imports[1].module_path, "./Foo.vue");
}
#[test]
fn vue_without_script_block_has_no_imports() {
let src = "<template>\n <div />\n</template>\n\n<style>.x{}</style>\n";
let (_tree, block) = parse_vue(src);
assert!(block.imports.is_empty());
assert!(block.byte_range.is_none());
}
#[test]
fn parse_ts_named_imports() {
let source = "import { useState, useEffect } from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.module_path, "react");
assert!(imp.names.contains(&"useState".to_string()));
assert!(imp.names.contains(&"useEffect".to_string()));
assert_eq!(imp.kind, ImportKind::Value);
assert_eq!(imp.group, ImportGroup::External);
}
#[test]
fn parse_ts_default_import() {
let source = "import React from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.default_import.as_deref(), Some("React"));
assert_eq!(imp.kind, ImportKind::Value);
}
#[test]
fn parse_ts_side_effect_import() {
let source = "import './styles.css';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
assert_eq!(block.imports[0].module_path, "./styles.css");
}
#[test]
fn parse_ts_relative_import() {
let source = "import { helper } from './utils';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].group, ImportGroup::Internal);
}
#[test]
fn parse_ts_multiple_groups() {
let source = "\
import React from 'react';
import { useState } from 'react';
import { helper } from './utils';
import { Config } from '../config';
";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 4);
let external: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::External)
.collect();
let relative: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Internal)
.collect();
assert_eq!(external.len(), 2);
assert_eq!(relative.len(), 2);
}
#[test]
fn parse_ts_namespace_import() {
let source = "import * as path from 'path';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.namespace_import.as_deref(), Some("path"));
assert_eq!(imp.kind, ImportKind::Value);
}
#[test]
fn parse_js_imports() {
let source = "import { readFile } from 'fs';\nimport { helper } from './helper';\n";
let (_, block) = parse_js(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].group, ImportGroup::External);
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn classify_external() {
assert_eq!(classify_group_ts("react"), ImportGroup::External);
assert_eq!(classify_group_ts("@scope/pkg"), ImportGroup::External);
assert_eq!(classify_group_ts("lodash/map"), ImportGroup::External);
}
#[test]
fn classify_relative() {
assert_eq!(classify_group_ts("./utils"), ImportGroup::Internal);
assert_eq!(classify_group_ts("../config"), ImportGroup::Internal);
assert_eq!(classify_group_ts("./"), ImportGroup::Internal);
}
#[test]
fn dedup_detects_same_named_import() {
let source = "import { useState } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(
&block,
"react",
&["useState".to_string()],
None,
false
));
}
#[test]
fn dedup_misses_different_name() {
let source = "import { useState } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(!is_duplicate(
&block,
"react",
&["useEffect".to_string()],
None,
false
));
}
#[test]
fn dedup_detects_default_import() {
let source = "import React from 'react';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(&block, "react", &[], Some("React"), false));
}
#[test]
fn dedup_side_effect() {
let source = "import './styles.css';\n";
let (_, block) = parse_ts(source);
assert!(is_duplicate(&block, "./styles.css", &[], None, false));
}
#[test]
fn dedup_namespace_import_distinct_from_side_effect_import() {
let side_effect_source = "import 'fs';\n";
let (_, side_effect_block) = parse_ts(side_effect_source);
assert!(!is_duplicate_with_namespace(
&side_effect_block,
"fs",
&[],
None,
Some("fs"),
false
));
let namespace_source = "import * as fs from 'fs';\n";
let (_, namespace_block) = parse_ts(namespace_source);
assert!(!is_duplicate(&namespace_block, "fs", &[], None, false));
assert!(is_duplicate_with_namespace(
&namespace_block,
"fs",
&[],
None,
Some("fs"),
false
));
assert!(!is_duplicate_with_namespace(
&namespace_block,
"fs",
&[],
None,
Some("other"),
false
));
}
#[test]
fn dedup_type_vs_value() {
let source = "import { FC } from 'react';\n";
let (_, block) = parse_ts(source);
assert!(!is_duplicate(
&block,
"react",
&["FC".to_string()],
None,
true
));
}
#[test]
fn generate_named_import() {
let line = generate_import_line(
LangId::TypeScript,
"react",
&["useState".to_string(), "useEffect".to_string()],
None,
false,
);
assert_eq!(line, "import { useEffect, useState } from 'react';");
}
#[test]
fn generate_named_import_sorts_by_imported_name() {
let line = generate_import_line(
LangId::TypeScript,
"x",
&[
"useState".to_string(),
"type Foo".to_string(),
"stdin as input".to_string(),
"type Bar".to_string(),
],
None,
false,
);
assert_eq!(
line,
"import { type Bar, type Foo, stdin as input, useState } from 'x';"
);
}
#[test]
fn generate_default_import() {
let line = generate_import_line(LangId::TypeScript, "react", &[], Some("React"), false);
assert_eq!(line, "import React from 'react';");
}
#[test]
fn generate_type_import() {
let line =
generate_import_line(LangId::TypeScript, "react", &["FC".to_string()], None, true);
assert_eq!(line, "import type { FC } from 'react';");
}
#[test]
fn generate_side_effect_import() {
let line = generate_import_line(LangId::TypeScript, "./styles.css", &[], None, false);
assert_eq!(line, "import './styles.css';");
}
#[test]
fn generate_default_and_named() {
let line = generate_import_line(
LangId::TypeScript,
"react",
&["useState".to_string()],
Some("React"),
false,
);
assert_eq!(line, "import React, { useState } from 'react';");
}
#[test]
fn parse_ts_type_import() {
let source = "import type { FC } from 'react';\n";
let (_, block) = parse_ts(source);
assert_eq!(block.imports.len(), 1);
let imp = &block.imports[0];
assert_eq!(imp.kind, ImportKind::Type);
assert!(imp.names.contains(&"FC".to_string()));
assert_eq!(imp.group, ImportGroup::External);
}
#[test]
fn insertion_empty_file() {
let source = "";
let (_, block) = parse_ts(source);
let (offset, _, _) =
find_insertion_point(source, &block, ImportGroup::External, "react", false);
assert_eq!(offset, 0);
}
#[test]
fn insertion_alphabetical_within_group() {
let source = "\
import { a } from 'alpha';
import { c } from 'charlie';
";
let (_, block) = parse_ts(source);
let (offset, _, _) =
find_insertion_point(source, &block, ImportGroup::External, "bravo", false);
let before_charlie = source.find("import { c }").unwrap();
assert_eq!(offset, before_charlie);
}
fn parse_py(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Python);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Python);
(tree, block)
}
#[test]
fn parse_py_import_statement() {
let source = "import os\nimport sys\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "os");
assert_eq!(block.imports[1].module_path, "sys");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
}
#[test]
fn parse_py_from_import() {
let source = "from collections import OrderedDict\nfrom typing import List, Optional\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "collections");
assert!(block.imports[0].names.contains(&"OrderedDict".to_string()));
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].module_path, "typing");
assert!(block.imports[1].names.contains(&"List".to_string()));
assert!(block.imports[1].names.contains(&"Optional".to_string()));
}
#[test]
fn parse_py_relative_import() {
let source = "from . import utils\nfrom ..config import Settings\n";
let (_, block) = parse_py(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, ".");
assert!(block.imports[0].names.contains(&"utils".to_string()));
assert_eq!(block.imports[0].group, ImportGroup::Internal);
assert_eq!(block.imports[1].module_path, "..config");
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn classify_py_groups() {
assert_eq!(classify_group_py("os"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("sys"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("json"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("collections"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("os.path"), ImportGroup::Stdlib);
assert_eq!(classify_group_py("requests"), ImportGroup::External);
assert_eq!(classify_group_py("flask"), ImportGroup::External);
assert_eq!(classify_group_py("."), ImportGroup::Internal);
assert_eq!(classify_group_py("..config"), ImportGroup::Internal);
assert_eq!(classify_group_py(".utils"), ImportGroup::Internal);
}
#[test]
fn parse_py_three_groups() {
let source = "import os\nimport sys\n\nimport requests\n\nfrom . import utils\n";
let (_, block) = parse_py(source);
let stdlib: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Stdlib)
.collect();
let external: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::External)
.collect();
let internal: Vec<_> = block
.imports
.iter()
.filter(|i| i.group == ImportGroup::Internal)
.collect();
assert_eq!(stdlib.len(), 2);
assert_eq!(external.len(), 1);
assert_eq!(internal.len(), 1);
}
#[test]
fn generate_py_import() {
let line = generate_import_line(LangId::Python, "os", &[], None, false);
assert_eq!(line, "import os");
}
#[test]
fn generate_py_from_import() {
let line = generate_import_line(
LangId::Python,
"collections",
&["OrderedDict".to_string()],
None,
false,
);
assert_eq!(line, "from collections import OrderedDict");
}
#[test]
fn generate_py_from_import_multiple() {
let line = generate_import_line(
LangId::Python,
"typing",
&["Optional".to_string(), "List".to_string()],
None,
false,
);
assert_eq!(line, "from typing import List, Optional");
}
fn parse_rust(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Rust);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Rust);
(tree, block)
}
#[test]
fn parse_rs_use_std() {
let source = "use std::collections::HashMap;\nuse std::io::Read;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].module_path, "std::collections::HashMap");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
}
#[test]
fn parse_rs_use_external() {
let source = "use serde::{Deserialize, Serialize};\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].group, ImportGroup::External);
assert!(block.imports[0].names.contains(&"Deserialize".to_string()));
assert!(block.imports[0].names.contains(&"Serialize".to_string()));
}
#[test]
fn parse_rs_use_crate() {
let source = "use crate::config::Settings;\nuse super::parent::Thing;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 2);
assert_eq!(block.imports[0].group, ImportGroup::Internal);
assert_eq!(block.imports[1].group, ImportGroup::Internal);
}
#[test]
fn parse_rs_pub_use() {
let source = "pub use super::parent::Thing;\n";
let (_, block) = parse_rust(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].default_import.as_deref(), Some("pub"));
}
#[test]
fn classify_rs_groups() {
assert_eq!(
classify_group_rs("std::collections::HashMap"),
ImportGroup::Stdlib
);
assert_eq!(classify_group_rs("core::mem"), ImportGroup::Stdlib);
assert_eq!(classify_group_rs("alloc::vec"), ImportGroup::Stdlib);
assert_eq!(
classify_group_rs("serde::Deserialize"),
ImportGroup::External
);
assert_eq!(classify_group_rs("tokio::runtime"), ImportGroup::External);
assert_eq!(classify_group_rs("crate::config"), ImportGroup::Internal);
assert_eq!(classify_group_rs("self::utils"), ImportGroup::Internal);
assert_eq!(classify_group_rs("super::parent"), ImportGroup::Internal);
}
#[test]
fn generate_rs_use() {
let line = generate_import_line(LangId::Rust, "std::fmt::Display", &[], None, false);
assert_eq!(line, "use std::fmt::Display;");
}
fn parse_go(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Go);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Go);
(tree, block)
}
#[test]
fn parse_go_single_import() {
let source = "package main\n\nimport \"fmt\"\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 1);
assert_eq!(block.imports[0].module_path, "fmt");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
}
#[test]
fn parse_go_grouped_import() {
let source =
"package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n\n\t\"github.com/pkg/errors\"\n)\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 3);
assert_eq!(block.imports[0].module_path, "fmt");
assert_eq!(block.imports[0].group, ImportGroup::Stdlib);
assert_eq!(block.imports[1].module_path, "os");
assert_eq!(block.imports[1].group, ImportGroup::Stdlib);
assert_eq!(block.imports[2].module_path, "github.com/pkg/errors");
assert_eq!(block.imports[2].group, ImportGroup::External);
}
#[test]
fn parse_go_mixed_imports() {
let source = "package main\n\nimport \"fmt\"\n\nimport (\n\t\"os\"\n\t\"github.com/pkg/errors\"\n)\n";
let (_, block) = parse_go(source);
assert_eq!(block.imports.len(), 3);
}
#[test]
fn classify_go_groups() {
assert_eq!(classify_group_go("fmt"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("os"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("net/http"), ImportGroup::Stdlib);
assert_eq!(classify_group_go("encoding/json"), ImportGroup::Stdlib);
assert_eq!(
classify_group_go("github.com/pkg/errors"),
ImportGroup::External
);
assert_eq!(
classify_group_go("golang.org/x/tools"),
ImportGroup::External
);
}
#[test]
fn generate_go_standalone() {
let line = generate_go_import_line("fmt", None, false);
assert_eq!(line, "import \"fmt\"");
}
#[test]
fn generate_go_grouped_spec() {
let line = generate_go_import_line("fmt", None, true);
assert_eq!(line, "\t\"fmt\"");
}
#[test]
fn generate_go_with_alias() {
let line = generate_go_import_line("github.com/pkg/errors", Some("errs"), false);
assert_eq!(line, "import errs \"github.com/pkg/errors\"");
}
fn parse_solidity(source: &str) -> (Tree, ImportBlock) {
let grammar = grammar_for(LangId::Solidity);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let tree = parser.parse(source, None).unwrap();
let block = parse_imports(source, &tree, LangId::Solidity);
(tree, block)
}
#[test]
fn solidity_grammar_node_kinds_are_stable() {
let grammar = grammar_for(LangId::Solidity);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
let src = "import { Foo, Bar as Baz } from \"./A.sol\";\nimport * as N from \"./B.sol\";\nimport \"./C.sol\" as C;\nimport \"./D.sol\";\n";
let tree = parser.parse(src, None).unwrap();
let mut kinds: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
fn walk(node: tree_sitter::Node, kinds: &mut std::collections::BTreeSet<String>) {
kinds.insert(node.kind().to_string());
let mut c = node.walk();
if c.goto_first_child() {
loop {
walk(c.node(), kinds);
if !c.goto_next_sibling() {
break;
}
}
}
}
walk(tree.root_node(), &mut kinds);
for required in [
"import_directive",
"string",
"identifier",
"as",
"from",
"*",
"{",
"}",
] {
assert!(
kinds.contains(required),
"solidity grammar missing node kind {required:?}; present: {kinds:?}"
);
}
}
#[test]
fn parse_solidity_all_four_forms() {
let (_, block) = parse_solidity(
"import \"./A.sol\";\nimport \"./B.sol\" as B;\nimport * as C from \"./C.sol\";\nimport { Foo, Bar as Baz } from \"./D.sol\";\n",
);
assert_eq!(block.imports.len(), 4);
assert_eq!(block.imports[0].module_path, "./A.sol");
assert_eq!(block.imports[0].kind, ImportKind::SideEffect);
assert_eq!(
block.imports[0].form,
ImportForm::Solidity {
named: vec![],
namespace: None,
alias: None
}
);
assert_eq!(
block.imports[1].form,
ImportForm::Solidity {
named: vec![],
namespace: None,
alias: Some("B".to_string())
}
);
match &block.imports[2].form {
ImportForm::Solidity { namespace, .. } => assert_eq!(namespace.as_deref(), Some("C")),
other => panic!("expected Solidity namespace, got {other:?}"),
}
assert_eq!(block.imports[2].namespace_import.as_deref(), Some("C"));
match &block.imports[3].form {
ImportForm::Solidity { named, .. } => {
assert_eq!(named, &vec!["Foo".to_string(), "Bar as Baz".to_string()]);
}
other => panic!("expected Solidity named, got {other:?}"),
}
assert_eq!(
block.imports[3].names,
vec!["Foo".to_string(), "Bar as Baz".to_string()]
);
}
#[test]
fn generate_solidity_all_forms() {
assert_eq!(
generate_import(
LangId::Solidity,
&ImportRequest::legacy("./A.sol", &[], None, None, false)
),
"import \"./A.sol\";"
);
let names = vec!["Foo".to_string(), "Bar as Baz".to_string()];
assert_eq!(
generate_import(
LangId::Solidity,
&ImportRequest::legacy("./D.sol", &names, None, None, false)
),
"import { Foo, Bar as Baz } from \"./D.sol\";"
);
assert_eq!(
generate_import(
LangId::Solidity,
&ImportRequest::legacy("./C.sol", &[], None, Some("C"), false)
),
"import * as C from \"./C.sol\";"
);
assert_eq!(
generate_import(
LangId::Solidity,
&ImportRequest {
module_path: "./B.sol",
names: &[],
default_import: None,
namespace: None,
alias: Some("B"),
type_only: false,
modifiers: &[],
import_kind: None,
}
),
"import \"./B.sol\" as B;"
);
}
#[test]
fn solidity_round_trips_through_parse_generate() {
for src in [
"import \"./A.sol\";",
"import \"./B.sol\" as B;",
"import * as C from \"./C.sol\";",
"import { Foo, Bar as Baz } from \"./D.sol\";",
] {
let (_, block) = parse_solidity(src);
assert_eq!(block.imports.len(), 1, "parse {src:?}");
let imp = &block.imports[0];
let (namespace, alias) = match &imp.form {
ImportForm::Solidity {
namespace, alias, ..
} => (namespace.as_deref(), alias.as_deref()),
other => panic!("expected Solidity, got {other:?}"),
};
let regenerated = generate_import(
LangId::Solidity,
&ImportRequest {
module_path: &imp.module_path,
names: &imp.names,
default_import: None,
namespace,
alias,
type_only: false,
modifiers: &[],
import_kind: None,
},
);
assert_eq!(regenerated, src, "round-trip mismatch for {src:?}");
}
}
#[test]
fn classify_group_solidity_relative_vs_external() {
assert_eq!(classify_group_solidity("./A.sol"), ImportGroup::Internal);
assert_eq!(
classify_group_solidity("../lib/B.sol"),
ImportGroup::Internal
);
assert_eq!(
classify_group_solidity("@openzeppelin/contracts/token/ERC20/ERC20.sol"),
ImportGroup::External
);
}
}