Skip to main content

vtcode_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5/// Derive macro that generates the same boilerplate as the `string_newtype!`
6/// declarative macro. Apply to a tuple struct wrapping a single `String` field.
7///
8/// Generates:
9/// - Inherent methods: `new()`, `as_str()`, `into_inner()`
10/// - `Deref<Target = str>`
11/// - `Borrow<str>`
12/// - `AsRef<str>`
13/// - `Display`
14/// - `From<String>`, `From<&str>`, `From<Self> for String`
15///
16/// # Example
17///
18/// ```rust,ignore
19/// #[derive(Debug, Clone, Serialize, Deserialize, StringNewtype)]
20/// #[serde(transparent)]
21/// pub struct SessionId(String);
22/// ```
23#[proc_macro_derive(StringNewtype)]
24pub fn derive_string_newtype(input: TokenStream) -> TokenStream {
25    let input = parse_macro_input!(input as DeriveInput);
26    impl_string_newtype(&input).unwrap_or_else(|err| err.to_compile_error().into())
27}
28
29fn impl_string_newtype(input: &DeriveInput) -> syn::Result<TokenStream> {
30    let name = &input.ident;
31
32    // Validate: must be a tuple struct with exactly one String field.
33    let field_type = match &input.data {
34        Data::Struct(data) => match &data.fields {
35            Fields::Unnamed(fields) => {
36                if fields.unnamed.len() != 1 {
37                    return Err(syn::Error::new_spanned(
38                        name,
39                        "StringNewtype requires a tuple struct with exactly one field",
40                    ));
41                }
42                let field = fields.unnamed.first().unwrap();
43                &field.ty
44            }
45            _ => {
46                return Err(syn::Error::new_spanned(
47                    name,
48                    "StringNewtype can only be derived for tuple structs",
49                ));
50            }
51        },
52        _ => {
53            return Err(syn::Error::new_spanned(
54                name,
55                "StringNewtype can only be derived for structs",
56            ));
57        }
58    };
59
60    // Verify the inner type is String.
61    if !is_string_type(field_type) {
62        return Err(syn::Error::new_spanned(
63            field_type,
64            "StringNewtype requires the inner type to be String",
65        ));
66    }
67
68    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
69
70    let output = quote! {
71        impl #impl_generics #name #ty_generics #where_clause {
72            /// Create a new instance from any value that converts to `String`.
73            pub fn new(value: impl Into<String>) -> Self {
74                Self(value.into())
75            }
76
77            /// Borrow the inner string as a `&str`.
78            pub fn as_str(&self) -> &str {
79                &self.0
80            }
81
82            /// Consume the wrapper and return the inner `String`.
83            pub fn into_inner(self) -> String {
84                self.0
85            }
86        }
87
88        impl #impl_generics std::ops::Deref for #name #ty_generics #where_clause {
89            type Target = str;
90
91            fn deref(&self) -> &Self::Target {
92                &self.0
93            }
94        }
95
96        impl #impl_generics std::borrow::Borrow<str> for #name #ty_generics #where_clause {
97            fn borrow(&self) -> &str {
98                &self.0
99            }
100        }
101
102        impl #impl_generics AsRef<str> for #name #ty_generics #where_clause {
103            fn as_ref(&self) -> &str {
104                &self.0
105            }
106        }
107
108        impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
109            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110                self.0.fmt(f)
111            }
112        }
113
114        impl #impl_generics From<String> for #name #ty_generics #where_clause {
115            fn from(value: String) -> Self {
116                Self(value)
117            }
118        }
119
120        impl #impl_generics From<&str> for #name #ty_generics #where_clause {
121            fn from(value: &str) -> Self {
122                Self(value.to_string())
123            }
124        }
125
126        impl #impl_generics From<#name #ty_generics> for String #where_clause {
127            fn from(value: #name #ty_generics) -> Self {
128                value.0
129            }
130        }
131    };
132
133    Ok(output.into())
134}
135
136fn is_string_type(ty: &syn::Type) -> bool {
137    if let syn::Type::Path(type_path) = ty {
138        if type_path.qself.is_none() && type_path.path.segments.len() == 1 {
139            return type_path.path.segments[0].ident == "String";
140        }
141    }
142    false
143}