use ide_db::source_change::SourceChangeBuilder;
use syntax::{
NodeOrToken, SyntaxToken, T, TextRange, algo,
ast::{self, AstNode, edit::AstNodeEdit},
};
use crate::{AssistContext, AssistId, Assists};
enum WrapUnwrapOption {
WrapDerive { derive: TextRange, attr: ast::Attr },
WrapAttr(Vec<ast::Attr>),
}
fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption {
let attempt_attr = || {
{
let mut derive = ident.text_range();
let mut prev = algo::skip_trivia_token(
ident.prev_sibling_or_token()?.into_token()?,
syntax::Direction::Prev,
)?;
let mut following = algo::skip_trivia_token(
ident.next_sibling_or_token()?.into_token()?,
syntax::Direction::Next,
)?;
if (prev.kind() == T![,] || prev.kind() == T!['('])
&& (following.kind() == T![,] || following.kind() == T![')'])
{
if following.kind() == T![,] {
derive = derive.cover(following.text_range());
} else if following.kind() == T![')'] && prev.kind() == T![,] {
derive = derive.cover(prev.text_range());
}
Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
} else {
let mut consumed_comma = false;
while let Some(prev_token) = algo::skip_trivia_token(prev, syntax::Direction::Prev)
{
let kind = prev_token.kind();
if kind == T![,] {
consumed_comma = true;
derive = derive.cover(prev_token.text_range());
break;
} else if kind == T!['('] {
break;
} else {
derive = derive.cover(prev_token.text_range());
}
prev = prev_token.prev_sibling_or_token()?.into_token()?;
}
while let Some(next_token) =
algo::skip_trivia_token(following.clone(), syntax::Direction::Next)
{
let kind = next_token.kind();
match kind {
T![,] if !consumed_comma => {
derive = derive.cover(next_token.text_range());
break;
}
T![')'] | T![,] => break,
_ => derive = derive.cover(next_token.text_range()),
}
following = next_token.next_sibling_or_token()?.into_token()?;
}
Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
}
}
};
if ident.parent().and_then(ast::TokenTree::cast).is_none()
|| !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default()
{
WrapUnwrapOption::WrapAttr(vec![attr])
} else {
attempt_attr().unwrap_or_else(|| WrapUnwrapOption::WrapAttr(vec![attr]))
}
}
pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let option = if ctx.has_empty_selection() {
let ident = ctx.find_token_syntax_at_offset(T![ident]);
let attr = ctx.find_node_at_offset::<ast::Attr>();
match (attr, ident) {
(Some(attr), Some(ident))
if attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() =>
{
Some(attempt_get_derive(attr, ident))
}
(Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(vec![attr])),
_ => None,
}
} else {
let covering_element = ctx.covering_element();
match covering_element {
NodeOrToken::Node(node) => {
if let Some(attr) = ast::Attr::cast(node.clone()) {
Some(WrapUnwrapOption::WrapAttr(vec![attr]))
} else {
let attrs = node
.children()
.filter(|it| it.text_range().intersect(ctx.selection_trimmed()).is_some())
.map(ast::Attr::cast)
.collect::<Option<Vec<_>>>()?;
if attrs.is_empty() {
return None;
}
Some(WrapUnwrapOption::WrapAttr(attrs))
}
}
NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => {
let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?;
Some(attempt_get_derive(attr, ident))
}
_ => None,
}
}?;
match option {
WrapUnwrapOption::WrapAttr(attrs) => {
if let [attr] = &attrs[..]
&& let Some(ast::Meta::CfgAttrMeta(meta)) = attr.meta()
{
unwrap_cfg_attr(acc, ctx, meta)
} else {
wrap_cfg_attrs(acc, ctx, attrs)
}
}
WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive),
}
}
fn wrap_derive(
acc: &mut Assists,
ctx: &AssistContext<'_>,
attr: ast::Attr,
derive_element: TextRange,
) -> Option<()> {
let range = attr.syntax().text_range();
let ast::Meta::TokenTreeMeta(meta) = attr.meta()? else { return None };
let token_tree = meta.token_tree()?;
let mut path_text = String::new();
let mut cfg_derive_tokens = Vec::new();
let mut new_derive = Vec::new();
for tt in token_tree.token_trees_and_tokens() {
let NodeOrToken::Token(token) = tt else {
continue;
};
if token.kind() == T!['('] || token.kind() == T![')'] {
continue;
}
if derive_element.contains_range(token.text_range()) {
if token.kind() != T![,] && token.kind() != syntax::SyntaxKind::WHITESPACE {
path_text.push_str(token.text());
cfg_derive_tokens.push(NodeOrToken::Token(token));
}
} else {
new_derive.push(NodeOrToken::Token(token));
}
}
let handle_source_change = |edit: &mut SourceChangeBuilder| {
let editor = edit.make_editor(attr.syntax());
let make = editor.make();
let new_derive = make.attr_outer(
make.meta_token_tree(make.ident_path("derive"), make.token_tree(T!['('], new_derive)),
);
let meta = make.cfg_attr_meta(
make.cfg_flag("cfg"),
[make.meta_token_tree(
make.ident_path("derive"),
make.token_tree(T!['('], cfg_derive_tokens),
)],
);
let cfg_attr = make.attr_outer(meta.clone().into());
editor.replace_with_many(
attr.syntax(),
vec![
new_derive.syntax().clone().into(),
make.whitespace("\n").into(),
cfg_attr.syntax().clone().into(),
],
);
if let Some(snippet_cap) = ctx.config.snippet_cap
&& let Some(cfg_predicate) = meta.cfg_predicate()
{
let tabstop = edit.make_placeholder_snippet(snippet_cap);
editor.add_annotation(cfg_predicate.syntax(), tabstop);
}
edit.add_file_edits(ctx.vfs_file_id(), editor);
};
acc.add(
AssistId::refactor("wrap_unwrap_cfg_attr"),
format!("Wrap #[derive({path_text})] in `cfg_attr`",),
range,
handle_source_change,
);
Some(())
}
fn wrap_cfg_attrs(acc: &mut Assists, ctx: &AssistContext<'_>, attrs: Vec<ast::Attr>) -> Option<()> {
let (first_attr, last_attr) = (attrs.first()?, attrs.last()?);
let range = first_attr.syntax().text_range().cover(last_attr.syntax().text_range());
let handle_source_change = |edit: &mut SourceChangeBuilder| {
let editor = edit.make_editor(first_attr.syntax());
let make = editor.make();
let meta =
make.cfg_attr_meta(make.cfg_flag("cfg"), attrs.iter().filter_map(|attr| attr.meta()));
let cfg_attr = if first_attr.excl_token().is_some() {
make.attr_inner(meta.clone().into())
} else {
make.attr_outer(meta.clone().into())
};
let syntax_range = first_attr.syntax().clone().into()..=last_attr.syntax().clone().into();
editor.replace_all(syntax_range, vec![cfg_attr.syntax().clone().into()]);
if let Some(snippet_cap) = ctx.config.snippet_cap
&& let Some(cfg_flag) = meta.cfg_predicate()
{
let tabstop = edit.make_placeholder_snippet(snippet_cap);
editor.add_annotation(cfg_flag.syntax(), tabstop);
}
edit.add_file_edits(ctx.vfs_file_id(), editor);
};
acc.add(
AssistId::refactor("wrap_unwrap_cfg_attr"),
"Convert to `cfg_attr`",
range,
handle_source_change,
);
Some(())
}
fn unwrap_cfg_attr(
acc: &mut Assists,
ctx: &AssistContext<'_>,
meta: ast::CfgAttrMeta,
) -> Option<()> {
let top_attr = ast::Meta::from(meta.clone()).parent_attr()?;
let range = top_attr.syntax().text_range();
let inner_metas: Vec<ast::Meta> = meta.metas().collect();
if inner_metas.is_empty() {
return None;
}
let is_inner = top_attr.excl_token().is_some();
let indent = top_attr.indent_level();
acc.add(
AssistId::refactor("wrap_unwrap_cfg_attr"),
"Extract Inner Attributes from `cfg_attr`",
range,
|builder: &mut SourceChangeBuilder| {
let editor = builder.make_editor(top_attr.syntax());
let make = editor.make();
let mut elements = vec![];
for (i, meta) in inner_metas.into_iter().enumerate() {
if i > 0 {
elements.push(make.whitespace(&format!("\n{indent}")).into());
}
let attr = if is_inner { make.attr_inner(meta) } else { make.attr_outer(meta) };
elements.push(attr.syntax().clone().into());
}
editor.replace_with_many(top_attr.syntax(), elements);
builder.add_file_edits(ctx.vfs_file_id(), editor);
},
);
Some(())
}
#[cfg(test)]
mod tests {
use crate::tests::check_assist;
use super::*;
#[test]
fn test_basic_to_from_cfg_attr() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive$0(Debug)]
pub struct Test {
test: u32,
}
"#,
r#"
#[cfg_attr(${0:cfg}, derive(Debug))]
pub struct Test {
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[cfg_attr(debug_assertions, $0 derive(Debug))]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive(Debug)]
pub struct Test {
test: u32,
}
"#,
);
}
#[test]
fn to_from_path_attr() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[foo$0]
test: u32,
}
"#,
r#"
pub struct Test {
#[cfg_attr(${0:cfg}, foo)]
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[cfg_attr(debug_assertions$0, foo)]
test: u32,
}
"#,
r#"
pub struct Test {
#[foo]
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[other_attr]
$0#[foo]
#[bar]$0
#[other_attr]
test: u32,
}
"#,
r#"
pub struct Test {
#[other_attr]
#[cfg_attr(${0:cfg}, foo, bar)]
#[other_attr]
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[cfg_attr(debug_assertions$0, foo, bar)]
test: u32,
}
"#,
r#"
pub struct Test {
#[foo]
#[bar]
test: u32,
}
"#,
);
}
#[test]
fn to_from_eq_attr() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[foo = "bar"$0]
test: u32,
}
"#,
r#"
pub struct Test {
#[cfg_attr(${0:cfg}, foo = "bar")]
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
pub struct Test {
#[cfg_attr(debug_assertions$0, foo = "bar")]
test: u32,
}
"#,
r#"
pub struct Test {
#[foo = "bar"]
test: u32,
}
"#,
);
}
#[test]
fn inner_attrs() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
#![no_std$0]
"#,
r#"
#![cfg_attr(${0:cfg}, no_std)]
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
#![cfg_attr(not(feature = "std")$0, no_std)]
"#,
r#"
#![no_std]
"#,
);
}
#[test]
fn test_derive_wrap() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(Debug$0, Clone, Copy)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive( Clone, Copy)]
#[cfg_attr(${0:cfg}, derive(Debug))]
pub struct Test {
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(Clone, Debug$0, Copy)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive(Clone, Copy)]
#[cfg_attr(${0:cfg}, derive(Debug))]
pub struct Test {
test: u32,
}
"#,
);
}
#[test]
fn test_derive_wrap_with_path() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(std::fmt::Debug$0, Clone, Copy)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive( Clone, Copy)]
#[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
pub struct Test {
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(Clone, std::fmt::Debug$0, Copy)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive(Clone, Copy)]
#[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
pub struct Test {
test: u32,
}
"#,
);
}
#[test]
fn test_derive_wrap_at_end() {
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(std::fmt::Debug, Clone, Cop$0y)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive(std::fmt::Debug, Clone)]
#[cfg_attr(${0:cfg}, derive(Copy))]
pub struct Test {
test: u32,
}
"#,
);
check_assist(
wrap_unwrap_cfg_attr,
r#"
#[derive(Clone, Copy, std::fmt::D$0ebug)]
pub struct Test {
test: u32,
}
"#,
r#"
#[derive(Clone, Copy)]
#[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
pub struct Test {
test: u32,
}
"#,
);
}
}