derive_stack_queue/
lib.rs1use darling::{export::NestedMeta, FromMeta};
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse::Error, parse_macro_input, ImplItem, ItemImpl};
5
6const MIN_BUFFER_LEN: usize = 64;
7
8#[cfg(target_pointer_width = "64")]
9const MAX_BUFFER_LEN: usize = u32::MAX as usize;
10#[cfg(target_pointer_width = "32")]
11const MAX_BUFFER_LEN: usize = u16::MAX as usize;
12
13enum Variant {
14 TaskQueue,
15 BackgroundQueue,
16 BatchReducer,
17}
18#[derive(FromMeta)]
19#[darling(default)]
20struct QueueOpt {
21 buffer_size: usize,
22}
23
24impl Default for QueueOpt {
25 fn default() -> Self {
26 QueueOpt { buffer_size: 1024 }
27 }
28}
29
30#[proc_macro_attribute]
37pub fn local_queue(
38 args: proc_macro::TokenStream,
39 input: proc_macro::TokenStream,
40) -> proc_macro::TokenStream {
41 let attr_args = match NestedMeta::parse_meta_list(args.into()) {
42 Ok(v) => v,
43 Err(e) => {
44 return darling::Error::from(e).write_errors().into();
45 }
46 };
47
48 let mut input = parse_macro_input!(input as ItemImpl);
49
50 input.attrs = vec![];
51
52 let ident = &input.self_ty;
53
54 let QueueOpt { buffer_size } = match QueueOpt::from_list(&attr_args) {
55 Ok(attr) => attr,
56 Err(err) => {
57 return err.write_errors().into();
58 }
59 };
60
61 if buffer_size > MAX_BUFFER_LEN {
62 return Error::new(
63 Span::call_site(),
64 format!("buffer_size must not exceed {MAX_BUFFER_LEN}"),
65 )
66 .into_compile_error()
67 .into();
68 }
69
70 if buffer_size < MIN_BUFFER_LEN {
71 return Error::new(
72 Span::call_site(),
73 format!("buffer_size must be at least {MIN_BUFFER_LEN}"),
74 )
75 .into_compile_error()
76 .into();
77 }
78
79 if buffer_size.ne(&buffer_size.next_power_of_two()) {
80 return Error::new(Span::call_site(), "buffer_size must be a power of 2")
81 .into_compile_error()
82 .into();
83 }
84
85 let variant = match &input.trait_ {
86 Some((_, path, _)) => {
87 let segments: Vec<_> = path
88 .segments
89 .iter()
90 .map(|segment| segment.ident.to_string())
91 .collect();
92
93 match *segments
94 .iter()
95 .map(String::as_ref)
96 .collect::<Vec<&str>>()
97 .as_slice()
98 {
99 ["stack_queue", "TaskQueue"] | ["TaskQueue"] => Some(Variant::TaskQueue),
100 ["stack_queue", "BackgroundQueue"] | ["BackgroundQueue"] => Some(Variant::BackgroundQueue),
101 ["stack_queue", "BatchReducer"] | ["BatchReducer"] => Some(Variant::BatchReducer),
102 _ => None,
103 }
104 }
105 None => None,
106 };
107
108 let variant = match variant {
109 Some(variant) => variant,
110 None => {
111 return Error::new(
112 Span::call_site(),
113 "must be used on TaskQueue, BackgroundQueue or BatchReducer impl",
114 )
115 .into_compile_error()
116 .into();
117 }
118 };
119
120 let task = match input
121 .items
122 .iter()
123 .filter_map(|impl_item| {
124 if let ImplItem::Type(impl_type) = impl_item {
125 Some(impl_type)
126 } else {
127 None
128 }
129 })
130 .find(|impl_type| impl_type.ident == "Task")
131 .map(|task_impl| &task_impl.ty)
132 {
133 Some(impl_type) => impl_type,
134 None => {
135 return Error::new(Span::call_site(), "missing `Task` in implementation")
136 .into_compile_error()
137 .into();
138 }
139 };
140
141 let buffer_cell = match &variant {
142 Variant::TaskQueue => quote!(stack_queue::task::TaskRef<#ident>),
143 Variant::BackgroundQueue => quote!(stack_queue::BufferCell<#task>),
144 Variant::BatchReducer => quote!(stack_queue::BufferCell<#task>),
145 };
146
147 let queue = quote!(stack_queue::StackQueue<#buffer_cell, #buffer_size>);
148
149 let expanded = quote!(
150 #input
151
152 #[cfg(not(loom))]
153 impl stack_queue::LocalQueue<#buffer_size> for #ident {
154 type BufferCell = #buffer_cell;
155
156 fn queue() -> &'static std::thread::LocalKey<#queue> {
157 thread_local! {
158 static QUEUE: #queue = stack_queue::StackQueue::default();
159 }
160
161 &QUEUE
162 }
163 }
164
165 #[cfg(loom)]
166 impl stack_queue::LocalQueue<#buffer_size> for #ident {
167 type BufferCell = #buffer_cell;
168
169 fn queue() -> &'static stack_queue::loom::thread::LocalKey<#queue> {
170 stack_queue::loom::thread_local! {
171 static QUEUE: #queue = stack_queue::StackQueue::default();
172 }
173
174 &QUEUE
175 }
176 }
177 );
178
179 expanded.into()
180}