facet_derive/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use unsynn::*;
5
6keyword! {
7    KPub = "pub";
8    KStruct = "struct";
9    KEnum = "enum";
10    KDoc = "doc";
11    KRepr = "repr";
12    KCrate = "crate";
13    KConst = "const";
14    KMut = "mut";
15}
16
17operator! {
18    Eq = "=";
19    Semi = ";";
20    Apostrophe = "'";
21    DoubleSemicolon = "::";
22}
23
24unsynn! {
25    enum Vis {
26        Pub(KPub),
27        PubCrate(Cons<KPub, ParenthesisGroupContaining<KCrate>>),
28    }
29
30    struct Attribute {
31        _pound: Pound,
32        body: BracketGroupContaining<AttributeInner>,
33    }
34
35    enum AttributeInner {
36        Doc(DocInner),
37        Repr(ReprInner),
38        Any(Vec<TokenTree>)
39    }
40
41    struct DocInner {
42        _kw_doc: KDoc,
43        _eq: Eq,
44        value: LiteralString,
45    }
46
47    struct ReprInner {
48        _kw_repr: KRepr,
49        attr: ParenthesisGroupContaining<Ident>,
50    }
51
52    struct Struct {
53        // Skip any doc attributes by consuming them
54        attributes: Vec<Attribute>,
55        _vis: Option<Vis>,
56        _kw_struct: KStruct,
57        name: Ident,
58        body: BraceGroupContaining<CommaDelimitedVec<StructField>>,
59    }
60
61    struct Lifetime {
62        _apostrophe: Apostrophe,
63        name: Ident,
64    }
65
66    enum Expr {
67        Integer(LiteralInteger),
68    }
69
70    enum Type {
71        Path(PathType),
72        Tuple(ParenthesisGroupContaining<CommaDelimitedVec<Box<Type>>>),
73        Slice(BracketGroupContaining<Box<Type>>),
74        Bare(BareType),
75    }
76
77    struct PathType {
78        prefix: Ident,
79        _doublesemi: DoubleSemicolon,
80        rest: Box<Type>,
81    }
82
83    struct BareType {
84        name: Ident,
85        generic_params: Option<GenericParams>,
86    }
87
88    struct GenericParams {
89        _lt: Lt,
90        params: CommaDelimitedVec<Type>,
91        _gt: Gt,
92    }
93
94    enum ConstOrMut {
95        Const(KConst),
96        Mut(KMut),
97    }
98
99    struct StructField {
100        attributes: Vec<Attribute>,
101        _vis: Option<Vis>,
102        name: Ident,
103        _colon: Colon,
104        typ: Type,
105    }
106
107    struct TupleStruct {
108        // Skip any doc attributes by consuming them
109        attributes: Vec<Attribute>,
110        _vis: Option<Vis>,
111        _kw_struct: KStruct,
112        name: Ident,
113        body: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
114    }
115
116    struct TupleField {
117        attributes: Vec<Attribute>,
118        vis: Option<Vis>,
119        typ: Type,
120    }
121
122    struct Enum {
123        // Skip any doc attributes by consuming them
124        attributes: Vec<Attribute>,
125        _pub: Option<KPub>,
126        _kw_enum: KEnum,
127        name: Ident,
128        body: BraceGroupContaining<CommaDelimitedVec<EnumVariantLike>>,
129    }
130
131    enum EnumVariantLike {
132        Unit(UnitVariant),
133        Tuple(TupleVariant),
134        Struct(StructVariant),
135    }
136
137    struct UnitVariant {
138        attributes: Vec<Attribute>,
139        name: Ident,
140    }
141
142    struct TupleVariant {
143        // Skip any doc comments on variants
144        attributes: Vec<Attribute>,
145        name: Ident,
146        _paren: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
147    }
148
149    struct StructVariant {
150        // Skip any doc comments on variants
151        _doc_attributes: Vec<Attribute>,
152        name: Ident,
153        _brace: BraceGroupContaining<CommaDelimitedVec<StructField>>,
154    }
155}
156
157/// Derive the Facet trait for structs, tuple structs, and enums.
158///
159/// This uses unsynn, so it's light, but it _will_ choke on some Rust syntax because...
160/// there's a lot of Rust syntax.
161#[proc_macro_derive(Facet)]
162pub fn facet_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
163    let input = TokenStream::from(input);
164    let mut i = input.to_token_iter();
165
166    // Try to parse as struct first
167    if let Ok(parsed) = i.parse::<Struct>() {
168        return process_struct(parsed);
169    }
170    let struct_tokens_left = i.count();
171
172    // Try to parse as tuple struct
173    i = input.to_token_iter(); // Reset iterator
174    if let Ok(parsed) = i.parse::<TupleStruct>() {
175        return process_tuple_struct(parsed);
176    }
177    let tuple_struct_tokens_left = i.count();
178
179    // Try to parse as enum
180    i = input.to_token_iter(); // Reset iterator
181    if let Ok(parsed) = i.parse::<Enum>() {
182        return process_enum(parsed);
183    }
184    let enum_tokens_left = i.count();
185
186    let mut msg = format!(
187        "Could not parse input as struct, tuple struct, or enum: {}",
188        input
189    );
190
191    // Find which parsing left the fewest tokens
192    let min_tokens_left = struct_tokens_left
193        .min(tuple_struct_tokens_left)
194        .min(enum_tokens_left);
195
196    // Parse again for the one with fewest tokens left and show remaining tokens
197    if min_tokens_left == struct_tokens_left {
198        i = input.to_token_iter();
199        let err = i.parse::<Struct>().err();
200        msg = format!(
201            "{}\n====> Error parsing struct: {:?}\n====> Remaining tokens after struct parsing: {}",
202            msg,
203            err,
204            i.collect::<TokenStream>()
205        );
206    } else if min_tokens_left == tuple_struct_tokens_left {
207        i = input.to_token_iter();
208        let err = i.parse::<TupleStruct>().err();
209        msg = format!(
210            "{}\n====> Error parsing tuple struct: {:?}\n====> Remaining tokens after tuple struct parsing: {}",
211            msg,
212            err,
213            i.collect::<TokenStream>()
214        );
215    } else {
216        i = input.to_token_iter();
217        let err = i.parse::<Enum>().err();
218        msg = format!(
219            "{}\n====> Error parsing enum: {:?}\n====> Remaining tokens after enum parsing: {}",
220            msg,
221            err,
222            i.collect::<TokenStream>()
223        );
224    }
225
226    // If we get here, couldn't parse as struct, tuple struct, or enum
227    panic!("{msg}");
228}
229
230/// Processes a regular struct to implement Facet
231///
232/// Example input:
233/// ```rust
234/// struct Blah {
235///     foo: u32,
236///     bar: String,
237/// }
238/// ```
239fn process_struct(parsed: Struct) -> proc_macro::TokenStream {
240    let struct_name = parsed.name.to_string();
241    let fields = parsed
242        .body
243        .content
244        .0
245        .iter()
246        .map(|field| field.value.name.to_string())
247        .collect::<Vec<String>>()
248        .join(", ");
249
250    // Generate dummy fields
251    let dummy_fields = parsed
252        .body
253        .content
254        .0
255        .iter()
256        .map(|field| field.value.name.to_string())
257        .map(|field| format!("{field}: Facet::DUMMY"))
258        .collect::<Vec<String>>()
259        .join(", ");
260
261    // Generate the impl
262    let output = format!(
263        r#"
264#[automatically_derived]
265unsafe impl facet::Facet for {struct_name} {{
266    const DUMMY: Self = Self {{
267        {dummy_fields}
268    }};
269    const SHAPE: &'static facet::Shape = &const {{
270        facet::Shape {{
271            layout: std::alloc::Layout::new::<Self>(),
272            vtable: facet::value_vtable!(
273                {struct_name},
274                |f, _opts| std::fmt::Write::write_str(f, "{struct_name}")
275            ),
276            def: facet::Def::Struct(facet::StructDef {{
277                fields: facet::struct_fields!({struct_name}, ({fields})),
278            }}),
279        }}
280    }};
281}}
282        "#
283    );
284    output.into_token_stream().into()
285}
286
287/// Processes a tuple struct to implement Facet
288///
289/// Example input:
290/// ```rust
291/// struct Point(f32, f32);
292/// ```
293fn process_tuple_struct(parsed: TupleStruct) -> proc_macro::TokenStream {
294    let struct_name = parsed.name.to_string();
295
296    // Generate field names for tuple elements (0, 1, 2, etc.)
297    let fields = parsed
298        .body
299        .content
300        .0
301        .iter()
302        .enumerate()
303        .map(|(idx, _)| idx.to_string())
304        .collect::<Vec<String>>();
305
306    // Create the fields string for struct_fields! macro
307    let fields_str = fields.join(", ");
308
309    let dummy_fields = (0..parsed.body.content.0.len())
310        .map(|_| String::from("Facet::DUMMY"))
311        .collect::<Vec<String>>()
312        .join(", ");
313
314    // Generate the impl
315    let output = format!(
316        r#"
317#[automatically_derived]
318unsafe impl facet::Facet for {struct_name} {{
319    const DUMMY: Self = Self({dummy_fields});
320    const SHAPE: &'static facet::Shape = &const {{
321        facet::Shape {{
322            layout: std::alloc::Layout::new::<Self>(),
323            vtable: facet::value_vtable!(
324                {struct_name},
325                |f, _opts| std::fmt::Write::write_str(f, "{struct_name}")
326            ),
327            def: facet::Def::TupleStruct(facet::StructDef {{
328                fields: facet::struct_fields!({struct_name}, ({fields_str})),
329            }}),
330        }}
331    }};
332}}
333    "#
334    );
335    output.into_token_stream().into()
336}
337
338/// Processes an enum to implement Facet
339///
340/// Example input:
341/// ```rust
342/// #[repr(u8)]
343/// enum Color {
344///     Red,
345///     Green,
346///     Blue(u8, u8),
347///     Custom { r: u8, g: u8, b: u8 }
348/// }
349/// ```
350fn process_enum(parsed: Enum) -> proc_macro::TokenStream {
351    let enum_name = parsed.name.to_string();
352
353    // Check for explicit repr attribute
354    let has_repr = parsed
355        .attributes
356        .iter()
357        .any(|attr| matches!(attr.body.content, AttributeInner::Repr(_)));
358
359    if !has_repr {
360        return r#"compile_error!("Enums must have an explicit representation (e.g. #[repr(u8)]) to be used with Facet")"#
361            .into_token_stream()
362            .into();
363    }
364
365    // Process each variant
366    let variants = parsed
367        .body
368        .content
369        .0
370        .iter()
371        .map(|var_like| match &var_like.value {
372            EnumVariantLike::Unit(unit) => {
373                let variant_name = unit.name.to_string();
374                format!("facet::enum_unit_variant!({enum_name}, {variant_name})")
375            }
376            EnumVariantLike::Tuple(tuple) => {
377                let variant_name = tuple.name.to_string();
378                let field_types = tuple
379                    ._paren
380                    .content
381                    .0
382                    .iter()
383                    .map(|field| field.value.typ.to_string())
384                    .collect::<Vec<String>>()
385                    .join(", ");
386
387                format!("facet::enum_tuple_variant!({enum_name}, {variant_name}, [{field_types}])")
388            }
389            EnumVariantLike::Struct(struct_var) => {
390                let variant_name = struct_var.name.to_string();
391                let fields = struct_var
392                    ._brace
393                    .content
394                    .0
395                    .iter()
396                    .map(|field| {
397                        let name = field.value.name.to_string();
398                        let typ = field.value.typ.to_string();
399                        format!("{name}: {typ}")
400                    })
401                    .collect::<Vec<String>>()
402                    .join(", ");
403
404                format!("facet::enum_struct_variant!({enum_name}, {variant_name}, {{{fields}}})")
405            }
406        })
407        .collect::<Vec<String>>()
408        .join(", ");
409
410    // Extract the repr type
411    let mut repr_type = "Default"; // Default fallback
412    for attr in &parsed.attributes {
413        if let AttributeInner::Repr(repr_attr) = &attr.body.content {
414            repr_type = match repr_attr.attr.content.to_string().as_str() {
415                "u8" => "U8",
416                "u16" => "U16",
417                "u32" => "U32",
418                "u64" => "U64",
419                "usize" => "USize",
420                "i8" => "I8",
421                "i16" => "I16",
422                "i32" => "I32",
423                "i64" => "I64",
424                "isize" => "ISize",
425                _ => "Default", // Unknown repr type
426            };
427            break;
428        }
429    }
430
431    // Generate the impl
432    let output = format!(
433        r#"
434#[automatically_derived]
435unsafe impl facet::Facet for {enum_name} {{
436    const SHAPE: &'static facet::Shape = &const {{
437        facet::Shape {{
438            layout: std::alloc::Layout::new::<Self>(),
439            vtable: facet::value_vtable!(
440                {enum_name},
441                |f, _opts| std::fmt::Write::write_str(f, "{enum_name}")
442            ),
443            def: facet::Def::Enum(facet::EnumDef {{
444                variants: facet::enum_variants!({enum_name}, [{variants}]),
445                repr: facet::EnumRepr::{repr_type},
446            }}),
447        }}
448    }};
449}}
450        "#
451    );
452    output.into_token_stream().into()
453}
454
455impl std::fmt::Display for Type {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        match self {
458            Type::Path(path) => {
459                write!(f, "{}::{}", path.prefix, path.rest)
460            }
461            Type::Tuple(tuple) => {
462                write!(f, "(")?;
463                for (i, typ) in tuple.content.0.iter().enumerate() {
464                    if i > 0 {
465                        write!(f, ", ")?;
466                    }
467                    write!(f, "{}", typ.value)?;
468                }
469                write!(f, ")")
470            }
471            Type::Slice(slice) => {
472                write!(f, "[{}]", slice.content)
473            }
474            Type::Bare(ident) => {
475                write!(f, "{}", ident.name)?;
476                if let Some(generic_params) = &ident.generic_params {
477                    write!(f, "<")?;
478                    for (i, param) in generic_params.params.0.iter().enumerate() {
479                        if i > 0 {
480                            write!(f, ", ")?;
481                        }
482                        write!(f, "{}", param.value)?;
483                    }
484                    write!(f, ">")?;
485                }
486                Ok(())
487            }
488        }
489    }
490}
491
492impl std::fmt::Display for ConstOrMut {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        match self {
495            ConstOrMut::Const(_) => write!(f, "const"),
496            ConstOrMut::Mut(_) => write!(f, "mut"),
497        }
498    }
499}
500
501impl std::fmt::Display for Lifetime {
502    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        write!(f, "'{}", self.name)
504    }
505}
506
507impl std::fmt::Display for Expr {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        match self {
510            Expr::Integer(int) => write!(f, "{}", int.value()),
511        }
512    }
513}