facet_derive/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4mod process_struct;
5
6use unsynn::*;
7
8keyword! {
9    KPub = "pub";
10    KStruct = "struct";
11    KEnum = "enum";
12    KDoc = "doc";
13    KRepr = "repr";
14    KCrate = "crate";
15    KConst = "const";
16    KMut = "mut";
17    KFacet = "facet";
18    KSensitive = "sensitive";
19}
20
21operator! {
22    Eq = "=";
23    Semi = ";";
24    Apostrophe = "'";
25    DoubleSemicolon = "::";
26}
27
28unsynn! {
29    enum Vis {
30        Pub(KPub),
31        PubCrate(Cons<KPub, ParenthesisGroupContaining<KCrate>>),
32    }
33
34    struct Attribute {
35        _pound: Pound,
36        body: BracketGroupContaining<AttributeInner>,
37    }
38
39    enum AttributeInner {
40        Facet(FacetAttr),
41        Doc(DocInner),
42        Repr(ReprInner),
43        Any(Vec<TokenTree>)
44    }
45
46    struct FacetAttr {
47        _facet: KFacet,
48        _sensitive: ParenthesisGroupContaining<FacetInner>,
49    }
50
51    enum FacetInner {
52        Sensitive(KSensitive),
53        Other(Vec<TokenTree>)
54    }
55
56    struct DocInner {
57        _kw_doc: KDoc,
58        _eq: Eq,
59        value: LiteralString,
60    }
61
62    struct ReprInner {
63        _kw_repr: KRepr,
64        attr: ParenthesisGroupContaining<Ident>,
65    }
66
67    struct Struct {
68        // Skip any doc attributes by consuming them
69        attributes: Vec<Attribute>,
70        _vis: Option<Vis>,
71        _kw_struct: KStruct,
72        name: Ident,
73        body: BraceGroupContaining<CommaDelimitedVec<StructField>>,
74    }
75
76    struct Lifetime {
77        _apostrophe: Apostrophe,
78        name: Ident,
79    }
80
81    enum Expr {
82        Integer(LiteralInteger),
83    }
84
85    enum Type {
86        Path(PathType),
87        Tuple(ParenthesisGroupContaining<CommaDelimitedVec<Box<Type>>>),
88        Slice(BracketGroupContaining<Box<Type>>),
89        Bare(BareType),
90    }
91
92    struct PathType {
93        prefix: Ident,
94        _doublesemi: DoubleSemicolon,
95        rest: Box<Type>,
96    }
97
98    struct BareType {
99        name: Ident,
100        generic_params: Option<GenericParams>,
101    }
102
103    struct GenericParams {
104        _lt: Lt,
105        params: CommaDelimitedVec<Type>,
106        _gt: Gt,
107    }
108
109    enum ConstOrMut {
110        Const(KConst),
111        Mut(KMut),
112    }
113
114    struct StructField {
115        attributes: Vec<Attribute>,
116        _vis: Option<Vis>,
117        name: Ident,
118        _colon: Colon,
119        typ: Type,
120    }
121
122    struct TupleStruct {
123        // Skip any doc attributes by consuming them
124        attributes: Vec<Attribute>,
125        _vis: Option<Vis>,
126        _kw_struct: KStruct,
127        name: Ident,
128        body: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
129    }
130
131    struct TupleField {
132        attributes: Vec<Attribute>,
133        vis: Option<Vis>,
134        typ: Type,
135    }
136
137    struct Enum {
138        // Skip any doc attributes by consuming them
139        attributes: Vec<Attribute>,
140        _pub: Option<KPub>,
141        _kw_enum: KEnum,
142        name: Ident,
143        body: BraceGroupContaining<CommaDelimitedVec<EnumVariantLike>>,
144    }
145
146    enum EnumVariantLike {
147        Unit(UnitVariant),
148        Tuple(TupleVariant),
149        Struct(StructVariant),
150    }
151
152    struct UnitVariant {
153        attributes: Vec<Attribute>,
154        name: Ident,
155    }
156
157    struct TupleVariant {
158        // Skip any doc comments on variants
159        attributes: Vec<Attribute>,
160        name: Ident,
161        _paren: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
162    }
163
164    struct StructVariant {
165        // Skip any doc comments on variants
166        _doc_attributes: Vec<Attribute>,
167        name: Ident,
168        _brace: BraceGroupContaining<CommaDelimitedVec<StructField>>,
169    }
170}
171
172/// Derive the Facet trait for structs, tuple structs, and enums.
173///
174/// This uses unsynn, so it's light, but it _will_ choke on some Rust syntax because...
175/// there's a lot of Rust syntax.
176#[proc_macro_derive(Facet, attributes(facet))]
177pub fn facet_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
178    let input = TokenStream::from(input);
179    let mut i = input.to_token_iter();
180
181    // Try to parse as struct first
182    if let Ok(parsed) = i.parse::<Struct>() {
183        return process_struct::process_struct(parsed);
184    }
185    let struct_tokens_left = i.count();
186
187    // Try to parse as tuple struct
188    i = input.to_token_iter(); // Reset iterator
189    if let Ok(parsed) = i.parse::<TupleStruct>() {
190        return process_tuple_struct(parsed);
191    }
192    let tuple_struct_tokens_left = i.count();
193
194    // Try to parse as enum
195    i = input.to_token_iter(); // Reset iterator
196    if let Ok(parsed) = i.parse::<Enum>() {
197        return process_enum(parsed);
198    }
199    let enum_tokens_left = i.count();
200
201    let mut msg = format!(
202        "Could not parse input as struct, tuple struct, or enum: {}",
203        input
204    );
205
206    // Find which parsing left the fewest tokens
207    let min_tokens_left = struct_tokens_left
208        .min(tuple_struct_tokens_left)
209        .min(enum_tokens_left);
210
211    // Parse again for the one with fewest tokens left and show remaining tokens
212    if min_tokens_left == struct_tokens_left {
213        i = input.to_token_iter();
214        let err = i.parse::<Struct>().err();
215        msg = format!(
216            "{}\n====> Error parsing struct: {:?}\n====> Remaining tokens after struct parsing: {}",
217            msg,
218            err,
219            i.collect::<TokenStream>()
220        );
221    } else if min_tokens_left == tuple_struct_tokens_left {
222        i = input.to_token_iter();
223        let err = i.parse::<TupleStruct>().err();
224        msg = format!(
225            "{}\n====> Error parsing tuple struct: {:?}\n====> Remaining tokens after tuple struct parsing: {}",
226            msg,
227            err,
228            i.collect::<TokenStream>()
229        );
230    } else {
231        i = input.to_token_iter();
232        let err = i.parse::<Enum>().err();
233        msg = format!(
234            "{}\n====> Error parsing enum: {:?}\n====> Remaining tokens after enum parsing: {}",
235            msg,
236            err,
237            i.collect::<TokenStream>()
238        );
239    }
240
241    // If we get here, couldn't parse as struct, tuple struct, or enum
242    panic!("{msg}");
243}
244
245/// Processes a tuple struct to implement Facet
246///
247/// Example input:
248/// ```rust
249/// struct Point(f32, f32);
250/// ```
251fn process_tuple_struct(parsed: TupleStruct) -> proc_macro::TokenStream {
252    let struct_name = parsed.name.to_string();
253
254    // Generate field names for tuple elements (0, 1, 2, etc.)
255    let fields = parsed
256        .body
257        .content
258        .0
259        .iter()
260        .enumerate()
261        .map(|(idx, _)| idx.to_string())
262        .collect::<Vec<String>>();
263
264    // Create the fields string for struct_fields! macro
265    let fields_str = fields.join(", ");
266
267    let dummy_fields = (0..parsed.body.content.0.len())
268        .map(|_| String::from("Facet::DUMMY"))
269        .collect::<Vec<String>>()
270        .join(", ");
271
272    // Generate the impl
273    let output = format!(
274        r#"
275#[automatically_derived]
276unsafe impl facet::Facet for {struct_name} {{
277    const DUMMY: Self = Self({dummy_fields});
278    const SHAPE: &'static facet::Shape = &const {{
279        facet::Shape {{
280            layout: std::alloc::Layout::new::<Self>(),
281            vtable: facet::value_vtable!(
282                {struct_name},
283                |f, _opts| std::fmt::Write::write_str(f, "{struct_name}")
284            ),
285            def: facet::Def::Struct(facet::StructDef {{
286                kind: facet::StructKind::TupleStruct,
287                fields: facet::struct_fields!({struct_name}, ({fields_str})),
288            }}),
289        }}
290    }};
291}}
292    "#
293    );
294    output.into_token_stream().into()
295}
296
297/// Processes an enum to implement Facet
298///
299/// Example input:
300/// ```rust
301/// #[repr(u8)]
302/// enum Color {
303///     Red,
304///     Green,
305///     Blue(u8, u8),
306///     Custom { r: u8, g: u8, b: u8 }
307/// }
308/// ```
309fn process_enum(parsed: Enum) -> proc_macro::TokenStream {
310    let enum_name = parsed.name.to_string();
311
312    // Check for explicit repr attribute
313    let has_repr = parsed
314        .attributes
315        .iter()
316        .any(|attr| matches!(attr.body.content, AttributeInner::Repr(_)));
317
318    if !has_repr {
319        return r#"compile_error!("Enums must have an explicit representation (e.g. #[repr(u8)]) to be used with Facet")"#
320            .into_token_stream()
321            .into();
322    }
323
324    // Process each variant
325    let variants = parsed
326        .body
327        .content
328        .0
329        .iter()
330        .map(|var_like| match &var_like.value {
331            EnumVariantLike::Unit(unit) => {
332                let variant_name = unit.name.to_string();
333                format!("facet::enum_unit_variant!({enum_name}, {variant_name})")
334            }
335            EnumVariantLike::Tuple(tuple) => {
336                let variant_name = tuple.name.to_string();
337                let field_types = tuple
338                    ._paren
339                    .content
340                    .0
341                    .iter()
342                    .map(|field| field.value.typ.to_string())
343                    .collect::<Vec<String>>()
344                    .join(", ");
345
346                format!("facet::enum_tuple_variant!({enum_name}, {variant_name}, [{field_types}])")
347            }
348            EnumVariantLike::Struct(struct_var) => {
349                let variant_name = struct_var.name.to_string();
350                let fields = struct_var
351                    ._brace
352                    .content
353                    .0
354                    .iter()
355                    .map(|field| {
356                        let name = field.value.name.to_string();
357                        let typ = field.value.typ.to_string();
358                        format!("{name}: {typ}")
359                    })
360                    .collect::<Vec<String>>()
361                    .join(", ");
362
363                format!("facet::enum_struct_variant!({enum_name}, {variant_name}, {{{fields}}})")
364            }
365        })
366        .collect::<Vec<String>>()
367        .join(", ");
368
369    // Extract the repr type
370    let mut repr_type = "Default"; // Default fallback
371    for attr in &parsed.attributes {
372        if let AttributeInner::Repr(repr_attr) = &attr.body.content {
373            repr_type = match repr_attr.attr.content.to_string().as_str() {
374                "u8" => "U8",
375                "u16" => "U16",
376                "u32" => "U32",
377                "u64" => "U64",
378                "usize" => "USize",
379                "i8" => "I8",
380                "i16" => "I16",
381                "i32" => "I32",
382                "i64" => "I64",
383                "isize" => "ISize",
384                _ => "Default", // Unknown repr type
385            };
386            break;
387        }
388    }
389
390    // Generate the impl
391    let output = format!(
392        r#"
393#[automatically_derived]
394unsafe impl facet::Facet for {enum_name} {{
395    const SHAPE: &'static facet::Shape = &const {{
396        facet::Shape {{
397            layout: std::alloc::Layout::new::<Self>(),
398            vtable: facet::value_vtable!(
399                {enum_name},
400                |f, _opts| std::fmt::Write::write_str(f, "{enum_name}")
401            ),
402            def: facet::Def::Enum(facet::EnumDef {{
403                variants: facet::enum_variants!({enum_name}, [{variants}]),
404                repr: facet::EnumRepr::{repr_type},
405            }}),
406        }}
407    }};
408}}
409        "#
410    );
411    output.into_token_stream().into()
412}
413
414impl std::fmt::Display for Type {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        match self {
417            Type::Path(path) => {
418                write!(f, "{}::{}", path.prefix, path.rest)
419            }
420            Type::Tuple(tuple) => {
421                write!(f, "(")?;
422                for (i, typ) in tuple.content.0.iter().enumerate() {
423                    if i > 0 {
424                        write!(f, ", ")?;
425                    }
426                    write!(f, "{}", typ.value)?;
427                }
428                write!(f, ")")
429            }
430            Type::Slice(slice) => {
431                write!(f, "[{}]", slice.content)
432            }
433            Type::Bare(ident) => {
434                write!(f, "{}", ident.name)?;
435                if let Some(generic_params) = &ident.generic_params {
436                    write!(f, "<")?;
437                    for (i, param) in generic_params.params.0.iter().enumerate() {
438                        if i > 0 {
439                            write!(f, ", ")?;
440                        }
441                        write!(f, "{}", param.value)?;
442                    }
443                    write!(f, ">")?;
444                }
445                Ok(())
446            }
447        }
448    }
449}
450
451impl std::fmt::Display for ConstOrMut {
452    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453        match self {
454            ConstOrMut::Const(_) => write!(f, "const"),
455            ConstOrMut::Mut(_) => write!(f, "mut"),
456        }
457    }
458}
459
460impl std::fmt::Display for Lifetime {
461    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        write!(f, "'{}", self.name)
463    }
464}
465
466impl std::fmt::Display for Expr {
467    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        match self {
469            Expr::Integer(int) => write!(f, "{}", int.value()),
470        }
471    }
472}