use std::path::Path;
use super::organize_imports;
use crate::context::AppContext;
use crate::edit;
use crate::imports;
use crate::parser::{detect_language, LangId};
use crate::protocol::{RawRequest, Response};
pub fn handle_add_import(req: &RawRequest, ctx: &AppContext) -> Response {
let op_id = crate::backup::new_op_id();
let file = match req.params.get("file").and_then(|v| v.as_str()) {
Some(f) => f,
None => {
return Response::error(
&req.id,
"invalid_request",
"add_import: missing required param 'file'",
);
}
};
let module = match req.params.get("module").and_then(|v| v.as_str()) {
Some(m) => m,
None => {
return Response::error(
&req.id,
"invalid_request",
"add_import: missing required param 'module'",
);
}
};
let names: Vec<String> = req
.params
.get("names")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let default_import = req
.params
.get("default_import")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let type_only = req
.params
.get("type_only")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let namespace = req
.params
.get("namespace")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let alias = req
.params
.get("alias")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let mut modifiers: Vec<String> = req
.params
.get("modifiers")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let import_kind = req
.params
.get("import_kind")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let path = match ctx.validate_path(&req.id, Path::new(file)) {
Ok(path) => path,
Err(resp) => return resp,
};
if !path.exists() {
return Response::error(
&req.id,
"file_not_found",
format!("add_import: file not found: {}", file),
);
}
let lang = match detect_language(&path) {
Some(l) => l,
None => {
return Response::error(
&req.id,
"unsupported_language",
format!(
"add_import: unsupported file extension: {}",
path.extension()
.and_then(|e| e.to_str())
.unwrap_or("<none>")
),
);
}
};
if !imports::is_supported(lang) {
return Response::error(
&req.id,
"unsupported_language",
format!(
"add_import: import management not yet supported for {:?}",
lang
),
);
}
let (module_owned, inferred_import_kind) = if matches!(lang, LangId::C | LangId::Cpp) {
imports::normalize_include_module(module)
} else {
(module.to_string(), None)
};
let module = module_owned.as_str();
let import_kind = import_kind.or_else(|| inferred_import_kind.map(String::from));
if let Err(reason) = validate_module_path_for_add(lang, module) {
return Response::error(
&req.id,
"invalid_request",
format!("add_import: invalid module path '{module}': {reason}"),
);
}
let (source, tree, block) = match imports::parse_file_imports(&path, lang) {
Ok(result) => result,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
if lang == LangId::Vue {
if let Err(err) = imports::vue_single_script_content_range(&tree) {
return Response::error(&req.id, err.code(), err.message("add_import"));
}
}
if matches!(lang, LangId::CSharp | LangId::Php)
&& organize_imports::imports_span_multiple_code_regions(&source, lang, &block.imports)
{
return Response::error_with_data(
&req.id,
"multi_region_imports",
format!(
"add_import: imports in {file} span multiple code regions; refusing to add because the target region is ambiguous"
),
serde_json::json!({ "file": file }),
);
}
if lang == LangId::Php
&& block.imports.iter().any(|imp| {
imports::php_grouped_use_shares_prefix(imp, module)
|| imports::php_grouped_use_matches_module(imp, module)
})
{
return Response::error_with_data(
&req.id,
"unsupported_grouped_import",
format!(
"add_import: PHP grouped use declarations matching '{module}' are not safe to edit member-wise; expand the grouped use first"
),
serde_json::json!({ "file": file, "module": module }),
);
}
if lang == LangId::Scala
&& imports::scala_block_uses_scala2_dialect(&block)
&& !modifiers.iter().any(|modifier| modifier == "scala2")
{
modifiers.push("scala2".to_string());
}
let import_request = imports::ImportRequest {
module_path: module,
names: &names,
default_import: default_import.as_deref(),
namespace: namespace.as_deref(),
alias: alias.as_deref(),
type_only,
modifiers: &modifiers,
import_kind: import_kind.as_deref(),
};
if imports::is_duplicate_import_request(lang, &block, &import_request) {
log::debug!("add_import: {} (already present)", file);
return Response::success(
&req.id,
serde_json::json!({
"file": file,
"added": false,
"module": module,
"already_present": true,
}),
);
}
let group = match (lang, import_kind.as_deref()) {
(LangId::C | LangId::Cpp, Some(kind)) => imports::classify_group_c_import_kind(kind),
_ => imports::classify_group(lang, module),
};
let target_kind = if type_only {
imports::ImportKind::Type
} else {
imports::ImportKind::Value
};
let merge_target = if !names.is_empty()
&& default_import.is_none()
&& matches!(
lang,
LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python | LangId::Rust
) {
block.imports.iter().find(|imp| {
imp.module_path == module
&& imp.kind == target_kind
&& imp.namespace_import.is_none()
&& imp.default_import.is_none()
&& !imp.names.is_empty()
})
} else {
None
};
let namespace_merge_target = if names.is_empty()
&& default_import.is_some()
&& namespace.is_some()
&& matches!(
lang,
LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Vue
) {
block.imports.iter().find(|imp| {
imp.module_path == module
&& imp.kind == target_kind
&& imp.names.is_empty()
&& imp.default_import.as_deref() == default_import.as_deref()
&& imp.namespace_import.is_none()
})
} else {
None
};
let (insert_offset, replace_end, insert_text, merged_into_existing) = if let Some(existing) =
merge_target
{
let mut merged_names: Vec<String> = existing.names.clone();
for name in &names {
if !merged_names
.iter()
.any(|n| imports::specifier_matches(n, name))
{
merged_names.push(name.clone());
}
}
merged_names
.sort_by(|a, b| imports::specifier_local_name(a).cmp(imports::specifier_local_name(b)));
let merged_line = imports::generate_import_line(
lang,
&existing.module_path,
&merged_names,
None,
type_only,
);
(
existing.byte_range.start,
existing.byte_range.end,
merged_line,
true,
)
} else if let Some(existing) = namespace_merge_target {
let merged_line = imports::generate_import_line_with_namespace(
lang,
&existing.module_path,
&existing.names,
existing.default_import.as_deref(),
namespace.as_deref(),
type_only,
);
(
existing.byte_range.start,
existing.byte_range.end,
merged_line,
true,
)
} else {
let bytes = source.as_bytes();
let skip_one_newline = |mut at: usize| {
if at < bytes.len() && bytes[at] == b'\r' {
at += 1;
}
if at < bytes.len() && bytes[at] == b'\n' {
at += 1;
}
at
};
let (insert_offset, needs_blank_before, needs_blank_after) = if !block.imports.is_empty() {
let (off, blank_before, blank_after) =
imports::find_insertion_point(&source, &block, group, module, type_only);
if matches!(lang, LangId::C | LangId::Cpp) {
(off, false, false)
} else {
(off, blank_before, blank_after)
}
} else if lang == LangId::Vue {
match imports::vue_script_content_range(&tree) {
Some((start, _end)) => (skip_one_newline(start), false, false),
None => imports::find_insertion_point(&source, &block, group, module, type_only),
}
} else if let Some(anchor) = shebang_anchor(lang, &source) {
let at = skip_one_newline(anchor);
let next_is_newline = at < bytes.len() && (bytes[at] == b'\n' || bytes[at] == b'\r');
(at, false, !next_is_newline && at < source.len())
} else if let Some(anchor) = header_prologue_anchor(lang, &source, &tree) {
let at = skip_one_newline(anchor);
let next_is_newline = at < bytes.len() && (bytes[at] == b'\n' || bytes[at] == b'\r');
(at, true, !next_is_newline && at < source.len())
} else if let Some(anchor) = spdx_license_anchor(lang, &source) {
let at = skip_one_newline(anchor);
let next_is_newline = at < bytes.len() && (bytes[at] == b'\n' || bytes[at] == b'\r');
(at, false, !next_is_newline && at < source.len())
} else {
imports::find_insertion_point(&source, &block, group, module, type_only)
};
let import_line = if matches!(lang, LangId::Go) {
let in_group = imports::go_offset_is_in_grouped_import(&source, &tree, insert_offset);
imports::generate_go_import_line_pub(module, default_import.as_deref(), in_group)
} else {
imports::generate_import(lang, &import_request)
};
let mut insert_text = String::new();
if needs_blank_before {
insert_text.push('\n');
}
insert_text.push_str(&import_line);
insert_text.push('\n');
if needs_blank_after {
insert_text.push('\n');
}
(insert_offset, insert_offset, insert_text, false)
};
let _ = merged_into_existing;
let backup_id = match edit::auto_backup(
ctx,
req.session(),
&path,
"add_import: pre-edit backup",
Some(&op_id),
) {
Ok(id) => id,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
let new_source =
match edit::replace_byte_range(&source, insert_offset, replace_end, &insert_text) {
Ok(s) => s,
Err(e) => return Response::error(&req.id, e.code(), e.to_string()),
};
let mut write_result =
match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
Ok(r) => r,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
if let Ok(final_content) = std::fs::read_to_string(&path) {
write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
}
log::debug!("add_import: {}", file);
if write_result.rolled_back {
return Response::error(
&req.id,
"generated_invalid_syntax",
format!(
"add_import: adding '{module}' to {file} would produce invalid syntax; file left unchanged"
),
);
}
let mut result = serde_json::json!({
"file": file,
"added": true,
"module": module,
"group": group.label(),
"formatted": write_result.formatted,
});
if let Some(valid) = write_result.syntax_valid {
result["syntax_valid"] = serde_json::json!(valid);
}
if let Some(ref reason) = write_result.format_skipped_reason {
result["format_skipped_reason"] = serde_json::json!(reason);
}
if write_result.validate_requested {
result["validation_errors"] = serde_json::json!(write_result.validation_errors);
}
if let Some(ref reason) = write_result.validate_skipped_reason {
result["validate_skipped_reason"] = serde_json::json!(reason);
}
if let Some(ref id) = backup_id {
result["backup_id"] = serde_json::json!(id);
}
write_result.append_lsp_diagnostics_to(&mut result);
Response::success(&req.id, result)
}
fn shebang_anchor(lang: LangId, source: &str) -> Option<usize> {
if !matches!(
lang,
LangId::TypeScript
| LangId::JavaScript
| LangId::Python
| LangId::Ruby
| LangId::Perl
| LangId::Lua
) || !source.starts_with("#!")
{
return None;
}
Some(source.find('\n').unwrap_or(source.len()))
}
fn spdx_license_anchor(lang: LangId, source: &str) -> Option<usize> {
if lang != LangId::Solidity {
return None;
}
let first_line_end = source.find('\n').unwrap_or(source.len());
let first_line = &source[..first_line_end];
first_line
.trim_start()
.starts_with("// SPDX-License-Identifier:")
.then_some(first_line_end)
}
fn header_prologue_anchor(lang: LangId, source: &str, tree: &tree_sitter::Tree) -> Option<usize> {
let (strong, skippable): (&[&str], &[&str]) = match lang {
LangId::Go => (&["package_clause"], &["comment"]),
LangId::Java => (&["package_declaration"], &["comment"]),
LangId::Kotlin => (&["package_header"], &["comment"]),
LangId::Scala => (&["package_clause"], &["comment"]),
LangId::Solidity => (&["pragma_directive"], &["comment"]),
LangId::Php => (&["php_tag", "namespace_definition"], &["comment"]),
LangId::Perl => (&["package_statement"], &["comment"]),
_ => return None,
};
let root = tree.root_node();
let mut cursor = root.walk();
let mut anchor: Option<usize> = None;
for child in root.named_children(&mut cursor) {
let kind = child.kind();
if strong.contains(&kind) {
anchor = if lang == LangId::Php && kind == "namespace_definition" {
php_braced_namespace_anchor(source, child).or(Some(child.end_byte()))
} else {
Some(child.end_byte())
};
} else if skippable.contains(&kind) {
continue;
} else {
break;
}
}
anchor
}
fn validate_module_path_for_add(lang: LangId, module: &str) -> Result<(), &'static str> {
let module = module.trim();
if uses_path_module_validation(lang) {
if is_absolute_module_path(module) {
return Err("absolute paths are not allowed for this language");
}
return Ok(());
}
if uses_namespace_module_validation(lang) {
if module.starts_with('/') || module.contains('/') {
return Err("filesystem paths are not allowed for this language");
}
if (lang != LangId::Php && module.starts_with('\\')) || has_windows_drive_prefix(module) {
return Err("absolute paths are not allowed for this language");
}
if lang != LangId::Php && module.contains('\\') {
return Err("filesystem paths are not allowed for this language");
}
if module.contains("..") || contains_parent_path_segment(module) {
return Err("parent path traversal segments are not allowed for this language");
}
}
Ok(())
}
fn uses_path_module_validation(lang: LangId) -> bool {
matches!(lang, LangId::C | LangId::Cpp | LangId::Solidity)
}
fn uses_namespace_module_validation(lang: LangId) -> bool {
matches!(
lang,
LangId::Php | LangId::Java | LangId::Kotlin | LangId::Scala | LangId::CSharp
)
}
fn is_absolute_module_path(module: &str) -> bool {
module.starts_with('/') || module.starts_with('\\') || has_windows_drive_absolute_path(module)
}
fn has_windows_drive_absolute_path(module: &str) -> bool {
let bytes = module.as_bytes();
bytes.len() >= 3
&& bytes[0].is_ascii_alphabetic()
&& bytes[1] == b':'
&& matches!(bytes[2], b'\\' | b'/')
}
fn has_windows_drive_prefix(module: &str) -> bool {
let bytes = module.as_bytes();
bytes.len() >= 2 && bytes[0].is_ascii_alphabetic() && bytes[1] == b':'
}
fn contains_parent_path_segment(module: &str) -> bool {
module
.split(|ch| ch == '/' || ch == '\\')
.any(|segment| segment == "..")
}
fn php_braced_namespace_anchor(source: &str, node: tree_sitter::Node<'_>) -> Option<usize> {
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "{" {
return Some(child.end_byte());
}
if !cursor.goto_next_sibling() {
break;
}
}
}
source[node.start_byte()..node.end_byte()]
.find('{')
.map(|offset| node.start_byte() + offset + 1)
}