bifrostlink_macros/
lib.rs1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{
4 parse::Parser, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Expr,
5 ExprLit, FnArg, Ident, ImplItem, ItemImpl, Lit, Meta, MetaNameValue, Pat, ReturnType, Token,
6 Type,
7};
8
9fn int_value(nv: &MetaNameValue, expected: &str) -> syn::Result<u16> {
10 if !nv.path.is_ident(expected) {
11 return Err(Error::new(
12 nv.path.span(),
13 format!("expected `{expected} = <int>`"),
14 ));
15 }
16 match &nv.value {
17 Expr::Lit(ExprLit {
18 lit: Lit::Int(i), ..
19 }) => i.base10_parse::<u16>(),
20 other => Err(Error::new(other.span(), "expected integer literal")),
21 }
22}
23
24fn snake_to_pascal(name: &str) -> String {
25 let mut out = String::new();
26 for part in name.split('_') {
27 let mut chars = part.chars();
28 if let Some(first) = chars.next() {
29 out.extend(first.to_uppercase());
30 out.push_str(chars.as_str());
31 }
32 }
33 out
34}
35
36fn expand(ns: u16, item: ItemImpl) -> syn::Result<TokenStream> {
37 let self_ty = &item.self_ty;
38 let self_ident = match &**self_ty {
39 Type::Path(p) => p
40 .path
41 .segments
42 .last()
43 .map(|s| s.ident.clone())
44 .ok_or_else(|| Error::new(self_ty.span(), "expected a named type"))?,
45 _ => return Err(Error::new(self_ty.span(), "expected a named type")),
46 };
47 let client_ident = Ident::new(&format!("{self_ident}Client"), self_ident.span());
48
49 let mut generated = Vec::new();
50 let mut client_methods = Vec::new();
51 let mut registrations = Vec::new();
52
53 for impl_item in &item.items {
54 let ImplItem::Fn(method) = impl_item else {
55 continue;
56 };
57
58 let Some(id_attr) = method.attrs.iter().find(|a| a.path().is_ident("endpoints")) else {
59 continue;
60 };
61 let metas = id_attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
62 let id = metas
63 .iter()
64 .find_map(|m| match m {
65 Meta::NameValue(nv) if nv.path.is_ident("id") => Some(int_value(nv, "id")),
66 _ => None,
67 })
68 .ok_or_else(|| Error::new(id_attr.span(), "missing `id = <int>` argument"))??;
69 let cancel = metas
70 .iter()
71 .any(|m| matches!(m, Meta::Path(p) if p.is_ident("cancel")));
72 let endpoint_id: u16 = (ns << 8) | id;
73
74 let method_name = &method.sig.ident;
75 let pascal = snake_to_pascal(&method_name.to_string());
76 let request_ident = Ident::new(&pascal, method_name.span());
77 let response_ident = Ident::new(&format!("{pascal}Response"), method_name.span());
78
79 let mut field_idents = Vec::new();
80 let mut field_types = Vec::new();
81 for arg in &method.sig.inputs {
82 let FnArg::Typed(pat_type) = arg else {
83 continue;
84 };
85 let Pat::Ident(pat_ident) = &*pat_type.pat else {
86 return Err(Error::new(
87 pat_type.pat.span(),
88 "endpoint arguments must be plain identifiers",
89 ));
90 };
91 field_idents.push(pat_ident.ident.clone());
92 field_types.push((*pat_type.ty).clone());
93 }
94
95 let ret_ty: Type = match &method.sig.output {
96 ReturnType::Type(_, ty) => (**ty).clone(),
97 ReturnType::Default => syn::parse_quote!(()),
98 };
99
100 let maybe_await = method.sig.asyncness.map(|_| quote!(.await));
101 let cancellable = cancel;
102
103 generated.push(quote! {
104 #[derive(Serialize, Deserialize)]
105 struct #request_ident {
106 #( #field_idents: #field_types ),*
107 }
108
109 #[derive(Serialize, Deserialize)]
110 struct #response_ident(#ret_ty);
111
112 ::bifrostlink::request!((#endpoint_id) #request_ident => #response_ident);
113 });
114
115 registrations.push(quote! {
116 {
117 let v = v.clone();
118 rpc.register_request_handler::<#request_ident, _>(#cancellable, move |_addr, req| {
119 let v = v.clone();
120 async move {
121 Ok(#response_ident(
122 v.#method_name( #( req.#field_idents ),* ) #maybe_await
123 ))
124 }
125 });
126 }
127 });
128
129 client_methods.push(quote! {
130 pub async fn #method_name(
131 &self,
132 #( #field_idents: #field_types ),*
133 ) -> Result<#ret_ty, C::Error> {
134 Ok(self
135 .remote
136 .request(#request_ident { #( #field_idents ),* })
137 .await?
138 .0)
139 }
140 });
141 }
142
143 let mut clean = item.clone();
144 for impl_item in &mut clean.items {
145 if let ImplItem::Fn(method) = impl_item {
146 method.attrs.retain(|a| !a.path().is_ident("endpoints"));
147 }
148 }
149
150 let (impl_generics, _ty_generics, where_clause) = item.generics.split_for_impl();
151
152 Ok(quote! {
153 #clean
154
155 #( #generated )*
156
157 impl #impl_generics #self_ty #where_clause {
158 pub fn register_endpoints<C: Config>(self, rpc: &mut ::bifrostlink::Rpc<C>) {
159 let v = ::std::sync::Arc::new(self);
160 #( #registrations )*
161 }
162 }
163
164 pub struct #client_ident<C: ::bifrostlink::Config>{
165 remote: ::bifrostlink::Remote<C>,
166 }
167 impl<C: ::bifrostlink::Config> #client_ident<C> {
168 #( #client_methods )*
169 }
170 impl<C: ::bifrostlink::Config> ::bifrostlink::declarative::RemoteEndpoints<C> for #client_ident<C> {
171 fn wrap(remote: ::bifrostlink::Remote<C>) -> Self {
172 Self { remote }
173 }
174 }
175 impl<C: ::bifrostlink::Config> Clone for #client_ident<C> {
176 fn clone(&self) -> Self {
177 Self {
178 remote: self.remote.clone(),
179 }
180 }
181 }
182 })
183}
184
185#[proc_macro_attribute]
186pub fn endpoints(
187 attrs: proc_macro::TokenStream,
188 body: proc_macro::TokenStream,
189) -> proc_macro::TokenStream {
190 let item = parse_macro_input!(body as ItemImpl);
191
192 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
193 let metas = match parser.parse(attrs) {
194 Ok(m) => m,
195 Err(e) => return e.to_compile_error().into(),
196 };
197 let ns = match metas.iter().find_map(|m| match m {
198 Meta::NameValue(nv) if nv.path.is_ident("ns") => Some(int_value(nv, "ns")),
199 _ => None,
200 }) {
201 Some(Ok(ns)) => ns,
202 Some(Err(e)) => return e.to_compile_error().into(),
203 None => {
204 return Error::new(Span::call_site(), "missing `ns = <int>` argument")
205 .to_compile_error()
206 .into()
207 }
208 };
209
210 match expand(ns, item) {
211 Ok(ts) => ts.into(),
212 Err(e) => e.to_compile_error().into(),
213 }
214}