Skip to main content

bifrostlink_macros/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{
4	parse::Parser, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Expr,
5	ExprLit, FnArg, Ident, ImplItem, ItemImpl, Lit, Meta, MetaNameValue, Pat, ReturnType, Token,
6	Type,
7};
8
9fn int_value(nv: &MetaNameValue, expected: &str) -> syn::Result<u16> {
10	if !nv.path.is_ident(expected) {
11		return Err(Error::new(
12			nv.path.span(),
13			format!("expected `{expected} = <int>`"),
14		));
15	}
16	match &nv.value {
17		Expr::Lit(ExprLit {
18			lit: Lit::Int(i), ..
19		}) => i.base10_parse::<u16>(),
20		other => Err(Error::new(other.span(), "expected integer literal")),
21	}
22}
23
24fn snake_to_pascal(name: &str) -> String {
25	let mut out = String::new();
26	for part in name.split('_') {
27		let mut chars = part.chars();
28		if let Some(first) = chars.next() {
29			out.extend(first.to_uppercase());
30			out.push_str(chars.as_str());
31		}
32	}
33	out
34}
35
36fn expand(ns: u16, item: ItemImpl) -> syn::Result<TokenStream> {
37	let self_ty = &item.self_ty;
38	let self_ident = match &**self_ty {
39		Type::Path(p) => p
40			.path
41			.segments
42			.last()
43			.map(|s| s.ident.clone())
44			.ok_or_else(|| Error::new(self_ty.span(), "expected a named type"))?,
45		_ => return Err(Error::new(self_ty.span(), "expected a named type")),
46	};
47	let client_ident = Ident::new(&format!("{self_ident}Client"), self_ident.span());
48
49	let mut generated = Vec::new();
50	let mut client_methods = Vec::new();
51	let mut registrations = Vec::new();
52
53	for impl_item in &item.items {
54		let ImplItem::Fn(method) = impl_item else {
55			continue;
56		};
57
58		let Some(id_attr) = method.attrs.iter().find(|a| a.path().is_ident("endpoints")) else {
59			continue;
60		};
61		let metas = id_attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
62		let id = metas
63			.iter()
64			.find_map(|m| match m {
65				Meta::NameValue(nv) if nv.path.is_ident("id") => Some(int_value(nv, "id")),
66				_ => None,
67			})
68			.ok_or_else(|| Error::new(id_attr.span(), "missing `id = <int>` argument"))??;
69		let cancel = metas
70			.iter()
71			.any(|m| matches!(m, Meta::Path(p) if p.is_ident("cancel")));
72		let endpoint_id: u16 = (ns << 8) | id;
73
74		let method_name = &method.sig.ident;
75		let pascal = snake_to_pascal(&method_name.to_string());
76		let request_ident = Ident::new(&pascal, method_name.span());
77		let response_ident = Ident::new(&format!("{pascal}Response"), method_name.span());
78
79		let mut field_idents = Vec::new();
80		let mut field_types = Vec::new();
81		for arg in &method.sig.inputs {
82			let FnArg::Typed(pat_type) = arg else {
83				continue;
84			};
85			let Pat::Ident(pat_ident) = &*pat_type.pat else {
86				return Err(Error::new(
87					pat_type.pat.span(),
88					"endpoint arguments must be plain identifiers",
89				));
90			};
91			field_idents.push(pat_ident.ident.clone());
92			field_types.push((*pat_type.ty).clone());
93		}
94
95		let ret_ty: Type = match &method.sig.output {
96			ReturnType::Type(_, ty) => (**ty).clone(),
97			ReturnType::Default => syn::parse_quote!(()),
98		};
99
100		let maybe_await = method.sig.asyncness.map(|_| quote!(.await));
101		let cancellable = cancel;
102
103		generated.push(quote! {
104			#[derive(Serialize, Deserialize)]
105			struct #request_ident {
106				#( #field_idents: #field_types ),*
107			}
108
109			#[derive(Serialize, Deserialize)]
110			struct #response_ident(#ret_ty);
111
112			::bifrostlink::request!((#endpoint_id) #request_ident => #response_ident);
113		});
114
115		registrations.push(quote! {
116			{
117				let v = v.clone();
118				rpc.register_request_handler::<#request_ident, _>(#cancellable, move |_addr, req| {
119					let v = v.clone();
120					async move {
121						Ok(#response_ident(
122							v.#method_name( #( req.#field_idents ),* ) #maybe_await
123						))
124					}
125				});
126			}
127		});
128
129		client_methods.push(quote! {
130			pub async fn #method_name(
131				&self,
132				#( #field_idents: #field_types ),*
133			) -> Result<#ret_ty, C::Error> {
134				Ok(self
135					.remote
136					.request(#request_ident { #( #field_idents ),* })
137					.await?
138					.0)
139			}
140		});
141	}
142
143	let mut clean = item.clone();
144	for impl_item in &mut clean.items {
145		if let ImplItem::Fn(method) = impl_item {
146			method.attrs.retain(|a| !a.path().is_ident("endpoints"));
147		}
148	}
149
150	let (impl_generics, _ty_generics, where_clause) = item.generics.split_for_impl();
151
152	Ok(quote! {
153		#clean
154
155		#( #generated )*
156
157		impl #impl_generics #self_ty #where_clause {
158			pub fn register_endpoints<C: Config>(self, rpc: &mut ::bifrostlink::Rpc<C>) {
159				let v = ::std::sync::Arc::new(self);
160				#( #registrations )*
161			}
162		}
163
164		pub struct #client_ident<C: ::bifrostlink::Config>{
165			remote: ::bifrostlink::Remote<C>,
166		}
167		impl<C: ::bifrostlink::Config> #client_ident<C> {
168			#( #client_methods )*
169		}
170		impl<C: ::bifrostlink::Config> ::bifrostlink::declarative::RemoteEndpoints<C> for #client_ident<C> {
171			fn wrap(remote: ::bifrostlink::Remote<C>) -> Self {
172				Self { remote }
173			}
174		}
175		impl<C: ::bifrostlink::Config> Clone for #client_ident<C> {
176			fn clone(&self) -> Self {
177				Self {
178					remote: self.remote.clone(),
179				}
180			}
181		}
182	})
183}
184
185#[proc_macro_attribute]
186pub fn endpoints(
187	attrs: proc_macro::TokenStream,
188	body: proc_macro::TokenStream,
189) -> proc_macro::TokenStream {
190	let item = parse_macro_input!(body as ItemImpl);
191
192	let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
193	let metas = match parser.parse(attrs) {
194		Ok(m) => m,
195		Err(e) => return e.to_compile_error().into(),
196	};
197	let ns = match metas.iter().find_map(|m| match m {
198		Meta::NameValue(nv) if nv.path.is_ident("ns") => Some(int_value(nv, "ns")),
199		_ => None,
200	}) {
201		Some(Ok(ns)) => ns,
202		Some(Err(e)) => return e.to_compile_error().into(),
203		None => {
204			return Error::new(Span::call_site(), "missing `ns = <int>` argument")
205				.to_compile_error()
206				.into()
207		}
208	};
209
210	match expand(ns, item) {
211		Ok(ts) => ts.into(),
212		Err(e) => e.to_compile_error().into(),
213	}
214}