1use {
2 crate::helpers::parse_attrs,
3 heck::{AsPascalCase, AsSnekCase},
4 helpers::{RouteInfo, get_inner_type, parse_fn_args, preamble, unit},
5 proc_macro::{Span, TokenStream},
6 quote::{ToTokens, format_ident, quote},
7 syn::{DeriveInput, FnArg, parse_macro_input},
8};
9
10#[macro_use]
11mod helpers;
12
13#[proc_macro_attribute]
40pub fn endpoint(annot: TokenStream, item: TokenStream) -> TokenStream {
41 let it = item.clone();
42 let meta = parse_macro_input!(it as syn::ItemFn);
43
44 let name = meta.sig.clone().ident;
45 let ret = meta.sig.clone().output;
46 let block = meta.block;
47
48 let args = parse_fn_args(
49 meta.sig
50 .inputs
51 .iter()
52 .map(|a| {
53 let a = match a {
54 FnArg::Typed(t) => t,
55 _ => panic!("Unexpected self type in endpoint"),
56 };
57
58 let ident = match *a.clone().pat {
59 syn::Pat::Ident(pat_ident) => pat_ident.ident,
60 _ => unreachable!(),
61 };
62
63 let ty = *a.clone().ty;
64
65 (ident, ty)
66 })
67 .collect::<Vec<_>>(),
68 );
69
70 let info = err!(RouteInfo::parse(annot.into()));
71 let (idempotent, auth) = (info.is_idempotent, info.auth);
72
73 let method = match idempotent {
74 true => "PUT",
75 false => "POST",
76 };
77
78 let inner_ret = match meta.sig.clone().output {
79 syn::ReturnType::Type(_, ty) => *ty,
80 _ => unreachable!(),
81 };
82
83 let inner_ret = err!(get_inner_type(inner_ret.clone()).map_err(|e| {
84 syn::Error::new_spanned(ret.to_token_stream(), format!("Unexpected return type (should be anyhow::Result<T>).\n{e}"))
85 }));
86
87 let struct_name = quote::format_ident!("Endpoint{}", AsPascalCase(name.to_string()).to_string());
88
89 let data = args.clone().input.1;
90 let client_type = args.client.clone().map(|c| c.1).unwrap_or(unit());
91 let args = args.to_tokens();
92
93 quote::quote! {
94 #[doc = concat!("Endpoint Struct for [", stringify!(#name) ,"]\n@ ", stringify!(#method), " -> ", stringify!(#struct_name), "::Data ([", stringify!(#ret), "])")]
95 #[derive(Clone)]
96 pub struct #struct_name;
97 impl milrouter::Endpoint<#client_type> for #struct_name {
98 type Data = #data;
99 type Returns = #inner_ret;
100
101 fn is_idempotent() -> bool { #idempotent }
102 }
103
104 #[cfg(target_arch = "x86_64")]
105 impl milrouter::ServerEndpoint<#client_type> for #struct_name {
106
107 fn auth() -> Box<dyn Fn(milrouter::hyper::HeaderMap) -> milrouter::BoxFuture<'static, milrouter::anyhow::Result<#client_type>> + 'static + Send> {
108 Box::new(move |i: milrouter::hyper::HeaderMap| Box::pin(#auth(i)))
109 }
110
111 fn handler() -> Box<dyn Fn(#client_type, milrouter::hyper::HeaderMap, Self::Data) -> milrouter::BoxFuture<'static, milrouter::anyhow::Result<Self::Returns>> + 'static + Send> {
112 Box::new(move |i: #client_type, i2: milrouter::hyper::HeaderMap, i3: Self::Data| Box::pin(#name(i, i2, i3)))
113 }
114 }
115
116
117 #[doc("Endpoint Handler for [#name]\n@ #method -> #struct_name::Data ([#arg])")]
118 #[cfg(target_arch = "x86_64")]
119 pub async fn #name(#args) #ret #block
120
121 }
122 .into()
123}
124
125#[proc_macro_derive(Router, attributes(assets, html))]
137pub fn router(item: TokenStream) -> TokenStream {
138 let (input, name, data) = preamble(parse_macro_input!(item as DeriveInput));
139 let (html, local_assets) = parse_attrs(input.clone());
140
141 let paths: Result<Vec<proc_macro2::TokenStream>, syn::Error> = data.variants.iter().map(|variant| {
142
143 let path = format_ident!("{}", AsSnekCase(variant.ident.to_string()).to_string());
144 let inner = variant.fields.iter()
145 .next()
146 .map(|ty| ty.ty.clone())
147 .ok_or(syn::Error::new_spanned(
148 variant.to_token_stream(),
149 format!("No endpoint specified for {}", variant.ident)
150 ))?;
151
152 let inner_name = &variant.ident;
153
154 Ok(quote::quote! {
155 (stringify!(#path), i) if i == #inner::is_idempotent() => ({
156 let auth = <#inner as milrouter::ServerEndpoint<_>>::auth();
157
158 let error_res = |e, code, label| {
159 milrouter::tracing::info!("[-] {code} {label} /{}", stringify!(#path));
160 milrouter::hyper::Response::builder()
161 .status(code)
162 .body(
163 milrouter::Body::from(format!(
164 "You aren't authorised to access this endpoint\n{e}"
165 ))
166 .full(),
167 )
168 .unwrap()
169 };
170
171 let client = match auth(headers.clone()).await {
172 Ok(c) => c,
173 Err(e) => return error_res(e.to_string(), 401, "Unauthorised"),
174 };
175
176 let body: std::boxed::Box<dyn std::any::Any> = match std::any::type_name::<<#inner as milrouter::Endpoint<_>>::Data>() {
177 "()" => std::boxed::Box::new(()),
178 _ => {
179 let bytes = req.collect().await.expect(&format!("Failed to read incoming bytes for {}", stringify!(#inner_name))).to_bytes();
180 std::boxed::Box::new(milrouter::serde_json::from_str::<<#inner as milrouter::Endpoint<_>>::Data>(&String::from_utf8_lossy(&bytes[..]).to_string()).expect(&format!("Failed to deserialize body for {}", stringify!(#inner_name))))
181 }
182 };
183
184 let body: <#inner as milrouter::Endpoint<_>>::Data = *body.downcast::<<#inner as milrouter::Endpoint<_>>::Data>().unwrap();
185 let handler = <#inner as milrouter::ServerEndpoint<_>>::handler();
186
187 match handler(client, headers, body).await {
188 Ok(response) => {
189 let bytes = milrouter::serde_json::to_vec(&response).expect(&format!("Failed to serialize response for {}", stringify!(#inner_name)));
190
191 let mut compressed_file = Vec::new();
192 milrouter::gz_compress(bytes.as_slice(), &mut compressed_file).unwrap();
193
194 milrouter::tracing::info!(concat!("[+] 200 Ok /", stringify!(#path)));
195 return milrouter::hyper::Response::builder()
196 .status(200)
197 .header("Content-Encoding", "gzip")
198 .body(milrouter::Body::from(compressed_file.as_slice()).full())
199 .unwrap();
200 },
201 Err(e) => {
202 milrouter::tracing::warn!(concat!("[-] 400 Bad Request /", stringify!(#path)));
203 return milrouter::hyper::Response::builder()
204 .status(400)
205 .body(milrouter::Body::from(e.to_string()).full())
206 .unwrap()
207 }
208 };
209 }),
210 })
211 }).collect();
212
213 let paths: Vec<proc_macro2::TokenStream> = err!(paths);
214
215 let into_routers: Result<Vec<proc_macro2::TokenStream>, syn::Error> = data
216 .variants
217 .iter()
218 .map(|variant| {
219 let ident = variant.fields.iter().next().map(|ty| ty.ty.clone()).ok_or(syn::Error::new_spanned(
220 variant.to_token_stream(),
221 format!("No endpoint specified for {}", variant.ident),
222 ))?;
223
224 let variant = variant.ident.clone();
225
226 Ok(quote::quote! {
227 impl milrouter::IntoRouter<#name> for #ident {
228 fn router(self) -> #name {
229 #name::#variant(#ident)
230 }
231 }
232 })
233 })
234 .collect();
235
236 let into_routers: Vec<proc_macro2::TokenStream> = err!(into_routers);
237
238 let as_paths = data
239 .variants
240 .iter()
241 .map(|variant| {
242 let ident = variant.ident.clone();
243 let snake = heck::AsSnekCase(variant.ident.to_string()).to_string();
244 quote::quote! {
245 Self::#ident(..) => f.write_str(#snake),
246 }
247 })
248 .collect::<Vec<_>>();
249
250 let walkdir = |p: std::path::PathBuf| {
251 walkdir::WalkDir::new(&p)
252 .into_iter()
253 .filter_map(|e| match e {
254 Err(_) => None,
255 Ok(f) => f.metadata().unwrap().is_file().then_some(f),
256 })
257 .map(move |entry| {
258 let route =
259 entry.path().display().to_string().strip_prefix(&format!("{}/", p.display())).unwrap().to_string();
260
261 let path = entry.path().display().to_string();
262
263 let mime = mime_guess::from_path(route.clone()).first_or_text_plain().to_string();
264 quote::quote! {
265 assets.insert(#route.to_string(), (#mime.to_string(), include_bytes!(#path)));
266 }
267 })
268 };
269
270 let inserts = match local_assets.clone() {
271 Some(v) => {
272 let root = Span::call_site().local_file().unwrap_or_default();
273 walkdir(root.join(&v)).collect::<Vec<_>>()
274 }
275 _ => Vec::new(),
276 };
277
278 let default_route_case = match html {
279 None => quote::quote!(),
280 Some(html) => quote::quote! {
281 else if path.is_empty() {
282 milrouter::tracing::info!("[#] 200 Ok (HTML) /{}", path);
283 return Ok(
284 milrouter::hyper::Response::builder()
285 .status(200)
286 .header("Content-Type", "text/html")
287 .body(milrouter::Body::from(#html()).full())
288 .unwrap()
289 )
290 }
291 },
292 };
293
294 let assets_serving = match local_assets.clone() {
295 Some(local_assets) => quote::quote! {
296 if let Some(file) = __ASSETS.get(&path) {
297 milrouter::tracing::info!("[#] 200 Ok (File) /{}", path);
298 return Ok(
299 milrouter::hyper::Response::builder()
300 .status(200)
301 .header("Content-Type", file.0.to_string())
302 .header("Content-Encoding", "gzip")
303 .body(match std::env::var("MILROUTER_LOCAL").is_ok() {
304 false => {
305 let mut compressed_file = Vec::new();
306 milrouter::gz_compress(file.1, &mut compressed_file).unwrap();
307 milrouter::Body::from(compressed_file.as_slice()).full()
308 },
309 true => {
310 use std::io::Read;
311 let mut byt = Vec::new();
312
313 let local = std::fs::File::open(std::path::PathBuf::from(#local_assets).join(&path)).and_then(|mut f| f.read_to_end(&mut byt));
314 let mut compressed_file = Vec::new();
315 milrouter::gz_compress(byt.as_slice(), &mut compressed_file).unwrap();
316 milrouter::Body::from(compressed_file.as_slice()).full()
317 }
318 })
319 .unwrap()
320 )
321 }
322 },
323 _ => quote::quote!(),
324 };
325
326 let el = if assets_serving.is_empty() && default_route_case.is_empty() {
327 quote! {}
328 } else {
329 quote! { else }
330 };
331
332 let ts = TokenStream::from(quote::quote! {
333 #[cfg(target_arch = "x86_64")]
334 static __ASSETS: std::sync::LazyLock<std::collections::BTreeMap::<String, (String, &'static [u8])>> = std::sync::LazyLock::new(|| {
335 use std::io::Read;
336 let mut assets = std::collections::BTreeMap::<String, (String, &'static [u8])>::new();
337 #(#inserts)*
338 assets
339 });
340
341 #[cfg(target_arch = "x86_64")]
342 impl #name {
343 pub async fn route(req: milrouter::hyper::Request<milrouter::hyper::body::Incoming>) -> std::result::Result<milrouter::hyper::Response<milrouter::http_body_util::Full<milrouter::bytes::Bytes>>, std::convert::Infallible> {
344 use milrouter::http_body_util::BodyExt;
345 use std::error::Error;
346
347 let path = req.uri().path().to_string();
348 let path = path.strip_prefix("/").map(|v| v.to_string()).unwrap_or(path);
349 let path = path.strip_prefix("static/").map(|v| v.to_string()).unwrap_or(path);
350 let headers = req.headers().clone();
351
352 if req.method() == milrouter::hyper::Method::GET {
353 #assets_serving
354 #default_route_case
355 #el {
356 milrouter::tracing::warn!("[#] 404 Not Found /{}", path);
357 return Ok(
358 milrouter::hyper::Response::builder()
359 .status(404)
360 .body(milrouter::Body::default().full())
361 .unwrap()
362 )
363 }
364 }
365
366 Ok(match milrouter::tokio::task::spawn(async move {
367 match (path.as_str(), req.method().is_idempotent()) {
368 #(#paths)*
369 path => {
370 milrouter::tracing::info!("[?] 404 Not Found /{}", path.0);
371 return milrouter::hyper::Response::builder()
372 .status(404)
373 .body(milrouter::Body::default().full())
374 .unwrap()
375 }
376 }
377 }).await {
378 Ok(inner) => inner,
379 Err(err) => {
380
381 let err = err.into_panic();
382
383 let value = err
384 .downcast_ref::<String>()
385 .cloned()
386 .or(err.downcast_ref::<&str>().map(|s| s.to_string()))
387 .unwrap_or("[Unexpected Error]".to_string());
388
389 milrouter::hyper::Response::builder()
390 .status(500)
391 .body(milrouter::Body::from(format!("{:?}", err)).full())
392 .unwrap()
393
394
395 }
396 })
397
398 }
399 }
400
401 impl std::fmt::Display for #name {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 match self {
404 #(#as_paths)*
405 }
406 }
407
408 }
409
410 impl milrouter::Router for #name {}
411
412 #(#into_routers)*
413
414 });
415
416 ts
418}
419
420#[proc_macro_attribute]
429pub fn assets(_: TokenStream, i: TokenStream) -> TokenStream { i }
430
431#[proc_macro_attribute]
437pub fn html(_: TokenStream, i: TokenStream) -> TokenStream { i }