Skip to main content

nexosim_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{ToTokens, quote, quote_token};
3use syn::{
4    Expr, ExprPath, FnArg, Generics, Ident, ImplItem, ImplItemFn, ItemType, Meta, Path,
5    PathArguments, PathSegment, Signature, Token, Type, TypeTuple,
6    punctuated::Punctuated,
7    spanned::Spanned,
8    token::{Paren, PathSep},
9};
10
11const INIT_ATTR: &str = "init";
12const SCHEDULABLE_ATTR: &str = "schedulable";
13const AVAILABLE_ATTRS: &[&str] = &[INIT_ATTR, SCHEDULABLE_ATTR];
14
15macro_rules! handle_parse_result {
16    ($call:expr) => {
17        match $call {
18            Ok(data) => data,
19            Err(err) => return syn::__private::TokenStream::from(err.to_compile_error()),
20        }
21    };
22}
23
24#[proc_macro_attribute]
25pub fn __erase(_: TokenStream, _: TokenStream) -> TokenStream {
26    <_>::default()
27}
28
29/// A helper macro that enables schema generation for the server endpoint
30/// data.
31#[proc_macro_derive(Message)]
32pub fn message_derive(input: TokenStream) -> TokenStream {
33    [
34        stringify!(
35            #[
36                ::core::prelude::v1::derive(
37                    ::nexosim::JsonSchema
38                )
39            ]
40            #[schemars(crate = "nexosim::schemars")]
41            #[::nexosim::nexosim_macros::__erase]
42        ),
43        &input.to_string(),
44    ]
45    .concat()
46    .parse()
47    .unwrap()
48}
49
50#[proc_macro]
51pub fn schedulable(input: TokenStream) -> TokenStream {
52    let ast = handle_parse_result!(syn::parse(input));
53    impl_schedulable(&ast).unwrap_or_else(|e| e.to_compile_error().into())
54}
55
56fn impl_schedulable(ast: &Path) -> Result<TokenStream, syn::Error> {
57    if ast.segments.len() != 2 {
58        return Err(syn::Error::new_spanned(
59            ast,
60            "invalid associated method path",
61        ));
62    }
63
64    let ty = ast.segments[0].clone();
65    let hidden_name = Ident::new(&format!("__{}", ast.segments[1].ident), ast.span());
66
67    let mut segments = ast.segments.clone();
68
69    segments[1].ident = hidden_name.clone();
70    let path = Path {
71        leading_colon: None,
72        segments,
73    };
74
75    let err_name = ast.segments[1].ident.to_string();
76    // Argument formatting not possible in the const context as of Rust >= 1.87
77    let err_msg = format!(
78        "method `{err_name}` is not a valid schedulable input for the model! Perhaps you forgot to include the #[nexosim(schedulable)] attribute or are using a method from another model."
79    );
80
81    let is_generic = matches!(ty.arguments, PathArguments::AngleBracketed(_));
82
83    // Generic types cannot be used in a const context. Therefore we are not
84    // able to use our custom error message.
85    let tokens = if !is_generic {
86        quote! {
87            {
88                // Call a hidden method in the array type definition to cast a
89                // custom error during a type-checking compilation phase.
90                let _: [(); { if !#ty::____is_schedulable(stringify!(#hidden_name)) {
91                    panic!(#err_msg)
92                }; 0} ] = [];
93                &#path
94            }
95        }
96    } else {
97        quote! {&#path}
98    };
99    Ok(tokens.into())
100}
101
102#[allow(non_snake_case)]
103#[proc_macro_attribute]
104pub fn Model(attr: TokenStream, input: TokenStream) -> TokenStream {
105    let mut ast: syn::ItemImpl = handle_parse_result!(syn::parse(input.clone()));
106    let env = handle_parse_result!(parse_env(attr));
107    let added_tokens = handle_parse_result!(impl_model(&mut ast, env));
108
109    let mut output: TokenStream = ast.to_token_stream().into();
110    output.extend(added_tokens);
111    output
112}
113
114fn impl_model(ast: &mut syn::ItemImpl, env: ItemType) -> Result<TokenStream, syn::Error> {
115    let name = &ast.self_ty;
116
117    let (init, schedulables) = parse_tagged_methods(&mut ast.items)?;
118
119    let registered_methods = get_registered_method_paths(&schedulables);
120    let mut tokens = get_impl_model_trait(name, &env, &ast.generics, init, registered_methods);
121    let hidden_methods = get_hidden_method_impls(&schedulables);
122
123    // We do not use ty_generics as they're already present in `name`
124    let (impl_generics, _, where_clause) = ast.generics.split_for_impl();
125
126    // Write hidden methods block.
127    tokens.extend(quote! {
128        impl #impl_generics #name #where_clause {
129            #( #hidden_methods )*
130        }
131    });
132
133    Ok(tokens.into())
134}
135
136/// Checks whether Env type is provided by the user.
137/// If not uses `()` as a default.
138fn parse_env(tokens: TokenStream) -> Result<ItemType, syn::Error> {
139    if tokens.is_empty() {
140        // No tokens found -> generate `type Env=();`.
141        let span = proc_macro2::Span::call_site();
142        return Ok(ItemType {
143            attrs: vec![],
144            vis: syn::Visibility::Inherited,
145            type_token: Token![type](span),
146            ident: Ident::new("Env", span),
147            generics: Generics::default(),
148            eq_token: Token![=](span),
149            ty: Box::new(Type::Tuple(TypeTuple {
150                paren_token: Paren(span),
151                elems: Punctuated::new(),
152            })),
153            semi_token: Token![;](span),
154        });
155    }
156
157    // Append semicolon at the end of the found token stream.
158    let mut with_semicolon = tokens.clone().into();
159    quote_token!(; with_semicolon);
160    syn::parse(with_semicolon.into())
161}
162
163/// Get MyModel::input method paths from scheduled inputs.
164fn get_registered_method_paths<'a>(
165    schedulables: &'a [ImplItemFn],
166) -> impl Iterator<Item = Expr> + use<'a> {
167    schedulables.iter().map(|a| {
168        let mut segments = Punctuated::new();
169        segments.push_value(PathSegment {
170            ident: Ident::new("Self", a.span()),
171            arguments: syn::PathArguments::None,
172        });
173        segments.push_punct(PathSep::default());
174        segments.push_value(PathSegment {
175            ident: a.sig.ident.clone(),
176            arguments: syn::PathArguments::None,
177        });
178        Expr::Path(ExprPath {
179            path: Path {
180                leading_colon: None,
181                segments,
182            },
183            attrs: Vec::new(),
184            qself: None,
185        })
186    })
187}
188
189/// Finds methods tagged as `init` or `schedulable`.
190/// Clears found tags from the original token stream.
191#[allow(clippy::type_complexity)]
192fn parse_tagged_methods(
193    items: &mut [ImplItem],
194) -> Result<(Option<proc_macro2::TokenStream>, Vec<ImplItemFn>), syn::Error> {
195    let mut init = None;
196    let mut schedulables = Vec::new();
197
198    // Find tagged methods.
199    for item in items.iter_mut() {
200        if let ImplItem::Fn(f) = item {
201            let attrs = collect_nexosim_attributes(f)?;
202            if attrs.contains(&SCHEDULABLE_ATTR) {
203                schedulables.push(f.clone());
204            }
205            if attrs.contains(&INIT_ATTR) {
206                init = Some(init_fn(&f.sig)?);
207            }
208        }
209    }
210
211    // Wrap init tokens into an Option for conditional rendering.
212    let init = init.and_then(|init| {
213        quote! {
214            fn init(
215                mut self, cx: &nexosim::model::Context<Self>, env: &mut Self::Env,
216            ) -> impl std::future::Future<Output = nexosim::model::InitializedModel<Self>> + Send {
217                async move { #init.await; self.into() }
218            }
219        }
220        .into()
221    });
222
223    Ok((init, schedulables))
224}
225
226/// Renders the impl Model for MyModel block.
227fn get_impl_model_trait(
228    name: &Type,
229    env: &ItemType,
230    generics: &Generics,
231    init: Option<proc_macro2::TokenStream>,
232    registered_methods: impl Iterator<Item = Expr>,
233) -> proc_macro2::TokenStream {
234    // We do not use ty_generics as they're already present in `name`
235    let (impl_generics, _, where_clause) = generics.split_for_impl();
236
237    quote! {
238        #[automatically_derived]
239        impl #impl_generics nexosim::model::Model for #name #where_clause {
240            #env
241
242            fn register_schedulables(
243                cx: &mut nexosim::model::BuildContext<impl nexosim::model::ProtoModel<Model = Self>>
244            ) -> nexosim::model::ModelRegistry {
245                let mut registry = nexosim::model::ModelRegistry::default();
246                #(
247                    registry.add(cx.register_schedulable(#registered_methods));
248                )*
249                registry
250            }
251
252            #init
253        }
254    }
255}
256
257/// Renders MyModel::__input associated consts.
258fn get_hidden_method_impls(schedulables: &[ImplItemFn]) -> Vec<proc_macro2::TokenStream> {
259    let mut hidden_methods = Vec::new();
260    let mut registered_schedulables = Vec::new();
261
262    for (i, func) in schedulables.iter().enumerate() {
263        let fname = Ident::new(&format!("__{}", func.sig.ident), func.sig.ident.span());
264
265        // Find argument type token.
266        let ty = func
267            .sig
268            .inputs
269            .iter()
270            .filter_map(|a| {
271                if let FnArg::Typed(t) = a {
272                    Some(t)
273                } else {
274                    None
275                }
276            })
277            .map(|a| a.ty.clone())
278            .next();
279
280        // If no arg is provided, construct a unit type.
281        let ty = match ty {
282            Some(t) => t,
283            None => Box::new(Type::Tuple(TypeTuple {
284                paren_token: Paren(func.sig.span()),
285                elems: Punctuated::new(),
286            })),
287        };
288
289        hidden_methods.push(quote! {
290            #[doc(hidden)]
291            #[allow(non_upper_case_globals)]
292            const #fname: nexosim::model::SchedulableId<Self, #ty> = nexosim::model::SchedulableId::__from_decorated(#i);
293        });
294        registered_schedulables.push(fname);
295    }
296
297    let byte_literals = registered_schedulables
298        .iter()
299        .map(|a| proc_macro2::Literal::byte_string(a.to_string().as_bytes()));
300
301    // Add a hidden method used for producing more meaningful compilation errors,
302    // when a user tries to schedule an undecorated method.
303    hidden_methods.push(quote! {
304        #[doc(hidden)]
305        const fn ____is_schedulable(fname: &'static str) -> bool {
306            match fname.as_bytes() {
307                #(#byte_literals => true,)*
308                _ => false
309            }
310        }
311    });
312
313    hidden_methods
314}
315
316fn collect_nexosim_attributes(f: &mut ImplItemFn) -> Result<Vec<&'static str>, syn::Error> {
317    let mut attrs = Vec::new();
318    let mut indices = Vec::new();
319
320    'outer: for (i, attr) in f.attrs.iter().enumerate() {
321        if !attr.meta.path().is_ident("nexosim") {
322            continue;
323        }
324        indices.push(i);
325
326        match &attr.meta {
327            Meta::List(meta) => {
328                if let Ok(Expr::Path(path)) = meta.parse_args::<Expr>()
329                    && let Some(segment) = path.path.segments.first()
330                {
331                    for attr in AVAILABLE_ATTRS {
332                        if segment.ident == attr {
333                            attrs.push(*attr);
334                            continue 'outer;
335                        }
336                    }
337                }
338
339                if meta.tokens.clone().into_iter().count() > 1 {
340                    return Err(syn::Error::new_spanned(
341                        meta,
342                        "attribute `nexosim` should have exactly one argument!",
343                    ));
344                }
345                return Err(syn::Error::new_spanned(
346                    meta,
347                    "invalid `nexosim` attribute!",
348                ));
349            }
350            _ => {
351                return Err(syn::Error::new_spanned(
352                    &attr.meta,
353                    "invalid `nexosim` attribute!",
354                ));
355            }
356        }
357    }
358
359    for i in indices.iter().rev() {
360        f.attrs.remove(*i);
361    }
362
363    Ok(attrs)
364}
365
366fn init_fn(sig: &Signature) -> Result<proc_macro2::TokenStream, syn::Error> {
367    let ident = sig.ident.clone();
368    match sig.inputs.len() {
369        1 => Ok(quote!(Self::#ident(&mut self))),
370        2 => Ok(quote!(Self::#ident(&mut self, cx))),
371        3 => Ok(quote!(Self::#ident(&mut self, cx, env))),
372        _ => Err(syn::Error::new_spanned(sig, "invalid number of arguments")),
373    }
374}