Skip to main content

generics_util/
lib.rs

1use std::iter::once;
2
3use imbl::HashSet;
4use syn::{
5    punctuated::{Pair, Punctuated},
6    *,
7};
8
9pub enum Usage<'a> {
10    Type(&'a Type),
11    Lifetime(&'a Lifetime),
12    Expression(&'a Expr),
13    TypeBound(&'a TypeParamBound),
14}
15
16/// Creates a new [`Generics`] object containing only the parameters which are
17/// used (including transitively) anywhere in the given sequence of types and
18/// expressions (`usage`), and excluding those which are already in scope
19/// (`context`).
20pub fn filter_generics<'a>(
21    base: Generics,
22    usage: impl Iterator<Item = Usage<'a>>,
23    context: impl Iterator<Item = &'a Generics>,
24) -> Generics {
25    #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
26    enum GenericRef {
27        Type(Ident),
28        Lifetime(Lifetime),
29        Const(Ident),
30    }
31
32    impl From<&GenericParam> for GenericRef {
33        fn from(value: &GenericParam) -> Self {
34            match value {
35                GenericParam::Type(type_param) => GenericRef::Type(type_param.ident.clone()),
36                GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
37                GenericParam::Const(c) => GenericRef::Const(c.ident.clone()),
38            }
39        }
40    }
41
42    fn add_bound_lifetimes(_bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
43        if let Some(_lifetimes) = b {
44            // bound.extend(
45            //     lifetimes
46            //         .lifetimes
47            //         .iter()
48            //         .flat_map(|param| [&param].into_iter().chain(param.bounds.iter()))
49            //         .map(|lt| GenericRef::Lifetime(lt.clone())),
50            // );
51            unimplemented!("bound lifetimes")
52        }
53    }
54
55    fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
56        let r = GenericRef::Lifetime(lt);
57        if !bound.contains(&r) {
58            used.insert(r);
59        }
60    }
61
62    fn process_generic_arguments(
63        used: &mut HashSet<GenericRef>,
64        bound: &HashSet<GenericRef>,
65        args: &AngleBracketedGenericArguments,
66    ) {
67        for arg in &args.args {
68            match arg {
69                GenericArgument::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
70                GenericArgument::Type(ty) => recurse_type(used, ty, bound),
71                GenericArgument::Const(expr) => recurse_expr(used, expr, bound),
72                GenericArgument::AssocType(assoc_type) => recurse_type(used, &assoc_type.ty, bound),
73                GenericArgument::AssocConst(assoc_const) => {
74                    recurse_expr(used, &assoc_const.value, bound)
75                }
76                GenericArgument::Constraint(constraint) => {
77                    process_bounds(used, bound, constraint.bounds.iter())
78                }
79                _ => unimplemented!("unknown generic argument kind: {:?}", arg),
80            }
81        }
82    }
83
84    fn process_path(
85        used: &mut HashSet<GenericRef>,
86        bound: &HashSet<GenericRef>,
87        path: &Path,
88        unqualified: bool,
89    ) {
90        if let Some(ident) = path.get_ident() {
91            let r = GenericRef::Type(ident.clone());
92            if !bound.contains(&r) {
93                used.insert(r);
94            }
95        } else {
96            let mut first = unqualified && path.leading_colon.is_none();
97            for s in &path.segments {
98                if first && s.arguments.is_empty() {
99                    let r = GenericRef::Type(s.ident.clone());
100                    if !bound.contains(&r) {
101                        used.insert(r);
102                    }
103                }
104                first = false;
105                match &s.arguments {
106                    PathArguments::None => {}
107                    PathArguments::AngleBracketed(args) => {
108                        process_generic_arguments(used, bound, args)
109                    }
110                    PathArguments::Parenthesized(args) => {
111                        for ty in &args.inputs {
112                            recurse_type(used, ty, bound);
113                        }
114                        if let ReturnType::Type(_, ty) = &args.output {
115                            recurse_type(used, ty, bound)
116                        }
117                    }
118                }
119            }
120        }
121    }
122
123    fn process_bounds<'a>(
124        used: &mut HashSet<GenericRef>,
125        bound: &HashSet<GenericRef>,
126        b: impl Iterator<Item = &'a TypeParamBound>,
127    ) {
128        for b in b {
129            match b {
130                TypeParamBound::Trait(b) => {
131                    let mut bound = bound.clone();
132                    add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
133                    process_path(used, &bound, &b.path, true);
134                }
135                TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
136                _ => unimplemented!("unknown type param bound kind: {:?}", b),
137            }
138        }
139    }
140
141    fn recurse_type(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
142        match ty {
143            Type::Array(arr) => recurse_type(used, &arr.elem, bound),
144            Type::BareFn(bare_fn) => {
145                let mut bound = bound.clone();
146                add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
147
148                for input in &bare_fn.inputs {
149                    recurse_type(used, &input.ty, &bound)
150                }
151
152                if let ReturnType::Type(_, ty) = &bare_fn.output {
153                    recurse_type(used, ty, &bound)
154                }
155            }
156            Type::Group(group) => recurse_type(used, &group.elem, bound),
157            Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
158            // Type::Infer(_) => todo!(),
159            // Type::Macro(_) => todo!(),
160            Type::Never(_) => {}
161            Type::Paren(paren) => recurse_type(used, &paren.elem, bound),
162            Type::Path(path) => {
163                if let Some(qself) = &path.qself {
164                    recurse_type(used, &qself.ty, bound);
165                }
166                process_path(used, bound, &path.path, path.qself.is_none());
167            }
168            Type::Ptr(ptr) => recurse_type(used, &ptr.elem, bound),
169            Type::Reference(reference) => recurse_type(used, &reference.elem, bound),
170            Type::Slice(slice) => recurse_type(used, &slice.elem, bound),
171            Type::TraitObject(trait_object) => {
172                process_bounds(used, bound, trait_object.bounds.iter());
173            }
174            Type::Tuple(tuple) => {
175                for ty in &tuple.elems {
176                    recurse_type(used, ty, bound);
177                }
178            }
179            // Type::Verbatim(_) => todo!(),
180            ty => panic!("unsupported type: {:?}", ty),
181        }
182    }
183
184    fn recurse_expr(used: &mut HashSet<GenericRef>, expr: &Expr, bound: &HashSet<GenericRef>) {
185        match expr {
186            Expr::Array(ExprArray { elems, .. }) | Expr::Tuple(ExprTuple { elems, .. }) => {
187                for expr in elems {
188                    recurse_expr(used, expr, bound);
189                }
190            }
191            Expr::Assign(ExprAssign { left, right, .. })
192            | Expr::Binary(ExprBinary { left, right, .. })
193            | Expr::Index(ExprIndex {
194                expr: left,
195                index: right,
196                ..
197            })
198            | Expr::Repeat(ExprRepeat {
199                expr: left,
200                len: right,
201                ..
202            }) => {
203                recurse_expr(used, left, bound);
204                recurse_expr(used, right, bound);
205            }
206            Expr::Async(ExprAsync {
207                block: Block { stmts, .. },
208                ..
209            })
210            | Expr::Block(ExprBlock {
211                block: Block { stmts, .. },
212                ..
213            })
214            | Expr::Loop(ExprLoop {
215                body: Block { stmts, .. },
216                ..
217            })
218            | Expr::TryBlock(ExprTryBlock {
219                block: Block { stmts, .. },
220                ..
221            })
222            | Expr::Unsafe(ExprUnsafe {
223                block: Block { stmts, .. },
224                ..
225            }) => {
226                for stmt in stmts {
227                    recurse_stmt(used, stmt, bound);
228                }
229            }
230            Expr::Await(ExprAwait { base: expr, .. })
231            | Expr::Field(ExprField { base: expr, .. })
232            | Expr::Group(ExprGroup { expr, .. })
233            | Expr::Paren(ExprParen { expr, .. })
234            | Expr::Reference(ExprReference { expr, .. })
235            | Expr::Try(ExprTry { expr, .. })
236            | Expr::Unary(ExprUnary { expr, .. }) => recurse_expr(used, expr, bound),
237            Expr::Break(ExprBreak { expr, .. }) | Expr::Return(ExprReturn { expr, .. }) => {
238                if let Some(expr) = expr {
239                    recurse_expr(used, expr, bound);
240                }
241            }
242            Expr::Call(ExprCall { func, args, .. }) => {
243                recurse_expr(used, func, bound);
244                for arg in args {
245                    recurse_expr(used, arg, bound);
246                }
247            }
248            Expr::Cast(ExprCast { expr, ty, .. }) => {
249                recurse_expr(used, expr, bound);
250                recurse_type(used, ty, bound);
251            }
252            Expr::Closure(ExprClosure {
253                inputs,
254                output,
255                body,
256                ..
257            }) => {
258                for pat in inputs {
259                    recurse_pat(used, pat, bound);
260                }
261                if let ReturnType::Type(_, ty) = output {
262                    recurse_type(used, ty, bound);
263                }
264                recurse_expr(used, body, bound);
265            }
266            Expr::Continue(_) | Expr::Lit(_) => {}
267            Expr::ForLoop(ExprForLoop {
268                pat,
269                expr,
270                body: Block { stmts, .. },
271                ..
272            }) => {
273                recurse_pat(used, pat, bound);
274                recurse_expr(used, expr, bound);
275                for stmt in stmts {
276                    recurse_stmt(used, stmt, bound);
277                }
278            }
279            Expr::If(ExprIf {
280                cond,
281                then_branch: Block { stmts, .. },
282                else_branch,
283                ..
284            }) => {
285                recurse_expr(used, cond, bound);
286                for stmt in stmts {
287                    recurse_stmt(used, stmt, bound);
288                }
289                if let Some((_, expr)) = else_branch {
290                    recurse_expr(used, expr, bound);
291                }
292            }
293            Expr::Let(ExprLet { pat, expr, .. }) => {
294                recurse_pat(used, pat, bound);
295                recurse_expr(used, expr, bound);
296            }
297            Expr::Match(ExprMatch { expr, arms, .. }) => {
298                recurse_expr(used, expr, bound);
299                for Arm {
300                    pat, guard, body, ..
301                } in arms
302                {
303                    recurse_pat(used, pat, bound);
304                    if let Some((_, expr)) = guard {
305                        recurse_expr(used, expr, bound);
306                    }
307                    recurse_expr(used, body, bound);
308                }
309            }
310            Expr::MethodCall(ExprMethodCall {
311                receiver,
312                turbofish,
313                args,
314                ..
315            }) => {
316                recurse_expr(used, receiver, bound);
317                if let Some(args) = turbofish {
318                    process_generic_arguments(used, bound, args);
319                }
320                for arg in args {
321                    recurse_expr(used, arg, bound);
322                }
323            }
324            Expr::Path(ExprPath { qself, path, .. }) => {
325                process_path(used, bound, path, true);
326                if let Some(QSelf { ty, .. }) = qself {
327                    recurse_type(used, ty, bound);
328                }
329            }
330            Expr::Range(ExprRange { start, end, .. }) => {
331                if let Some(expr) = start {
332                    recurse_expr(used, expr, bound);
333                }
334                if let Some(expr) = end {
335                    recurse_expr(used, expr, bound);
336                }
337            }
338            Expr::Struct(ExprStruct { path, fields, .. }) => {
339                process_path(used, bound, path, true);
340                for FieldValue { expr, .. } in fields {
341                    recurse_expr(used, expr, bound);
342                }
343            }
344            // Expr::Verbatim(_) => todo!(),
345            Expr::While(ExprWhile {
346                cond,
347                body: Block { stmts, .. },
348                ..
349            }) => {
350                recurse_expr(used, cond, bound);
351                for stmt in stmts {
352                    recurse_stmt(used, stmt, bound);
353                }
354            }
355            Expr::Yield(ExprYield { expr, .. }) => {
356                if let Some(expr) = expr {
357                    recurse_expr(used, expr, bound);
358                }
359            }
360            expr => panic!("unsupported expression: {:?}", expr),
361        }
362    }
363
364    fn recurse_pat(used: &mut HashSet<GenericRef>, pat: &Pat, bound: &HashSet<GenericRef>) {
365        match pat {
366            Pat::Const(ExprConst {
367                block: Block { stmts, .. },
368                ..
369            }) => {
370                for stmt in stmts {
371                    recurse_stmt(used, stmt, bound);
372                }
373            }
374            Pat::Ident(_) | Pat::Lit(_) | Pat::Macro(_) | Pat::Rest(_) | Pat::Wild(_) => {}
375            Pat::Or(PatOr { cases, .. }) => {
376                for pat in cases {
377                    recurse_pat(used, pat, bound);
378                }
379            }
380            Pat::Paren(PatParen { pat, .. }) => recurse_pat(used, pat, bound),
381            Pat::Path(ExprPath { qself, path, .. }) => {
382                process_path(used, bound, path, true);
383                if let Some(QSelf { ty, .. }) = qself {
384                    recurse_type(used, ty, bound);
385                }
386            }
387            Pat::Range(ExprRange { start, end, .. }) => {
388                if let Some(expr) = start {
389                    recurse_expr(used, expr, bound);
390                }
391                if let Some(expr) = end {
392                    recurse_expr(used, expr, bound);
393                }
394            }
395            Pat::Reference(PatReference { pat, .. }) => recurse_pat(used, pat, bound),
396            Pat::Slice(PatSlice { elems, .. }) | Pat::Tuple(PatTuple { elems, .. }) => {
397                for pat in elems {
398                    recurse_pat(used, pat, bound);
399                }
400            }
401            Pat::Struct(PatStruct {
402                qself,
403                path,
404                fields,
405                ..
406            }) => {
407                process_path(used, bound, path, true);
408                if let Some(QSelf { ty, .. }) = qself {
409                    recurse_type(used, ty, bound);
410                }
411                for FieldPat { pat, .. } in fields {
412                    recurse_pat(used, pat, bound);
413                }
414            }
415            Pat::TupleStruct(PatTupleStruct {
416                qself, path, elems, ..
417            }) => {
418                process_path(used, bound, path, true);
419                if let Some(QSelf { ty, .. }) = qself {
420                    recurse_type(used, ty, bound);
421                }
422                for pat in elems {
423                    recurse_pat(used, pat, bound);
424                }
425            }
426            Pat::Type(PatType { pat, ty, .. }) => {
427                recurse_pat(used, pat, bound);
428                recurse_type(used, ty, bound);
429            }
430            _ => unimplemented!("unknown pattern kind: {:?}", pat),
431        }
432    }
433
434    fn recurse_stmt(used: &mut HashSet<GenericRef>, stmt: &Stmt, bound: &HashSet<GenericRef>) {
435        match stmt {
436            Stmt::Local(Local { pat, init, .. }) => {
437                recurse_pat(used, pat, bound);
438                if let Some(LocalInit { expr, diverge, .. }) = init {
439                    recurse_expr(used, expr, bound);
440                    if let Some((_, diverge)) = diverge {
441                        recurse_expr(used, &diverge, bound);
442                    }
443                }
444            }
445            Stmt::Expr(expr, _) => recurse_expr(used, expr, bound),
446            _ => unimplemented!("unknown statement kind: {:?}", stmt),
447        }
448    }
449
450    fn finalize(
451        used: &mut HashSet<GenericRef>,
452        bound: &HashSet<GenericRef>,
453        base: Generics,
454    ) -> Generics {
455        let mut args = Vec::new();
456        for arg in base.params.into_pairs().rev() {
457            match arg.value() {
458                GenericParam::Type(type_param) => {
459                    if used.contains(&GenericRef::Type(type_param.ident.clone())) {
460                        process_bounds(used, bound, type_param.bounds.iter());
461                        args.push(arg);
462                    }
463                }
464                GenericParam::Lifetime(lt) => {
465                    if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
466                        for b in &lt.bounds {
467                            process_lifetime(used, bound, b.clone());
468                        }
469                        args.push(arg);
470                    }
471                }
472                GenericParam::Const(_) => todo!(),
473            }
474        }
475
476        if args.is_empty() {
477            Generics::default()
478        } else {
479            Generics {
480                params: Punctuated::from_iter(args.into_iter().rev()),
481                ..base
482            }
483        }
484    }
485
486    let mut used = HashSet::new();
487    let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
488
489    for u in usage {
490        match u {
491            Usage::Type(ty) => recurse_type(&mut used, ty, &bound),
492            Usage::Lifetime(lt) => process_lifetime(&mut used, &bound, lt.clone()),
493            Usage::Expression(expr) => recurse_expr(&mut used, expr, &bound),
494            Usage::TypeBound(b) => process_bounds(&mut used, &bound, once(b)),
495        }
496    }
497
498    finalize(&mut used, &bound, base)
499}
500
501/// Converts a [`Generics`] object to a corresponding [`PathArguments`] object.
502///
503/// For example, could be used to convert `<T: 'static, U: From<T>>` to `<T, U>`
504/// in the following:
505///
506/// ```rust
507/// struct MyType<T: 'static, U: From<T>>(T, U);
508///
509/// impl<T: 'static, U: From<T>> MyType<T, U> {}
510/// ```
511pub fn generics_as_args(generics: &Generics) -> PathArguments {
512    if generics.params.is_empty() {
513        PathArguments::None
514    } else {
515        PathArguments::AngleBracketed(AngleBracketedGenericArguments {
516            colon2_token: None,
517            lt_token: generics.lt_token.unwrap_or_default(),
518            args: Punctuated::from_iter(generics.params.pairs().map(|p| {
519                let (param, punct) = p.into_tuple();
520                Pair::new(
521                    match param {
522                        GenericParam::Type(type_param) => {
523                            GenericArgument::Type(Type::Path(TypePath {
524                                qself: None,
525                                path: type_param.ident.clone().into(),
526                            }))
527                        }
528                        GenericParam::Lifetime(lt) => {
529                            GenericArgument::Lifetime(lt.lifetime.clone())
530                        }
531                        GenericParam::Const(_) => todo!(),
532                    },
533                    punct.cloned(),
534                )
535            })),
536            gt_token: generics.gt_token.unwrap_or_default(),
537        })
538    }
539}