actix_prost_build/
request.rs

1use crate::config::HttpRule;
2use proc_macro2::{Ident, TokenStream};
3use std::{collections::HashSet, iter::FromIterator};
4use syn::PathArguments;
5
6pub struct RequestFields {
7    name: String,
8    fields: Vec<String>,
9}
10
11pub struct Request {
12    message: syn::ItemStruct,
13    method_name: Ident,
14    path: RequestFields,
15    query: RequestFields,
16    body: RequestFields,
17}
18
19impl Request {
20    pub fn new(message: syn::ItemStruct, method_name: Ident, config: &HttpRule) -> Request {
21        let fields: Vec<String> = config
22            .pattern
23            .path()
24            .split('{')
25            .skip(1)
26            .filter_map(|q| q.split('}').next())
27            .map(|x| x.to_owned())
28            .collect();
29
30        let (path, query, body) = Self::split_fields(&message, &fields, &config.body);
31
32        Request {
33            message,
34            method_name,
35            path: RequestFields {
36                name: "Path".into(),
37                fields: path,
38            },
39            query: RequestFields {
40                name: "Query".into(),
41                fields: query,
42            },
43            body: RequestFields {
44                name: "Json".into(),
45                fields: body,
46            },
47        }
48    }
49
50    fn split_fields(
51        message: &syn::ItemStruct,
52        path_fields: &[String],
53        body_fields: &Option<String>,
54    ) -> (Vec<String>, Vec<String>, Vec<String>) {
55        let fields = if let syn::Fields::Named(fields) = &message.fields {
56            fields
57        } else {
58            panic!("non named fields aren't supported");
59        };
60
61        let path_filter: HashSet<&str> = HashSet::from_iter(path_fields.iter().map(|s| s.as_ref()));
62        let (path, non_path): (Vec<_>, Vec<_>) = fields
63            .named
64            .iter()
65            .map(|field| field.ident.as_ref().unwrap().to_string())
66            .partition(|field| path_filter.contains(field.as_str()));
67
68        if path_fields.len() != path.len() {
69            let found: HashSet<String> = HashSet::from_iter(path);
70            panic!(
71                "some path fields were not found: {:?}",
72                path_fields
73                    .iter()
74                    .filter(|f| !found.contains(f.as_str()))
75                    .collect::<Vec<_>>()
76            )
77        }
78
79        let (body, query) = if let Some(body_fields) = body_fields {
80            if body_fields != "*" {
81                non_path.into_iter().partition(|f| f == body_fields)
82            } else {
83                (non_path, Vec::default())
84            }
85        } else {
86            (Vec::default(), non_path)
87        };
88
89        if path.len() + query.len() + body.len() != message.fields.len() {
90            panic!("could not map all message fields to path, query and body parts")
91        }
92
93        (path, query, body)
94    }
95
96    pub fn filter_fields(&self, req: &RequestFields) -> syn::Fields {
97        // Is called from `generate_struct` method. The method generates structs for the actix module
98        // and those structs will be located inside `mod *_actix`. The problem with `super::` paths
99        // occurs because proto structures are located in the main module. Thus, we need to add a one more
100        // `super::` path segment for those paths to make them out of `mod *_actix`.
101        fn update_type_super_path(ty: &mut syn::Type) {
102            if let syn::Type::Path(type_path) = ty {
103                let mut super_segment_data = None;
104                for (i, segment) in type_path.path.segments.iter_mut().enumerate() {
105                    if segment.ident.to_string().as_str() == "super" {
106                        // We need to add only one additional `super` segment,
107                        // thus we are looking only the first inclusion.
108                        super_segment_data = Some((i, segment.clone()));
109                        break;
110                    }
111                    // Update segment paths in the arguments, if there are any
112                    match &mut segment.arguments {
113                        PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
114                            args,
115                            ..
116                        }) => args.iter_mut().for_each(|arg| {
117                            if let syn::GenericArgument::Type(ty) = arg {
118                                update_type_super_path(ty)
119                            }
120                        }),
121                        PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
122                            inputs,
123                            ..
124                        }) => inputs.iter_mut().for_each(update_type_super_path),
125                        PathArguments::None => {}
126                    }
127                }
128
129                // Make the actual update. We cannot do that inside `for` cycle,
130                // because `type_path.path.segments` are mutually borrowed to update arguments.
131                if let Some((index, segment)) = super_segment_data {
132                    type_path.path.segments.insert(index, segment)
133                }
134            }
135        }
136
137        let filter: HashSet<&str> = HashSet::from_iter(req.fields.iter().map(|x| x.as_ref()));
138        let fields = self
139            .message
140            .fields
141            .iter()
142            .filter(|&field| filter.contains(field.ident.as_ref().unwrap().to_string().as_str()))
143            .cloned()
144            .map(|mut field| {
145                update_type_super_path(&mut field.ty);
146                field
147            })
148            .collect();
149        let brace_token = if let syn::Fields::Named(named) = &self.message.fields {
150            named.brace_token
151        } else {
152            panic!("not named fields not supported");
153        };
154        syn::Fields::Named(syn::FieldsNamed {
155            brace_token,
156            named: fields,
157        })
158    }
159
160    pub fn path(&self) -> &RequestFields {
161        &self.path
162    }
163
164    pub fn body(&self) -> &RequestFields {
165        &self.body
166    }
167
168    pub fn query(&self) -> &RequestFields {
169        &self.query
170    }
171
172    pub fn has_sub(&self, req: &RequestFields) -> bool {
173        !req.fields.is_empty()
174    }
175
176    pub fn sub_name(&self, req: &RequestFields) -> Option<Ident> {
177        if self.has_sub(req) {
178            Some(quote::format_ident!("{}{}", self.method_name, req.name))
179        } else {
180            None
181        }
182    }
183
184    fn generate_struct(
185        &self,
186        req: &RequestFields,
187        attrs: Option<TokenStream>,
188    ) -> Option<TokenStream> {
189        self.sub_name(req).map(|name| {
190            let mut generated = self.message.clone();
191            generated.ident = name;
192            if let Some(attrs) = attrs {
193                generated.attrs.retain(|attr| {
194                    let serde: syn::Path = syn::parse_quote!(actix_prost_macros::serde);
195                    attr.path() != &serde
196                });
197                generated
198                    .attrs
199                    .push(syn::parse_quote!(#[actix_prost_macros::serde(#attrs)]));
200            }
201            generated.fields = self.filter_fields(req);
202            quote::quote!(#generated)
203        })
204    }
205
206    pub fn generate_structs(&self) -> TokenStream {
207        let path = self.generate_struct(&self.path, Some(quote::quote!(rename_all = "snake_case")));
208        let query = self.generate_struct(&self.query, None);
209        let body = self.generate_struct(&self.body, None);
210        quote::quote!(#path #query #body)
211    }
212
213    pub fn generate_fields_init(req: &RequestFields) -> impl Iterator<Item = TokenStream> + '_ {
214        req.fields
215            .iter()
216            .map(|f| quote::format_ident!("{}", f))
217            .map(|f| {
218                let field_name = quote::format_ident!("{}", req.name.to_lowercase());
219                quote::quote!(
220                    #f: #field_name.#f,
221                )
222            })
223    }
224
225    pub fn generate_new_request(&self) -> TokenStream {
226        let name = &self.message.ident;
227        let path_fields = Self::generate_fields_init(&self.path);
228        let query_fields = Self::generate_fields_init(&self.query);
229        let body_fields = Self::generate_fields_init(&self.body);
230        quote::quote!(
231            #name {
232                #(#path_fields)*
233                #(#query_fields)*
234                #(#body_fields)*
235            }
236        )
237    }
238
239    fn generate_extract(&self, req: &RequestFields) -> Option<TokenStream> {
240        let field_name = quote::format_ident!("{}", req.name.to_lowercase());
241        let extractor = quote::format_ident!("{}", req.name);
242        self.sub_name(req)
243            .map(|name| quote::quote!(
244                let #field_name = <::actix_web::web::#extractor::<#name> as ::actix_web::FromRequest>::extract(&http_request)
245                    .await
246                    .map_err(|err| ::actix_prost::Error::from_actix(err, ::tonic::Code::InvalidArgument))?
247                    .into_inner();
248            ))
249    }
250
251    fn generate_from_request(&self, req: &RequestFields) -> Option<TokenStream> {
252        let field_name = quote::format_ident!("{}", req.name.to_lowercase());
253        let extractor = quote::format_ident!("{}", req.name);
254        self.sub_name(req)
255            .map(|name| quote::quote!(
256                let #field_name = <::actix_web::web::#extractor::<#name> as ::actix_web::FromRequest>::from_request(&http_request, &mut payload)
257                    .await
258                    .map_err(|err| ::actix_prost::Error::from_actix(err, ::tonic::Code::InvalidArgument))?
259                    .into_inner();
260            ))
261    }
262
263    pub fn generate_extractors(&self) -> TokenStream {
264        let path = self.generate_extract(&self.path);
265        let query = self.generate_extract(&self.query);
266        let body = self.generate_from_request(&self.body);
267        quote::quote!(#path #query #body)
268    }
269}