gn_redisadapter_derive/
lib.rs

1extern crate proc_macro;
2
3use lazy_static::lazy_static;
4use proc_macro::TokenStream;
5use quote::quote;
6use std::sync::Mutex;
7use syn::DeriveInput;
8use syn::{self, Ident};
9
10struct SafeDataStruct {
11    type_name: String,
12    impl_type: ImplType,
13    db_name: String,
14}
15
16#[derive(PartialEq)]
17enum ImplType {
18    InsertWriter,
19    OutputReader,
20    Identifiable,
21    Updater,
22}
23
24lazy_static! {
25    static ref DB_STRUCTS: Mutex<Vec<SafeDataStruct>> = Mutex::new(Vec::new());
26}
27
28#[proc_macro_derive(RedisInsertWriter, attributes(name))]
29pub fn insert_writer_derive(input: TokenStream) -> TokenStream {
30    let ast: DeriveInput = syn::parse(input).unwrap();
31    insert_new_struct(&ast, ImplType::InsertWriter);
32    impl_insert_writer(&ast)
33}
34
35#[proc_macro_derive(RedisOutputReader, attributes(uuid))]
36pub fn output_reader_derive(input: TokenStream) -> TokenStream {
37    let ast = syn::parse(input).unwrap();
38    insert_new_struct(&ast, ImplType::OutputReader);
39    impl_output_reader(&ast)
40}
41
42#[proc_macro_derive(RedisIdentifiable, attributes(name, single_instance))]
43pub fn identifiable_derive(input: TokenStream) -> TokenStream {
44    let ast = syn::parse(input).unwrap();
45    insert_new_struct(&ast, ImplType::Identifiable);
46    impl_identifiable(&ast)
47}
48
49#[proc_macro_derive(RedisUpdater, attributes(name))]
50pub fn updater_derive(input: TokenStream) -> TokenStream {
51    let ast = syn::parse(input).unwrap();
52
53    insert_new_struct(&ast, ImplType::Updater);
54    for strct in DB_STRUCTS.lock().unwrap().iter() {
55        if strct.impl_type == ImplType::InsertWriter && strct.db_name == get_name_attr(&ast) {
56            return impl_updater(&ast, &strct);
57        }
58    }
59    panic!("No parent struct found for updater. Please make sure the parent struct has been defined before the updater.");
60}
61
62fn impl_insert_writer(ast: &syn::DeriveInput) -> TokenStream {
63    let name = &ast.ident;
64    let data = match &ast.data {
65        syn::Data::Struct(data) => data,
66        _ => panic!("Only structs are supported"),
67    };
68
69    let sets: Vec<proc_macro2::TokenStream> = data.fields.iter().map(|field| {
70        let field_name = field.ident.as_ref().unwrap();
71        quote! {
72            self.#field_name.write(pipe, format!("{base_key}:{}", stringify!(#field_name)).as_str())?;
73        }
74
75    }).collect();
76
77    let expire_sets: Vec<proc_macro2::TokenStream> = data.fields.iter().map(|field| {
78        let field_name = field.ident.as_ref().unwrap();
79        quote! {
80            pipe.expire(format!("{base_key}:{}", stringify!(#field_name)).as_str(), timeout);
81        }
82
83    }).collect();
84
85
86
87    let gen = quote! {
88            impl gn_matchmaking_state::adapters::redis::RedisInsertWriter for #name {
89                fn write(&self, pipe: &mut gn_matchmaking_state::adapters::redis::Pipeline, base_key: &str) -> Result<(), Box<dyn std::error::Error>> {
90                    #(#sets)*
91                    Ok(())
92                }
93            }
94
95            impl gn_matchmaking_state::adapters::redis::RedisExpireable for #name {
96                fn expire(&self, pipe: &mut gn_matchmaking_state::adapters::redis::Pipeline, base_key: &str, timeout: i64) -> Result<(), Box<dyn std::error::Error>> {
97                    #(#expire_sets)*
98                    Ok(())
99                }
100            }
101    };
102    gen.into()
103}
104
105fn impl_output_reader(ast: &syn::DeriveInput) -> TokenStream {
106    let name = &ast.ident;
107    let data = match &ast.data {
108        syn::Data::Struct(data) => data,
109        _ => panic!("Only structs are supported"),
110    };
111
112    let mut uuid_field = Option::None;
113
114    let found = data
115        .fields
116        .iter()
117        .find(|x| x.attrs.iter().any(|x| x.path.is_ident("uuid")));
118    if let Some(found) = found {
119        uuid_field = Some(found.ident.as_ref().unwrap());
120    }
121
122    let sets: Vec<proc_macro2::TokenStream> = data
123        .fields
124        .iter()
125        .filter(|x| uuid_field == None || x.ident.as_ref().unwrap() != uuid_field.unwrap())
126        .map(|field| {
127            let field_name = field.ident.as_ref().unwrap();
128            let ty = &field.ty;
129            quote! {
130                #field_name: <#ty as gn_matchmaking_state::adapters::redis::RedisOutputReader>::read(connection, &format!("{base_key}:{}", stringify!(#field_name)))?
131            }
132        })
133        .collect();
134
135    let uuid_code = match uuid_field {
136        Some(field) => quote! {
137            #field: base_key.to_owned(),
138        },
139        None => quote! {},
140    };
141
142    let gen = quote! {
143        impl gn_matchmaking_state::adapters::redis::RedisOutputReader for #name {
144            fn read(connection: &mut gn_matchmaking_state::adapters::redis::Connection, base_key: &str) -> Result<Self, Box<dyn std::error::Error>> {
145                Ok(Self {
146                    #uuid_code
147                    #(#sets),*
148                })
149            }
150    }
151    };
152
153    gen.into()
154}
155
156fn impl_identifiable(ast: &syn::DeriveInput) -> TokenStream {
157    let name = &ast.ident;
158
159    let db_name = get_name_attr(ast);
160
161    let mut single_instance = false;
162    ast.attrs.iter().for_each(|attr| {
163        if attr.path.is_ident("single_instance") {
164            single_instance = attr.parse_args::<syn::LitBool>().unwrap().value();
165        }
166    });
167
168    let next_uuid = match single_instance {
169        true => quote! {
170            fn next_uuid(connection: &mut gn_matchmaking_state::adapters::redis::Connection) -> Result<String, Box<dyn std::error::Error>> {
171                Ok(format!("-1:{}", Self::name()))
172            }
173        },
174        false => quote! {},
175    };
176
177    let gen = quote! {
178        impl gn_matchmaking_state::adapters::redis::RedisIdentifiable for #name {
179            fn name() -> String {
180                #db_name.to_owned()
181            }
182
183            #next_uuid
184    }
185    };
186    gen.into()
187}
188
189fn impl_updater(ast: &syn::DeriveInput, parent: &SafeDataStruct) -> TokenStream {
190    let name = &ast.ident;
191    let data = match &ast.data {
192        syn::Data::Struct(data) => data,
193        _ => panic!("Only structs are supported"),
194    };
195
196    let sets: Vec<proc_macro2::TokenStream> = data.fields.iter().map(|field| {
197        let field_name = field.ident.as_ref().unwrap();
198        quote! {
199            if self.#field_name.is_some() {
200                self.#field_name.clone().unwrap().write(pipe, format!("{uuid}:{}", stringify!(#field_name)).as_str())?;
201            }
202        }
203    }).collect();
204
205    let option_conversion: Vec<proc_macro2::TokenStream> = data
206        .fields
207        .iter()
208        .map(|field| {
209            let field_name = field.ident.as_ref().unwrap();
210            quote! {
211                #field_name: Some(parent.#field_name.clone()),
212            }
213        })
214        .collect();
215
216    let mut updater_name = format!("{}Updater", name.to_string());
217    ast.attrs.iter().for_each(|attr| {
218        if attr.path.is_ident("update_struct") {
219            updater_name = attr.parse_args::<syn::LitStr>().unwrap().value();
220        }
221    });
222
223    let parent_ident = Ident::new(&parent.type_name, name.span());
224    let gen = quote! {
225
226            impl gn_matchmaking_state::adapters::redis::RedisUpdater<#parent_ident> for #name {
227                fn update(&self, pipe: &mut gn_matchmaking_state::adapters::redis::Pipeline, uuid: &str) -> Result<(), Box<dyn std::error::Error>> {
228                    use gn_matchmaking_state::adapters::redis::RedisInsertWriter;
229                    #(#sets)*
230                    Ok(())
231                }
232            }
233
234            impl From<#parent_ident> for #name {
235                fn from(parent: #parent_ident) -> Self {
236                    Self {
237                        #(#option_conversion)*
238                    }
239                }
240            }
241    };
242    gen.into()
243}
244
245fn get_name_attr(ast: &syn::DeriveInput) -> String {
246    for attr in ast.attrs.iter() {
247        if attr.path.is_ident("name") {
248            return attr.parse_args::<syn::LitStr>().unwrap().value();
249        }
250    }
251    let name = &ast.ident;
252    format!("{}s", name.to_string().to_lowercase())
253}
254
255fn insert_new_struct(ast: &syn::DeriveInput, impl_type: ImplType) {
256    let mut db_structs = DB_STRUCTS.lock().unwrap();
257    let db_name = get_name_attr(ast);
258
259    let safe_struct = SafeDataStruct {
260        type_name: ast.ident.to_string(),
261        impl_type,
262        db_name,
263    };
264    db_structs.push(safe_struct);
265}