cali_derive/
lib.rs

1extern crate proc_macro;
2use std::path::Path;
3
4use convert_case::{Case, Casing};
5use cali_core::protos::parser::get_proto_data;
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::quote;
9use syn::{parse_macro_input, DeriveInput};
10
11#[proc_macro]
12pub fn autogen_protos(_item: TokenStream) -> TokenStream {
13    let gen = quote! {
14        let service_files: Vec<String> = std::fs::read_dir("../interface/grpc/services/")
15            .expect("Could not read contents of interface file")
16            .filter(|entry| entry.is_ok())
17            .map(|entry| entry.unwrap().path().to_str().unwrap().to_string())
18            .collect();
19
20        let out_path = std::path::Path::new("src/protos");
21        if !out_path.exists() {
22            let _ = std::fs::create_dir(out_path)
23                .expect(&format!("Unable to create protos folder {:?}", out_path));
24        }
25
26        if service_files.len() > 0 {
27            tonic_build::configure()
28                .build_server(true)
29                .out_dir(out_path.to_str().unwrap())
30                .compile(service_files.as_slice(), &["../interface/grpc/".to_string()])
31                .unwrap();
32        }
33
34        // build the protos mod.rs
35        let path = std::path::Path::new("../interface/grpc/services");
36        let proto_data = cali_core::protos::parser::get_proto_data(&path).expect("Should have worked");
37        let mut mod_contents = "".to_string();
38        proto_data.services.iter().for_each(|service| {
39            let import_line = format!("pub mod {};\n", service.name.to_case(Case::Snake));
40            mod_contents.push_str(&import_line);
41        });
42        let mod_path = std::path::Path::new("src/protos/mod.rs");
43        std::fs::write(mod_path, mod_contents).expect("Could not write main file");
44    };
45    gen.into()
46}
47
48#[proc_macro]
49pub fn controller(input: TokenStream) -> TokenStream {
50    let controller_struct_name = Ident::new(&format!("{}", input)[..], Span::call_site());
51
52    let gen = quote! {
53        #[derive(Clone)]
54        pub struct #controller_struct_name {}
55
56        impl #controller_struct_name {
57            pub fn new() -> Self {
58                #controller_struct_name {}
59            }
60        }
61    };
62    gen.into()
63}
64#[proc_macro_derive(Ensnare)]
65pub fn derive_ensnare(input: TokenStream) -> TokenStream {
66    let input = parse_macro_input!(input as DeriveInput);
67    let struct_name = input.ident;
68
69    let struct_fields = match input.data {
70        syn::Data::Struct(syn::DataStruct {
71            fields: syn::Fields::Named(named_fields),
72            ..
73        }) => named_fields
74            .named
75            .iter()
76            .filter_map(|f| f.ident.clone())
77            .collect::<Vec<Ident>>(),
78        _ => {
79            panic!("Can only Ensnare struct types");
80        }
81    };
82
83    let bind_points = struct_fields
84        .iter()
85        .map(|_| "?".to_string())
86        .collect::<Vec<String>>()
87        .join(",");
88
89    let fields = struct_fields
90        .iter()
91        .map(|f| f.to_string())
92        .collect::<Vec<String>>()
93        .join(",");
94
95    let bindings: Vec<TokenStream2> = struct_fields
96        .iter()
97        .map(|f| quote!(bind(self.#f.clone())))
98        .collect();
99
100    let expanded = quote! {
101    impl cali_core::store::snare::Ensnarable for #struct_name {
102                fn insert_parts(&self) -> (String, String) {
103                    (#fields.to_string(), #bind_points.to_string())
104                }
105
106                fn capture<'a>(
107                    &'a self,
108                    query: sqlx::query::Query<
109                        'a,
110                        sqlx::MySql,
111                        <sqlx::MySql as sqlx::database::HasArguments<'_>>::Arguments,
112                    >,
113                ) -> sqlx::query::Query<
114                    'a,
115                    sqlx::MySql,
116                    <sqlx::MySql as sqlx::database::HasArguments<'_>>::Arguments,
117                > {
118                    query.#(#bindings).*
119                }
120            }
121
122            impl #struct_name {
123                pub fn trap(self, table_name: &str) -> cali_core::store::snare::Snare<#struct_name> {
124                    cali_core::store::snare::Snare {
125                        query: "".to_string(),
126                        table_name: table_name.to_string(),
127                        data: self,
128                    }
129                }
130            }
131        };
132
133    TokenStream::from(expanded)
134}
135
136#[proc_macro]
137pub fn setup_server(input: TokenStream) -> TokenStream {
138    let app_name: String;
139    let version: String;
140    let extentable_context: Ident;
141
142    let input = proc_macro2::TokenStream::from(input);
143    let mut params_stream = input.into_iter();
144
145    if let Some(proc_macro2::TokenTree::Literal(val)) = params_stream.next() {
146        let temp = format!("{}", val);
147        params_stream.next(); // Skip the comma
148        app_name = temp[1..temp.len() - 1].to_string();
149    } else {
150        panic!("Please add an application name")
151    }
152
153    if let Some(proc_macro2::TokenTree::Literal(val)) = params_stream.next() {
154        let temp = format!("{}", val);
155        version = temp[1..temp.len() - 1].to_string();
156        params_stream.next(); // Skip the comma
157    } else {
158        panic!("Please add a version")
159    }
160
161    if let Some(proc_macro2::TokenTree::Ident(val)) = params_stream.next() {
162        extentable_context = val;
163    } else {
164        panic!("An extentable_context has to be provided")
165    }
166
167    let path = Path::new("./interface/grpc/services");
168    let proto_data = get_proto_data(&path).expect("Should have worked");
169
170    let web_crate = Ident::new(&format!("{}_web", app_name)[..], Span::call_site());
171
172    let controllers: Vec<proc_macro2::TokenStream> = proto_data
173        .services
174        .iter()
175        .map(|service| {
176            let controller_var_name = Ident::new(
177                &format!("{}_controller", service.name.to_case(Case::Snake))[..],
178                Span::call_site(),
179            );
180
181            let controller_snake_name = Ident::new(
182                &format!("{}", service.name.to_case(Case::Snake))[..],
183                Span::call_site(),
184            );
185
186            let controller_name = Ident::new(
187                &format!(
188                    "{}Controller",
189                    service.name.to_case(Case::UpperCamel)
190                )[..],
191                Span::call_site(),
192            );
193
194            quote! {
195                let #controller_var_name = #web_crate::controllers::#controller_snake_name::#controller_name::new();
196            }
197        })
198        .collect();
199
200    let services: Vec<proc_macro2::TokenStream> = proto_data
201        .services
202        .iter()
203        .map(|service| {
204            let controller_var_name = Ident::new(
205                &format!("{}_controller", service.name.to_case(Case::Snake))[..],
206                Span::call_site(),
207            );
208            let service_name = Ident::new(
209                &format!("{}Server", service.name.to_case(Case::UpperCamel))[..],
210                Span::call_site(),
211            );
212
213            let controller_snake_name = Ident::new(
214                &format!("{}", service.name.to_case(Case::Snake))[..],
215                Span::call_site(),
216            );
217            let server_snake_name = Ident::new(
218                &format!("{}_server", service.name.to_case(Case::Snake))[..],
219                Span::call_site(),
220            );
221
222            quote! {
223                .add_service(#web_crate::protos::#controller_snake_name::#server_snake_name::#service_name::new(#controller_var_name))
224            }
225        })
226        .collect();
227
228    let mut body = quote! {
229        // Setup logging
230        cali_core::logging::util::setup();
231
232        log::info!("Getting ready...");
233        // Configure CLI App
234        let matches = clap::App::new(#app_name)
235            .version(#version)
236            .arg(
237                clap::Arg::with_name("config")
238                    .short('c')
239                    .long("config")
240                    .value_name("FILE")
241                    .help("Sets a custom config file")
242                    .default_value("./web/config/dev.yml")
243                    .takes_value(true),
244            )
245            .get_matches();
246
247
248        // Setup Config File
249        log::info!("Loading config...");
250        let config_file = std::fs::File::open(matches.value_of("config")
251                            .expect("No value set for config path"))
252                            .expect("Could not open config file at web/config/dev.yml");
253
254        let config: std::sync::Arc<Config> = std::sync::Arc::new({
255            let deserializer = serde_yaml::Deserializer::from_reader(config_file);
256            let config: Config = serde_ignored::deserialize(deserializer, |path| {
257                log::warn!("Unused config field: {}", path);
258            })
259            .expect("Could not deserialize config");
260            // Edit config here if you want to
261            config
262        });
263
264        log::info!("Connecting to DB...");
265        let db_pool = sqlx::mysql::MySqlPoolOptions::new()
266            .max_connections(config.database.num_connections)
267            .test_before_acquire(true)
268            .connect(&config.database.url)
269            .await?;
270
271        let server_ctx : std::sync::Arc<cali_core::ServerContext> = std::sync::Arc::new(cali_core::ServerContext { db_pool });
272
273        let context_layer = cali_core::middleware::server_context::ServerContextLayer {
274            config: config.clone(),
275            extentable_context: #extentable_context.clone(),
276            internal_context: server_ctx.clone()
277        };
278    };
279
280    let grpc_segment = quote! {
281        #(#controllers)*
282
283        let (host, port) = cali_core::helpers::split_host_and_port(&config.bind_address);
284        let addr = format!("{}:{}", host, port);
285
286        let server = tonic::transport::Server::builder()
287            .layer(context_layer)
288            #(#services)*;
289
290        log::info!("GRPC server started, waiting for requests...");
291        let mut interrupt_signal = tokio::signal::ctrl_c();
292        let closer = async move {
293            let _ = interrupt_signal.await;
294            log::info!("Good bye!");
295        };
296
297        server
298            .serve_with_shutdown(
299                std::net::SocketAddr::from_str(&addr[..]).unwrap(),
300                async move {
301                    // Add closers for other processes
302                    let _ = closer.await;
303                },
304            )
305            .await?;
306    };
307
308    if services.len() > 0 {
309        body.extend(grpc_segment);
310    }
311
312    body.into()
313}
314
315#[proc_macro]
316pub fn test_runner(_input: TokenStream) -> TokenStream {
317    // Rather let this return a wrapping type called test context under cali core?
318    // That way I can implement the drop trait on that type and clean up test databases that way?
319    let test_setup_body = quote! {
320         pub async fn run(config_file: &str, test: impl std::future::Future<Output = ()>) -> () {
321        cali_core::logging::util::setup();
322
323        let config_file = std::fs::File::open(config_file).expect("Could not open config file");
324
325        let config = {
326            let deserializer = serde_yaml::Deserializer::from_reader(config_file);
327            let config: Config =
328                serde_ignored::deserialize(deserializer, |_| {}).expect("Could not deserialize config");
329            config
330        };
331
332        let db_url = url::Url::parse(&config.clone().database.url).expect("Unable to parse DB url");
333
334        // Create the database
335        let pool = sqlx::MySqlPool::connect(&db_url[..url::Position::BeforePath])
336            .await
337            .unwrap();
338
339        let db_name = db_url
340            .path_segments()
341            .expect("No database specified")
342            .next()
343            .expect("No database specified");
344
345        // Delete the existing database
346        let drop_query = format!("DROP DATABASE IF EXISTS {}", db_name);
347        sqlx::query(&drop_query).execute(&pool).await.unwrap();
348
349        // Recreate it
350        let create_query = format!("CREATE DATABASE IF NOT EXISTS {}", db_name);
351        sqlx::query(&create_query).execute(&pool).await.unwrap();
352
353        // Run all migrations
354        let pool = sqlx::MySqlPool::connect(&db_url.to_string()).await.unwrap();
355
356        sqlx::migrate!("../store/migrations")
357            .run(&pool)
358            .await
359            .expect("Expected to be able to run migrations");
360
361        let db_pool = sqlx::mysql::MySqlPoolOptions::new()
362            .max_connections(1)
363            .test_before_acquire(true)
364            .connect(&config.clone().database.url)
365            .await
366            .expect("Couldn't connect to test database");
367
368        let mut context: std::collections::HashMap<std::any::TypeId, cali_core::MapKey> =
369            std::collections::HashMap::new();
370
371        context.insert(
372            std::any::TypeId::of::<cali_core::ServerContext>(),
373            std::sync::Arc::new(cali_core::ServerContext { db_pool }),
374        );
375
376
377        context.insert(
378            std::any::TypeId::of::<Config>(),
379            std::sync::Arc::new(config),
380        );
381
382        cali_core::SERVER_CONTEXT
383            .scope(std::sync::Arc::new(context), test)
384            .await
385    }
386            };
387
388    test_setup_body.into()
389}