Skip to main content

dtact_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    FnArg, ItemFn, Lit, Meta, Token, parse::Parse, parse::ParseStream, parse_macro_input,
5    punctuated::Punctuated,
6};
7
8struct TaskArgs {
9    priority: String,
10    affinity: String,
11    kind: String,
12    stack: String,
13    switcher: String,
14}
15
16impl Parse for TaskArgs {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let vars = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
19        let mut priority = "Normal".to_string();
20        let mut affinity = "SameCore".to_string();
21        let mut kind = "Compute".to_string();
22        let mut stack = "2M".to_string();
23        let mut switcher = "CrossThreadFloat".to_string();
24
25        for var in vars {
26            if let Meta::NameValue(nv) = var {
27                if nv.path.is_ident("priority") {
28                    if let syn::Expr::Lit(syn::ExprLit {
29                        lit: Lit::Str(s), ..
30                    }) = nv.value
31                    {
32                        priority = s.value();
33                    }
34                } else if nv.path.is_ident("affinity") {
35                    if let syn::Expr::Lit(syn::ExprLit {
36                        lit: Lit::Str(s), ..
37                    }) = nv.value
38                    {
39                        affinity = s.value();
40                    }
41                } else if nv.path.is_ident("kind") {
42                    if let syn::Expr::Lit(syn::ExprLit {
43                        lit: Lit::Str(s), ..
44                    }) = nv.value
45                    {
46                        kind = s.value();
47                    }
48                } else if nv.path.is_ident("stack") {
49                    if let syn::Expr::Lit(syn::ExprLit {
50                        lit: Lit::Str(s), ..
51                    }) = nv.value
52                    {
53                        stack = s.value();
54                    }
55                } else if nv.path.is_ident("switcher")
56                    && let syn::Expr::Lit(syn::ExprLit {
57                        lit: Lit::Str(s), ..
58                    }) = nv.value
59                {
60                    switcher = s.value();
61                }
62            }
63        }
64
65        Ok(TaskArgs {
66            priority,
67            affinity,
68            kind,
69            stack,
70            switcher,
71        })
72    }
73}
74
75#[proc_macro_attribute]
76pub fn task(args: TokenStream, item: TokenStream) -> TokenStream {
77    let args = parse_macro_input!(args as TaskArgs);
78    let input = parse_macro_input!(item as ItemFn);
79
80    let fn_name = &input.sig.ident;
81    let priority = &args.priority;
82    let affinity = &args.affinity;
83    let kind = &args.kind;
84    let stack = &args.stack;
85
86    let metadata_mod = syn::Ident::new(&format!("dtact_metadata_{}", fn_name), fn_name.span());
87    let priority_ident = syn::Ident::new(priority, fn_name.span());
88    let affinity_ident = syn::Ident::new(affinity, fn_name.span());
89    let kind_ident = syn::Ident::new(kind, fn_name.span());
90    let switcher = &args.switcher;
91    let switcher_ident = syn::Ident::new(switcher, fn_name.span());
92
93    let expanded = quote! {
94        #input
95
96        pub mod #metadata_mod {
97            pub const PRIORITY: dtact::Priority = dtact::Priority::#priority_ident;
98            pub const AFFINITY: dtact::topology::Affinity = dtact::topology::Affinity::#affinity_ident;
99            pub const KIND: dtact::WorkloadKind = dtact::WorkloadKind::#kind_ident;
100            pub const STACK_SIZE: &'static str = #stack;
101            pub type SWITCHER = dtact::#switcher_ident;
102        }
103    };
104
105    TokenStream::from(expanded)
106}
107
108#[proc_macro_attribute]
109pub fn export_async(_args: TokenStream, item: TokenStream) -> TokenStream {
110    let input = parse_macro_input!(item as ItemFn);
111    let fn_name = &input.sig.ident;
112    let wrapper_name = syn::Ident::new(&format!("dtact_export_{}", fn_name), fn_name.span());
113
114    let mut c_params = Vec::new();
115    let mut call_args = Vec::new();
116
117    for input in &input.sig.inputs {
118        if let FnArg::Typed(pat_type) = input {
119            let pat = &pat_type.pat;
120            let ty = &pat_type.ty;
121            c_params.push(quote! { #pat: #ty });
122            call_args.push(quote! { #pat });
123        } else {
124            panic!("export_async does not support 'self' parameters");
125        }
126    }
127
128    let expanded = quote! {
129        #input
130
131        #[unsafe(no_mangle)]
132        pub extern "C" fn #wrapper_name(#(#c_params),*) -> dtact::dtact_handle_t {
133            dtact::spawn(#fn_name(#(#call_args),*))
134        }
135    };
136
137    TokenStream::from(expanded)
138}
139
140#[proc_macro_attribute]
141pub fn export_fiber(_args: TokenStream, item: TokenStream) -> TokenStream {
142    let input = parse_macro_input!(item as ItemFn);
143    let fn_name = &input.sig.ident;
144    let wrapper_name = syn::Ident::new(&format!("dtact_export_fiber_{}", fn_name), fn_name.span());
145
146    let mut c_params = Vec::new();
147    let mut call_args = Vec::new();
148
149    for input in &input.sig.inputs {
150        if let FnArg::Typed(pat_type) = input {
151            let pat = &pat_type.pat;
152            let ty = &pat_type.ty;
153            c_params.push(quote! { #pat: #ty });
154            call_args.push(quote! { #pat });
155        } else {
156            panic!("export_fiber does not support 'self' parameters");
157        }
158    }
159
160    let expanded = quote! {
161        #input
162
163        #[unsafe(no_mangle)]
164        pub extern "C" fn #wrapper_name(#(#c_params),*) -> dtact::dtact_handle_t {
165            dtact::api::fiber::spawn_with_stack("2M", move || {
166                #fn_name(#(#call_args),*);
167            })
168        }
169    };
170
171    TokenStream::from(expanded)
172}
173
174struct InitArgs {
175    topology: String,
176    safety: String,
177    workers: usize,
178    capacity: u32,
179    stack: usize,
180    numa: usize,
181}
182
183impl Parse for InitArgs {
184    fn parse(input: ParseStream) -> syn::Result<Self> {
185        let vars = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
186        let mut topology = "P2PMesh".to_string();
187        let mut safety = "Safety1".to_string();
188        let mut workers = 0;
189        let mut capacity = 16384;
190        let mut stack = 2 * 1024 * 1024;
191        let mut numa = 0;
192
193        for var in vars {
194            if let Meta::NameValue(nv) = var {
195                if nv.path.is_ident("topology") {
196                    if let syn::Expr::Lit(syn::ExprLit {
197                        lit: Lit::Str(s), ..
198                    }) = &nv.value
199                    {
200                        topology = s.value();
201                    }
202                } else if nv.path.is_ident("safety") {
203                    if let syn::Expr::Lit(syn::ExprLit {
204                        lit: Lit::Str(s), ..
205                    }) = &nv.value
206                    {
207                        safety = s.value();
208                    }
209                } else if nv.path.is_ident("workers") {
210                    if let syn::Expr::Lit(syn::ExprLit {
211                        lit: Lit::Int(i), ..
212                    }) = &nv.value
213                    {
214                        workers = i.base10_parse()?;
215                    }
216                } else if nv.path.is_ident("capacity") {
217                    if let syn::Expr::Lit(syn::ExprLit {
218                        lit: Lit::Int(i), ..
219                    }) = &nv.value
220                    {
221                        capacity = i.base10_parse()?;
222                    }
223                } else if nv.path.is_ident("stack") {
224                    if let syn::Expr::Lit(syn::ExprLit {
225                        lit: Lit::Int(i), ..
226                    }) = &nv.value
227                    {
228                        stack = i.base10_parse()?;
229                    }
230                } else if nv.path.is_ident("numa")
231                    && let syn::Expr::Lit(syn::ExprLit {
232                        lit: Lit::Int(i), ..
233                    }) = &nv.value
234                {
235                    numa = i.base10_parse()?;
236                }
237            }
238        }
239        Ok(InitArgs {
240            topology,
241            safety,
242            workers,
243            capacity,
244            stack,
245            numa,
246        })
247    }
248}
249
250#[proc_macro_attribute]
251pub fn dtact_init(args: TokenStream, item: TokenStream) -> TokenStream {
252    let args = parse_macro_input!(args as InitArgs);
253    let input = parse_macro_input!(item as ItemFn);
254
255    let topology = &args.topology;
256    let safety = &args.safety;
257    let workers = args.workers;
258    let capacity = args.capacity;
259    let stack = args.stack;
260    let numa = args.numa;
261
262    let topology_ident = syn::Ident::new(topology, input.sig.ident.span());
263    let safety_ident = syn::Ident::new(safety, input.sig.ident.span());
264    let autostart_fn_name = syn::Ident::new("dtact_autostart", input.sig.ident.span());
265
266    let attrs = &input.attrs;
267    let vis = &input.vis;
268    let sig = &input.sig;
269    let block = &input.block;
270
271    let expanded = quote! {
272        #[unsafe(no_mangle)]
273        extern "C" fn #autostart_fn_name() {
274            let runtime = dtact::GLOBAL_RUNTIME.get_or_init(|| {
275                let mut workers_count = #workers;
276                if workers_count == 0 {
277                    workers_count = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1);
278                }
279
280                let scheduler = dtact::dta_scheduler::DtaScheduler::new(
281                    workers_count,
282                    dtact::dta_scheduler::TopologyMode::#topology_ident
283                );
284
285                let pool = dtact::memory_management::ContextPool::new(
286                    #capacity,
287                    #stack,
288                    dtact::memory_management::SafetyLevel::#safety_ident,
289                    #numa
290                ).expect("DTA-V3 Hardware Initialization Failed");
291
292                dtact::Runtime {
293                    scheduler,
294                    pool,
295                    started: core::sync::atomic::AtomicBool::new(false),
296                    shutdown: core::sync::atomic::AtomicBool::new(false),
297                }
298            });
299            runtime.start();
300        }
301
302        #(#attrs)* #vis #sig {
303            #autostart_fn_name();
304            #block
305        }
306    };
307
308    TokenStream::from(expanded)
309}