pu_239/
lib.rs

1use std::collections::BTreeMap;
2
3use syn::visit::Visit;
4use quote::{ToTokens, quote};
5
6fn quick_hash<T: std::hash::Hash>(t: &T) -> u64 {
7	use std::{collections::hash_map::DefaultHasher, hash::Hasher};
8
9	let mut hasher = DefaultHasher::new();
10	t.hash(&mut hasher);
11	hasher.finish()
12}
13
14#[proc_macro_attribute]
15pub fn server(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
16	let mut item = syn::parse_macro_input!(item as syn::ItemFn);
17	let hash = quick_hash(&item);
18	let output = match item.sig.output {
19		syn::ReturnType::Default => syn::parse_quote!(()),
20		syn::ReturnType::Type(_, ty) => *ty,
21	};
22	item.sig.output = syn::parse_quote!(-> ::std::result::Result<#output, ::anyhow::Error>);
23	let arg_idents = item.sig.inputs.iter().map(|x| match x {
24		syn::FnArg::Typed(x) => x.pat.clone(),
25		syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
26	});
27	item.block = syn::parse_quote!({
28		const HASH: u64 = #hash;
29
30		let args = (#(#arg_idents),*);
31		let mut serialized = ::std::vec::Vec::with_capacity(::postcard::experimental::serialized_size(&HASH)? + ::postcard::experimental::serialized_size(&args)?);
32		::postcard::to_io(&HASH, &mut serialized)?;
33		::postcard::to_io(&args, &mut serialized)?;
34		Ok(::postcard::from_bytes(&crate::api::dispatch(serialized).await?)?)
35	});
36	item.into_token_stream().into()
37}
38
39
40struct Visitor<'a> {
41	root: &'a std::path::Path,
42
43	api_fns: Vec<syn::ItemFn>,
44
45	current_path: (Vec<syn::Ident>, Vec<syn::Attribute>),
46	sub_visitors: BTreeMap<syn::Ident, Self>,
47}
48
49impl<'a> Visitor<'a> {
50	fn new(root: &'a std::path::Path, current_path: (Vec<syn::Ident>, Vec<syn::Attribute>)) -> Self {
51		Self { root, api_fns: Vec::new(), current_path, sub_visitors: BTreeMap::new() }
52	}
53
54	fn write_out(&self, out: &mut Vec<syn::Item>) {
55		for f in &self.api_fns {
56			out.push(syn::parse_quote!(#f));
57		}
58
59		for (module, sub_visitor) in &self.sub_visitors {
60			if sub_visitor.api_fns.is_empty() { continue; }
61			let mut sub_out: Vec<syn::Item> = Vec::with_capacity(sub_visitor.api_fns.len() + sub_visitor.sub_visitors.len());
62			sub_visitor.write_out(&mut sub_out);
63			out.push(syn::parse_quote!(pub mod #module { #(#sub_out)* }));
64		}
65	}
66
67	fn write_arms(&self, out: &mut Vec<syn::Arm>) {
68		for f in &self.api_fns {
69			let hash = quick_hash(&f);
70			let current_path = &self.current_path.0;
71			let fn_ident = &f.sig.ident;
72			let fn_path = quote!(#(#current_path ::)*#fn_ident);
73			let arg_idents = &f.sig.inputs.iter().map(|x| match x {
74				syn::FnArg::Typed(x) => x.pat.clone(),
75				syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
76			}).collect::<Vec<_>>();
77			#[cfg(feature = "trace")] let fn_path_str = fn_path.to_string().replace(" ", "");
78			#[cfg(feature = "trace")] let log_str_pre = format!("{fn_path_str}{{args:?}}");
79			#[cfg(feature = "trace")] let log_str_post = format!("{fn_path_str} -> {{res:?}}");
80			#[cfg(feature = "trace")] let maybe_trace_pre = quote!(log::trace!(#log_str_pre););
81			#[cfg(feature = "trace")] let maybe_trace_post = quote!(log::trace!(#log_str_post););
82			#[cfg(not(feature = "trace"))] let maybe_trace_pre = quote!();
83			#[cfg(not(feature = "trace"))] let maybe_trace_post = quote!();
84			out.push(syn::parse_quote!(#hash => {
85				let args = ::postcard::from_io((&mut bytes, &mut scratch))?.0;
86				#maybe_trace_pre
87				let (#(#arg_idents),*) = args;
88				let res = #fn_path(#(#arg_idents),*).await;
89				#maybe_trace_post
90				Ok(::postcard::to_stdvec(&res)?)
91			}));
92		}
93
94		for sub_visitor in self.sub_visitors.values() {
95			sub_visitor.write_arms(out);
96		}
97	}
98
99	fn total_fns(&self) -> usize {
100		self.api_fns.len() + self.sub_visitors.values().map(Visitor::total_fns).sum::<usize>()
101	}
102}
103
104impl Visit<'_> for Visitor<'_> {
105	// create a visitor for each api module or file, recursive
106	fn visit_item_mod(&mut self, node: &syn::ItemMod) {
107		let mut path = self.current_path.0.clone();
108		path.push(node.ident.clone());
109		let mut visitor = Visitor::new(self.root, (path, node.attrs.clone()));
110		if let Some((_, items)) = &node.content {
111			for item in items {
112				visitor.visit_item(item);
113			}
114		} else {
115			let mut path = visitor.root.to_owned();
116			for seg in &visitor.current_path.0 {
117				path.push(seg.to_string());
118			}
119
120			let name_rs = path.with_extension("rs");
121			let mod_rs = path.join("mod.rs");
122
123			let path = if name_rs.exists() { name_rs } else if mod_rs.exists() { mod_rs } else {
124				panic!("No file found for module {}, is there a loose mod declaration that isn't pointing anywhere?", visitor.current_path.0.last().unwrap())
125			};
126
127			let file = std::fs::read_to_string(&path).expect("Error reading file. Is there a loose mod declaration that isn't pointing anywhere?");
128			visitor.visit_file(&syn::parse_file(&file).unwrap());
129		}
130
131		self.sub_visitors.insert(node.ident.clone(), visitor);
132	}
133
134	fn visit_item_fn(&mut self, node: &syn::ItemFn) {
135		let pu_239_server: syn::Path = syn::parse_quote!(pu_239::server);
136		let Some(_api_attr) = node.attrs.iter().find(|attr| *attr.path() == pu_239_server) else { return syn::visit::visit_item_fn(self, node); };
137		let mut node = node.clone();
138		node.attrs.retain(|attr| *attr.path() != pu_239_server);
139		self.api_fns.push(node);
140	}
141}
142
143#[proc_macro]
144pub fn build_api(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
145	let roots = syn::parse_macro_input!(item as syn::ExprArray).elems.into_iter()
146		.map(|elem| {
147			let root: syn::LitStr = syn::parse_quote!(#elem);
148			std::path::PathBuf::from(root.value())
149		})
150		.collect::<Vec<_>>();
151
152	let visitors = roots.iter().map(|root| {
153		let mut visitor = Visitor::new(root.parent().unwrap(), (Vec::new(), Vec::new()));
154		visitor.visit_file(&syn::parse_file(&std::fs::read_to_string(root).unwrap()).unwrap());
155		visitor
156	}).collect::<Vec<_>>();
157
158	let mut out = Vec::<syn::Item>::with_capacity(visitors.iter().map(|visitor| visitor.api_fns.len() + visitor.sub_visitors.len()).sum());
159	let mut arms = Vec::<syn::Arm>::with_capacity(visitors.iter().map(|visitor| visitor.total_fns()).sum());
160
161	for visitor in visitors {
162		visitor.write_out(&mut out);
163		visitor.write_arms(&mut arms);
164	}
165
166	quote!(
167		#(#out)*
168
169		async fn deserialize_api_match(mut bytes: impl ::std::io::Read) -> ::std::result::Result<::std::vec::Vec<u8>, ::anyhow::Error> {
170			let mut scratch = [0u8; 2048];
171			let (hash, (mut bytes, _)) = ::postcard::from_io::<u64, _>((bytes, &mut scratch))?;
172			match hash {
173				#(#arms),*
174				method_id => Err(::anyhow::anyhow!("Unknown method id: {method_id}")),
175			}
176		}
177	).into()
178}