1extern crate proc_macro;
2use std::path::Path;
3
4use convert_case::{Case, Casing};
5use cali_core::protos::parser::get_proto_data;
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::quote;
9use syn::{parse_macro_input, DeriveInput};
10
11#[proc_macro]
12pub fn autogen_protos(_item: TokenStream) -> TokenStream {
13 let gen = quote! {
14 let service_files: Vec<String> = std::fs::read_dir("../interface/grpc/services/")
15 .expect("Could not read contents of interface file")
16 .filter(|entry| entry.is_ok())
17 .map(|entry| entry.unwrap().path().to_str().unwrap().to_string())
18 .collect();
19
20 let out_path = std::path::Path::new("src/protos");
21 if !out_path.exists() {
22 let _ = std::fs::create_dir(out_path)
23 .expect(&format!("Unable to create protos folder {:?}", out_path));
24 }
25
26 if service_files.len() > 0 {
27 tonic_build::configure()
28 .build_server(true)
29 .out_dir(out_path.to_str().unwrap())
30 .compile(service_files.as_slice(), &["../interface/grpc/".to_string()])
31 .unwrap();
32 }
33
34 let path = std::path::Path::new("../interface/grpc/services");
36 let proto_data = cali_core::protos::parser::get_proto_data(&path).expect("Should have worked");
37 let mut mod_contents = "".to_string();
38 proto_data.services.iter().for_each(|service| {
39 let import_line = format!("pub mod {};\n", service.name.to_case(Case::Snake));
40 mod_contents.push_str(&import_line);
41 });
42 let mod_path = std::path::Path::new("src/protos/mod.rs");
43 std::fs::write(mod_path, mod_contents).expect("Could not write main file");
44 };
45 gen.into()
46}
47
48#[proc_macro]
49pub fn controller(input: TokenStream) -> TokenStream {
50 let controller_struct_name = Ident::new(&format!("{}", input)[..], Span::call_site());
51
52 let gen = quote! {
53 #[derive(Clone)]
54 pub struct #controller_struct_name {}
55
56 impl #controller_struct_name {
57 pub fn new() -> Self {
58 #controller_struct_name {}
59 }
60 }
61 };
62 gen.into()
63}
64#[proc_macro_derive(Ensnare)]
65pub fn derive_ensnare(input: TokenStream) -> TokenStream {
66 let input = parse_macro_input!(input as DeriveInput);
67 let struct_name = input.ident;
68
69 let struct_fields = match input.data {
70 syn::Data::Struct(syn::DataStruct {
71 fields: syn::Fields::Named(named_fields),
72 ..
73 }) => named_fields
74 .named
75 .iter()
76 .filter_map(|f| f.ident.clone())
77 .collect::<Vec<Ident>>(),
78 _ => {
79 panic!("Can only Ensnare struct types");
80 }
81 };
82
83 let bind_points = struct_fields
84 .iter()
85 .map(|_| "?".to_string())
86 .collect::<Vec<String>>()
87 .join(",");
88
89 let fields = struct_fields
90 .iter()
91 .map(|f| f.to_string())
92 .collect::<Vec<String>>()
93 .join(",");
94
95 let bindings: Vec<TokenStream2> = struct_fields
96 .iter()
97 .map(|f| quote!(bind(self.#f.clone())))
98 .collect();
99
100 let expanded = quote! {
101 impl cali_core::store::snare::Ensnarable for #struct_name {
102 fn insert_parts(&self) -> (String, String) {
103 (#fields.to_string(), #bind_points.to_string())
104 }
105
106 fn capture<'a>(
107 &'a self,
108 query: sqlx::query::Query<
109 'a,
110 sqlx::MySql,
111 <sqlx::MySql as sqlx::database::HasArguments<'_>>::Arguments,
112 >,
113 ) -> sqlx::query::Query<
114 'a,
115 sqlx::MySql,
116 <sqlx::MySql as sqlx::database::HasArguments<'_>>::Arguments,
117 > {
118 query.#(#bindings).*
119 }
120 }
121
122 impl #struct_name {
123 pub fn trap(self, table_name: &str) -> cali_core::store::snare::Snare<#struct_name> {
124 cali_core::store::snare::Snare {
125 query: "".to_string(),
126 table_name: table_name.to_string(),
127 data: self,
128 }
129 }
130 }
131 };
132
133 TokenStream::from(expanded)
134}
135
136#[proc_macro]
137pub fn setup_server(input: TokenStream) -> TokenStream {
138 let app_name: String;
139 let version: String;
140 let extentable_context: Ident;
141
142 let input = proc_macro2::TokenStream::from(input);
143 let mut params_stream = input.into_iter();
144
145 if let Some(proc_macro2::TokenTree::Literal(val)) = params_stream.next() {
146 let temp = format!("{}", val);
147 params_stream.next(); app_name = temp[1..temp.len() - 1].to_string();
149 } else {
150 panic!("Please add an application name")
151 }
152
153 if let Some(proc_macro2::TokenTree::Literal(val)) = params_stream.next() {
154 let temp = format!("{}", val);
155 version = temp[1..temp.len() - 1].to_string();
156 params_stream.next(); } else {
158 panic!("Please add a version")
159 }
160
161 if let Some(proc_macro2::TokenTree::Ident(val)) = params_stream.next() {
162 extentable_context = val;
163 } else {
164 panic!("An extentable_context has to be provided")
165 }
166
167 let path = Path::new("./interface/grpc/services");
168 let proto_data = get_proto_data(&path).expect("Should have worked");
169
170 let web_crate = Ident::new(&format!("{}_web", app_name)[..], Span::call_site());
171
172 let controllers: Vec<proc_macro2::TokenStream> = proto_data
173 .services
174 .iter()
175 .map(|service| {
176 let controller_var_name = Ident::new(
177 &format!("{}_controller", service.name.to_case(Case::Snake))[..],
178 Span::call_site(),
179 );
180
181 let controller_snake_name = Ident::new(
182 &format!("{}", service.name.to_case(Case::Snake))[..],
183 Span::call_site(),
184 );
185
186 let controller_name = Ident::new(
187 &format!(
188 "{}Controller",
189 service.name.to_case(Case::UpperCamel)
190 )[..],
191 Span::call_site(),
192 );
193
194 quote! {
195 let #controller_var_name = #web_crate::controllers::#controller_snake_name::#controller_name::new();
196 }
197 })
198 .collect();
199
200 let services: Vec<proc_macro2::TokenStream> = proto_data
201 .services
202 .iter()
203 .map(|service| {
204 let controller_var_name = Ident::new(
205 &format!("{}_controller", service.name.to_case(Case::Snake))[..],
206 Span::call_site(),
207 );
208 let service_name = Ident::new(
209 &format!("{}Server", service.name.to_case(Case::UpperCamel))[..],
210 Span::call_site(),
211 );
212
213 let controller_snake_name = Ident::new(
214 &format!("{}", service.name.to_case(Case::Snake))[..],
215 Span::call_site(),
216 );
217 let server_snake_name = Ident::new(
218 &format!("{}_server", service.name.to_case(Case::Snake))[..],
219 Span::call_site(),
220 );
221
222 quote! {
223 .add_service(#web_crate::protos::#controller_snake_name::#server_snake_name::#service_name::new(#controller_var_name))
224 }
225 })
226 .collect();
227
228 let mut body = quote! {
229 cali_core::logging::util::setup();
231
232 log::info!("Getting ready...");
233 let matches = clap::App::new(#app_name)
235 .version(#version)
236 .arg(
237 clap::Arg::with_name("config")
238 .short('c')
239 .long("config")
240 .value_name("FILE")
241 .help("Sets a custom config file")
242 .default_value("./web/config/dev.yml")
243 .takes_value(true),
244 )
245 .get_matches();
246
247
248 log::info!("Loading config...");
250 let config_file = std::fs::File::open(matches.value_of("config")
251 .expect("No value set for config path"))
252 .expect("Could not open config file at web/config/dev.yml");
253
254 let config: std::sync::Arc<Config> = std::sync::Arc::new({
255 let deserializer = serde_yaml::Deserializer::from_reader(config_file);
256 let config: Config = serde_ignored::deserialize(deserializer, |path| {
257 log::warn!("Unused config field: {}", path);
258 })
259 .expect("Could not deserialize config");
260 config
262 });
263
264 log::info!("Connecting to DB...");
265 let db_pool = sqlx::mysql::MySqlPoolOptions::new()
266 .max_connections(config.database.num_connections)
267 .test_before_acquire(true)
268 .connect(&config.database.url)
269 .await?;
270
271 let server_ctx : std::sync::Arc<cali_core::ServerContext> = std::sync::Arc::new(cali_core::ServerContext { db_pool });
272
273 let context_layer = cali_core::middleware::server_context::ServerContextLayer {
274 config: config.clone(),
275 extentable_context: #extentable_context.clone(),
276 internal_context: server_ctx.clone()
277 };
278 };
279
280 let grpc_segment = quote! {
281 #(#controllers)*
282
283 let (host, port) = cali_core::helpers::split_host_and_port(&config.bind_address);
284 let addr = format!("{}:{}", host, port);
285
286 let server = tonic::transport::Server::builder()
287 .layer(context_layer)
288 #(#services)*;
289
290 log::info!("GRPC server started, waiting for requests...");
291 let mut interrupt_signal = tokio::signal::ctrl_c();
292 let closer = async move {
293 let _ = interrupt_signal.await;
294 log::info!("Good bye!");
295 };
296
297 server
298 .serve_with_shutdown(
299 std::net::SocketAddr::from_str(&addr[..]).unwrap(),
300 async move {
301 let _ = closer.await;
303 },
304 )
305 .await?;
306 };
307
308 if services.len() > 0 {
309 body.extend(grpc_segment);
310 }
311
312 body.into()
313}
314
315#[proc_macro]
316pub fn test_runner(_input: TokenStream) -> TokenStream {
317 let test_setup_body = quote! {
320 pub async fn run(config_file: &str, test: impl std::future::Future<Output = ()>) -> () {
321 cali_core::logging::util::setup();
322
323 let config_file = std::fs::File::open(config_file).expect("Could not open config file");
324
325 let config = {
326 let deserializer = serde_yaml::Deserializer::from_reader(config_file);
327 let config: Config =
328 serde_ignored::deserialize(deserializer, |_| {}).expect("Could not deserialize config");
329 config
330 };
331
332 let db_url = url::Url::parse(&config.clone().database.url).expect("Unable to parse DB url");
333
334 let pool = sqlx::MySqlPool::connect(&db_url[..url::Position::BeforePath])
336 .await
337 .unwrap();
338
339 let db_name = db_url
340 .path_segments()
341 .expect("No database specified")
342 .next()
343 .expect("No database specified");
344
345 let drop_query = format!("DROP DATABASE IF EXISTS {}", db_name);
347 sqlx::query(&drop_query).execute(&pool).await.unwrap();
348
349 let create_query = format!("CREATE DATABASE IF NOT EXISTS {}", db_name);
351 sqlx::query(&create_query).execute(&pool).await.unwrap();
352
353 let pool = sqlx::MySqlPool::connect(&db_url.to_string()).await.unwrap();
355
356 sqlx::migrate!("../store/migrations")
357 .run(&pool)
358 .await
359 .expect("Expected to be able to run migrations");
360
361 let db_pool = sqlx::mysql::MySqlPoolOptions::new()
362 .max_connections(1)
363 .test_before_acquire(true)
364 .connect(&config.clone().database.url)
365 .await
366 .expect("Couldn't connect to test database");
367
368 let mut context: std::collections::HashMap<std::any::TypeId, cali_core::MapKey> =
369 std::collections::HashMap::new();
370
371 context.insert(
372 std::any::TypeId::of::<cali_core::ServerContext>(),
373 std::sync::Arc::new(cali_core::ServerContext { db_pool }),
374 );
375
376
377 context.insert(
378 std::any::TypeId::of::<Config>(),
379 std::sync::Arc::new(config),
380 );
381
382 cali_core::SERVER_CONTEXT
383 .scope(std::sync::Arc::new(context), test)
384 .await
385 }
386 };
387
388 test_setup_body.into()
389}