Skip to main content

nnn_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3/* Modules */
4mod argument;
5mod codegen;
6mod ctx;
7mod utils;
8/* Built-in imports */
9extern crate alloc;
10use alloc::collections::BTreeMap;
11/* Crate imports */
12use argument::{Argument, Arguments};
13use ctx::Context;
14/* Dependencies imports */
15use quote::quote;
16use syn::{parse::Parser as _, punctuated::Punctuated};
17use utils::syn_ext::SynDataExt as _;
18
19#[proc_macro_attribute]
20pub fn nnn(
21    nnn_args: proc_macro::TokenStream,
22    type_definition: proc_macro::TokenStream,
23) -> proc_macro::TokenStream {
24    expand(nnn_args, type_definition)
25        .unwrap_or_else(|err| err.to_compile_error())
26        .into()
27}
28
29fn expand(
30    nnn_args: proc_macro::TokenStream,
31    type_definition: proc_macro::TokenStream,
32) -> syn::Result<proc_macro2::TokenStream> {
33    let input: syn::DeriveInput = syn::parse(type_definition)?;
34    let original_visibility = input.vis.clone();
35
36    let args = Arguments::from(
37        Punctuated::<Argument, syn::Token![,]>::parse_terminated
38            .parse(nnn_args)?,
39    );
40
41    let (type_name, inner_type, generics) = split_derive_input(input.clone())?;
42    let ctx = Context::try_from((input, args))?;
43    let (impl_generics, ty_generics, where_clause) =
44        ctx.generics().split_for_impl();
45
46    let tests = ctx.args().get_tests(&ctx);
47    let impls = ctx.args().get_impls(&ctx);
48    let (
49        impl_blocks,
50        bare_impls,
51        macro_attrs,
52        err_variants,
53        validity_checks,
54        err_display_arm,
55        sanitization_steps,
56        new_enums,
57        custom_test_harness,
58    ) = codegen::Implementation::separate_variants(&impls);
59
60    let dedup_err_variants = err_variants
61        .map(|variant| (variant.ident.clone(), variant))
62        .collect::<BTreeMap<_, _>>()
63        .into_values();
64
65    let error_type = quote::format_ident!("{type_name}Error",);
66    let mod_name = quote::format_ident!("__private_{type_name}",);
67
68    Ok(quote! {
69        #[doc(hidden)]
70        #[allow(non_snake_case, reason = "Includes NNNType name which is probably CamelCase.")]
71        #[allow(clippy::module_name_repetitions, reason = "Includes NNNType which is probably the name of the file.")]
72        mod #mod_name {
73            use super::*;
74
75            #(#macro_attrs)*
76            pub struct #type_name #generics (#inner_type) #where_clause;
77
78            #[derive(Debug, Clone, PartialEq, Eq)]
79            #[non_exhaustive]
80            pub enum #error_type {
81                #(#dedup_err_variants),*
82            }
83
84            impl ::core::error::Error for #error_type {}
85
86            impl ::core::fmt::Display for #error_type {
87                fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88                    match *self {
89                        #(#err_display_arm)*
90                    }
91                }
92            }
93
94            impl #impl_generics nnn::NNNewType for #type_name #ty_generics #where_clause {
95                type Inner = #inner_type;
96                type Error = #error_type;
97
98                fn sanitize(mut value: Self::Inner) -> Self::Inner {
99                    #(#sanitization_steps;)*
100                    value
101                }
102
103                fn try_new(mut value: Self::Inner) -> Result<Self, Self::Error> {
104                    value = Self::sanitize(value);
105                    #(#validity_checks;)*
106                    Ok(Self(value))
107                }
108
109                fn into_inner(self) -> Self::Inner {
110                    self.0
111                }
112            }
113
114            impl #impl_generics #type_name #ty_generics #where_clause {
115                #(#bare_impls)*
116            }
117
118            #(#impl_blocks)*
119
120            #(#new_enums)*
121
122            #[cfg(test)]
123            #custom_test_harness
124            mod tests {
125                use super::*;
126
127                #(#tests)*
128            }
129        }
130
131        #[allow(clippy::pub_use, reason = "pub use can happen if struct is meant to be public.")]
132        #original_visibility use #mod_name::*;
133    })
134}
135
136fn split_derive_input(
137    input: syn::DeriveInput,
138) -> Result<(syn::Ident, syn::Type, syn::Generics), syn::Error> {
139    if let Some(attr) = input.attrs.first() {
140        return Err(syn::Error::new_spanned(
141            attr,
142            "Attributes are not supported; pass additional parameters via `nnn` instead.",
143        ));
144    }
145
146    let syn::DeriveInput {
147        data,
148        ident: type_name,
149        generics,
150        ..
151    } = input;
152
153    let syn::Data::Struct(data_struct) = data else {
154        return Err(syn::Error::new(
155            data.decl_span(),
156            "nnn is only supported on structs.",
157        ));
158    };
159
160    let syn::Fields::Unnamed(syn::FieldsUnnamed {
161        unnamed: fields, ..
162    }) = data_struct.fields
163    else {
164        return Err(syn::Error::new_spanned(
165            data_struct.fields,
166            "`nnn` can only be used on structs with unnamed fields.",
167        ));
168    };
169
170    let mut fields_iter = fields.iter();
171    let Some(inner_field) = fields_iter.next() else {
172        return Err(syn::Error::new_spanned(
173            fields,
174            "Cannot use `nnn` on empty structs.",
175        ));
176    };
177
178    if !matches!(inner_field.vis, syn::Visibility::Inherited) {
179        return Err(syn::Error::new_spanned(
180            &inner_field.vis,
181            "You can only have a private field here.",
182        ));
183    }
184
185    if let Some(extra_field) = fields_iter.next() {
186        return Err(syn::Error::new_spanned(
187            extra_field,
188            "You cannot have more than one field.",
189        ));
190    }
191
192    Ok((type_name, inner_field.ty.clone(), generics))
193}