dionysos_derives/
lib.rs

1use proc_macro::{self, TokenStream};
2use quote::quote;
3
4mod find_fields;
5use find_fields::*;
6
7#[proc_macro_derive(FileProvider, attributes(consumers_list))]
8pub fn derive_file_provider(input: TokenStream) -> TokenStream {
9    let ast: syn::DeriveInput = syn::parse(input).unwrap();
10    let ident = &ast.ident;
11    
12    let fields = find_fields_by_attrname(&ast, "consumers_list");
13    let consumers_list = match fields.len() {
14        0 => panic!("no field with attribute consumers_list found"),
15        1 => &fields[0],
16        _ => panic!("multiple fields with #[consumers_list] defined")
17    };
18
19    let cl_ident = &consumers_list.ident;
20
21    let output = quote! {
22        impl FileProvider for #ident {
23            fn register_consumer(&mut self, consumer: Box<dyn FileConsumer>) {
24                self.#cl_ident.push(consumer);
25            }
26        }
27    };
28    output.into()
29}
30
31#[proc_macro_derive(FileConsumer, attributes(consumer_data, thread_handle))]
32pub fn derive_file_consumer(input: TokenStream) -> TokenStream {
33    let ast: syn::DeriveInput = syn::parse(input).unwrap();
34    let ident = &ast.ident;
35
36    let fields = find_fields_by_attrname(&ast, "consumers_list");
37    let consumers_list = match fields.len() {
38        0 => None,
39        1 => fields.into_iter().next(),
40        _ => panic!("multiple fields with #[consumers_list] defined")
41    };
42
43    let fields = find_fields_by_attrname(&ast, "thread_handle");
44    let thread_handle = match fields.len() {
45        0 => panic!("no field with attribute thread_handle found"),
46        1 => &fields[0],
47        _ => panic!("multiple fields with #[thread_handle] defined")
48    };
49
50
51    let fields = find_fields_by_attrname(&ast, "consumer_data");
52    let mut consumer_data = match fields.len() {
53        0 => None,
54        1 =>  {
55            let field = &fields[0];
56            Some((field.ident.clone().unwrap(), field.ty.clone()))
57        }
58        _ => panic!("multiple fields with #[consumer_data] defined")
59    };
60
61    if let Some(cd) = consumer_data.take() {
62        let outer_type = cd.1.clone();
63        match outer_type {
64            syn::Type::Path(path) => {
65                'outer: for segment in path.path.segments.iter() {
66                    match &segment.arguments {
67                        syn::PathArguments::AngleBracketed(args) => {
68                            for arg in args.args.iter() {
69                                match arg {
70                                    syn::GenericArgument::Type(t) => {
71                                        consumer_data = Some((cd.0, t.clone()));
72                                        break 'outer;
73                                    }
74                                    _ => ()
75                                }
76                            }
77                        }
78                        _ => ()
79                    }
80                }
81            }
82            _ => ()
83        }
84    }
85
86    let has_worker = match consumer_data {
87        None => {
88            quote!{
89                impl HasWorker<()> for #ident {}
90            }
91        }
92        Some(ref cd) => {
93            let consumerdata_type = &cd.1;
94            quote! {
95                impl HasWorker<#consumerdata_type> for #ident {}
96            }
97        }
98    };
99
100    let consumers_ref = match consumers_list {
101        Some(cl) => {
102            let cl_ident = &cl.ident;
103            quote! {
104                std::mem::take(&mut self.#cl_ident)
105            }
106        }
107        None => {
108            quote!{
109                Vec::new()
110            }
111        }
112    };
113
114    let start_with = match consumer_data {
115        None => {
116            quote!{
117                fn start_with(&mut self, receiver: std::sync::mpsc::Receiver<std::sync::Arc<ScannerResult>>) {
118                    let dummy = Arc::new(());
119                    let consumers = #consumers_ref;
120                    let handle = std::thread::spawn(|| Self::worker(receiver, consumers, dummy));
121                    self.thread_handle = Some(handle);
122                }
123            }
124        }
125        Some(ref cd) => {
126            let consumerdata_name = &cd.0;
127            quote! {
128                fn start_with(&mut self, receiver: std::sync::mpsc::Receiver<std::sync::Arc<ScannerResult>>) {
129                    let data = Arc::clone(&self.#consumerdata_name);
130                    let consumers = std::mem::take(&mut self.consumers);
131                    let handle = std::thread::spawn(|| Self::worker(receiver, consumers, data));
132                    self.thread_handle = Some(handle);
133                }
134            }
135        }
136    };
137
138    let th_ident = &thread_handle.ident;
139
140    let output = quote! {
141        #has_worker
142
143        impl FileConsumer for #ident {
144            fn join(&mut self) {
145                if let Some(th) = self.#th_ident.take() {
146                    match th.join() {
147                        Err(why) => {
148                            log::error!("join: {:?}", why);
149                            // do not abort, instead also join() the remaining threads
150                            // return Err(Box::new(why));
151                        }
152                        Ok(_) => ()
153                    }
154                }
155            }
156            #start_with
157        }
158    };
159    output.into()
160}