use crate::cst::{CabalCst, CstNodeKind};
use crate::span::{NodeId, Span};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ListStyle {
SingleLine,
LeadingComma,
TrailingComma,
NoComma,
}
#[derive(Debug, Clone)]
pub struct TextEdit {
pub range: Span,
pub replacement: String,
}
#[derive(Debug, Clone)]
pub struct EditBatch {
edits: Vec<TextEdit>,
}
impl Default for EditBatch {
fn default() -> Self {
Self::new()
}
}
impl EditBatch {
pub fn new() -> Self {
Self { edits: Vec::new() }
}
pub fn add(&mut self, edit: TextEdit) {
self.edits.push(edit);
}
pub fn add_all(&mut self, edits: Vec<TextEdit>) {
self.edits.extend(edits);
}
pub fn is_empty(&self) -> bool {
self.edits.is_empty()
}
pub fn apply(mut self, source: &str) -> String {
if self.edits.is_empty() {
return source.to_owned();
}
self.edits
.sort_by_key(|edit| std::cmp::Reverse(edit.range.start));
for pair in self.edits.windows(2) {
assert!(
pair[0].range.start >= pair[1].range.end,
"overlapping edits: {:?} and {:?}",
pair[1].range,
pair[0].range,
);
}
let mut result = source.to_owned();
for edit in &self.edits {
result.replace_range(edit.range.start..edit.range.end, &edit.replacement);
}
result
}
}
pub fn detect_list_style(cst: &CabalCst, field_node: NodeId) -> ListStyle {
let node = cst.node(field_node);
debug_assert_eq!(node.kind, CstNodeKind::Field);
let value_lines: Vec<NodeId> = node
.children
.iter()
.copied()
.filter(|&id| cst.node(id).kind == CstNodeKind::ValueLine)
.collect();
if value_lines.is_empty() {
return ListStyle::SingleLine;
}
let mut has_leading_comma = false;
let mut has_trailing_comma = false;
let mut has_any_comma = false;
for &vl_id in &value_lines {
let vl = cst.node(vl_id);
let text = vl.content_span.slice(&cst.source);
let trimmed = text.trim();
if trimmed.starts_with(',') {
has_leading_comma = true;
has_any_comma = true;
}
if trimmed.ends_with(',') {
has_trailing_comma = true;
has_any_comma = true;
}
}
if let Some(fv) = node.field_value {
let fv_text = fv.slice(&cst.source).trim();
if fv_text.ends_with(',') {
has_trailing_comma = true;
has_any_comma = true;
}
}
if !has_any_comma {
return ListStyle::NoComma;
}
if has_leading_comma {
return ListStyle::LeadingComma;
}
if has_trailing_comma {
return ListStyle::TrailingComma;
}
ListStyle::TrailingComma
}
fn item_name(item: &str) -> &str {
let trimmed = item.trim().trim_start_matches(',').trim();
let end = trimmed
.find(|c: char| c.is_whitespace() || c == '>' || c == '<' || c == '=' || c == '^')
.unwrap_or(trimmed.len());
let name = &trimmed[..end];
name.trim_end_matches(',')
}
fn clean_item_text(text: &str) -> &str {
text.trim()
.trim_start_matches(',')
.trim_end_matches(',')
.trim()
}
fn gather_items(cst: &CabalCst, field_node: NodeId) -> Vec<(String, Span)> {
let node = cst.node(field_node);
let mut items = Vec::new();
if let Some(fv) = node.field_value {
let fv_text = fv.slice(&cst.source);
if !fv_text.trim().is_empty() {
items.push((fv_text.to_owned(), fv));
}
}
for &child_id in &node.children {
let child = cst.node(child_id);
if child.kind == CstNodeKind::ValueLine {
let text = child.content_span.slice(&cst.source).to_owned();
items.push((text, child.span));
}
}
items
}
fn detect_item_indent(cst: &CabalCst, field_node: NodeId) -> String {
let node = cst.node(field_node);
for &child_id in &node.children {
let child = cst.node(child_id);
if child.kind == CstNodeKind::ValueLine {
let line_start = child.span.start;
let content_start = child.content_span.start;
if content_start > line_start {
return cst.source[line_start..content_start].to_owned();
}
return " ".repeat(child.indent);
}
}
let field_indent = node.indent;
" ".repeat(field_indent + 2)
}
fn field_value_end(cst: &CabalCst, field_node: NodeId) -> usize {
let node = cst.node(field_node);
let value_lines: Vec<NodeId> = node
.children
.iter()
.copied()
.filter(|&id| cst.node(id).kind == CstNodeKind::ValueLine)
.collect();
if let Some(&last_vl) = value_lines.last() {
return cst.node(last_vl).span.end;
}
node.span.end
}
fn find_sorted_insert_index(items: &[(String, Span)], new_item: &str) -> usize {
let new_name = item_name(new_item).to_lowercase();
for (i, (text, _)) in items.iter().enumerate() {
let existing_name = item_name(clean_item_text(text)).to_lowercase();
if new_name < existing_name {
return i;
}
}
items.len()
}
pub fn add_list_item(cst: &CabalCst, field_node: NodeId, item: &str, sort: bool) -> Vec<TextEdit> {
let style = detect_list_style(cst, field_node);
let items = gather_items(cst, field_node);
if items.is_empty() {
return add_item_to_empty_field(cst, field_node, item, style);
}
let insert_idx = if sort {
find_sorted_insert_index(&items, item)
} else {
items.len()
};
match style {
ListStyle::SingleLine => add_item_single_line(cst, field_node, &items, item, insert_idx),
ListStyle::LeadingComma => {
add_item_leading_comma(cst, field_node, &items, item, insert_idx)
}
ListStyle::TrailingComma => {
add_item_trailing_comma(cst, field_node, &items, item, insert_idx)
}
ListStyle::NoComma => add_item_no_comma(cst, field_node, &items, item, insert_idx),
}
}
fn add_item_to_empty_field(
cst: &CabalCst,
field_node: NodeId,
item: &str,
style: ListStyle,
) -> Vec<TextEdit> {
let node = cst.node(field_node);
match style {
ListStyle::SingleLine => {
let content_end = node.content_span.end;
vec![TextEdit {
range: Span::new(content_end, content_end),
replacement: format!(" {item}"),
}]
}
_ => {
let indent = detect_item_indent(cst, field_node);
let end = field_value_end(cst, field_node);
vec![TextEdit {
range: Span::new(end, end),
replacement: format!("{indent}{item}\n"),
}]
}
}
}
fn add_item_single_line(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
item: &str,
insert_idx: usize,
) -> Vec<TextEdit> {
let node = cst.node(field_node);
if let Some(fv) = node.field_value {
let fv_text = fv.slice(&cst.source);
let parts: Vec<&str> = fv_text.split(',').collect();
if insert_idx >= parts.len() || insert_idx >= items.len() {
vec![TextEdit {
range: Span::new(fv.end, fv.end),
replacement: format!(", {item}"),
}]
} else {
let mut offset = 0;
for (i, part) in parts.iter().enumerate() {
if i == insert_idx {
break;
}
offset += part.len() + 1; }
let insert_offset = fv.start + offset;
vec![TextEdit {
range: Span::new(insert_offset, insert_offset),
replacement: format!("{item}, "),
}]
}
} else {
add_item_to_empty_field(cst, field_node, item, ListStyle::SingleLine)
}
}
fn add_item_leading_comma(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
item: &str,
insert_idx: usize,
) -> Vec<TextEdit> {
let indent = detect_item_indent(cst, field_node);
let node = cst.node(field_node);
let value_lines: Vec<NodeId> = node
.children
.iter()
.copied()
.filter(|&id| cst.node(id).kind == CstNodeKind::ValueLine)
.collect();
let inline_value = node
.field_value
.filter(|span| !span.slice(&cst.source).trim().is_empty());
let has_inline = inline_value.is_some();
let comma_prefix = find_leading_comma_prefix(cst, &value_lines, &indent);
if insert_idx == 0 {
if has_inline {
let Some(fv) = inline_value else {
return Vec::new();
};
let fv_text = fv.slice(&cst.source).to_owned();
let old_first_clean = clean_item_text(&fv_text);
let first_vl_end = if value_lines.is_empty() {
field_value_end(cst, field_node)
} else {
value_lines
.first()
.map(|id| cst.node(*id).span.start)
.unwrap_or_else(|| field_value_end(cst, field_node))
};
vec![
TextEdit {
range: fv,
replacement: item.to_owned(),
},
TextEdit {
range: Span::new(first_vl_end, first_vl_end),
replacement: format!("{comma_prefix}{old_first_clean}\n"),
},
]
} else if !value_lines.is_empty() {
let first_vl = cst.node(value_lines[0]);
let first_text = first_vl.content_span.slice(&cst.source);
let old_first_clean = clean_item_text(first_text).to_owned();
let first_item_indent = &cst.source[first_vl.span.start..first_vl.content_span.start];
vec![TextEdit {
range: first_vl.span,
replacement: format!(
"{first_item_indent}{item}\n{comma_prefix}{old_first_clean}\n"
),
}]
} else {
add_item_to_empty_field(cst, field_node, item, ListStyle::LeadingComma)
}
} else if insert_idx >= items.len() {
let end = field_value_end(cst, field_node);
vec![TextEdit {
range: Span::new(end, end),
replacement: format!("{comma_prefix}{item}\n"),
}]
} else {
let Some(target_item) = items.get(insert_idx) else {
return Vec::new();
};
let target_span = target_item.1;
vec![TextEdit {
range: Span::new(target_span.start, target_span.start),
replacement: format!("{comma_prefix}{item}\n"),
}]
}
}
fn find_leading_comma_prefix(
cst: &CabalCst,
value_lines: &[NodeId],
default_indent: &str,
) -> String {
for &vl_id in value_lines {
let vl = cst.node(vl_id);
let text = vl.content_span.slice(&cst.source);
let trimmed = text.trim();
if trimmed.starts_with(',') {
let after_comma = trimmed.trim_start_matches(',').len();
let comma_and_space_len = trimmed.len() - after_comma;
let prefix_end = vl.content_span.start + comma_and_space_len;
return cst.source[vl.span.start..prefix_end].to_owned();
}
}
format!("{default_indent}, ")
}
fn add_item_trailing_comma(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
item: &str,
insert_idx: usize,
) -> Vec<TextEdit> {
let indent = detect_item_indent(cst, field_node);
if insert_idx >= items.len() {
let last = &items[items.len() - 1];
let last_span = last.1;
let last_node_text = last.0.trim();
let last_has_comma = last_node_text.ends_with(',');
let mut edits = Vec::new();
if !last_has_comma {
let last_content_end = find_content_end_in_span(cst, last_span);
edits.push(TextEdit {
range: Span::new(last_content_end, last_content_end),
replacement: ",".to_owned(),
});
}
let new_item = if last_has_comma {
format!("{indent}{item},\n")
} else {
format!("{indent}{item}\n")
};
let end = field_value_end(cst, field_node);
edits.push(TextEdit {
range: Span::new(end, end),
replacement: new_item,
});
edits
} else if insert_idx == 0 {
let Some(first) = items.first() else {
return Vec::new();
};
let first_span = first.1;
vec![TextEdit {
range: Span::new(first_span.start, first_span.start),
replacement: format!("{indent}{item},\n"),
}]
} else {
let Some(target) = items.get(insert_idx) else {
return Vec::new();
};
let target_span = target.1;
vec![TextEdit {
range: Span::new(target_span.start, target_span.start),
replacement: format!("{indent}{item},\n"),
}]
}
}
fn add_item_no_comma(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
item: &str,
insert_idx: usize,
) -> Vec<TextEdit> {
let indent = detect_item_indent(cst, field_node);
if insert_idx >= items.len() {
let end = field_value_end(cst, field_node);
vec![TextEdit {
range: Span::new(end, end),
replacement: format!("{indent}{item}\n"),
}]
} else {
let Some(target) = items.get(insert_idx) else {
return Vec::new();
};
let target_span = target.1;
vec![TextEdit {
range: Span::new(target_span.start, target_span.start),
replacement: format!("{indent}{item}\n"),
}]
}
}
fn find_content_end_in_span(cst: &CabalCst, span: Span) -> usize {
let text = &cst.source[span.start..span.end];
let trimmed_len = text.trim_end().len();
span.start + trimmed_len
}
pub fn remove_list_item(cst: &CabalCst, field_node: NodeId, item_prefix: &str) -> Vec<TextEdit> {
let style = detect_list_style(cst, field_node);
if style == ListStyle::SingleLine {
return remove_item_single_line(cst, field_node, item_prefix);
}
let items = gather_items(cst, field_node);
let prefix_lower = item_prefix.to_lowercase();
let remove_idx = items.iter().position(|(text, _)| {
let name = item_name(clean_item_text(text)).to_lowercase();
name == prefix_lower || name.starts_with(&prefix_lower)
});
let remove_idx = match remove_idx {
Some(idx) => idx,
None => return Vec::new(), };
match style {
ListStyle::SingleLine => unreachable!(),
ListStyle::LeadingComma => remove_item_leading_comma(cst, field_node, &items, remove_idx),
ListStyle::TrailingComma => remove_item_trailing_comma(cst, field_node, &items, remove_idx),
ListStyle::NoComma => remove_item_no_comma(&items, remove_idx),
}
}
fn remove_item_single_line(cst: &CabalCst, field_node: NodeId, item_prefix: &str) -> Vec<TextEdit> {
let node = cst.node(field_node);
if let Some(fv) = node.field_value {
let fv_text = fv.slice(&cst.source);
let parts: Vec<&str> = fv_text.split(',').collect();
let prefix_lower = item_prefix.to_lowercase();
let part_idx = parts.iter().position(|part| {
let name = item_name(part.trim()).to_lowercase();
name == prefix_lower || name.starts_with(&prefix_lower)
});
let part_idx = match part_idx {
Some(idx) => idx,
None => return Vec::new(),
};
if parts.len() <= 1 {
let name_end = node
.field_name
.map(|s| s.end)
.unwrap_or(node.content_span.start);
let colon_end = {
let after_name = &cst.source[name_end..node.content_span.end];
let colon_pos = after_name.find(':').map(|p| name_end + p + 1);
colon_pos.unwrap_or(node.content_span.end)
};
return vec![TextEdit {
range: Span::new(colon_end, fv.end),
replacement: String::new(),
}];
}
let mut new_parts: Vec<&str> = Vec::new();
for (i, part) in parts.iter().enumerate() {
if i != part_idx {
new_parts.push(part.trim());
}
}
let new_value = new_parts.join(", ");
vec![TextEdit {
range: fv,
replacement: new_value,
}]
} else {
Vec::new()
}
}
fn remove_item_leading_comma(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
remove_idx: usize,
) -> Vec<TextEdit> {
let node = cst.node(field_node);
let inline_value = node
.field_value
.filter(|span| !span.slice(&cst.source).trim().is_empty());
let has_inline = inline_value.is_some();
if items.len() == 1 {
let (_, span) = &items[0];
if has_inline {
let Some(fv) = inline_value else {
return Vec::new();
};
return vec![TextEdit {
range: fv,
replacement: String::new(),
}];
}
return vec![TextEdit {
range: *span,
replacement: String::new(),
}];
}
if remove_idx == 0 {
let first_span = items[0].1;
if has_inline {
let second_text = clean_item_text(&items[1].0);
let second_span = items[1].1;
let Some(fv) = inline_value else {
return Vec::new();
};
return vec![
TextEdit {
range: fv,
replacement: second_text.to_owned(),
},
TextEdit {
range: second_span,
replacement: String::new(),
},
];
}
let second_vl_id = node
.children
.iter()
.copied()
.filter(|&id| cst.node(id).kind == CstNodeKind::ValueLine)
.nth(1)
.or_else(|| node.children.get(1).copied());
let Some(second_vl_id) = second_vl_id else {
return Vec::new();
};
let second_vl = cst.node(second_vl_id);
let second_text = second_vl.content_span.slice(&cst.source);
let clean = clean_item_text(second_text);
let first_vl_id = node
.children
.iter()
.copied()
.find(|&id| cst.node(id).kind == CstNodeKind::ValueLine)
.unwrap_or(second_vl_id);
let first_vl = cst.node(first_vl_id);
let first_indent = &cst.source[first_vl.span.start..first_vl.content_span.start];
vec![
TextEdit {
range: first_span,
replacement: String::new(),
},
TextEdit {
range: items[1].1,
replacement: format!("{first_indent}{clean}\n"),
},
]
} else {
let Some((_, span)) = items.get(remove_idx) else {
return Vec::new();
};
vec![TextEdit {
range: *span,
replacement: String::new(),
}]
}
}
fn remove_item_trailing_comma(
cst: &CabalCst,
field_node: NodeId,
items: &[(String, Span)],
remove_idx: usize,
) -> Vec<TextEdit> {
let _ = cst;
let _ = field_node;
if items.len() == 1 {
let Some((_, span)) = items.first() else {
return Vec::new();
};
return vec![TextEdit {
range: *span,
replacement: String::new(),
}];
}
if remove_idx == items.len() - 1 {
let Some((last_text, last_span)) = items.get(remove_idx) else {
return Vec::new();
};
let last_has_comma = last_text.trim().ends_with(',');
if !last_has_comma && items.len() > 1 {
let Some((prev_text, prev_span)) = items.get(remove_idx - 1) else {
return Vec::new();
};
if prev_text.trim().ends_with(',') {
let content_end = find_content_end_in_span(cst, *prev_span);
return vec![
TextEdit {
range: Span::new(content_end.saturating_sub(1), content_end),
replacement: String::new(),
},
TextEdit {
range: *last_span,
replacement: String::new(),
},
];
}
}
return vec![TextEdit {
range: *last_span,
replacement: String::new(),
}];
}
let Some((_, span)) = items.get(remove_idx) else {
return Vec::new();
};
vec![TextEdit {
range: *span,
replacement: String::new(),
}]
}
fn remove_item_no_comma(items: &[(String, Span)], remove_idx: usize) -> Vec<TextEdit> {
let Some((_, span)) = items.get(remove_idx) else {
return Vec::new();
};
vec![TextEdit {
range: *span,
replacement: String::new(),
}]
}
pub fn set_field_value(cst: &CabalCst, field_node: NodeId, value: &str) -> TextEdit {
let node = cst.node(field_node);
debug_assert_eq!(node.kind, CstNodeKind::Field);
if let Some(fv) = node.field_value {
TextEdit {
range: fv,
replacement: value.to_owned(),
}
} else {
let insert_at = node.content_span.end;
TextEdit {
range: Span::new(insert_at, insert_at),
replacement: format!(" {value}"),
}
}
}
pub fn add_field_to_root(cst: &CabalCst, field_name: &str, field_value: &str) -> TextEdit {
let root = cst.node(cst.root);
let mut insert_at = 0usize;
for &child_id in &root.children {
let child = cst.node(child_id);
match child.kind {
CstNodeKind::Field | CstNodeKind::Comment | CstNodeKind::BlankLine => {
insert_at = child.span.end;
}
CstNodeKind::Section => {
break;
}
_ => {
insert_at = child.span.end;
}
}
}
TextEdit {
range: Span::new(insert_at, insert_at),
replacement: format!("{field_name}: {field_value}\n"),
}
}
pub fn add_field_to_section(
cst: &CabalCst,
section_node: NodeId,
field_name: &str,
field_value: &str,
) -> TextEdit {
let section = cst.node(section_node);
debug_assert_eq!(section.kind, CstNodeKind::Section);
let field_indent = find_section_field_indent(cst, section_node);
let insert_at = find_field_insertion_point(cst, section_node);
TextEdit {
range: Span::new(insert_at, insert_at),
replacement: format!("{field_indent}{field_name}: {field_value}\n"),
}
}
fn find_section_field_indent(cst: &CabalCst, section_node: NodeId) -> String {
let section = cst.node(section_node);
for &child_id in §ion.children {
let child = cst.node(child_id);
if child.kind == CstNodeKind::Field || child.kind == CstNodeKind::Import {
return " ".repeat(child.indent);
}
}
" ".repeat(section.indent + 2)
}
fn find_field_insertion_point(cst: &CabalCst, section_node: NodeId) -> usize {
let section = cst.node(section_node);
let mut last_field_end = section.span.start;
if section.children.is_empty() {
return section.span.end;
}
for &child_id in §ion.children {
let child = cst.node(child_id);
match child.kind {
CstNodeKind::Field
| CstNodeKind::Import
| CstNodeKind::Comment
| CstNodeKind::BlankLine => {
last_field_end = child.span.end;
}
CstNodeKind::Conditional => {
break;
}
_ => {
last_field_end = child.span.end;
}
}
}
last_field_end
}
pub fn add_section(
cst: &CabalCst,
keyword: &str,
name: Option<&str>,
fields: &[(&str, &str)],
indent: usize,
) -> TextEdit {
let insert_at = cst.source.len();
let indent_str = " ".repeat(indent);
let mut text = String::new();
if !cst.source.is_empty() && !cst.source.ends_with('\n') {
text.push('\n');
}
if !cst.source.is_empty() && !cst.source.ends_with("\n\n") {
text.push('\n');
}
text.push_str(keyword);
if let Some(n) = name {
text.push(' ');
text.push_str(n);
}
text.push('\n');
for (fname, fvalue) in fields {
text.push_str(&indent_str);
text.push_str(fname);
text.push_str(": ");
text.push_str(fvalue);
text.push('\n');
}
TextEdit {
range: Span::new(insert_at, insert_at),
replacement: text,
}
}
pub fn find_field(cst: &CabalCst, parent_node: NodeId, field_name: &str) -> Option<NodeId> {
let parent = cst.node(parent_node);
let normalized = normalize_field_name(field_name);
for &child_id in &parent.children {
let child = cst.node(child_id);
if child.kind == CstNodeKind::Field {
if let Some(name_span) = child.field_name {
let name = name_span.slice(&cst.source);
if normalize_field_name(name) == normalized {
return Some(child_id);
}
}
}
}
None
}
pub fn find_section(cst: &CabalCst, keyword: &str, name: Option<&str>) -> Option<NodeId> {
let root = cst.node(cst.root);
for &child_id in &root.children {
let child = cst.node(child_id);
if child.kind == CstNodeKind::Section {
if let Some(kw_span) = child.section_keyword {
let kw = kw_span.slice(&cst.source);
if kw.eq_ignore_ascii_case(keyword) {
match name {
None => return Some(child_id),
Some(n) => {
if let Some(arg_span) = child.section_arg {
if arg_span.slice(&cst.source).eq_ignore_ascii_case(n) {
return Some(child_id);
}
}
}
}
}
}
}
}
None
}
fn normalize_field_name(name: &str) -> String {
name.to_lowercase().replace('_', "-")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse;
fn apply_edits(source: &str, edits: Vec<TextEdit>) -> String {
let mut batch = EditBatch::new();
batch.add_all(edits);
batch.apply(source)
}
#[test]
fn edit_batch_empty() {
let source = "hello world";
let batch = EditBatch::new();
assert_eq!(batch.apply(source), "hello world");
}
#[test]
fn edit_batch_single_insert() {
let source = "hello world";
let mut batch = EditBatch::new();
batch.add(TextEdit {
range: Span::new(5, 5),
replacement: ",".to_owned(),
});
assert_eq!(batch.apply(source), "hello, world");
}
#[test]
fn edit_batch_single_replace() {
let source = "hello world";
let mut batch = EditBatch::new();
batch.add(TextEdit {
range: Span::new(6, 11),
replacement: "rust".to_owned(),
});
assert_eq!(batch.apply(source), "hello rust");
}
#[test]
fn edit_batch_single_delete() {
let source = "hello world";
let mut batch = EditBatch::new();
batch.add(TextEdit {
range: Span::new(5, 6),
replacement: String::new(),
});
assert_eq!(batch.apply(source), "helloworld");
}
#[test]
fn edit_batch_multiple_non_overlapping() {
let source = "aaa bbb ccc";
let mut batch = EditBatch::new();
batch.add(TextEdit {
range: Span::new(0, 3),
replacement: "xxx".to_owned(),
});
batch.add(TextEdit {
range: Span::new(8, 11),
replacement: "zzz".to_owned(),
});
assert_eq!(batch.apply(source), "xxx bbb zzz");
}
#[test]
#[should_panic(expected = "overlapping edits")]
fn edit_batch_overlapping_panics() {
let source = "hello world";
let mut batch = EditBatch::new();
batch.add(TextEdit {
range: Span::new(0, 7),
replacement: "hi".to_owned(),
});
batch.add(TextEdit {
range: Span::new(5, 11),
replacement: "there".to_owned(),
});
batch.apply(source);
}
#[test]
fn detect_style_single_line() {
let src = "\
library
build-depends: base >=4.14, text >=2.0, aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
assert_eq!(detect_list_style(&result.cst, field), ListStyle::SingleLine);
}
#[test]
fn detect_style_leading_comma() {
let src = "\
library
build-depends:
base >=4.14
, text >=2.0
, aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
assert_eq!(
detect_list_style(&result.cst, field),
ListStyle::LeadingComma
);
}
#[test]
fn detect_style_trailing_comma() {
let src = "\
library
build-depends:
base >=4.14,
text >=2.0,
aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
assert_eq!(
detect_list_style(&result.cst, field),
ListStyle::TrailingComma
);
}
#[test]
fn detect_style_no_comma() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
assert_eq!(detect_list_style(&result.cst, field), ListStyle::NoComma);
}
#[test]
fn set_scalar_field_value() {
let src = "name: foo\nversion: 0.1.0.0\n";
let result = parse::parse(src);
let field = find_field(&result.cst, result.cst.root, "version").unwrap();
let edit = set_field_value(&result.cst, field, "1.0.0.0");
let new_src = apply_edits(src, vec![edit]);
assert_eq!(new_src, "name: foo\nversion: 1.0.0.0\n");
}
#[test]
fn set_field_value_empty_field() {
let src = "name:\nversion: 0.1.0.0\n";
let result = parse::parse(src);
let field = find_field(&result.cst, result.cst.root, "name").unwrap();
let edit = set_field_value(&result.cst, field, "my-package");
let new_src = apply_edits(src, vec![edit]);
assert_eq!(new_src, "name: my-package\nversion: 0.1.0.0\n");
}
#[test]
fn add_module_no_comma() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
let edits = add_list_item(&result.cst, field, "Data.List", true);
let new_src = apply_edits(src, edits);
assert!(new_src.contains("Data.List"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_module_no_comma_end() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
let edits = add_list_item(&result.cst, field, "Data.Text", true);
let new_src = apply_edits(src, edits);
let map_pos = new_src.find("Data.Map").unwrap();
let set_pos = new_src.find("Data.Set").unwrap();
let text_pos = new_src.find("Data.Text").unwrap();
assert!(map_pos < set_pos);
assert!(set_pos < text_pos);
}
#[test]
fn add_dep_trailing_comma_end() {
let src = "\
library
build-depends:
base >=4.14,
text >=2.0,
aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = add_list_item(&result.cst, field, "zlib ^>=0.7", true);
let new_src = apply_edits(src, edits);
assert!(new_src.contains("zlib ^>=0.7"));
assert!(new_src.contains("aeson ^>=2.2,"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_dep_leading_comma_end() {
let src = "\
library
build-depends:
base >=4.14
, text >=2.0
, aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = add_list_item(&result.cst, field, "zlib ^>=0.7", true);
let new_src = apply_edits(src, edits);
assert!(new_src.contains("zlib ^>=0.7"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_dep_single_line_end() {
let src = "\
library
build-depends: base >=4.14, text >=2.0
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = add_list_item(&result.cst, field, "aeson ^>=2.2", false);
let new_src = apply_edits(src, edits);
assert!(new_src.contains("aeson ^>=2.2"));
assert!(new_src.contains("text >=2.0, aeson ^>=2.2"));
}
#[test]
fn remove_module_no_comma() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
Data.Text
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
let edits = remove_list_item(&result.cst, field, "Data.Set");
let new_src = apply_edits(src, edits);
assert!(!new_src.contains("Data.Set"));
assert!(new_src.contains("Data.Map"));
assert!(new_src.contains("Data.Text"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn remove_dep_trailing_comma_middle() {
let src = "\
library
build-depends:
base >=4.14,
text >=2.0,
aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = remove_list_item(&result.cst, field, "text");
let new_src = apply_edits(src, edits);
assert!(!new_src.contains("text"));
assert!(new_src.contains("base"));
assert!(new_src.contains("aeson"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn remove_dep_trailing_comma_last() {
let src = "\
library
build-depends:
base >=4.14,
text >=2.0,
aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = remove_list_item(&result.cst, field, "aeson");
let new_src = apply_edits(src, edits);
assert!(!new_src.contains("aeson"));
assert!(new_src.contains("base"));
assert!(new_src.contains("text"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn remove_dep_single_line_middle() {
let src = "\
library
build-depends: base >=4.14, text >=2.0, aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = remove_list_item(&result.cst, field, "text");
let new_src = apply_edits(src, edits);
assert!(!new_src.contains("text"));
assert!(new_src.contains("base >=4.14, aeson ^>=2.2"));
}
#[test]
fn add_field_to_section_basic() {
let src = "\
library
exposed-modules: Foo
build-depends: base
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let edit = add_field_to_section(&result.cst, section, "default-language", "GHC2021");
let new_src = apply_edits(src, vec![edit]);
assert!(new_src.contains("default-language: GHC2021"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_new_section() {
let src = "\
cabal-version: 3.0
name: foo
version: 0.1.0.0
";
let result = parse::parse(src);
let edit = add_section(
&result.cst,
"library",
None,
&[
("exposed-modules", "Foo"),
("build-depends", "base"),
("hs-source-dirs", "src"),
],
2,
);
let new_src = apply_edits(src, vec![edit]);
assert!(new_src.contains("library\n"));
assert!(new_src.contains(" exposed-modules: Foo\n"));
assert!(new_src.contains(" build-depends: base\n"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_named_section() {
let src = "\
cabal-version: 3.0
name: foo
version: 0.1.0.0
";
let result = parse::parse(src);
let edit = add_section(
&result.cst,
"executable",
Some("my-exe"),
&[("main-is", "Main.hs"), ("build-depends", "base, foo")],
2,
);
let new_src = apply_edits(src, vec![edit]);
assert!(new_src.contains("executable my-exe\n"));
assert!(new_src.contains(" main-is: Main.hs\n"));
}
#[test]
fn find_field_case_insensitive() {
let src = "Name: foo\nVersion: 0.1.0.0\n";
let result = parse::parse(src);
assert!(find_field(&result.cst, result.cst.root, "name").is_some());
assert!(find_field(&result.cst, result.cst.root, "NAME").is_some());
}
#[test]
fn find_field_underscore_hyphen() {
let src = "build-depends: base\n";
let result = parse::parse(src);
assert!(find_field(&result.cst, result.cst.root, "build_depends").is_some());
assert!(find_field(&result.cst, result.cst.root, "build-depends").is_some());
}
#[test]
fn find_section_library() {
let src = "\
cabal-version: 3.0
name: foo
version: 0.1.0.0
library
exposed-modules: Foo
";
let result = parse::parse(src);
assert!(find_section(&result.cst, "library", None).is_some());
}
#[test]
fn find_section_named_executable() {
let src = "\
executable my-exe
main-is: Main.hs
";
let result = parse::parse(src);
assert!(find_section(&result.cst, "executable", Some("my-exe")).is_some());
assert!(find_section(&result.cst, "executable", Some("other")).is_none());
}
#[test]
fn round_trip_add_remove_no_comma() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
let edits = add_list_item(&result.cst, field, "Data.List", true);
let added_src = apply_edits(src, edits);
assert!(added_src.contains("Data.List"));
let result2 = parse::parse(&added_src);
let section2 = result2.cst.node(result2.cst.root).children[0];
let field2 = find_field(&result2.cst, section2, "exposed-modules").unwrap();
let edits2 = remove_list_item(&result2.cst, field2, "Data.List");
let removed_src = apply_edits(&added_src, edits2);
assert_eq!(
removed_src, src,
"round-trip add+remove should restore original"
);
}
#[test]
fn round_trip_add_remove_trailing_comma() {
let src = "\
library
build-depends:
base >=4.14,
aeson ^>=2.2
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = add_list_item(&result.cst, field, "text >=2.0", true);
let added_src = apply_edits(src, edits);
assert!(added_src.contains("text >=2.0"));
let result2 = parse::parse(&added_src);
let section2 = result2.cst.node(result2.cst.root).children[0];
let field2 = find_field(&result2.cst, section2, "build-depends").unwrap();
let edits2 = remove_list_item(&result2.cst, field2, "text");
let removed_src = apply_edits(&added_src, edits2);
assert_eq!(
removed_src, src,
"round-trip add+remove should restore original"
);
}
#[test]
fn item_name_basic() {
assert_eq!(item_name("base >=4.14"), "base");
assert_eq!(item_name("aeson ^>=2.2"), "aeson");
assert_eq!(item_name(" , text >=2.0"), "text");
assert_eq!(item_name("Data.Map"), "Data.Map");
assert_eq!(item_name("base,"), "base");
}
#[test]
fn add_to_empty_field_single_line() {
let src = "\
library
build-depends:
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "build-depends").unwrap();
let edits = add_list_item(&result.cst, field, "base >=4.14", false);
let new_src = apply_edits(src, edits);
assert!(new_src.contains("base >=4.14"));
let re_parsed = parse::parse(&new_src);
assert_eq!(re_parsed.cst.render(), new_src);
}
#[test]
fn add_list_item_sorted_beginning() {
let src = "\
library
exposed-modules:
Data.Map
Data.Set
";
let result = parse::parse(src);
let section = result.cst.node(result.cst.root).children[0];
let field = find_field(&result.cst, section, "exposed-modules").unwrap();
let edits = add_list_item(&result.cst, field, "Data.Aeson", true);
let new_src = apply_edits(src, edits);
let aeson_pos = new_src.find("Data.Aeson").unwrap();
let map_pos = new_src.find("Data.Map").unwrap();
assert!(aeson_pos < map_pos);
}
}