dysql_tpl_derive/
lib.rs

1//! ## Dysql-tpl
2//!
3//! This is a `#[derive]` macro crate, [for documentation go to main crate](https://docs.rs/dysql-tpl).
4
5// The `quote!` macro requires deep recursion.
6#![recursion_limit = "196"]
7
8extern crate proc_macro;
9
10use bae2::FromAttributes;
11use fnv::FnvHasher;
12use proc_macro::TokenStream;
13use proc_macro2::{Span, TokenStream as TokenStream2};
14use quote::quote;
15use syn::punctuated::Punctuated;
16use syn::token::Comma;
17use syn::{Fields, ItemStruct, LitInt, LitStr, Path};
18
19use std::cmp::Ordering;
20use std::hash::{Hash, Hasher};
21
22type UnitFields = Punctuated<syn::Field, Comma>;
23
24struct Field {
25    hash: u64,
26    field: TokenStream2,
27    callback: Option<Path>,
28}
29
30impl PartialEq for Field {
31    fn eq(&self, other: &Field) -> bool {
32        self.hash == other.hash
33    }
34}
35
36impl Eq for Field {}
37
38impl PartialOrd for Field {
39    fn partial_cmp(&self, other: &Field) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for Field {
45    fn cmp(&self, other: &Field) -> Ordering {
46        self.hash.cmp(&other.hash)
47    }
48}
49
50#[derive(FromAttributes)]
51struct Ramhorns {
52    skip: Option<()>,
53    md: Option<()>,
54    flatten: Option<()>,
55    rename: Option<LitStr>,
56    callback: Option<Path>,
57}
58
59#[proc_macro_derive(Content, attributes(md, ramhorns))]
60pub fn content_derive(input: TokenStream) -> TokenStream {
61    let item: ItemStruct =
62        syn::parse(input).expect("#[derive(Content)] can be only applied to structs");
63
64    // panic!("{:#?}", item);
65
66    let name = &item.ident;
67    let generics = &item.generics;
68    let type_params = item.generics.type_params();
69    let unit_fields = UnitFields::new();
70
71    let mut errors = Vec::new();
72
73    let fields = match item.fields {
74        Fields::Named(fields) => fields.named.into_iter(),
75        Fields::Unnamed(fields) => fields.unnamed.into_iter(),
76        _ => unit_fields.into_iter(),
77    };
78
79    let mut flatten = Vec::new();
80    let md_callback: Path = syn::parse(quote!(::dysql::encoding::encode_cmark).into()).unwrap();
81    let mut fields = fields
82        .enumerate()
83        .filter_map(|(index, field)| {
84            let mut callback = None;
85            let mut rename = None;
86            let mut skip = false;
87
88            match Ramhorns::try_from_attributes(&field.attrs) {
89                Ok(Some(ramhorns)) => {
90                    if ramhorns.skip.is_some() {
91                        skip = true;
92                    }
93                    if ramhorns.md.is_some() {
94                        callback = Some(md_callback.clone());
95                    }
96                    if ramhorns.flatten.is_some() {
97                        flatten.push(field.ident.as_ref().map_or_else(
98                            || {
99                                let index = index.to_string();
100                                let lit = LitInt::new(&index, Span::call_site());
101                                quote!(#lit)
102                            },
103                            |ident| quote!(#ident),
104                        ));
105                        skip = true;
106                    }
107                    if let Some(lit_str) = ramhorns.rename {
108                        rename = Some(lit_str.value());
109                    }
110                    if let Some(path) = ramhorns.callback {
111                        callback = Some(path);
112                    }
113                },
114                Ok(None) => (),
115                Err(err) => errors.push(err),
116            };
117
118            if skip {
119                return None;
120            }
121
122            let (name, field) = field.ident.as_ref().map_or_else(
123                || {
124                    let index = index.to_string();
125                    let lit = LitInt::new(&index, Span::call_site());
126                    let name = rename.as_ref().cloned().unwrap_or(index);
127                    (name, quote!(#lit))
128                },
129                |ident| {
130                    let name = rename
131                        .as_ref()
132                        .cloned()
133                        .unwrap_or_else(|| ident.to_string());
134                    (name, quote!(#ident))
135                },
136            );
137
138            let mut hasher = FnvHasher::default();
139            name.hash(&mut hasher);
140            let hash = hasher.finish();
141
142            Some(Field {
143                hash,
144                field,
145                callback,
146            })
147        })
148        .collect::<Vec<_>>();
149
150    if !errors.is_empty() {
151        let errors: Vec<_> = errors.into_iter().map(|e| e.to_compile_error()).collect();
152        return quote! {
153            fn _ramhorns_derive_compile_errors() {
154                #(#errors)*
155            }
156        }
157        .into();
158    }
159
160    fields.sort_unstable();
161
162    let render_field_escaped = fields.iter().map(
163        |Field {
164             field,
165             hash,
166             callback,
167             ..
168         }| {
169            if let Some(callback) = callback {
170                quote! {
171                    #hash => #callback(&self.#field, encoder).map(|_| true),
172                }
173            } else {
174                quote! {
175                    #hash => self.#field.render_escaped(encoder).map(|_| true),
176                }
177            }
178        },
179    );
180
181    let render_field_unescaped = fields.iter().map(
182        |Field {
183             field,
184             hash,
185             callback,
186             ..
187         }| {
188            if let Some(callback) = callback {
189                quote! {
190                    #hash => #callback(&self.#field, encoder).map(|_| true),
191                }
192            } else {
193                quote! {
194                    #hash => self.#field.render_unescaped(encoder).map(|_| true),
195                }
196            }
197        },
198    );
199
200    let apply_field_unescaped = fields.iter().map(|Field {field, hash, ..}| {
201            quote! {
202                #hash => self.#field.apply_unescaped(),
203            }
204        },
205    );
206
207
208    let render_field_section = fields.iter().map(|Field { field, hash, .. }| {
209        quote! {
210            #hash => self.#field.render_section(section, encoder, Option::<&()>::None).map(|_| true),
211        }
212    });
213
214    // dto 获取字段值
215    let apply_field_section = fields.iter().map(|Field { field, hash, .. }| {
216        quote! {
217            #hash => self.#field.apply_section(section),
218        }
219    });
220
221    let render_field_inverse = fields.iter().map(|Field { field, hash, .. }| {
222        quote! {
223            #hash => self.#field.render_inverse(section, encoder, Option::<&()>::None).map(|_| true),
224        }
225    });
226
227    let render_field_notnone_section = fields.iter().map(|Field { field, hash, .. }| {
228        quote! {
229            // #hash => self.#field.render_notnone_section(section, encoder, Option::<&()>::None).map(|_| true),
230            // #hash => Ok(self.#field.is_truthy()),
231            #hash => {
232                self.#field.render_notnone_section(section, encoder, Option::<&()>::None)?;
233                Ok(self.#field.is_truthy())
234            }
235        }
236    });
237
238    let flatten = &*flatten;
239    let fields = fields.iter().map(|Field { field, .. }| field);
240
241    let where_clause = type_params
242        .map(|param| quote!(#param: ::dysql::Content))
243        .collect::<Vec<_>>();
244    let where_clause = if !where_clause.is_empty() {
245        quote!(where #(#where_clause),*)
246    } else {
247        quote!()
248    };
249
250    // FIXME: decouple lifetimes from actual generics with trait boundaries
251    let tokens = quote! {
252        impl#generics ::dysql::Content for #name#generics #where_clause {
253            
254            #[inline]
255            fn capacity_hint(&self, tpl: &::dysql::Template) -> usize {
256                tpl.capacity_hint() #( + self.#fields.capacity_hint(tpl) )*
257            }
258
259            #[inline]
260            fn render_section<C, E, IC>(&self, section: ::dysql::Section<C>, encoder: &mut E, _content: Option<&IC>) -> std::result::Result<(), E::Error>
261            where
262                C: ::dysql::traits::ContentSequence,
263                E: ::dysql::encoding::Encoder,
264            {
265                section.with(self).render(encoder, Option::<&()>::None)
266            }
267
268            #[inline]
269            fn apply_section<C>(&self, section: ::dysql::SimpleSection<C>) -> std::result::Result<::dysql::SimpleValue, ::dysql::SimpleError>
270            where
271                C: ::dysql::traits::ContentSequence,
272            {
273                section.with(self).apply()
274            }
275
276            #[inline]
277            fn render_notnone_section<C, E, IC>(&self, section: ::dysql::Section<C>, encoder: &mut E, _content: Option<&IC>) -> std::result::Result<(), E::Error>
278            where
279                C: ::dysql::traits::ContentSequence,
280                E: ::dysql::encoding::Encoder,
281            {
282                section.with(self).render(encoder, Option::<&()>::None)
283            }
284
285            #[inline]
286            fn render_field_escaped<E>(&self, hash: u64, name: &str, encoder: &mut E) -> std::result::Result<bool, E::Error>
287            where
288                E: ::dysql::encoding::Encoder,
289            {
290                match hash {
291                    #( #render_field_escaped )*
292                    _ => Ok(
293                        #( self.#flatten.render_field_escaped(hash, name, encoder)? ||)*
294                        false
295                    )
296                }
297            }
298
299            #[inline]
300            fn render_field_unescaped<E>(&self, hash: u64, name: &str, encoder: &mut E) -> std::result::Result<bool, E::Error>
301            where
302                E: ::dysql::encoding::Encoder,
303            {
304                match hash {
305                    #( #render_field_unescaped )*
306                    _ => Ok(
307                        #( self.#flatten.render_field_unescaped(hash, name, encoder)? ||)*
308                        false
309                    )
310                }
311            }
312
313
314            #[inline]
315            fn apply_field_unescaped(&self, hash: u64, name: &str) -> std::result::Result<dysql::SimpleValue, dysql::SimpleError>
316            {
317                match hash {
318                    #( #apply_field_unescaped )*
319                    _ => Err(dysql::SimpleInnerError(format!("the data type of field: {} is not supported ", name)).into())
320                }
321            }
322
323            fn render_field_section<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
324            where
325                P: ::dysql::traits::ContentSequence,
326                E: ::dysql::encoding::Encoder,
327            {
328                match hash {
329                    #( #render_field_section )*
330                    _ => Ok(
331                        #( self.#flatten.render_field_section(hash, name, section, encoder)? ||)*
332                        false
333                    )
334                }
335            }
336
337            fn apply_field_section<P>(&self, hash: u64, name: &str, section: ::dysql::SimpleSection<P>) -> std::result::Result<dysql::SimpleValue, dysql::SimpleError>
338            where
339                P: ::dysql::traits::ContentSequence,
340            {
341                match hash {
342                    #( #apply_field_section )*
343                    _ => Err(dysql::SimpleInnerError(format!("the data type of field is not supported")).into())
344                }
345            }
346
347            fn render_field_inverse<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
348            where
349                P: ::dysql::traits::ContentSequence,
350                E: ::dysql::encoding::Encoder,
351            {
352                match hash {
353                    #( #render_field_inverse )*
354                    _ => Ok(
355                        #( self.#flatten.render_field_inverse(hash, name, section, encoder)? ||)*
356                        false
357                    )
358                }
359            }
360
361            fn render_field_notnone_section<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
362            where
363                P: ::dysql::traits::ContentSequence,
364                E: ::dysql::encoding::Encoder,
365            {
366                match hash {
367                    #( #render_field_notnone_section )*
368                    _ => Ok(
369                        #( self.#flatten.render_field_notnone_section(hash, name, section, encoder)? ||)*
370                        false
371                    )
372                }
373            }
374        }
375    };
376
377    // panic!("{}", tokens);
378
379    TokenStream::from(tokens)
380}