use rustc::ty::{self, ParamEnv, TyKind};
use syntax::ast::*;
use syntax::token;
use syntax::ptr::P;
use syntax_pos::Symbol;
use crate::command::{CommandState, Registry};
use crate::driver::Phase;
use crate::matcher::{mut_visit_match_with, replace_expr, MatchCtxt};
use crate::transform::Transform;
use crate::RefactorCtxt;
use c2rust_ast_builder::mk;
#[cfg(test)]
mod tests;
pub struct RemoveRedundantCasts;
impl Transform for RemoveRedundantCasts {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
let tcx = cx.ty_ctxt();
let mut mcx = MatchCtxt::new(st, cx);
let pat = mcx.parse_expr("$oe:Expr as $ot:Ty");
mut_visit_match_with(mcx, pat, krate, |ast, mcx| {
let oe = mcx.bindings.get::<_, P<Expr>>("$oe").unwrap();
let oe_ty = cx.node_type(oe.id);
let oe_ty = tcx.normalize_erasing_regions(ParamEnv::empty(), oe_ty);
let ot = mcx.bindings.get::<_, P<Ty>>("$ot").unwrap();
let ot_ty = cx.node_type(ot.id);
let ot_ty = tcx.normalize_erasing_regions(ParamEnv::empty(), ot_ty);
debug!("checking cast: {:?}, types: {:?} => {:?}",
ast, oe_ty, ot_ty);
let ast_mk = mk().id(ast.id).span(ast.span);
match oe.kind {
ExprKind::Cast(ref ie, ref it) => {
let ie_ty = cx.node_type(ie.id);
let ie_ty = tcx.normalize_erasing_regions(ParamEnv::empty(), ie_ty);
let it_ty = cx.node_type(it.id);
let it_ty = tcx.normalize_erasing_regions(ParamEnv::empty(), it_ty);
debug!("inner cast: {:?} => {:?}", ie_ty, it_ty);
match check_double_cast(ie_ty.into(), it_ty.into(), ot_ty.into()) {
DoubleCastAction::RemoveBoth => {
debug!("redundant cast => removing both");
*ast = ie.clone();
return;
}
DoubleCastAction::RemoveInner => {
debug!("redundant cast => removing inner");
*ast = ast_mk.cast_expr(ie, ot);
return;
}
DoubleCastAction::KeepBoth => {}
}
}
ExprKind::Lit(ref lit) => {
let new_lit = replace_suffix(lit, SimpleTy::from(ot_ty));
if let Some(nl) = new_lit {
let new_expr = ast_mk.lit_expr(nl);
let ast_const = eval_const(ast.clone(), cx);
let new_const = eval_const(new_expr.clone(), cx);
debug!(
"checking {:?} == {:?}: {:?} == {:?}",
*ast, new_expr, ast_const, new_const
);
if new_const.is_some() && new_const == ast_const {
*ast = new_expr;
return;
}
}
if lit.kind.is_unsuffixed() {
return;
}
}
ExprKind::Unary(UnOp::Neg, ref expr) => match expr.kind {
ExprKind::Lit(ref lit) => {
let new_lit = replace_suffix(lit, SimpleTy::from(ot_ty));
if let Some(nl) = new_lit {
let expr_mk = mk().id(expr.id).span(expr.span);
let new_expr = ast_mk.unary_expr(UnOp::Neg, expr_mk.lit_expr(nl));
let ast_const = eval_const(ast.clone(), cx);
let new_const = eval_const(new_expr.clone(), cx);
debug!(
"checking {:?} == {:?}: {:?} == {:?}",
*ast, new_expr, ast_const, new_const
);
if new_const.is_some() && new_const == ast_const {
*ast = new_expr;
return;
}
}
if lit.kind.is_unsuffixed() {
return;
}
}
_ => {}
},
_ => {}
}
if oe_ty == ot_ty {
debug!("no-op cast");
*ast = oe.clone();
return;
}
})
}
fn min_phase(&self) -> Phase {
Phase::Phase3
}
}
enum DoubleCastAction {
RemoveBoth,
RemoveInner,
KeepBoth,
}
fn check_double_cast<'tcx>(e_ty: SimpleTy, t1_ty: SimpleTy, t2_ty: SimpleTy) -> DoubleCastAction {
use CastKind::*;
let inner_cast = cast_kind(e_ty, t1_ty);
let outer_cast = cast_kind(t1_ty, t2_ty);
match (inner_cast, outer_cast) {
(Required, _) | (_, Required) => DoubleCastAction::KeepBoth,
(SameWidth, SameWidth) if t2_ty == SimpleTy::Pointer => {
DoubleCastAction::RemoveInner
}
(SameWidth, SameWidth) | (Extend(_), Truncate) if e_ty == t2_ty => {
DoubleCastAction::RemoveBoth
}
(Extend(_), Extend(s))
| (SameWidth, Extend(s))
| (SameWidth, FromPointer(s))
| (SameWidth, ToPointer(s))
if s == e_ty.is_signed() =>
{
DoubleCastAction::RemoveInner
}
(_, SameWidth) | (_, Truncate) => DoubleCastAction::RemoveInner,
_ => DoubleCastAction::KeepBoth,
}
}
enum CastKind {
Extend(bool),
Truncate,
SameWidth,
FromPointer(bool),
ToPointer(bool),
Required,
Unknown,
}
fn cast_kind(from_ty: SimpleTy, to_ty: SimpleTy) -> CastKind {
use SimpleTy::*;
match (from_ty, to_ty) {
(Int(fw, fs), Int(tw, _)) if fw < tw => CastKind::Extend(fs),
(Int(fw, _), Int(tw, _)) if fw > tw => CastKind::Truncate,
(Int(..), Int(..)) => CastKind::SameWidth,
(Int(fw, fs), Size(_)) | (Int(fw, fs), Pointer) if fw <= 16 => CastKind::Extend(fs),
(Int(fw, _), Size(_)) | (Int(fw, _), Pointer) if fw >= 64 => CastKind::Truncate,
(Int(..), Size(ts)) => CastKind::ToPointer(ts),
(Int(..), Pointer) => CastKind::ToPointer(false),
(Size(fs), Int(tw, _)) if tw >= 64 => CastKind::Extend(fs),
(Pointer, Int(tw, _)) if tw >= 64 => CastKind::Extend(false),
(Size(_), Int(tw, _)) | (Pointer, Int(tw, _)) if tw <= 16 => CastKind::Truncate,
(Size(fs), Int(..)) => CastKind::FromPointer(fs),
(Pointer, Int(..)) => CastKind::FromPointer(false),
(Pointer, Pointer) | (Pointer, Size(_)) | (Size(_), Pointer) | (Size(_), Size(_)) => {
CastKind::SameWidth
}
(Ref, Pointer) | (Array, Pointer) => CastKind::Required,
(Float32, Float32) => CastKind::SameWidth,
(Float32, Float64) => CastKind::Extend(true),
(Float64, Float32) => CastKind::Truncate,
(Float64, Float64) => CastKind::SameWidth,
(_, _) => CastKind::Unknown,
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SimpleTy {
Int(usize, bool),
Size(bool),
Float32,
Float64,
Pointer,
Ref,
Array,
Other,
}
impl SimpleTy {
fn is_signed(&self) -> bool {
match self {
SimpleTy::Int(_, s) => *s,
SimpleTy::Size(s) => *s,
SimpleTy::Float32 => true,
SimpleTy::Float64 => true,
_ => false,
}
}
fn ast_lit_int_type(&self) -> LitIntType {
match self {
SimpleTy::Int(8, false) => LitIntType::Unsigned(UintTy::U8),
SimpleTy::Int(16, false) => LitIntType::Unsigned(UintTy::U16),
SimpleTy::Int(32, false) => LitIntType::Unsigned(UintTy::U32),
SimpleTy::Int(64, false) => LitIntType::Unsigned(UintTy::U64),
SimpleTy::Int(128, false) => LitIntType::Unsigned(UintTy::U128),
SimpleTy::Int(8, true) => LitIntType::Signed(IntTy::I8),
SimpleTy::Int(16, true) => LitIntType::Signed(IntTy::I16),
SimpleTy::Int(32, true) => LitIntType::Signed(IntTy::I32),
SimpleTy::Int(64, true) => LitIntType::Signed(IntTy::I64),
SimpleTy::Int(128, true) => LitIntType::Signed(IntTy::I128),
SimpleTy::Size(false) => LitIntType::Unsigned(UintTy::Usize),
SimpleTy::Size(true) => LitIntType::Signed(IntTy::Isize),
_ => panic!("ast_lit_int_type() called with non-integer type")
}
}
fn ast_float_ty(&self) -> FloatTy {
match self {
SimpleTy::Float32 => FloatTy::F32,
SimpleTy::Float64 => FloatTy::F64,
_ => panic!("as_float_ty() called with non-float type")
}
}
fn max_int_value(&self) -> u128 {
match self {
SimpleTy::Int(8, false) => u8::max_value() as u128,
SimpleTy::Int(16, false) => u16::max_value() as u128,
SimpleTy::Int(32, false) => u32::max_value() as u128,
SimpleTy::Int(64, false) => u64::max_value() as u128,
SimpleTy::Int(128, false) => u128::max_value() as u128,
SimpleTy::Int(8, true) => i8::max_value() as u128,
SimpleTy::Int(16, true) => i16::max_value() as u128,
SimpleTy::Int(32, true) => i32::max_value() as u128,
SimpleTy::Int(64, true) => i64::max_value() as u128,
SimpleTy::Int(128, true) => i128::max_value() as u128,
_ => panic!("max_int_value() called with non-integer type")
}
}
}
impl<'tcx> From<ty::Ty<'tcx>> for SimpleTy {
fn from(ty: ty::Ty<'tcx>) -> Self {
use SimpleTy::*;
match ty.kind {
TyKind::Int(IntTy::Isize) => Size(true),
TyKind::Uint(UintTy::Usize) => Size(false),
TyKind::Int(int_ty) => Int(int_ty.bit_width().unwrap(), true),
TyKind::Uint(uint_ty) => Int(uint_ty.bit_width().unwrap(), false),
TyKind::Float(FloatTy::F32) => Float32,
TyKind::Float(FloatTy::F64) => Float64,
TyKind::Ref(_, ty, _mutbl) => match ty.kind {
TyKind::Array(..) => Array,
_ => Ref,
}
TyKind::RawPtr(_) | TyKind::FnPtr(_) => Pointer,
_ => Other,
}
}
}
pub(crate) fn sym_token_kind(sym: Symbol) -> token::LitKind {
if sym.as_str().parse::<u128>().is_ok() {
token::LitKind::Integer
} else {
token::LitKind::Float
}
}
fn replace_suffix<'tcx>(lit: &Lit, ty: SimpleTy) -> Option<Lit> {
let mk_int = |i, ty| {
let new_suffix = match ty {
LitIntType::Signed(ty) => Some(ty.name()),
LitIntType::Unsigned(ty) => Some(ty.name()),
LitIntType::Unsuffixed => None
};
let new_lit = Lit {
kind: LitKind::Int(i, ty),
span: lit.span,
token: token::Lit {
kind: token::LitKind::Integer,
symbol: lit.token.symbol,
suffix: new_suffix,
},
};
Some(new_lit)
};
let mk_float = |fs: String, ty: FloatTy| {
let fsym = Symbol::intern(&fs);
Some(Lit {
kind: LitKind::Float(fsym, LitFloatType::Suffixed(ty)),
span: lit.span,
token: token::Lit {
kind: sym_token_kind(fsym),
symbol: fsym,
suffix: Some(Symbol::intern(ty.name_str())),
}
})
};
let lit_mk = mk().span(lit.span);
match (&lit.kind, &ty) {
(LitKind::Int(i, _), SimpleTy::Size(true)) if *i <= i16::max_value() as u128 => {
mk_int(*i, LitIntType::Signed(IntTy::Isize))
}
(LitKind::Int(i, _), SimpleTy::Size(false)) if *i <= u16::max_value() as u128 => {
mk_int(*i, LitIntType::Unsigned(UintTy::Usize))
}
(LitKind::Int(i, _), SimpleTy::Int(..)) if *i <= ty.max_int_value() => {
mk_int(*i, ty.ast_lit_int_type())
}
(LitKind::Int(i, _), SimpleTy::Float32)
| (LitKind::Int(i, _), SimpleTy::Float64) => {
mk_float(i.to_string(), ty.ast_float_ty())
}
(LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F32)), SimpleTy::Int(..)) => {
let fv = f.as_str().parse::<f32>().ok()?;
Some(lit_mk.int_lit(fv as u128, ty.ast_lit_int_type()))
}
(LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F64)), SimpleTy::Int(..))
| (LitKind::Float(f, LitFloatType::Unsuffixed), SimpleTy::Int(..)) => {
let fv = f.as_str().parse::<f64>().ok()?;
Some(lit_mk.int_lit(fv as u128, ty.ast_lit_int_type()))
}
(LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F32)), SimpleTy::Float32)
| (LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F32)), SimpleTy::Float64) => {
let fv = f.as_str().parse::<f32>().ok()?;
mk_float(fv.to_string(), ty.ast_float_ty())
}
(LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F64)), SimpleTy::Float32)
| (LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F64)), SimpleTy::Float64)
| (LitKind::Float(f, LitFloatType::Unsuffixed), SimpleTy::Float32)
| (LitKind::Float(f, LitFloatType::Unsuffixed), SimpleTy::Float64) => {
let fv = f.as_str().parse::<f64>().ok()?;
mk_float(fv.to_string(), ty.ast_float_ty())
}
_ => None,
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
enum ConstantValue {
Int(i128),
Uint(u128),
Float32(f32),
Float64(f64),
}
impl ConstantValue {
fn cast(self, ty: SimpleTy) -> Self {
use ConstantValue::*;
macro_rules! match_ty {
($($pat:pat => $const_ty:ident[$($as_ty:ty),*]),*) => {
match (self, &ty) {
$(
(Int(v), $pat) => return $const_ty(v $(as $as_ty)*),
(Uint(v), $pat) => return $const_ty(v $(as $as_ty)*),
(Float32(v), $pat) => return $const_ty(v $(as $as_ty)*),
(Float64(v), $pat) => return $const_ty(v $(as $as_ty)*),
)*
_ => panic!("Unexpected SimpleTy: {:?}", ty)
}
}
};
match_ty! {
SimpleTy::Int(8, false) => Uint[u8, u128],
SimpleTy::Int(16, false) => Uint[u16, u128],
SimpleTy::Int(32, false) => Uint[u32, u128],
SimpleTy::Int(64, false) => Uint[u64, u128],
SimpleTy::Int(128, false) => Uint[u128],
SimpleTy::Int(8, true) => Int[i8, i128],
SimpleTy::Int(16, true) => Int[i16, i128],
SimpleTy::Int(32, true) => Int[i32, i128],
SimpleTy::Int(64, true) => Int[i64, i128],
SimpleTy::Int(128, true) => Int[i128],
SimpleTy::Size(false) => Uint[usize, u128],
SimpleTy::Size(true) => Int[isize, i128],
SimpleTy::Float32 => Float32[f32],
SimpleTy::Float64 => Float64[f64]
};
}
}
fn eval_const<'tcx>(e: P<Expr>, cx: &RefactorCtxt) -> Option<ConstantValue> {
match e.kind {
ExprKind::Lit(ref lit) => {
match lit.kind {
LitKind::Int(i, LitIntType::Unsuffixed) => Some(ConstantValue::Uint(i)),
LitKind::Int(i, LitIntType::Signed(IntTy::Isize)) => {
Some(ConstantValue::Int(i as i16 as i128))
}
LitKind::Int(i, LitIntType::Signed(IntTy::I8)) => {
Some(ConstantValue::Int(i as i8 as i128))
}
LitKind::Int(i, LitIntType::Signed(IntTy::I16)) => {
Some(ConstantValue::Int(i as i16 as i128))
}
LitKind::Int(i, LitIntType::Signed(IntTy::I32)) => {
Some(ConstantValue::Int(i as i32 as i128))
}
LitKind::Int(i, LitIntType::Signed(IntTy::I64)) => {
Some(ConstantValue::Int(i as i64 as i128))
}
LitKind::Int(i, LitIntType::Signed(IntTy::I128)) => {
Some(ConstantValue::Int(i as i128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::Usize)) => {
Some(ConstantValue::Uint(i as u16 as u128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::U8)) => {
Some(ConstantValue::Uint(i as u8 as u128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::U16)) => {
Some(ConstantValue::Uint(i as u16 as u128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::U32)) => {
Some(ConstantValue::Uint(i as u32 as u128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::U64)) => {
Some(ConstantValue::Uint(i as u64 as u128))
}
LitKind::Int(i, LitIntType::Unsigned(UintTy::U128)) => {
Some(ConstantValue::Uint(i as u128))
}
LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F32)) => {
let fv = f.as_str().parse::<f32>().ok()?;
Some(ConstantValue::Float32(fv))
}
LitKind::Float(f, LitFloatType::Suffixed(FloatTy::F64))
| LitKind::Float(f, LitFloatType::Unsuffixed) => {
let fv = f.as_str().parse::<f64>().ok()?;
Some(ConstantValue::Float64(fv))
}
_ => None,
}
}
ExprKind::Unary(UnOp::Neg, ref ie) => {
let ic = eval_const(ie.clone(), cx)?;
use ConstantValue::*;
match ic {
Uint(i) if i > (i128::max_value() as u128) => None,
Uint(i) => Some(Int(-(i as i128))),
Int(i) => Some(Int(-i)),
Float32(f) => Some(Float32(-f)),
Float64(f) => Some(Float64(-f)),
}
}
ExprKind::Cast(ref ie, ref ty) => {
let tcx = cx.ty_ctxt();
let ty_ty = cx.node_type(ty.id);
let ty_ty = tcx.normalize_erasing_regions(ParamEnv::empty(), ty_ty);
let ic = eval_const(ie.clone(), cx)?;
Some(ic.cast(SimpleTy::from(ty_ty)))
}
_ => unreachable!("Unexpected ExprKind"),
}
}
pub struct ConvertCastAsPtr;
impl Transform for ConvertCastAsPtr {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
replace_expr(
st,
cx,
krate,
"typed!($expr:Expr, &[$ty:Ty]) as *const $ty",
"$expr.as_ptr()",
);
replace_expr(
st,
cx,
krate,
"typed!($expr:Expr, &[$ty:Ty]) as *mut $ty",
"$expr.as_mut_ptr()",
);
replace_expr(
st,
cx,
krate,
"typed!($expr:Expr, &[$ty:Ty; $len]) as *const $ty",
"$expr.as_ptr()",
);
replace_expr(
st,
cx,
krate,
"typed!($expr:Expr, &[$ty:Ty; $len]) as *mut $ty",
"$expr.as_mut_ptr()",
);
}
fn min_phase(&self) -> Phase {
Phase::Phase3
}
}
pub fn register_commands(reg: &mut Registry) {
use super::mk;
reg.register("remove_redundant_casts", |_| mk(RemoveRedundantCasts));
reg.register("convert_cast_as_ptr", |_| mk(ConvertCastAsPtr));
}