use std::collections::{HashMap, HashSet};
use log::Level;
use rustc::hir::def::DefKind;
use rustc::hir::def_id::{DefId};
use rustc::ty::{Instance, TyCtxt, TyKind, Ty};
use syntax::ast::*;
use syntax::ptr::P;
use c2rust_ast_builder::mk;
use crate::ast_manip::{MutVisitNodes, visit_nodes};
use crate::command::{CommandState, Registry};
use crate::driver::{Phase};
use crate::path_edit::fold_resolved_paths_with_id;
use crate::reflect::Reflector;
use crate::resolve;
use crate::transform::Transform;
use crate::RefactorCtxt;
use crate::context::TypeCompare;
pub fn fix_users(
krate: &mut Crate,
replace_map: &HashMap<DefId, DefId>,
path_ids: &HashMap<NodeId, DefId>,
new_paths: &HashMap<DefId, (Option<QSelf>, Path)>,
cx: &RefactorCtxt,
) {
let tcx = cx.ty_ctxt();
let reflector = Reflector::new_with_mapping(tcx, new_paths);
let ty_compare = TypeCompare::new_with_mapping(cx, replace_map);
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
enum TyLoc {
Arg(usize),
Ret,
Whole,
}
let mut ty_replace_map: HashMap<(DefId, TyLoc), (Ty, Ty)> = HashMap::new();
let mut changed_fns: HashSet<DefId> = HashSet::new();
for (&old_did, &new_did) in replace_map {
let old_ty = cx.def_type(old_did);
let new_ty = cx.def_type(new_did);
if !matches!([old_ty.kind] TyKind::FnDef(..)) {
if !ty_compare.structural_eq_tys(old_ty, new_ty) {
ty_replace_map.insert((old_did, TyLoc::Whole), (old_ty, new_ty));
}
continue;
}
let old_sig = tcx.fn_sig(old_did);
let new_sig = tcx.fn_sig(new_did);
macro_rules! bail {
($msg:expr) => {{
warn!(concat!("canonicalize_externs: {:?} -> {:?}: ", $msg, " - skipping"),
old_did, new_did);
continue;
}};
}
let (old_sig, new_sig) = match (old_sig.no_bound_vars(),
new_sig.no_bound_vars()) {
(Some(x), Some(y)) => (x, y),
_ => bail!("old or new sig had late-bound regions"),
};
if old_sig.inputs().len() != new_sig.inputs().len() {
bail!("old and new sig differ in arg count");
}
if old_sig.c_variadic != new_sig.c_variadic {
bail!("old and new sig differ in variadicness");
}
for (i, (&old_ty, &new_ty)) in old_sig.inputs().iter()
.zip(new_sig.inputs().iter()).enumerate() {
if !ty_compare.eq_tys(old_ty, new_ty) {
ty_replace_map.insert((old_did, TyLoc::Arg(i)), (old_ty, new_ty));
changed_fns.insert(old_did);
}
}
let old_ty = old_sig.output();
let new_ty = new_sig.output();
if !ty_compare.eq_tys(old_ty, new_ty) {
ty_replace_map.insert((old_did, TyLoc::Ret), (old_ty, new_ty));
changed_fns.insert(old_did);
}
}
if log_enabled!(Level::Info) {
for (&k, &v) in replace_map {
info!("REPLACE {:?} ({:?})", k, cx.def_type(k));
info!(" WITH {:?} ({:?})", v, cx.def_type(v));
}
let mut stuff = ty_replace_map.iter().collect::<Vec<_>>();
stuff.sort_by_key(|&(&a, _)| a);
for (&(did, loc), &(old, new)) in stuff {
info!("TYPE CHANGE: {:?} @{:?}: {:?} -> {:?}", did, loc, old, new);
}
}
MutVisitNodes::visit(krate, |e: &mut P<Expr>| {
if let Some(&old_did) = path_ids.get(&e.id) {
if let Some(&(old_ty, _new_ty)) = ty_replace_map.get(&(old_did, TyLoc::Whole)) {
*e = make_cast(cx, &reflector, e.clone(), old_ty);
}
}
match &e.kind {
ExprKind::Call(ref f, _) => if let Some(&old_did) = path_ids.get(&f.id) {
let arg_count = expect!([e.kind] ExprKind::Call(_, ref a) => a.len());
info!("rewriting call - e = {:?}", e);
for i in 0 .. arg_count {
let k = (old_did, TyLoc::Arg(i));
if let Some(&(_old_ty, new_ty)) = ty_replace_map.get(&k) {
expect!([e.kind] ExprKind::Call(_, ref mut args) => {
if let Some(ty) = cx.opt_node_type(args[i].id) {
if ty_compare.eq_tys(ty, new_ty) {
return;
}
}
let new_arg = make_cast(cx, &reflector, args[i].clone(), new_ty);
args[i] = new_arg;
});
info!(" arg {} - rewrote e = {:?}", i, e);
}
}
if let Some(&(old_ty, _new_ty)) = ty_replace_map.get(&(old_did, TyLoc::Ret)) {
*e = make_cast(cx, &reflector, e.clone(), old_ty);
info!(" return - rewrote e = {:?}", e);
}
},
_ => {}
}
});
}
fn make_cast<'a, 'tcx>(
cx: &RefactorCtxt<'a, 'tcx>,
reflector: &Reflector<'a, 'tcx>,
expr: P<Expr>,
ty: Ty<'tcx>
) -> P<Expr> {
let ty_ast = reflector.reflect_ty(ty);
let mut needs_transmute = ty.is_fn_ptr();
if let Some(info) = cx.opt_callee_info(&expr) {
if let Some(def_id) = info.def_id {
if cx.def_path(def_id).segments.last().unwrap().ident.as_str() == "Some"
&& info.fn_sig.inputs()[0].is_fn_ptr()
{
needs_transmute = true;
}
}
}
if needs_transmute {
mk().call_expr(mk().path_expr(vec!["", "std", "mem", "transmute"]), vec![expr])
} else {
let expr = if let ExprKind::AddrOf(_, mutability, _) = expr.kind {
mk().cast_expr(expr, mk().set_mutbl(mutability).ptr_ty(mk().infer_ty()))
} else {
expr
};
mk().cast_expr(expr, ty_ast)
}
}
pub struct CanonicalizeExterns {
path: String,
}
fn is_foreign_symbol(tcx: TyCtxt, did: DefId) -> bool {
tcx.is_foreign_item(did) &&
matches!([tcx.def_kind(did)] Some(DefKind::Fn), Some(DefKind::Static))
}
impl Transform for CanonicalizeExterns {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
let tcx = cx.ty_ctxt();
let lib_path = self.path.split("::").map(|s| Ident::from_str(s)).collect::<Vec<_>>();
let lib = resolve::resolve_absolute(tcx, &lib_path);
let mut symbol_map = HashMap::new();
for (_sym, def) in resolve::module_children(tcx, lib.def_id()) {
let did = def.def_id();
if is_foreign_symbol(tcx, did) {
let inst = Instance::new(did, tcx.intern_substs(&[]));
let sym = tcx.symbol_name(inst).name;
symbol_map.insert(sym, did);
}
}
for (&sym, &def) in &symbol_map {
info!(" found symbol {} :: {:?} at {:?}", self.path, sym, def);
}
let mut replace_map = HashMap::new();
visit_nodes(krate, |fi: &ForeignItem| {
if !st.marked(fi.id, "target") {
return;
}
let did = cx.node_def_id(fi.id);
if !is_foreign_symbol(tcx, did) {
return;
}
let inst = Instance::new(did, tcx.intern_substs(&[]));
let sym = tcx.symbol_name(inst).name;
if let Some(&repl_did) = symbol_map.get(&sym) {
replace_map.insert(did, repl_did);
}
});
let mut path_ids = HashMap::new();
let new_paths = replace_map.iter().map(|(old_did, new_did)| {
(*old_did, cx.def_qpath(*new_did))
}).collect::<HashMap<_, _>>();
fold_resolved_paths_with_id(krate, cx, |id, qself, path, def| {
let old_did = match_or!([def[0].opt_def_id()] Some(x) => x; return (qself, path));
if let Some(new_path) = new_paths.get(&old_did) {
path_ids.insert(id, old_did);
new_path.clone()
} else {
(qself, path)
}
});
fix_users(krate, &replace_map, &path_ids, &new_paths, cx);
MutVisitNodes::visit(krate, |fm: &mut ForeignMod| {
fm.items.retain(|fi| {
let did = cx.node_def_id(fi.id);
!replace_map.contains_key(&did)
});
});
}
fn min_phase(&self) -> Phase {
Phase::Phase3
}
}
pub fn register_commands(reg: &mut Registry) {
use super::mk;
reg.register("canonicalize_externs", |args| mk(CanonicalizeExterns {
path: args[0].clone(),
}));
}