cglue_gen/
generics.rs

1use crate::util::{parse_punctuated, recurse_type_to_path};
2use proc_macro2::TokenStream;
3use quote::*;
4use std::collections::{HashMap, HashSet};
5use syn::parse::{Parse, ParseStream};
6use syn::punctuated::Punctuated;
7use syn::token::{Comma, Gt, Lt};
8use syn::*;
9
10fn ident_path(ident: Ident) -> Type {
11    let mut path = Path {
12        leading_colon: None,
13        segments: Punctuated::new(),
14    };
15
16    path.segments.push_value(PathSegment {
17        ident,
18        arguments: Default::default(),
19    });
20
21    Type::Path(TypePath { qself: None, path })
22}
23
24fn ty_ident(ty: &Type) -> Option<&Ident> {
25    if let Type::Path(path) = ty {
26        if path.qself.is_none() {
27            path.path.get_ident()
28        } else {
29            None
30        }
31    } else {
32        None
33    }
34}
35
36#[derive(Clone)]
37pub struct ParsedGenerics {
38    /// Lifetime declarations on the left side of the type/trait.
39    ///
40    /// This may include any bounds it contains, for instance: `'a: 'b,`.
41    pub life_declare: Punctuated<LifetimeDef, Comma>,
42    /// Declarations "using" the lifetimes i.e. has bounds stripped.
43    ///
44    /// For instance: `'a: 'b,` becomes just `'a,`.
45    pub life_use: Punctuated<Lifetime, Comma>,
46    /// Type declarations on the left side of the type/trait.
47    ///
48    /// This may include any trait bounds it contains, for instance: `T: Clone,`.
49    pub gen_declare: Punctuated<TypeParam, Comma>,
50    /// Declarations that "use" the traits i.e. has bounds stripped.
51    ///
52    /// For instance: `T: Clone,` becomes just `T,`.
53    pub gen_use: Punctuated<Type, Comma>,
54    /// All where predicates, without the `where` keyword.
55    pub gen_where_bounds: Punctuated<WherePredicate, Comma>,
56    /// Remap generic T to a particular type using T = type syntax.
57    ///
58    /// Then, when generics get cross referenced, all concrete T declarations get removed, and T
59    /// uses get replaced with concrete types.
60    pub gen_remaps: HashMap<Ident, Type>,
61}
62
63impl ParsedGenerics {
64    pub fn declare_without_nonstatic_bounds(&self) -> Punctuated<TypeParam, Comma> {
65        let mut ret = self.gen_declare.clone();
66
67        for p in ret.iter_mut() {
68            p.bounds = std::mem::take(&mut p.bounds)
69                .into_iter()
70                .filter(|b| {
71                    if let TypeParamBound::Lifetime(lt) = b {
72                        lt.ident == "static"
73                    } else {
74                        true
75                    }
76                })
77                .collect();
78        }
79
80        ret
81    }
82
83    #[cfg(feature = "layout_checks")]
84    pub fn declare_lt_for_all(&self, lt: &TokenStream) -> TokenStream {
85        let mut ts = TokenStream::new();
86
87        for p in &self.gen_use {
88            ts.extend(quote!(#p: #lt,));
89        }
90
91        ts
92    }
93
94    #[cfg(not(feature = "layout_checks"))]
95    pub fn declare_lt_for_all(&self, _: &TokenStream) -> TokenStream {
96        Default::default()
97    }
98
99    pub fn declare_sabi_for_all(&self, crate_path: &TokenStream) -> TokenStream {
100        let mut ts = TokenStream::new();
101
102        for p in &self.gen_use {
103            ts.extend(quote!(#p: #crate_path::trait_group::GenericTypeBounds,));
104        }
105
106        ts
107    }
108
109    /// This function cross references input lifetimes and returns a new Self
110    /// that only contains generic type information about those types.
111    pub fn cross_ref<'a>(&self, input: impl IntoIterator<Item = &'a ParsedGenerics>) -> Self {
112        let mut applied_lifetimes = HashSet::<&Ident>::new();
113        let mut applied_typenames = HashSet::<&Type>::new();
114
115        let mut life_declare = Punctuated::new();
116        let mut life_use = Punctuated::new();
117        let mut gen_declare = Punctuated::new();
118        let mut gen_use = Punctuated::<Type, _>::new();
119        let mut gen_where_bounds = Punctuated::new();
120
121        for ParsedGenerics {
122            life_use: in_lu,
123            gen_use: in_gu,
124            ..
125        } in input
126        {
127            for lt in in_lu.iter() {
128                if applied_lifetimes.contains(&lt.ident) {
129                    continue;
130                }
131
132                let decl = self
133                    .life_declare
134                    .iter()
135                    .find(|ld| ld.lifetime.ident == lt.ident)
136                    .expect("Gen 1");
137
138                life_declare.push_value(decl.clone());
139                life_declare.push_punct(Default::default());
140                life_use.push_value(decl.lifetime.clone());
141                life_use.push_punct(Default::default());
142
143                applied_lifetimes.insert(&lt.ident);
144            }
145
146            for ty in in_gu.iter() {
147                if applied_typenames.contains(&ty) {
148                    continue;
149                }
150
151                let (decl, ident) = self
152                    .gen_declare
153                    .iter()
154                    .zip(self.gen_use.iter())
155                    .find(|(_, ident)| *ident == ty)
156                    .expect("Gen 2");
157
158                gen_declare.push_value(decl.clone());
159                gen_declare.push_punct(Default::default());
160                gen_use.push_value(ident_path(decl.ident.clone()));
161                gen_use.push_punct(Default::default());
162
163                applied_typenames.insert(ident);
164            }
165        }
166
167        for wb in self.gen_where_bounds.iter() {
168            if match wb {
169                WherePredicate::Type(ty) => applied_typenames.contains(&ty.bounded_ty),
170                WherePredicate::Lifetime(lt) => applied_lifetimes.contains(&lt.lifetime.ident),
171                _ => false,
172            } {
173                gen_where_bounds.push_value(wb.clone());
174                gen_where_bounds.push_punct(Default::default());
175            }
176        }
177
178        Self {
179            life_declare,
180            life_use,
181            gen_declare,
182            gen_use,
183            gen_where_bounds,
184            gen_remaps: Default::default(),
185        }
186    }
187
188    pub fn merge_remaps(&mut self, other: &mut ParsedGenerics) {
189        self.gen_remaps
190            .extend(std::mem::take(&mut other.gen_remaps));
191        other.gen_remaps = self.gen_remaps.clone();
192    }
193
194    pub fn merge_and_remap(&mut self, other: &mut ParsedGenerics) {
195        self.merge_remaps(other);
196        self.remap_types();
197        other.remap_types();
198    }
199
200    pub fn remap_types(&mut self) {
201        let old_gen_declare = std::mem::take(&mut self.gen_declare);
202        let old_gen_use = std::mem::take(&mut self.gen_use);
203
204        for val in old_gen_declare.into_pairs() {
205            match val {
206                punctuated::Pair::Punctuated(p, punc) => {
207                    if !self.gen_remaps.contains_key(&p.ident) {
208                        self.gen_declare.push_value(p);
209                        self.gen_declare.push_punct(punc);
210                    }
211                }
212                punctuated::Pair::End(p) => {
213                    if !self.gen_remaps.contains_key(&p.ident) {
214                        self.gen_declare.push_value(p);
215                    }
216                }
217            }
218        }
219
220        for val in old_gen_use.into_pairs() {
221            match val {
222                punctuated::Pair::Punctuated(p, punc) => {
223                    if let Some(ident) = ty_ident(&p) {
224                        self.gen_use
225                            .push_value(self.gen_remaps.get(ident).cloned().unwrap_or(p));
226                    } else {
227                        self.gen_use.push_value(p);
228                    }
229                    self.gen_use.push_punct(punc);
230                }
231                punctuated::Pair::End(p) => {
232                    if let Some(ident) = ty_ident(&p) {
233                        self.gen_use
234                            .push_value(self.gen_remaps.get(ident).cloned().unwrap_or(p));
235                    } else {
236                        self.gen_use.push_value(p);
237                    }
238                }
239            }
240        }
241    }
242
243    /// Generate phantom data definitions for all lifetimes and types used.
244    pub fn phantom_data_definitions(&self) -> TokenStream {
245        let mut stream = TokenStream::new();
246
247        for ty in self.gen_declare.iter() {
248            let ty_ident = format_ident!("_ty_{}", ty.ident.to_string().to_lowercase());
249            let ty = &ty.ident;
250            stream.extend(quote!(#ty_ident: ::core::marker::PhantomData<#ty>,));
251        }
252
253        stream
254    }
255
256    /// Generate phantom data initializations for all lifetimes and types used.
257    pub fn phantom_data_init(&self) -> TokenStream {
258        let mut stream = TokenStream::new();
259
260        for ty in self.gen_declare.iter() {
261            let ty_ident = format_ident!("_ty_{}", ty.ident.to_string().to_lowercase());
262            stream.extend(quote!(#ty_ident: ::core::marker::PhantomData{},));
263        }
264
265        stream
266    }
267
268    /// Replace generic arguments on the type with ones stored within Self.
269    ///
270    /// The same generic args are replaced as the ones extracted from `util::recurse_type_to_path`.
271    pub fn replace_on_type(&self, ty: &mut Type) {
272        recurse_type_to_path(ty, |path| {
273            let mut generics = None;
274            for part in path.segments.pairs_mut() {
275                if let punctuated::Pair::End(p) = part {
276                    if let PathArguments::AngleBracketed(arg) = &mut p.arguments {
277                        generics = Some(arg);
278                    }
279                }
280            }
281
282            let life_use = &self.life_use;
283            let gen_use = &self.gen_use;
284
285            if let Some(generics) = generics {
286                *generics = syn::parse2(quote!(<#life_use #gen_use>)).unwrap();
287            }
288
289            Some(())
290        });
291    }
292
293    pub fn extract_lifetimes(&mut self, ty: &Type) {
294        fn extract_nonpath_lifetimes(ty: &Type, out: &mut HashSet<Lifetime>) {
295            match ty {
296                Type::Array(TypeArray { elem, .. }) => extract_nonpath_lifetimes(elem, out),
297                Type::Group(TypeGroup { elem, .. }) => extract_nonpath_lifetimes(elem, out),
298                Type::Paren(TypeParen { elem, .. }) => extract_nonpath_lifetimes(elem, out),
299                Type::Ptr(TypePtr { elem, .. }) => extract_nonpath_lifetimes(elem, out),
300                Type::Reference(TypeReference { elem, lifetime, .. }) => {
301                    if let Some(lifetime) = lifetime {
302                        out.insert(lifetime.clone());
303                    }
304                    extract_nonpath_lifetimes(elem, out)
305                }
306                Type::Slice(TypeSlice { elem, .. }) => extract_nonpath_lifetimes(elem, out),
307                _ => (),
308            }
309        }
310
311        let mut lifetimes = HashSet::new();
312        extract_nonpath_lifetimes(ty, &mut lifetimes);
313
314        let existing_lifetimes = self
315            .life_declare
316            .iter()
317            .map(|l| &l.lifetime)
318            .collect::<HashSet<&Lifetime>>();
319
320        for lt in existing_lifetimes {
321            lifetimes.remove(lt);
322        }
323
324        for lt in lifetimes {
325            self.life_use.push_value(lt.clone());
326            self.life_use.push_punct(Default::default());
327            self.life_declare.push_value(LifetimeDef::new(lt));
328            self.life_declare.push_punct(Default::default());
329        }
330    }
331}
332
333impl<'a> std::iter::FromIterator<&'a ParsedGenerics> for ParsedGenerics {
334    fn from_iter<I: IntoIterator<Item = &'a ParsedGenerics>>(input: I) -> Self {
335        let mut life_declare = Punctuated::new();
336        let mut life_declared = HashSet::<&Ident>::new();
337
338        let mut life_use = Punctuated::new();
339        let mut gen_use = Punctuated::new();
340
341        let mut gen_declare = Punctuated::new();
342        let mut gen_declared = HashSet::<&Ident>::new();
343
344        let mut gen_where_bounds = Punctuated::new();
345
346        let mut gen_remaps = HashMap::default();
347
348        for val in input {
349            life_use.extend(val.life_use.clone());
350            gen_use.extend(val.gen_use.clone());
351
352            for life in val.life_declare.pairs() {
353                let (val, punct) = life.into_tuple();
354                if life_declared.contains(&val.lifetime.ident) {
355                    continue;
356                }
357                life_declare.push_value(val.clone());
358                if let Some(punct) = punct {
359                    life_declare.push_punct(*punct);
360                }
361                life_declared.insert(&val.lifetime.ident);
362            }
363
364            for gen in val.gen_declare.pairs() {
365                let (val, punct) = gen.into_tuple();
366                if gen_declared.contains(&val.ident) {
367                    continue;
368                }
369                gen_declare.push_value(val.clone());
370                if let Some(punct) = punct {
371                    gen_declare.push_punct(*punct);
372                }
373                gen_declared.insert(&val.ident);
374            }
375
376            gen_where_bounds.extend(val.gen_where_bounds.clone());
377            gen_remaps.extend(val.gen_remaps.clone());
378        }
379
380        if !gen_where_bounds.empty_or_trailing() {
381            gen_where_bounds.push_punct(Default::default());
382        }
383
384        Self {
385            life_declare,
386            life_use,
387            gen_declare,
388            gen_use,
389            gen_where_bounds,
390            gen_remaps,
391        }
392    }
393}
394
395impl From<Option<&Punctuated<GenericArgument, Comma>>> for ParsedGenerics {
396    fn from(input: Option<&Punctuated<GenericArgument, Comma>>) -> Self {
397        match input {
398            Some(input) => Self::from(input),
399            _ => Self {
400                life_declare: Punctuated::new(),
401                life_use: Punctuated::new(),
402                gen_declare: Punctuated::new(),
403                gen_use: Punctuated::new(),
404                gen_where_bounds: Punctuated::new(),
405                gen_remaps: Default::default(),
406            },
407        }
408    }
409}
410
411impl From<&Punctuated<GenericArgument, Comma>> for ParsedGenerics {
412    fn from(input: &Punctuated<GenericArgument, Comma>) -> Self {
413        let mut life_declare = Punctuated::new();
414        let mut life_use = Punctuated::new();
415        let mut gen_declare = Punctuated::new();
416        let mut gen_use = Punctuated::new();
417        let mut gen_remaps = HashMap::new();
418
419        for param in input {
420            match param {
421                GenericArgument::Type(ty) => {
422                    if let Some(ident) = ty_ident(ty).cloned() {
423                        gen_declare.push_value(TypeParam {
424                            attrs: vec![],
425                            ident,
426                            colon_token: None,
427                            bounds: Punctuated::new(),
428                            eq_token: None,
429                            default: None,
430                        });
431                        gen_declare.push_punct(Default::default());
432                    }
433                    gen_use.push_value(ty.clone());
434                    gen_use.push_punct(Default::default());
435                }
436                GenericArgument::Const(_cn) => {
437                    // TODO
438                }
439                GenericArgument::Lifetime(lifetime) => {
440                    life_use.push_value(lifetime.clone());
441                    life_use.push_punct(Default::default());
442                    life_declare.push_value(LifetimeDef {
443                        attrs: vec![],
444                        lifetime: lifetime.clone(),
445                        colon_token: None,
446                        bounds: Punctuated::new(),
447                    });
448                    life_declare.push_punct(Default::default());
449                }
450                GenericArgument::Constraint(constraint) => {
451                    gen_use.push_value(ident_path(constraint.ident.clone()));
452                    gen_use.push_punct(Default::default());
453                    gen_declare.push_value(TypeParam {
454                        attrs: vec![],
455                        ident: constraint.ident.clone(),
456                        colon_token: None,
457                        bounds: constraint.bounds.clone(),
458                        eq_token: None,
459                        default: None,
460                    });
461                    gen_declare.push_punct(Default::default());
462                }
463                GenericArgument::Binding(bind) => {
464                    gen_use.push_value(bind.ty.clone());
465                    gen_use.push_punct(Default::default());
466                    gen_remaps.insert(bind.ident.clone(), bind.ty.clone());
467                }
468            }
469        }
470
471        Self {
472            life_declare,
473            life_use,
474            gen_declare,
475            gen_use,
476            gen_where_bounds: Punctuated::new(),
477            gen_remaps,
478        }
479    }
480}
481
482impl From<&Generics> for ParsedGenerics {
483    fn from(input: &Generics) -> Self {
484        let gen_where = &input.where_clause;
485        let gen_where_bounds = gen_where.as_ref().map(|w| &w.predicates);
486
487        let mut life_declare = Punctuated::new();
488        let mut life_use = Punctuated::new();
489        let mut gen_declare = Punctuated::new();
490        let mut gen_use = Punctuated::new();
491
492        for param in input.params.iter() {
493            match param {
494                GenericParam::Type(ty) => {
495                    gen_use.push_value(ident_path(ty.ident.clone()));
496                    gen_use.push_punct(Default::default());
497                    gen_declare.push_value(ty.clone());
498                    gen_declare.push_punct(Default::default());
499                }
500                GenericParam::Const(_cn) => {
501                    // TODO
502                }
503                GenericParam::Lifetime(lt) => {
504                    let lifetime = &lt.lifetime;
505                    life_use.push_value(lifetime.clone());
506                    life_use.push_punct(Default::default());
507                    life_declare.push_value(lt.clone());
508                    life_declare.push_punct(Default::default());
509                }
510            }
511        }
512
513        Self {
514            life_declare,
515            life_use,
516            gen_declare,
517            gen_use,
518            gen_where_bounds: gen_where_bounds.cloned().unwrap_or_else(Punctuated::new),
519            gen_remaps: Default::default(),
520        }
521    }
522}
523
524fn parse_generic_arguments(input: ParseStream) -> Punctuated<GenericArgument, Comma> {
525    parse_punctuated(input)
526}
527
528impl Parse for ParsedGenerics {
529    fn parse(input: ParseStream) -> Result<Self> {
530        let gens = match input.parse::<Lt>() {
531            Ok(_) => {
532                let punct = parse_generic_arguments(input);
533                input.parse::<Gt>()?;
534                Some(punct)
535            }
536            _ => None,
537        };
538
539        let ret = Self::from(gens.as_ref());
540
541        if let Ok(mut clause) = input.parse::<WhereClause>() {
542            if !clause.predicates.trailing_punct() {
543                clause.predicates.push_punct(Default::default());
544            }
545
546            let predicates = &clause.predicates;
547
548            Ok(Self {
549                gen_where_bounds: predicates.clone(),
550                ..ret
551            })
552        } else {
553            Ok(ret)
554        }
555    }
556}
557
558pub struct GenericCastType {
559    pub expr: Box<Expr>,
560    pub target: GenericType,
561    pub ident: TokenStream,
562}
563
564impl Parse for GenericCastType {
565    fn parse(input: ParseStream) -> Result<Self> {
566        let cast: ExprCast = input.parse()?;
567
568        let expr = cast.expr;
569        let target = GenericType::from_type(&cast.ty, true);
570        let ident = GenericType::from_type(&cast.ty, false).target;
571
572        Ok(Self {
573            expr,
574            target,
575            ident,
576        })
577    }
578}
579
580pub struct GroupCastType {
581    pub expr: Box<Expr>,
582    pub target: GenericType,
583    pub ident: TokenStream,
584}
585
586impl Parse for GroupCastType {
587    fn parse(input: ParseStream) -> Result<Self> {
588        let cast: ExprCast = input.parse()?;
589
590        let expr = cast.expr;
591        let target = GenericType::from_type(&cast.ty, true);
592        let ident = GenericType::from_type(&cast.ty, false).target;
593
594        Ok(Self {
595            expr,
596            target,
597            ident,
598        })
599    }
600}
601
602#[derive(Clone)]
603pub struct GenericType {
604    /// Path to type (core:: in core::Option<T>)
605    pub path: Path,
606    /// Separator to use, this depends on `cast_to_group` parameter
607    pub gen_separator: TokenStream,
608    /// Generic lifetime parameters (there isn't an example in core::Option<T>)
609    pub generic_lifetimes: Punctuated<Lifetime, Comma>,
610    /// Generic type parameters (T in core::Option<T>)
611    pub generic_types: Punctuated<Type, Comma>,
612    /// The resulting type (Option in core::Option<T>)
613    pub target: TokenStream,
614}
615
616impl GenericType {
617    pub fn push_lifetime_start(&mut self, lifetime: &Lifetime) {
618        self.generic_lifetimes.insert(0, lifetime.clone());
619        if !self.generic_lifetimes.trailing_punct() {
620            self.generic_lifetimes.push_punct(Default::default());
621        }
622    }
623
624    pub fn push_types_start(&mut self, types: TokenStream) {
625        let typestr = types.to_string();
626        let mut types =
627            syn::parse::Parser::parse2(Punctuated::<Type, Comma>::parse_terminated, types)
628                .expect(&format!("Invalid types provided: {}", typestr));
629
630        if !types.trailing_punct() {
631            types.push_punct(Default::default());
632        }
633
634        // Swap here, because types becomes the start
635        std::mem::swap(&mut self.generic_types, &mut types);
636
637        self.generic_types.extend(types.into_iter());
638
639        if !self.generic_types.trailing_punct() {
640            self.generic_types.push_punct(Default::default());
641        }
642    }
643
644    pub fn push_types_end(&mut self, types: TokenStream) {
645        let typestr = types.to_string();
646        let types = syn::parse::Parser::parse2(Punctuated::<Type, Comma>::parse_terminated, types)
647            .expect(&format!("Invalid types provided: {}", typestr));
648
649        // Unlike in push_types_start, we do not swap them
650
651        if !self.generic_types.trailing_punct() {
652            self.generic_types.push_punct(Default::default());
653        }
654
655        self.generic_types.extend(types.into_iter());
656
657        if !self.generic_types.trailing_punct() {
658            self.generic_types.push_punct(Default::default());
659        }
660    }
661
662    fn from_type(target: &Type, cast_to_obj: bool) -> Self {
663        let (path, mut target, generics) = match target {
664            Type::Path(ty) => {
665                let (path, target, generics) =
666                    crate::util::split_path_ident(&ty.path).expect("Gen 3");
667                (path, quote!(#target), generics)
668            }
669            x => (
670                Path {
671                    leading_colon: None,
672                    segments: Default::default(),
673                },
674                quote!(#x),
675                None,
676            ),
677        };
678
679        let (generic_lifetimes, mut generic_types) = match &generics {
680            Some(params) => {
681                let pg = ParsedGenerics::from(params);
682                (pg.life_use, pg.gen_use)
683            }
684            _ => Default::default(),
685        };
686
687        let gen_separator = if cast_to_obj {
688            if generics.is_some() {
689                let infer = Type::Infer(TypeInfer {
690                    underscore_token: Default::default(),
691                });
692                generic_types.insert(0, infer.clone());
693                generic_types.insert(0, infer);
694
695                if !generic_types.trailing_punct() {
696                    generic_types.push_punct(Default::default());
697                }
698            }
699            target = format_ident!("{}Base", target.to_string()).to_token_stream();
700
701            quote!(::)
702        } else {
703            quote!()
704        };
705
706        Self {
707            path,
708            gen_separator,
709            generic_lifetimes,
710            generic_types,
711            target,
712        }
713    }
714}
715
716impl ToTokens for GenericType {
717    fn to_tokens(&self, tokens: &mut TokenStream) {
718        tokens.extend(self.path.to_token_stream());
719        tokens.extend(self.target.clone());
720        let generic_lifetimes = &self.generic_lifetimes;
721        let generic_types = &self.generic_types;
722        if !generic_lifetimes.is_empty() || !generic_types.is_empty() {
723            tokens.extend(self.gen_separator.clone());
724            tokens.extend(quote!(<#generic_lifetimes #generic_types>));
725        }
726    }
727}
728
729impl Parse for GenericType {
730    fn parse(input: ParseStream) -> Result<Self> {
731        let target: Type = input.parse()?;
732        Ok(Self::from_type(&target, false))
733    }
734}