1#![recursion_limit = "196"]
7
8extern crate proc_macro;
9
10use bae2::FromAttributes;
11use fnv::FnvHasher;
12use proc_macro::TokenStream;
13use proc_macro2::{Span, TokenStream as TokenStream2};
14use quote::quote;
15use syn::punctuated::Punctuated;
16use syn::token::Comma;
17use syn::{Fields, ItemStruct, LitInt, LitStr, Path};
18
19use std::cmp::Ordering;
20use std::hash::{Hash, Hasher};
21
22type UnitFields = Punctuated<syn::Field, Comma>;
23
24struct Field {
25 hash: u64,
26 field: TokenStream2,
27 callback: Option<Path>,
28}
29
30impl PartialEq for Field {
31 fn eq(&self, other: &Field) -> bool {
32 self.hash == other.hash
33 }
34}
35
36impl Eq for Field {}
37
38impl PartialOrd for Field {
39 fn partial_cmp(&self, other: &Field) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44impl Ord for Field {
45 fn cmp(&self, other: &Field) -> Ordering {
46 self.hash.cmp(&other.hash)
47 }
48}
49
50#[derive(FromAttributes)]
51struct Ramhorns {
52 skip: Option<()>,
53 md: Option<()>,
54 flatten: Option<()>,
55 rename: Option<LitStr>,
56 callback: Option<Path>,
57}
58
59#[proc_macro_derive(Content, attributes(md, ramhorns))]
60pub fn content_derive(input: TokenStream) -> TokenStream {
61 let item: ItemStruct =
62 syn::parse(input).expect("#[derive(Content)] can be only applied to structs");
63
64 let name = &item.ident;
67 let generics = &item.generics;
68 let type_params = item.generics.type_params();
69 let unit_fields = UnitFields::new();
70
71 let mut errors = Vec::new();
72
73 let fields = match item.fields {
74 Fields::Named(fields) => fields.named.into_iter(),
75 Fields::Unnamed(fields) => fields.unnamed.into_iter(),
76 _ => unit_fields.into_iter(),
77 };
78
79 let mut flatten = Vec::new();
80 let md_callback: Path = syn::parse(quote!(::dysql::encoding::encode_cmark).into()).unwrap();
81 let mut fields = fields
82 .enumerate()
83 .filter_map(|(index, field)| {
84 let mut callback = None;
85 let mut rename = None;
86 let mut skip = false;
87
88 match Ramhorns::try_from_attributes(&field.attrs) {
89 Ok(Some(ramhorns)) => {
90 if ramhorns.skip.is_some() {
91 skip = true;
92 }
93 if ramhorns.md.is_some() {
94 callback = Some(md_callback.clone());
95 }
96 if ramhorns.flatten.is_some() {
97 flatten.push(field.ident.as_ref().map_or_else(
98 || {
99 let index = index.to_string();
100 let lit = LitInt::new(&index, Span::call_site());
101 quote!(#lit)
102 },
103 |ident| quote!(#ident),
104 ));
105 skip = true;
106 }
107 if let Some(lit_str) = ramhorns.rename {
108 rename = Some(lit_str.value());
109 }
110 if let Some(path) = ramhorns.callback {
111 callback = Some(path);
112 }
113 },
114 Ok(None) => (),
115 Err(err) => errors.push(err),
116 };
117
118 if skip {
119 return None;
120 }
121
122 let (name, field) = field.ident.as_ref().map_or_else(
123 || {
124 let index = index.to_string();
125 let lit = LitInt::new(&index, Span::call_site());
126 let name = rename.as_ref().cloned().unwrap_or(index);
127 (name, quote!(#lit))
128 },
129 |ident| {
130 let name = rename
131 .as_ref()
132 .cloned()
133 .unwrap_or_else(|| ident.to_string());
134 (name, quote!(#ident))
135 },
136 );
137
138 let mut hasher = FnvHasher::default();
139 name.hash(&mut hasher);
140 let hash = hasher.finish();
141
142 Some(Field {
143 hash,
144 field,
145 callback,
146 })
147 })
148 .collect::<Vec<_>>();
149
150 if !errors.is_empty() {
151 let errors: Vec<_> = errors.into_iter().map(|e| e.to_compile_error()).collect();
152 return quote! {
153 fn _ramhorns_derive_compile_errors() {
154 #(#errors)*
155 }
156 }
157 .into();
158 }
159
160 fields.sort_unstable();
161
162 let render_field_escaped = fields.iter().map(
163 |Field {
164 field,
165 hash,
166 callback,
167 ..
168 }| {
169 if let Some(callback) = callback {
170 quote! {
171 #hash => #callback(&self.#field, encoder).map(|_| true),
172 }
173 } else {
174 quote! {
175 #hash => self.#field.render_escaped(encoder).map(|_| true),
176 }
177 }
178 },
179 );
180
181 let render_field_unescaped = fields.iter().map(
182 |Field {
183 field,
184 hash,
185 callback,
186 ..
187 }| {
188 if let Some(callback) = callback {
189 quote! {
190 #hash => #callback(&self.#field, encoder).map(|_| true),
191 }
192 } else {
193 quote! {
194 #hash => self.#field.render_unescaped(encoder).map(|_| true),
195 }
196 }
197 },
198 );
199
200 let apply_field_unescaped = fields.iter().map(|Field {field, hash, ..}| {
201 quote! {
202 #hash => self.#field.apply_unescaped(),
203 }
204 },
205 );
206
207
208 let render_field_section = fields.iter().map(|Field { field, hash, .. }| {
209 quote! {
210 #hash => self.#field.render_section(section, encoder, Option::<&()>::None).map(|_| true),
211 }
212 });
213
214 let apply_field_section = fields.iter().map(|Field { field, hash, .. }| {
216 quote! {
217 #hash => self.#field.apply_section(section),
218 }
219 });
220
221 let render_field_inverse = fields.iter().map(|Field { field, hash, .. }| {
222 quote! {
223 #hash => self.#field.render_inverse(section, encoder, Option::<&()>::None).map(|_| true),
224 }
225 });
226
227 let render_field_notnone_section = fields.iter().map(|Field { field, hash, .. }| {
228 quote! {
229 #hash => {
232 self.#field.render_notnone_section(section, encoder, Option::<&()>::None)?;
233 Ok(self.#field.is_truthy())
234 }
235 }
236 });
237
238 let flatten = &*flatten;
239 let fields = fields.iter().map(|Field { field, .. }| field);
240
241 let where_clause = type_params
242 .map(|param| quote!(#param: ::dysql::Content))
243 .collect::<Vec<_>>();
244 let where_clause = if !where_clause.is_empty() {
245 quote!(where #(#where_clause),*)
246 } else {
247 quote!()
248 };
249
250 let tokens = quote! {
252 impl#generics ::dysql::Content for #name#generics #where_clause {
253
254 #[inline]
255 fn capacity_hint(&self, tpl: &::dysql::Template) -> usize {
256 tpl.capacity_hint() #( + self.#fields.capacity_hint(tpl) )*
257 }
258
259 #[inline]
260 fn render_section<C, E, IC>(&self, section: ::dysql::Section<C>, encoder: &mut E, _content: Option<&IC>) -> std::result::Result<(), E::Error>
261 where
262 C: ::dysql::traits::ContentSequence,
263 E: ::dysql::encoding::Encoder,
264 {
265 section.with(self).render(encoder, Option::<&()>::None)
266 }
267
268 #[inline]
269 fn apply_section<C>(&self, section: ::dysql::SimpleSection<C>) -> std::result::Result<::dysql::SimpleValue, ::dysql::SimpleError>
270 where
271 C: ::dysql::traits::ContentSequence,
272 {
273 section.with(self).apply()
274 }
275
276 #[inline]
277 fn render_notnone_section<C, E, IC>(&self, section: ::dysql::Section<C>, encoder: &mut E, _content: Option<&IC>) -> std::result::Result<(), E::Error>
278 where
279 C: ::dysql::traits::ContentSequence,
280 E: ::dysql::encoding::Encoder,
281 {
282 section.with(self).render(encoder, Option::<&()>::None)
283 }
284
285 #[inline]
286 fn render_field_escaped<E>(&self, hash: u64, name: &str, encoder: &mut E) -> std::result::Result<bool, E::Error>
287 where
288 E: ::dysql::encoding::Encoder,
289 {
290 match hash {
291 #( #render_field_escaped )*
292 _ => Ok(
293 #( self.#flatten.render_field_escaped(hash, name, encoder)? ||)*
294 false
295 )
296 }
297 }
298
299 #[inline]
300 fn render_field_unescaped<E>(&self, hash: u64, name: &str, encoder: &mut E) -> std::result::Result<bool, E::Error>
301 where
302 E: ::dysql::encoding::Encoder,
303 {
304 match hash {
305 #( #render_field_unescaped )*
306 _ => Ok(
307 #( self.#flatten.render_field_unescaped(hash, name, encoder)? ||)*
308 false
309 )
310 }
311 }
312
313
314 #[inline]
315 fn apply_field_unescaped(&self, hash: u64, name: &str) -> std::result::Result<dysql::SimpleValue, dysql::SimpleError>
316 {
317 match hash {
318 #( #apply_field_unescaped )*
319 _ => Err(dysql::SimpleInnerError(format!("the data type of field: {} is not supported ", name)).into())
320 }
321 }
322
323 fn render_field_section<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
324 where
325 P: ::dysql::traits::ContentSequence,
326 E: ::dysql::encoding::Encoder,
327 {
328 match hash {
329 #( #render_field_section )*
330 _ => Ok(
331 #( self.#flatten.render_field_section(hash, name, section, encoder)? ||)*
332 false
333 )
334 }
335 }
336
337 fn apply_field_section<P>(&self, hash: u64, name: &str, section: ::dysql::SimpleSection<P>) -> std::result::Result<dysql::SimpleValue, dysql::SimpleError>
338 where
339 P: ::dysql::traits::ContentSequence,
340 {
341 match hash {
342 #( #apply_field_section )*
343 _ => Err(dysql::SimpleInnerError(format!("the data type of field is not supported")).into())
344 }
345 }
346
347 fn render_field_inverse<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
348 where
349 P: ::dysql::traits::ContentSequence,
350 E: ::dysql::encoding::Encoder,
351 {
352 match hash {
353 #( #render_field_inverse )*
354 _ => Ok(
355 #( self.#flatten.render_field_inverse(hash, name, section, encoder)? ||)*
356 false
357 )
358 }
359 }
360
361 fn render_field_notnone_section<P, E>(&self, hash: u64, name: &str, section: ::dysql::Section<P>, encoder: &mut E) -> std::result::Result<bool, E::Error>
362 where
363 P: ::dysql::traits::ContentSequence,
364 E: ::dysql::encoding::Encoder,
365 {
366 match hash {
367 #( #render_field_notnone_section )*
368 _ => Ok(
369 #( self.#flatten.render_field_notnone_section(hash, name, section, encoder)? ||)*
370 false
371 )
372 }
373 }
374 }
375 };
376
377 TokenStream::from(tokens)
380}