Skip to main content

facet_macros_impl/
derive.rs

1use crate::{ToTokens, *};
2use proc_macro2::Delimiter;
3use quote::{TokenStreamExt as _, quote};
4
5use crate::plugin::{extract_derive_plugins, generate_plugin_chain};
6use crate::{LifetimeName, RenameRule, process_enum, process_struct};
7
8/// Recursively flattens transparent groups (groups with `Delimiter::None`) in a token stream.
9///
10/// When macros like `macro_rules_attribute` process metavariables like `$vis:vis`, they wrap
11/// the captured tokens in a `Group` with `Delimiter::None`. This function unwraps such groups
12/// so that the inner tokens can be parsed normally.
13///
14/// For example, if a `$vis:vis` captures `pub`, the token stream might contain:
15/// ```text
16/// Group { delimiter: None, stream: TokenStream [Ident { ident: "pub" }] }
17/// ```
18///
19/// After flattening, this becomes just:
20/// ```text
21/// Ident { ident: "pub" }
22/// ```
23fn flatten_transparent_groups(input: TokenStream) -> TokenStream {
24    input
25        .into_iter()
26        .flat_map(|tt| match tt {
27            TokenTree::Group(group) if group.delimiter() == Delimiter::None => {
28                // Recursively flatten the contents of the transparent group
29                flatten_transparent_groups(group.stream())
30            }
31            TokenTree::Group(group) => {
32                // For non-transparent groups, recursively flatten their contents
33                // but keep the group structure
34                let flattened_stream = flatten_transparent_groups(group.stream());
35                let mut new_group = proc_macro2::Group::new(group.delimiter(), flattened_stream);
36                new_group.set_span(group.span());
37                std::iter::once(TokenTree::Group(new_group)).collect()
38            }
39            other => std::iter::once(other).collect(),
40        })
41        .collect()
42}
43
44/// Generate a static declaration that pre-evaluates `<T as Facet>::SHAPE`.
45/// Only emitted in release builds to avoid slowing down debug compile times.
46/// Skipped for generic types since we can't create a static for an unmonomorphized type.
47pub(crate) fn generate_static_decl(
48    type_name: &Ident,
49    facet_crate: &TokenStream,
50    has_type_or_const_generics: bool,
51) -> TokenStream {
52    // Can't generate a static for generic types - the type parameters aren't concrete
53    if has_type_or_const_generics {
54        return quote! {};
55    }
56
57    let type_name_str = type_name.to_string();
58    let screaming_snake_name = RenameRule::ScreamingSnakeCase.apply(&type_name_str);
59
60    let static_name_ident = quote::format_ident!("{}_SHAPE", screaming_snake_name);
61
62    quote! {
63        #[cfg(not(debug_assertions))]
64        static #static_name_ident: &'static #facet_crate::Shape = <#type_name as #facet_crate::Facet>::SHAPE;
65    }
66}
67
68/// Main entry point for the `#[derive(Facet)]` macro. Parses type declarations and generates Facet trait implementations.
69///
70/// If `#[facet(derive(...))]` is present, chains to plugins before generating.
71pub fn facet_macros(input: TokenStream) -> TokenStream {
72    // Flatten transparent groups (Delimiter::None) before parsing.
73    // This handles macros like `macro_rules_attribute` that wrap metavariables
74    // like `$vis:vis` in transparent groups.
75    let input = flatten_transparent_groups(input);
76    let mut i = input.clone().to_token_iter();
77
78    // Parse as TypeDecl
79    match i.parse::<Cons<AdtDecl, EndOfStream>>() {
80        Ok(it) => {
81            // Extract attributes to check for plugins
82            let attrs = match &it.first {
83                AdtDecl::Struct(s) => &s.attributes,
84                AdtDecl::Enum(e) => &e.attributes,
85            };
86
87            // Check for #[facet(derive(...))] plugins
88            let plugins = extract_derive_plugins(attrs);
89
90            if !plugins.is_empty() {
91                // Get the facet crate path from attributes
92                let facet_crate = {
93                    let parsed_attrs = PAttrs::parse(attrs);
94                    parsed_attrs.facet_crate()
95                };
96
97                // Generate plugin chain
98                if let Some(chain) = generate_plugin_chain(&input, &plugins, &facet_crate) {
99                    return chain;
100                }
101            }
102
103            // No plugins, proceed with normal codegen
104            match it.first {
105                AdtDecl::Struct(parsed) => process_struct::process_struct(parsed),
106                AdtDecl::Enum(parsed) => process_enum::process_enum(parsed),
107            }
108        }
109        Err(err) => {
110            panic!("Could not parse type declaration: {input}\nError: {err}");
111        }
112    }
113}
114
115pub(crate) fn build_where_clauses(
116    where_clauses: Option<&WhereClauses>,
117    generics: Option<&GenericParams>,
118    opaque: bool,
119    facet_crate: &TokenStream,
120    custom_bounds: &[TokenStream],
121) -> TokenStream {
122    let mut where_clause_tokens = TokenStream::new();
123    let mut has_clauses = false;
124
125    if let Some(wc) = where_clauses {
126        for c in wc.clauses.iter() {
127            if has_clauses {
128                where_clause_tokens.extend(quote! { , });
129            }
130            where_clause_tokens.extend(c.value.to_token_stream());
131            has_clauses = true;
132        }
133    }
134
135    if let Some(generics) = generics {
136        for p in generics.params.iter() {
137            match &p.value {
138                GenericParam::Lifetime { name, .. } => {
139                    let facet_lifetime = LifetimeName(quote::format_ident!("{}", "ʄ"));
140                    let lifetime = LifetimeName(name.name.clone());
141                    if has_clauses {
142                        where_clause_tokens.extend(quote! { , });
143                    }
144                    where_clause_tokens
145                        .extend(quote! { #lifetime: #facet_lifetime, #facet_lifetime: #lifetime });
146
147                    has_clauses = true;
148                }
149                GenericParam::Const { .. } => {
150                    // ignore for now
151                }
152                GenericParam::Type { name, .. } => {
153                    if has_clauses {
154                        where_clause_tokens.extend(quote! { , });
155                    }
156                    // Only specify lifetime bound for opaque containers
157                    if opaque {
158                        where_clause_tokens.extend(quote! { #name: });
159                    } else {
160                        where_clause_tokens.extend(quote! { #name: #facet_crate::Facet<> });
161                    }
162                    has_clauses = true;
163                }
164            }
165        }
166    }
167
168    // Add custom bounds from #[facet(bound = "...")]
169    for bound in custom_bounds {
170        if has_clauses {
171            where_clause_tokens.extend(quote! { , });
172        }
173        where_clause_tokens.extend(bound.clone());
174        has_clauses = true;
175    }
176
177    if !has_clauses {
178        quote! {}
179    } else {
180        quote! { where #where_clause_tokens }
181    }
182}
183
184/// Build the `.type_params(...)` builder call, returning empty if no type params.
185pub(crate) fn build_type_params_call(
186    generics: Option<&GenericParams>,
187    opaque: bool,
188    facet_crate: &TokenStream,
189) -> TokenStream {
190    if opaque {
191        return quote! {};
192    }
193
194    let mut type_params = Vec::new();
195    if let Some(generics) = generics {
196        for p in generics.params.iter() {
197            match &p.value {
198                GenericParam::Lifetime { .. } => {
199                    // ignore for now
200                }
201                GenericParam::Const { .. } => {
202                    // handled by build_const_params_call
203                }
204                GenericParam::Type { name, .. } => {
205                    let name_str = name.to_string();
206                    type_params.push(quote! {
207                        #facet_crate::TypeParam {
208                            name: #name_str,
209                            shape: <#name as #facet_crate::Facet>::SHAPE
210                        }
211                    });
212                }
213            }
214        }
215    }
216
217    if type_params.is_empty() {
218        quote! {}
219    } else {
220        quote! { .type_params(&[#(#type_params),*]) }
221    }
222}
223
224/// Build the `.const_params(...)` builder call, returning empty if no supported const params.
225pub(crate) fn build_const_params_call(
226    generics: Option<&GenericParams>,
227    opaque: bool,
228    facet_crate: &TokenStream,
229) -> TokenStream {
230    if opaque {
231        return quote! {};
232    }
233
234    let mut const_params = Vec::new();
235    if let Some(generics) = generics {
236        for p in generics.params.iter() {
237            if let GenericParam::Const { name, typ, .. } = &p.value {
238                let name_str = name.to_string();
239                let typ = typ.to_token_stream().to_string().replace(' ', "");
240                let primitive = typ.rsplit("::").next().unwrap_or(&typ);
241
242                let (kind, value) = match primitive {
243                    "bool" => (
244                        quote! { #facet_crate::ConstParamKind::Bool },
245                        quote! { if #name { 1u64 } else { 0u64 } },
246                    ),
247                    "char" => (
248                        quote! { #facet_crate::ConstParamKind::Char },
249                        quote! { #name as u32 as u64 },
250                    ),
251                    "u8" => (
252                        quote! { #facet_crate::ConstParamKind::U8 },
253                        quote! { #name as u64 },
254                    ),
255                    "u16" => (
256                        quote! { #facet_crate::ConstParamKind::U16 },
257                        quote! { #name as u64 },
258                    ),
259                    "u32" => (
260                        quote! { #facet_crate::ConstParamKind::U32 },
261                        quote! { #name as u64 },
262                    ),
263                    "u64" => (
264                        quote! { #facet_crate::ConstParamKind::U64 },
265                        quote! { #name as u64 },
266                    ),
267                    "usize" => (
268                        quote! { #facet_crate::ConstParamKind::Usize },
269                        quote! { #name as u64 },
270                    ),
271                    "i8" => (
272                        quote! { #facet_crate::ConstParamKind::I8 },
273                        quote! { (#name as i64) as u64 },
274                    ),
275                    "i16" => (
276                        quote! { #facet_crate::ConstParamKind::I16 },
277                        quote! { (#name as i64) as u64 },
278                    ),
279                    "i32" => (
280                        quote! { #facet_crate::ConstParamKind::I32 },
281                        quote! { (#name as i64) as u64 },
282                    ),
283                    "i64" => (
284                        quote! { #facet_crate::ConstParamKind::I64 },
285                        quote! { (#name as i64) as u64 },
286                    ),
287                    "isize" => (
288                        quote! { #facet_crate::ConstParamKind::Isize },
289                        quote! { (#name as i64) as u64 },
290                    ),
291                    _ => continue,
292                };
293
294                const_params.push(quote! {
295                    #facet_crate::ConstParam {
296                        name: #name_str,
297                        value: #value,
298                        kind: #kind,
299                    }
300                });
301            }
302        }
303    }
304
305    if const_params.is_empty() {
306        quote! {}
307    } else {
308        quote! { .const_params(&[#(#const_params),*]) }
309    }
310}
311
312/// Generate the `type_name` function for the `ValueVTable`,
313/// displaying realized generics if present.
314pub(crate) fn generate_type_name_fn(
315    type_name: &Ident,
316    generics: Option<&GenericParams>,
317    opaque: bool,
318    facet_crate: &TokenStream,
319) -> TokenStream {
320    let type_name_str = type_name.to_string();
321
322    let write_generics = (!opaque)
323        .then_some(generics)
324        .flatten()
325        .and_then(|generics| {
326            let params = generics.params.iter();
327            let write_each = params.filter_map(|param| match &param.value {
328                // Lifetimes not shown by `std::any::type_name`, this is parity.
329                GenericParam::Lifetime { .. } => None,
330                GenericParam::Const { name, .. } => Some(quote! {
331                    write!(f, "{:?}", #name)?;
332                }),
333                GenericParam::Type { name, .. } => Some(quote! {
334                    <#name as #facet_crate::Facet>::SHAPE.write_type_name(f, opts)?;
335                }),
336            });
337            // TODO: is there a way to construct a DelimitedVec from an iterator?
338            let mut tokens = TokenStream::new();
339            tokens.append_separated(write_each, quote! { write!(f, ", ")?; });
340            if tokens.is_empty() {
341                None
342            } else {
343                Some(tokens)
344            }
345        });
346
347    match write_generics {
348        Some(write_generics) => {
349            quote! {
350                |_shape, f, opts| {
351                    write!(f, #type_name_str)?;
352                    if let Some(opts) = opts.for_children() {
353                        write!(f, "<")?;
354                        #write_generics
355                        write!(f, ">")?;
356                    } else {
357                        write!(f, "<…>")?;
358                    }
359                    Ok(())
360                }
361            }
362        }
363        None => quote! { |_shape, f, _opts| ::core::fmt::Write::write_str(f, #type_name_str) },
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_flatten_transparent_groups_simple() {
373        // Test that regular tokens pass through unchanged
374        let input: TokenStream = quote::quote! { pub struct Foo; };
375        let flattened = flatten_transparent_groups(input.clone());
376        assert_eq!(flattened.to_string(), input.to_string());
377    }
378
379    #[test]
380    fn test_flatten_transparent_groups_with_none_delimiter() {
381        // Simulate what macro_rules_attribute does with $vis:vis
382        // Create a Group with None delimiter containing "pub"
383        let pub_token: TokenStream = quote::quote! { pub };
384        let none_group = proc_macro2::Group::new(proc_macro2::Delimiter::None, pub_token.clone());
385
386        let mut input = TokenStream::new();
387        input.extend(std::iter::once(TokenTree::Group(none_group)));
388        input.extend(quote::quote! { struct Cat; });
389
390        let flattened = flatten_transparent_groups(input);
391
392        // After flattening, should be "pub struct Cat;"
393        let expected: TokenStream = quote::quote! { pub struct Cat; };
394        assert_eq!(flattened.to_string(), expected.to_string());
395    }
396
397    #[test]
398    fn test_flatten_transparent_groups_preserves_braces() {
399        // Test that normal braces are preserved
400        let input: TokenStream = quote::quote! { struct Foo { x: u32 } };
401        let flattened = flatten_transparent_groups(input.clone());
402        assert_eq!(flattened.to_string(), input.to_string());
403    }
404
405    #[test]
406    fn test_flatten_transparent_groups_nested() {
407        // Test nested transparent groups
408        let inner: TokenStream = quote::quote! { pub };
409        let inner_group = proc_macro2::Group::new(proc_macro2::Delimiter::None, inner);
410        let outer_stream: TokenStream = std::iter::once(TokenTree::Group(inner_group)).collect();
411        let outer_group = proc_macro2::Group::new(proc_macro2::Delimiter::None, outer_stream);
412
413        let mut input = TokenStream::new();
414        input.extend(std::iter::once(TokenTree::Group(outer_group)));
415        input.extend(quote::quote! { struct Cat; });
416
417        let flattened = flatten_transparent_groups(input);
418
419        let expected: TokenStream = quote::quote! { pub struct Cat; };
420        assert_eq!(flattened.to_string(), expected.to_string());
421    }
422
423    #[test]
424    fn test_flatten_transparent_groups_inside_braces() {
425        // Test that transparent groups inside braces are also flattened
426        let pub_token: TokenStream = quote::quote! { pub };
427        let none_group = proc_macro2::Group::new(proc_macro2::Delimiter::None, pub_token);
428
429        let mut brace_content = TokenStream::new();
430        brace_content.extend(std::iter::once(TokenTree::Group(none_group)));
431        brace_content.extend(quote::quote! { x: u32 });
432
433        let brace_group = proc_macro2::Group::new(proc_macro2::Delimiter::Brace, brace_content);
434
435        let mut input: TokenStream = quote::quote! { struct Foo };
436        input.extend(std::iter::once(TokenTree::Group(brace_group)));
437
438        let flattened = flatten_transparent_groups(input);
439
440        let expected: TokenStream = quote::quote! { struct Foo { pub x: u32 } };
441        assert_eq!(flattened.to_string(), expected.to_string());
442    }
443
444    #[test]
445    fn test_parse_struct_with_transparent_group_visibility() {
446        // Simulate the exact scenario from the issue: $vis:vis wrapped in None-delimited group
447        let pub_token: TokenStream = quote::quote! { pub };
448        let none_group = proc_macro2::Group::new(proc_macro2::Delimiter::None, pub_token);
449
450        let mut input = TokenStream::new();
451        input.extend(std::iter::once(TokenTree::Group(none_group)));
452        input.extend(quote::quote! { struct Cat; });
453
454        // This should now succeed after flattening
455        let flattened = flatten_transparent_groups(input);
456        let mut iter = flattened.to_token_iter();
457        let result = iter.parse::<Cons<AdtDecl, EndOfStream>>();
458
459        assert!(
460            result.is_ok(),
461            "Parsing should succeed after flattening transparent groups"
462        );
463    }
464}