virtue_next/parse/
generics.rs

1use super::utils::*;
2use crate::generate::StreamBuilder;
3use crate::prelude::{Ident, TokenTree};
4use crate::{Error, Result};
5use std::iter::Peekable;
6use std::ops::{Deref, DerefMut};
7
8/// A generic parameter for a struct or enum.
9///
10/// ```
11/// use std::marker::PhantomData;
12/// use std::fmt::Display;
13///
14/// // Generic will be `Generic::Generic("F")`
15/// struct Foo<F> {
16///     f: PhantomData<F>
17/// }
18/// // Generics will be `Generic::Generic("F: Display")`
19/// struct Bar<F: Display> {
20///     f: PhantomData<F>
21/// }
22/// // Generics will be `[Generic::Lifetime("a"), Generic::Generic("F: Display")]`
23/// struct Baz<'a, F> {
24///     f: PhantomData<&'a F>
25/// }
26/// ```
27#[derive(Debug, Clone)]
28pub struct Generics(pub Vec<Generic>);
29
30impl Generics {
31    pub(crate) fn try_take(
32        input: &mut Peekable<impl Iterator<Item = TokenTree>>,
33    ) -> Result<Option<Generics>> {
34        let maybe_punct = input.peek();
35        if let Some(TokenTree::Punct(punct)) = maybe_punct {
36            if punct.as_char() == '<' {
37                let punct = assume_punct(input.next(), '<');
38                let mut result = Generics(Vec::new());
39                loop {
40                    match input.peek() {
41                        Some(TokenTree::Punct(punct)) if punct.as_char() == '\'' => {
42                            result.push(Lifetime::take(input)?.into());
43                            consume_punct_if(input, ',');
44                        }
45                        Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
46                            assume_punct(input.next(), '>');
47                            break;
48                        }
49                        Some(TokenTree::Ident(ident)) if ident_eq(ident, "const") => {
50                            result.push(ConstGeneric::take(input)?.into());
51                            consume_punct_if(input, ',');
52                        }
53                        Some(TokenTree::Ident(_)) => {
54                            result.push(SimpleGeneric::take(input)?.into());
55                            consume_punct_if(input, ',');
56                        }
57                        x => {
58                            return Err(Error::InvalidRustSyntax {
59                                span: x.map(|x| x.span()).unwrap_or_else(|| punct.span()),
60                                expected: format!("', > or an ident, got {:?}", x),
61                            });
62                        }
63                    }
64                }
65                return Ok(Some(result));
66            }
67        }
68        Ok(None)
69    }
70
71    /// Returns `true` if any of the generics is a [`Generic::Lifetime`]
72    pub fn has_lifetime(&self) -> bool {
73        self.iter().any(|lt| lt.is_lifetime())
74    }
75
76    /// Returns an iterator which contains only the simple type generics
77    pub fn iter_generics(&self) -> impl Iterator<Item = &SimpleGeneric> {
78        self.iter().filter_map(|g| match g {
79            Generic::Generic(s) => Some(s),
80            _ => None,
81        })
82    }
83
84    /// Returns an iterator which contains only the lifetimes
85    pub fn iter_lifetimes(&self) -> impl Iterator<Item = &Lifetime> {
86        self.iter().filter_map(|g| match g {
87            Generic::Lifetime(s) => Some(s),
88            _ => None,
89        })
90    }
91
92    /// Returns an iterator which contains only the const generics
93    pub fn iter_consts(&self) -> impl Iterator<Item = &ConstGeneric> {
94        self.iter().filter_map(|g| match g {
95            Generic::Const(s) => Some(s),
96            _ => None,
97        })
98    }
99
100    pub(crate) fn impl_generics(&self) -> StreamBuilder {
101        let mut result = StreamBuilder::new();
102        result.punct('<');
103
104        for (idx, generic) in self.iter().enumerate() {
105            if idx > 0 {
106                result.punct(',');
107            }
108
109            generic.append_to_result_with_constraints(&mut result);
110        }
111
112        result.punct('>');
113
114        result
115    }
116
117    pub(crate) fn impl_generics_with_additional(
118        &self,
119        lifetimes: &[String],
120        types: &[String],
121    ) -> StreamBuilder {
122        let mut result = StreamBuilder::new();
123        result.punct('<');
124        let mut is_first = true;
125        for lt in lifetimes.iter() {
126            if !is_first {
127                result.punct(',');
128            } else {
129                is_first = false;
130            }
131            result.lifetime_str(lt);
132        }
133
134        for generic in self.iter() {
135            if !is_first {
136                result.punct(',');
137            } else {
138                is_first = false;
139            }
140            generic.append_to_result_with_constraints(&mut result);
141        }
142        for ty in types {
143            if !is_first {
144                result.punct(',');
145            } else {
146                is_first = false;
147            }
148            result.ident_str(ty);
149        }
150
151        result.punct('>');
152
153        result
154    }
155
156    pub(crate) fn type_generics(&self) -> StreamBuilder {
157        let mut result = StreamBuilder::new();
158        result.punct('<');
159
160        for (idx, generic) in self.iter().enumerate() {
161            if idx > 0 {
162                result.punct(',');
163            }
164            if generic.is_lifetime() {
165                result.lifetime(generic.ident().clone());
166            } else {
167                result.ident(generic.ident().clone());
168            }
169        }
170
171        result.punct('>');
172        result
173    }
174}
175
176impl Deref for Generics {
177    type Target = Vec<Generic>;
178
179    fn deref(&self) -> &Self::Target {
180        &self.0
181    }
182}
183
184impl DerefMut for Generics {
185    fn deref_mut(&mut self) -> &mut Self::Target {
186        &mut self.0
187    }
188}
189
190/// A single generic argument on a type
191#[derive(Debug, Clone)]
192#[allow(clippy::enum_variant_names)]
193#[non_exhaustive]
194pub enum Generic {
195    /// A lifetime generic
196    ///
197    /// ```
198    /// # use std::marker::PhantomData;
199    /// struct Foo<'a> { // will be Generic::Lifetime("a")
200    /// #   a: PhantomData<&'a ()>,
201    /// }
202    /// ```
203    Lifetime(Lifetime),
204    /// A simple generic
205    ///
206    /// ```
207    /// # use std::marker::PhantomData;
208    /// struct Foo<F> { // will be Generic::Generic("F")
209    /// #   a: PhantomData<F>,
210    /// }
211    /// ```
212    Generic(SimpleGeneric),
213    /// A const generic
214    ///
215    /// ```
216    /// struct Foo<const N: usize> { // will be Generic::Const("N")
217    /// #   a: [u8; N],
218    /// }
219    /// ```
220    Const(ConstGeneric),
221}
222
223impl Generic {
224    fn is_lifetime(&self) -> bool {
225        matches!(self, Generic::Lifetime(_))
226    }
227
228    /// The ident of this generic
229    pub fn ident(&self) -> &Ident {
230        match self {
231            Self::Lifetime(lt) => &lt.ident,
232            Self::Generic(gen) => &gen.ident,
233            Self::Const(gen) => &gen.ident,
234        }
235    }
236
237    fn has_constraints(&self) -> bool {
238        match self {
239            Self::Lifetime(lt) => !lt.constraint.is_empty(),
240            Self::Generic(gen) => !gen.constraints.is_empty(),
241            Self::Const(_) => true, // const generics always have a constraint
242        }
243    }
244
245    fn constraints(&self) -> Vec<TokenTree> {
246        match self {
247            Self::Lifetime(lt) => lt.constraint.clone(),
248            Self::Generic(gen) => gen.constraints.clone(),
249            Self::Const(gen) => gen.constraints.clone(),
250        }
251    }
252
253    fn append_to_result_with_constraints(&self, builder: &mut StreamBuilder) {
254        match self {
255            Self::Lifetime(lt) => builder.lifetime(lt.ident.clone()),
256            Self::Generic(gen) => builder.ident(gen.ident.clone()),
257            Self::Const(gen) => {
258                builder.ident(gen.const_token.clone());
259                builder.ident(gen.ident.clone())
260            }
261        };
262        if self.has_constraints() {
263            builder.punct(':');
264            builder.extend(self.constraints());
265        }
266    }
267}
268
269impl From<Lifetime> for Generic {
270    fn from(lt: Lifetime) -> Self {
271        Self::Lifetime(lt)
272    }
273}
274
275impl From<SimpleGeneric> for Generic {
276    fn from(gen: SimpleGeneric) -> Self {
277        Self::Generic(gen)
278    }
279}
280
281impl From<ConstGeneric> for Generic {
282    fn from(gen: ConstGeneric) -> Self {
283        Self::Const(gen)
284    }
285}
286
287#[test]
288fn test_generics_try_take() {
289    use crate::token_stream;
290
291    assert!(Generics::try_take(&mut token_stream("")).unwrap().is_none());
292    assert!(Generics::try_take(&mut token_stream("foo"))
293        .unwrap()
294        .is_none());
295    assert!(Generics::try_take(&mut token_stream("()"))
296        .unwrap()
297        .is_none());
298
299    let stream = &mut token_stream("struct Foo<'a, T>()");
300    let (data_type, ident) = super::DataType::take(stream).unwrap();
301    assert_eq!(data_type, super::DataType::Struct);
302    assert_eq!(ident, "Foo");
303    let generics = Generics::try_take(stream).unwrap().unwrap();
304    assert_eq!(generics.len(), 2);
305    assert_eq!(generics[0].ident(), "a");
306    assert_eq!(generics[1].ident(), "T");
307
308    let stream = &mut token_stream("struct Foo<A, B>()");
309    let (data_type, ident) = super::DataType::take(stream).unwrap();
310    assert_eq!(data_type, super::DataType::Struct);
311    assert_eq!(ident, "Foo");
312    let generics = Generics::try_take(stream).unwrap().unwrap();
313    assert_eq!(generics.len(), 2);
314    assert_eq!(generics[0].ident(), "A");
315    assert_eq!(generics[1].ident(), "B");
316
317    let stream = &mut token_stream("struct Foo<'a, T: Display>()");
318    let (data_type, ident) = super::DataType::take(stream).unwrap();
319    assert_eq!(data_type, super::DataType::Struct);
320    assert_eq!(ident, "Foo");
321    let generics = Generics::try_take(stream).unwrap().unwrap();
322    dbg!(&generics);
323    assert_eq!(generics.len(), 2);
324    assert_eq!(generics[0].ident(), "a");
325    assert_eq!(generics[1].ident(), "T");
326
327    let stream = &mut token_stream("struct Foo<'a, T: for<'a> Bar<'a> + 'static>()");
328    let (data_type, ident) = super::DataType::take(stream).unwrap();
329    assert_eq!(data_type, super::DataType::Struct);
330    assert_eq!(ident, "Foo");
331    dbg!(&generics);
332    assert_eq!(generics.len(), 2);
333    assert_eq!(generics[0].ident(), "a");
334    assert_eq!(generics[1].ident(), "T");
335
336    let stream = &mut token_stream(
337        "struct Baz<T: for<'a> Bar<'a, for<'b> Bar<'b, for<'c> Bar<'c, u32>>>> {}",
338    );
339    let (data_type, ident) = super::DataType::take(stream).unwrap();
340    assert_eq!(data_type, super::DataType::Struct);
341    assert_eq!(ident, "Baz");
342    let generics = Generics::try_take(stream).unwrap().unwrap();
343    dbg!(&generics);
344    assert_eq!(generics.len(), 1);
345    assert_eq!(generics[0].ident(), "T");
346
347    let stream = &mut token_stream("struct Baz<()> {}");
348    let (data_type, ident) = super::DataType::take(stream).unwrap();
349    assert_eq!(data_type, super::DataType::Struct);
350    assert_eq!(ident, "Baz");
351    assert!(Generics::try_take(stream)
352        .unwrap_err()
353        .is_invalid_rust_syntax());
354
355    let stream = &mut token_stream("struct Bar<A: FnOnce(&'static str) -> SomeStruct, B>");
356    let (data_type, ident) = super::DataType::take(stream).unwrap();
357    assert_eq!(data_type, super::DataType::Struct);
358    assert_eq!(ident, "Bar");
359    let generics = Generics::try_take(stream).unwrap().unwrap();
360    dbg!(&generics);
361    assert_eq!(generics.len(), 2);
362    assert_eq!(generics[0].ident(), "A");
363    assert_eq!(generics[1].ident(), "B");
364
365    let stream = &mut token_stream("struct Bar<A = ()>");
366    let (data_type, ident) = super::DataType::take(stream).unwrap();
367    assert_eq!(data_type, super::DataType::Struct);
368    assert_eq!(ident, "Bar");
369    let generics = Generics::try_take(stream).unwrap().unwrap();
370    dbg!(&generics);
371    assert_eq!(generics.len(), 1);
372    if let Generic::Generic(generic) = &generics[0] {
373        assert_eq!(generic.ident, "A");
374        assert_eq!(generic.default_value.len(), 1);
375        assert_eq!(generic.default_value[0].to_string(), "()");
376    } else {
377        panic!("Expected simple generic, got {:?}", generics[0]);
378    }
379}
380
381/// a lifetime generic parameter, e.g. `struct Foo<'a> { ... }`
382#[derive(Debug, Clone)]
383pub struct Lifetime {
384    /// The ident of this lifetime
385    pub ident: Ident,
386    /// Any constraints that this lifetime may have
387    pub constraint: Vec<TokenTree>,
388}
389
390impl Lifetime {
391    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
392        let start = assume_punct(input.next(), '\'');
393        let ident = match input.peek() {
394            Some(TokenTree::Ident(_)) => assume_ident(input.next()),
395            Some(t) => return Err(Error::ExpectedIdent(t.span())),
396            None => return Err(Error::ExpectedIdent(start.span())),
397        };
398
399        let mut constraint = Vec::new();
400        if let Some(TokenTree::Punct(p)) = input.peek() {
401            if p.as_char() == ':' {
402                assume_punct(input.next(), ':');
403                constraint = read_tokens_until_punct(input, &[',', '>'])?;
404            }
405        }
406
407        Ok(Self { ident, constraint })
408    }
409
410    #[cfg(test)]
411    fn is_ident(&self, s: &str) -> bool {
412        self.ident == s
413    }
414}
415
416#[test]
417fn test_lifetime_take() {
418    use crate::token_stream;
419    use std::panic::catch_unwind;
420    assert!(Lifetime::take(&mut token_stream("'a"))
421        .unwrap()
422        .is_ident("a"));
423    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0"))).is_err());
424    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'("))).is_err());
425    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("')"))).is_err());
426    assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0'"))).is_err());
427
428    let stream = &mut token_stream("'a: 'b>");
429    let lifetime = Lifetime::take(stream).unwrap();
430    assert_eq!(lifetime.ident, "a");
431    assert_eq!(lifetime.constraint.len(), 2);
432    assume_punct(stream.next(), '>');
433    assert!(stream.next().is_none());
434}
435
436/// a simple generic parameter, e.g. `struct Foo<F> { .. }`
437#[derive(Debug, Clone)]
438#[non_exhaustive]
439pub struct SimpleGeneric {
440    /// The ident of this generic
441    pub ident: Ident,
442    /// The constraints of this generic, e.g. `F: SomeTrait`
443    pub constraints: Vec<TokenTree>,
444    /// The default value of this generic, e.g. `F = ()`
445    pub default_value: Vec<TokenTree>,
446}
447
448impl SimpleGeneric {
449    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
450        let ident = assume_ident(input.next());
451        let mut constraints = Vec::new();
452        let mut default_value = Vec::new();
453        if let Some(TokenTree::Punct(punct)) = input.peek() {
454            let punct_char = punct.as_char();
455            if punct_char == ':' {
456                assume_punct(input.next(), ':');
457                constraints = read_tokens_until_punct(input, &['>', ','])?;
458            }
459            if punct_char == '=' {
460                assume_punct(input.next(), '=');
461                default_value = read_tokens_until_punct(input, &['>', ','])?;
462            }
463        }
464        Ok(Self {
465            ident,
466            constraints,
467            default_value,
468        })
469    }
470
471    /// The name of this generic, e.g. `T`
472    pub fn name(&self) -> Ident {
473        self.ident.clone()
474    }
475}
476
477/// a const generic parameter, e.g. `struct Foo<const N: usize> { .. }`
478#[derive(Debug, Clone)]
479pub struct ConstGeneric {
480    /// The `const` token for this generic
481    pub const_token: Ident,
482    /// The ident of this generic
483    pub ident: Ident,
484    /// The "constraints" (type) of this generic, e.g. the `usize` from `const N: usize`
485    pub constraints: Vec<TokenTree>,
486}
487
488impl ConstGeneric {
489    pub(crate) fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
490        let const_token = assume_ident(input.next());
491        let ident = assume_ident(input.next());
492        let mut constraints = Vec::new();
493        if let Some(TokenTree::Punct(punct)) = input.peek() {
494            if punct.as_char() == ':' {
495                assume_punct(input.next(), ':');
496                constraints = read_tokens_until_punct(input, &['>', ','])?;
497            }
498        }
499        Ok(Self {
500            const_token,
501            ident,
502            constraints,
503        })
504    }
505}
506
507/// Constraints on generic types.
508///
509/// ```
510/// # use std::marker::PhantomData;
511/// # use std::fmt::Display;
512///
513/// struct Foo<F>
514///     where F: Display // These are `GenericConstraints`
515/// {
516///     f: PhantomData<F>
517/// }
518#[derive(Debug, Clone, Default)]
519pub struct GenericConstraints {
520    constraints: Vec<TokenTree>,
521}
522
523impl GenericConstraints {
524    pub(crate) fn try_take(
525        input: &mut Peekable<impl Iterator<Item = TokenTree>>,
526    ) -> Result<Option<Self>> {
527        match input.peek() {
528            Some(TokenTree::Ident(ident)) => {
529                if !ident_eq(ident, "where") {
530                    return Ok(None);
531                }
532            }
533            _ => {
534                return Ok(None);
535            }
536        }
537        input.next();
538        let constraints = read_tokens_until_punct(input, &['{', '('])?;
539        Ok(Some(Self { constraints }))
540    }
541
542    pub(crate) fn where_clause(&self) -> StreamBuilder {
543        let mut result = StreamBuilder::new();
544        result.ident_str("where");
545        result.extend(self.constraints.clone());
546        result
547    }
548
549    /// Push the given constraint onto this stream.
550    ///
551    /// ```ignore
552    /// let mut generic_constraints = GenericConstraints::parse("T: Foo"); // imaginary function
553    /// let mut generic = SimpleGeneric::new("U"); // imaginary function
554    ///
555    /// generic_constraints.push_constraint(&generic, "Bar");
556    ///
557    /// // generic_constraints is now:
558    /// // `T: Foo, U: Bar`
559    /// ```
560    pub fn push_constraint(
561        &mut self,
562        generic: &SimpleGeneric,
563        constraint: impl AsRef<str>,
564    ) -> Result<()> {
565        let mut builder = StreamBuilder::new();
566        let last_constraint_was_comma = self
567            .constraints
568            .last()
569            .is_some_and(|l| matches!(l, TokenTree::Punct(c) if c.as_char() == ','));
570        if !self.constraints.is_empty() && !last_constraint_was_comma {
571            builder.punct(',');
572        }
573        builder.ident(generic.ident.clone());
574        builder.punct(':');
575        builder.push_parsed(constraint)?;
576        self.constraints.extend(builder.stream);
577
578        Ok(())
579    }
580
581    /// Push the given constraint onto this stream.
582    ///
583    /// ```ignore
584    /// let mut generic_constraints = GenericConstraints::parse("T: Foo"); // imaginary function
585    ///
586    /// generic_constraints.push_parsed_constraint("u32: SomeTrait");
587    ///
588    /// // generic_constraints is now:
589    /// // `T: Foo, u32: SomeTrait`
590    /// ```
591    pub fn push_parsed_constraint(&mut self, constraint: impl AsRef<str>) -> Result<()> {
592        let mut builder = StreamBuilder::new();
593        if !self.constraints.is_empty() {
594            builder.punct(',');
595        }
596        builder.push_parsed(constraint)?;
597        self.constraints.extend(builder.stream);
598
599        Ok(())
600    }
601
602    /// Clear the constraints
603    pub fn clear(&mut self) {
604        self.constraints.clear();
605    }
606}
607
608#[test]
609fn test_generic_constraints_try_take() {
610    use super::{DataType, StructBody, Visibility};
611    use crate::parse::body::Fields;
612    use crate::token_stream;
613
614    let stream = &mut token_stream("struct Foo where Foo: Bar { }");
615    DataType::take(stream).unwrap();
616    assert!(GenericConstraints::try_take(stream).unwrap().is_some());
617
618    let stream = &mut token_stream("struct Foo { }");
619    DataType::take(stream).unwrap();
620    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
621
622    let stream = &mut token_stream("struct Foo where Foo: Bar(Foo)");
623    DataType::take(stream).unwrap();
624    assert!(GenericConstraints::try_take(stream).unwrap().is_some());
625
626    let stream = &mut token_stream("struct Foo()");
627    DataType::take(stream).unwrap();
628    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
629
630    let stream = &mut token_stream("struct Foo()");
631    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
632
633    let stream = &mut token_stream("{}");
634    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
635
636    let stream = &mut token_stream("");
637    assert!(GenericConstraints::try_take(stream).unwrap().is_none());
638
639    let stream = &mut token_stream("pub(crate) struct Test<T: Encode> {}");
640    assert_eq!(Visibility::Pub, Visibility::try_take(stream).unwrap());
641    let (data_type, ident) = DataType::take(stream).unwrap();
642    assert_eq!(data_type, DataType::Struct);
643    assert_eq!(ident, "Test");
644    let constraints = Generics::try_take(stream).unwrap().unwrap();
645    assert_eq!(constraints.len(), 1);
646    assert_eq!(constraints[0].ident(), "T");
647    let body = StructBody::take(stream).unwrap();
648    if let Some(Fields::Struct(v)) = body.fields {
649        assert!(v.is_empty());
650    } else {
651        panic!("wrong fields {:?}", body.fields);
652    }
653}
654
655#[test]
656fn test_generic_constraints_trailing_comma() {
657    use crate::parse::{
658        Attribute, AttributeLocation, DataType, GenericConstraints, Generics, StructBody,
659        Visibility,
660    };
661    use crate::token_stream;
662    let source = &mut token_stream("pub struct MyStruct<T> where T: Clone, { }");
663
664    Attribute::try_take(AttributeLocation::Container, source).unwrap();
665    Visibility::try_take(source).unwrap();
666    DataType::take(source).unwrap();
667    Generics::try_take(source).unwrap().unwrap();
668    GenericConstraints::try_take(source).unwrap().unwrap();
669    StructBody::take(source).unwrap();
670}