use proc_macro2::Span;
use quote::ToTokens;
use crate::ast::BinOp;
fn expr_precedence(expr: &syn::Expr) -> u8 {
match expr {
syn::Expr::Lit(_) | syn::Expr::Path(_) | syn::Expr::Paren(_) => 100,
syn::Expr::MethodCall(_) | syn::Expr::Field(_) | syn::Expr::Index(_) |
syn::Expr::Call(_) => 90,
syn::Expr::Try(_) => 85,
syn::Expr::Unary(_) => 80, syn::Expr::Cast(_) => 75, syn::Expr::Binary(b) => syn_binop_precedence(&b.op),
syn::Expr::Range(_) => 15,
syn::Expr::Assign(_) => 10,
syn::Expr::Return(_) | syn::Expr::Break(_) | syn::Expr::Closure(_) => 5,
syn::Expr::If(_) | syn::Expr::Match(_) => 1,
syn::Expr::Block(_) | syn::Expr::Unsafe(_) |
syn::Expr::Loop(_) | syn::Expr::While(_) | syn::Expr::ForLoop(_) => 100,
_ => 50, }
}
fn syn_binop_precedence(op: &syn::BinOp) -> u8 {
match op {
syn::BinOp::Mul(_) | syn::BinOp::Div(_) | syn::BinOp::Rem(_) => 70,
syn::BinOp::Add(_) | syn::BinOp::Sub(_) => 65,
syn::BinOp::Shl(_) | syn::BinOp::Shr(_) => 60,
syn::BinOp::BitAnd(_) => 55,
syn::BinOp::BitXor(_) => 50,
syn::BinOp::BitOr(_) => 45,
syn::BinOp::Lt(_) | syn::BinOp::Gt(_) | syn::BinOp::Le(_) | syn::BinOp::Ge(_) |
syn::BinOp::Eq(_) | syn::BinOp::Ne(_) => 40,
syn::BinOp::And(_) => 35,
syn::BinOp::Or(_) => 30,
_ => 50,
}
}
pub fn parenthesize(expr: syn::Expr) -> syn::Expr {
match expr {
syn::Expr::Binary(mut binary) => {
let parent_prec = syn_binop_precedence(&binary.op);
*binary.left = parenthesize_child(*binary.left, parent_prec, true);
*binary.right = parenthesize_child(*binary.right, parent_prec, false);
syn::Expr::Binary(binary)
}
syn::Expr::Cast(mut cast) => {
let child = parenthesize(*cast.expr);
let child_prec = expr_precedence(&child);
*cast.expr = if child_prec < 75 {
wrap_paren(child)
} else {
child
};
syn::Expr::Cast(cast)
}
syn::Expr::Unary(mut unary) => {
let child = parenthesize(*unary.expr);
let child_prec = expr_precedence(&child);
*unary.expr = if child_prec < 80 {
wrap_paren(child)
} else {
child
};
syn::Expr::Unary(unary)
}
syn::Expr::Field(mut field) => {
let child = parenthesize(*field.base);
let child_prec = expr_precedence(&child);
*field.base = if child_prec < 90 {
wrap_paren(child)
} else {
child
};
syn::Expr::Field(field)
}
syn::Expr::MethodCall(mut mc) => {
let child = parenthesize(*mc.receiver);
let child_prec = expr_precedence(&child);
*mc.receiver = if child_prec < 90 {
wrap_paren(child)
} else {
child
};
syn::Expr::MethodCall(mc)
}
syn::Expr::If(mut if_expr) => {
*if_expr.cond = parenthesize(*if_expr.cond);
parenthesize_block(&mut if_expr.then_branch);
if let Some((_, ref mut else_branch)) = if_expr.else_branch {
*else_branch = Box::new(parenthesize(*else_branch.clone()));
}
syn::Expr::If(if_expr)
}
syn::Expr::Paren(mut paren) => {
*paren.expr = parenthesize(*paren.expr);
syn::Expr::Paren(paren)
}
syn::Expr::Assign(mut assign) => {
*assign.left = parenthesize(*assign.left);
*assign.right = parenthesize(*assign.right);
syn::Expr::Assign(assign)
}
syn::Expr::Call(mut call) => {
*call.func = parenthesize(*call.func);
for arg in call.args.iter_mut() {
*arg = parenthesize(arg.clone());
}
syn::Expr::Call(call)
}
syn::Expr::Return(mut ret) => {
if let Some(ref mut expr) = ret.expr {
*expr = Box::new(parenthesize(*expr.clone()));
}
syn::Expr::Return(ret)
}
syn::Expr::Block(mut b) => {
parenthesize_block(&mut b.block);
syn::Expr::Block(b)
}
syn::Expr::Reference(mut r) => {
r.expr = Box::new(parenthesize(*r.expr));
syn::Expr::Reference(r)
}
syn::Expr::Index(mut i) => {
i.expr = Box::new(parenthesize(*i.expr));
i.index = Box::new(parenthesize(*i.index));
syn::Expr::Index(i)
}
other => other,
}
}
fn parenthesize_child(child: syn::Expr, parent_prec: u8, is_left: bool) -> syn::Expr {
let child = parenthesize(child);
let child_prec = expr_precedence(&child);
let mut needs_parens = child_prec < parent_prec
|| (child_prec == parent_prec && !is_left);
if is_left && starts_with_block(&child) {
needs_parens = true;
}
if needs_parens {
wrap_paren(child)
} else {
child
}
}
fn starts_with_block(expr: &syn::Expr) -> bool {
matches!(expr,
syn::Expr::Block(_) | syn::Expr::Unsafe(_) | syn::Expr::If(_) |
syn::Expr::Match(_) | syn::Expr::Loop(_) | syn::Expr::While(_) |
syn::Expr::ForLoop(_)
)
}
fn wrap_paren(expr: syn::Expr) -> syn::Expr {
syn::Expr::Paren(syn::ExprParen {
attrs: vec![],
paren_token: syn::token::Paren::default(),
expr: Box::new(expr),
})
}
fn parenthesize_block(block: &mut syn::Block) {
for stmt in block.stmts.iter_mut() {
match stmt {
syn::Stmt::Expr(expr, _) => {
*expr = parenthesize(expr.clone());
}
syn::Stmt::Local(local) => {
if let Some(ref mut init) = local.init {
init.expr = Box::new(parenthesize(*init.expr.clone()));
}
}
_ => {}
}
}
}
pub fn to_syn_binop(op: BinOp) -> syn::BinOp {
match op {
BinOp::Add => syn::BinOp::Add(Default::default()),
BinOp::Sub => syn::BinOp::Sub(Default::default()),
BinOp::Mul => syn::BinOp::Mul(Default::default()),
BinOp::Div => syn::BinOp::Div(Default::default()),
BinOp::Mod => syn::BinOp::Rem(Default::default()),
BinOp::BitAnd => syn::BinOp::BitAnd(Default::default()),
BinOp::BitOr => syn::BinOp::BitOr(Default::default()),
BinOp::BitXor => syn::BinOp::BitXor(Default::default()),
BinOp::Shl => syn::BinOp::Shl(Default::default()),
BinOp::Shr => syn::BinOp::Shr(Default::default()),
BinOp::Eq => syn::BinOp::Eq(Default::default()),
BinOp::Ne => syn::BinOp::Ne(Default::default()),
BinOp::Lt => syn::BinOp::Lt(Default::default()),
BinOp::Gt => syn::BinOp::Gt(Default::default()),
BinOp::Le => syn::BinOp::Le(Default::default()),
BinOp::Ge => syn::BinOp::Ge(Default::default()),
BinOp::LogAnd => syn::BinOp::And(Default::default()),
BinOp::LogOr => syn::BinOp::Or(Default::default()),
}
}
pub fn expr_to_string(expr: &syn::Expr) -> String {
let parenthesized = parenthesize(expr.clone());
parenthesized.to_token_stream().to_string()
}
pub fn ident(name: &str) -> syn::Ident {
if let Some(raw_name) = name.strip_prefix("r#") {
return syn::Ident::new_raw(raw_name, Span::call_site());
}
if name.is_empty() || !name.chars().all(|c| c.is_alphanumeric() || c == '_') ||
name.starts_with(|c: char| c.is_ascii_digit()) {
return syn::Ident::new("__invalid_ident__", Span::call_site());
}
if is_rust_keyword(name) {
syn::Ident::new_raw(name, Span::call_site())
} else {
syn::Ident::new(name, Span::call_site())
}
}
fn is_rust_keyword(name: &str) -> bool {
matches!(name,
"as" | "break" | "const" | "continue" | "crate" | "else" | "enum" |
"extern" | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" |
"loop" | "match" | "mod" | "move" | "mut" | "pub" | "ref" | "return" |
"self" | "Self" | "static" | "struct" | "super" | "trait" | "true" |
"type" | "unsafe" | "use" | "where" | "while" | "async" | "await" |
"dyn" | "abstract" | "become" | "box" | "do" | "final" | "macro" |
"override" | "priv" | "typeof" | "unsized" | "virtual" | "yield" | "gen" | "try"
)
}
pub fn normalize_parens(s: &str) -> String {
if let Some(parsed) = syn::parse_str::<syn::Expr>(s).ok() {
let stripped = strip_all_parens(parsed);
let paren_added = parenthesize(stripped);
let pretty = pretty_expr(&paren_added);
if !pretty.is_empty() && !pretty.contains('\n') {
return pretty;
}
let toks = quote::quote! { #paren_added }.to_string();
if !toks.is_empty() {
return toks;
}
}
fallback_strip_outer_parens(s)
}
fn fallback_strip_outer_parens(s: &str) -> String {
let s = s.trim();
if s.len() < 2 || !s.starts_with('(') || !s.ends_with(')') {
return s.to_string();
}
let inner = &s[1..s.len() - 1];
if inner.trim_start().starts_with('{') {
return s.to_string();
}
let mut depth = 0i32;
for ch in inner.chars() {
match ch {
'(' | '{' | '[' => depth += 1,
')' | '}' | ']' => {
depth -= 1;
if depth < 0 {
return s.to_string();
}
}
_ => {}
}
}
if depth == 0 { inner.to_string() } else { s.to_string() }
}
pub fn strip_all_parens(expr: syn::Expr) -> syn::Expr {
match expr {
syn::Expr::Paren(p) => strip_all_parens(*p.expr),
syn::Expr::Binary(mut b) => {
*b.left = strip_all_parens(*b.left);
*b.right = strip_all_parens(*b.right);
syn::Expr::Binary(b)
}
syn::Expr::Unary(mut u) => {
*u.expr = strip_all_parens(*u.expr);
syn::Expr::Unary(u)
}
syn::Expr::Cast(mut c) => {
*c.expr = strip_all_parens(*c.expr);
syn::Expr::Cast(c)
}
syn::Expr::Field(mut f) => {
*f.base = strip_all_parens(*f.base);
syn::Expr::Field(f)
}
syn::Expr::MethodCall(mut m) => {
*m.receiver = strip_all_parens(*m.receiver);
for arg in m.args.iter_mut() {
*arg = strip_all_parens(arg.clone());
}
syn::Expr::MethodCall(m)
}
syn::Expr::Call(mut c) => {
*c.func = strip_all_parens(*c.func);
for arg in c.args.iter_mut() {
*arg = strip_all_parens(arg.clone());
}
syn::Expr::Call(c)
}
syn::Expr::If(mut i) => {
*i.cond = strip_all_parens(*i.cond);
strip_parens_in_block(&mut i.then_branch);
if let Some((_, ref mut else_branch)) = i.else_branch {
*else_branch = Box::new(strip_all_parens(*else_branch.clone()));
}
syn::Expr::If(i)
}
syn::Expr::Index(mut i) => {
*i.expr = strip_all_parens(*i.expr);
*i.index = strip_all_parens(*i.index);
syn::Expr::Index(i)
}
syn::Expr::Assign(mut a) => {
*a.left = strip_all_parens(*a.left);
*a.right = strip_all_parens(*a.right);
syn::Expr::Assign(a)
}
syn::Expr::Return(mut r) => {
if let Some(ref mut e) = r.expr {
*e = Box::new(strip_all_parens(*e.clone()));
}
syn::Expr::Return(r)
}
syn::Expr::Block(mut b) => {
strip_parens_in_block(&mut b.block);
syn::Expr::Block(b)
}
syn::Expr::Reference(mut r) => {
*r.expr = strip_all_parens(*r.expr);
syn::Expr::Reference(r)
}
syn::Expr::Unsafe(mut u) => {
strip_parens_in_block(&mut u.block);
syn::Expr::Unsafe(u)
}
other => other,
}
}
fn strip_parens_in_block(block: &mut syn::Block) {
for stmt in block.stmts.iter_mut() {
match stmt {
syn::Stmt::Expr(e, _) => *e = strip_all_parens(e.clone()),
syn::Stmt::Local(l) => {
if let Some(ref mut init) = l.init {
*init.expr = strip_all_parens(*init.expr.clone());
}
}
_ => {}
}
}
}
fn pretty_expr(expr: &syn::Expr) -> String {
let tokens = quote::quote! {
fn __() -> __T {
#expr
}
};
let file: syn::File = match syn::parse2(tokens) {
Ok(f) => f,
Err(_) => {
return expr.to_token_stream().to_string();
}
};
let formatted = prettyplease::unparse(&file);
extract_fn_body(&formatted)
}
fn extract_fn_body(formatted: &str) -> String {
let lines: Vec<&str> = formatted.lines().collect();
if lines.len() < 3 {
return formatted.to_string();
}
let body_lines: Vec<&str> = lines[1..lines.len() - 1]
.iter()
.map(|l| l.strip_prefix(" ").unwrap_or(l))
.collect();
body_lines.join("\n")
}
pub fn is_bool_syn_expr(expr: &syn::Expr) -> bool {
match expr {
syn::Expr::Binary(b) => matches!(b.op,
syn::BinOp::Eq(_) | syn::BinOp::Ne(_) |
syn::BinOp::Lt(_) | syn::BinOp::Gt(_) |
syn::BinOp::Le(_) | syn::BinOp::Ge(_) |
syn::BinOp::And(_) | syn::BinOp::Or(_)
),
syn::Expr::Unary(u) => matches!(u.op, syn::UnOp::Not(_)) && is_bool_syn_expr(&u.expr),
syn::Expr::Lit(lit) => matches!(lit.lit, syn::Lit::Bool(_)),
syn::Expr::Paren(p) => is_bool_syn_expr(&p.expr),
syn::Expr::MethodCall(mc) => mc.method == "is_null",
_ => false,
}
}
pub fn looks_like_pointer(expr: &syn::Expr) -> bool {
match expr {
syn::Expr::Cast(cast) => {
let ty_str = cast.ty.to_token_stream().to_string();
ty_str.contains("* mut") || ty_str.contains("* const")
}
syn::Expr::MethodCall(mc) => {
let method = mc.method.to_string();
matches!(method.as_str(),
"offset" | "wrapping_add" | "wrapping_sub" | "as_ptr" | "as_mut_ptr")
}
syn::Expr::Call(call) => {
let func_str = call.func.to_token_stream().to_string();
func_str.contains("null_mut") || func_str.contains("null")
}
syn::Expr::Paren(p) => looks_like_pointer(&p.expr),
_ => false,
}
}
pub fn wrap_as_bool(expr: syn::Expr) -> syn::Expr {
if is_bool_syn_expr(&expr) {
return expr;
}
if looks_like_pointer(&expr) {
let is_null_call = syn::Expr::MethodCall(syn::ExprMethodCall {
attrs: vec![],
receiver: Box::new(expr),
dot_token: Default::default(),
method: ident("is_null"),
turbofish: None,
paren_token: Default::default(),
args: syn::punctuated::Punctuated::new(),
});
return syn::Expr::Unary(syn::ExprUnary {
attrs: vec![],
op: syn::UnOp::Not(Default::default()),
expr: Box::new(is_null_call),
});
}
syn::Expr::Binary(syn::ExprBinary {
attrs: vec![],
left: Box::new(expr),
op: syn::BinOp::Ne(Default::default()),
right: Box::new(int_lit(0)),
})
}
pub fn int_lit(n: i64) -> syn::Expr {
let lit = syn::LitInt::new(&n.to_string(), Span::call_site());
syn::Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Int(lit),
})
}
pub fn cast_syn_expr(expr: syn::Expr, ty_str: &str) -> syn::Expr {
insert_cast(expr, parse_type(ty_str))
}
pub fn insert_cast(expr: syn::Expr, ty: syn::Type) -> syn::Expr {
syn::Expr::Cast(syn::ExprCast {
attrs: vec![],
expr: Box::new(expr),
as_token: Default::default(),
ty: Box::new(ty),
})
}
pub fn parse_type(ty_str: &str) -> syn::Type {
syn::parse_str(ty_str).unwrap_or_else(|_| {
syn::parse_str("c_int").unwrap()
})
}
pub fn null_for_type(ty_str: &str) -> syn::Expr {
if ty_str.contains("*const") {
syn::parse_str("std::ptr::null()").unwrap()
} else if ty_str.contains("*mut") || ty_str.contains("*") {
syn::parse_str("std::ptr::null_mut()").unwrap()
} else {
int_lit(0)
}
}
pub fn as_ptr(expr: syn::Expr) -> syn::Expr {
syn::Expr::MethodCall(syn::ExprMethodCall {
attrs: vec![],
receiver: Box::new(expr),
dot_token: Default::default(),
method: ident("as_ptr"),
turbofish: None,
paren_token: Default::default(),
args: syn::punctuated::Punctuated::new(),
})
}
pub fn field_access(expr: syn::Expr, field_name: &str) -> syn::Expr {
syn::Expr::Field(syn::ExprField {
attrs: vec![],
base: Box::new(expr),
dot_token: Default::default(),
member: syn::Member::Named(ident(field_name)),
})
}
pub fn deref(expr: syn::Expr) -> syn::Expr {
syn::Expr::Unary(syn::ExprUnary {
attrs: vec![],
op: syn::UnOp::Deref(Default::default()),
expr: Box::new(expr),
})
}
pub fn addr_of_mut(expr: syn::Expr) -> syn::Expr {
syn::Expr::RawAddr(syn::ExprRawAddr {
attrs: vec![],
and_token: Default::default(),
raw: Default::default(),
mutability: syn::PointerMutability::Mut(Default::default()),
expr: Box::new(expr),
})
}
pub fn call(func_name: &str, args: Vec<syn::Expr>) -> syn::Expr {
let func_ident = ident(func_name);
let mut punctuated = syn::punctuated::Punctuated::new();
for arg in args {
punctuated.push(arg);
}
syn::Expr::Call(syn::ExprCall {
attrs: vec![],
func: Box::new(syn::Expr::Path(syn::ExprPath {
attrs: vec![],
qself: None,
path: func_ident.into(),
})),
paren_token: Default::default(),
args: punctuated,
})
}
pub fn ident_expr(name: &str) -> syn::Expr {
syn::Expr::Path(syn::ExprPath {
attrs: vec![],
qself: None,
path: ident(name).into(),
})
}
pub fn method_call(receiver: syn::Expr, method: &str, args: Vec<syn::Expr>) -> syn::Expr {
let mut punctuated = syn::punctuated::Punctuated::new();
for arg in args {
punctuated.push(arg);
}
syn::Expr::MethodCall(syn::ExprMethodCall {
attrs: vec![],
receiver: Box::new(receiver),
dot_token: Default::default(),
method: ident(method),
turbofish: None,
paren_token: Default::default(),
args: punctuated,
})
}
pub fn assign_expr(lhs: syn::Expr, rhs: syn::Expr) -> syn::Expr {
syn::Expr::Assign(syn::ExprAssign {
attrs: vec![],
left: Box::new(lhs),
eq_token: Default::default(),
right: Box::new(rhs),
})
}
pub fn assign_op_expr(lhs: syn::Expr, op: syn::BinOp, rhs: syn::Expr) -> syn::Expr {
syn::Expr::Binary(syn::ExprBinary {
attrs: vec![],
left: Box::new(lhs),
op,
right: Box::new(rhs),
})
}
pub fn semi_stmt(expr: syn::Expr) -> syn::Stmt {
syn::Stmt::Expr(expr, Some(Default::default()))
}
pub fn let_stmt(name: &str, value: syn::Expr) -> syn::Stmt {
let pat = syn::Pat::Ident(syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: ident(name),
subpat: None,
});
syn::Stmt::Local(syn::Local {
attrs: vec![],
let_token: Default::default(),
pat,
init: Some(syn::LocalInit {
eq_token: Default::default(),
expr: Box::new(value),
diverge: None,
}),
semi_token: Default::default(),
})
}
pub fn block_with_value(stmts: Vec<syn::Stmt>, value: syn::Expr) -> syn::Expr {
let mut all_stmts = stmts;
all_stmts.push(syn::Stmt::Expr(value, None)); syn::Expr::Block(syn::ExprBlock {
attrs: vec![],
label: None,
block: syn::Block {
brace_token: Default::default(),
stmts: all_stmts,
},
})
}
pub fn c_assign_op_to_syn_compound(op: crate::ast::AssignOp) -> Option<syn::BinOp> {
use crate::ast::AssignOp;
Some(match op {
AssignOp::Assign => return None,
AssignOp::AddAssign => syn::BinOp::AddAssign(Default::default()),
AssignOp::SubAssign => syn::BinOp::SubAssign(Default::default()),
AssignOp::MulAssign => syn::BinOp::MulAssign(Default::default()),
AssignOp::DivAssign => syn::BinOp::DivAssign(Default::default()),
AssignOp::ModAssign => syn::BinOp::RemAssign(Default::default()),
AssignOp::AndAssign => syn::BinOp::BitAndAssign(Default::default()),
AssignOp::OrAssign => syn::BinOp::BitOrAssign(Default::default()),
AssignOp::XorAssign => syn::BinOp::BitXorAssign(Default::default()),
AssignOp::ShlAssign => syn::BinOp::ShlAssign(Default::default()),
AssignOp::ShrAssign => syn::BinOp::ShrAssign(Default::default()),
})
}
pub fn if_else(cond: syn::Expr, then_expr: syn::Expr, else_expr: syn::Expr) -> syn::Expr {
syn::Expr::If(syn::ExprIf {
attrs: vec![],
if_token: Default::default(),
cond: Box::new(cond),
then_branch: syn::Block {
brace_token: Default::default(),
stmts: vec![syn::Stmt::Expr(then_expr, None)],
},
else_branch: Some((
Default::default(),
Box::new(syn::Expr::Block(syn::ExprBlock {
attrs: vec![],
label: None,
block: syn::Block {
brace_token: Default::default(),
stmts: vec![syn::Stmt::Expr(else_expr, None)],
},
})),
)),
})
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_parenthesize_binary_precedence() {
let expr: syn::Expr = parse_quote!(a + b * c);
let result = expr_to_string(&expr);
assert_eq!(result, "a + b * c");
let a: syn::Expr = parse_quote!(a);
let b: syn::Expr = parse_quote!(b);
let c: syn::Expr = parse_quote!(c);
let add: syn::Expr = parse_quote!(#a + #b);
let mul = syn::Expr::Binary(syn::ExprBinary {
attrs: vec![],
left: Box::new(add),
op: syn::BinOp::Mul(Default::default()),
right: Box::new(c),
});
let result = expr_to_string(&mul);
assert_eq!(result, "(a + b) * c");
}
#[test]
fn test_parenthesize_cast() {
let a: syn::Expr = parse_quote!(a);
let mask: syn::Expr = parse_quote!(MASK);
let bitand = syn::Expr::Binary(syn::ExprBinary {
attrs: vec![],
left: Box::new(a),
op: syn::BinOp::BitAnd(Default::default()),
right: Box::new(mask),
});
let cast = syn::Expr::Cast(syn::ExprCast {
attrs: vec![],
expr: Box::new(bitand),
as_token: Default::default(),
ty: Box::new(parse_quote!(u32)),
});
let result = expr_to_string(&cast);
assert_eq!(result, "(a & MASK) as u32");
}
#[test]
fn test_parenthesize_if_ne() {
let if_expr: syn::Expr = parse_quote!(if cond { A } else { B });
let ne = syn::Expr::Binary(syn::ExprBinary {
attrs: vec![],
left: Box::new(if_expr),
op: syn::BinOp::Ne(Default::default()),
right: Box::new(parse_quote!(0)),
});
let result = expr_to_string(&ne);
assert!(result.contains("if cond"));
}
#[test]
fn test_deref_field() {
let a: syn::Expr = parse_quote!(a);
let deref = syn::Expr::Unary(syn::ExprUnary {
attrs: vec![],
op: syn::UnOp::Deref(Default::default()),
expr: Box::new(a),
});
let field = syn::Expr::Field(syn::ExprField {
attrs: vec![],
base: Box::new(deref),
dot_token: Default::default(),
member: syn::Member::Named(ident("field")),
});
let result = expr_to_string(&field);
assert_eq!(result, "(* a) . field");
}
#[test]
fn test_ident_keyword() {
let i = ident("type");
assert_eq!(i.to_string(), "r#type");
}
#[test]
fn test_normalize_cast_removes_outer_parens() {
assert_eq!(normalize_parens("(x as i32)"), "x as i32");
}
#[test]
fn test_normalize_deref_removes_outer_parens() {
assert_eq!(normalize_parens("(*ptr)"), "*ptr");
}
#[test]
fn test_normalize_addr_of_removes_outer_parens() {
assert_eq!(normalize_parens("(&mut x)"), "&mut x");
}
#[test]
fn test_normalize_binary_removes_outer_parens() {
assert_eq!(normalize_parens("(a + b)"), "a + b");
}
#[test]
fn test_normalize_deref_field_preserves_needed_parens() {
assert_eq!(normalize_parens("(*a).field"), "(*a).field");
}
#[test]
fn test_normalize_cast_in_binary_preserves_needed_parens() {
assert_eq!(normalize_parens("(a & MASK) as u32"), "(a & MASK) as u32");
}
#[test]
fn test_normalize_nested_unnecessary_parens() {
assert_eq!(normalize_parens("((x as i32))"), "x as i32");
}
#[test]
fn test_normalize_preserves_precedence() {
assert_eq!(normalize_parens("(a + b) * c"), "(a + b) * c");
}
#[test]
fn test_normalize_no_change_needed() {
assert_eq!(normalize_parens("x"), "x");
assert_eq!(normalize_parens("42"), "42");
assert_eq!(normalize_parens("foo(a, b)"), "foo(a, b)");
}
#[test]
fn test_normalize_method_call() {
assert_eq!(normalize_parens("(ptr).is_null()"), "ptr.is_null()");
}
#[test]
fn test_normalize_logical_ops() {
assert_eq!(normalize_parens("(a && b)"), "a && b");
assert_eq!(normalize_parens("(a || b)"), "a || b");
}
#[test]
fn test_normalize_complex_nested() {
assert_eq!(
normalize_parens("((*sv).sv_flags as u32)"),
"(*sv).sv_flags as u32"
);
}
#[test]
fn test_normalize_unary_minus() {
assert_eq!(normalize_parens("(-x)"), "-x");
}
#[test]
fn test_normalize_not() {
assert_eq!(normalize_parens("(!cond)"), "!cond");
}
#[test]
fn test_normalize_block_expr_passthrough() {
let s = "{ x += 1; x }";
let result = normalize_parens(s);
assert!(result == s || !result.contains('\n'));
}
#[test]
fn test_is_bool_syn_expr_comparison() {
let expr: syn::Expr = parse_quote!(a == b);
assert!(is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(a != 0);
assert!(is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(a < b);
assert!(is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_logical() {
let expr: syn::Expr = parse_quote!(a && b);
assert!(is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(a || b);
assert!(is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_not() {
let expr: syn::Expr = parse_quote!(!(a == b));
assert!(is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(!x);
assert!(!is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_non_bool() {
let expr: syn::Expr = parse_quote!(a + b);
assert!(!is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(42);
assert!(!is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(foo(x));
assert!(!is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_bool_lit() {
let expr: syn::Expr = parse_quote!(true);
assert!(is_bool_syn_expr(&expr));
let expr: syn::Expr = parse_quote!(false);
assert!(is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_is_null() {
let expr: syn::Expr = parse_quote!(ptr.is_null());
assert!(is_bool_syn_expr(&expr));
}
#[test]
fn test_is_bool_syn_expr_paren() {
let expr: syn::Expr = parse_quote!((a == b));
assert!(is_bool_syn_expr(&expr));
}
#[test]
fn test_looks_like_pointer_cast() {
let expr: syn::Expr = parse_quote!(x as *mut i32);
assert!(looks_like_pointer(&expr));
let expr: syn::Expr = parse_quote!(x as *const u8);
assert!(looks_like_pointer(&expr));
let expr: syn::Expr = parse_quote!(x as i32);
assert!(!looks_like_pointer(&expr));
}
#[test]
fn test_looks_like_pointer_method() {
let expr: syn::Expr = parse_quote!(p.offset(1));
assert!(looks_like_pointer(&expr));
let expr: syn::Expr = parse_quote!(p.wrapping_add(n));
assert!(looks_like_pointer(&expr));
let expr: syn::Expr = parse_quote!(arr.as_ptr());
assert!(looks_like_pointer(&expr));
}
#[test]
fn test_looks_like_pointer_null() {
let expr: syn::Expr = parse_quote!(std::ptr::null_mut());
assert!(looks_like_pointer(&expr));
}
#[test]
fn test_wrap_as_bool_already_bool() {
let expr: syn::Expr = parse_quote!(a == b);
let result = wrap_as_bool(expr);
let s = expr_to_string(&result);
assert_eq!(s, "a == b");
}
#[test]
fn test_wrap_as_bool_integer() {
let expr: syn::Expr = parse_quote!(x);
let result = wrap_as_bool(expr);
let s = expr_to_string(&result);
assert_eq!(s, "x != 0");
}
#[test]
fn test_wrap_as_bool_pointer() {
let expr: syn::Expr = parse_quote!(p as *mut i32);
let result = wrap_as_bool(expr);
let s = expr_to_string(&result);
assert!(s.contains("is_null"), "expected is_null in: {}", s);
}
#[test]
fn test_int_lit() {
let expr = int_lit(42);
let s = expr_to_string(&expr);
assert_eq!(s, "42");
let expr = int_lit(0);
let s = expr_to_string(&expr);
assert_eq!(s, "0");
let expr = int_lit(-1);
let s = expr_to_string(&expr);
assert_eq!(s, "- 1"); }
#[test]
fn test_insert_cast() {
let expr: syn::Expr = parse_quote!(x);
let ty = parse_type("u32");
let result = insert_cast(expr, ty);
let s = expr_to_string(&result);
assert_eq!(s, "x as u32");
}
#[test]
fn test_insert_cast_complex_expr() {
let expr: syn::Expr = parse_quote!(a + b);
let ty = parse_type("i32");
let result = insert_cast(expr, ty);
let s = expr_to_string(&result);
assert_eq!(s, "(a + b) as i32");
}
#[test]
fn test_parse_type_basic() {
let ty = parse_type("i32");
assert_eq!(ty.to_token_stream().to_string(), "i32");
}
#[test]
fn test_parse_type_pointer() {
let ty = parse_type("*mut u8");
assert_eq!(ty.to_token_stream().to_string(), "* mut u8");
}
#[test]
fn test_parse_type_fallback() {
let ty = parse_type("not a valid type!!!");
assert_eq!(ty.to_token_stream().to_string(), "c_int");
}
#[test]
fn test_null_for_type_mut() {
let expr = null_for_type("*mut SV");
let s = expr_to_string(&expr);
assert!(s.contains("null_mut"), "expected null_mut in: {}", s);
}
#[test]
fn test_null_for_type_const() {
let expr = null_for_type("*const c_char");
let s = expr_to_string(&expr);
assert!(s.contains("null"), "expected null in: {}", s);
assert!(!s.contains("null_mut"), "should not contain null_mut in: {}", s);
}
#[test]
fn test_null_for_type_non_pointer() {
let expr = null_for_type("i32");
let s = expr_to_string(&expr);
assert_eq!(s, "0");
}
#[test]
fn test_as_ptr() {
let expr: syn::Expr = parse_quote!(PL_Yes);
let result = as_ptr(expr);
let s = expr_to_string(&result);
assert!(s.contains("as_ptr"), "expected as_ptr in: {}", s);
assert!(s.contains("PL_Yes"), "expected PL_Yes in: {}", s);
}
#[test]
fn test_field_access() {
let expr: syn::Expr = parse_quote!(sv);
let result = field_access(expr, "sv_flags");
let s = expr_to_string(&result);
assert_eq!(s, "sv . sv_flags");
}
#[test]
fn test_deref_simple() {
let expr: syn::Expr = parse_quote!(ptr);
let result = deref(expr);
let s = expr_to_string(&result);
assert_eq!(s, "* ptr");
}
#[test]
fn test_deref_field_parenthesized() {
let ptr: syn::Expr = parse_quote!(ptr);
let d = deref(ptr);
let f = field_access(d, "field");
let s = expr_to_string(&f);
assert_eq!(s, "(* ptr) . field");
}
#[test]
fn test_addr_of_mut() {
let expr: syn::Expr = parse_quote!(x);
let result = addr_of_mut(expr);
let s = expr_to_string(&result);
assert_eq!(s, "& raw mut x");
}
#[test]
fn test_call_no_args() {
let result = call("foo", vec![]);
let s = expr_to_string(&result);
assert!(s.contains("foo"), "expected foo in: {}", s);
let normalized = s.replace(' ', "");
assert_eq!(normalized, "foo()");
}
#[test]
fn test_call_with_args() {
let a: syn::Expr = parse_quote!(x);
let b: syn::Expr = parse_quote!(y);
let result = call("bar", vec![a, b]);
let s = expr_to_string(&result);
let normalized = s.replace(' ', "");
assert_eq!(normalized, "bar(x,y)");
}
#[test]
fn test_if_else() {
let cond: syn::Expr = parse_quote!(x > 0);
let then_expr: syn::Expr = parse_quote!(a);
let else_expr: syn::Expr = parse_quote!(b);
let result = if_else(cond, then_expr, else_expr);
let s = expr_to_string(&result);
assert!(s.contains("if"), "expected if in: {}", s);
assert!(s.contains("else"), "expected else in: {}", s);
}
#[test]
fn test_wrap_as_bool_with_binary() {
let expr: syn::Expr = parse_quote!(a + b);
let result = wrap_as_bool(expr);
let s = expr_to_string(&result);
assert_eq!(s, "a + b != 0");
}
#[test]
fn test_combined_cast_and_bool() {
let x: syn::Expr = parse_quote!(flags);
let cast = insert_cast(x, parse_type("u32"));
let bool_expr = wrap_as_bool(cast);
let s = expr_to_string(&bool_expr);
assert_eq!(s, "flags as u32 != 0");
}
}