accepts_codegen/acceptor/next_acceptors_auto_impl/
expand.rs

1use proc_macro2::TokenStream;
2use quote::ToTokens;
3use syn::ext::IdentExt;
4use syn::{
5    Block, Data, DeriveInput, Fields, GenericArgument, Ident, Lifetime, Meta, PathSegment, Type,
6    parse::{Parse, ParseStream},
7    parse2,
8    punctuated::Punctuated,
9    spanned::Spanned,
10    token::{Comma, PathSep},
11};
12
13use crate::{
14    acceptor::common::ast::next_acceptors_trait_ast::PartialNextAcceptorsTraitImpl,
15    common::{context::CodegenContext, syn::ast::tokens::PathSplitLastArgs},
16};
17
18#[derive(Debug, Default, Clone, PartialEq, Eq)]
19struct NextAcceptorOptions {
20    once: bool,
21    option_once: bool,
22    mut_: bool,
23    ref_: bool,
24}
25
26impl Parse for NextAcceptorOptions {
27    fn parse(input: ParseStream) -> syn::Result<Self> {
28        let mut opts = Self::default();
29        while !input.is_empty() {
30            let ident: Ident = input.call(Ident::parse_any)?;
31            match &*ident.to_string() {
32                "once" => opts.once = true,
33                "option_once" => opts.option_once = true,
34                "mut" => opts.mut_ = true,
35                "ref" => opts.ref_ = true,
36                other => {
37                    return Err(syn::Error::new(
38                        ident.span(),
39                        format!("unknown option `{}`", other),
40                    ));
41                }
42            }
43            if input.peek(Comma) {
44                let _ = input.parse::<Comma>();
45            }
46        }
47        Ok(opts)
48    }
49}
50
51fn option_inner_type(ty: &Type) -> Option<Type> {
52    if let Type::Path(p) = ty {
53        if let Some(seg) = p.path.segments.last() {
54            if seg.ident == "Option" {
55                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
56                    if let Some(first) = args.args.first() {
57                        if let GenericArgument::Type(inner) = first {
58                            return Some(inner.clone());
59                        }
60                    }
61                }
62            }
63        }
64    }
65    None
66}
67
68pub fn expand(ctx: &CodegenContext, item: TokenStream) -> TokenStream {
69    let input: DeriveInput = match parse2(item) {
70        Ok(i) => i,
71        Err(e) => return e.to_compile_error(),
72    };
73
74    let fields = match &input.data {
75        Data::Struct(s) => match &s.fields {
76            Fields::Named(f) => &f.named,
77            _ => {
78                return syn::Error::new(
79                    s.struct_token.span(),
80                    "NextAcceptors can only be derived for structs with named fields",
81                )
82                .to_compile_error();
83            }
84        },
85        _ => {
86            return syn::Error::new(
87                input.ident.span(),
88                "NextAcceptors can only be derived for structs",
89            )
90            .to_compile_error();
91        }
92    };
93
94    let mut next_fields = Vec::new();
95    let mut opts: Option<NextAcceptorOptions> = None;
96    for field in fields.iter() {
97        for attr in field
98            .attrs
99            .iter()
100            .filter(|a| a.path().is_ident("next_acceptor"))
101        {
102            let this_opts = match attr.meta.clone() {
103                Meta::Path(_) => NextAcceptorOptions::default(),
104                Meta::List(list) => match syn::parse2::<NextAcceptorOptions>(list.tokens) {
105                    Ok(o) => o,
106                    Err(e) => return e.to_compile_error(),
107                },
108                Meta::NameValue(_) => {
109                    return syn::Error::new(attr.span(), "unsupported attribute format")
110                        .to_compile_error();
111                }
112            };
113            if let Some(existing) = &opts {
114                if existing != &this_opts {
115                    return syn::Error::new(attr.span(), "conflicting #[next_acceptor] options")
116                        .to_compile_error();
117                }
118            } else {
119                opts = Some(this_opts);
120            }
121            if let Some(id) = field.ident.clone() {
122                next_fields.push((id, field.ty.clone()));
123            }
124        }
125    }
126
127    if next_fields.is_empty() {
128        return syn::Error::new(input.ident.span(), "no field with #[next_acceptor] found")
129            .to_compile_error();
130    }
131
132    let mut options = opts.unwrap_or_default();
133    if !options.mut_ && !options.ref_ {
134        options.ref_ = true;
135    }
136    if options.once && options.option_once {
137        return syn::Error::new(
138            input.ident.span(),
139            "conflicting options: once and option_once",
140        )
141        .to_compile_error();
142    }
143
144    let iter_len = next_fields.len();
145    if (options.once || options.option_once) && iter_len != 1 {
146        return syn::Error::new(
147            input.ident.span(),
148            "options once/option_once require exactly one #[next_acceptor] field",
149        )
150        .to_compile_error();
151    }
152
153    let mut acceptor_type = next_fields[0].1.clone();
154    if options.option_once {
155        acceptor_type = match option_inner_type(&acceptor_type) {
156            Some(t) => t,
157            None => {
158                return syn::Error::new(
159                    acceptor_type.span(),
160                    "field with option_once must be Option<T>",
161                )
162                .to_compile_error();
163            }
164        };
165    } else if !next_fields.iter().all(|(_, ty)| *ty == acceptor_type) {
166        return syn::Error::new(
167            next_fields[0].1.span(),
168            "all #[next_acceptor] fields must have the same type",
169        )
170        .to_compile_error();
171    }
172
173    let field_idents: Vec<Ident> = next_fields.into_iter().map(|(id, _)| id).collect();
174    let iter_lifetime: Lifetime = syn::parse_quote!('a);
175
176    let mut impls = Vec::new();
177
178    if options.ref_ {
179        let (iter_type, next_block): (Type, Block) = if options.once {
180            let ident = &field_idents[0];
181            (
182                syn::parse_quote!(core::iter::Once<&'a #acceptor_type>),
183                syn::parse_quote!({ core::iter::once(&self.#ident) }),
184            )
185        } else if options.option_once {
186            let ident = &field_idents[0];
187            (
188                syn::parse_quote!(core::option::Iter<'a, #acceptor_type>),
189                syn::parse_quote!({ self.#ident.iter() }),
190            )
191        } else {
192            (
193                syn::parse_quote!(core::array::IntoIter<&'a #acceptor_type, #iter_len>),
194                syn::parse_quote!({ [#(&self.#field_idents),*].into_iter() }),
195            )
196        };
197
198        let partial = PartialNextAcceptorsTraitImpl::from_types(
199            acceptor_type.clone(),
200            iter_lifetime.clone(),
201            iter_type,
202            false,
203            next_block,
204        );
205        let self_ty_path = PathSplitLastArgs::from_parts(
206            None,
207            Punctuated::<PathSegment, PathSep>::new(),
208            input.ident.clone(),
209        );
210        impls.push(partial.into_item_impl_from_path(ctx, self_ty_path, input.generics.clone()));
211    }
212
213    if options.mut_ {
214        let (iter_type, next_block): (Type, Block) = if options.once {
215            let ident = &field_idents[0];
216            (
217                syn::parse_quote!(core::iter::Once<&'a mut #acceptor_type>),
218                syn::parse_quote!({ core::iter::once(&mut self.#ident) }),
219            )
220        } else if options.option_once {
221            let ident = &field_idents[0];
222            (
223                syn::parse_quote!(core::option::IterMut<'a, #acceptor_type>),
224                syn::parse_quote!({ self.#ident.iter_mut() }),
225            )
226        } else {
227            (
228                syn::parse_quote!(core::array::IntoIter<&'a mut #acceptor_type, #iter_len>),
229                syn::parse_quote!({ [#(&mut self.#field_idents),*].into_iter() }),
230            )
231        };
232
233        let partial = PartialNextAcceptorsTraitImpl::from_types(
234            acceptor_type,
235            iter_lifetime,
236            iter_type,
237            true,
238            next_block,
239        );
240        let self_ty_path = PathSplitLastArgs::from_parts(
241            None,
242            Punctuated::<PathSegment, PathSep>::new(),
243            input.ident.clone(),
244        );
245        impls.push(partial.into_item_impl_from_path(ctx, self_ty_path, input.generics.clone()));
246    }
247
248    let mut tokens = TokenStream::new();
249    for item in impls {
250        tokens.extend(item.into_token_stream());
251    }
252    tokens
253}