actix_prost_build/
request.rs1use crate::config::HttpRule;
2use proc_macro2::{Ident, TokenStream};
3use std::{collections::HashSet, iter::FromIterator};
4use syn::PathArguments;
5
6pub struct RequestFields {
7 name: String,
8 fields: Vec<String>,
9}
10
11pub struct Request {
12 message: syn::ItemStruct,
13 method_name: Ident,
14 path: RequestFields,
15 query: RequestFields,
16 body: RequestFields,
17}
18
19impl Request {
20 pub fn new(message: syn::ItemStruct, method_name: Ident, config: &HttpRule) -> Request {
21 let fields: Vec<String> = config
22 .pattern
23 .path()
24 .split('{')
25 .skip(1)
26 .filter_map(|q| q.split('}').next())
27 .map(|x| x.to_owned())
28 .collect();
29
30 let (path, query, body) = Self::split_fields(&message, &fields, &config.body);
31
32 Request {
33 message,
34 method_name,
35 path: RequestFields {
36 name: "Path".into(),
37 fields: path,
38 },
39 query: RequestFields {
40 name: "Query".into(),
41 fields: query,
42 },
43 body: RequestFields {
44 name: "Json".into(),
45 fields: body,
46 },
47 }
48 }
49
50 fn split_fields(
51 message: &syn::ItemStruct,
52 path_fields: &[String],
53 body_fields: &Option<String>,
54 ) -> (Vec<String>, Vec<String>, Vec<String>) {
55 let fields = if let syn::Fields::Named(fields) = &message.fields {
56 fields
57 } else {
58 panic!("non named fields aren't supported");
59 };
60
61 let path_filter: HashSet<&str> = HashSet::from_iter(path_fields.iter().map(|s| s.as_ref()));
62 let (path, non_path): (Vec<_>, Vec<_>) = fields
63 .named
64 .iter()
65 .map(|field| field.ident.as_ref().unwrap().to_string())
66 .partition(|field| path_filter.contains(field.as_str()));
67
68 if path_fields.len() != path.len() {
69 let found: HashSet<String> = HashSet::from_iter(path);
70 panic!(
71 "some path fields were not found: {:?}",
72 path_fields
73 .iter()
74 .filter(|f| !found.contains(f.as_str()))
75 .collect::<Vec<_>>()
76 )
77 }
78
79 let (body, query) = if let Some(body_fields) = body_fields {
80 if body_fields != "*" {
81 non_path.into_iter().partition(|f| f == body_fields)
82 } else {
83 (non_path, Vec::default())
84 }
85 } else {
86 (Vec::default(), non_path)
87 };
88
89 if path.len() + query.len() + body.len() != message.fields.len() {
90 panic!("could not map all message fields to path, query and body parts")
91 }
92
93 (path, query, body)
94 }
95
96 pub fn filter_fields(&self, req: &RequestFields) -> syn::Fields {
97 fn update_type_super_path(ty: &mut syn::Type) {
102 if let syn::Type::Path(type_path) = ty {
103 let mut super_segment_data = None;
104 for (i, segment) in type_path.path.segments.iter_mut().enumerate() {
105 if segment.ident.to_string().as_str() == "super" {
106 super_segment_data = Some((i, segment.clone()));
109 break;
110 }
111 match &mut segment.arguments {
113 PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
114 args,
115 ..
116 }) => args.iter_mut().for_each(|arg| {
117 if let syn::GenericArgument::Type(ty) = arg {
118 update_type_super_path(ty)
119 }
120 }),
121 PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
122 inputs,
123 ..
124 }) => inputs.iter_mut().for_each(update_type_super_path),
125 PathArguments::None => {}
126 }
127 }
128
129 if let Some((index, segment)) = super_segment_data {
132 type_path.path.segments.insert(index, segment)
133 }
134 }
135 }
136
137 let filter: HashSet<&str> = HashSet::from_iter(req.fields.iter().map(|x| x.as_ref()));
138 let fields = self
139 .message
140 .fields
141 .iter()
142 .filter(|&field| filter.contains(field.ident.as_ref().unwrap().to_string().as_str()))
143 .cloned()
144 .map(|mut field| {
145 update_type_super_path(&mut field.ty);
146 field
147 })
148 .collect();
149 let brace_token = if let syn::Fields::Named(named) = &self.message.fields {
150 named.brace_token
151 } else {
152 panic!("not named fields not supported");
153 };
154 syn::Fields::Named(syn::FieldsNamed {
155 brace_token,
156 named: fields,
157 })
158 }
159
160 pub fn path(&self) -> &RequestFields {
161 &self.path
162 }
163
164 pub fn body(&self) -> &RequestFields {
165 &self.body
166 }
167
168 pub fn query(&self) -> &RequestFields {
169 &self.query
170 }
171
172 pub fn has_sub(&self, req: &RequestFields) -> bool {
173 !req.fields.is_empty()
174 }
175
176 pub fn sub_name(&self, req: &RequestFields) -> Option<Ident> {
177 if self.has_sub(req) {
178 Some(quote::format_ident!("{}{}", self.method_name, req.name))
179 } else {
180 None
181 }
182 }
183
184 fn generate_struct(
185 &self,
186 req: &RequestFields,
187 attrs: Option<TokenStream>,
188 ) -> Option<TokenStream> {
189 self.sub_name(req).map(|name| {
190 let mut generated = self.message.clone();
191 generated.ident = name;
192 if let Some(attrs) = attrs {
193 generated.attrs.retain(|attr| {
194 let serde: syn::Path = syn::parse_quote!(actix_prost_macros::serde);
195 attr.path() != &serde
196 });
197 generated
198 .attrs
199 .push(syn::parse_quote!(#[actix_prost_macros::serde(#attrs)]));
200 }
201 generated.fields = self.filter_fields(req);
202 quote::quote!(#generated)
203 })
204 }
205
206 pub fn generate_structs(&self) -> TokenStream {
207 let path = self.generate_struct(&self.path, Some(quote::quote!(rename_all = "snake_case")));
208 let query = self.generate_struct(&self.query, None);
209 let body = self.generate_struct(&self.body, None);
210 quote::quote!(#path #query #body)
211 }
212
213 pub fn generate_fields_init(req: &RequestFields) -> impl Iterator<Item = TokenStream> + '_ {
214 req.fields
215 .iter()
216 .map(|f| quote::format_ident!("{}", f))
217 .map(|f| {
218 let field_name = quote::format_ident!("{}", req.name.to_lowercase());
219 quote::quote!(
220 #f: #field_name.#f,
221 )
222 })
223 }
224
225 pub fn generate_new_request(&self) -> TokenStream {
226 let name = &self.message.ident;
227 let path_fields = Self::generate_fields_init(&self.path);
228 let query_fields = Self::generate_fields_init(&self.query);
229 let body_fields = Self::generate_fields_init(&self.body);
230 quote::quote!(
231 #name {
232 #(#path_fields)*
233 #(#query_fields)*
234 #(#body_fields)*
235 }
236 )
237 }
238
239 fn generate_extract(&self, req: &RequestFields) -> Option<TokenStream> {
240 let field_name = quote::format_ident!("{}", req.name.to_lowercase());
241 let extractor = quote::format_ident!("{}", req.name);
242 self.sub_name(req)
243 .map(|name| quote::quote!(
244 let #field_name = <::actix_web::web::#extractor::<#name> as ::actix_web::FromRequest>::extract(&http_request)
245 .await
246 .map_err(|err| ::actix_prost::Error::from_actix(err, ::tonic::Code::InvalidArgument))?
247 .into_inner();
248 ))
249 }
250
251 fn generate_from_request(&self, req: &RequestFields) -> Option<TokenStream> {
252 let field_name = quote::format_ident!("{}", req.name.to_lowercase());
253 let extractor = quote::format_ident!("{}", req.name);
254 self.sub_name(req)
255 .map(|name| quote::quote!(
256 let #field_name = <::actix_web::web::#extractor::<#name> as ::actix_web::FromRequest>::from_request(&http_request, &mut payload)
257 .await
258 .map_err(|err| ::actix_prost::Error::from_actix(err, ::tonic::Code::InvalidArgument))?
259 .into_inner();
260 ))
261 }
262
263 pub fn generate_extractors(&self) -> TokenStream {
264 let path = self.generate_extract(&self.path);
265 let query = self.generate_extract(&self.query);
266 let body = self.generate_from_request(&self.body);
267 quote::quote!(#path #query #body)
268 }
269}