Skip to main content

milrouter_macros/
lib.rs

1use {
2    crate::helpers::parse_attrs,
3    heck::{AsPascalCase, AsSnekCase},
4    helpers::{RouteInfo, get_inner_type, parse_fn_args, preamble, unit},
5    proc_macro::{Span, TokenStream},
6    quote::{ToTokens, format_ident, quote},
7    syn::{DeriveInput, FnArg, parse_macro_input},
8};
9
10#[macro_use]
11mod helpers;
12
13/// ### Arguments
14/// Takes 2 k=v arguments:
15/// - `is_idempotent` — __Optional__<br>
16///   Idempotency is defalted to false<br>
17///   Providing `is_idempotent` is sufficient (no `= true` needed)<br>
18///   See HTTP spec (https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2)
19///
20///
21/// - `auth` — __Required__<br>
22///   Function to determine which request are allowed through based on headers.<br>
23///   The inner type returned (e.g unit in this case) is
24///   passed to the `client` variable of the endpoint, if it exists
25///
26/// ### Function Signature
27/// - #### Parameters (All optional):
28///   - `client`<br>
29///     Type derived from headers via `auth`<br>
30///   - param (name derived from use)<br>
31///     Type must be serializable, and representable in JSON (via serde-json).<br>
32///     If undefined, becomes unit `()`
33///
34/// ### Example:
35/// ```rust
36/// #[endpoint(is_idempotent = false, auth = auth_handler)]
37/// fn route(client: (), param1: ()) -> anyhow::Result<String> {}
38/// ```
39#[proc_macro_attribute]
40pub fn endpoint(annot: TokenStream, item: TokenStream) -> TokenStream {
41    let it = item.clone();
42    let meta = parse_macro_input!(it as syn::ItemFn);
43
44    let name = meta.sig.clone().ident;
45    let ret = meta.sig.clone().output;
46    let block = meta.block;
47
48    let args = parse_fn_args(
49        meta.sig
50            .inputs
51            .iter()
52            .map(|a| {
53                let a = match a {
54                    FnArg::Typed(t) => t,
55                    _ => panic!("Unexpected self type in endpoint"),
56                };
57
58                let ident = match *a.clone().pat {
59                    syn::Pat::Ident(pat_ident) => pat_ident.ident,
60                    _ => unreachable!(),
61                };
62
63                let ty = *a.clone().ty;
64
65                (ident, ty)
66            })
67            .collect::<Vec<_>>(),
68    );
69
70    let info = err!(RouteInfo::parse(annot.into()));
71    let (idempotent, auth) = (info.is_idempotent, info.auth);
72
73    let method = match idempotent {
74        true => "PUT",
75        false => "POST",
76    };
77
78    let inner_ret = match meta.sig.clone().output {
79        syn::ReturnType::Type(_, ty) => *ty,
80        _ => unreachable!(),
81    };
82
83    let inner_ret = err!(get_inner_type(inner_ret.clone()).map_err(|e| {
84        syn::Error::new_spanned(ret.to_token_stream(), format!("Unexpected return type (should be anyhow::Result<T>).\n{e}"))
85    }));
86
87    let struct_name = quote::format_ident!("Endpoint{}", AsPascalCase(name.to_string()).to_string());
88
89    let data = args.clone().input.1;
90    let client_type = args.client.clone().map(|c| c.1).unwrap_or(unit());
91    let args = args.to_tokens();
92
93    quote::quote! {
94        #[doc = concat!("Endpoint Struct for [", stringify!(#name) ,"]\n@ ", stringify!(#method), " -> ", stringify!(#struct_name), "::Data ([", stringify!(#ret), "])")]
95        #[derive(Clone)]
96        pub struct #struct_name;
97        impl milrouter::Endpoint<#client_type> for #struct_name {
98            type Data = #data;
99            type Returns = #inner_ret;
100
101            fn is_idempotent() -> bool { #idempotent }
102        }
103
104        #[cfg(target_arch = "x86_64")]
105        impl milrouter::ServerEndpoint<#client_type> for #struct_name {
106
107            fn auth() -> Box<dyn Fn(milrouter::hyper::HeaderMap) -> milrouter::BoxFuture<'static, milrouter::anyhow::Result<#client_type>> + 'static + Send> {
108                Box::new(move |i: milrouter::hyper::HeaderMap| Box::pin(#auth(i)))
109            }
110
111            fn handler() -> Box<dyn Fn(#client_type, milrouter::hyper::HeaderMap, Self::Data) -> milrouter::BoxFuture<'static, milrouter::anyhow::Result<Self::Returns>> + 'static + Send> {
112                Box::new(move |i: #client_type, i2: milrouter::hyper::HeaderMap, i3: Self::Data| Box::pin(#name(i, i2, i3)))
113            }
114        }
115
116
117        #[doc("Endpoint Handler for [#name]\n@ #method -> #struct_name::Data ([#arg])")]
118        #[cfg(target_arch = "x86_64")]
119        pub async fn #name(#args) #ret #block
120
121    }
122    .into()
123}
124
125/// Apply to an enum. <br>
126/// Variants' snake_case names are used as paths, and inner type's used as endpoint handlers.
127/// ### Example:
128/// ```rust
129/// #[derive(Router)]
130/// #[assets("./example/static")]
131/// #[html(super_awesome_html_generator)]
132/// pub enum DemoRouter {
133///     Greet(EndpointGreet),
134/// }
135/// ```
136#[proc_macro_derive(Router, attributes(assets, html))]
137pub fn router(item: TokenStream) -> TokenStream {
138    let (input, name, data) = preamble(parse_macro_input!(item as DeriveInput));
139    let (html, local_assets) = parse_attrs(input.clone());
140
141    let paths: Result<Vec<proc_macro2::TokenStream>, syn::Error> = data.variants.iter().map(|variant| {
142
143        let path = format_ident!("{}", AsSnekCase(variant.ident.to_string()).to_string());
144        let inner = variant.fields.iter()
145            .next()
146            .map(|ty| ty.ty.clone())
147            .ok_or(syn::Error::new_spanned(
148                variant.to_token_stream(),
149                format!("No endpoint specified for {}", variant.ident)
150            ))?;
151
152        let inner_name = &variant.ident;
153
154        Ok(quote::quote! {
155            (stringify!(#path), i) if i == #inner::is_idempotent() => ({
156                let auth = <#inner as milrouter::ServerEndpoint<_>>::auth();
157
158                let error_res = |e, code, label| {
159                    milrouter::tracing::info!("[-] {code} {label} /{}", stringify!(#path));
160                    milrouter::hyper::Response::builder()
161                        .status(code)
162                        .body(
163                            milrouter::Body::from(format!(
164                                "You aren't authorised to access this endpoint\n{e}"
165                            ))
166                            .full(),
167                        )
168                        .unwrap()
169                };
170
171                let client = match auth(headers.clone()).await {
172                    Ok(c) => c,
173                    Err(e) => return error_res(e.to_string(), 401, "Unauthorised"),
174                };
175
176                let body: std::boxed::Box<dyn std::any::Any> = match std::any::type_name::<<#inner as milrouter::Endpoint<_>>::Data>() {
177                    "()" => std::boxed::Box::new(()),
178                    _ => {
179                        let bytes = req.collect().await.expect(&format!("Failed to read incoming bytes for {}", stringify!(#inner_name))).to_bytes();
180                        std::boxed::Box::new(milrouter::serde_json::from_str::<<#inner as milrouter::Endpoint<_>>::Data>(&String::from_utf8_lossy(&bytes[..]).to_string()).expect(&format!("Failed to deserialize body for {}", stringify!(#inner_name))))
181                    }
182                };
183
184                let body: <#inner as milrouter::Endpoint<_>>::Data = *body.downcast::<<#inner as milrouter::Endpoint<_>>::Data>().unwrap();
185                let handler = <#inner as milrouter::ServerEndpoint<_>>::handler();
186
187                match handler(client, headers, body).await {
188                    Ok(response) => {
189                        let bytes = milrouter::serde_json::to_vec(&response).expect(&format!("Failed to serialize response for {}", stringify!(#inner_name)));
190
191                        let mut compressed_file = Vec::new();
192                        milrouter::gz_compress(bytes.as_slice(), &mut compressed_file).unwrap();
193
194                        milrouter::tracing::info!(concat!("[+] 200 Ok /", stringify!(#path)));
195                        return milrouter::hyper::Response::builder()
196                            .status(200)
197                            .header("Content-Encoding", "gzip")
198                            .body(milrouter::Body::from(compressed_file.as_slice()).full())
199                            .unwrap();
200                    },
201                    Err(e) => {
202                        milrouter::tracing::warn!(concat!("[-] 400 Bad Request /", stringify!(#path)));
203                        return milrouter::hyper::Response::builder()
204                            .status(400)
205                            .body(milrouter::Body::from(e.to_string()).full())
206                            .unwrap()
207                    }
208                };
209            }),
210        })
211    }).collect();
212
213    let paths: Vec<proc_macro2::TokenStream> = err!(paths);
214
215    let into_routers: Result<Vec<proc_macro2::TokenStream>, syn::Error> = data
216        .variants
217        .iter()
218        .map(|variant| {
219            let ident = variant.fields.iter().next().map(|ty| ty.ty.clone()).ok_or(syn::Error::new_spanned(
220                variant.to_token_stream(),
221                format!("No endpoint specified for {}", variant.ident),
222            ))?;
223
224            let variant = variant.ident.clone();
225
226            Ok(quote::quote! {
227                impl milrouter::IntoRouter<#name> for #ident {
228                    fn router(self) -> #name {
229                        #name::#variant(#ident)
230                    }
231                }
232            })
233        })
234        .collect();
235
236    let into_routers: Vec<proc_macro2::TokenStream> = err!(into_routers);
237
238    let as_paths = data
239        .variants
240        .iter()
241        .map(|variant| {
242            let ident = variant.ident.clone();
243            let snake = heck::AsSnekCase(variant.ident.to_string()).to_string();
244            quote::quote! {
245               Self::#ident(..) => f.write_str(#snake),
246            }
247        })
248        .collect::<Vec<_>>();
249
250    let walkdir = |p: std::path::PathBuf| {
251        walkdir::WalkDir::new(&p)
252            .into_iter()
253            .filter_map(|e| match e {
254                Err(_) => None,
255                Ok(f) => f.metadata().unwrap().is_file().then_some(f),
256            })
257            .map(move |entry| {
258                let route =
259                    entry.path().display().to_string().strip_prefix(&format!("{}/", p.display())).unwrap().to_string();
260
261                let path = entry.path().display().to_string();
262
263                let mime = mime_guess::from_path(route.clone()).first_or_text_plain().to_string();
264                quote::quote! {
265                    assets.insert(#route.to_string(), (#mime.to_string(), include_bytes!(#path)));
266                }
267            })
268    };
269
270    let inserts = match local_assets.clone() {
271        Some(v) => {
272            let root = Span::call_site().local_file().unwrap_or_default();
273            walkdir(root.join(&v)).collect::<Vec<_>>()
274        }
275        _ => Vec::new(),
276    };
277
278    let default_route_case = match html {
279        None => quote::quote!(),
280        Some(html) => quote::quote! {
281            else if path.is_empty() {
282                milrouter::tracing::info!("[#] 200 Ok (HTML) /{}", path);
283                return Ok(
284                    milrouter::hyper::Response::builder()
285                        .status(200)
286                        .header("Content-Type", "text/html")
287                        .body(milrouter::Body::from(#html()).full())
288                        .unwrap()
289                )
290            }
291        },
292    };
293
294    let assets_serving = match local_assets.clone() {
295        Some(local_assets) => quote::quote! {
296             if let Some(file) = __ASSETS.get(&path) {
297                milrouter::tracing::info!("[#] 200 Ok (File) /{}", path);
298                return Ok(
299                    milrouter::hyper::Response::builder()
300                        .status(200)
301                        .header("Content-Type", file.0.to_string())
302                        .header("Content-Encoding", "gzip")
303                        .body(match std::env::var("MILROUTER_LOCAL").is_ok() {
304                            false => {
305                                let mut compressed_file = Vec::new();
306                                milrouter::gz_compress(file.1, &mut compressed_file).unwrap();
307                                milrouter::Body::from(compressed_file.as_slice()).full()
308                            },
309                            true => {
310                                use std::io::Read;
311                                let mut byt = Vec::new();
312
313                                let local = std::fs::File::open(std::path::PathBuf::from(#local_assets).join(&path)).and_then(|mut f| f.read_to_end(&mut byt));
314                                let mut compressed_file = Vec::new();
315                                milrouter::gz_compress(byt.as_slice(), &mut compressed_file).unwrap();
316                                milrouter::Body::from(compressed_file.as_slice()).full()
317                            }
318                        })
319                        .unwrap()
320                )
321            }
322        },
323        _ => quote::quote!(),
324    };
325
326    let el = if assets_serving.is_empty() && default_route_case.is_empty() {
327        quote! {}
328    } else {
329        quote! { else }
330    };
331
332    let ts = TokenStream::from(quote::quote! {
333        #[cfg(target_arch = "x86_64")]
334        static __ASSETS: std::sync::LazyLock<std::collections::BTreeMap::<String, (String, &'static [u8])>> = std::sync::LazyLock::new(|| {
335            use std::io::Read;
336            let mut assets = std::collections::BTreeMap::<String, (String, &'static [u8])>::new();
337            #(#inserts)*
338            assets
339        });
340
341        #[cfg(target_arch = "x86_64")]
342        impl #name {
343            pub async fn route(req: milrouter::hyper::Request<milrouter::hyper::body::Incoming>) -> std::result::Result<milrouter::hyper::Response<milrouter::http_body_util::Full<milrouter::bytes::Bytes>>, std::convert::Infallible> {
344                use milrouter::http_body_util::BodyExt;
345                use std::error::Error;
346
347                let path = req.uri().path().to_string();
348                let path = path.strip_prefix("/").map(|v| v.to_string()).unwrap_or(path);
349                let path = path.strip_prefix("static/").map(|v| v.to_string()).unwrap_or(path);
350                let headers = req.headers().clone();
351
352                if req.method() == milrouter::hyper::Method::GET {
353                    #assets_serving
354                    #default_route_case
355                    #el {
356                        milrouter::tracing::warn!("[#] 404 Not Found /{}", path);
357                        return Ok(
358                            milrouter::hyper::Response::builder()
359                                .status(404)
360                                .body(milrouter::Body::default().full())
361                                .unwrap()
362                        )
363                    }
364                }
365
366                Ok(match milrouter::tokio::task::spawn(async move {
367                    match (path.as_str(), req.method().is_idempotent()) {
368                        #(#paths)*
369                        path => {
370                            milrouter::tracing::info!("[?] 404 Not Found /{}", path.0);
371                            return milrouter::hyper::Response::builder()
372                                .status(404)
373                                .body(milrouter::Body::default().full())
374                                .unwrap()
375                        }
376                    }
377                }).await {
378                    Ok(inner) => inner,
379                    Err(err) => {
380
381                        let err = err.into_panic();
382
383                        let value = err
384                            .downcast_ref::<String>()
385                            .cloned()
386                            .or(err.downcast_ref::<&str>().map(|s| s.to_string()))
387                            .unwrap_or("[Unexpected Error]".to_string());
388
389                        milrouter::hyper::Response::builder()
390                            .status(500)
391                            .body(milrouter::Body::from(format!("{:?}", err)).full())
392                            .unwrap()
393
394
395                    }
396                })
397
398            }
399        }
400
401        impl std::fmt::Display for #name {
402            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403                match self {
404                    #(#as_paths)*
405                }
406            }
407
408        }
409
410        impl milrouter::Router for #name {}
411
412        #(#into_routers)*
413
414    });
415
416    // dbg!(ts.to_string());
417    ts
418}
419
420// For docs
421
422/// __Optional__
423///  
424/// Serves static assets (relative to the file in which its invoked)
425///
426/// If `MILROUTER_LOCAL` is set, will read from disk every request,
427/// otherwise, will load into LazyLock
428#[proc_macro_attribute]
429pub fn assets(_: TokenStream, i: TokenStream) -> TokenStream { i }
430
431/// __Optional__
432///  
433/// Serves static HTML governed from a function
434///
435/// If this is not provided, `/` will give a `400`
436#[proc_macro_attribute]
437pub fn html(_: TokenStream, i: TokenStream) -> TokenStream { i }