Skip to main content

actix_web_schema_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, ItemStruct, ItemTrait, Meta, MetaList, MetaNameValue};
4
5#[proc_macro_attribute]
6pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
7    let input_trait = parse_macro_input!(input as ItemTrait);
8    let trait_name = &input_trait.ident;
9    let service_name = format_ident!("{}Service", trait_name);
10
11    // Extract route information from methods and filter attributes
12    let mut routes = Vec::new();
13    let mut function_items = Vec::new();
14
15    for item in &input_trait.items {
16        if let syn::TraitItem::Fn(m) = item {
17            let method_name = &m.sig.ident;
18
19            // Look for HTTP method attributes and filter them out
20            let mut http_method_attr: Option<(String, String)> = None;
21            let mut filtered_attrs = Vec::new();
22
23            for attr in &m.attrs {
24                if let Some(pair) = parse_route_attr(attr) {
25                    http_method_attr = Some(pair);
26                } else {
27                    filtered_attrs.push(attr.clone());
28                }
29            }
30
31            if let Some((http_method, path)) = http_method_attr {
32                let method_fn = match http_method.as_str() {
33                    "get" => quote! { get },
34                    "post" => quote! { post },
35                    "put" => quote! { put },
36                    "delete" => quote! { delete },
37                    "patch" => quote! { patch },
38                    "head" => quote! { head },
39                    "options" => quote! { options },
40                    _ => continue,
41                };
42
43                routes.push(quote! {
44                    ::actix_web::web::resource(#path).#method_fn(Self::#method_name)
45                });
46
47                // Create a new method without HTTP method attributes
48                let mut new_method = m.clone();
49                new_method.attrs = filtered_attrs;
50                function_items.push(syn::TraitItem::Fn(new_method));
51                continue;
52            }
53        }
54
55        function_items.push(item.clone());
56    }
57
58    // Build the original trait with 'static bound
59    let original_trait = {
60        let vis = &input_trait.vis;
61        let generics = &input_trait.generics;
62
63        quote! {
64            #vis trait #trait_name #generics where Self: 'static,
65            {
66                #(#function_items)*
67            }
68        }
69    };
70
71    // Extract doc attributes from the trait
72    let doc_attrs: Vec<_> = input_trait
73        .attrs
74        .iter()
75        .filter(|attr| attr.path().is_ident("doc"))
76        .collect();
77
78    // Build the Service struct with doc comments from the trait
79    let vis = &input_trait.vis;
80    let service_struct = quote! {
81        #(#doc_attrs)*
82        #vis struct #service_name;
83    };
84
85    // Build the HttpServiceFactory impl with proper spacing
86    let factory_impl = quote! {
87        impl ::actix_web::dev::HttpServiceFactory for #service_name {
88            fn register(self, config: &mut ::actix_web::dev::AppService) {
89                #(#routes.register(config);)*
90            }
91        }
92    };
93
94    let expanded = quote! {
95        #original_trait
96        #service_struct
97        #factory_impl
98    };
99
100    TokenStream::from(expanded)
101}
102
103fn parse_route_attr(attr: &Attribute) -> Option<(String, String)> {
104    let path = attr.path();
105    let last = path.segments.last()?;
106    let ident = last.ident.to_string();
107
108    // Only support lowercase HTTP method attributes (get, post, etc.)
109    let http_method = match ident.as_str() {
110        "get" | "post" | "put" | "delete" | "patch" | "head" | "options" => ident,
111        _ => return None,
112    };
113
114    // Parse the attribute value to get the path
115    let meta = attr.meta.clone();
116
117    // Handle #[get("/path")]
118    let path_str = match meta {
119        Meta::Path(_) => return None,
120        Meta::List(MetaList { tokens, .. }) => {
121            // Get the first string literal from the list
122            let tokens_str = tokens.to_string();
123            // Remove quotes and any extra formatting
124            tokens_str.trim_matches('"').to_string()
125        }
126        Meta::NameValue(MetaNameValue { value, .. }) => {
127            // For MetaNameValue, get the string literal value
128            let value_str = quote::ToTokens::to_token_stream(&value).to_string();
129            value_str.trim_matches('"').to_string()
130        }
131    };
132
133    Some((http_method, path_str))
134}
135
136#[proc_macro_attribute]
137pub fn response(attr: TokenStream, input: TokenStream) -> TokenStream {
138    let input_struct = parse_macro_input!(input as ItemStruct);
139    let struct_name = &input_struct.ident;
140    let vis = &input_struct.vis;
141    let generics = &input_struct.generics;
142    let fields = &input_struct.fields;
143
144    // Parse the attribute to check for "raw" flag
145    let attr_str = attr.to_string();
146    let is_raw = attr_str.trim() == "raw";
147
148    // Extract doc attributes from the struct
149    let doc_attrs: Vec<_> = input_struct
150        .attrs
151        .iter()
152        .filter(|attr| attr.path().is_ident("doc"))
153        .collect();
154
155    // Reconstruct the original struct with #[derive(Serialize)] and doc comments
156    // Handle different field types (named, unnamed, unit)
157    let struct_fields = match fields {
158        syn::Fields::Named(named) => {
159            let fields = named.named.iter();
160            quote! { { #(#fields),* } }
161        }
162        syn::Fields::Unnamed(unnamed) => {
163            let fields = unnamed.unnamed.iter();
164            quote! { ( #(#fields),* ); }
165        }
166        syn::Fields::Unit => quote! { ; },
167    };
168
169    let original_struct = quote! {
170        #[allow(unused)]
171        #[derive(::serde::Serialize)]
172        #(#doc_attrs)*
173        #vis struct #struct_name #generics #struct_fields
174    };
175
176    // Generate the Responder implementation based on whether raw is specified
177    let responder_impl = if is_raw {
178        quote! {
179            impl ::actix_web::Responder for #struct_name #generics {
180                type Body = ::actix_web::body::BoxBody;
181
182                fn respond_to(self, _req: &::actix_web::HttpRequest) -> ::actix_web::HttpResponse<Self::Body> {
183                    ::actix_web::HttpResponse::Ok().json(self)
184                }
185            }
186        }
187    } else {
188        quote! {
189            impl ::actix_web::Responder for #struct_name #generics {
190                type Body = ::actix_web::body::BoxBody;
191
192                fn respond_to(self, _req: &::actix_web::HttpRequest) -> ::actix_web::HttpResponse<Self::Body> {
193                    ::actix_web::HttpResponse::Ok().json(::serde_json::json!({"code": 0, "data": self}))
194                }
195            }
196        }
197    };
198
199    let expanded = quote! {
200        #original_struct
201        #responder_impl
202    };
203
204    TokenStream::from(expanded)
205}
206
207#[proc_macro_attribute]
208pub fn request(_attr: TokenStream, input: TokenStream) -> TokenStream {
209    let input_struct = parse_macro_input!(input as ItemStruct);
210    let struct_name = &input_struct.ident;
211    let vis = &input_struct.vis;
212    let generics = &input_struct.generics;
213    let fields = &input_struct.fields;
214
215    // Extract doc attributes from the struct
216    let doc_attrs: Vec<_> = input_struct
217        .attrs
218        .iter()
219        .filter(|attr| attr.path().is_ident("doc"))
220        .collect();
221
222    // Handle different field types (named, unnamed, unit)
223    let struct_fields = match fields {
224        syn::Fields::Named(named) => {
225            let fields = named.named.iter();
226            quote! { { #(#fields),* } }
227        }
228        syn::Fields::Unnamed(unnamed) => {
229            let fields = unnamed.unnamed.iter();
230            quote! { ( #(#fields),* ); }
231        }
232        syn::Fields::Unit => quote! { ; },
233    };
234
235    // Reconstruct the original struct with #[derive(Deserialize)] and doc comments
236    let original_struct = quote! {
237        #[allow(unused)]
238        #[derive(::serde::Deserialize)]
239        #(#doc_attrs)*
240        #vis struct #struct_name #generics #struct_fields
241    };
242
243    TokenStream::from(original_struct)
244}