facet_derive/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4mod process_enum;
5mod process_struct;
6
7use unsynn::*;
8
9keyword! {
10    KPub = "pub";
11    KStruct = "struct";
12    KEnum = "enum";
13    KDoc = "doc";
14    KRepr = "repr";
15    KCrate = "crate";
16    KIn = "in";
17    KConst = "const";
18    KWhere = "where";
19    KMut = "mut";
20    KFacet = "facet";
21    KSensitive = "sensitive";
22}
23
24operator! {
25    Eq = "=";
26    Semi = ";";
27    Apostrophe = "'";
28    DoubleSemicolon = "::";
29}
30
31/// Parses tokens and groups until `C` is found on the current token tree level.
32type VerbatimUntil<C> = Many<Cons<Except<C>, AngleTokenTree>>;
33type ModPath = Cons<Option<PathSep>, PathSepDelimited<Ident>>;
34type Bounds = Cons<Colon, VerbatimUntil<Either<Comma, Eq, Gt>>>;
35
36unsynn! {
37    /// Parses either a `TokenTree` or `<...>` grouping (which is not a [`Group`] as far as proc-macros
38    /// are concerned).
39    struct AngleTokenTree(Either<Cons<Lt, Vec<Cons<Except<Gt>, AngleTokenTree>>, Gt>, TokenTree>);
40
41    enum AdtDecl {
42        Struct(Struct),
43        Enum(Enum),
44    }
45
46    enum Vis {
47        Pub(KPub),
48        /// `pub(in? crate::foo::bar)`/`pub(in? ::foo::bar)`
49        PubIn(Cons<KPub, ParenthesisGroupContaining<Cons<Option<KIn>, ModPath>>>),
50    }
51
52    struct Attribute {
53        _pound: Pound,
54        body: BracketGroupContaining<AttributeInner>,
55    }
56
57    enum AttributeInner {
58        Facet(FacetAttr),
59        Doc(DocInner),
60        Repr(ReprInner),
61        Any(Vec<TokenTree>)
62    }
63
64    struct FacetAttr {
65        _facet: KFacet,
66        inner: ParenthesisGroupContaining<FacetInner>,
67    }
68
69    enum FacetInner {
70        Sensitive(KSensitive),
71        Other(Vec<TokenTree>)
72    }
73
74    struct DocInner {
75        _kw_doc: KDoc,
76        _eq: Eq,
77        value: LiteralString,
78    }
79
80    struct ReprInner {
81        _kw_repr: KRepr,
82        attr: ParenthesisGroupContaining<CommaDelimitedVec<Ident>>,
83    }
84
85    struct Struct {
86        attributes: Vec<Attribute>,
87        _vis: Option<Vis>,
88        _kw_struct: KStruct,
89        name: Ident,
90        generics: Option<GenericParams>,
91        kind: StructKind,
92    }
93
94    struct GenericParams {
95        _lt: Lt,
96        params: CommaDelimitedVec<GenericParam>,
97        _gt: Gt,
98    }
99
100    enum GenericParam {
101        Lifetime{
102            name: Lifetime,
103            bounds: Option<Cons<Colon, VerbatimUntil<Either<Comma, Gt>>>>,
104        },
105        Const {
106            _const: KConst,
107            name: Ident,
108            _colon: Colon,
109            typ: VerbatimUntil<Either<Comma, Gt, Eq>>,
110            default: Option<Cons<Eq, VerbatimUntil<Either<Comma, Gt>>>>,
111        },
112        Type {
113            name: Ident,
114            bounds: Option<Bounds>,
115            default: Option<Cons<Eq, VerbatimUntil<Either<Comma, Gt>>>>,
116        },
117    }
118
119    struct WhereClauses {
120        _kw_where: KWhere,
121        clauses: CommaDelimitedVec<WhereClause>,
122    }
123
124    struct WhereClause {
125        // FIXME: This likely breaks for absolute `::` paths
126        _pred: VerbatimUntil<Colon>,
127        _colon: Colon,
128        bounds: VerbatimUntil<Either<Comma, Semicolon, BraceGroup>>,
129    }
130
131    enum StructKind {
132        Struct {
133            clauses: Option<WhereClauses>, fields: BraceGroupContaining<CommaDelimitedVec<StructField>>
134        },
135        TupleStruct {
136            fields: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
137            clauses: Option<WhereClauses>,
138            semi: Semi
139        },
140        UnitStruct {
141            clauses: Option<WhereClauses>,
142            semi: Semi
143        }
144    }
145
146    struct Lifetime {
147        _apostrophe: PunctJoint<'\''>,
148        name: Ident,
149    }
150
151    enum Expr {
152        Integer(LiteralInteger),
153    }
154
155    enum ConstOrMut {
156        Const(KConst),
157        Mut(KMut),
158    }
159
160    struct StructField {
161        attributes: Vec<Attribute>,
162        _vis: Option<Vis>,
163        name: Ident,
164        _colon: Colon,
165        typ: VerbatimUntil<Comma>,
166    }
167
168    struct TupleField {
169        attributes: Vec<Attribute>,
170        vis: Option<Vis>,
171        typ: VerbatimUntil<Comma>,
172    }
173
174    struct Enum {
175        attributes: Vec<Attribute>,
176        _pub: Option<KPub>,
177        _kw_enum: KEnum,
178        name: Ident,
179        generics: Option<GenericParams>,
180        clauses: Option<WhereClauses>,
181        body: BraceGroupContaining<CommaDelimitedVec<EnumVariantLike>>,
182    }
183
184    enum EnumVariantLike {
185        Tuple(TupleVariant),
186        Struct(StructVariant),
187        Unit(UnitVariant),
188    }
189
190    struct UnitVariant {
191        attributes: Vec<Attribute>,
192        name: Ident,
193    }
194
195    struct TupleVariant {
196        attributes: Vec<Attribute>,
197        name: Ident,
198        fields: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
199    }
200
201    struct StructVariant {
202        attributes: Vec<Attribute>,
203        name: Ident,
204        fields: BraceGroupContaining<CommaDelimitedVec<StructField>>,
205    }
206}
207
208/// Derive the Facet trait for structs, tuple structs, and enums.
209///
210/// This uses unsynn, so it's light, but it _will_ choke on some Rust syntax because...
211/// there's a lot of Rust syntax.
212#[proc_macro_derive(Facet, attributes(facet))]
213pub fn facet_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
214    let input = TokenStream::from(input);
215    let mut i = input.to_token_iter();
216
217    // Parse as TypeDecl
218    match i.parse::<Cons<AdtDecl, EndOfStream>>() {
219        Ok(it) => match it.first {
220            AdtDecl::Struct(parsed) => process_struct::process_struct(parsed),
221            AdtDecl::Enum(parsed) => process_enum::process_enum(parsed),
222        },
223        Err(err) => {
224            panic!(
225                "Could not parse type declaration: {}\nError: {}",
226                input, err
227            );
228        }
229    }
230}
231
232impl core::fmt::Display for AngleTokenTree {
233    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234        match &self.0 {
235            Either::First(it) => {
236                write!(f, "<")?;
237                for it in it.second.iter() {
238                    write!(f, "{}", it.second)?;
239                }
240                write!(f, ">")?;
241            }
242            Either::Second(it) => write!(f, "{}", it)?,
243            Either::Third(Invalid) => unreachable!(),
244            Either::Fourth(Invalid) => unreachable!(),
245        };
246        Ok(())
247    }
248}
249
250struct VerbatimDisplay<'a, C>(&'a VerbatimUntil<C>);
251impl<C> core::fmt::Display for VerbatimDisplay<'_, C> {
252    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
253        for tt in self.0.0.iter() {
254            write!(f, "{}", tt.value.second)?;
255        }
256        Ok(())
257    }
258}
259
260impl core::fmt::Display for ConstOrMut {
261    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
262        match self {
263            ConstOrMut::Const(_) => write!(f, "const"),
264            ConstOrMut::Mut(_) => write!(f, "mut"),
265        }
266    }
267}
268
269impl core::fmt::Display for Lifetime {
270    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
271        write!(f, "'{}", self.name)
272    }
273}
274
275impl core::fmt::Display for WhereClauses {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        write!(f, "where ")?;
278        for clause in self.clauses.0.iter() {
279            write!(f, "{},", clause.value)?;
280        }
281        Ok(())
282    }
283}
284
285impl core::fmt::Display for WhereClause {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        write!(
288            f,
289            "{}: {}",
290            VerbatimDisplay(&self._pred),
291            VerbatimDisplay(&self.bounds)
292        )
293    }
294}
295
296impl core::fmt::Display for Expr {
297    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
298        match self {
299            Expr::Integer(int) => write!(f, "{}", int.value()),
300        }
301    }
302}
303
304/// Converts PascalCase to UPPER_SNAKE_CASE
305pub(crate) fn to_upper_snake_case(input: &str) -> String {
306    input
307        .chars()
308        .enumerate()
309        .fold(String::new(), |mut acc, (i, c)| {
310            if c.is_uppercase() {
311                if i > 0 {
312                    acc.push('_');
313                }
314                acc.push(c.to_ascii_uppercase());
315            } else {
316                acc.push(c.to_ascii_uppercase());
317            }
318            acc
319        })
320}
321
322/// Generate a static declaration that exports the crate
323pub(crate) fn generate_static_decl(type_name: &str) -> String {
324    format!(
325        "#[used]\nstatic {}_SHAPE: &'static ::facet::Shape = <{} as ::facet::Facet>::SHAPE;",
326        to_upper_snake_case(type_name),
327        type_name
328    )
329}
330
331pub(crate) fn build_maybe_doc(attrs: &[Attribute]) -> String {
332    let doc_lines: Vec<_> = attrs
333        .iter()
334        .filter_map(|attr| match &attr.body.content {
335            AttributeInner::Doc(doc_inner) => Some(doc_inner.value.value()),
336            _ => None,
337        })
338        .collect();
339
340    if doc_lines.is_empty() {
341        String::new()
342    } else {
343        format!(r#".doc(&[{}])"#, doc_lines.join(","))
344    }
345}
346
347pub(crate) fn gen_struct_field(
348    field_name: &str,
349    struct_name: &str,
350    generics: &str,
351    attrs: &[Attribute],
352) -> String {
353    // Determine field flags
354    let mut flags = "::facet::FieldFlags::EMPTY";
355    let mut attribute_list: Vec<String> = vec![];
356    let mut doc_lines: Vec<&str> = vec![];
357    for attr in attrs {
358        match &attr.body.content {
359            AttributeInner::Facet(facet_attr) => match &facet_attr.inner.content {
360                FacetInner::Sensitive(_ksensitive) => {
361                    flags = "::facet::FieldFlags::SENSITIVE";
362                    attribute_list.push("::facet::FieldAttribute::Sensitive".to_string());
363                }
364                FacetInner::Other(tt) => {
365                    attribute_list.push(format!(
366                        r#"::facet::FieldAttribute::Arbitrary({:?})"#,
367                        tt.tokens_to_string()
368                    ));
369                }
370            },
371            AttributeInner::Doc(doc_inner) => doc_lines.push(doc_inner.value.value()),
372            AttributeInner::Repr(_) => {
373                // muffin
374            }
375            AttributeInner::Any(_) => {
376                // muffin two
377            }
378        }
379    }
380    let attributes = attribute_list.join(",");
381
382    let maybe_field_doc = if doc_lines.is_empty() {
383        String::new()
384    } else {
385        format!(r#".doc(&[{}])"#, doc_lines.join(","))
386    };
387
388    // Generate each field definition
389    format!(
390        "::facet::Field::builder()
391    .name(\"{field_name}\")
392    .shape(::facet::shape_of(&|s: {struct_name}<{generics}>| s.{field_name}))
393    .offset(::core::mem::offset_of!({struct_name}<{generics}>, {field_name}))
394    .flags({flags})
395    .attributes(&[{attributes}])
396    {maybe_field_doc}
397    .build()"
398    )
399}
400
401fn generics_split_for_impl(generics: Option<&GenericParams>) -> (String, String) {
402    let Some(generics) = generics else {
403        return ("".to_string(), "".to_string());
404    };
405    let mut generics_impl = Vec::new();
406    let mut generics_target = Vec::new();
407
408    for param in generics.params.0.iter() {
409        match &param.value {
410            GenericParam::Type {
411                name,
412                bounds,
413                default: _,
414            } => {
415                let name = name.to_string();
416                let mut impl_ = name.clone();
417                if let Some(bounds) = bounds {
418                    impl_.push_str(&format!(": {}", VerbatimDisplay(&bounds.second)));
419                }
420                generics_impl.push(impl_);
421                generics_target.push(name);
422            }
423            GenericParam::Lifetime { name, bounds } => {
424                let name = name.to_string();
425                let mut impl_ = name.clone();
426                if let Some(bounds) = bounds {
427                    impl_.push_str(&format!(": {}", VerbatimDisplay(&bounds.second)));
428                }
429                generics_impl.push(impl_);
430                generics_target.push(name);
431            }
432            GenericParam::Const {
433                _const,
434                name,
435                _colon,
436                typ,
437                default: _,
438            } => {
439                let name = name.to_string();
440                generics_impl.push(format!("const {}: {}", name, VerbatimDisplay(typ)));
441                generics_target.push(name);
442            }
443        }
444    }
445    let generics_impl = generics_impl.join(", ");
446    let generics_target = generics_target.join(", ");
447    (generics_impl, generics_target)
448}