use std::collections::HashSet;
use syntax::ast::*;
use syntax::ptr::P;
use syntax::symbol::Symbol;
use smallvec::smallvec;
use crate::ast_manip::{FlatMapNodes, MutVisitNodes};
use crate::command::{CommandState, Registry};
use crate::driver::{parse_ty};
use crate::path_edit::fold_resolved_paths_with_id;
use crate::transform::Transform;
use crate::RefactorCtxt;
use c2rust_ast_builder::{mk, IntoSymbol};
pub struct GeneralizeItems {
ty_var_name: Symbol,
replacement_ty: Option<String>,
}
impl Transform for GeneralizeItems {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
let mut replacement_ty = self.replacement_ty.as_ref()
.map(|s| parse_ty(cx.session(), s));
MutVisitNodes::visit(krate, |ty: &mut P<Ty>| {
if !st.marked(ty.id, "target") {
return;
}
let hir_id = cx.hir_map().node_to_hir_id(ty.id);
let parent_id = cx.hir_map().get_parent_item(hir_id);
let parent_id = cx.hir_map().hir_to_node_id(parent_id);
if !st.marked(parent_id, "target") {
return;
}
if replacement_ty.is_none() {
replacement_ty = Some(ty.clone());
}
*ty = mk().ident_ty(self.ty_var_name)
});
let mut item_def_ids = HashSet::new();
FlatMapNodes::visit(krate, |i: P<Item>| {
if !st.marked(i.id, "target") {
return smallvec![i];
}
item_def_ids.insert(cx.node_def_id(i.id));
smallvec![i.map(|mut i| {
{
let gen = match i.kind {
ItemKind::Fn(_, ref mut gen, _) => gen,
ItemKind::Enum(_, ref mut gen) => gen,
ItemKind::Struct(_, ref mut gen) => gen,
ItemKind::Union(_, ref mut gen) => gen,
ItemKind::Trait(_, _, ref mut gen, _, _) => gen,
ItemKind::Impl(_, _, _, ref mut gen, _, _, _) => gen,
_ => panic!("item has no room for generics"),
};
gen.params.push(mk().ty_param(self.ty_var_name));
}
i
})]
});
let replacement_ty = replacement_ty
.expect("must provide a replacement type argument or mark");
fold_resolved_paths_with_id(krate, cx, |path_id, qself, mut path, def| {
match def[0].opt_def_id() {
Some(def_id) if item_def_ids.contains(&def_id) => (),
_ => return (qself, path),
};
let hir_id = cx.hir_map().node_to_hir_id(path_id);
let parent_id = cx.hir_map().get_parent_item(hir_id);
let parent_id = cx.hir_map().hir_to_node_id(parent_id);
let arg = if st.marked(parent_id, "target") {
mk().ident_ty(self.ty_var_name)
} else {
replacement_ty.clone()
};
{
let seg = path.segments.last_mut().unwrap();
if let Some(ref mut args) = seg.args {
*args = args.clone().map(|mut args| {
match args {
GenericArgs::AngleBracketed(ref mut abpd) =>
abpd.args.push(mk().generic_arg(arg)),
GenericArgs::Parenthesized(..) =>
panic!("expected angle bracketed params, but found parenthesized"),
}
args
});
} else {
let abpd = mk().angle_bracketed_args(vec![arg]);
seg.args = Some(P(GenericArgs::AngleBracketed(abpd)));
}
}
(qself, path)
});
}
}
pub fn register_commands(reg: &mut Registry) {
use super::mk;
reg.register("generalize_items", |args| mk(GeneralizeItems {
ty_var_name: args.get(0).map_or("T", |x| x).into_symbol(),
replacement_ty: args.get(1).cloned(),
}));
}