1use im::HashSet;
2use syn::{
3 punctuated::{Pair, Punctuated},
4 AngleBracketedGenericArguments, BoundLifetimes, GenericArgument, GenericParam, Generics, Ident,
5 Lifetime, Path, PathArguments, ReturnType, Type, TypeParamBound, TypePath,
6};
7
8pub fn filter_generics<'a>(
12 base: Generics,
13 usage: impl Iterator<Item = &'a Type>,
14 context: impl Iterator<Item = &'a Generics>,
15) -> Generics {
16 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
17 enum GenericRef {
18 Lifetime(Lifetime),
19 TypeOrConst(Ident),
20 }
21
22 impl From<&GenericParam> for GenericRef {
23 fn from(value: &GenericParam) -> Self {
24 match value {
25 GenericParam::Type(type_param) => GenericRef::TypeOrConst(type_param.ident.clone()),
26 GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
27 GenericParam::Const(_) => todo!(),
28 }
29 }
30 }
31
32 fn add_bound_lifetimes(bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
33 if let Some(lifetimes) = b {
34 bound.extend(
35 lifetimes
36 .lifetimes
37 .iter()
38 .flat_map(|lt| [<.lifetime].into_iter().chain(lt.bounds.iter()))
39 .map(|lt| GenericRef::Lifetime(lt.clone())),
40 );
41 }
42 }
43
44 fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
45 let r = GenericRef::Lifetime(lt);
46 if !bound.contains(&r) {
47 used.insert(r);
48 }
49 }
50
51 fn process_path(
52 used: &mut HashSet<GenericRef>,
53 bound: &HashSet<GenericRef>,
54 path: &Path,
55 unqualified: bool,
56 ) {
57 if let Some(ident) = path.get_ident() {
58 let r = GenericRef::TypeOrConst(ident.clone());
59 if !bound.contains(&r) {
60 used.insert(r);
61 }
62 } else {
63 let mut first = unqualified && path.leading_colon.is_none();
64 for s in &path.segments {
65 if first && s.arguments.is_empty() {
66 let r = GenericRef::TypeOrConst(s.ident.clone());
67 if !bound.contains(&r) {
68 used.insert(r);
69 }
70 }
71 first = false;
72 match &s.arguments {
73 PathArguments::None => {}
74 PathArguments::AngleBracketed(args) => {
75 for arg in &args.args {
76 match arg {
77 GenericArgument::Lifetime(lt) => {
78 process_lifetime(used, bound, lt.clone())
79 }
80 GenericArgument::Type(ty) => recurse(used, ty, bound),
81 GenericArgument::Const(_) => todo!(),
82 GenericArgument::Binding(binding) => {
83 recurse(used, &binding.ty, bound)
84 }
85 GenericArgument::Constraint(constraint) => {
86 process_bounds(used, bound, constraint.bounds.iter())
87 }
88 }
89 }
90 }
91 PathArguments::Parenthesized(args) => {
92 for ty in &args.inputs {
93 recurse(used, ty, bound);
94 }
95 if let ReturnType::Type(_, ty) = &args.output {
96 recurse(used, ty, bound)
97 }
98 }
99 }
100 }
101 }
102 }
103
104 fn process_bounds<'a>(
105 used: &mut HashSet<GenericRef>,
106 bound: &HashSet<GenericRef>,
107 b: impl Iterator<Item = &'a TypeParamBound>,
108 ) {
109 for b in b {
110 match b {
111 TypeParamBound::Trait(b) => {
112 let mut bound = bound.clone();
113 add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
114 process_path(used, &bound, &b.path, true);
115 }
116 TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
117 }
118 }
119 }
120
121 fn recurse(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
122 match ty {
123 Type::Array(arr) => recurse(used, &arr.elem, bound),
124 Type::BareFn(bare_fn) => {
125 let mut bound = bound.clone();
126 add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
127
128 for input in &bare_fn.inputs {
129 recurse(used, &input.ty, &bound)
130 }
131
132 if let ReturnType::Type(_, ty) = &bare_fn.output {
133 recurse(used, ty, &bound)
134 }
135 }
136 Type::Group(group) => recurse(used, &group.elem, bound),
137 Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
138 Type::Never(_) => {}
141 Type::Paren(paren) => recurse(used, &paren.elem, bound),
142 Type::Path(path) => {
143 if let Some(qself) = &path.qself {
144 recurse(used, &qself.ty, bound);
145 }
146 process_path(used, bound, &path.path, path.qself.is_none());
147 }
148 Type::Ptr(ptr) => recurse(used, &ptr.elem, bound),
149 Type::Reference(reference) => recurse(used, &reference.elem, bound),
150 Type::Slice(slice) => recurse(used, &slice.elem, bound),
151 Type::TraitObject(trait_object) => {
152 process_bounds(used, bound, trait_object.bounds.iter());
153 }
154 Type::Tuple(tuple) => {
155 for ty in &tuple.elems {
156 recurse(used, ty, bound);
157 }
158 }
159 ty => panic!("unsupported type: {:?}", ty),
161 }
162 }
163
164 fn finalize(
165 used: &mut HashSet<GenericRef>,
166 bound: &HashSet<GenericRef>,
167 base: Generics,
168 ) -> Generics {
169 let mut args = Vec::new();
170 for arg in base.params.into_pairs().rev() {
171 match arg.value() {
172 GenericParam::Type(type_param) => {
173 if used.contains(&GenericRef::TypeOrConst(type_param.ident.clone())) {
174 process_bounds(used, bound, type_param.bounds.iter());
175 args.push(arg);
176 }
177 }
178 GenericParam::Lifetime(lt) => {
179 if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
180 for b in <.bounds {
181 process_lifetime(used, bound, b.clone());
182 }
183 args.push(arg);
184 }
185 }
186 GenericParam::Const(_) => todo!(),
187 }
188 }
189
190 if args.is_empty() {
191 Generics::default()
192 } else {
193 Generics {
194 params: Punctuated::from_iter(args.into_iter().rev()),
195 ..base
196 }
197 }
198 }
199
200 let mut used = HashSet::new();
201 let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
202
203 for ty in usage {
204 recurse(&mut used, ty, &bound);
205 }
206
207 finalize(&mut used, &bound, base)
208}
209
210pub fn generics_as_args(generics: &Generics) -> PathArguments {
221 if generics.params.is_empty() {
222 PathArguments::None
223 } else {
224 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
225 colon2_token: None,
226 lt_token: generics.lt_token.unwrap_or_default(),
227 args: Punctuated::from_iter(generics.params.pairs().map(|p| {
228 let (param, punct) = p.into_tuple();
229 Pair::new(
230 match param {
231 GenericParam::Type(type_param) => {
232 GenericArgument::Type(Type::Path(TypePath {
233 qself: None,
234 path: type_param.ident.clone().into(),
235 }))
236 }
237 GenericParam::Lifetime(lt) => {
238 GenericArgument::Lifetime(lt.lifetime.clone())
239 }
240 GenericParam::Const(_) => todo!(),
241 },
242 punct.cloned(),
243 )
244 })),
245 gt_token: generics.gt_token.unwrap_or_default(),
246 })
247 }
248}