1use std::iter::once;
2
3use imbl::HashSet;
4use syn::{
5 punctuated::{Pair, Punctuated},
6 *,
7};
8
9pub enum Usage<'a> {
10 Type(&'a Type),
11 Lifetime(&'a Lifetime),
12 Expression(&'a Expr),
13 TypeBound(&'a TypeParamBound),
14}
15
16pub fn filter_generics<'a>(
21 base: Generics,
22 usage: impl Iterator<Item = Usage<'a>>,
23 context: impl Iterator<Item = &'a Generics>,
24) -> Generics {
25 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
26 enum GenericRef {
27 Type(Ident),
28 Lifetime(Lifetime),
29 Const(Ident),
30 }
31
32 impl From<&GenericParam> for GenericRef {
33 fn from(value: &GenericParam) -> Self {
34 match value {
35 GenericParam::Type(type_param) => GenericRef::Type(type_param.ident.clone()),
36 GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
37 GenericParam::Const(c) => GenericRef::Const(c.ident.clone()),
38 }
39 }
40 }
41
42 fn add_bound_lifetimes(_bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
43 if let Some(_lifetimes) = b {
44 unimplemented!("bound lifetimes")
52 }
53 }
54
55 fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
56 let r = GenericRef::Lifetime(lt);
57 if !bound.contains(&r) {
58 used.insert(r);
59 }
60 }
61
62 fn process_generic_arguments(
63 used: &mut HashSet<GenericRef>,
64 bound: &HashSet<GenericRef>,
65 args: &AngleBracketedGenericArguments,
66 ) {
67 for arg in &args.args {
68 match arg {
69 GenericArgument::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
70 GenericArgument::Type(ty) => recurse_type(used, ty, bound),
71 GenericArgument::Const(expr) => recurse_expr(used, expr, bound),
72 GenericArgument::AssocType(assoc_type) => recurse_type(used, &assoc_type.ty, bound),
73 GenericArgument::AssocConst(assoc_const) => {
74 recurse_expr(used, &assoc_const.value, bound)
75 }
76 GenericArgument::Constraint(constraint) => {
77 process_bounds(used, bound, constraint.bounds.iter())
78 }
79 _ => unimplemented!("unknown generic argument kind: {:?}", arg),
80 }
81 }
82 }
83
84 fn process_path(
85 used: &mut HashSet<GenericRef>,
86 bound: &HashSet<GenericRef>,
87 path: &Path,
88 unqualified: bool,
89 ) {
90 if let Some(ident) = path.get_ident() {
91 let r = GenericRef::Type(ident.clone());
92 if !bound.contains(&r) {
93 used.insert(r);
94 }
95 } else {
96 let mut first = unqualified && path.leading_colon.is_none();
97 for s in &path.segments {
98 if first && s.arguments.is_empty() {
99 let r = GenericRef::Type(s.ident.clone());
100 if !bound.contains(&r) {
101 used.insert(r);
102 }
103 }
104 first = false;
105 match &s.arguments {
106 PathArguments::None => {}
107 PathArguments::AngleBracketed(args) => {
108 process_generic_arguments(used, bound, args)
109 }
110 PathArguments::Parenthesized(args) => {
111 for ty in &args.inputs {
112 recurse_type(used, ty, bound);
113 }
114 if let ReturnType::Type(_, ty) = &args.output {
115 recurse_type(used, ty, bound)
116 }
117 }
118 }
119 }
120 }
121 }
122
123 fn process_bounds<'a>(
124 used: &mut HashSet<GenericRef>,
125 bound: &HashSet<GenericRef>,
126 b: impl Iterator<Item = &'a TypeParamBound>,
127 ) {
128 for b in b {
129 match b {
130 TypeParamBound::Trait(b) => {
131 let mut bound = bound.clone();
132 add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
133 process_path(used, &bound, &b.path, true);
134 }
135 TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
136 _ => unimplemented!("unknown type param bound kind: {:?}", b),
137 }
138 }
139 }
140
141 fn recurse_type(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
142 match ty {
143 Type::Array(arr) => recurse_type(used, &arr.elem, bound),
144 Type::BareFn(bare_fn) => {
145 let mut bound = bound.clone();
146 add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
147
148 for input in &bare_fn.inputs {
149 recurse_type(used, &input.ty, &bound)
150 }
151
152 if let ReturnType::Type(_, ty) = &bare_fn.output {
153 recurse_type(used, ty, &bound)
154 }
155 }
156 Type::Group(group) => recurse_type(used, &group.elem, bound),
157 Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
158 Type::Never(_) => {}
161 Type::Paren(paren) => recurse_type(used, &paren.elem, bound),
162 Type::Path(path) => {
163 if let Some(qself) = &path.qself {
164 recurse_type(used, &qself.ty, bound);
165 }
166 process_path(used, bound, &path.path, path.qself.is_none());
167 }
168 Type::Ptr(ptr) => recurse_type(used, &ptr.elem, bound),
169 Type::Reference(reference) => recurse_type(used, &reference.elem, bound),
170 Type::Slice(slice) => recurse_type(used, &slice.elem, bound),
171 Type::TraitObject(trait_object) => {
172 process_bounds(used, bound, trait_object.bounds.iter());
173 }
174 Type::Tuple(tuple) => {
175 for ty in &tuple.elems {
176 recurse_type(used, ty, bound);
177 }
178 }
179 ty => panic!("unsupported type: {:?}", ty),
181 }
182 }
183
184 fn recurse_expr(used: &mut HashSet<GenericRef>, expr: &Expr, bound: &HashSet<GenericRef>) {
185 match expr {
186 Expr::Array(ExprArray { elems, .. }) | Expr::Tuple(ExprTuple { elems, .. }) => {
187 for expr in elems {
188 recurse_expr(used, expr, bound);
189 }
190 }
191 Expr::Assign(ExprAssign { left, right, .. })
192 | Expr::Binary(ExprBinary { left, right, .. })
193 | Expr::Index(ExprIndex {
194 expr: left,
195 index: right,
196 ..
197 })
198 | Expr::Repeat(ExprRepeat {
199 expr: left,
200 len: right,
201 ..
202 }) => {
203 recurse_expr(used, left, bound);
204 recurse_expr(used, right, bound);
205 }
206 Expr::Async(ExprAsync {
207 block: Block { stmts, .. },
208 ..
209 })
210 | Expr::Block(ExprBlock {
211 block: Block { stmts, .. },
212 ..
213 })
214 | Expr::Loop(ExprLoop {
215 body: Block { stmts, .. },
216 ..
217 })
218 | Expr::TryBlock(ExprTryBlock {
219 block: Block { stmts, .. },
220 ..
221 })
222 | Expr::Unsafe(ExprUnsafe {
223 block: Block { stmts, .. },
224 ..
225 }) => {
226 for stmt in stmts {
227 recurse_stmt(used, stmt, bound);
228 }
229 }
230 Expr::Await(ExprAwait { base: expr, .. })
231 | Expr::Field(ExprField { base: expr, .. })
232 | Expr::Group(ExprGroup { expr, .. })
233 | Expr::Paren(ExprParen { expr, .. })
234 | Expr::Reference(ExprReference { expr, .. })
235 | Expr::Try(ExprTry { expr, .. })
236 | Expr::Unary(ExprUnary { expr, .. }) => recurse_expr(used, expr, bound),
237 Expr::Break(ExprBreak { expr, .. }) | Expr::Return(ExprReturn { expr, .. }) => {
238 if let Some(expr) = expr {
239 recurse_expr(used, expr, bound);
240 }
241 }
242 Expr::Call(ExprCall { func, args, .. }) => {
243 recurse_expr(used, func, bound);
244 for arg in args {
245 recurse_expr(used, arg, bound);
246 }
247 }
248 Expr::Cast(ExprCast { expr, ty, .. }) => {
249 recurse_expr(used, expr, bound);
250 recurse_type(used, ty, bound);
251 }
252 Expr::Closure(ExprClosure {
253 inputs,
254 output,
255 body,
256 ..
257 }) => {
258 for pat in inputs {
259 recurse_pat(used, pat, bound);
260 }
261 if let ReturnType::Type(_, ty) = output {
262 recurse_type(used, ty, bound);
263 }
264 recurse_expr(used, body, bound);
265 }
266 Expr::Continue(_) | Expr::Lit(_) => {}
267 Expr::ForLoop(ExprForLoop {
268 pat,
269 expr,
270 body: Block { stmts, .. },
271 ..
272 }) => {
273 recurse_pat(used, pat, bound);
274 recurse_expr(used, expr, bound);
275 for stmt in stmts {
276 recurse_stmt(used, stmt, bound);
277 }
278 }
279 Expr::If(ExprIf {
280 cond,
281 then_branch: Block { stmts, .. },
282 else_branch,
283 ..
284 }) => {
285 recurse_expr(used, cond, bound);
286 for stmt in stmts {
287 recurse_stmt(used, stmt, bound);
288 }
289 if let Some((_, expr)) = else_branch {
290 recurse_expr(used, expr, bound);
291 }
292 }
293 Expr::Let(ExprLet { pat, expr, .. }) => {
294 recurse_pat(used, pat, bound);
295 recurse_expr(used, expr, bound);
296 }
297 Expr::Match(ExprMatch { expr, arms, .. }) => {
298 recurse_expr(used, expr, bound);
299 for Arm {
300 pat, guard, body, ..
301 } in arms
302 {
303 recurse_pat(used, pat, bound);
304 if let Some((_, expr)) = guard {
305 recurse_expr(used, expr, bound);
306 }
307 recurse_expr(used, body, bound);
308 }
309 }
310 Expr::MethodCall(ExprMethodCall {
311 receiver,
312 turbofish,
313 args,
314 ..
315 }) => {
316 recurse_expr(used, receiver, bound);
317 if let Some(args) = turbofish {
318 process_generic_arguments(used, bound, args);
319 }
320 for arg in args {
321 recurse_expr(used, arg, bound);
322 }
323 }
324 Expr::Path(ExprPath { qself, path, .. }) => {
325 process_path(used, bound, path, true);
326 if let Some(QSelf { ty, .. }) = qself {
327 recurse_type(used, ty, bound);
328 }
329 }
330 Expr::Range(ExprRange { start, end, .. }) => {
331 if let Some(expr) = start {
332 recurse_expr(used, expr, bound);
333 }
334 if let Some(expr) = end {
335 recurse_expr(used, expr, bound);
336 }
337 }
338 Expr::Struct(ExprStruct { path, fields, .. }) => {
339 process_path(used, bound, path, true);
340 for FieldValue { expr, .. } in fields {
341 recurse_expr(used, expr, bound);
342 }
343 }
344 Expr::While(ExprWhile {
346 cond,
347 body: Block { stmts, .. },
348 ..
349 }) => {
350 recurse_expr(used, cond, bound);
351 for stmt in stmts {
352 recurse_stmt(used, stmt, bound);
353 }
354 }
355 Expr::Yield(ExprYield { expr, .. }) => {
356 if let Some(expr) = expr {
357 recurse_expr(used, expr, bound);
358 }
359 }
360 expr => panic!("unsupported expression: {:?}", expr),
361 }
362 }
363
364 fn recurse_pat(used: &mut HashSet<GenericRef>, pat: &Pat, bound: &HashSet<GenericRef>) {
365 match pat {
366 Pat::Const(ExprConst {
367 block: Block { stmts, .. },
368 ..
369 }) => {
370 for stmt in stmts {
371 recurse_stmt(used, stmt, bound);
372 }
373 }
374 Pat::Ident(_) | Pat::Lit(_) | Pat::Macro(_) | Pat::Rest(_) | Pat::Wild(_) => {}
375 Pat::Or(PatOr { cases, .. }) => {
376 for pat in cases {
377 recurse_pat(used, pat, bound);
378 }
379 }
380 Pat::Paren(PatParen { pat, .. }) => recurse_pat(used, pat, bound),
381 Pat::Path(ExprPath { qself, path, .. }) => {
382 process_path(used, bound, path, true);
383 if let Some(QSelf { ty, .. }) = qself {
384 recurse_type(used, ty, bound);
385 }
386 }
387 Pat::Range(ExprRange { start, end, .. }) => {
388 if let Some(expr) = start {
389 recurse_expr(used, expr, bound);
390 }
391 if let Some(expr) = end {
392 recurse_expr(used, expr, bound);
393 }
394 }
395 Pat::Reference(PatReference { pat, .. }) => recurse_pat(used, pat, bound),
396 Pat::Slice(PatSlice { elems, .. }) | Pat::Tuple(PatTuple { elems, .. }) => {
397 for pat in elems {
398 recurse_pat(used, pat, bound);
399 }
400 }
401 Pat::Struct(PatStruct {
402 qself,
403 path,
404 fields,
405 ..
406 }) => {
407 process_path(used, bound, path, true);
408 if let Some(QSelf { ty, .. }) = qself {
409 recurse_type(used, ty, bound);
410 }
411 for FieldPat { pat, .. } in fields {
412 recurse_pat(used, pat, bound);
413 }
414 }
415 Pat::TupleStruct(PatTupleStruct {
416 qself, path, elems, ..
417 }) => {
418 process_path(used, bound, path, true);
419 if let Some(QSelf { ty, .. }) = qself {
420 recurse_type(used, ty, bound);
421 }
422 for pat in elems {
423 recurse_pat(used, pat, bound);
424 }
425 }
426 Pat::Type(PatType { pat, ty, .. }) => {
427 recurse_pat(used, pat, bound);
428 recurse_type(used, ty, bound);
429 }
430 _ => unimplemented!("unknown pattern kind: {:?}", pat),
431 }
432 }
433
434 fn recurse_stmt(used: &mut HashSet<GenericRef>, stmt: &Stmt, bound: &HashSet<GenericRef>) {
435 match stmt {
436 Stmt::Local(Local { pat, init, .. }) => {
437 recurse_pat(used, pat, bound);
438 if let Some(LocalInit { expr, diverge, .. }) = init {
439 recurse_expr(used, expr, bound);
440 if let Some((_, diverge)) = diverge {
441 recurse_expr(used, &diverge, bound);
442 }
443 }
444 }
445 Stmt::Expr(expr, _) => recurse_expr(used, expr, bound),
446 _ => unimplemented!("unknown statement kind: {:?}", stmt),
447 }
448 }
449
450 fn finalize(
451 used: &mut HashSet<GenericRef>,
452 bound: &HashSet<GenericRef>,
453 base: Generics,
454 ) -> Generics {
455 let mut args = Vec::new();
456 for arg in base.params.into_pairs().rev() {
457 match arg.value() {
458 GenericParam::Type(type_param) => {
459 if used.contains(&GenericRef::Type(type_param.ident.clone())) {
460 process_bounds(used, bound, type_param.bounds.iter());
461 args.push(arg);
462 }
463 }
464 GenericParam::Lifetime(lt) => {
465 if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
466 for b in <.bounds {
467 process_lifetime(used, bound, b.clone());
468 }
469 args.push(arg);
470 }
471 }
472 GenericParam::Const(_) => todo!(),
473 }
474 }
475
476 if args.is_empty() {
477 Generics::default()
478 } else {
479 Generics {
480 params: Punctuated::from_iter(args.into_iter().rev()),
481 ..base
482 }
483 }
484 }
485
486 let mut used = HashSet::new();
487 let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
488
489 for u in usage {
490 match u {
491 Usage::Type(ty) => recurse_type(&mut used, ty, &bound),
492 Usage::Lifetime(lt) => process_lifetime(&mut used, &bound, lt.clone()),
493 Usage::Expression(expr) => recurse_expr(&mut used, expr, &bound),
494 Usage::TypeBound(b) => process_bounds(&mut used, &bound, once(b)),
495 }
496 }
497
498 finalize(&mut used, &bound, base)
499}
500
501pub fn generics_as_args(generics: &Generics) -> PathArguments {
512 if generics.params.is_empty() {
513 PathArguments::None
514 } else {
515 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
516 colon2_token: None,
517 lt_token: generics.lt_token.unwrap_or_default(),
518 args: Punctuated::from_iter(generics.params.pairs().map(|p| {
519 let (param, punct) = p.into_tuple();
520 Pair::new(
521 match param {
522 GenericParam::Type(type_param) => {
523 GenericArgument::Type(Type::Path(TypePath {
524 qself: None,
525 path: type_param.ident.clone().into(),
526 }))
527 }
528 GenericParam::Lifetime(lt) => {
529 GenericArgument::Lifetime(lt.lifetime.clone())
530 }
531 GenericParam::Const(_) => todo!(),
532 },
533 punct.cloned(),
534 )
535 })),
536 gt_token: generics.gt_token.unwrap_or_default(),
537 })
538 }
539}