generics_util/
lib.rs

1use im::HashSet;
2use syn::{
3    punctuated::{Pair, Punctuated},
4    AngleBracketedGenericArguments, BoundLifetimes, GenericArgument, GenericParam, Generics, Ident,
5    Lifetime, Path, PathArguments, ReturnType, Type, TypeParamBound, TypePath,
6};
7
8/// Creates a new [`Generics`] object containing only the parameters which are
9/// used (including transitively) anywhere in the given sequence of types
10/// (`usage`), and excluding those which are already in scope (`context`).
11pub fn filter_generics<'a>(
12    base: Generics,
13    usage: impl Iterator<Item = &'a Type>,
14    context: impl Iterator<Item = &'a Generics>,
15) -> Generics {
16    #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
17    enum GenericRef {
18        Lifetime(Lifetime),
19        TypeOrConst(Ident),
20    }
21
22    impl From<&GenericParam> for GenericRef {
23        fn from(value: &GenericParam) -> Self {
24            match value {
25                GenericParam::Type(type_param) => GenericRef::TypeOrConst(type_param.ident.clone()),
26                GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
27                GenericParam::Const(_) => todo!(),
28            }
29        }
30    }
31
32    fn add_bound_lifetimes(bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
33        if let Some(lifetimes) = b {
34            bound.extend(
35                lifetimes
36                    .lifetimes
37                    .iter()
38                    .flat_map(|lt| [&lt.lifetime].into_iter().chain(lt.bounds.iter()))
39                    .map(|lt| GenericRef::Lifetime(lt.clone())),
40            );
41        }
42    }
43
44    fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
45        let r = GenericRef::Lifetime(lt);
46        if !bound.contains(&r) {
47            used.insert(r);
48        }
49    }
50
51    fn process_path(
52        used: &mut HashSet<GenericRef>,
53        bound: &HashSet<GenericRef>,
54        path: &Path,
55        unqualified: bool,
56    ) {
57        if let Some(ident) = path.get_ident() {
58            let r = GenericRef::TypeOrConst(ident.clone());
59            if !bound.contains(&r) {
60                used.insert(r);
61            }
62        } else {
63            let mut first = unqualified && path.leading_colon.is_none();
64            for s in &path.segments {
65                if first && s.arguments.is_empty() {
66                    let r = GenericRef::TypeOrConst(s.ident.clone());
67                    if !bound.contains(&r) {
68                        used.insert(r);
69                    }
70                }
71                first = false;
72                match &s.arguments {
73                    PathArguments::None => {}
74                    PathArguments::AngleBracketed(args) => {
75                        for arg in &args.args {
76                            match arg {
77                                GenericArgument::Lifetime(lt) => {
78                                    process_lifetime(used, bound, lt.clone())
79                                }
80                                GenericArgument::Type(ty) => recurse(used, ty, bound),
81                                GenericArgument::Const(_) => todo!(),
82                                GenericArgument::Binding(binding) => {
83                                    recurse(used, &binding.ty, bound)
84                                }
85                                GenericArgument::Constraint(constraint) => {
86                                    process_bounds(used, bound, constraint.bounds.iter())
87                                }
88                            }
89                        }
90                    }
91                    PathArguments::Parenthesized(args) => {
92                        for ty in &args.inputs {
93                            recurse(used, ty, bound);
94                        }
95                        if let ReturnType::Type(_, ty) = &args.output {
96                            recurse(used, ty, bound)
97                        }
98                    }
99                }
100            }
101        }
102    }
103
104    fn process_bounds<'a>(
105        used: &mut HashSet<GenericRef>,
106        bound: &HashSet<GenericRef>,
107        b: impl Iterator<Item = &'a TypeParamBound>,
108    ) {
109        for b in b {
110            match b {
111                TypeParamBound::Trait(b) => {
112                    let mut bound = bound.clone();
113                    add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
114                    process_path(used, &bound, &b.path, true);
115                }
116                TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
117            }
118        }
119    }
120
121    fn recurse(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
122        match ty {
123            Type::Array(arr) => recurse(used, &arr.elem, bound),
124            Type::BareFn(bare_fn) => {
125                let mut bound = bound.clone();
126                add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
127
128                for input in &bare_fn.inputs {
129                    recurse(used, &input.ty, &bound)
130                }
131
132                if let ReturnType::Type(_, ty) = &bare_fn.output {
133                    recurse(used, ty, &bound)
134                }
135            }
136            Type::Group(group) => recurse(used, &group.elem, bound),
137            Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
138            // Type::Infer(_) => todo!(),
139            // Type::Macro(_) => todo!(),
140            Type::Never(_) => {}
141            Type::Paren(paren) => recurse(used, &paren.elem, bound),
142            Type::Path(path) => {
143                if let Some(qself) = &path.qself {
144                    recurse(used, &qself.ty, bound);
145                }
146                process_path(used, bound, &path.path, path.qself.is_none());
147            }
148            Type::Ptr(ptr) => recurse(used, &ptr.elem, bound),
149            Type::Reference(reference) => recurse(used, &reference.elem, bound),
150            Type::Slice(slice) => recurse(used, &slice.elem, bound),
151            Type::TraitObject(trait_object) => {
152                process_bounds(used, bound, trait_object.bounds.iter());
153            }
154            Type::Tuple(tuple) => {
155                for ty in &tuple.elems {
156                    recurse(used, ty, bound);
157                }
158            }
159            // Type::Verbatim(_) => todo!(),
160            ty => panic!("unsupported type: {:?}", ty),
161        }
162    }
163
164    fn finalize(
165        used: &mut HashSet<GenericRef>,
166        bound: &HashSet<GenericRef>,
167        base: Generics,
168    ) -> Generics {
169        let mut args = Vec::new();
170        for arg in base.params.into_pairs().rev() {
171            match arg.value() {
172                GenericParam::Type(type_param) => {
173                    if used.contains(&GenericRef::TypeOrConst(type_param.ident.clone())) {
174                        process_bounds(used, bound, type_param.bounds.iter());
175                        args.push(arg);
176                    }
177                }
178                GenericParam::Lifetime(lt) => {
179                    if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
180                        for b in &lt.bounds {
181                            process_lifetime(used, bound, b.clone());
182                        }
183                        args.push(arg);
184                    }
185                }
186                GenericParam::Const(_) => todo!(),
187            }
188        }
189
190        if args.is_empty() {
191            Generics::default()
192        } else {
193            Generics {
194                params: Punctuated::from_iter(args.into_iter().rev()),
195                ..base
196            }
197        }
198    }
199
200    let mut used = HashSet::new();
201    let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
202
203    for ty in usage {
204        recurse(&mut used, ty, &bound);
205    }
206
207    finalize(&mut used, &bound, base)
208}
209
210/// Converts a [`Generics`] object to a corresponding [`PathArguments`] object.
211///
212/// For example, could be used to convert `<T: 'static, U: From<T>>` to `<T, U>`
213/// in the following:
214///
215/// ```rust
216/// struct MyType<T: 'static, U: From<T>>(T, U);
217///
218/// impl<T: 'static, U: From<T>> MyType<T, U> {}
219/// ```
220pub fn generics_as_args(generics: &Generics) -> PathArguments {
221    if generics.params.is_empty() {
222        PathArguments::None
223    } else {
224        PathArguments::AngleBracketed(AngleBracketedGenericArguments {
225            colon2_token: None,
226            lt_token: generics.lt_token.unwrap_or_default(),
227            args: Punctuated::from_iter(generics.params.pairs().map(|p| {
228                let (param, punct) = p.into_tuple();
229                Pair::new(
230                    match param {
231                        GenericParam::Type(type_param) => {
232                            GenericArgument::Type(Type::Path(TypePath {
233                                qself: None,
234                                path: type_param.ident.clone().into(),
235                            }))
236                        }
237                        GenericParam::Lifetime(lt) => {
238                            GenericArgument::Lifetime(lt.lifetime.clone())
239                        }
240                        GenericParam::Const(_) => todo!(),
241                    },
242                    punct.cloned(),
243                )
244            })),
245            gt_token: generics.gt_token.unwrap_or_default(),
246        })
247    }
248}