use imbl::HashSet;
use syn::{
punctuated::{Pair, Punctuated},
*,
};
pub enum Usage<'a> {
Type(&'a Type),
Expression(&'a Expr),
}
pub fn filter_generics<'a>(
base: Generics,
usage: impl Iterator<Item = Usage<'a>>,
context: impl Iterator<Item = &'a Generics>,
) -> Generics {
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum GenericRef {
Type(Ident),
Lifetime(Lifetime),
Const(Ident),
}
impl From<&GenericParam> for GenericRef {
fn from(value: &GenericParam) -> Self {
match value {
GenericParam::Type(type_param) => GenericRef::Type(type_param.ident.clone()),
GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
GenericParam::Const(c) => GenericRef::Const(c.ident.clone()),
}
}
}
fn add_bound_lifetimes(_bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
if let Some(_lifetimes) = b {
unimplemented!()
}
}
fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
let r = GenericRef::Lifetime(lt);
if !bound.contains(&r) {
used.insert(r);
}
}
fn process_generic_arguments(
used: &mut HashSet<GenericRef>,
bound: &HashSet<GenericRef>,
args: &AngleBracketedGenericArguments,
) {
for arg in &args.args {
match arg {
GenericArgument::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
GenericArgument::Type(ty) => recurse_type(used, ty, bound),
GenericArgument::Const(_) => unimplemented!(),
GenericArgument::AssocType(assoc_type) => recurse_type(used, &assoc_type.ty, bound),
GenericArgument::AssocConst(assoc_const) => {
recurse_expr(used, &assoc_const.value, bound)
}
GenericArgument::Constraint(constraint) => {
process_bounds(used, bound, constraint.bounds.iter())
}
_ => unimplemented!(),
}
}
}
fn process_path(
used: &mut HashSet<GenericRef>,
bound: &HashSet<GenericRef>,
path: &Path,
unqualified: bool,
) {
if let Some(ident) = path.get_ident() {
let r = GenericRef::Type(ident.clone());
if !bound.contains(&r) {
used.insert(r);
}
} else {
let mut first = unqualified && path.leading_colon.is_none();
for s in &path.segments {
if first && s.arguments.is_empty() {
let r = GenericRef::Type(s.ident.clone());
if !bound.contains(&r) {
used.insert(r);
}
}
first = false;
match &s.arguments {
PathArguments::None => {}
PathArguments::AngleBracketed(args) => {
process_generic_arguments(used, bound, args)
}
PathArguments::Parenthesized(args) => {
for ty in &args.inputs {
recurse_type(used, ty, bound);
}
if let ReturnType::Type(_, ty) = &args.output {
recurse_type(used, ty, bound)
}
}
}
}
}
}
fn process_bounds<'a>(
used: &mut HashSet<GenericRef>,
bound: &HashSet<GenericRef>,
b: impl Iterator<Item = &'a TypeParamBound>,
) {
for b in b {
match b {
TypeParamBound::Trait(b) => {
let mut bound = bound.clone();
add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
process_path(used, &bound, &b.path, true);
}
TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
_ => unimplemented!(),
}
}
}
fn recurse_type(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
match ty {
Type::Array(arr) => recurse_type(used, &arr.elem, bound),
Type::BareFn(bare_fn) => {
let mut bound = bound.clone();
add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
for input in &bare_fn.inputs {
recurse_type(used, &input.ty, &bound)
}
if let ReturnType::Type(_, ty) = &bare_fn.output {
recurse_type(used, ty, &bound)
}
}
Type::Group(group) => recurse_type(used, &group.elem, bound),
Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
Type::Never(_) => {}
Type::Paren(paren) => recurse_type(used, &paren.elem, bound),
Type::Path(path) => {
if let Some(qself) = &path.qself {
recurse_type(used, &qself.ty, bound);
}
process_path(used, bound, &path.path, path.qself.is_none());
}
Type::Ptr(ptr) => recurse_type(used, &ptr.elem, bound),
Type::Reference(reference) => recurse_type(used, &reference.elem, bound),
Type::Slice(slice) => recurse_type(used, &slice.elem, bound),
Type::TraitObject(trait_object) => {
process_bounds(used, bound, trait_object.bounds.iter());
}
Type::Tuple(tuple) => {
for ty in &tuple.elems {
recurse_type(used, ty, bound);
}
}
ty => panic!("unsupported type: {:?}", ty),
}
}
fn recurse_expr(used: &mut HashSet<GenericRef>, expr: &Expr, bound: &HashSet<GenericRef>) {
match expr {
Expr::Array(ExprArray { elems, .. }) | Expr::Tuple(ExprTuple { elems, .. }) => {
for expr in elems {
recurse_expr(used, expr, bound);
}
}
Expr::Assign(ExprAssign { left, right, .. })
| Expr::Binary(ExprBinary { left, right, .. })
| Expr::Index(ExprIndex {
expr: left,
index: right,
..
})
| Expr::Repeat(ExprRepeat {
expr: left,
len: right,
..
}) => {
recurse_expr(used, left, bound);
recurse_expr(used, right, bound);
}
Expr::Async(ExprAsync {
block: Block { stmts, .. },
..
})
| Expr::Block(ExprBlock {
block: Block { stmts, .. },
..
})
| Expr::Loop(ExprLoop {
body: Block { stmts, .. },
..
})
| Expr::TryBlock(ExprTryBlock {
block: Block { stmts, .. },
..
})
| Expr::Unsafe(ExprUnsafe {
block: Block { stmts, .. },
..
}) => {
for stmt in stmts {
recurse_stmt(used, stmt, bound);
}
}
Expr::Await(ExprAwait { base: expr, .. })
| Expr::Field(ExprField { base: expr, .. })
| Expr::Group(ExprGroup { expr, .. })
| Expr::Paren(ExprParen { expr, .. })
| Expr::Reference(ExprReference { expr, .. })
| Expr::Try(ExprTry { expr, .. })
| Expr::Unary(ExprUnary { expr, .. }) => recurse_expr(used, expr, bound),
Expr::Break(ExprBreak { expr, .. }) | Expr::Return(ExprReturn { expr, .. }) => {
if let Some(expr) = expr {
recurse_expr(used, expr, bound);
}
}
Expr::Call(ExprCall { func, args, .. }) => {
recurse_expr(used, func, bound);
for arg in args {
recurse_expr(used, arg, bound);
}
}
Expr::Cast(ExprCast { expr, ty, .. }) => {
recurse_expr(used, expr, bound);
recurse_type(used, ty, bound);
}
Expr::Closure(ExprClosure {
inputs,
output,
body,
..
}) => {
for pat in inputs {
recurse_pat(used, pat, bound);
}
if let ReturnType::Type(_, ty) = output {
recurse_type(used, ty, bound);
}
recurse_expr(used, body, bound);
}
Expr::Continue(_) | Expr::Lit(_) => {}
Expr::ForLoop(ExprForLoop {
pat,
expr,
body: Block { stmts, .. },
..
}) => {
recurse_pat(used, pat, bound);
recurse_expr(used, expr, bound);
for stmt in stmts {
recurse_stmt(used, stmt, bound);
}
}
Expr::If(ExprIf {
cond,
then_branch: Block { stmts, .. },
else_branch,
..
}) => {
recurse_expr(used, cond, bound);
for stmt in stmts {
recurse_stmt(used, stmt, bound);
}
if let Some((_, expr)) = else_branch {
recurse_expr(used, expr, bound);
}
}
Expr::Let(ExprLet { pat, expr, .. }) => {
recurse_pat(used, pat, bound);
recurse_expr(used, expr, bound);
}
Expr::Match(ExprMatch { expr, arms, .. }) => {
recurse_expr(used, expr, bound);
for Arm {
pat, guard, body, ..
} in arms
{
recurse_pat(used, pat, bound);
if let Some((_, expr)) = guard {
recurse_expr(used, expr, bound);
}
recurse_expr(used, body, bound);
}
}
Expr::MethodCall(ExprMethodCall {
receiver,
turbofish,
args,
..
}) => {
recurse_expr(used, receiver, bound);
if let Some(args) = turbofish {
process_generic_arguments(used, bound, args);
}
for arg in args {
recurse_expr(used, arg, bound);
}
}
Expr::Path(ExprPath { qself, path, .. }) => {
process_path(used, bound, path, true);
if let Some(QSelf { ty, .. }) = qself {
recurse_type(used, ty, bound);
}
}
Expr::Range(ExprRange { start, end, .. }) => {
if let Some(expr) = start {
recurse_expr(used, expr, bound);
}
if let Some(expr) = end {
recurse_expr(used, expr, bound);
}
}
Expr::Struct(ExprStruct { path, fields, .. }) => {
process_path(used, bound, path, true);
for FieldValue { expr, .. } in fields {
recurse_expr(used, expr, bound);
}
}
Expr::While(ExprWhile {
cond,
body: Block { stmts, .. },
..
}) => {
recurse_expr(used, cond, bound);
for stmt in stmts {
recurse_stmt(used, stmt, bound);
}
}
Expr::Yield(ExprYield { expr, .. }) => {
if let Some(expr) = expr {
recurse_expr(used, expr, bound);
}
}
expr => panic!("unsupported expression: {:?}", expr),
}
}
fn recurse_pat(used: &mut HashSet<GenericRef>, pat: &Pat, bound: &HashSet<GenericRef>) {
match pat {
Pat::Const(ExprConst {
block: Block { stmts, .. },
..
}) => {
for stmt in stmts {
recurse_stmt(used, stmt, bound);
}
}
Pat::Ident(_) | Pat::Lit(_) | Pat::Macro(_) | Pat::Rest(_) | Pat::Wild(_) => {}
Pat::Or(PatOr { cases, .. }) => {
for pat in cases {
recurse_pat(used, pat, bound);
}
}
Pat::Paren(PatParen { pat, .. }) => recurse_pat(used, pat, bound),
Pat::Path(ExprPath { qself, path, .. }) => {
process_path(used, bound, path, true);
if let Some(QSelf { ty, .. }) = qself {
recurse_type(used, ty, bound);
}
}
Pat::Range(ExprRange { start, end, .. }) => {
if let Some(expr) = start {
recurse_expr(used, expr, bound);
}
if let Some(expr) = end {
recurse_expr(used, expr, bound);
}
}
Pat::Reference(PatReference { pat, .. }) => recurse_pat(used, pat, bound),
Pat::Slice(PatSlice { elems, .. }) | Pat::Tuple(PatTuple { elems, .. }) => {
for pat in elems {
recurse_pat(used, pat, bound);
}
}
Pat::Struct(PatStruct {
qself,
path,
fields,
..
}) => {
process_path(used, bound, path, true);
if let Some(QSelf { ty, .. }) = qself {
recurse_type(used, ty, bound);
}
for FieldPat { pat, .. } in fields {
recurse_pat(used, pat, bound);
}
}
Pat::TupleStruct(PatTupleStruct {
qself, path, elems, ..
}) => {
process_path(used, bound, path, true);
if let Some(QSelf { ty, .. }) = qself {
recurse_type(used, ty, bound);
}
for pat in elems {
recurse_pat(used, pat, bound);
}
}
Pat::Type(PatType { pat, ty, .. }) => {
recurse_pat(used, pat, bound);
recurse_type(used, ty, bound);
}
_ => unimplemented!(),
}
}
fn recurse_stmt(used: &mut HashSet<GenericRef>, stmt: &Stmt, bound: &HashSet<GenericRef>) {
match stmt {
Stmt::Local(Local { pat, init, .. }) => {
recurse_pat(used, pat, bound);
if let Some(LocalInit { expr, diverge, .. }) = init {
recurse_expr(used, expr, bound);
if let Some((_, diverge)) = diverge {
recurse_expr(used, &diverge, bound);
}
}
}
Stmt::Expr(expr, _) => recurse_expr(used, expr, bound),
_ => unimplemented!(),
}
}
fn finalize(
used: &mut HashSet<GenericRef>,
bound: &HashSet<GenericRef>,
base: Generics,
) -> Generics {
let mut args = Vec::new();
for arg in base.params.into_pairs().rev() {
match arg.value() {
GenericParam::Type(type_param) => {
if used.contains(&GenericRef::Type(type_param.ident.clone())) {
process_bounds(used, bound, type_param.bounds.iter());
args.push(arg);
}
}
GenericParam::Lifetime(lt) => {
if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
for b in <.bounds {
process_lifetime(used, bound, b.clone());
}
args.push(arg);
}
}
GenericParam::Const(_) => todo!(),
}
}
if args.is_empty() {
Generics::default()
} else {
Generics {
params: Punctuated::from_iter(args.into_iter().rev()),
..base
}
}
}
let mut used = HashSet::new();
let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
for u in usage {
match u {
Usage::Type(ty) => recurse_type(&mut used, ty, &bound),
Usage::Expression(expr) => recurse_expr(&mut used, expr, &bound),
}
}
finalize(&mut used, &bound, base)
}
pub fn generics_as_args(generics: &Generics) -> PathArguments {
if generics.params.is_empty() {
PathArguments::None
} else {
PathArguments::AngleBracketed(AngleBracketedGenericArguments {
colon2_token: None,
lt_token: generics.lt_token.unwrap_or_default(),
args: Punctuated::from_iter(generics.params.pairs().map(|p| {
let (param, punct) = p.into_tuple();
Pair::new(
match param {
GenericParam::Type(type_param) => {
GenericArgument::Type(Type::Path(TypePath {
qself: None,
path: type_param.ident.clone().into(),
}))
}
GenericParam::Lifetime(lt) => {
GenericArgument::Lifetime(lt.lifetime.clone())
}
GenericParam::Const(_) => todo!(),
},
punct.cloned(),
)
})),
gt_token: generics.gt_token.unwrap_or_default(),
})
}
}