neuro_sama_derive/
lib.rs

1use proc_macro2::{Group, Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use syn::{spanned::Spanned, token::Mut, Data, DeriveInput, Fields, Ident, Item, Path};
4
5fn derive_actions2(input: TokenStream) -> TokenStream {
6    let data: DeriveInput = syn::parse2(input).unwrap();
7    let name = data.ident;
8    let Data::Enum(data) = data.data else {
9        panic!("#[derive(Actions)] is only supported on enums")
10    };
11    let mut ret = TokenStream::new();
12    let mut ret1 = TokenStream::new();
13    let mut meta = TokenStream::new();
14    let mut names = TokenStream::new();
15    for variant in data.variants {
16        let field = match variant.fields {
17            Fields::Unit => None,
18            Fields::Unnamed(a) => {
19                if a.unnamed.len() > 1 {
20                    panic!(
21                        "#[derive(Actions)] doesn't support enum variants with more than one field"
22                    );
23                }
24                a.unnamed.into_iter().next()
25            }
26            Fields::Named(_) => panic!("#[derive(Actions)] doesn't support named fields"),
27        };
28        if let Some(field) = field {
29            let ty = field.ty;
30            let ident = variant.ident;
31            let mut desc = String::new();
32            let mut name = None;
33            for attr in variant.attrs {
34                match attr.meta.path().to_token_stream().to_string().as_str() {
35                    "doc" => {
36                        let x = attr.meta.require_name_value().unwrap();
37                        match &x.value {
38                            syn::Expr::Lit(lit) => match &lit.lit {
39                                syn::Lit::Str(s) => {
40                                    if !desc.is_empty() {
41                                        desc.push('\n');
42                                    }
43                                    desc += s.value().trim();
44                                }
45                                _ => panic!("doc comment value is not a string literal???"),
46                            },
47                            _ => panic!("doc comment value is not a string literal???"),
48                        }
49                    }
50                    "name" => {
51                        let x = attr.meta.require_name_value().unwrap();
52                        name = Some(x.value.clone());
53                    }
54                    _ => {}
55                }
56            }
57            if desc.is_empty() {
58                panic!("expected variant {} to have a doc comment", ident)
59            }
60            let name = name
61                .ok_or_else(|| {
62                    panic!(
63                        "expected variant {} to have a #[name = ...] attribute",
64                        ident
65                    )
66                })
67                .unwrap();
68            ret.extend(quote! {
69                impl neuro_sama::game::Action for #ty {
70                    fn name() -> &'static str {
71                        #name
72                    }
73                    fn description() -> &'static str {
74                        #desc.trim()
75                    }
76                }
77            });
78            ret1.extend(quote! {
79                #name => <#ty as neuro_sama::serde::Deserialize<'_>>::deserialize(de).map(Self::#ident),
80            });
81            meta.extend(quote! {
82                neuro_sama::schema::Action {
83                    name: #name.into(),
84                    description: #desc.trim().into(),
85                    schema: neuro_sama::schemars::schema_for!(#ty),
86                },
87            });
88            names.extend(quote! { #name.into(), });
89        } else {
90            panic!("#[derive(Actions)] doesn't support empty variants, since each variant has to be a separate type as well");
91        }
92    }
93    ret.extend(quote! {
94        impl<'de> neuro_sama::game::Actions<'de> for #name where Self: 'de  {
95            fn deserialize<D: neuro_sama::serde::Deserializer<'de>>(discriminant: &str, de: D) -> Result<Self, D::Error> {
96                use neuro_sama::serde::de::Error as _;
97                match discriminant {
98                    #ret1
99                    _ => Err(D::Error::custom(format!("unexpected action: `{discriminant}`"))),
100                }
101            }
102        }
103        impl neuro_sama::game::ActionMetadata for #name {
104            fn actions() -> Vec<neuro_sama::schema::Action> {
105                vec![#meta]
106            }
107            fn names() -> Vec<std::borrow::Cow<'static, str>> {
108                vec![#names]
109            }
110        }
111    });
112    ret
113}
114
115fn generic_mutability2(attr: TokenStream, input: TokenStream) -> TokenStream {
116    let inp: Item = syn::parse2(input).unwrap();
117    let mut attr = attr.into_iter();
118    let ident = Ident::new(&attr.next().unwrap().to_string(), Span::call_site());
119    let (ident, out) = match &inp {
120        Item::Struct(inp) => {
121            let mut out = inp.clone();
122            let ident2 = Ident::new(&attr.nth(1).unwrap().to_string(), Span::call_site());
123            match out
124                .generics
125                .type_params_mut()
126                .next()
127                .unwrap()
128                .bounds
129                .first_mut()
130                .unwrap()
131            {
132                syn::TypeParamBound::Trait(tr) => {
133                    tr.path.segments.first_mut().unwrap().ident = ident2
134                }
135                _ => panic!(),
136            }
137            out.ident = ident;
138            (Some(inp.ident.clone()), out.to_token_stream())
139        }
140        Item::Impl(inp) => {
141            let mut out = inp.clone();
142            let ident2 = Ident::new(&attr.nth(1).unwrap().to_string(), Span::call_site());
143            match out
144                .generics
145                .type_params_mut()
146                .next()
147                .unwrap()
148                .bounds
149                .first_mut()
150                .unwrap()
151            {
152                syn::TypeParamBound::Trait(tr) => {
153                    tr.path.segments.first_mut().unwrap().ident = ident2
154                }
155                _ => panic!(),
156            }
157            match &mut *out.self_ty {
158                syn::Type::Path(x) => {
159                    let seg = x.path.segments.first_mut().unwrap();
160                    seg.ident = ident;
161                }
162                _ => panic!(),
163            }
164            (None, out.to_token_stream())
165        }
166        Item::Trait(inp) => {
167            let mut out = inp.clone();
168            out.ident = ident;
169            out.attrs.retain(|x| {
170                !matches!(
171                    x.path().to_token_stream().to_string().as_str(),
172                    "generic_mutability" | "doc",
173                )
174            });
175            if attr.next().is_some() {
176                if let syn::TypeParamBound::Trait(t) = out.supertraits.first_mut().unwrap() {
177                    t.path =
178                        Path::from(Ident::new(&attr.next().unwrap().to_string(), t.path.span()))
179                }
180            }
181            for item in &mut out.items {
182                if let syn::TraitItem::Fn(x) = item {
183                    if let Some(arg) = x.sig.inputs.first_mut() {
184                        match arg {
185                            syn::FnArg::Receiver(x)
186                                if x.mutability.is_none() && x.reference.is_some() =>
187                            {
188                                x.mutability = Some(Mut {
189                                    span: x.reference.as_ref().unwrap().0.span,
190                                });
191                                if let syn::Type::Reference(x) = &mut *x.ty {
192                                    x.mutability = Some(Mut {
193                                        span: x.elem.span(),
194                                    })
195                                }
196                            }
197                            _ => {}
198                        }
199                    }
200                    // panic!("{}", x.to_token_stream().to_string());
201                }
202            }
203            (Some(inp.ident.clone()), out.to_token_stream())
204        }
205        _ => panic!(),
206    };
207    fn hack_stream(x: TokenStream) -> TokenStream {
208        x.into_iter().map(hack_tree).collect()
209    }
210    fn hack_tree(x: TokenTree) -> TokenTree {
211        match x {
212            TokenTree::Group(x) => {
213                let del = x.delimiter();
214                let st = x.stream();
215                let st = hack_stream(st);
216                TokenTree::Group(Group::new(del, st))
217            }
218            TokenTree::Ident(ref ident) => match ident.to_string().as_str() {
219                "ForceActionsBuilder" => {
220                    TokenTree::Ident(Ident::new("ForceActionsBuilderMut", Span::call_site()))
221                }
222                "send_ws_command" => {
223                    TokenTree::Ident(Ident::new("send_ws_command_mut", Span::call_site()))
224                }
225                _ => x,
226            },
227            TokenTree::Punct(_) => x,
228            TokenTree::Literal(_) => x,
229        }
230    }
231    let out = hack_stream(out);
232
233    if let Some(ident) = ident {
234        let doc = format!(
235            "A mutable version of [`{}`]. See [`{}`] docs for examples.",
236            ident, ident
237        );
238        quote! {
239            #[doc = #doc]
240            #out
241            #inp
242        }
243    } else {
244        quote! {
245            #out
246            #inp
247        }
248    }
249}
250
251/// See the `neuro_sama` crate for more info.
252#[proc_macro_derive(Actions, attributes(name))]
253pub fn derive_actions(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
254    derive_actions2(input.into()).into()
255}
256
257#[proc_macro_attribute]
258#[doc(hidden)]
259pub fn generic_mutability(
260    attr: proc_macro::TokenStream,
261    input: proc_macro::TokenStream,
262) -> proc_macro::TokenStream {
263    generic_mutability2(attr.into(), input.into()).into()
264}