bifrostlink-macros 0.2.1

Codegen helper for bifrostlink
Documentation
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
	parse::Parser, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Expr,
	ExprLit, FnArg, Ident, ImplItem, ItemImpl, Lit, Meta, MetaNameValue, Pat, ReturnType, Token,
	Type,
};

fn int_value(nv: &MetaNameValue, expected: &str) -> syn::Result<u16> {
	if !nv.path.is_ident(expected) {
		return Err(Error::new(
			nv.path.span(),
			format!("expected `{expected} = <int>`"),
		));
	}
	match &nv.value {
		Expr::Lit(ExprLit {
			lit: Lit::Int(i), ..
		}) => i.base10_parse::<u16>(),
		other => Err(Error::new(other.span(), "expected integer literal")),
	}
}

fn snake_to_pascal(name: &str) -> String {
	let mut out = String::new();
	for part in name.split('_') {
		let mut chars = part.chars();
		if let Some(first) = chars.next() {
			out.extend(first.to_uppercase());
			out.push_str(chars.as_str());
		}
	}
	out
}

fn expand(ns: u16, item: ItemImpl) -> syn::Result<TokenStream> {
	let self_ty = &item.self_ty;
	let self_ident = match &**self_ty {
		Type::Path(p) => p
			.path
			.segments
			.last()
			.map(|s| s.ident.clone())
			.ok_or_else(|| Error::new(self_ty.span(), "expected a named type"))?,
		_ => return Err(Error::new(self_ty.span(), "expected a named type")),
	};
	let client_ident = Ident::new(&format!("{self_ident}Client"), self_ident.span());

	let mut generated = Vec::new();
	let mut client_methods = Vec::new();
	let mut registrations = Vec::new();

	for impl_item in &item.items {
		let ImplItem::Fn(method) = impl_item else {
			continue;
		};

		let Some(id_attr) = method.attrs.iter().find(|a| a.path().is_ident("endpoints")) else {
			continue;
		};
		let metas = id_attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
		let id = metas
			.iter()
			.find_map(|m| match m {
				Meta::NameValue(nv) if nv.path.is_ident("id") => Some(int_value(nv, "id")),
				_ => None,
			})
			.ok_or_else(|| Error::new(id_attr.span(), "missing `id = <int>` argument"))??;
		let cancel = metas
			.iter()
			.any(|m| matches!(m, Meta::Path(p) if p.is_ident("cancel")));
		let endpoint_id: u16 = (ns << 8) | id;

		let method_name = &method.sig.ident;
		let pascal = snake_to_pascal(&method_name.to_string());
		let request_ident = Ident::new(&pascal, method_name.span());
		let response_ident = Ident::new(&format!("{pascal}Response"), method_name.span());

		let mut field_idents = Vec::new();
		let mut field_types = Vec::new();
		for arg in &method.sig.inputs {
			let FnArg::Typed(pat_type) = arg else {
				continue;
			};
			let Pat::Ident(pat_ident) = &*pat_type.pat else {
				return Err(Error::new(
					pat_type.pat.span(),
					"endpoint arguments must be plain identifiers",
				));
			};
			field_idents.push(pat_ident.ident.clone());
			field_types.push((*pat_type.ty).clone());
		}

		let ret_ty: Type = match &method.sig.output {
			ReturnType::Type(_, ty) => (**ty).clone(),
			ReturnType::Default => syn::parse_quote!(()),
		};

		let maybe_await = method.sig.asyncness.map(|_| quote!(.await));
		let cancellable = cancel;

		generated.push(quote! {
			#[derive(Serialize, Deserialize)]
			struct #request_ident {
				#( #field_idents: #field_types ),*
			}

			#[derive(Serialize, Deserialize)]
			struct #response_ident(#ret_ty);

			::bifrostlink::request!((#endpoint_id) #request_ident => #response_ident);
		});

		registrations.push(quote! {
			{
				let v = v.clone();
				rpc.register_request_handler::<#request_ident, _>(#cancellable, move |_addr, req| {
					let v = v.clone();
					async move {
						Ok(#response_ident(
							v.#method_name( #( req.#field_idents ),* ) #maybe_await
						))
					}
				});
			}
		});

		client_methods.push(quote! {
			pub async fn #method_name(
				&self,
				#( #field_idents: #field_types ),*
			) -> Result<#ret_ty, C::Error> {
				Ok(self
					.remote
					.request(#request_ident { #( #field_idents ),* })
					.await?
					.0)
			}
		});
	}

	let mut clean = item.clone();
	for impl_item in &mut clean.items {
		if let ImplItem::Fn(method) = impl_item {
			method.attrs.retain(|a| !a.path().is_ident("endpoints"));
		}
	}

	let (impl_generics, _ty_generics, where_clause) = item.generics.split_for_impl();

	Ok(quote! {
		#clean

		#( #generated )*

		impl #impl_generics #self_ty #where_clause {
			pub fn register_endpoints<C: Config>(self, rpc: &mut ::bifrostlink::Rpc<C>) {
				let v = ::std::sync::Arc::new(self);
				#( #registrations )*
			}
		}

		pub struct #client_ident<C: ::bifrostlink::Config>{
			remote: ::bifrostlink::Remote<C>,
		}
		impl<C: ::bifrostlink::Config> #client_ident<C> {
			#( #client_methods )*
		}
		impl<C: ::bifrostlink::Config> ::bifrostlink::declarative::RemoteEndpoints<C> for #client_ident<C> {
			fn wrap(remote: ::bifrostlink::Remote<C>) -> Self {
				Self { remote }
			}
		}
		impl<C: ::bifrostlink::Config> Clone for #client_ident<C> {
			fn clone(&self) -> Self {
				Self {
					remote: self.remote.clone(),
				}
			}
		}
	})
}

#[proc_macro_attribute]
pub fn endpoints(
	attrs: proc_macro::TokenStream,
	body: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
	let item = parse_macro_input!(body as ItemImpl);

	let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
	let metas = match parser.parse(attrs) {
		Ok(m) => m,
		Err(e) => return e.to_compile_error().into(),
	};
	let ns = match metas.iter().find_map(|m| match m {
		Meta::NameValue(nv) if nv.path.is_ident("ns") => Some(int_value(nv, "ns")),
		_ => None,
	}) {
		Some(Ok(ns)) => ns,
		Some(Err(e)) => return e.to_compile_error().into(),
		None => {
			return Error::new(Span::call_site(), "missing `ns = <int>` argument")
				.to_compile_error()
				.into()
		}
	};

	match expand(ns, item) {
		Ok(ts) => ts.into(),
		Err(e) => e.to_compile_error().into(),
	}
}