1#[cfg(test)]
4mod tests;
5
6use proc_macro2::TokenStream;
7use quote::{quote, ToTokens};
8use std::collections::HashMap;
9use syn::{
10 braced, custom_keyword, parenthesized,
11 parse::{Parse, ParseStream, Result},
12 punctuated::Punctuated,
13 token::Paren,
14 AngleBracketedGenericArguments, Ident, Path, Token,
15};
16
17struct LockTree {
18 map: HashMap<proc_macro2::Ident, LockSequence>,
19}
20
21impl Parse for LockTree {
22 fn parse(input: ParseStream) -> Result<Self> {
23 let mut map = HashMap::new();
24 while !input.is_empty() {
25 let name = input.parse::<Ident>()?;
26 let seq;
27 braced!(seq in input);
28 map.insert(name, seq.parse::<LockSequence>()?);
29 }
30
31 Ok(LockTree { map })
32 }
33}
34
35struct LockSequence {
36 seq: Vec<Lock>,
37}
38
39impl Parse for LockSequence {
40 fn parse(input: ParseStream) -> Result<Self> {
41 Ok(Self {
42 seq: Punctuated::<Lock, Token![,]>::parse_terminated(input)?
43 .into_iter()
44 .collect(),
45 })
46 }
47}
48
49struct Lock {
50 name: String,
51 ty: LockType,
52}
53
54impl Lock {
55 fn fragment(&self, struct_prefix: &str) -> Fragment {
56 let forward = self.forward(struct_prefix);
57 let name =
58 proc_macro2::Ident::new(&self.name, proc_macro2::Span::call_site());
59 let type_declaraction = self.ty.declaration();
60 let init_var = proc_macro2::Ident::new(
61 &format!("{}_value", &self.name),
62 proc_macro2::Span::call_site(),
63 );
64 let generics = self.ty.generics();
65
66 Fragment {
67 main_accessors: self
68 .ty
69 .accessor_functions(&self.name, &forward, true),
70 forward_accessors: self
71 .ty
72 .accessor_functions(&self.name, &forward, false),
73 forward,
74 lock_declaration: quote! {
75 #name: #type_declaraction,
76 },
77 init_arg: quote! {
78 #init_var: #generics
79 },
80 init_statement: quote! {
81 #name: ::locktree::New::new(#init_var),
82 },
83 }
84 }
85
86 fn forward(&self, struct_prefix: &str) -> String {
87 format!("{}{}", struct_prefix, snake_to_camel_case(&self.name))
88 }
89}
90
91impl Parse for Lock {
92 fn parse(input: ParseStream) -> Result<Self> {
93 let name = input.parse::<Ident>()?.to_string();
94 input.parse::<Token![:]>()?;
95 let ty = input.parse::<LockType>()?;
96
97 Ok(Self { name, ty })
98 }
99}
100
101struct LockType {
102 is_async: bool,
103 declaration: TokenStream,
104 generics: TokenStream,
105 interface: LockInterface,
106}
107
108impl LockType {
109 fn accessor_functions(
110 &self,
111 name: &str,
112 forward: &str,
113 is_entry_point: bool,
114 ) -> TokenStream {
115 let name =
116 proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
117 let forward =
118 proc_macro2::Ident::new(&forward, proc_macro2::Span::call_site());
119 let accessor = if is_entry_point {
120 quote! {
121 self
122 }
123 } else {
124 quote! {
125 self.locks
126 }
127 };
128
129 self.interface.accessor_functions(
130 !is_entry_point,
131 self.is_async,
132 &name,
133 &forward,
134 &accessor,
135 &self.declaration,
136 )
137 }
138
139 fn declaration(&self) -> &TokenStream {
140 &self.declaration
141 }
142
143 fn generics(&self) -> &TokenStream {
144 &self.generics
145 }
146}
147
148impl Parse for LockType {
149 fn parse(input: ParseStream) -> Result<Self> {
150 let is_async = input.peek(Token![async]);
151 if is_async {
152 input.parse::<Token![async]>().unwrap();
153 }
154
155 let interface = input.parse::<LockInterface>()?;
156 let hkt = if input.peek(Paren) {
157 let hkt;
158 parenthesized!(hkt in input);
159
160 hkt.parse::<Path>()?.into_token_stream()
161 } else {
162 if is_async {
163 return Err(syn::Error::new(
164 input.span(),
165 "async locks must have an explicit HKT",
166 ));
167 }
168
169 interface.default_concrete_type()
170 };
171 let generics = input
172 .parse::<AngleBracketedGenericArguments>()?
173 .args
174 .to_token_stream();
175
176 Ok(Self {
177 is_async,
178 declaration: quote! {
179 #hkt<#generics>
180 },
181 generics,
182 interface,
183 })
184 }
185}
186
187#[derive(Clone, Copy)]
188enum LockInterface {
189 Mutex,
190 RwLock,
191}
192
193impl LockInterface {
194 fn default_concrete_type(&self) -> TokenStream {
195 match self {
196 Self::Mutex => quote! {
197 ::std::sync::Mutex
198 },
199 Self::RwLock => quote! {
200 ::std::sync::RwLock
201 },
202 }
203 }
204
205 fn accessor_functions(
206 &self,
207 use_mut_ref: bool,
208 is_async: bool,
209 name: &proc_macro2::Ident,
210 forward: &proc_macro2::Ident,
211 accessor: &TokenStream,
212 declaration: &TokenStream,
213 ) -> TokenStream {
214 let mut_keyword = if use_mut_ref {
215 Some(proc_macro2::Ident::new(
216 "mut",
217 proc_macro2::Span::call_site(),
218 ))
219 } else {
220 None
221 };
222 match self {
223 Self::Mutex => {
224 let lock_fn_name = proc_macro2::Ident::new(
225 &format!("lock_{}", name),
226 proc_macro2::Span::call_site(),
227 );
228 let async_keyword = if is_async { "Async" } else { "" };
229 let guard = proc_macro2::Ident::new(
230 &format!("Plugged{}MutexGuard", async_keyword),
231 proc_macro2::Span::call_site(),
232 );
233 let lock = proc_macro2::Ident::new(
234 &format!("{}Mutex", async_keyword),
235 proc_macro2::Span::call_site(),
236 );
237
238 quote! {
239 pub fn #lock_fn_name<'a>(
240 &'a #mut_keyword self
241 ) -> (
242 ::locktree::#guard<'a, #declaration>,
243 #forward<'a>
244 ) {
245 (::locktree::#lock::lock(&#accessor.#name), #forward { locks: #accessor })
246 }
247 }
248 }
249 Self::RwLock => {
250 let read_fn_name = proc_macro2::Ident::new(
251 &format!("read_{}", name),
252 proc_macro2::Span::call_site(),
253 );
254 let write_fn_name = proc_macro2::Ident::new(
255 &format!("write_{}", name),
256 proc_macro2::Span::call_site(),
257 );
258 let async_keyword = if is_async { "Async" } else { "" };
259 let read_guard = proc_macro2::Ident::new(
260 &format!("Plugged{}RwLockReadGuard", async_keyword),
261 proc_macro2::Span::call_site(),
262 );
263 let write_guard = proc_macro2::Ident::new(
264 &format!("Plugged{}RwLockWriteGuard", async_keyword),
265 proc_macro2::Span::call_site(),
266 );
267 let lock = proc_macro2::Ident::new(
268 &format!("{}RwLock", async_keyword),
269 proc_macro2::Span::call_site(),
270 );
271
272 quote! {
273 pub fn #read_fn_name<'a>(
274 &'a #mut_keyword self
275 ) -> (
276 ::locktree::#read_guard<'a, #declaration>,
277 #forward<'a>
278 ) {
279 (::locktree::#lock::read(&#accessor.#name), #forward { locks: #accessor })
280 }
281
282 pub fn #write_fn_name<'a>(
283 &'a #mut_keyword self
284 ) -> (
285 ::locktree::#write_guard<'a, #declaration>,
286 #forward<'a>
287 ) {
288 (::locktree::#lock::write(&#accessor.#name), #forward { locks: #accessor })
289 }
290 }
291 }
292 }
293 }
294}
295
296impl Parse for LockInterface {
297 fn parse(input: ParseStream) -> Result<Self> {
298 custom_keyword!(Mutex);
299 custom_keyword!(RwLock);
300
301 let lookahead = input.lookahead1();
302 if lookahead.peek(Mutex) {
303 input.parse::<Mutex>().unwrap();
304
305 Ok(Self::Mutex)
306 } else if lookahead.peek(RwLock) {
307 input.parse::<RwLock>().unwrap();
308
309 Ok(Self::RwLock)
310 } else {
311 Err(lookahead.error())
312 }
313 }
314}
315
316struct Fragment {
317 main_accessors: TokenStream,
318 forward_accessors: TokenStream,
319 forward: String,
320 lock_declaration: TokenStream,
321 init_arg: TokenStream,
322 init_statement: TokenStream,
323}
324
325#[proc_macro]
326pub fn locktree(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
327 locktree_impl(input.into()).into()
328}
329
330fn locktree_impl(input: TokenStream) -> TokenStream {
331 let map = syn::parse2::<LockTree>(input).unwrap().map;
332 let mut code = TokenStream::new();
333 for (struct_name, LockSequence { seq }) in map {
334 let struct_prefix = format!("{}LockTree", struct_name);
335 let main_struct = proc_macro2::Ident::new(
336 &struct_prefix,
337 proc_macro2::Span::call_site(),
338 );
339 let fragments = seq
340 .into_iter()
341 .map(|x| x.fragment(&struct_prefix))
342 .collect::<Vec<_>>();
343
344 let init_args = fragments.iter().map(|x| &x.init_arg);
345 let init_statements = fragments.iter().map(|x| &x.init_statement);
346 let init_fn = quote! {
347 pub fn new(#(#init_args),*) -> Self {
348 Self {
349 #(#init_statements)*
350 }
351 }
352 };
353
354 let main_accessors = fragments.iter().map(|x| &x.main_accessors);
355 let lock_declarations = fragments.iter().map(|x| &x.lock_declaration);
356 code.extend(quote! {
357 struct #main_struct {
358 #(#lock_declarations)*
359 }
360
361 impl #main_struct {
362 #init_fn
363
364 #(#main_accessors)*
365 }
366 });
367
368 for (i, fragment) in fragments.iter().enumerate() {
369 let name = proc_macro2::Ident::new(
370 &fragment.forward,
371 proc_macro2::Span::call_site(),
372 );
373 let forward_accessors =
374 fragments[i + 1..].iter().map(|x| &x.forward_accessors);
375 code.extend(quote! {
376 struct #name<'b> {
377 locks: &'b #main_struct
378 }
379
380 impl<'b> #name<'b> {
381 #(#forward_accessors)*
382 }
383 });
384 }
385 }
386
387 code
388}
389
390fn snake_to_camel_case(x: &str) -> String {
391 let mut camel = String::new();
392 for word in x.split('_') {
393 camel.extend(word.chars().next().unwrap().to_uppercase());
394 camel.push_str(&word[1..])
395 }
396
397 camel
398}