embedded_executor_macros/
lib.rs

1//! Proc macros for the embedded-executor crate.
2//!
3//! Provides `#[main]` and `#[test]` attributes for running async functions
4//! with a specified IO implementation.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{ItemFn, Path, ReturnType, parse_macro_input, parse_quote};
9
10/// Marks an async function as the main entry point.
11///
12/// # Usage
13///
14/// ```ignore
15/// use embedded_executor::PollingIo;
16///
17/// #[embedded_executor::main(PollingIo)]
18/// async fn main(executor: _, io: _) -> std::io::Result<()> {
19///     // Your async code here
20///     executor.spawn(async { /* ... */ });
21///     Ok(())
22/// }
23/// ```
24///
25/// The IO type must implement `embedded_executor::runtime::Runtime`.
26#[proc_macro_attribute]
27pub fn main(args: TokenStream, item: TokenStream) -> TokenStream {
28    let io_type = if args.is_empty() {
29        // Default to PollingIo if not specified
30        parse_quote!(::embedded_executor::PollingIo)
31    } else {
32        parse_macro_input!(args as Path)
33    };
34
35    let input = parse_macro_input!(item as ItemFn);
36
37    if input.sig.asyncness.is_none() {
38        return syn::Error::new_spanned(input.sig.fn_token, "main function must be async")
39            .to_compile_error()
40            .into();
41    }
42
43    let name = &input.sig.ident;
44    if name != "main" {
45        return syn::Error::new_spanned(name, "function must be named `main`")
46            .to_compile_error()
47            .into();
48    }
49
50    let body = &input.block;
51    let return_type = match &input.sig.output {
52        ReturnType::Default => quote!(()),
53        ReturnType::Type(_, ty) => quote!(#ty),
54    };
55
56    // Check parameter count
57    let param_count = input.sig.inputs.len();
58
59    let expanded = match param_count {
60        0 => quote! {
61            fn main() -> #return_type {
62                use ::embedded_executor::Runtime;
63                let io = ::std::rc::Rc::new(
64                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
65                );
66                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
67                io.block_on(executor, |_executor, _io| async move #body)
68            }
69        },
70        1 => quote! {
71            fn main() -> #return_type {
72                use ::embedded_executor::Runtime;
73                let io = ::std::rc::Rc::new(
74                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
75                );
76                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
77                io.block_on(executor, |executor, _io| async move #body)
78            }
79        },
80        _ => quote! {
81            fn main() -> #return_type {
82                use ::embedded_executor::Runtime;
83                let io = ::std::rc::Rc::new(
84                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
85                );
86                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
87                io.block_on(executor, |executor, io| async move #body)
88            }
89        },
90    };
91
92    expanded.into()
93}
94
95/// Marks an async function as a test.
96///
97/// # Usage
98///
99/// ```ignore
100/// use embedded_executor::MemIo;
101///
102/// #[embedded_executor::test]  // Defaults to MemIo
103/// async fn test_basic(executor: _, io: _) {
104///     // Your test code
105/// }
106///
107/// #[embedded_executor::test(PollingIo)]
108/// async fn test_with_real_io(executor: _, io: _) {
109///     // Your test code
110/// }
111/// ```
112#[proc_macro_attribute]
113pub fn test(args: TokenStream, item: TokenStream) -> TokenStream {
114    let io_type = if args.is_empty() {
115        // Default to MemIo for tests
116        parse_quote!(::embedded_executor::MemIo)
117    } else {
118        parse_macro_input!(args as Path)
119    };
120
121    let input = parse_macro_input!(item as ItemFn);
122
123    if input.sig.asyncness.is_none() {
124        return syn::Error::new_spanned(input.sig.fn_token, "test function must be async")
125            .to_compile_error()
126            .into();
127    }
128
129    let name = &input.sig.ident;
130    let body = &input.block;
131    let attrs = &input.attrs;
132
133    // Check parameter count
134    let param_count = input.sig.inputs.len();
135
136    let expanded = match param_count {
137        0 => quote! {
138            #[::core::prelude::v1::test]
139            #(#attrs)*
140            fn #name() {
141                use ::embedded_executor::Runtime;
142                let io = ::std::rc::Rc::new(
143                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
144                );
145                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
146                io.block_on(executor, |_executor, _io| async move #body)
147            }
148        },
149        1 => quote! {
150            #[::core::prelude::v1::test]
151            #(#attrs)*
152            fn #name() {
153                use ::embedded_executor::Runtime;
154                let io = ::std::rc::Rc::new(
155                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
156                );
157                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
158                io.block_on(executor, |executor, _io| async move #body)
159            }
160        },
161        _ => quote! {
162            #[::core::prelude::v1::test]
163            #(#attrs)*
164            fn #name() {
165                use ::embedded_executor::Runtime;
166                let io = ::std::rc::Rc::new(
167                    <#io_type as Runtime>::create().expect("failed to create IO runtime")
168                );
169                let executor = ::std::rc::Rc::new(::embedded_executor::Executor::new());
170                io.block_on(executor, |executor, io| async move #body)
171            }
172        },
173    };
174
175    expanded.into()
176}