1use imbl::HashSet;
2use syn::{
3 punctuated::{Pair, Punctuated},
4 *,
5};
6
7pub enum Usage<'a> {
8 Type(&'a Type),
9 Expression(&'a Expr),
10}
11
12pub fn filter_generics<'a>(
17 base: Generics,
18 usage: impl Iterator<Item = Usage<'a>>,
19 context: impl Iterator<Item = &'a Generics>,
20) -> Generics {
21 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22 enum GenericRef {
23 Type(Ident),
24 Lifetime(Lifetime),
25 Const(Ident),
26 }
27
28 impl From<&GenericParam> for GenericRef {
29 fn from(value: &GenericParam) -> Self {
30 match value {
31 GenericParam::Type(type_param) => GenericRef::Type(type_param.ident.clone()),
32 GenericParam::Lifetime(lt) => GenericRef::Lifetime(lt.lifetime.clone()),
33 GenericParam::Const(c) => GenericRef::Const(c.ident.clone()),
34 }
35 }
36 }
37
38 fn add_bound_lifetimes(_bound: &mut HashSet<GenericRef>, b: Option<&BoundLifetimes>) {
39 if let Some(_lifetimes) = b {
40 unimplemented!()
48 }
49 }
50
51 fn process_lifetime(used: &mut HashSet<GenericRef>, bound: &HashSet<GenericRef>, lt: Lifetime) {
52 let r = GenericRef::Lifetime(lt);
53 if !bound.contains(&r) {
54 used.insert(r);
55 }
56 }
57
58 fn process_generic_arguments(
59 used: &mut HashSet<GenericRef>,
60 bound: &HashSet<GenericRef>,
61 args: &AngleBracketedGenericArguments,
62 ) {
63 for arg in &args.args {
64 match arg {
65 GenericArgument::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
66 GenericArgument::Type(ty) => recurse_type(used, ty, bound),
67 GenericArgument::Const(_) => unimplemented!(),
68 GenericArgument::AssocType(assoc_type) => recurse_type(used, &assoc_type.ty, bound),
69 GenericArgument::AssocConst(assoc_const) => {
70 recurse_expr(used, &assoc_const.value, bound)
71 }
72 GenericArgument::Constraint(constraint) => {
73 process_bounds(used, bound, constraint.bounds.iter())
74 }
75 _ => unimplemented!(),
76 }
77 }
78 }
79
80 fn process_path(
81 used: &mut HashSet<GenericRef>,
82 bound: &HashSet<GenericRef>,
83 path: &Path,
84 unqualified: bool,
85 ) {
86 if let Some(ident) = path.get_ident() {
87 let r = GenericRef::Type(ident.clone());
88 if !bound.contains(&r) {
89 used.insert(r);
90 }
91 } else {
92 let mut first = unqualified && path.leading_colon.is_none();
93 for s in &path.segments {
94 if first && s.arguments.is_empty() {
95 let r = GenericRef::Type(s.ident.clone());
96 if !bound.contains(&r) {
97 used.insert(r);
98 }
99 }
100 first = false;
101 match &s.arguments {
102 PathArguments::None => {}
103 PathArguments::AngleBracketed(args) => {
104 process_generic_arguments(used, bound, args)
105 }
106 PathArguments::Parenthesized(args) => {
107 for ty in &args.inputs {
108 recurse_type(used, ty, bound);
109 }
110 if let ReturnType::Type(_, ty) = &args.output {
111 recurse_type(used, ty, bound)
112 }
113 }
114 }
115 }
116 }
117 }
118
119 fn process_bounds<'a>(
120 used: &mut HashSet<GenericRef>,
121 bound: &HashSet<GenericRef>,
122 b: impl Iterator<Item = &'a TypeParamBound>,
123 ) {
124 for b in b {
125 match b {
126 TypeParamBound::Trait(b) => {
127 let mut bound = bound.clone();
128 add_bound_lifetimes(&mut bound, b.lifetimes.as_ref());
129 process_path(used, &bound, &b.path, true);
130 }
131 TypeParamBound::Lifetime(lt) => process_lifetime(used, bound, lt.clone()),
132 _ => unimplemented!(),
133 }
134 }
135 }
136
137 fn recurse_type(used: &mut HashSet<GenericRef>, ty: &Type, bound: &HashSet<GenericRef>) {
138 match ty {
139 Type::Array(arr) => recurse_type(used, &arr.elem, bound),
140 Type::BareFn(bare_fn) => {
141 let mut bound = bound.clone();
142 add_bound_lifetimes(&mut bound, bare_fn.lifetimes.as_ref());
143
144 for input in &bare_fn.inputs {
145 recurse_type(used, &input.ty, &bound)
146 }
147
148 if let ReturnType::Type(_, ty) = &bare_fn.output {
149 recurse_type(used, ty, &bound)
150 }
151 }
152 Type::Group(group) => recurse_type(used, &group.elem, bound),
153 Type::ImplTrait(impl_trait) => process_bounds(used, bound, impl_trait.bounds.iter()),
154 Type::Never(_) => {}
157 Type::Paren(paren) => recurse_type(used, &paren.elem, bound),
158 Type::Path(path) => {
159 if let Some(qself) = &path.qself {
160 recurse_type(used, &qself.ty, bound);
161 }
162 process_path(used, bound, &path.path, path.qself.is_none());
163 }
164 Type::Ptr(ptr) => recurse_type(used, &ptr.elem, bound),
165 Type::Reference(reference) => recurse_type(used, &reference.elem, bound),
166 Type::Slice(slice) => recurse_type(used, &slice.elem, bound),
167 Type::TraitObject(trait_object) => {
168 process_bounds(used, bound, trait_object.bounds.iter());
169 }
170 Type::Tuple(tuple) => {
171 for ty in &tuple.elems {
172 recurse_type(used, ty, bound);
173 }
174 }
175 ty => panic!("unsupported type: {:?}", ty),
177 }
178 }
179
180 fn recurse_expr(used: &mut HashSet<GenericRef>, expr: &Expr, bound: &HashSet<GenericRef>) {
181 match expr {
182 Expr::Array(ExprArray { elems, .. }) | Expr::Tuple(ExprTuple { elems, .. }) => {
183 for expr in elems {
184 recurse_expr(used, expr, bound);
185 }
186 }
187 Expr::Assign(ExprAssign { left, right, .. })
188 | Expr::Binary(ExprBinary { left, right, .. })
189 | Expr::Index(ExprIndex {
190 expr: left,
191 index: right,
192 ..
193 })
194 | Expr::Repeat(ExprRepeat {
195 expr: left,
196 len: right,
197 ..
198 }) => {
199 recurse_expr(used, left, bound);
200 recurse_expr(used, right, bound);
201 }
202 Expr::Async(ExprAsync {
203 block: Block { stmts, .. },
204 ..
205 })
206 | Expr::Block(ExprBlock {
207 block: Block { stmts, .. },
208 ..
209 })
210 | Expr::Loop(ExprLoop {
211 body: Block { stmts, .. },
212 ..
213 })
214 | Expr::TryBlock(ExprTryBlock {
215 block: Block { stmts, .. },
216 ..
217 })
218 | Expr::Unsafe(ExprUnsafe {
219 block: Block { stmts, .. },
220 ..
221 }) => {
222 for stmt in stmts {
223 recurse_stmt(used, stmt, bound);
224 }
225 }
226 Expr::Await(ExprAwait { base: expr, .. })
227 | Expr::Field(ExprField { base: expr, .. })
228 | Expr::Group(ExprGroup { expr, .. })
229 | Expr::Paren(ExprParen { expr, .. })
230 | Expr::Reference(ExprReference { expr, .. })
231 | Expr::Try(ExprTry { expr, .. })
232 | Expr::Unary(ExprUnary { expr, .. }) => recurse_expr(used, expr, bound),
233 Expr::Break(ExprBreak { expr, .. }) | Expr::Return(ExprReturn { expr, .. }) => {
234 if let Some(expr) = expr {
235 recurse_expr(used, expr, bound);
236 }
237 }
238 Expr::Call(ExprCall { func, args, .. }) => {
239 recurse_expr(used, func, bound);
240 for arg in args {
241 recurse_expr(used, arg, bound);
242 }
243 }
244 Expr::Cast(ExprCast { expr, ty, .. }) => {
245 recurse_expr(used, expr, bound);
246 recurse_type(used, ty, bound);
247 }
248 Expr::Closure(ExprClosure {
249 inputs,
250 output,
251 body,
252 ..
253 }) => {
254 for pat in inputs {
255 recurse_pat(used, pat, bound);
256 }
257 if let ReturnType::Type(_, ty) = output {
258 recurse_type(used, ty, bound);
259 }
260 recurse_expr(used, body, bound);
261 }
262 Expr::Continue(_) | Expr::Lit(_) => {}
263 Expr::ForLoop(ExprForLoop {
264 pat,
265 expr,
266 body: Block { stmts, .. },
267 ..
268 }) => {
269 recurse_pat(used, pat, bound);
270 recurse_expr(used, expr, bound);
271 for stmt in stmts {
272 recurse_stmt(used, stmt, bound);
273 }
274 }
275 Expr::If(ExprIf {
276 cond,
277 then_branch: Block { stmts, .. },
278 else_branch,
279 ..
280 }) => {
281 recurse_expr(used, cond, bound);
282 for stmt in stmts {
283 recurse_stmt(used, stmt, bound);
284 }
285 if let Some((_, expr)) = else_branch {
286 recurse_expr(used, expr, bound);
287 }
288 }
289 Expr::Let(ExprLet { pat, expr, .. }) => {
290 recurse_pat(used, pat, bound);
291 recurse_expr(used, expr, bound);
292 }
293 Expr::Match(ExprMatch { expr, arms, .. }) => {
294 recurse_expr(used, expr, bound);
295 for Arm {
296 pat, guard, body, ..
297 } in arms
298 {
299 recurse_pat(used, pat, bound);
300 if let Some((_, expr)) = guard {
301 recurse_expr(used, expr, bound);
302 }
303 recurse_expr(used, body, bound);
304 }
305 }
306 Expr::MethodCall(ExprMethodCall {
307 receiver,
308 turbofish,
309 args,
310 ..
311 }) => {
312 recurse_expr(used, receiver, bound);
313 if let Some(args) = turbofish {
314 process_generic_arguments(used, bound, args);
315 }
316 for arg in args {
317 recurse_expr(used, arg, bound);
318 }
319 }
320 Expr::Path(ExprPath { qself, path, .. }) => {
321 process_path(used, bound, path, true);
322 if let Some(QSelf { ty, .. }) = qself {
323 recurse_type(used, ty, bound);
324 }
325 }
326 Expr::Range(ExprRange { start, end, .. }) => {
327 if let Some(expr) = start {
328 recurse_expr(used, expr, bound);
329 }
330 if let Some(expr) = end {
331 recurse_expr(used, expr, bound);
332 }
333 }
334 Expr::Struct(ExprStruct { path, fields, .. }) => {
335 process_path(used, bound, path, true);
336 for FieldValue { expr, .. } in fields {
337 recurse_expr(used, expr, bound);
338 }
339 }
340 Expr::While(ExprWhile {
342 cond,
343 body: Block { stmts, .. },
344 ..
345 }) => {
346 recurse_expr(used, cond, bound);
347 for stmt in stmts {
348 recurse_stmt(used, stmt, bound);
349 }
350 }
351 Expr::Yield(ExprYield { expr, .. }) => {
352 if let Some(expr) = expr {
353 recurse_expr(used, expr, bound);
354 }
355 }
356 expr => panic!("unsupported expression: {:?}", expr),
357 }
358 }
359
360 fn recurse_pat(used: &mut HashSet<GenericRef>, pat: &Pat, bound: &HashSet<GenericRef>) {
361 match pat {
362 Pat::Const(ExprConst {
363 block: Block { stmts, .. },
364 ..
365 }) => {
366 for stmt in stmts {
367 recurse_stmt(used, stmt, bound);
368 }
369 }
370 Pat::Ident(_) | Pat::Lit(_) | Pat::Macro(_) | Pat::Rest(_) | Pat::Wild(_) => {}
371 Pat::Or(PatOr { cases, .. }) => {
372 for pat in cases {
373 recurse_pat(used, pat, bound);
374 }
375 }
376 Pat::Paren(PatParen { pat, .. }) => recurse_pat(used, pat, bound),
377 Pat::Path(ExprPath { qself, path, .. }) => {
378 process_path(used, bound, path, true);
379 if let Some(QSelf { ty, .. }) = qself {
380 recurse_type(used, ty, bound);
381 }
382 }
383 Pat::Range(ExprRange { start, end, .. }) => {
384 if let Some(expr) = start {
385 recurse_expr(used, expr, bound);
386 }
387 if let Some(expr) = end {
388 recurse_expr(used, expr, bound);
389 }
390 }
391 Pat::Reference(PatReference { pat, .. }) => recurse_pat(used, pat, bound),
392 Pat::Slice(PatSlice { elems, .. }) | Pat::Tuple(PatTuple { elems, .. }) => {
393 for pat in elems {
394 recurse_pat(used, pat, bound);
395 }
396 }
397 Pat::Struct(PatStruct {
398 qself,
399 path,
400 fields,
401 ..
402 }) => {
403 process_path(used, bound, path, true);
404 if let Some(QSelf { ty, .. }) = qself {
405 recurse_type(used, ty, bound);
406 }
407 for FieldPat { pat, .. } in fields {
408 recurse_pat(used, pat, bound);
409 }
410 }
411 Pat::TupleStruct(PatTupleStruct {
412 qself, path, elems, ..
413 }) => {
414 process_path(used, bound, path, true);
415 if let Some(QSelf { ty, .. }) = qself {
416 recurse_type(used, ty, bound);
417 }
418 for pat in elems {
419 recurse_pat(used, pat, bound);
420 }
421 }
422 Pat::Type(PatType { pat, ty, .. }) => {
423 recurse_pat(used, pat, bound);
424 recurse_type(used, ty, bound);
425 }
426 _ => unimplemented!(),
427 }
428 }
429
430 fn recurse_stmt(used: &mut HashSet<GenericRef>, stmt: &Stmt, bound: &HashSet<GenericRef>) {
431 match stmt {
432 Stmt::Local(Local { pat, init, .. }) => {
433 recurse_pat(used, pat, bound);
434 if let Some(LocalInit { expr, diverge, .. }) = init {
435 recurse_expr(used, expr, bound);
436 if let Some((_, diverge)) = diverge {
437 recurse_expr(used, &diverge, bound);
438 }
439 }
440 }
441 Stmt::Expr(expr, _) => recurse_expr(used, expr, bound),
442 _ => unimplemented!(),
443 }
444 }
445
446 fn finalize(
447 used: &mut HashSet<GenericRef>,
448 bound: &HashSet<GenericRef>,
449 base: Generics,
450 ) -> Generics {
451 let mut args = Vec::new();
452 for arg in base.params.into_pairs().rev() {
453 match arg.value() {
454 GenericParam::Type(type_param) => {
455 if used.contains(&GenericRef::Type(type_param.ident.clone())) {
456 process_bounds(used, bound, type_param.bounds.iter());
457 args.push(arg);
458 }
459 }
460 GenericParam::Lifetime(lt) => {
461 if used.contains(&GenericRef::Lifetime(lt.lifetime.clone())) {
462 for b in <.bounds {
463 process_lifetime(used, bound, b.clone());
464 }
465 args.push(arg);
466 }
467 }
468 GenericParam::Const(_) => todo!(),
469 }
470 }
471
472 if args.is_empty() {
473 Generics::default()
474 } else {
475 Generics {
476 params: Punctuated::from_iter(args.into_iter().rev()),
477 ..base
478 }
479 }
480 }
481
482 let mut used = HashSet::new();
483 let bound = HashSet::from_iter(context.flat_map(|g| g.params.iter()).map(GenericRef::from));
484
485 for u in usage {
486 match u {
487 Usage::Type(ty) => recurse_type(&mut used, ty, &bound),
488 Usage::Expression(expr) => recurse_expr(&mut used, expr, &bound),
489 }
490 }
491
492 finalize(&mut used, &bound, base)
493}
494
495pub fn generics_as_args(generics: &Generics) -> PathArguments {
506 if generics.params.is_empty() {
507 PathArguments::None
508 } else {
509 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
510 colon2_token: None,
511 lt_token: generics.lt_token.unwrap_or_default(),
512 args: Punctuated::from_iter(generics.params.pairs().map(|p| {
513 let (param, punct) = p.into_tuple();
514 Pair::new(
515 match param {
516 GenericParam::Type(type_param) => {
517 GenericArgument::Type(Type::Path(TypePath {
518 qself: None,
519 path: type_param.ident.clone().into(),
520 }))
521 }
522 GenericParam::Lifetime(lt) => {
523 GenericArgument::Lifetime(lt.lifetime.clone())
524 }
525 GenericParam::Const(_) => todo!(),
526 },
527 punct.cloned(),
528 )
529 })),
530 gt_token: generics.gt_token.unwrap_or_default(),
531 })
532 }
533}