generics_util/
lib.rs

1use imbl::HashSet;
2use syn::{
3    punctuated::{Pair, Punctuated},
4    *,
5};
6
7pub enum Usage<'a> {
8    Type(&'a Type),
9    Expression(&'a Expr),
10}
11
12/// Creates a new [`Generics`] object containing only the parameters which are
13/// used (including transitively) anywhere in the given sequence of types and
14/// expressions (`usage`), and excluding those which are already in scope
15/// (`context`).
16pub fn filter_generics<'a>(
17    base: Generics,
18    usage: impl Iterator<Item = Usage<'a>>,
19    context: impl Iterator<Item = &'a Generics>,
20) -> Generics {
21    #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22    enum GenericRef {
23        Type(Ident),
24        Lifetime(Lifetime),
25        Const(Ident),
26    }
27
28    impl From<&GenericParam> for GenericRef {
29        fn from(value: &GenericParam) -> Self {
30            match value {
31                GenericParam::Type(type_param) => GenericRef::Type(type_param.ident.clone()),
32                GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
33                GenericParam::Const(c) => GenericRef::Const(c.ident.clone()),
34            }
35        }
36    }
37
38    fn add_bound_lifetimes(_bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
39        if let Some(_lifetimes) = b {
40            // bound.extend(
41            //     lifetimes
42            //         .lifetimes
43            //         .iter()
44            //         .flat_map(|param| [&param].into_iter().chain(param.bounds.iter()))
45            //         .map(|lt| GenericRef::Lifetime(lt.clone())),
46            // );
47            unimplemented!()
48        }
49    }
50
51    fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
52        let r = GenericRef::Lifetime(lt);
53        if !bound.contains(&r) {
54            used.insert(r);
55        }
56    }
57
58    fn process_generic_arguments(
59        used: &mut HashSet<GenericRef>,
60        bound: &HashSet<GenericRef>,
61        args: &AngleBracketedGenericArguments,
62    ) {
63        for arg in &args.args {
64            match arg {
65                GenericArgument::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
66                GenericArgument::Type(ty) => recurse_type(used, ty, bound),
67                GenericArgument::Const(_) => unimplemented!(),
68                GenericArgument::AssocType(assoc_type) => recurse_type(used, &assoc_type.ty, bound),
69                GenericArgument::AssocConst(assoc_const) => {
70                    recurse_expr(used, &assoc_const.value, bound)
71                }
72                GenericArgument::Constraint(constraint) => {
73                    process_bounds(used, bound, constraint.bounds.iter())
74                }
75                _ => unimplemented!(),
76            }
77        }
78    }
79
80    fn process_path(
81        used: &mut HashSet<GenericRef>,
82        bound: &HashSet<GenericRef>,
83        path: &Path,
84        unqualified: bool,
85    ) {
86        if let Some(ident) = path.get_ident() {
87            let r = GenericRef::Type(ident.clone());
88            if !bound.contains(&r) {
89                used.insert(r);
90            }
91        } else {
92            let mut first = unqualified && path.leading_colon.is_none();
93            for s in &path.segments {
94                if first && s.arguments.is_empty() {
95                    let r = GenericRef::Type(s.ident.clone());
96                    if !bound.contains(&r) {
97                        used.insert(r);
98                    }
99                }
100                first = false;
101                match &s.arguments {
102                    PathArguments::None => {}
103                    PathArguments::AngleBracketed(args) => {
104                        process_generic_arguments(used, bound, args)
105                    }
106                    PathArguments::Parenthesized(args) => {
107                        for ty in &args.inputs {
108                            recurse_type(used, ty, bound);
109                        }
110                        if let ReturnType::Type(_, ty) = &args.output {
111                            recurse_type(used, ty, bound)
112                        }
113                    }
114                }
115            }
116        }
117    }
118
119    fn process_bounds<'a>(
120        used: &mut HashSet<GenericRef>,
121        bound: &HashSet<GenericRef>,
122        b: impl Iterator<Item = &'a TypeParamBound>,
123    ) {
124        for b in b {
125            match b {
126                TypeParamBound::Trait(b) => {
127                    let mut bound = bound.clone();
128                    add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
129                    process_path(used, &bound, &b.path, true);
130                }
131                TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
132                _ => unimplemented!(),
133            }
134        }
135    }
136
137    fn recurse_type(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
138        match ty {
139            Type::Array(arr) => recurse_type(used, &arr.elem, bound),
140            Type::BareFn(bare_fn) => {
141                let mut bound = bound.clone();
142                add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
143
144                for input in &bare_fn.inputs {
145                    recurse_type(used, &input.ty, &bound)
146                }
147
148                if let ReturnType::Type(_, ty) = &bare_fn.output {
149                    recurse_type(used, ty, &bound)
150                }
151            }
152            Type::Group(group) => recurse_type(used, &group.elem, bound),
153            Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
154            // Type::Infer(_) => todo!(),
155            // Type::Macro(_) => todo!(),
156            Type::Never(_) => {}
157            Type::Paren(paren) => recurse_type(used, &paren.elem, bound),
158            Type::Path(path) => {
159                if let Some(qself) = &path.qself {
160                    recurse_type(used, &qself.ty, bound);
161                }
162                process_path(used, bound, &path.path, path.qself.is_none());
163            }
164            Type::Ptr(ptr) => recurse_type(used, &ptr.elem, bound),
165            Type::Reference(reference) => recurse_type(used, &reference.elem, bound),
166            Type::Slice(slice) => recurse_type(used, &slice.elem, bound),
167            Type::TraitObject(trait_object) => {
168                process_bounds(used, bound, trait_object.bounds.iter());
169            }
170            Type::Tuple(tuple) => {
171                for ty in &tuple.elems {
172                    recurse_type(used, ty, bound);
173                }
174            }
175            // Type::Verbatim(_) => todo!(),
176            ty => panic!("unsupported type: {:?}", ty),
177        }
178    }
179
180    fn recurse_expr(used: &mut HashSet<GenericRef>, expr: &Expr, bound: &HashSet<GenericRef>) {
181        match expr {
182            Expr::Array(ExprArray { elems, .. }) | Expr::Tuple(ExprTuple { elems, .. }) => {
183                for expr in elems {
184                    recurse_expr(used, expr, bound);
185                }
186            }
187            Expr::Assign(ExprAssign { left, right, .. })
188            | Expr::Binary(ExprBinary { left, right, .. })
189            | Expr::Index(ExprIndex {
190                expr: left,
191                index: right,
192                ..
193            })
194            | Expr::Repeat(ExprRepeat {
195                expr: left,
196                len: right,
197                ..
198            }) => {
199                recurse_expr(used, left, bound);
200                recurse_expr(used, right, bound);
201            }
202            Expr::Async(ExprAsync {
203                block: Block { stmts, .. },
204                ..
205            })
206            | Expr::Block(ExprBlock {
207                block: Block { stmts, .. },
208                ..
209            })
210            | Expr::Loop(ExprLoop {
211                body: Block { stmts, .. },
212                ..
213            })
214            | Expr::TryBlock(ExprTryBlock {
215                block: Block { stmts, .. },
216                ..
217            })
218            | Expr::Unsafe(ExprUnsafe {
219                block: Block { stmts, .. },
220                ..
221            }) => {
222                for stmt in stmts {
223                    recurse_stmt(used, stmt, bound);
224                }
225            }
226            Expr::Await(ExprAwait { base: expr, .. })
227            | Expr::Field(ExprField { base: expr, .. })
228            | Expr::Group(ExprGroup { expr, .. })
229            | Expr::Paren(ExprParen { expr, .. })
230            | Expr::Reference(ExprReference { expr, .. })
231            | Expr::Try(ExprTry { expr, .. })
232            | Expr::Unary(ExprUnary { expr, .. }) => recurse_expr(used, expr, bound),
233            Expr::Break(ExprBreak { expr, .. }) | Expr::Return(ExprReturn { expr, .. }) => {
234                if let Some(expr) = expr {
235                    recurse_expr(used, expr, bound);
236                }
237            }
238            Expr::Call(ExprCall { func, args, .. }) => {
239                recurse_expr(used, func, bound);
240                for arg in args {
241                    recurse_expr(used, arg, bound);
242                }
243            }
244            Expr::Cast(ExprCast { expr, ty, .. }) => {
245                recurse_expr(used, expr, bound);
246                recurse_type(used, ty, bound);
247            }
248            Expr::Closure(ExprClosure {
249                inputs,
250                output,
251                body,
252                ..
253            }) => {
254                for pat in inputs {
255                    recurse_pat(used, pat, bound);
256                }
257                if let ReturnType::Type(_, ty) = output {
258                    recurse_type(used, ty, bound);
259                }
260                recurse_expr(used, body, bound);
261            }
262            Expr::Continue(_) | Expr::Lit(_) => {}
263            Expr::ForLoop(ExprForLoop {
264                pat,
265                expr,
266                body: Block { stmts, .. },
267                ..
268            }) => {
269                recurse_pat(used, pat, bound);
270                recurse_expr(used, expr, bound);
271                for stmt in stmts {
272                    recurse_stmt(used, stmt, bound);
273                }
274            }
275            Expr::If(ExprIf {
276                cond,
277                then_branch: Block { stmts, .. },
278                else_branch,
279                ..
280            }) => {
281                recurse_expr(used, cond, bound);
282                for stmt in stmts {
283                    recurse_stmt(used, stmt, bound);
284                }
285                if let Some((_, expr)) = else_branch {
286                    recurse_expr(used, expr, bound);
287                }
288            }
289            Expr::Let(ExprLet { pat, expr, .. }) => {
290                recurse_pat(used, pat, bound);
291                recurse_expr(used, expr, bound);
292            }
293            Expr::Match(ExprMatch { expr, arms, .. }) => {
294                recurse_expr(used, expr, bound);
295                for Arm {
296                    pat, guard, body, ..
297                } in arms
298                {
299                    recurse_pat(used, pat, bound);
300                    if let Some((_, expr)) = guard {
301                        recurse_expr(used, expr, bound);
302                    }
303                    recurse_expr(used, body, bound);
304                }
305            }
306            Expr::MethodCall(ExprMethodCall {
307                receiver,
308                turbofish,
309                args,
310                ..
311            }) => {
312                recurse_expr(used, receiver, bound);
313                if let Some(args) = turbofish {
314                    process_generic_arguments(used, bound, args);
315                }
316                for arg in args {
317                    recurse_expr(used, arg, bound);
318                }
319            }
320            Expr::Path(ExprPath { qself, path, .. }) => {
321                process_path(used, bound, path, true);
322                if let Some(QSelf { ty, .. }) = qself {
323                    recurse_type(used, ty, bound);
324                }
325            }
326            Expr::Range(ExprRange { start, end, .. }) => {
327                if let Some(expr) = start {
328                    recurse_expr(used, expr, bound);
329                }
330                if let Some(expr) = end {
331                    recurse_expr(used, expr, bound);
332                }
333            }
334            Expr::Struct(ExprStruct { path, fields, .. }) => {
335                process_path(used, bound, path, true);
336                for FieldValue { expr, .. } in fields {
337                    recurse_expr(used, expr, bound);
338                }
339            }
340            // Expr::Verbatim(_) => todo!(),
341            Expr::While(ExprWhile {
342                cond,
343                body: Block { stmts, .. },
344                ..
345            }) => {
346                recurse_expr(used, cond, bound);
347                for stmt in stmts {
348                    recurse_stmt(used, stmt, bound);
349                }
350            }
351            Expr::Yield(ExprYield { expr, .. }) => {
352                if let Some(expr) = expr {
353                    recurse_expr(used, expr, bound);
354                }
355            }
356            expr => panic!("unsupported expression: {:?}", expr),
357        }
358    }
359
360    fn recurse_pat(used: &mut HashSet<GenericRef>, pat: &Pat, bound: &HashSet<GenericRef>) {
361        match pat {
362            Pat::Const(ExprConst {
363                block: Block { stmts, .. },
364                ..
365            }) => {
366                for stmt in stmts {
367                    recurse_stmt(used, stmt, bound);
368                }
369            }
370            Pat::Ident(_) | Pat::Lit(_) | Pat::Macro(_) | Pat::Rest(_) | Pat::Wild(_) => {}
371            Pat::Or(PatOr { cases, .. }) => {
372                for pat in cases {
373                    recurse_pat(used, pat, bound);
374                }
375            }
376            Pat::Paren(PatParen { pat, .. }) => recurse_pat(used, pat, bound),
377            Pat::Path(ExprPath { qself, path, .. }) => {
378                process_path(used, bound, path, true);
379                if let Some(QSelf { ty, .. }) = qself {
380                    recurse_type(used, ty, bound);
381                }
382            }
383            Pat::Range(ExprRange { start, end, .. }) => {
384                if let Some(expr) = start {
385                    recurse_expr(used, expr, bound);
386                }
387                if let Some(expr) = end {
388                    recurse_expr(used, expr, bound);
389                }
390            }
391            Pat::Reference(PatReference { pat, .. }) => recurse_pat(used, pat, bound),
392            Pat::Slice(PatSlice { elems, .. }) | Pat::Tuple(PatTuple { elems, .. }) => {
393                for pat in elems {
394                    recurse_pat(used, pat, bound);
395                }
396            }
397            Pat::Struct(PatStruct {
398                qself,
399                path,
400                fields,
401                ..
402            }) => {
403                process_path(used, bound, path, true);
404                if let Some(QSelf { ty, .. }) = qself {
405                    recurse_type(used, ty, bound);
406                }
407                for FieldPat { pat, .. } in fields {
408                    recurse_pat(used, pat, bound);
409                }
410            }
411            Pat::TupleStruct(PatTupleStruct {
412                qself, path, elems, ..
413            }) => {
414                process_path(used, bound, path, true);
415                if let Some(QSelf { ty, .. }) = qself {
416                    recurse_type(used, ty, bound);
417                }
418                for pat in elems {
419                    recurse_pat(used, pat, bound);
420                }
421            }
422            Pat::Type(PatType { pat, ty, .. }) => {
423                recurse_pat(used, pat, bound);
424                recurse_type(used, ty, bound);
425            }
426            _ => unimplemented!(),
427        }
428    }
429
430    fn recurse_stmt(used: &mut HashSet<GenericRef>, stmt: &Stmt, bound: &HashSet<GenericRef>) {
431        match stmt {
432            Stmt::Local(Local { pat, init, .. }) => {
433                recurse_pat(used, pat, bound);
434                if let Some(LocalInit { expr, diverge, .. }) = init {
435                    recurse_expr(used, expr, bound);
436                    if let Some((_, diverge)) = diverge {
437                        recurse_expr(used, &diverge, bound);
438                    }
439                }
440            }
441            Stmt::Expr(expr, _) => recurse_expr(used, expr, bound),
442            _ => unimplemented!(),
443        }
444    }
445
446    fn finalize(
447        used: &mut HashSet<GenericRef>,
448        bound: &HashSet<GenericRef>,
449        base: Generics,
450    ) -> Generics {
451        let mut args = Vec::new();
452        for arg in base.params.into_pairs().rev() {
453            match arg.value() {
454                GenericParam::Type(type_param) => {
455                    if used.contains(&GenericRef::Type(type_param.ident.clone())) {
456                        process_bounds(used, bound, type_param.bounds.iter());
457                        args.push(arg);
458                    }
459                }
460                GenericParam::Lifetime(lt) => {
461                    if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
462                        for b in &lt.bounds {
463                            process_lifetime(used, bound, b.clone());
464                        }
465                        args.push(arg);
466                    }
467                }
468                GenericParam::Const(_) => todo!(),
469            }
470        }
471
472        if args.is_empty() {
473            Generics::default()
474        } else {
475            Generics {
476                params: Punctuated::from_iter(args.into_iter().rev()),
477                ..base
478            }
479        }
480    }
481
482    let mut used = HashSet::new();
483    let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
484
485    for u in usage {
486        match u {
487            Usage::Type(ty) => recurse_type(&mut used, ty, &bound),
488            Usage::Expression(expr) => recurse_expr(&mut used, expr, &bound),
489        }
490    }
491
492    finalize(&mut used, &bound, base)
493}
494
495/// Converts a [`Generics`] object to a corresponding [`PathArguments`] object.
496///
497/// For example, could be used to convert `<T: 'static, U: From<T>>` to `<T, U>`
498/// in the following:
499///
500/// ```rust
501/// struct MyType<T: 'static, U: From<T>>(T, U);
502///
503/// impl<T: 'static, U: From<T>> MyType<T, U> {}
504/// ```
505pub fn generics_as_args(generics: &Generics) -> PathArguments {
506    if generics.params.is_empty() {
507        PathArguments::None
508    } else {
509        PathArguments::AngleBracketed(AngleBracketedGenericArguments {
510            colon2_token: None,
511            lt_token: generics.lt_token.unwrap_or_default(),
512            args: Punctuated::from_iter(generics.params.pairs().map(|p| {
513                let (param, punct) = p.into_tuple();
514                Pair::new(
515                    match param {
516                        GenericParam::Type(type_param) => {
517                            GenericArgument::Type(Type::Path(TypePath {
518                                qself: None,
519                                path: type_param.ident.clone().into(),
520                            }))
521                        }
522                        GenericParam::Lifetime(lt) => {
523                            GenericArgument::Lifetime(lt.lifetime.clone())
524                        }
525                        GenericParam::Const(_) => todo!(),
526                    },
527                    punct.cloned(),
528                )
529            })),
530            gt_token: generics.gt_token.unwrap_or_default(),
531        })
532    }
533}