locktree_derive/
lib.rs

1//! See the [locktree crate](https://crates.io/crates/locktree).
2
3#[cfg(test)]
4mod tests;
5
6use proc_macro2::TokenStream;
7use quote::{quote, ToTokens};
8use std::collections::HashMap;
9use syn::{
10    braced, custom_keyword, parenthesized,
11    parse::{Parse, ParseStream, Result},
12    punctuated::Punctuated,
13    token::Paren,
14    AngleBracketedGenericArguments, Ident, Path, Token,
15};
16
17struct LockTree {
18    map: HashMap<proc_macro2::Ident, LockSequence>,
19}
20
21impl Parse for LockTree {
22    fn parse(input: ParseStream) -> Result<Self> {
23        let mut map = HashMap::new();
24        while !input.is_empty() {
25            let name = input.parse::<Ident>()?;
26            let seq;
27            braced!(seq in input);
28            map.insert(name, seq.parse::<LockSequence>()?);
29        }
30
31        Ok(LockTree { map })
32    }
33}
34
35struct LockSequence {
36    seq: Vec<Lock>,
37}
38
39impl Parse for LockSequence {
40    fn parse(input: ParseStream) -> Result<Self> {
41        Ok(Self {
42            seq: Punctuated::<Lock, Token![,]>::parse_terminated(input)?
43                .into_iter()
44                .collect(),
45        })
46    }
47}
48
49struct Lock {
50    name: String,
51    ty: LockType,
52}
53
54impl Lock {
55    fn fragment(&self, struct_prefix: &str) -> Fragment {
56        let forward = self.forward(struct_prefix);
57        let name =
58            proc_macro2::Ident::new(&self.name, proc_macro2::Span::call_site());
59        let type_declaraction = self.ty.declaration();
60        let init_var = proc_macro2::Ident::new(
61            &format!("{}_value", &self.name),
62            proc_macro2::Span::call_site(),
63        );
64        let generics = self.ty.generics();
65
66        Fragment {
67            main_accessors: self
68                .ty
69                .accessor_functions(&self.name, &forward, true),
70            forward_accessors: self
71                .ty
72                .accessor_functions(&self.name, &forward, false),
73            forward,
74            lock_declaration: quote! {
75                #name: #type_declaraction,
76            },
77            init_arg: quote! {
78                #init_var: #generics
79            },
80            init_statement: quote! {
81                #name: ::locktree::New::new(#init_var),
82            },
83        }
84    }
85
86    fn forward(&self, struct_prefix: &str) -> String {
87        format!("{}{}", struct_prefix, snake_to_camel_case(&self.name))
88    }
89}
90
91impl Parse for Lock {
92    fn parse(input: ParseStream) -> Result<Self> {
93        let name = input.parse::<Ident>()?.to_string();
94        input.parse::<Token![:]>()?;
95        let ty = input.parse::<LockType>()?;
96
97        Ok(Self { name, ty })
98    }
99}
100
101struct LockType {
102    is_async: bool,
103    declaration: TokenStream,
104    generics: TokenStream,
105    interface: LockInterface,
106}
107
108impl LockType {
109    fn accessor_functions(
110        &self,
111        name: &str,
112        forward: &str,
113        is_entry_point: bool,
114    ) -> TokenStream {
115        let name =
116            proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
117        let forward =
118            proc_macro2::Ident::new(&forward, proc_macro2::Span::call_site());
119        let accessor = if is_entry_point {
120            quote! {
121                self
122            }
123        } else {
124            quote! {
125                self.locks
126            }
127        };
128
129        self.interface.accessor_functions(
130            !is_entry_point,
131            self.is_async,
132            &name,
133            &forward,
134            &accessor,
135            &self.declaration,
136        )
137    }
138
139    fn declaration(&self) -> &TokenStream {
140        &self.declaration
141    }
142
143    fn generics(&self) -> &TokenStream {
144        &self.generics
145    }
146}
147
148impl Parse for LockType {
149    fn parse(input: ParseStream) -> Result<Self> {
150        let is_async = input.peek(Token![async]);
151        if is_async {
152            input.parse::<Token![async]>().unwrap();
153        }
154
155        let interface = input.parse::<LockInterface>()?;
156        let hkt = if input.peek(Paren) {
157            let hkt;
158            parenthesized!(hkt in input);
159
160            hkt.parse::<Path>()?.into_token_stream()
161        } else {
162            if is_async {
163                return Err(syn::Error::new(
164                    input.span(),
165                    "async locks must have an explicit HKT",
166                ));
167            }
168
169            interface.default_concrete_type()
170        };
171        let generics = input
172            .parse::<AngleBracketedGenericArguments>()?
173            .args
174            .to_token_stream();
175
176        Ok(Self {
177            is_async,
178            declaration: quote! {
179                #hkt<#generics>
180            },
181            generics,
182            interface,
183        })
184    }
185}
186
187#[derive(Clone, Copy)]
188enum LockInterface {
189    Mutex,
190    RwLock,
191}
192
193impl LockInterface {
194    fn default_concrete_type(&self) -> TokenStream {
195        match self {
196            Self::Mutex => quote! {
197                ::std::sync::Mutex
198            },
199            Self::RwLock => quote! {
200                ::std::sync::RwLock
201            },
202        }
203    }
204
205    fn accessor_functions(
206        &self,
207        use_mut_ref: bool,
208        is_async: bool,
209        name: &proc_macro2::Ident,
210        forward: &proc_macro2::Ident,
211        accessor: &TokenStream,
212        declaration: &TokenStream,
213    ) -> TokenStream {
214        let mut_keyword = if use_mut_ref {
215            Some(proc_macro2::Ident::new(
216                "mut",
217                proc_macro2::Span::call_site(),
218            ))
219        } else {
220            None
221        };
222        match self {
223            Self::Mutex => {
224                let lock_fn_name = proc_macro2::Ident::new(
225                    &format!("lock_{}", name),
226                    proc_macro2::Span::call_site(),
227                );
228                let async_keyword = if is_async { "Async" } else { "" };
229                let guard = proc_macro2::Ident::new(
230                    &format!("Plugged{}MutexGuard", async_keyword),
231                    proc_macro2::Span::call_site(),
232                );
233                let lock = proc_macro2::Ident::new(
234                    &format!("{}Mutex", async_keyword),
235                    proc_macro2::Span::call_site(),
236                );
237
238                quote! {
239                    pub fn #lock_fn_name<'a>(
240                        &'a #mut_keyword self
241                    ) -> (
242                        ::locktree::#guard<'a, #declaration>,
243                        #forward<'a>
244                    ) {
245                        (::locktree::#lock::lock(&#accessor.#name), #forward { locks: #accessor })
246                    }
247                }
248            }
249            Self::RwLock => {
250                let read_fn_name = proc_macro2::Ident::new(
251                    &format!("read_{}", name),
252                    proc_macro2::Span::call_site(),
253                );
254                let write_fn_name = proc_macro2::Ident::new(
255                    &format!("write_{}", name),
256                    proc_macro2::Span::call_site(),
257                );
258                let async_keyword = if is_async { "Async" } else { "" };
259                let read_guard = proc_macro2::Ident::new(
260                    &format!("Plugged{}RwLockReadGuard", async_keyword),
261                    proc_macro2::Span::call_site(),
262                );
263                let write_guard = proc_macro2::Ident::new(
264                    &format!("Plugged{}RwLockWriteGuard", async_keyword),
265                    proc_macro2::Span::call_site(),
266                );
267                let lock = proc_macro2::Ident::new(
268                    &format!("{}RwLock", async_keyword),
269                    proc_macro2::Span::call_site(),
270                );
271
272                quote! {
273                    pub fn #read_fn_name<'a>(
274                        &'a #mut_keyword self
275                    ) -> (
276                        ::locktree::#read_guard<'a, #declaration>,
277                        #forward<'a>
278                    ) {
279                        (::locktree::#lock::read(&#accessor.#name), #forward { locks: #accessor })
280                    }
281
282                    pub fn #write_fn_name<'a>(
283                        &'a #mut_keyword self
284                    ) -> (
285                        ::locktree::#write_guard<'a, #declaration>,
286                        #forward<'a>
287                    ) {
288                        (::locktree::#lock::write(&#accessor.#name), #forward { locks: #accessor })
289                    }
290                }
291            }
292        }
293    }
294}
295
296impl Parse for LockInterface {
297    fn parse(input: ParseStream) -> Result<Self> {
298        custom_keyword!(Mutex);
299        custom_keyword!(RwLock);
300
301        let lookahead = input.lookahead1();
302        if lookahead.peek(Mutex) {
303            input.parse::<Mutex>().unwrap();
304
305            Ok(Self::Mutex)
306        } else if lookahead.peek(RwLock) {
307            input.parse::<RwLock>().unwrap();
308
309            Ok(Self::RwLock)
310        } else {
311            Err(lookahead.error())
312        }
313    }
314}
315
316struct Fragment {
317    main_accessors: TokenStream,
318    forward_accessors: TokenStream,
319    forward: String,
320    lock_declaration: TokenStream,
321    init_arg: TokenStream,
322    init_statement: TokenStream,
323}
324
325#[proc_macro]
326pub fn locktree(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
327    locktree_impl(input.into()).into()
328}
329
330fn locktree_impl(input: TokenStream) -> TokenStream {
331    let map = syn::parse2::<LockTree>(input).unwrap().map;
332    let mut code = TokenStream::new();
333    for (struct_name, LockSequence { seq }) in map {
334        let struct_prefix = format!("{}LockTree", struct_name);
335        let main_struct = proc_macro2::Ident::new(
336            &struct_prefix,
337            proc_macro2::Span::call_site(),
338        );
339        let fragments = seq
340            .into_iter()
341            .map(|x| x.fragment(&struct_prefix))
342            .collect::<Vec<_>>();
343
344        let init_args = fragments.iter().map(|x| &x.init_arg);
345        let init_statements = fragments.iter().map(|x| &x.init_statement);
346        let init_fn = quote! {
347            pub fn new(#(#init_args),*) -> Self {
348                Self {
349                    #(#init_statements)*
350                }
351            }
352        };
353
354        let main_accessors = fragments.iter().map(|x| &x.main_accessors);
355        let lock_declarations = fragments.iter().map(|x| &x.lock_declaration);
356        code.extend(quote! {
357            struct #main_struct {
358                #(#lock_declarations)*
359            }
360
361            impl #main_struct {
362                #init_fn
363
364                #(#main_accessors)*
365            }
366        });
367
368        for (i, fragment) in fragments.iter().enumerate() {
369            let name = proc_macro2::Ident::new(
370                &fragment.forward,
371                proc_macro2::Span::call_site(),
372            );
373            let forward_accessors =
374                fragments[i + 1..].iter().map(|x| &x.forward_accessors);
375            code.extend(quote! {
376                struct #name<'b> {
377                    locks: &'b #main_struct
378                }
379
380                impl<'b> #name<'b> {
381                    #(#forward_accessors)*
382                }
383            });
384        }
385    }
386
387    code
388}
389
390fn snake_to_camel_case(x: &str) -> String {
391    let mut camel = String::new();
392    for word in x.split('_') {
393        camel.extend(word.chars().next().unwrap().to_uppercase());
394        camel.push_str(&word[1..])
395    }
396
397    camel
398}