1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, FnArg, ReturnType, Pat};
4
5#[proc_macro_derive(Service)]
27pub fn derive_service(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29 let name = &input.ident;
30
31 let fields = match &input.data {
32 syn::Data::Struct(data) => match &data.fields {
33 Fields::Named(fields) => &fields.named,
34 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
35 .to_compile_error()
36 .into(),
37 },
38 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
39 .to_compile_error()
40 .into(),
41 };
42
43 let field_inits = fields.iter().map(|f| {
44 let field_name = f.ident.as_ref().unwrap();
45 quote! { #field_name: app.extract() }
46 });
47
48 let expanded = quote! {
49 impl ::archy::ServiceFactory for #name {
50 fn create(app: &::archy::App) -> Self {
51 #name {
52 #(#field_inits),*
53 }
54 }
55 }
56 };
57
58 TokenStream::from(expanded)
59}
60
61#[proc_macro_attribute]
77pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream {
78 let input = parse_macro_input!(item as ItemImpl);
79 let service_name = match &*input.self_ty {
80 syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
81 _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
82 .to_compile_error()
83 .into(),
84 };
85
86 let msg_enum_name = format_ident!("{}Msg", service_name);
87 let client_trait_name = format_ident!("{}Client", service_name);
88
89 let mut methods = Vec::new();
91 for item in &input.items {
92 if let ImplItem::Fn(method) = item {
93 let is_pub = matches!(method.vis, syn::Visibility::Public(_));
94 let is_async = method.sig.asyncness.is_some();
95 let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
96
97 if is_pub && is_async && has_self {
98 methods.push(method);
99 }
100 }
101 }
102
103 let msg_variants = methods.iter().map(|method| {
105 let method_name = &method.sig.ident;
106 let variant_name = to_pascal_case(&method_name.to_string());
107 let variant_ident = format_ident!("{}", variant_name);
108
109 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
111 if let FnArg::Typed(pat_type) = arg {
112 if let Pat::Ident(pat_ident) = &*pat_type.pat {
113 let name = &pat_ident.ident;
114 let ty = &pat_type.ty;
115 return Some(quote! { #name: #ty });
116 }
117 }
118 None
119 }).collect();
120
121 if is_unit_return(&method.sig.output) {
123 quote! {
124 #variant_ident { #(#params),* }
125 }
126 } else {
127 let return_type = match &method.sig.output {
128 ReturnType::Default => quote! { () },
129 ReturnType::Type(_, ty) => quote! { #ty },
130 };
131 quote! {
132 #variant_ident { #(#params,)* respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
133 }
134 }
135 });
136
137 let match_arms = methods.iter().map(|method| {
139 let method_name = &method.sig.ident;
140 let variant_name = to_pascal_case(&method_name.to_string());
141 let variant_ident = format_ident!("{}", variant_name);
142
143 let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
145 if let FnArg::Typed(pat_type) = arg {
146 if let Pat::Ident(pat_ident) = &*pat_type.pat {
147 return Some(&pat_ident.ident);
148 }
149 }
150 None
151 }).collect();
152
153 let method_call = if param_names.is_empty() {
154 quote! { self.#method_name().await }
155 } else {
156 quote! { self.#method_name(#(#param_names),*).await }
157 };
158
159 if is_unit_return(&method.sig.output) {
161 let param_pattern = if param_names.is_empty() {
162 quote! {}
163 } else {
164 quote! { #(#param_names),* }
165 };
166 quote! {
167 #msg_enum_name::#variant_ident { #param_pattern } => {
168 #method_call;
169 }
170 }
171 } else {
172 let param_pattern = if param_names.is_empty() {
173 quote! { respond }
174 } else {
175 quote! { #(#param_names,)* respond }
176 };
177 quote! {
178 #msg_enum_name::#variant_ident { #param_pattern } => {
179 let result = #method_call;
180 let _ = respond.send(result);
181 }
182 }
183 }
184 });
185
186 let client_trait_methods = methods.iter().map(|method| {
188 let method_name = &method.sig.ident;
189
190 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
192 if let FnArg::Typed(pat_type) = arg {
193 if let Pat::Ident(pat_ident) = &*pat_type.pat {
194 let name = &pat_ident.ident;
195 let ty = &pat_type.ty;
196 return Some(quote! { #name: #ty });
197 }
198 }
199 None
200 }).collect();
201
202 let return_type = match &method.sig.output {
204 ReturnType::Default => quote! { () },
205 ReturnType::Type(_, ty) => quote! { #ty },
206 };
207
208 quote! {
209 async fn #method_name(&self, #(#params),*) -> ::std::result::Result<#return_type, ::archy::ServiceError>;
210 }
211 });
212
213 let client_impl_methods = methods.iter().map(|method| {
215 let method_name = &method.sig.ident;
216 let variant_name = to_pascal_case(&method_name.to_string());
217 let variant_ident = format_ident!("{}", variant_name);
218
219 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
221 if let FnArg::Typed(pat_type) = arg {
222 if let Pat::Ident(pat_ident) = &*pat_type.pat {
223 let name = &pat_ident.ident;
224 let ty = &pat_type.ty;
225 return Some((name.clone(), quote! { #ty }));
226 }
227 }
228 None
229 }).collect();
230
231 let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
232 let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
233
234 let return_type = match &method.sig.output {
236 ReturnType::Default => quote! { () },
237 ReturnType::Type(_, ty) => quote! { #ty },
238 };
239
240 if is_unit_return(&method.sig.output) {
242 let msg_construction = if param_names.is_empty() {
243 quote! { #msg_enum_name::#variant_ident {} }
244 } else {
245 quote! { #msg_enum_name::#variant_ident { #(#param_names),* } }
246 };
247
248 quote! {
249 async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
250 self.sender.send(#msg_construction).await
251 .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
252 Ok(())
253 }
254 }
255 } else {
256 let msg_construction = if param_names.is_empty() {
257 quote! { #msg_enum_name::#variant_ident { respond: tx } }
258 } else {
259 quote! { #msg_enum_name::#variant_ident { #(#param_names,)* respond: tx } }
260 };
261
262 quote! {
263 async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
264 let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
265 self.sender.send(#msg_construction).await
266 .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
267 rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
268 }
269 }
270 }
271 });
272
273 let expanded = quote! {
274 #input
276
277 pub enum #msg_enum_name {
279 #(#msg_variants),*
280 }
281
282 impl ::archy::Service for #service_name {
284 type Message = #msg_enum_name;
285
286 fn create(app: &::archy::App) -> Self {
287 <Self as ::archy::ServiceFactory>::create(app)
288 }
289
290 fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
291 async move {
292 match msg {
293 #(#match_arms)*
294 }
295 }
296 }
297 }
298
299 #[allow(async_fn_in_trait)]
301 pub trait #client_trait_name {
302 #(#client_trait_methods)*
303 }
304
305 impl #client_trait_name for ::archy::Client<#service_name> {
307 #(#client_impl_methods)*
308 }
309 };
310
311 TokenStream::from(expanded)
312}
313
314fn to_pascal_case(s: &str) -> String {
315 let mut result = String::new();
316 let mut capitalize_next = true;
317 for c in s.chars() {
318 if c == '_' {
319 capitalize_next = true;
320 } else if capitalize_next {
321 result.push(c.to_ascii_uppercase());
322 capitalize_next = false;
323 } else {
324 result.push(c);
325 }
326 }
327 result
328}
329
330fn is_unit_return(output: &ReturnType) -> bool {
332 match output {
333 ReturnType::Default => true,
334 ReturnType::Type(_, ty) => {
335 if let syn::Type::Tuple(tuple) = &**ty {
336 tuple.elems.is_empty()
337 } else {
338 false
339 }
340 }
341 }
342}