use hir::TraitRef;
use syntax::ast::{
self, AstNode, BinaryOp, CmpOp, HasName, LogicOp, edit::AstNodeEdit,
syntax_factory::SyntaxFactory,
};
pub(crate) fn gen_trait_fn_body(
make: &SyntaxFactory,
func: &ast::Fn,
trait_path: &ast::Path,
adt: &ast::Adt,
trait_ref: Option<TraitRef<'_>>,
) -> Option<ast::BlockExpr> {
let _ = func.body()?;
match trait_path.segment()?.name_ref()?.text().as_str() {
"Clone" => {
stdx::always!(func.name().is_some_and(|name| name.text() == "clone"));
gen_clone_impl(make, adt)
}
"Debug" => gen_debug_impl(make, adt),
"Default" => gen_default_impl(make, adt),
"Hash" => {
stdx::always!(func.name().is_some_and(|name| name.text() == "hash"));
gen_hash_impl(make, adt)
}
"PartialEq" => {
stdx::always!(func.name().is_some_and(|name| name.text() == "eq"));
gen_partial_eq(make, adt, trait_ref)
}
"PartialOrd" => {
stdx::always!(func.name().is_some_and(|name| name.text() == "partial_cmp"));
gen_partial_ord(make, adt, trait_ref)
}
_ => None,
}
}
fn gen_clone_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
let gen_clone_call = |target: ast::Expr| -> ast::Expr {
let method = make.name_ref("clone");
make.expr_method_call(target, method, make.arg_list([])).into()
};
let expr = match adt {
ast::Adt::Union(_) => return None,
ast::Adt::Enum(enum_) => {
let list = enum_.variant_list()?;
let mut arms = vec![];
for variant in list.variants() {
let name = variant.name()?;
let variant_name = make.path_from_idents(["Self", &format!("{name}")])?;
match variant.field_list() {
Some(ast::FieldList::RecordFieldList(list)) => {
let mut pats = vec![];
let mut fields = vec![];
for field in list.fields() {
let field_name = field.name()?;
let pat = make.ident_pat(false, false, field_name.clone());
pats.push(make.record_pat_field_shorthand(pat.into()));
let path = make.ident_path(&field_name.to_string());
let method_call = gen_clone_call(make.expr_path(path));
let name_ref = make.name_ref(&field_name.to_string());
let field = make.record_expr_field(name_ref, Some(method_call));
fields.push(field);
}
let pat_field_list = make.record_pat_field_list(pats, None);
let pat = make.record_pat_with_fields(variant_name.clone(), pat_field_list);
let fields = make.record_expr_field_list(fields);
let record_expr = make.record_expr(variant_name, fields).into();
arms.push(make.match_arm(pat.into(), None, record_expr));
}
Some(ast::FieldList::TupleFieldList(list)) => {
let mut pats = vec![];
let mut fields = vec![];
for (i, _) in list.fields().enumerate() {
let field_name = format!("arg{i}");
let pat = make.ident_pat(false, false, make.name(&field_name));
pats.push(pat.into());
let f_path = make.expr_path(make.ident_path(&field_name));
fields.push(gen_clone_call(f_path));
}
let pat = make.tuple_struct_pat(variant_name.clone(), pats.into_iter());
let struct_name = make.expr_path(variant_name);
let tuple_expr = make.expr_call(struct_name, make.arg_list(fields)).into();
arms.push(make.match_arm(pat.into(), None, tuple_expr));
}
None => {
let pattern = make.path_pat(variant_name.clone());
let variant_expr = make.expr_path(variant_name);
arms.push(make.match_arm(pattern, None, variant_expr));
}
}
}
let match_target = make.expr_path(make.ident_path("self"));
let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
make.expr_match(match_target, list).into()
}
ast::Adt::Struct(strukt) => {
match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut fields = vec![];
for field in field_list.fields() {
let base = make.expr_path(make.ident_path("self"));
let target = make.expr_field(base, &field.name()?.to_string()).into();
let method_call = gen_clone_call(target);
let name_ref = make.name_ref(&field.name()?.to_string());
let field = make.record_expr_field(name_ref, Some(method_call));
fields.push(field);
}
let struct_name = make.ident_path("Self");
let fields = make.record_expr_field_list(fields);
make.record_expr(struct_name, fields).into()
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut fields = vec![];
for (i, _) in field_list.fields().enumerate() {
let f_path = make.expr_path(make.ident_path("self"));
let target = make.expr_field(f_path, &format!("{i}")).into();
fields.push(gen_clone_call(target));
}
let struct_name = make.expr_path(make.ident_path("Self"));
make.expr_call(struct_name, make.arg_list(fields)).into()
}
None => {
let struct_name = make.ident_path("Self");
let fields = make.record_expr_field_list([]);
make.record_expr(struct_name, fields).into()
}
}
}
};
let body = make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
Some(body)
}
fn gen_debug_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
let annotated_name = adt.name()?;
match adt {
ast::Adt::Union(_) => None,
ast::Adt::Enum(enum_) => {
let list = enum_.variant_list()?;
let mut arms = vec![];
for variant in list.variants() {
let name = variant.name()?;
let variant_name = make.path_from_idents(["Self", &format!("{name}")])?;
let target = make.expr_path(make.ident_path("f"));
match variant.field_list() {
Some(ast::FieldList::RecordFieldList(list)) => {
let target = make.expr_path(make.ident_path("f"));
let method = make.name_ref("debug_struct");
let struct_name = format!("\"{name}\"");
let args = make.arg_list([make.expr_literal(&struct_name).into()]);
let mut expr = make.expr_method_call(target, method, args).into();
let mut pats = vec![];
for field in list.fields() {
let field_name = field.name()?;
let pat = make.ident_pat(false, false, field_name.clone());
pats.push(make.record_pat_field_shorthand(pat.into()));
let method_name = make.name_ref("field");
let name = make.expr_literal(&(format!("\"{field_name}\""))).into();
let path = &format!("{field_name}");
let path = make.expr_path(make.ident_path(path));
let args = make.arg_list([name, path]);
expr = make.expr_method_call(expr, method_name, args).into();
}
let method = make.name_ref("finish");
let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
let pat_field_list = make.record_pat_field_list(pats, None);
let pat = make.record_pat_with_fields(variant_name.clone(), pat_field_list);
arms.push(make.match_arm(pat.into(), None, expr));
}
Some(ast::FieldList::TupleFieldList(list)) => {
let target = make.expr_path(make.ident_path("f"));
let method = make.name_ref("debug_tuple");
let struct_name = format!("\"{name}\"");
let args = make.arg_list([make.expr_literal(&struct_name).into()]);
let mut expr = make.expr_method_call(target, method, args).into();
let mut pats = vec![];
for (i, _) in list.fields().enumerate() {
let name = format!("arg{i}");
let field_name = make.name(&name);
let pat = make.ident_pat(false, false, field_name.clone());
pats.push(pat.into());
let method_name = make.name_ref("field");
let field_path = &name.to_string();
let field_path = make.expr_path(make.ident_path(field_path));
let args = make.arg_list([field_path]);
expr = make.expr_method_call(expr, method_name, args).into();
}
let method = make.name_ref("finish");
let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
let pat = make.tuple_struct_pat(variant_name.clone(), pats.into_iter());
arms.push(make.match_arm(pat.into(), None, expr));
}
None => {
let fmt_string = make.expr_literal(&(format!("\"{name}\""))).into();
let args =
make.token_tree_from_node(make.arg_list([target, fmt_string]).syntax());
let macro_name = make.ident_path("write");
let macro_call = make.expr_macro(macro_name, args);
let variant_name = make.path_pat(variant_name);
arms.push(make.match_arm(variant_name, None, macro_call.into()));
}
}
}
let match_target = make.expr_path(make.ident_path("self"));
let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
let match_expr = make.expr_match(match_target, list);
let body = make.block_expr(None::<ast::Stmt>, Some(match_expr.into()));
let body = body.indent(ast::edit::IndentLevel(1));
Some(body)
}
ast::Adt::Struct(strukt) => {
let name = format!("\"{annotated_name}\"");
let args = make.arg_list([make.expr_literal(&name).into()]);
let target = make.expr_path(make.ident_path("f"));
let expr = match strukt.field_list() {
None => make.expr_method_call(target, make.name_ref("debug_struct"), args).into(),
Some(ast::FieldList::RecordFieldList(field_list)) => {
let method = make.name_ref("debug_struct");
let mut expr = make.expr_method_call(target, method, args).into();
for field in field_list.fields() {
let name = field.name()?;
let f_name = make.expr_literal(&(format!("\"{name}\""))).into();
let f_path = make.expr_path(make.ident_path("self"));
let f_path = make.expr_field(f_path, &format!("{name}")).into();
let f_path = make.expr_ref(f_path, false);
let args = make.arg_list([f_name, f_path]);
expr = make.expr_method_call(expr, make.name_ref("field"), args).into();
}
expr
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let method = make.name_ref("debug_tuple");
let mut expr = make.expr_method_call(target, method, args).into();
for (i, _) in field_list.fields().enumerate() {
let f_path = make.expr_path(make.ident_path("self"));
let f_path = make.expr_field(f_path, &format!("{i}")).into();
let f_path = make.expr_ref(f_path, false);
let method = make.name_ref("field");
expr = make.expr_method_call(expr, method, make.arg_list([f_path])).into();
}
expr
}
};
let method = make.name_ref("finish");
let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
let body =
make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1));
Some(body)
}
}
}
fn gen_default_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
let gen_default_call = || -> Option<ast::Expr> {
let fn_name = make.path_from_idents(["Default", "default"])?;
Some(make.expr_call(make.expr_path(fn_name), make.arg_list([])).into())
};
match adt {
ast::Adt::Union(_) => None,
ast::Adt::Enum(_) => None,
ast::Adt::Struct(strukt) => {
let expr = match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut fields = vec![];
for field in field_list.fields() {
let method_call = gen_default_call()?;
let name_ref = make.name_ref(&field.name()?.to_string());
let field = make.record_expr_field(name_ref, Some(method_call));
fields.push(field);
}
let struct_name = make.ident_path("Self");
let fields = make.record_expr_field_list(fields);
make.record_expr(struct_name, fields).into()
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let struct_name = make.expr_path(make.ident_path("Self"));
let fields = field_list
.fields()
.map(|_| gen_default_call())
.collect::<Option<Vec<ast::Expr>>>()?;
make.expr_call(struct_name, make.arg_list(fields)).into()
}
None => {
let struct_name = make.ident_path("Self");
let fields = make.record_expr_field_list([]);
make.record_expr(struct_name, fields).into()
}
};
let body =
make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1));
Some(body)
}
}
}
fn gen_hash_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
let gen_hash_call = |target: ast::Expr| -> ast::Stmt {
let method = make.name_ref("hash");
let arg = make.expr_path(make.ident_path("state"));
let expr = make.expr_method_call(target, method, make.arg_list([arg])).into();
make.expr_stmt(expr).into()
};
let body = match adt {
ast::Adt::Union(_) => return None,
ast::Adt::Enum(_) => {
let fn_name = make_discriminant(make)?;
let arg = make.expr_path(make.ident_path("self"));
let fn_call: ast::Expr = make.expr_call(fn_name, make.arg_list([arg])).into();
let stmt = gen_hash_call(fn_call);
make.block_expr([stmt], None).indent(ast::edit::IndentLevel(1))
}
ast::Adt::Struct(strukt) => match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut stmts = vec![];
for field in field_list.fields() {
let base = make.expr_path(make.ident_path("self"));
let target = make.expr_field(base, &field.name()?.to_string()).into();
stmts.push(gen_hash_call(target));
}
make.block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut stmts = vec![];
for (i, _) in field_list.fields().enumerate() {
let base = make.expr_path(make.ident_path("self"));
let target = make.expr_field(base, &format!("{i}")).into();
stmts.push(gen_hash_call(target));
}
make.block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
}
None => return None,
},
};
Some(body)
}
fn gen_partial_eq(
make: &SyntaxFactory,
adt: &ast::Adt,
trait_ref: Option<TraitRef<'_>>,
) -> Option<ast::BlockExpr> {
let gen_eq_chain = |expr: Option<ast::Expr>, cmp: ast::Expr| -> Option<ast::Expr> {
match expr {
Some(expr) => Some(make.expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)),
None => Some(cmp),
}
};
let gen_record_pat_field = |field_name: &str, pat_name: &str| -> ast::RecordPatField {
let pat = make.ident_pat(false, false, make.name(pat_name));
let name_ref = make.name_ref(field_name);
make.record_pat_field(name_ref, pat.into())
};
let gen_record_pat =
|record_name: ast::Path, fields: Vec<ast::RecordPatField>| -> ast::RecordPat {
let list = make.record_pat_field_list(fields, None);
make.record_pat_with_fields(record_name, list)
};
let gen_variant_path = |variant: &ast::Variant| -> Option<ast::Path> {
make.path_from_idents(["Self", &variant.name()?.to_string()])
};
let gen_tuple_field = |field_name: &str| -> ast::Pat {
ast::Pat::IdentPat(make.ident_pat(false, false, make.name(field_name)))
};
if let Some(trait_ref) = trait_ref {
let self_ty = trait_ref.self_ty();
let rhs_ty = trait_ref.get_type_argument(1)?;
if self_ty != rhs_ty {
return None;
}
}
let body = match adt {
ast::Adt::Union(_) => return None,
ast::Adt::Enum(enum_) => {
let lhs_name = make.expr_path(make.ident_path("self"));
let lhs =
make.expr_call(make_discriminant(make)?, make.arg_list([lhs_name.clone()])).into();
let rhs_name = make.expr_path(make.ident_path("other"));
let rhs =
make.expr_call(make_discriminant(make)?, make.arg_list([rhs_name.clone()])).into();
let eq_check =
make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
let mut n_cases = 0;
let mut arms = vec![];
for variant in enum_.variant_list()?.variants() {
n_cases += 1;
match variant.field_list() {
Some(ast::FieldList::RecordFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for field in list.fields() {
let field_name = field.name()?.to_string();
let l_name = &format!("l_{field_name}");
l_fields.push(gen_record_pat_field(&field_name, l_name));
let r_name = &format!("r_{field_name}");
r_fields.push(gen_record_pat_field(&field_name, r_name));
let lhs = make.expr_path(make.ident_path(l_name));
let rhs = make.expr_path(make.ident_path(r_name));
let cmp = make.expr_bin_op(
lhs,
BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
rhs,
);
expr = gen_eq_chain(expr, cmp);
}
let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make.tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make.match_arm(tuple.into(), None, expr));
}
}
Some(ast::FieldList::TupleFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for (i, _) in list.fields().enumerate() {
let field_name = format!("{i}");
let l_name = format!("l{field_name}");
l_fields.push(gen_tuple_field(&l_name));
let r_name = format!("r{field_name}");
r_fields.push(gen_tuple_field(&r_name));
let lhs = make.expr_path(make.ident_path(&l_name));
let rhs = make.expr_path(make.ident_path(&r_name));
let cmp = make.expr_bin_op(
lhs,
BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
rhs,
);
expr = gen_eq_chain(expr, cmp);
}
let left = make.tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
let right = make.tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make.tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make.match_arm(tuple.into(), None, expr));
}
}
None => continue,
}
}
let expr = match arms.len() {
0 => eq_check,
arms_len => {
if n_cases > 1 {
let lhs = make.wildcard_pat().into();
let rhs = if arms_len == n_cases {
make.expr_literal("false").into()
} else {
eq_check
};
arms.push(make.match_arm(lhs, None, rhs));
}
let match_target = make.expr_tuple([lhs_name, rhs_name]).into();
let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
make.expr_match(match_target, list).into()
}
};
make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1))
}
ast::Adt::Struct(strukt) => match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut expr = None;
for field in field_list.fields() {
let lhs = make.expr_path(make.ident_path("self"));
let lhs = make.expr_field(lhs, &field.name()?.to_string()).into();
let rhs = make.expr_path(make.ident_path("other"));
let rhs = make.expr_field(rhs, &field.name()?.to_string()).into();
let cmp =
make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
expr = gen_eq_chain(expr, cmp);
}
make.block_expr(None, expr).indent(ast::edit::IndentLevel(1))
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut expr = None;
for (i, _) in field_list.fields().enumerate() {
let idx = format!("{i}");
let lhs = make.expr_path(make.ident_path("self"));
let lhs = make.expr_field(lhs, &idx).into();
let rhs = make.expr_path(make.ident_path("other"));
let rhs = make.expr_field(rhs, &idx).into();
let cmp =
make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
expr = gen_eq_chain(expr, cmp);
}
make.block_expr(None::<ast::Stmt>, expr).indent(ast::edit::IndentLevel(1))
}
None => {
let expr = make.expr_literal("true").into();
make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
},
};
Some(body)
}
fn gen_partial_ord(
make: &SyntaxFactory,
adt: &ast::Adt,
trait_ref: Option<TraitRef<'_>>,
) -> Option<ast::BlockExpr> {
let gen_partial_eq_match = |match_target: ast::Expr| -> Option<ast::Stmt> {
let mut arms = vec![];
let variant_name =
make.path_pat(make.path_from_idents(["core", "cmp", "Ordering", "Equal"])?);
let lhs = make.tuple_struct_pat(make.path_from_idents(["Some"])?, [variant_name]);
arms.push(make.match_arm(lhs.into(), None, make.expr_empty_block().into()));
arms.push(make.match_arm(
make.ident_pat(false, false, make.name("ord")).into(),
None,
make.expr_return(Some(make.expr_path(make.ident_path("ord")))).into(),
));
let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
Some(make.expr_stmt(make.expr_match(match_target, list).into()).into())
};
let gen_partial_cmp_call = |lhs: ast::Expr, rhs: ast::Expr| -> ast::Expr {
let rhs = make.expr_ref(rhs, false);
let method = make.name_ref("partial_cmp");
make.expr_method_call(lhs, method, make.arg_list([rhs])).into()
};
if let Some(trait_ref) = trait_ref {
let self_ty = trait_ref.self_ty();
let rhs_ty = trait_ref.get_type_argument(1)?;
if self_ty != rhs_ty {
return None;
}
}
let body = match adt {
ast::Adt::Union(_) => return None,
ast::Adt::Enum(_) => return None,
ast::Adt::Struct(strukt) => match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut exprs = vec![];
for field in field_list.fields() {
let lhs = make.expr_path(make.ident_path("self"));
let lhs = make.expr_field(lhs, &field.name()?.to_string()).into();
let rhs = make.expr_path(make.ident_path("other"));
let rhs = make.expr_field(rhs, &field.name()?.to_string()).into();
let ord = gen_partial_cmp_call(lhs, rhs);
exprs.push(ord);
}
let tail = exprs.pop();
let stmts = exprs
.into_iter()
.map(gen_partial_eq_match)
.collect::<Option<Vec<ast::Stmt>>>()?;
make.block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut exprs = vec![];
for (i, _) in field_list.fields().enumerate() {
let idx = format!("{i}");
let lhs = make.expr_path(make.ident_path("self"));
let lhs = make.expr_field(lhs, &idx).into();
let rhs = make.expr_path(make.ident_path("other"));
let rhs = make.expr_field(rhs, &idx).into();
let ord = gen_partial_cmp_call(lhs, rhs);
exprs.push(ord);
}
let tail = exprs.pop();
let stmts = exprs
.into_iter()
.map(gen_partial_eq_match)
.collect::<Option<Vec<ast::Stmt>>>()?;
make.block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
}
None => {
let expr = make.expr_literal("true").into();
make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
},
};
Some(body)
}
fn make_discriminant(make: &SyntaxFactory) -> Option<ast::Expr> {
Some(make.expr_path(make.path_from_idents(["core", "mem", "discriminant"])?))
}