1use proc_macro::TokenStream;
2use quote::{ToTokens, quote};
3use std::collections::HashSet;
4use syn::{
5 DeriveInput, FnArg, GenericArgument, ItemFn, PathArguments, ReturnType, Type,
6 parse_macro_input,
7};
8
9#[proc_macro_attribute]
11pub fn transition(attr: TokenStream, item: TokenStream) -> TokenStream {
12 let mut input_fn = parse_macro_input!(item as ItemFn);
13 let original_ident = input_fn.sig.ident.clone();
14 let vis = &input_fn.vis;
15 let block = &input_fn.block;
16 let inputs = &input_fn.sig.inputs;
17
18 let internal_fn_ident = quote::format_ident!("__ranvier_fn_{}", original_ident);
22 input_fn.sig.ident = internal_fn_ident.clone();
23
24 let mut res_override = None;
26 let mut bus_allow_types: Vec<Type> = Vec::new();
27 let mut bus_deny_types: Vec<Type> = Vec::new();
28 let mut bus_allow_specified = false;
29 let mut bus_deny_specified = false;
30 let mut x_pos = None;
31 let mut y_pos = None;
32 let mut schema_flag = false;
33 if !attr.is_empty() {
34 let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
35 if let Ok(metas) = syn::parse::Parser::parse2(parser, attr.into()) {
36 for meta in metas {
37 match meta {
38 syn::Meta::Path(path) if path.is_ident("schema") => {
39 schema_flag = true;
40 }
41 syn::Meta::NameValue(nv) => {
42 if nv.path.is_ident("res") {
43 res_override = Some(nv.value);
44 } else if nv.path.is_ident("bus_allow") {
45 bus_allow_specified = true;
46 match parse_type_array_expr(&nv.value) {
47 Ok(types) => bus_allow_types = types,
48 Err(err) => return err.to_compile_error().into(),
49 }
50 } else if nv.path.is_ident("bus_deny") {
51 bus_deny_specified = true;
52 match parse_type_array_expr(&nv.value) {
53 Ok(types) => bus_deny_types = types,
54 Err(err) => return err.to_compile_error().into(),
55 }
56 } else if nv.path.is_ident("x") {
57 x_pos = Some(nv.value);
58 } else if nv.path.is_ident("y") {
59 y_pos = Some(nv.value);
60 }
61 }
62 _ => {}
63 }
64 }
65 }
66 }
67
68 if let Err(err) = validate_bus_policy_types(&bus_allow_types, &bus_deny_types) {
69 return err.to_compile_error().into();
70 }
71
72 let input_type = if let Some(FnArg::Typed(pat_type)) = inputs.first() {
74 let ty = &pat_type.ty;
75 quote! { #ty }
76 } else {
77 quote! { () }
78 };
79
80 let second_is_bus = inputs.get(1).map(is_bus_argument).unwrap_or(false);
82 let res_type = if let Some(res) = res_override {
83 quote! { #res }
84 } else if second_is_bus {
85 quote! { () }
86 } else if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
87 let ty = &pat_type.ty;
88 if let Type::Reference(type_ref) = &**ty {
89 let elem = &type_ref.elem;
90 quote! { #elem }
91 } else {
92 quote! { #ty }
93 }
94 } else {
95 quote! { () }
96 };
97
98 let (output_type, error_type) = if let ReturnType::Type(_, ty) = &input_fn.sig.output {
100 extract_outcome_types(ty).unwrap_or((quote! { () }, quote! { anyhow::Error }))
101 } else {
102 (quote! { () }, quote! { anyhow::Error })
103 };
104
105 let arg_count = inputs.len();
107 let run_body = match arg_count {
108 1 => {
109 if let Some(FnArg::Typed(pat_type)) = inputs.first() {
110 let pat = &pat_type.pat;
111 quote! { let #pat = input; #block }
112 } else {
113 quote! { #block }
114 }
115 }
116 2 => {
117 let mut bindings = quote! {};
118 if let Some(FnArg::Typed(pat_type)) = inputs.first() {
119 let pat = &pat_type.pat;
120 bindings.extend(quote! { let #pat = input; });
121 }
122 if second_is_bus {
123 if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
124 let pat = &pat_type.pat;
125 bindings.extend(quote! { let #pat = bus; });
126 }
127 } else if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
128 let pat = &pat_type.pat;
129 bindings.extend(quote! { let #pat = resources; });
130 }
131 quote! { #bindings #block }
132 }
133 3 => {
134 let mut bindings = quote! {};
135 if let Some(FnArg::Typed(pat_type)) = inputs.first() {
136 let pat = &pat_type.pat;
137 bindings.extend(quote! { let #pat = input; });
138 }
139 if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
140 let pat = &pat_type.pat;
141 bindings.extend(quote! { let #pat = resources; });
142 }
143 if let Some(FnArg::Typed(pat_type)) = inputs.get(2) {
144 let pat = &pat_type.pat;
145 bindings.extend(quote! { let #pat = bus; });
146 }
147 quote! { #bindings #block }
148 }
149 _ => quote! { #block },
150 };
151
152 let bus_policy_method = if bus_allow_specified || bus_deny_specified {
153 let allow_expr = if bus_allow_specified {
154 quote! {
155 Some(vec![#(ranvier_core::bus::BusTypeRef::of::<#bus_allow_types>()),*])
156 }
157 } else {
158 quote! { None }
159 };
160 let deny_expr = if bus_deny_specified {
161 quote! {
162 vec![#(ranvier_core::bus::BusTypeRef::of::<#bus_deny_types>()),*]
163 }
164 } else {
165 quote! { Vec::new() }
166 };
167 quote! {
168 fn bus_access_policy(&self) -> Option<ranvier_core::bus::BusAccessPolicy> {
169 Some(ranvier_core::bus::BusAccessPolicy {
170 allow: #allow_expr,
171 deny: #deny_expr,
172 })
173 }
174 }
175 } else {
176 quote! {}
177 };
178
179 let position_method = if let (Some(x), Some(y)) = (x_pos, y_pos) {
180 quote! {
181 fn position(&self) -> Option<(f32, f32)> {
182 Some((#x as f32, #y as f32))
183 }
184 }
185 } else {
186 quote! {}
187 };
188
189 let schema_method = if schema_flag {
190 quote! {
191 fn input_schema(&self) -> Option<serde_json::Value> {
192 let schema = schemars::schema_for!(#input_type);
193 serde_json::to_value(schema).ok()
194 }
195 }
196 } else {
197 quote! {}
198 };
199
200 let expanded = quote! {
201 #[derive(Clone, Default)]
202 #[allow(non_camel_case_types)]
203 #vis struct #original_ident;
204
205 #[::async_trait::async_trait]
206 impl ranvier_core::transition::Transition<#input_type, #output_type> for #original_ident {
207 type Error = #error_type;
208 type Resources = #res_type;
209
210 #bus_policy_method
211 #position_method
212 #schema_method
213
214 async fn run(
215 &self,
216 input: #input_type,
217 resources: &Self::Resources,
218 bus: &mut ranvier_core::bus::Bus,
219 ) -> ranvier_core::outcome::Outcome<#output_type, Self::Error> {
220 #run_body
221 }
222 }
223
224 #input_fn
225 };
226
227 TokenStream::from(expanded)
228}
229
230#[proc_macro_attribute]
232pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
233 let input_fn = parse_macro_input!(item as ItemFn);
234 let original_ident = input_fn.sig.ident.clone();
235 let vis = &input_fn.vis;
236
237 let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
238 let attr_args = parse_macro_input!(attr with parser);
239
240 if attr_args.len() < 2 {
241 return TokenStream::from(
242 quote! { compile_error!("route attribute requires method and path"); },
243 );
244 }
245
246 let method = &attr_args[0];
247 let path = &attr_args[1];
248
249 let struct_name = quote::format_ident!("Route_{}", original_ident);
251
252 let expanded = quote! {
253 #input_fn
254
255 #[allow(non_camel_case_types)]
256 #vis struct #struct_name;
257
258 impl #struct_name {
259 pub const METHOD: &'static str = stringify!(#method);
260 pub const PATH: &'static str = #path;
261 }
262 };
263
264 TokenStream::from(expanded)
265}
266
267#[proc_macro]
269pub fn ranvier_router(input: TokenStream) -> TokenStream {
270 let parser = syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated;
271 let idents = parse_macro_input!(input with parser);
272
273 let mut registrations = quote! {};
274
275 for ident in idents {
276 let route_struct = quote::format_ident!("Route_{}", ident);
277 registrations.extend(quote! {
278 let method_str = #route_struct::METHOD;
279 let method = match method_str {
280 "GET" => http::Method::GET,
281 "POST" => http::Method::POST,
282 "PUT" => http::Method::PUT,
283 "DELETE" => http::Method::DELETE,
284 _ => http::Method::GET,
285 };
286 ingress = ingress.route_method(method, #route_struct::PATH, #ident().await);
287 });
288 }
289
290 let expanded = quote! {
291 {
292 let mut ingress = ranvier_http::HttpIngress::new();
293 #registrations
294 ingress
295 }
296 };
297
298 TokenStream::from(expanded)
299}
300
301fn extract_outcome_types(
302 ty: &Type,
303) -> Option<(quote::__private::TokenStream, quote::__private::TokenStream)> {
304 if let Type::Path(type_path) = ty
305 && let Some(segment) = type_path.path.segments.last()
306 && segment.ident == "Outcome"
307 && let PathArguments::AngleBracketed(args) = &segment.arguments
308 {
309 let mut type_args = args.args.iter();
310 if let (Some(GenericArgument::Type(to)), Some(GenericArgument::Type(err))) =
311 (type_args.next(), type_args.next())
312 {
313 return Some((quote! { #to }, quote! { #err }));
314 }
315 }
316 None
317}
318
319fn is_bus_argument(arg: &FnArg) -> bool {
320 let FnArg::Typed(pat_type) = arg else {
321 return false;
322 };
323 let Type::Reference(type_ref) = &*pat_type.ty else {
324 return false;
325 };
326 let Type::Path(type_path) = &*type_ref.elem else {
327 return false;
328 };
329 type_path
330 .path
331 .segments
332 .last()
333 .map(|segment| segment.ident == "Bus")
334 .unwrap_or(false)
335}
336
337fn parse_type_array_expr(expr: &syn::Expr) -> syn::Result<Vec<Type>> {
338 let syn::Expr::Array(array) = expr else {
339 return Err(syn::Error::new_spanned(
340 expr,
341 "expected array syntax: [TypeA, TypeB]",
342 ));
343 };
344
345 array
346 .elems
347 .iter()
348 .map(|elem| syn::parse2::<Type>(elem.to_token_stream()))
349 .collect()
350}
351
352#[proc_macro_derive(ResourceRequirement)]
369pub fn derive_resource_requirement(input: TokenStream) -> TokenStream {
370 let input = parse_macro_input!(input as DeriveInput);
371 let name = &input.ident;
372 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
373
374 let expanded = quote! {
375 impl #impl_generics ranvier_core::transition::ResourceRequirement for #name #ty_generics #where_clause {}
376 };
377
378 TokenStream::from(expanded)
379}
380
381fn validate_bus_policy_types(allow: &[Type], deny: &[Type]) -> syn::Result<()> {
382 let mut allow_keys = HashSet::new();
383 for ty in allow {
384 let key = ty.to_token_stream().to_string();
385 if !allow_keys.insert(key) {
386 return Err(syn::Error::new_spanned(
387 ty,
388 "duplicate type in bus_allow list",
389 ));
390 }
391 }
392
393 let mut deny_keys = HashSet::new();
394 for ty in deny {
395 let key = ty.to_token_stream().to_string();
396 if !deny_keys.insert(key) {
397 return Err(syn::Error::new_spanned(
398 ty,
399 "duplicate type in bus_deny list",
400 ));
401 }
402 }
403
404 for ty in allow {
405 let key = ty.to_token_stream().to_string();
406 if deny_keys.contains(&key) {
407 return Err(syn::Error::new_spanned(
408 ty,
409 "same type cannot be present in both bus_allow and bus_deny",
410 ));
411 }
412 }
413
414 Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419 use super::{is_bus_argument, parse_type_array_expr, validate_bus_policy_types};
420 use syn::{Expr, FnArg, parse_quote};
421
422 #[test]
423 fn detects_mut_bus_reference_argument() {
424 let arg: FnArg = parse_quote!(bus: &mut Bus);
425 assert!(is_bus_argument(&arg));
426 }
427
428 #[test]
429 fn detects_fully_qualified_bus_reference_argument() {
430 let arg: FnArg = parse_quote!(bus: &mut ranvier_core::bus::Bus);
431 assert!(is_bus_argument(&arg));
432 }
433
434 #[test]
435 fn rejects_non_bus_argument() {
436 let arg: FnArg = parse_quote!(res: &MyResources);
437 assert!(!is_bus_argument(&arg));
438 }
439
440 #[test]
441 fn parses_type_array_expr_for_bus_policy() {
442 let expr: Expr = parse_quote!([i32, alloc::string::String]);
443 let parsed = parse_type_array_expr(&expr).expect("type array should parse");
444 assert_eq!(parsed.len(), 2);
445 }
446
447 #[test]
448 fn validates_bus_policy_rejects_duplicate_allow() {
449 let allow = vec![parse_quote!(i32), parse_quote!(i32)];
450 let deny = Vec::new();
451 let err = validate_bus_policy_types(&allow, &deny).expect_err("should fail");
452 assert!(err.to_string().contains("duplicate type in bus_allow"));
453 }
454
455 #[test]
456 fn validates_bus_policy_rejects_allow_deny_conflict() {
457 let allow = vec![parse_quote!(i32)];
458 let deny = vec![parse_quote!(i32)];
459 let err = validate_bus_policy_types(&allow, &deny).expect_err("should fail");
460 assert!(
461 err.to_string()
462 .contains("same type cannot be present in both bus_allow and bus_deny")
463 );
464 }
465}