1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 FnArg, ItemFn, Lit, Meta, Token, parse::Parse, parse::ParseStream, parse_macro_input,
5 punctuated::Punctuated,
6};
7
8struct TaskArgs {
9 priority: String,
10 affinity: String,
11 kind: String,
12 stack: String,
13 switcher: String,
14}
15
16impl Parse for TaskArgs {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let vars = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
19 let mut priority = "Normal".to_string();
20 let mut affinity = "SameCore".to_string();
21 let mut kind = "Compute".to_string();
22 let mut stack = "2M".to_string();
23 let mut switcher = "CrossThreadFloat".to_string();
24
25 for var in vars {
26 if let Meta::NameValue(nv) = var {
27 if nv.path.is_ident("priority") {
28 if let syn::Expr::Lit(syn::ExprLit {
29 lit: Lit::Str(s), ..
30 }) = nv.value
31 {
32 priority = s.value();
33 }
34 } else if nv.path.is_ident("affinity") {
35 if let syn::Expr::Lit(syn::ExprLit {
36 lit: Lit::Str(s), ..
37 }) = nv.value
38 {
39 affinity = s.value();
40 }
41 } else if nv.path.is_ident("kind") {
42 if let syn::Expr::Lit(syn::ExprLit {
43 lit: Lit::Str(s), ..
44 }) = nv.value
45 {
46 kind = s.value();
47 }
48 } else if nv.path.is_ident("stack") {
49 if let syn::Expr::Lit(syn::ExprLit {
50 lit: Lit::Str(s), ..
51 }) = nv.value
52 {
53 stack = s.value();
54 }
55 } else if nv.path.is_ident("switcher")
56 && let syn::Expr::Lit(syn::ExprLit {
57 lit: Lit::Str(s), ..
58 }) = nv.value
59 {
60 switcher = s.value();
61 }
62 }
63 }
64
65 Ok(TaskArgs {
66 priority,
67 affinity,
68 kind,
69 stack,
70 switcher,
71 })
72 }
73}
74
75#[proc_macro_attribute]
76pub fn task(args: TokenStream, item: TokenStream) -> TokenStream {
77 let args = parse_macro_input!(args as TaskArgs);
78 let input = parse_macro_input!(item as ItemFn);
79
80 let fn_name = &input.sig.ident;
81 let priority = &args.priority;
82 let affinity = &args.affinity;
83 let kind = &args.kind;
84 let stack = &args.stack;
85
86 let metadata_mod = syn::Ident::new(&format!("dtact_metadata_{}", fn_name), fn_name.span());
87 let priority_ident = syn::Ident::new(priority, fn_name.span());
88 let affinity_ident = syn::Ident::new(affinity, fn_name.span());
89 let kind_ident = syn::Ident::new(kind, fn_name.span());
90 let switcher = &args.switcher;
91 let switcher_ident = syn::Ident::new(switcher, fn_name.span());
92
93 let expanded = quote! {
94 #input
95
96 pub mod #metadata_mod {
97 pub const PRIORITY: dtact::Priority = dtact::Priority::#priority_ident;
98 pub const AFFINITY: dtact::topology::Affinity = dtact::topology::Affinity::#affinity_ident;
99 pub const KIND: dtact::WorkloadKind = dtact::WorkloadKind::#kind_ident;
100 pub const STACK_SIZE: &'static str = #stack;
101 pub type SWITCHER = dtact::#switcher_ident;
102 }
103 };
104
105 TokenStream::from(expanded)
106}
107
108#[proc_macro_attribute]
109pub fn export_async(_args: TokenStream, item: TokenStream) -> TokenStream {
110 let input = parse_macro_input!(item as ItemFn);
111 let fn_name = &input.sig.ident;
112 let wrapper_name = syn::Ident::new(&format!("dtact_export_{}", fn_name), fn_name.span());
113
114 let mut c_params = Vec::new();
115 let mut call_args = Vec::new();
116
117 for input in &input.sig.inputs {
118 if let FnArg::Typed(pat_type) = input {
119 let pat = &pat_type.pat;
120 let ty = &pat_type.ty;
121 c_params.push(quote! { #pat: #ty });
122 call_args.push(quote! { #pat });
123 } else {
124 panic!("export_async does not support 'self' parameters");
125 }
126 }
127
128 let expanded = quote! {
129 #input
130
131 #[unsafe(no_mangle)]
132 pub extern "C" fn #wrapper_name(#(#c_params),*) -> dtact::dtact_handle_t {
133 dtact::spawn(#fn_name(#(#call_args),*))
134 }
135 };
136
137 TokenStream::from(expanded)
138}
139
140#[proc_macro_attribute]
141pub fn export_fiber(_args: TokenStream, item: TokenStream) -> TokenStream {
142 let input = parse_macro_input!(item as ItemFn);
143 let fn_name = &input.sig.ident;
144 let wrapper_name = syn::Ident::new(&format!("dtact_export_fiber_{}", fn_name), fn_name.span());
145
146 let mut c_params = Vec::new();
147 let mut call_args = Vec::new();
148
149 for input in &input.sig.inputs {
150 if let FnArg::Typed(pat_type) = input {
151 let pat = &pat_type.pat;
152 let ty = &pat_type.ty;
153 c_params.push(quote! { #pat: #ty });
154 call_args.push(quote! { #pat });
155 } else {
156 panic!("export_fiber does not support 'self' parameters");
157 }
158 }
159
160 let expanded = quote! {
161 #input
162
163 #[unsafe(no_mangle)]
164 pub extern "C" fn #wrapper_name(#(#c_params),*) -> dtact::dtact_handle_t {
165 dtact::api::fiber::spawn_with_stack("2M", move || {
166 #fn_name(#(#call_args),*);
167 })
168 }
169 };
170
171 TokenStream::from(expanded)
172}
173
174struct InitArgs {
175 topology: String,
176 safety: String,
177 workers: usize,
178 capacity: u32,
179 stack: usize,
180 numa: usize,
181}
182
183impl Parse for InitArgs {
184 fn parse(input: ParseStream) -> syn::Result<Self> {
185 let vars = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
186 let mut topology = "P2PMesh".to_string();
187 let mut safety = "Safety1".to_string();
188 let mut workers = 0;
189 let mut capacity = 16384;
190 let mut stack = 2 * 1024 * 1024;
191 let mut numa = 0;
192
193 for var in vars {
194 if let Meta::NameValue(nv) = var {
195 if nv.path.is_ident("topology") {
196 if let syn::Expr::Lit(syn::ExprLit {
197 lit: Lit::Str(s), ..
198 }) = &nv.value
199 {
200 topology = s.value();
201 }
202 } else if nv.path.is_ident("safety") {
203 if let syn::Expr::Lit(syn::ExprLit {
204 lit: Lit::Str(s), ..
205 }) = &nv.value
206 {
207 safety = s.value();
208 }
209 } else if nv.path.is_ident("workers") {
210 if let syn::Expr::Lit(syn::ExprLit {
211 lit: Lit::Int(i), ..
212 }) = &nv.value
213 {
214 workers = i.base10_parse()?;
215 }
216 } else if nv.path.is_ident("capacity") {
217 if let syn::Expr::Lit(syn::ExprLit {
218 lit: Lit::Int(i), ..
219 }) = &nv.value
220 {
221 capacity = i.base10_parse()?;
222 }
223 } else if nv.path.is_ident("stack") {
224 if let syn::Expr::Lit(syn::ExprLit {
225 lit: Lit::Int(i), ..
226 }) = &nv.value
227 {
228 stack = i.base10_parse()?;
229 }
230 } else if nv.path.is_ident("numa")
231 && let syn::Expr::Lit(syn::ExprLit {
232 lit: Lit::Int(i), ..
233 }) = &nv.value
234 {
235 numa = i.base10_parse()?;
236 }
237 }
238 }
239 Ok(InitArgs {
240 topology,
241 safety,
242 workers,
243 capacity,
244 stack,
245 numa,
246 })
247 }
248}
249
250#[proc_macro_attribute]
251pub fn dtact_init(args: TokenStream, item: TokenStream) -> TokenStream {
252 let args = parse_macro_input!(args as InitArgs);
253 let input = parse_macro_input!(item as ItemFn);
254
255 let topology = &args.topology;
256 let safety = &args.safety;
257 let workers = args.workers;
258 let capacity = args.capacity;
259 let stack = args.stack;
260 let numa = args.numa;
261
262 let topology_ident = syn::Ident::new(topology, input.sig.ident.span());
263 let safety_ident = syn::Ident::new(safety, input.sig.ident.span());
264 let autostart_fn_name = syn::Ident::new("dtact_autostart", input.sig.ident.span());
265
266 let attrs = &input.attrs;
267 let vis = &input.vis;
268 let sig = &input.sig;
269 let block = &input.block;
270
271 let expanded = quote! {
272 #[unsafe(no_mangle)]
273 extern "C" fn #autostart_fn_name() {
274 let runtime = dtact::GLOBAL_RUNTIME.get_or_init(|| {
275 let mut workers_count = #workers;
276 if workers_count == 0 {
277 workers_count = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1);
278 }
279
280 let scheduler = dtact::dta_scheduler::DtaScheduler::new(
281 workers_count,
282 dtact::dta_scheduler::TopologyMode::#topology_ident
283 );
284
285 let pool = dtact::memory_management::ContextPool::new(
286 #capacity,
287 #stack,
288 dtact::memory_management::SafetyLevel::#safety_ident,
289 #numa
290 ).expect("DTA-V3 Hardware Initialization Failed");
291
292 dtact::Runtime {
293 scheduler,
294 pool,
295 started: core::sync::atomic::AtomicBool::new(false),
296 shutdown: core::sync::atomic::AtomicBool::new(false),
297 }
298 });
299 runtime.start();
300 }
301
302 #(#attrs)* #vis #sig {
303 #autostart_fn_name();
304 #block
305 }
306 };
307
308 TokenStream::from(expanded)
309}