Skip to main content

tarantool_proc/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_error2::{proc_macro_error, SpanRange};
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, AttributeArgs, DeriveInput,
7    FnArg, Ident, Item, ItemFn, Signature, Token,
8};
9
10// https://git.picodata.io/picodata/picodata/tarantool-module/-/merge_requests/505#note_78473
11macro_rules! unwrap_or_compile_error {
12    ($expr:expr) => {
13        match $expr {
14            Ok(v) => v,
15            Err(e) => return e.to_compile_error().into(),
16        }
17    };
18}
19
20fn default_tarantool_crate_path() -> syn::Path {
21    parse_quote! { tarantool }
22}
23
24mod test;
25
26/// Mark a function as a test.
27///
28/// See `tarantool::test` doc-comments in tarantool crate for details.
29#[proc_macro_attribute]
30pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
31    test::impl_macro_attribute(attr, item)
32}
33
34mod msgpack {
35
36    use darling::FromDeriveInput;
37    use proc_macro2::TokenStream;
38    use proc_macro_error2::{abort, SpanRange};
39    use quote::{format_ident, quote, quote_spanned, ToTokens};
40    use syn::{
41        parse_quote, spanned::Spanned, Data, Field, Fields, FieldsNamed, FieldsUnnamed,
42        GenericParam, Generics, Ident, Index, Path, Type, Variant,
43    };
44
45    #[derive(Default, FromDeriveInput)]
46    #[darling(attributes(encode), default)]
47    pub struct Args {
48        /// Whether this struct should be serialized as MP_MAP instead of MP_ARRAY.
49        pub as_map: bool,
50        /// Path to tarantool crate.
51        pub tarantool: Option<String>,
52        /// Allows optional fields of unnamed structs to be decoded if values are not presented.
53        pub allow_array_optionals: bool,
54        /// <https://serde.rs/enum-representations.html#untagged>
55        pub untagged: bool,
56    }
57
58    pub fn add_trait_bounds(mut generics: Generics, tarantool_crate: &Path) -> Generics {
59        for param in &mut generics.params {
60            if let GenericParam::Type(ref mut type_param) = *param {
61                type_param
62                    .bounds
63                    .push(parse_quote!(#tarantool_crate::msgpack::Encode));
64            }
65        }
66        generics
67    }
68
69    trait TypeExt {
70        fn is_option(&self) -> bool;
71    }
72
73    impl TypeExt for Type {
74        fn is_option(&self) -> bool {
75            if let Type::Path(ref typepath) = self {
76                typepath
77                    .path
78                    .segments
79                    .last()
80                    .map(|segment| segment.ident == "Option")
81                    .unwrap_or(false)
82            } else {
83                false
84            }
85        }
86    }
87
88    /// Defines how field will be encoded or decoded according to attribute on it.
89    enum FieldAttr {
90        /// Field should be serialized without any check of internal value whatsoever.
91        Raw,
92        /// TODO: Field should be serialized as MP_MAP, ignoring struct-level serialization type.
93        Map,
94        /// TODO: Field should be serialized as MP_ARRAY, ignoring struct-level serialization type.
95        Vec,
96        /// Field should use a default value if missing during decoding
97        Default,
98    }
99
100    impl FieldAttr {
101        /// Returns appropriate `Some(FieldAttr)` for this field according to attribute on it, `None` if
102        /// no attribute was on a field, or errors if attribute encoding type is empty/multiple/wrong.
103        #[inline]
104        fn from_field(field: &Field) -> Result<Option<Self>, syn::Error> {
105            let attrs = &field.attrs;
106
107            let mut encode_attr = None;
108
109            for attr in attrs.iter().filter(|attr| attr.path.is_ident("encode")) {
110                if encode_attr.is_some() {
111                    return Err(syn::Error::new(
112                        attr.span(),
113                        "multiple encoding types are not allowed",
114                    ));
115                }
116
117                encode_attr = Some(attr);
118            }
119
120            match encode_attr {
121                Some(attr) => attr.parse_args_with(|input: syn::parse::ParseStream| {
122                    if input.is_empty() {
123                        return Err(syn::Error::new(
124                            input.span(),
125                            "empty encoding type is not allowed",
126                        ));
127                    }
128
129                    let ident: Ident = input.parse()?;
130
131                    if !input.is_empty() {
132                        return Err(syn::Error::new(
133                            ident.span(),
134                            "multiple encoding types are not allowed",
135                        ));
136                    }
137
138                    if ident == "as_raw" {
139                        let mut field_type_name = proc_macro2::TokenStream::new();
140                        field.ty.to_tokens(&mut field_type_name);
141                        if field_type_name.to_string() != "Vec < u8 >" {
142                            Err(syn::Error::new(
143                                ident.span(),
144                                "only `Vec<u8>` is supported for `as_raw`",
145                            ))
146                        } else {
147                            Ok(Some(Self::Raw))
148                        }
149                    } else if ident == "as_map" {
150                        Ok(Some(Self::Map))
151                    } else if ident == "as_vec" {
152                        Ok(Some(Self::Vec))
153                    } else if ident == "default" {
154                        Ok(Some(Self::Default))
155                    } else {
156                        Err(syn::Error::new(ident.span(), "unknown encoding type"))
157                    }
158                }),
159                None => Ok(None),
160            }
161        }
162    }
163
164    /// Defines how an enum variant will be encoded or decoded according to attributes on it.
165    enum VariantAttr {
166        /// Rename a variant.
167        Rename(String),
168    }
169
170    impl VariantAttr {
171        #[inline]
172        fn from_variant(variant: &Variant) -> Result<Option<Self>, syn::Error> {
173            let attrs = &variant.attrs;
174
175            let mut encode_attr = None;
176
177            for attr in attrs.iter().filter(|attr| attr.path.is_ident("encode")) {
178                if encode_attr.is_some() {
179                    return Err(syn::Error::new(
180                        attr.span(),
181                        "multiple encoding types are not allowed",
182                    ));
183                }
184
185                encode_attr = Some(attr);
186            }
187
188            match encode_attr {
189                Some(attr) => attr.parse_args_with(|input: syn::parse::ParseStream| {
190                    let ident: Ident = input.parse()?;
191                    input.parse::<syn::Token![=]>()?;
192
193                    if ident != "rename" {
194                        return Err(syn::Error::new(ident.span(), "expected `rename`"));
195                    }
196
197                    let lit: syn::LitStr = input.parse()?;
198
199                    Ok(Some(Self::Rename(lit.value())))
200                }),
201                None => Ok(None),
202            }
203        }
204    }
205
206    fn encode_named_fields(
207        fields: &FieldsNamed,
208        tarantool_crate: &Path,
209        add_self: bool,
210    ) -> proc_macro2::TokenStream {
211        fields
212            .named
213            .iter()
214            .flat_map(|f| {
215                let field_name = f.ident.as_ref().expect("only named fields here");
216                let field_repr = format_ident!("{}", field_name).to_string();
217                let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(f));
218
219                let s = if add_self {
220                    quote! {&self.}
221                } else {
222                    quote! {}
223                };
224
225                let write_key = quote_spanned! {f.span()=>
226                    if as_map {
227                        #tarantool_crate::msgpack::rmp::encode::write_str(w, #field_repr)?;
228                    }
229                };
230                if let Some(attr) = field_attr {
231                    match attr {
232                        FieldAttr::Raw => quote_spanned! {f.span()=>
233                            #write_key
234                            w.write_all(#s #field_name)?;
235                        },
236                        // TODO: encode with `#[encode(as_map)]` and `#[encode(as_vec)]`
237                        FieldAttr::Map => {
238                            syn::Error::new(f.span(), "`as_map` is not currently supported")
239                                .to_compile_error()
240                        }
241                        FieldAttr::Vec => {
242                            syn::Error::new(f.span(), "`as_vec` is not currently supported")
243                                .to_compile_error()
244                        }
245                        FieldAttr::Default => quote_spanned! {f.span()=>
246                            #write_key
247                            #tarantool_crate::msgpack::Encode::encode(#s #field_name, w, context)?;
248                        },
249                    }
250                } else {
251                    quote_spanned! {f.span()=>
252                        #write_key
253                        #tarantool_crate::msgpack::Encode::encode(#s #field_name, w, context)?;
254                    }
255                }
256            })
257            .collect()
258    }
259
260    fn encode_unnamed_fields(
261        fields: &FieldsUnnamed,
262        tarantool_crate: &Path,
263    ) -> proc_macro2::TokenStream {
264        fields
265            .unnamed
266            .iter()
267            .enumerate()
268            .flat_map(|(i, f)| {
269                let index = Index::from(i);
270                let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(f));
271
272                if let Some(field) = field_attr {
273                    match field {
274                        FieldAttr::Raw => quote_spanned! {f.span()=>
275                            w.write_all(&self.#index)?;
276                        },
277                        // TODO: encode with `#[encode(as_map)]` and `#[encode(as_vec)]`
278                        FieldAttr::Map => {
279                            syn::Error::new(f.span(), "`as_map` is not currently supported")
280                                .to_compile_error()
281                        }
282                        FieldAttr::Vec => {
283                            syn::Error::new(f.span(), "`as_vec` is not currently supported")
284                                .to_compile_error()
285                        }
286                        FieldAttr::Default => quote_spanned! {f.span()=>
287                            #tarantool_crate::msgpack::Encode::encode(&self.#index, w, context)?;
288                        },
289                    }
290                } else {
291                    quote_spanned! {f.span()=>
292                        #tarantool_crate::msgpack::Encode::encode(&self.#index, w, context)?;
293                    }
294                }
295            })
296            .collect()
297    }
298
299    pub fn encode_fields(
300        data: &Data,
301        tarantool_crate: &Path,
302        attrs_span: impl Fn() -> SpanRange,
303        args: &Args,
304    ) -> proc_macro2::TokenStream {
305        let as_map = args.as_map;
306        let is_untagged = args.untagged;
307        match *data {
308            Data::Struct(ref data) => {
309                if is_untagged {
310                    abort!(
311                        attrs_span(),
312                        "untagged encode representation is allowed only for enums"
313                    );
314                }
315                match data.fields {
316                    Fields::Named(ref fields) => {
317                        let field_count = fields.named.len() as u32;
318                        let fields = encode_named_fields(fields, tarantool_crate, true);
319                        quote! {
320                            let as_map = match context.struct_style() {
321                                StructStyle::Default => #as_map,
322                                StructStyle::ForceAsMap => true,
323                                StructStyle::ForceAsArray => false,
324                            };
325                            if as_map {
326                                #tarantool_crate::msgpack::rmp::encode::write_map_len(w, #field_count)?;
327                            } else {
328                                #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
329                            }
330                            #fields
331                        }
332                    }
333                    Fields::Unnamed(ref fields) => {
334                        if as_map {
335                            abort!(
336                                attrs_span(),
337                                "`as_map` attribute can be specified only for structs with named fields"
338                            );
339                        }
340                        let field_count = fields.unnamed.len() as u32;
341                        let fields = encode_unnamed_fields(fields, tarantool_crate);
342                        quote! {
343                            #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
344                            #fields
345                        }
346                    }
347                    Fields::Unit => {
348                        quote!(#tarantool_crate::msgpack::Encode::encode(&(), w, context)?;)
349                    }
350                }
351            }
352            Data::Enum(ref variants) => {
353                if as_map {
354                    abort!(
355                        attrs_span(),
356                        "`as_map` attribute can be specified only for structs"
357                    );
358                }
359                let variants: proc_macro2::TokenStream = variants
360                    .variants
361                    .iter()
362                    .flat_map(|variant| {
363                        let variant_name = &variant.ident;
364                        let attr = unwrap_or_compile_error!(VariantAttr::from_variant(variant));
365                        let variant_repr = if let Some(VariantAttr::Rename(new_name)) = attr {
366                            new_name
367                        } else {
368                            format_ident!("{}", variant_name).to_string()
369                        };
370                        match variant.fields {
371                            Fields::Named(ref fields) => {
372                                let field_count = fields.named.len() as u32;
373                                let field_names = fields.named.iter().map(|field| field.ident.clone());
374                                let fields = encode_named_fields(fields, tarantool_crate, false);
375                                // TODO: allow `#[encode(as_map)]` for struct variants
376                                if is_untagged {
377                                    quote! {
378                                        Self::#variant_name { #(#field_names),*} => {
379                                            #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
380                                            let as_map = false;
381                                            #fields
382                                        }
383                                    }
384                                } else {
385                                    quote! {
386                                        Self::#variant_name { #(#field_names),*} => {
387                                            #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
388                                            #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
389                                            let as_map = false;
390                                            #fields
391                                        }
392                                    }
393                                }
394                            },
395                            Fields::Unnamed(ref fields) => {
396                                let field_count = fields.unnamed.len() as u32;
397                                let field_names = fields.unnamed.iter().enumerate().map(|(i, _)| format_ident!("_field_{}", i));
398                                let fields: proc_macro2::TokenStream = field_names.clone()
399                                    .flat_map(|field_name| quote! {
400                                        #tarantool_crate::msgpack::Encode::encode(#field_name, w, context)?;
401                                    })
402                                    .collect();
403                                if is_untagged {
404                                    quote! {
405                                        Self::#variant_name ( #(#field_names),*) => {
406                                            #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
407                                            #fields
408                                        }
409                                    }
410                                } else {
411                                    quote! {
412                                        Self::#variant_name ( #(#field_names),*) => {
413                                            #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
414                                            #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
415                                            #fields
416                                        }
417                                    }
418                                }
419                            }
420                            Fields::Unit => {
421                                if is_untagged {
422                                    quote! {
423                                        Self::#variant_name => #tarantool_crate::msgpack::Encode::encode(&(), w, context)?,
424                                    }
425                                } else {
426                                    quote! {
427                                        Self::#variant_name => {
428                                            #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
429                                            #tarantool_crate::msgpack::Encode::encode(&(), w, context)?;
430                                        }
431                                    }
432                                }
433                            },
434                        }
435                    })
436                    .collect();
437                if is_untagged {
438                    quote! {
439                        match self {
440                            #variants
441                        }
442                    }
443                } else {
444                    quote! {
445                        #tarantool_crate::msgpack::rmp::encode::write_map_len(w, 1)?;
446                        match self {
447                            #variants
448                        }
449                    }
450                }
451            }
452            Data::Union(_) => unimplemented!(),
453        }
454    }
455
456    fn decode_named_fields(
457        fields: &FieldsNamed,
458        tarantool_crate: &Path,
459        enum_variant: Option<&syn::Ident>,
460        args: &Args,
461    ) -> TokenStream {
462        let allow_array_optionals = args.allow_array_optionals;
463
464        let mut var_names = Vec::with_capacity(fields.named.len());
465        let mut met_option = false;
466        let fields_amount = fields.named.len();
467        let mut fields_passed = fields_amount;
468        let code: TokenStream = fields
469            .named
470            .iter()
471            .map(|f| {
472                if f.ty.is_option() {
473                    met_option = true;
474                    fields_passed -= 1;
475                    decode_named_optional_field(f, tarantool_crate, &mut var_names, allow_array_optionals, fields_amount, fields_passed)
476                } else {
477                    if met_option && allow_array_optionals {
478                        return syn::Error::new(
479                            f.span(),
480                            "optional fields must be the last in the parameter list if allow_array_optionals is enabled",
481                        )
482                        .to_compile_error();
483                    }
484                    fields_passed -= 1;
485                    decode_named_required_field(f, tarantool_crate, &mut var_names)
486                }
487            })
488            .collect();
489        let field_names = fields.named.iter().map(|f| &f.ident);
490        let enum_variant = if let Some(variant) = enum_variant {
491            quote! { ::#variant }
492        } else {
493            quote! {}
494        };
495        quote! {
496            #code
497            Ok(Self #enum_variant {
498                #(#field_names: #var_names),*
499            })
500        }
501    }
502
503    #[inline]
504    fn decode_named_optional_field(
505        field: &Field,
506        tarantool_crate: &Path,
507        names: &mut Vec<Ident>,
508        allow_array_optionals: bool,
509        fields_amount: usize,
510        fields_passed: usize,
511    ) -> TokenStream {
512        let field_type = &field.ty;
513        let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
514
515        let field_ident = field.ident.as_ref().expect("only named fields here");
516        let field_repr = format_ident!("{}", field_ident).to_string();
517        let field_name = proc_macro2::Literal::byte_string(field_repr.as_bytes());
518        let var_name = format_ident!("_field_{}", field_ident);
519
520        let read_key = quote_spanned! {field.span()=>
521            if as_map {
522                use #tarantool_crate::msgpack::str_bounds;
523
524                let (byte_len, field_name_len_spaced) = str_bounds(r)
525                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part("field name"))?;
526                let decoded_field_name = r.get(byte_len..field_name_len_spaced).unwrap();
527                if decoded_field_name != #field_name {
528                    is_none = true;
529                } else {
530                    let len = rmp::decode::read_str_len(r).unwrap();
531                    *r = &r[(len as usize)..]; // advance if matches field name
532                }
533            }
534        };
535
536        // TODO: allow `#[encode(as_map)]` and `#[encode(as_vec)]` for struct fields
537        let out = match field_attr {
538            Some(FieldAttr::Map) => unimplemented!("`as_map` is not currently supported"),
539            Some(FieldAttr::Vec) => unimplemented!("`as_vec` is not currently supported"),
540            Some(FieldAttr::Default) => {
541                panic!("optional fields are marked `#[encode(default)]` by default")
542            }
543            Some(FieldAttr::Raw) => quote_spanned! {field.span()=>
544                    let mut #var_name: #field_type = None;
545                    let mut is_none = false;
546
547                    #read_key
548                    if !is_none {
549                        #var_name = Some(#tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here"));
550                    }
551            },
552            None => quote_spanned! {field.span()=>
553                let mut #var_name: #field_type = None;
554                let mut is_none = false;
555
556                #read_key
557                if !is_none {
558                    match #tarantool_crate::msgpack::Decode::decode(r, context) {
559                        Ok(val) => #var_name = Some(val),
560                        Err(err) => {
561                            let markered = err.source.get(err.source.len() - 33..).unwrap_or("") == "failed to read MessagePack marker";
562                            let nulled = if err.part.is_some() {
563                                err.part.as_ref().expect("Can't fail after a conditional check") == "got Null"
564                            } else {
565                                false
566                            };
567
568                            if !nulled && !#allow_array_optionals && !as_map {
569                                let message = format!("not enough fields, expected {}, got {} (note: optional fields must be explicitly null unless `allow_array_optionals` attribute is passed)", #fields_amount, #fields_passed);
570                                Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(message))?;
571                            } else if !nulled && !markered && #allow_array_optionals {
572                                Err(err)?;
573                            }
574                        },
575                    }
576                }
577            },
578        };
579
580        names.push(var_name);
581        out
582    }
583
584    #[inline]
585    fn decode_named_required_field(
586        field: &Field,
587        tarantool_crate: &Path,
588        names: &mut Vec<Ident>,
589    ) -> TokenStream {
590        let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
591
592        let field_ident = field.ident.as_ref().expect("only named fields here");
593        let field_repr = format_ident!("{}", field_ident).to_string();
594        let field_name = proc_macro2::Literal::byte_string(field_repr.as_bytes());
595        let var_name = format_ident!("_field_{}", field_ident);
596
597        let mut read_key = quote_spanned! {field.span()=>
598            if as_map {
599                let len = rmp::decode::read_str_len(r)
600                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err).with_part("field name"))?;
601                let decoded_field_name = r.get(0..(len as usize))
602                    .ok_or_else(|| #tarantool_crate::msgpack::DecodeError::new::<Self>("not enough data").with_part("field name"))?;
603                if decoded_field_name != #field_name {
604                    let field_name = String::from_utf8(#field_name.to_vec()).expect("is valid utf8");
605                    let err = if let Ok(decoded_field_name) = String::from_utf8(decoded_field_name.to_vec()) {
606                        format!("expected field {}, got {}", field_name, decoded_field_name)
607                    } else {
608                        format!("expected field {}, got invalid utf8 {:?}", field_name, decoded_field_name)
609                    };
610                    return Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err));
611                } else {
612                    *r = &r[(len as usize)..]; // advance
613                }
614            }
615        };
616
617        // TODO: allow `#[encode(as_map)]` and `#[encode(as_vec)]` for struct fields
618        let out = if let Some(FieldAttr::Raw) = field_attr {
619            quote_spanned! {field.span()=>
620                #read_key
621                let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
622            }
623        } else if let Some(FieldAttr::Map) = field_attr {
624            unimplemented!("`as_map` is not currently supported");
625        } else if let Some(FieldAttr::Vec) = field_attr {
626            unimplemented!("`as_vec` is not currently supported");
627        } else if let Some(FieldAttr::Default) = field_attr {
628            read_key = quote_spanned! {field.span()=>
629                let mut skip = false;
630                if as_map {
631                    let mut tmp = *r;
632                    let len = rmp::decode::read_str_len(&mut tmp)
633                        .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err).with_part("field name"))?;
634                    let decoded_field_name = tmp.get(0..(len as usize))
635                        .ok_or_else(|| #tarantool_crate::msgpack::DecodeError::new::<Self>("not enough data").with_part("field name"))?;
636                    if decoded_field_name != #field_name {
637                        skip = true;
638                    } else {
639                        *r = &tmp[(len as usize)..]; // advance
640                    }
641                }
642            };
643
644            quote_spanned! {field.span()=>
645                #read_key
646                let #var_name = if skip {
647                    Default::default()
648                } else {
649                    let mut tmp = *r;
650                    match #tarantool_crate::msgpack::Decode::decode(&mut tmp, context) {
651                        Ok(value) => {
652                            *r = tmp;
653                            value
654                        },
655                        Err(_) => Default::default(),
656                    }
657                };
658            }
659        } else {
660            quote_spanned! {field.span()=>
661                #read_key
662                let #var_name = #tarantool_crate::msgpack::Decode::decode(r, context)
663                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", stringify!(#field_ident))))?;
664            }
665        };
666
667        names.push(var_name);
668        out
669    }
670
671    fn decode_unnamed_fields(
672        fields: &FieldsUnnamed,
673        tarantool_crate: &Path,
674        enum_variant: Option<&syn::Ident>,
675        args: &Args,
676    ) -> proc_macro2::TokenStream {
677        let allow_array_optionals = args.allow_array_optionals;
678
679        let mut var_names = Vec::with_capacity(fields.unnamed.len());
680        let mut met_option = false;
681        let code: proc_macro2::TokenStream = fields
682            .unnamed
683            .iter()
684            .enumerate()
685            .map(|(i, f)| {
686                let is_option = f.ty.is_option();
687                if is_option {
688                    met_option = true;
689                    decode_unnamed_optional_field(f, i, tarantool_crate, &mut var_names)
690                } else if met_option && allow_array_optionals {
691                    syn::Error::new(
692                        f.span(),
693                        "optional fields must be the last in the parameter list with `allow_array_optionals` attribute",
694                    )
695                    .to_compile_error()
696                } else {
697                    decode_unnamed_required_field(f, i, tarantool_crate, &mut var_names)
698                }
699            })
700            .collect();
701        let enum_variant = if let Some(variant) = enum_variant {
702            quote! { ::#variant }
703        } else {
704            quote! {}
705        };
706        quote! {
707            #code
708            Ok(Self #enum_variant (
709                #(#var_names),*
710            ))
711        }
712    }
713
714    fn decode_unnamed_optional_field(
715        field: &Field,
716        index: usize,
717        tarantool_crate: &Path,
718        names: &mut Vec<Ident>,
719    ) -> TokenStream {
720        let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
721        let field_type = &field.ty;
722
723        let field_index = Index::from(index);
724        let var_name = quote::format_ident!("_field_{}", field_index);
725
726        let out = match field_attr {
727            Some(FieldAttr::Map) => unimplemented!("`as_map` is not currently supported"),
728            Some(FieldAttr::Vec) => unimplemented!("`as_vec` is not currently supported"),
729            Some(FieldAttr::Default) => {
730                panic!("optional fields are marked `#[encode(default)]` by default")
731            }
732            Some(FieldAttr::Raw) => quote_spanned! {field.span()=>
733                let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
734            },
735            None => quote_spanned! {field.span()=>
736                let mut #var_name: #field_type = None;
737                match #tarantool_crate::msgpack::Decode::decode(r, context) {
738                    Ok(val) => #var_name = Some(val),
739                    Err(err) => {
740                        let markered = err.source.get(err.source.len() - 33..).unwrap_or("")== "failed to read MessagePack marker";
741                        let nulled = if err.part.is_some() {
742                            err.part.as_ref().expect("Can't fail after a conditional check") == "got Null"
743                        } else {
744                            false
745                        };
746
747                        if !nulled && !markered {
748                            Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("{}", stringify!(#field_index))))?;
749                        }
750                    },
751                }
752            },
753        };
754
755        names.push(var_name);
756        out
757    }
758
759    fn decode_unnamed_required_field(
760        field: &Field,
761        index: usize,
762        tarantool_crate: &Path,
763        names: &mut Vec<Ident>,
764    ) -> TokenStream {
765        let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
766
767        let field_index = Index::from(index);
768        let var_name = quote::format_ident!("_field_{}", field_index);
769
770        let out = if let Some(FieldAttr::Raw) = field_attr {
771            quote_spanned! {field.span()=>
772                let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
773            }
774        } else if let Some(FieldAttr::Map) = field_attr {
775            unimplemented!("`as_map` is not currently supported");
776        } else if let Some(FieldAttr::Vec) = field_attr {
777            unimplemented!("`as_vec` is not currently supported");
778        } else if let Some(FieldAttr::Default) = field_attr {
779            quote_spanned! {field.span()=>
780                let mut tmp = *r;
781                let res = #tarantool_crate::msgpack::Decode::decode(&mut tmp, context);
782                let #var_name = match res {
783                    Ok(v) => {
784                        *r = &tmp;
785                        v
786                    },
787                    Err(_) => Default::default(),
788                };
789            }
790        } else {
791            quote_spanned! {field.span()=>
792                let #var_name = #tarantool_crate::msgpack::Decode::decode(r, context)
793                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", #index)))?;
794            }
795        };
796
797        names.push(var_name);
798        out
799    }
800
801    pub fn decode_fields(
802        data: &Data,
803        tarantool_crate: &Path,
804        attrs_span: impl Fn() -> SpanRange,
805        args: &Args,
806    ) -> TokenStream {
807        let as_map = args.as_map;
808        let is_untagged = args.untagged;
809
810        if is_untagged {
811            return decode_untagged(data, tarantool_crate, attrs_span);
812        }
813
814        match *data {
815            Data::Struct(ref data) => {
816                match data.fields {
817                    Fields::Named(ref fields) => {
818                        let first_field_name = fields
819                            .named
820                            .first()
821                            .expect("not a unit struct")
822                            .ident
823                            .as_ref()
824                            .expect("not an unnamed struct")
825                            .to_string();
826                        let fields = decode_named_fields(fields, tarantool_crate, None, args);
827                        quote! {
828                            let as_map = match context.struct_style() {
829                                StructStyle::Default => #as_map,
830                                StructStyle::ForceAsMap => true,
831                                StructStyle::ForceAsArray => false,
832                            };
833                            // TODO: Assert map and array len with number of struct fields
834                            if as_map {
835                                #tarantool_crate::msgpack::rmp::decode::read_map_len(r)
836                                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
837                            } else {
838                                #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
839                                    .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre_with_field::<Self>(err, #first_field_name))?;
840                            }
841                            #fields
842                        }
843                    }
844                    Fields::Unnamed(ref fields) => {
845                        if as_map {
846                            abort!(
847                                attrs_span(),
848                                "`as_map` attribute can be specified only for structs with named fields"
849                            );
850                        }
851
852                        let mut option_key = TokenStream::new();
853                        if fields.unnamed.len() == 1 {
854                            let first_field = fields.unnamed.first().expect("len is sufficient");
855                            let is_option = first_field.ty.is_option();
856                            if is_option {
857                                option_key = quote! {
858                                    if r.is_empty() {
859                                        return Ok(Self(None));
860                                    }
861                                };
862                            }
863                        }
864
865                        let fields = decode_unnamed_fields(fields, tarantool_crate, None, args);
866                        quote! {
867                            #option_key
868                            #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
869                                .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
870                            #fields
871                        }
872                    }
873                    Fields::Unit => {
874                        quote! {
875                            let () = #tarantool_crate::msgpack::Decode::decode(r, context)?;
876                            Ok(Self)
877                        }
878                    }
879                }
880            }
881            Data::Enum(ref variants) => {
882                if as_map {
883                    abort!(
884                        attrs_span(),
885                        "`as_map` attribute can be specified only for structs"
886                    );
887                }
888                let mut variant_reprs = Vec::new();
889                let variants: proc_macro2::TokenStream = variants
890                    .variants
891                    .iter()
892                    .flat_map(|variant| {
893                        let variant_name = &variant.ident;
894                        let attr = unwrap_or_compile_error!(VariantAttr::from_variant(variant));
895                        let variant_repr = if let Some(VariantAttr::Rename(new_name)) = attr {
896                            new_name
897                        } else {
898                            format_ident!("{}", variant_name).to_string()
899                        };
900                        variant_reprs.push(variant_repr.clone());
901                        let variant_repr = proc_macro2::Literal::byte_string(variant_repr.as_bytes());
902
903                        match variant.fields {
904                            Fields::Named(ref fields) => {
905                                let fields = decode_named_fields(fields, tarantool_crate, Some(&variant.ident), args);
906                                // TODO: allow `#[encode(as_map)]` for struct variants
907                                quote! {
908                                    #variant_repr => {
909                                        #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
910                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
911                                        let as_map = false;
912                                        #fields
913                                    }
914                                }
915                            },
916                            Fields::Unnamed(ref fields) => {
917                                let fields = decode_unnamed_fields(fields, tarantool_crate, Some(&variant.ident), args);
918                                quote! {
919                                    #variant_repr => {
920                                        #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
921                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
922                                        let as_map = false;
923                                        #fields
924                                    }
925                                }
926                            }
927                            Fields::Unit => {
928                                quote! {
929                                    #variant_repr => {
930                                        let () = #tarantool_crate::msgpack::Decode::decode(r, context)
931                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
932                                        Ok(Self::#variant_name)
933                                    }
934                                }
935                            },
936                        }
937                    })
938                    .collect();
939                quote! {
940                    // TODO: assert map len 1
941                    #tarantool_crate::msgpack::rmp::decode::read_map_len(r)
942                        .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
943                    let len = rmp::decode::read_str_len(r)
944                        .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err).with_part("variant name"))?;
945                    let variant_name = r.get(0..(len as usize))
946                        .ok_or_else(|| #tarantool_crate::msgpack::DecodeError::new::<Self>("not enough data").with_part("variant name"))?;
947                    *r = &r[(len as usize)..]; // advance
948                    match variant_name {
949                        #variants
950                        other => {
951                            let err = if let Ok(other) = String::from_utf8(other.to_vec()) {
952                                format!("enum variant {} does not exist", other)
953                            } else {
954                                format!("enum variant {:?} is invalid utf8", other)
955                            };
956                            return Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err));
957                        }
958                    }
959                }
960            }
961            Data::Union(_) => unimplemented!(),
962        }
963    }
964
965    pub fn decode_untagged(
966        data: &Data,
967        tarantool_crate: &Path,
968        attrs_span: impl Fn() -> SpanRange,
969    ) -> TokenStream {
970        let out = match *data {
971            Data::Struct(_) => abort!(
972                attrs_span(),
973                "untagged decode representation is allowed only for enums"
974            ),
975            Data::Union(_) => unimplemented!(),
976            Data::Enum(ref variants) => {
977                let variants = variants.variants.iter();
978                let variants_amount = variants.len();
979                if variants_amount == 0 {
980                    abort!(
981                        attrs_span(),
982                        "deserialization of enum with no variants is not possible"
983                    );
984                }
985
986                variants
987                    .flat_map(|variant| {
988                        let variant_ident = &variant.ident;
989
990                        match variant.fields {
991                            Fields::Unit => {
992                                quote! {
993                                    // https://doc.rust-lang.org/beta/unstable-book/language-features/try-blocks.html (2016 -_-)
994                                    let mut r_try = *r;
995                                    let mut try_unit = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
996                                        let () = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
997                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
998                                        checker = Some(Self::#variant_ident);
999                                        Ok(())
1000                                    };
1001
1002                                    if try_unit().is_ok() {
1003                                        *r = r_try;
1004                                        return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
1005                                    }
1006                                }
1007                            },
1008                            Fields::Unnamed(ref fields) => {
1009                                let fields = &fields.unnamed;
1010                                let fields_amount = fields.len();
1011                                let mut var_names = Vec::with_capacity(fields.len());
1012                                let code: TokenStream = fields
1013                                    .iter()
1014                                    .enumerate()
1015                                    .map(|(index, field)| {
1016                                        let field_index = Index::from(index);
1017                                        let var_name = quote::format_ident!("_field_{}", field_index);
1018                                        let var_type = &field.ty;
1019
1020                                        let out = quote_spanned! {field.span()=>
1021                                            let #var_name: #var_type = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
1022                                                .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", #index)))?;
1023                                        };
1024
1025                                        var_names.push(var_name);
1026                                        out
1027                                    })
1028                                    .collect();
1029                                quote! {
1030                                    // https://doc.rust-lang.org/beta/unstable-book/language-features/try-blocks.html (2016 -_-)
1031                                    let mut r_try = *r;
1032                                    let mut try_unnamed = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
1033                                        let amount = #tarantool_crate::msgpack::rmp::decode::read_array_len(&mut r_try)
1034                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
1035                                        if amount as usize != #fields_amount {
1036                                            Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("non-equal amount of type fields"))?;
1037                                        }
1038                                        #code
1039                                        checker = Some(Self::#variant_ident(
1040                                            #(#var_names),*
1041                                        ));
1042                                        Ok(())
1043                                    };
1044
1045                                    if try_unnamed().is_ok() {
1046                                        *r = r_try;
1047                                        return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
1048                                    }
1049                                }
1050                            },
1051                            Fields::Named(ref fields) => {
1052                                let fields = &fields.named;
1053                                let fields_amount = fields.len();
1054                                let field_names = fields.iter().map(|field| &field.ident);
1055                                let mut var_names = Vec::with_capacity(fields.len());
1056                                let code: TokenStream = fields
1057                                    .iter()
1058                                    .map(|field| {
1059                                        let field_ident = field.ident.as_ref().expect("only named fields here");
1060                                        let var_name = format_ident!("_field_{}", field_ident);
1061                                        let var_type = &field.ty;
1062
1063                                        let out = quote_spanned! {field.span()=>
1064                                            let #var_name: #var_type = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
1065                                                .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", stringify!(#field_ident))))?;
1066                                        };
1067
1068                                        var_names.push(var_name);
1069                                        out
1070                                    })
1071                                    .collect();
1072                                quote! {
1073                                    // https://doc.rust-lang.org/beta/unstable-book/language-features/try-blocks.html (2016 -_-)
1074                                    let mut r_try = *r;
1075                                    let mut try_named = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
1076                                        let amount = #tarantool_crate::msgpack::rmp::decode::read_array_len(&mut r_try)
1077                                            .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
1078                                        if amount as usize != #fields_amount {
1079                                            Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("non-equal amount of type fields"))?;
1080                                        }
1081                                        #code
1082                                        checker = Some(Self::#variant_ident {
1083                                            #(#field_names: #var_names),*
1084                                        });
1085                                        Ok(())
1086                                    };
1087
1088                                    if try_named().is_ok() {
1089                                        *r = r_try;
1090                                        return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
1091                                    }
1092                                }
1093                            },
1094                        }
1095                    })
1096                    .collect::<TokenStream>()
1097            }
1098        };
1099        quote! {
1100            let mut checker: Option<Self> = None;
1101            #out
1102            Result::<Self, #tarantool_crate::msgpack::DecodeError>::Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("received stream didn't match any enum variant"))
1103        }
1104    }
1105}
1106
1107/// Utility function to get a span range of the attributes.
1108fn attrs_span<'a>(attrs: impl IntoIterator<Item = &'a Attribute>) -> SpanRange {
1109    SpanRange::from_tokens(
1110        &attrs
1111            .into_iter()
1112            .flat_map(ToTokens::into_token_stream)
1113            .collect::<TokenStream2>(),
1114    )
1115}
1116
1117/// Collects all lifetimes from `syn::Generic` into `syn::Punctuated` iterator
1118/// in a format like: `'a + 'b + 'c` and so on.
1119#[inline]
1120fn collect_lifetimes(generics: &syn::Generics) -> Punctuated<syn::Lifetime, Token![+]> {
1121    let mut lifetimes = Punctuated::new();
1122    let mut unique_lifetimes = std::collections::HashSet::new();
1123
1124    for param in &generics.params {
1125        if let syn::GenericParam::Lifetime(lifetime_def) = param {
1126            if unique_lifetimes.insert(lifetime_def.lifetime.clone()) {
1127                lifetimes.push(lifetime_def.lifetime.clone());
1128            }
1129        }
1130    }
1131
1132    lifetimes
1133}
1134
1135/// Macro to automatically derive `tarantool::msgpack::Encode`
1136/// Deriving this trait will make this struct encodable into msgpack format.
1137/// It is meant as a replacement for serde + rmp_serde
1138/// allowing us to customize it for tarantool case and hopefully also decreasing compile-time due to its simplicity.
1139///
1140/// For more information see `tarantool::msgpack::Encode`
1141#[proc_macro_error]
1142#[proc_macro_derive(Encode, attributes(encode))]
1143pub fn derive_encode(input: TokenStream) -> TokenStream {
1144    let input = parse_macro_input!(input as DeriveInput);
1145    let name = &input.ident;
1146
1147    // Get attribute arguments
1148    let args: msgpack::Args = darling::FromDeriveInput::from_derive_input(&input).unwrap();
1149    let tarantool_crate = args
1150        .tarantool
1151        .as_deref()
1152        .map(syn::parse_str)
1153        .transpose()
1154        .unwrap()
1155        .unwrap_or_else(default_tarantool_crate_path);
1156
1157    // Add a bound to every type parameter.
1158    let generics = msgpack::add_trait_bounds(input.generics, &tarantool_crate);
1159    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
1160    let encode_fields = msgpack::encode_fields(
1161        &input.data,
1162        &tarantool_crate,
1163        // Use a closure as the function might be costly, but is only used for errors
1164        // and we don't want to slow down compilation.
1165        || attrs_span(&input.attrs),
1166        &args,
1167    );
1168    let expanded = quote! {
1169        // The generated impl.
1170        impl #impl_generics #tarantool_crate::msgpack::Encode for #name #ty_generics #where_clause {
1171            fn encode(&self, w: &mut impl ::std::io::Write, context: &#tarantool_crate::msgpack::Context)
1172                -> std::result::Result<(), #tarantool_crate::msgpack::EncodeError>
1173            {
1174                use #tarantool_crate::msgpack::StructStyle;
1175                #encode_fields
1176                Ok(())
1177            }
1178        }
1179    };
1180
1181    expanded.into()
1182}
1183
1184/// Macro to automatically derive `tarantool::msgpack::Decode`
1185/// Deriving this trait will allow decoding this struct from msgpack format.
1186/// It is meant as a replacement for serde + rmp_serde
1187/// allowing us to customize it for tarantool case and hopefully also decreasing compile-time due to its simplicity.
1188///
1189/// For more information see `tarantool::msgpack::Decode`
1190#[proc_macro_error]
1191#[proc_macro_derive(Decode, attributes(encode))]
1192pub fn derive_decode(input: TokenStream) -> TokenStream {
1193    let input = parse_macro_input!(input as DeriveInput);
1194    let name = &input.ident;
1195
1196    // Get attribute arguments
1197    let args: msgpack::Args = darling::FromDeriveInput::from_derive_input(&input).unwrap();
1198    let tarantool_crate = args.tarantool.as_deref().unwrap_or("tarantool");
1199    let tarantool_crate = Ident::new(tarantool_crate, Span::call_site()).into();
1200
1201    // Add a bound to every type parameter.
1202    let generics = msgpack::add_trait_bounds(input.generics.clone(), &tarantool_crate);
1203    let mut impl_generics = input.generics;
1204    impl_generics.params.insert(
1205        0,
1206        syn::GenericParam::Lifetime(syn::LifetimeDef {
1207            attrs: vec![],
1208            lifetime: syn::Lifetime::new("'de", Span::call_site()),
1209            colon_token: Some(syn::token::Colon::default()),
1210            bounds: collect_lifetimes(&generics),
1211        }),
1212    );
1213    // let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
1214    let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
1215    let (_, ty_generics, _) = generics.split_for_impl();
1216    let decode_fields = msgpack::decode_fields(
1217        &input.data,
1218        &tarantool_crate,
1219        // Use a closure as the function might be costly, but is only used for errors
1220        // and we don't want to slow down compilation.
1221        || attrs_span(&input.attrs),
1222        &args,
1223    );
1224    let expanded = quote! {
1225        // The generated impl.
1226        impl #impl_generics #tarantool_crate::msgpack::Decode<'de> for #name #ty_generics #where_clause {
1227            fn decode(r: &mut &'de [u8], context: &#tarantool_crate::msgpack::Context)
1228                -> std::result::Result<Self, #tarantool_crate::msgpack::DecodeError>
1229            {
1230                use #tarantool_crate::msgpack::StructStyle;
1231                #decode_fields
1232            }
1233        }
1234    };
1235
1236    expanded.into()
1237}
1238
1239/// Create a tarantool stored procedure.
1240///
1241/// See `tarantool::proc` doc-comments in tarantool crate for details.
1242#[proc_macro_attribute]
1243pub fn stored_proc(attr: TokenStream, item: TokenStream) -> TokenStream {
1244    let args = parse_macro_input!(attr as AttributeArgs);
1245    let ctx = Context::from_args(args);
1246
1247    let input = parse_macro_input!(item as Item);
1248
1249    #[rustfmt::skip]
1250    let ItemFn { vis, sig, block, attrs, .. } = match input {
1251        Item::Fn(f) => f,
1252        _ => panic!("only `fn` items can be stored procedures"),
1253    };
1254
1255    let (ident, inputs, output, generics) = match sig {
1256        Signature {
1257            asyncness: Some(_), ..
1258        } => {
1259            panic!("async stored procedures are not supported yet")
1260        }
1261        Signature {
1262            variadic: Some(_), ..
1263        } => {
1264            panic!("variadic stored procedures are not supported yet")
1265        }
1266        Signature {
1267            ident,
1268            inputs,
1269            output,
1270            generics,
1271            ..
1272        } => (ident, inputs, output, generics),
1273    };
1274
1275    let Inputs {
1276        inputs,
1277        input_pattern,
1278        input_idents,
1279        inject_inputs,
1280        n_actual_arguments,
1281    } = Inputs::parse(&ctx, inputs);
1282
1283    if ctx.is_packed && n_actual_arguments > 1 {
1284        panic!("proc with 'packed_args' can only have a single parameter")
1285    }
1286
1287    let Context {
1288        tarantool,
1289        linkme,
1290        section,
1291        debug_tuple,
1292        wrap_ret,
1293        ..
1294    } = ctx;
1295
1296    let inner_fn_name = syn::Ident::new("__tp_inner", ident.span());
1297    let desc_name = ident.to_string();
1298    let desc_ident = syn::Ident::new(&desc_name.to_uppercase(), ident.span());
1299    let mut public = matches!(vis, syn::Visibility::Public(_));
1300    if let Some(override_public) = ctx.public {
1301        public = override_public;
1302    }
1303
1304    // Only add tarantool::proc-annotated function to the distributed slice
1305    // if the `stored_procs_slice` feature is active. We need this to combat
1306    // the runtime panics introduced in linkme 0.3.1.
1307    let attrs_distributed_slice = if cfg!(feature = "stored_procs_slice") {
1308        quote! {
1309            #[#linkme::distributed_slice(#section)]
1310            #[linkme(crate = #linkme)]
1311        }
1312    } else {
1313        quote! {}
1314    };
1315
1316    quote! {
1317        #attrs_distributed_slice
1318        #[cfg(not(test))]
1319        #[allow(deprecated)]
1320        static #desc_ident: #tarantool::proc::Proc = #tarantool::proc::Proc::new(
1321            #desc_name,
1322            #ident,
1323        ).with_public(#public);
1324
1325        #(#attrs)*
1326        #[no_mangle]
1327        pub unsafe extern "C" fn #ident (
1328            __tp_ctx: #tarantool::tuple::FunctionCtx,
1329            __tp_args: #tarantool::tuple::FunctionArgs,
1330        ) -> ::std::os::raw::c_int {
1331            #debug_tuple
1332            let #input_pattern =
1333                match __tp_args.decode() {
1334                    ::std::result::Result::Ok(__tp_args) => __tp_args,
1335                    ::std::result::Result::Err(__tp_err) => {
1336                        #tarantool::set_error!(
1337                            #tarantool::error::TarantoolErrorCode::ProcC,
1338                            "{}",
1339                            __tp_err
1340                        );
1341                        return -1;
1342                    }
1343                };
1344
1345            #inject_inputs
1346
1347            fn #inner_fn_name #generics (#inputs) #output {
1348                #block
1349            }
1350
1351            let __tp_res = __tp_inner(#(#input_idents),*);
1352
1353            #wrap_ret
1354
1355            #tarantool::proc::Return::ret(__tp_res, __tp_ctx)
1356        }
1357    }
1358    .into()
1359}
1360
1361struct Context {
1362    tarantool: syn::Path,
1363    section: syn::Path,
1364    linkme: syn::Path,
1365    debug_tuple: TokenStream2,
1366    is_packed: bool,
1367    public: Option<bool>,
1368    wrap_ret: TokenStream2,
1369}
1370
1371impl Context {
1372    fn from_args(args: AttributeArgs) -> Self {
1373        let mut tarantool: syn::Path = default_tarantool_crate_path();
1374        let mut linkme = None;
1375        let mut section = None;
1376        let mut debug_tuple_needed = false;
1377        let mut is_packed = false;
1378        let mut public = None;
1379        let mut wrap_ret = quote! {};
1380
1381        for arg in args {
1382            if let Some(path) = imp::parse_lit_str_with_key(&arg, "tarantool") {
1383                tarantool = path;
1384                continue;
1385            }
1386            if let Some(path) = imp::parse_lit_str_with_key(&arg, "linkme") {
1387                linkme = Some(path);
1388                continue;
1389            }
1390            if let Some(path) = imp::parse_lit_str_with_key(&arg, "section") {
1391                section = Some(path);
1392                continue;
1393            }
1394            if imp::is_path_eq_to(&arg, "custom_ret") {
1395                wrap_ret = quote! {
1396                    let __tp_res = #tarantool::proc::ReturnMsgpack(__tp_res);
1397                };
1398                continue;
1399            }
1400            if imp::is_path_eq_to(&arg, "packed_args") {
1401                is_packed = true;
1402                continue;
1403            }
1404            if imp::is_path_eq_to(&arg, "debug") {
1405                debug_tuple_needed = true;
1406                continue;
1407            }
1408            if let Some(v) = imp::parse_bool_with_key(&arg, "public") {
1409                public = Some(v);
1410                continue;
1411            }
1412            panic!("unsuported attribute argument `{}`", quote!(#arg))
1413        }
1414
1415        let section = section.unwrap_or_else(|| {
1416            imp::path_from_ts2(quote! { #tarantool::proc::TARANTOOL_MODULE_STORED_PROCS })
1417        });
1418        let linkme = linkme.unwrap_or_else(|| imp::path_from_ts2(quote! { #tarantool::linkme }));
1419
1420        let debug_tuple = if debug_tuple_needed {
1421            quote! {
1422                ::std::dbg!(#tarantool::tuple::Tuple::from(&__tp_args));
1423            }
1424        } else {
1425            quote! {}
1426        };
1427        Self {
1428            tarantool,
1429            linkme,
1430            section,
1431            debug_tuple,
1432            is_packed,
1433            wrap_ret,
1434            public,
1435        }
1436    }
1437}
1438
1439struct Inputs {
1440    inputs: Punctuated<FnArg, Token![,]>,
1441    input_pattern: TokenStream2,
1442    input_idents: Vec<syn::Pat>,
1443    inject_inputs: TokenStream2,
1444    n_actual_arguments: usize,
1445}
1446
1447impl Inputs {
1448    fn parse(ctx: &Context, mut inputs: Punctuated<FnArg, Token![,]>) -> Self {
1449        let mut input_idents = vec![];
1450        let mut actual_inputs = vec![];
1451        let mut injected_inputs = vec![];
1452        let mut injected_exprs = vec![];
1453        for i in &mut inputs {
1454            let syn::PatType {
1455                ref pat,
1456                ref mut attrs,
1457                ..
1458            } = match i {
1459                FnArg::Receiver(_) => {
1460                    panic!("`self` receivers aren't supported in stored procedures")
1461                }
1462                FnArg::Typed(pat_ty) => pat_ty,
1463            };
1464            let mut inject_expr = None;
1465            attrs.retain(|attr| {
1466                let path = &attr.path;
1467                if path.is_ident("inject") {
1468                    match attr.parse_args() {
1469                        Ok(AttrInject { expr, .. }) => {
1470                            inject_expr = Some(expr);
1471                            false
1472                        }
1473                        Err(e) => panic!("attribute argument error: {e}"),
1474                    }
1475                } else {
1476                    // Skip doc comments as they are not allowed for inner functions
1477                    !path.is_ident("doc")
1478                }
1479            });
1480            if let Some(expr) = inject_expr {
1481                injected_inputs.push(pat.clone());
1482                injected_exprs.push(expr);
1483            } else {
1484                actual_inputs.push(pat.clone());
1485            }
1486            input_idents.push((**pat).clone());
1487        }
1488
1489        let input_pattern = if inputs.is_empty() {
1490            quote! { []: [(); 0] }
1491        } else if ctx.is_packed {
1492            quote! { #(#actual_inputs)* }
1493        } else {
1494            quote! { ( #(#actual_inputs,)* ) }
1495        };
1496
1497        let inject_inputs = quote! {
1498            #( let #injected_inputs = #injected_exprs; )*
1499        };
1500
1501        Self {
1502            inputs,
1503            input_pattern,
1504            input_idents,
1505            inject_inputs,
1506            n_actual_arguments: actual_inputs.len(),
1507        }
1508    }
1509}
1510
1511#[derive(Debug)]
1512struct AttrInject {
1513    expr: syn::Expr,
1514}
1515
1516impl syn::parse::Parse for AttrInject {
1517    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1518        Ok(AttrInject {
1519            expr: input.parse()?,
1520        })
1521    }
1522}
1523
1524mod kw {
1525    syn::custom_keyword! {inject}
1526}
1527
1528mod imp {
1529    use proc_macro2::{Group, Span, TokenStream, TokenTree};
1530    use syn::parse::{self, Parse};
1531
1532    #[track_caller]
1533    pub(crate) fn parse_lit_str_with_key<T>(nm: &syn::NestedMeta, key: &str) -> Option<T>
1534    where
1535        T: Parse,
1536    {
1537        match nm {
1538            syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
1539                path, lit, ..
1540            })) if path.is_ident(key) => match &lit {
1541                syn::Lit::Str(s) => Some(crate::imp::parse_lit_str(s).unwrap()),
1542                _ => panic!("{key} value must be a string literal"),
1543            },
1544            _ => None,
1545        }
1546    }
1547
1548    #[track_caller]
1549    pub(crate) fn parse_bool_with_key(nm: &syn::NestedMeta, key: &str) -> Option<bool> {
1550        match nm {
1551            syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
1552                path, lit, ..
1553            })) if path.is_ident(key) => match &lit {
1554                syn::Lit::Bool(b) => Some(b.value),
1555                _ => panic!("value for attribute '{key}' must be a bool literal (true | false)"),
1556            },
1557            syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident(key) => {
1558                panic!("expected ({key} = true|false), got just {key}");
1559            }
1560            _ => None,
1561        }
1562    }
1563
1564    #[track_caller]
1565    pub(crate) fn is_path_eq_to(nm: &syn::NestedMeta, expected: &str) -> bool {
1566        matches!(
1567            nm,
1568            syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident(expected)
1569        )
1570    }
1571
1572    pub(crate) fn path_from_ts2(ts: TokenStream) -> syn::Path {
1573        syn::parse2(ts).unwrap()
1574    }
1575
1576    // stolen from serde
1577
1578    pub(crate) fn parse_lit_str<T>(s: &syn::LitStr) -> parse::Result<T>
1579    where
1580        T: Parse,
1581    {
1582        let tokens = spanned_tokens(s)?;
1583        syn::parse2(tokens)
1584    }
1585
1586    fn spanned_tokens(s: &syn::LitStr) -> parse::Result<TokenStream> {
1587        let stream = syn::parse_str(&s.value())?;
1588        Ok(respan(stream, s.span()))
1589    }
1590
1591    fn respan(stream: TokenStream, span: Span) -> TokenStream {
1592        stream
1593            .into_iter()
1594            .map(|token| respan_token(token, span))
1595            .collect()
1596    }
1597
1598    fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
1599        if let TokenTree::Group(g) = &mut token {
1600            *g = Group::new(g.delimiter(), respan(g.stream(), span));
1601        }
1602        token.set_span(span);
1603        token
1604    }
1605}