pu-239 0.2.1

serverside fns in the client
Documentation
use std::collections::BTreeMap;

use syn::visit::Visit;
use quote::{ToTokens, quote};

fn quick_hash<T: std::hash::Hash>(t: &T) -> u64 {
	use std::{collections::hash_map::DefaultHasher, hash::Hasher};

	let mut hasher = DefaultHasher::new();
	t.hash(&mut hasher);
	hasher.finish()
}

#[proc_macro_attribute]
pub fn server(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let mut item = syn::parse_macro_input!(item as syn::ItemFn);
	let hash = quick_hash(&item);
	let output = match item.sig.output {
		syn::ReturnType::Default => syn::parse_quote!(()),
		syn::ReturnType::Type(_, ty) => *ty,
	};
	item.sig.output = syn::parse_quote!(-> ::std::result::Result<#output, ::anyhow::Error>);
	let arg_idents = item.sig.inputs.iter().map(|x| match x {
		syn::FnArg::Typed(x) => x.pat.clone(),
		syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
	});
	item.block = syn::parse_quote!({
		const HASH: u64 = #hash;

		let args = (#(#arg_idents),*);
		let mut serialized = ::std::vec::Vec::with_capacity(::postcard::experimental::serialized_size(&HASH)? + ::postcard::experimental::serialized_size(&args)?);
		::postcard::to_io(&HASH, &mut serialized)?;
		::postcard::to_io(&args, &mut serialized)?;
		Ok(::postcard::from_bytes(&crate::api::dispatch(serialized).await?)?)
	});
	item.into_token_stream().into()
}


struct Visitor<'a> {
	root: &'a std::path::Path,

	api_fns: Vec<syn::ItemFn>,

	current_path: (Vec<syn::Ident>, Vec<syn::Attribute>),
	sub_visitors: BTreeMap<syn::Ident, Self>,
}

impl<'a> Visitor<'a> {
	fn new(root: &'a std::path::Path, current_path: (Vec<syn::Ident>, Vec<syn::Attribute>)) -> Self {
		Self { root, api_fns: Vec::new(), current_path, sub_visitors: BTreeMap::new() }
	}

	fn write_out(&self, out: &mut Vec<syn::Item>) {
		for f in &self.api_fns {
			out.push(syn::parse_quote!(#f));
		}

		for (module, sub_visitor) in &self.sub_visitors {
			if sub_visitor.api_fns.is_empty() { continue; }
			let mut sub_out: Vec<syn::Item> = Vec::with_capacity(sub_visitor.api_fns.len() + sub_visitor.sub_visitors.len());
			sub_visitor.write_out(&mut sub_out);
			out.push(syn::parse_quote!(pub mod #module { #(#sub_out)* }));
		}
	}

	fn write_arms(&self, out: &mut Vec<syn::Arm>) {
		for f in &self.api_fns {
			let hash = quick_hash(&f);
			let current_path = &self.current_path.0;
			let fn_ident = &f.sig.ident;
			let fn_path = quote!(#(#current_path ::)*#fn_ident);
			let arg_idents = &f.sig.inputs.iter().map(|x| match x {
				syn::FnArg::Typed(x) => x.pat.clone(),
				syn::FnArg::Receiver(_) => panic!("Expected typed argument"),
			}).collect::<Vec<_>>();
			#[cfg(feature = "trace")] let fn_path_str = fn_path.to_string().replace(" ", "");
			#[cfg(feature = "trace")] let log_str_pre = format!("{fn_path_str}{{args:?}}");
			#[cfg(feature = "trace")] let log_str_post = format!("{fn_path_str} -> {{res:?}}");
			#[cfg(feature = "trace")] let maybe_trace_pre = quote!(log::trace!(#log_str_pre););
			#[cfg(feature = "trace")] let maybe_trace_post = quote!(log::trace!(#log_str_post););
			#[cfg(not(feature = "trace"))] let maybe_trace_pre = quote!();
			#[cfg(not(feature = "trace"))] let maybe_trace_post = quote!();
			out.push(syn::parse_quote!(#hash => {
				let args = ::postcard::from_io((&mut bytes, &mut scratch))?.0;
				#maybe_trace_pre
				let (#(#arg_idents),*) = args;
				let res = #fn_path(#(#arg_idents),*).await;
				#maybe_trace_post
				Ok(::postcard::to_stdvec(&res)?)
			}));
		}

		for sub_visitor in self.sub_visitors.values() {
			sub_visitor.write_arms(out);
		}
	}

	fn total_fns(&self) -> usize {
		self.api_fns.len() + self.sub_visitors.values().map(Visitor::total_fns).sum::<usize>()
	}
}

impl Visit<'_> for Visitor<'_> {
	// create a visitor for each api module or file, recursive
	fn visit_item_mod(&mut self, node: &syn::ItemMod) {
		let mut path = self.current_path.0.clone();
		path.push(node.ident.clone());
		let mut visitor = Visitor::new(self.root, (path, node.attrs.clone()));
		if let Some((_, items)) = &node.content {
			for item in items {
				visitor.visit_item(item);
			}
		} else {
			let mut path = visitor.root.to_owned();
			for seg in &visitor.current_path.0 {
				path.push(seg.to_string());
			}

			let name_rs = path.with_extension("rs");
			let mod_rs = path.join("mod.rs");

			let path = if name_rs.exists() { name_rs } else if mod_rs.exists() { mod_rs } else {
				panic!("No file found for module {}, is there a loose mod declaration that isn't pointing anywhere?", visitor.current_path.0.last().unwrap())
			};

			let file = std::fs::read_to_string(&path).expect("Error reading file. Is there a loose mod declaration that isn't pointing anywhere?");
			visitor.visit_file(&syn::parse_file(&file).unwrap());
		}

		self.sub_visitors.insert(node.ident.clone(), visitor);
	}

	fn visit_item_fn(&mut self, node: &syn::ItemFn) {
		let pu_239_server: syn::Path = syn::parse_quote!(pu_239::server);
		let Some(_api_attr) = node.attrs.iter().find(|attr| *attr.path() == pu_239_server) else { return syn::visit::visit_item_fn(self, node); };
		let mut node = node.clone();
		node.attrs.retain(|attr| *attr.path() != pu_239_server);
		self.api_fns.push(node);
	}
}

#[proc_macro]
pub fn build_api(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let roots = syn::parse_macro_input!(item as syn::ExprArray).elems.into_iter()
		.map(|elem| {
			let root: syn::LitStr = syn::parse_quote!(#elem);
			std::path::PathBuf::from(root.value())
		})
		.collect::<Vec<_>>();

	let visitors = roots.iter().map(|root| {
		let mut visitor = Visitor::new(root.parent().unwrap(), (Vec::new(), Vec::new()));
		visitor.visit_file(&syn::parse_file(&std::fs::read_to_string(root).unwrap()).unwrap());
		visitor
	}).collect::<Vec<_>>();

	let mut out = Vec::<syn::Item>::with_capacity(visitors.iter().map(|visitor| visitor.api_fns.len() + visitor.sub_visitors.len()).sum());
	let mut arms = Vec::<syn::Arm>::with_capacity(visitors.iter().map(|visitor| visitor.total_fns()).sum());

	for visitor in visitors {
		visitor.write_out(&mut out);
		visitor.write_arms(&mut arms);
	}

	quote!(
		#(#out)*

		async fn deserialize_api_match(mut bytes: impl ::std::io::Read) -> ::std::result::Result<::std::vec::Vec<u8>, ::anyhow::Error> {
			let mut scratch = [0u8; 2048];
			let (hash, (mut bytes, _)) = ::postcard::from_io::<u64, _>((bytes, &mut scratch))?;
			match hash {
				#(#arms),*
				method_id => Err(::anyhow::anyhow!("Unknown method id: {method_id}")),
			}
		}
	).into()
}