1use std::collections::BTreeMap;
2
3use syn::visit::Visit;
4use quote::{ToTokens, quote};
5
6fn quick_hash<T: std::hash::Hash>(t: &T) -> u64 {
7 use std::{collections::hash_map::DefaultHasher, hash::Hasher};
8
9 let mut hasher = DefaultHasher::new();
10 t.hash(&mut hasher);
11 hasher.finish()
12}
13
14#[proc_macro_attribute]
15pub fn server(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
16 let mut item = syn::parse_macro_input!(item as syn::ItemFn);
17 let hash = quick_hash(&item);
18 let output = match item.sig.output {
19 syn::ReturnType::Default => syn::parse_quote!(()),
20 syn::ReturnType::Type(_, ty) => *ty,
21 };
22 item.sig.output = syn::parse_quote!(-> ::std::result::Result<#output, ::anyhow::Error>);
23 let arg_idents = item.sig.inputs.iter().map(|x| match x {
24 syn::FnArg::Typed(x) => x.pat.clone(),
25 syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
26 });
27 item.block = syn::parse_quote!({
28 const HASH: u64 = #hash;
29
30 let args = (#(#arg_idents),*);
31 let mut serialized = ::std::vec::Vec::with_capacity(::postcard::experimental::serialized_size(&HASH)? + ::postcard::experimental::serialized_size(&args)?);
32 ::postcard::to_io(&HASH, &mut serialized)?;
33 ::postcard::to_io(&args, &mut serialized)?;
34 Ok(::postcard::from_bytes(&crate::api::dispatch(serialized).await?)?)
35 });
36 item.into_token_stream().into()
37}
38
39
40struct Visitor<'a> {
41 root: &'a std::path::Path,
42
43 api_fns: Vec<syn::ItemFn>,
44
45 current_path: (Vec<syn::Ident>, Vec<syn::Attribute>),
46 sub_visitors: BTreeMap<syn::Ident, Self>,
47}
48
49impl<'a> Visitor<'a> {
50 fn new(root: &'a std::path::Path, current_path: (Vec<syn::Ident>, Vec<syn::Attribute>)) -> Self {
51 Self { root, api_fns: Vec::new(), current_path, sub_visitors: BTreeMap::new() }
52 }
53
54 fn write_out(&self, out: &mut Vec<syn::Item>) {
55 for f in &self.api_fns {
56 out.push(syn::parse_quote!(#f));
57 }
58
59 for (module, sub_visitor) in &self.sub_visitors {
60 if sub_visitor.api_fns.is_empty() { continue; }
61 let mut sub_out: Vec<syn::Item> = Vec::with_capacity(sub_visitor.api_fns.len() + sub_visitor.sub_visitors.len());
62 sub_visitor.write_out(&mut sub_out);
63 out.push(syn::parse_quote!(pub mod #module { #(#sub_out)* }));
64 }
65 }
66
67 fn write_arms(&self, out: &mut Vec<syn::Arm>) {
68 for f in &self.api_fns {
69 let hash = quick_hash(&f);
70 let current_path = &self.current_path.0;
71 let fn_ident = &f.sig.ident;
72 let fn_path = quote!(#(#current_path ::)*#fn_ident);
73 let arg_idents = &f.sig.inputs.iter().map(|x| match x {
74 syn::FnArg::Typed(x) => x.pat.clone(),
75 syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
76 }).collect::<Vec<_>>();
77 #[cfg(feature = "trace")] let fn_path_str = fn_path.to_string().replace(" ", "");
78 #[cfg(feature = "trace")] let log_str_pre = format!("{fn_path_str}{{args:?}}");
79 #[cfg(feature = "trace")] let log_str_post = format!("{fn_path_str} -> {{res:?}}");
80 #[cfg(feature = "trace")] let maybe_trace_pre = quote!(log::trace!(#log_str_pre););
81 #[cfg(feature = "trace")] let maybe_trace_post = quote!(log::trace!(#log_str_post););
82 #[cfg(not(feature = "trace"))] let maybe_trace_pre = quote!();
83 #[cfg(not(feature = "trace"))] let maybe_trace_post = quote!();
84 out.push(syn::parse_quote!(#hash => {
85 let args = ::postcard::from_io((&mut bytes, &mut scratch))?.0;
86 #maybe_trace_pre
87 let (#(#arg_idents),*) = args;
88 let res = #fn_path(#(#arg_idents),*).await;
89 #maybe_trace_post
90 Ok(::postcard::to_stdvec(&res)?)
91 }));
92 }
93
94 for sub_visitor in self.sub_visitors.values() {
95 sub_visitor.write_arms(out);
96 }
97 }
98
99 fn total_fns(&self) -> usize {
100 self.api_fns.len() + self.sub_visitors.values().map(Visitor::total_fns).sum::<usize>()
101 }
102}
103
104impl Visit<'_> for Visitor<'_> {
105 fn visit_item_mod(&mut self, node: &syn::ItemMod) {
107 let mut path = self.current_path.0.clone();
108 path.push(node.ident.clone());
109 let mut visitor = Visitor::new(self.root, (path, node.attrs.clone()));
110 if let Some((_, items)) = &node.content {
111 for item in items {
112 visitor.visit_item(item);
113 }
114 } else {
115 let mut path = visitor.root.to_owned();
116 for seg in &visitor.current_path.0 {
117 path.push(seg.to_string());
118 }
119
120 let name_rs = path.with_extension("rs");
121 let mod_rs = path.join("mod.rs");
122
123 let path = if name_rs.exists() { name_rs } else if mod_rs.exists() { mod_rs } else {
124 panic!("No file found for module {}, is there a loose mod declaration that isn't pointing anywhere?", visitor.current_path.0.last().unwrap())
125 };
126
127 let file = std::fs::read_to_string(&path).expect("Error reading file. Is there a loose mod declaration that isn't pointing anywhere?");
128 visitor.visit_file(&syn::parse_file(&file).unwrap());
129 }
130
131 self.sub_visitors.insert(node.ident.clone(), visitor);
132 }
133
134 fn visit_item_fn(&mut self, node: &syn::ItemFn) {
135 let pu_239_server: syn::Path = syn::parse_quote!(pu_239::server);
136 let Some(_api_attr) = node.attrs.iter().find(|attr| *attr.path() == pu_239_server) else { return syn::visit::visit_item_fn(self, node); };
137 let mut node = node.clone();
138 node.attrs.retain(|attr| *attr.path() != pu_239_server);
139 self.api_fns.push(node);
140 }
141}
142
143#[proc_macro]
144pub fn build_api(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
145 let roots = syn::parse_macro_input!(item as syn::ExprArray).elems.into_iter()
146 .map(|elem| {
147 let root: syn::LitStr = syn::parse_quote!(#elem);
148 std::path::PathBuf::from(root.value())
149 })
150 .collect::<Vec<_>>();
151
152 let visitors = roots.iter().map(|root| {
153 let mut visitor = Visitor::new(root.parent().unwrap(), (Vec::new(), Vec::new()));
154 visitor.visit_file(&syn::parse_file(&std::fs::read_to_string(root).unwrap()).unwrap());
155 visitor
156 }).collect::<Vec<_>>();
157
158 let mut out = Vec::<syn::Item>::with_capacity(visitors.iter().map(|visitor| visitor.api_fns.len() + visitor.sub_visitors.len()).sum());
159 let mut arms = Vec::<syn::Arm>::with_capacity(visitors.iter().map(|visitor| visitor.total_fns()).sum());
160
161 for visitor in visitors {
162 visitor.write_out(&mut out);
163 visitor.write_arms(&mut arms);
164 }
165
166 quote!(
167 #(#out)*
168
169 async fn deserialize_api_match(mut bytes: impl ::std::io::Read) -> ::std::result::Result<::std::vec::Vec<u8>, ::anyhow::Error> {
170 let mut scratch = [0u8; 2048];
171 let (hash, (mut bytes, _)) = ::postcard::from_io::<u64, _>((bytes, &mut scratch))?;
172 match hash {
173 #(#arms),*
174 method_id => Err(::anyhow::anyhow!("Unknown method id: {method_id}")),
175 }
176 }
177 ).into()
178}