1use proc_macro::{self, TokenStream};
2use quote::quote;
3
4mod find_fields;
5use find_fields::*;
6
7#[proc_macro_derive(FileProvider, attributes(consumers_list))]
8pub fn derive_file_provider(input: TokenStream) -> TokenStream {
9 let ast: syn::DeriveInput = syn::parse(input).unwrap();
10 let ident = &ast.ident;
11
12 let fields = find_fields_by_attrname(&ast, "consumers_list");
13 let consumers_list = match fields.len() {
14 0 => panic!("no field with attribute consumers_list found"),
15 1 => &fields[0],
16 _ => panic!("multiple fields with #[consumers_list] defined")
17 };
18
19 let cl_ident = &consumers_list.ident;
20
21 let output = quote! {
22 impl FileProvider for #ident {
23 fn register_consumer(&mut self, consumer: Box<dyn FileConsumer>) {
24 self.#cl_ident.push(consumer);
25 }
26 }
27 };
28 output.into()
29}
30
31#[proc_macro_derive(FileConsumer, attributes(consumer_data, thread_handle))]
32pub fn derive_file_consumer(input: TokenStream) -> TokenStream {
33 let ast: syn::DeriveInput = syn::parse(input).unwrap();
34 let ident = &ast.ident;
35
36 let fields = find_fields_by_attrname(&ast, "consumers_list");
37 let consumers_list = match fields.len() {
38 0 => None,
39 1 => fields.into_iter().next(),
40 _ => panic!("multiple fields with #[consumers_list] defined")
41 };
42
43 let fields = find_fields_by_attrname(&ast, "thread_handle");
44 let thread_handle = match fields.len() {
45 0 => panic!("no field with attribute thread_handle found"),
46 1 => &fields[0],
47 _ => panic!("multiple fields with #[thread_handle] defined")
48 };
49
50
51 let fields = find_fields_by_attrname(&ast, "consumer_data");
52 let mut consumer_data = match fields.len() {
53 0 => None,
54 1 => {
55 let field = &fields[0];
56 Some((field.ident.clone().unwrap(), field.ty.clone()))
57 }
58 _ => panic!("multiple fields with #[consumer_data] defined")
59 };
60
61 if let Some(cd) = consumer_data.take() {
62 let outer_type = cd.1.clone();
63 match outer_type {
64 syn::Type::Path(path) => {
65 'outer: for segment in path.path.segments.iter() {
66 match &segment.arguments {
67 syn::PathArguments::AngleBracketed(args) => {
68 for arg in args.args.iter() {
69 match arg {
70 syn::GenericArgument::Type(t) => {
71 consumer_data = Some((cd.0, t.clone()));
72 break 'outer;
73 }
74 _ => ()
75 }
76 }
77 }
78 _ => ()
79 }
80 }
81 }
82 _ => ()
83 }
84 }
85
86 let has_worker = match consumer_data {
87 None => {
88 quote!{
89 impl HasWorker<()> for #ident {}
90 }
91 }
92 Some(ref cd) => {
93 let consumerdata_type = &cd.1;
94 quote! {
95 impl HasWorker<#consumerdata_type> for #ident {}
96 }
97 }
98 };
99
100 let consumers_ref = match consumers_list {
101 Some(cl) => {
102 let cl_ident = &cl.ident;
103 quote! {
104 std::mem::take(&mut self.#cl_ident)
105 }
106 }
107 None => {
108 quote!{
109 Vec::new()
110 }
111 }
112 };
113
114 let start_with = match consumer_data {
115 None => {
116 quote!{
117 fn start_with(&mut self, receiver: std::sync::mpsc::Receiver<std::sync::Arc<ScannerResult>>) {
118 let dummy = Arc::new(());
119 let consumers = #consumers_ref;
120 let handle = std::thread::spawn(|| Self::worker(receiver, consumers, dummy));
121 self.thread_handle = Some(handle);
122 }
123 }
124 }
125 Some(ref cd) => {
126 let consumerdata_name = &cd.0;
127 quote! {
128 fn start_with(&mut self, receiver: std::sync::mpsc::Receiver<std::sync::Arc<ScannerResult>>) {
129 let data = Arc::clone(&self.#consumerdata_name);
130 let consumers = std::mem::take(&mut self.consumers);
131 let handle = std::thread::spawn(|| Self::worker(receiver, consumers, data));
132 self.thread_handle = Some(handle);
133 }
134 }
135 }
136 };
137
138 let th_ident = &thread_handle.ident;
139
140 let output = quote! {
141 #has_worker
142
143 impl FileConsumer for #ident {
144 fn join(&mut self) {
145 if let Some(th) = self.#th_ident.take() {
146 match th.join() {
147 Err(why) => {
148 log::error!("join: {:?}", why);
149 }
152 Ok(_) => ()
153 }
154 }
155 }
156 #start_with
157 }
158 };
159 output.into()
160}