rinf_proc/
lib.rs

1use heck::{ToShoutySnakeCase, ToSnakeCase};
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{
5  Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Ident, Index,
6  parse_macro_input,
7};
8
9static BANNED_LOWER_PREFIX: &str = "rinf";
10
11/// Marks the struct as a signal
12/// that can be nested within other signals.
13/// A `SignalPiece` cannot be sent independently
14/// and is only a partial component of `DartSignal` or `RustSignal`.
15#[proc_macro_derive(SignalPiece)]
16pub fn derive_signal_piece(input: TokenStream) -> TokenStream {
17  // Collect information about the item.
18  let ast = parse_macro_input!(input as DeriveInput);
19  let name = &ast.ident;
20  let name_lit = name.to_string();
21
22  // Check the name.
23  if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
24    return create_name_error(ast);
25  }
26
27  // Ban generic types.
28  if ast.generics.params.iter().count() != 0 {
29    return create_generic_error(ast);
30  }
31
32  // Enforce all fields to implement the foreign signal trait.
33  let expanded = match &ast.data {
34    Data::Struct(data_struct) => get_struct_signal_impl(data_struct, name),
35    Data::Enum(data_enum) => get_enum_signal_impl(data_enum, name),
36    _ => return TokenStream::new(),
37  };
38
39  // Convert the generated code into token stream and return it.
40  TokenStream::from(expanded)
41}
42
43/// Marks the struct as a signal endpoint
44/// that contains a message from Dart to Rust.
45/// This can be marked on any type that implements `Deserialize`.
46#[proc_macro_derive(DartSignal)]
47pub fn derive_dart_signal(input: TokenStream) -> TokenStream {
48  derive_dart_signal_real(input, false)
49}
50
51/// Marks the struct as a signal endpoint
52/// that contains a message and binary from Dart to Rust.
53/// This can be marked on any type that implements `Deserialize`.
54#[proc_macro_derive(DartSignalBinary)]
55pub fn derive_dart_signal_binary(input: TokenStream) -> TokenStream {
56  derive_dart_signal_real(input, true)
57}
58
59fn derive_dart_signal_real(
60  input: TokenStream,
61  include_binary: bool,
62) -> TokenStream {
63  // Collect information about the item.
64  let ast = parse_macro_input!(input as DeriveInput);
65  let name = &ast.ident;
66  let name_lit = name.to_string();
67  let snake_name = name_lit.to_snake_case();
68  let upper_snake_name = name_lit.to_shouty_snake_case();
69
70  // Check the name.
71  if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
72    return create_name_error(ast);
73  }
74
75  // Ban generic types.
76  if ast.generics.params.iter().count() != 0 {
77    return create_generic_error(ast);
78  }
79
80  // Enforce all fields to implement the foreign signal trait.
81  let where_clause = match &ast.data {
82    Data::Struct(data_struct) => get_struct_where_clause(data_struct),
83    Data::Enum(data_enum) => get_enum_where_clause(data_enum),
84    _ => return TokenStream::new(),
85  };
86
87  // Collect identifiers and names.
88  let channel_type_ident = Ident::new(&format!("{}Channel", name), name.span());
89  let channel_const_ident =
90    Ident::new(&format!("{}_CHANNEL", upper_snake_name), name.span());
91  let extern_fn_name = &format!("rinf_send_dart_signal_{}", snake_name);
92  let extern_fn_ident = Ident::new(extern_fn_name, name.span());
93
94  // Implement methods and extern functions.
95  let signal_trait = if include_binary {
96    quote! { rinf::DartSignalBinary }
97  } else {
98    quote! { rinf::DartSignal }
99  };
100  let expanded = quote! {
101    impl #signal_trait for #name #where_clause {
102      fn get_dart_signal_receiver(
103      ) -> rinf::SignalReceiver<rinf::DartSignalPack<Self>> {
104        #channel_const_ident.1.clone()
105      }
106    }
107
108    impl #name #where_clause {
109      fn send_dart_signal(message_bytes: &[u8], binary: &[u8]) {
110        use rinf::{AppError, DartSignalPack, debug_print, deserialize};
111        let message_result: Result<#name, AppError> =
112          deserialize(message_bytes)
113          .map_err(|_| AppError::CannotDecodeMessage);
114        let message = match message_result {
115          Ok(inner) => inner,
116          Err(err) => {
117            let type_name = #name_lit;
118            debug_print!("{}: \n{}", type_name, err);
119            return;
120          }
121        };
122        let dart_signal = DartSignalPack {
123          message,
124          binary: binary.to_vec(),
125        };
126        #channel_const_ident.0.send(dart_signal);
127      }
128    }
129
130    type #channel_type_ident = std::sync::LazyLock<(
131      rinf::SignalSender<rinf::DartSignalPack<#name>>,
132      rinf::SignalReceiver<rinf::DartSignalPack<#name>>,
133    )>;
134
135    static #channel_const_ident: #channel_type_ident =
136      std::sync::LazyLock::new(rinf::signal_channel);
137
138    #[cfg(not(target_family = "wasm"))]
139    #[unsafe(no_mangle)]
140    unsafe extern "C" fn #extern_fn_ident(
141      message_pointer: *const u8,
142      message_size: usize,
143      binary_pointer: *const u8,
144      binary_size: usize,
145    ) {
146      use std::slice::from_raw_parts;
147      let message_bytes = from_raw_parts(message_pointer, message_size);
148      let binary = from_raw_parts(binary_pointer, binary_size);
149      #name::send_dart_signal(message_bytes, binary);
150    }
151
152    #[cfg(target_family = "wasm")]
153    #[wasm_bindgen::prelude::wasm_bindgen]
154    pub fn #extern_fn_ident(message_bytes: &[u8], binary: &[u8]) {
155      #name::send_dart_signal(message_bytes, binary);
156    }
157  };
158
159  // Convert the generated code into token stream and return it.
160  TokenStream::from(expanded)
161}
162
163/// Marks the struct as a signal endpoint
164/// that contains a message from Rust to Dart.
165/// This can be marked on any type that implements `Serialize`.
166#[proc_macro_derive(RustSignal)]
167pub fn derive_rust_signal(input: TokenStream) -> TokenStream {
168  derive_rust_signal_real(input, false)
169}
170
171/// Marks the struct as a signal endpoint
172/// that contains a message and binary from Rust to Dart.
173/// This can be marked on any type that implements `Serialize`.
174#[proc_macro_derive(RustSignalBinary)]
175pub fn derive_rust_signal_binary(input: TokenStream) -> TokenStream {
176  derive_rust_signal_real(input, true)
177}
178
179fn derive_rust_signal_real(
180  input: TokenStream,
181  include_binary: bool,
182) -> TokenStream {
183  // Collect information about the item.
184  let ast = parse_macro_input!(input as DeriveInput);
185  let name = &ast.ident;
186  let name_lit = name.to_string();
187
188  // Check the name.
189  if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
190    return create_name_error(ast);
191  }
192
193  // Ban generic types.
194  if ast.generics.params.iter().count() != 0 {
195    return create_generic_error(ast);
196  }
197
198  // Enforce all fields to implement the foreign signal trait.
199  let where_clause = match &ast.data {
200    Data::Struct(data_struct) => get_struct_where_clause(data_struct),
201    Data::Enum(data_enum) => get_enum_where_clause(data_enum),
202    _ => return TokenStream::new(),
203  };
204
205  // Implement methods and extern functions.
206  let expanded = if include_binary {
207    quote! {
208      impl rinf::RustSignalBinary for #name #where_clause {
209        fn send_signal_to_dart(&self, binary: Vec<u8>) {
210          use rinf::{AppError, debug_print, send_rust_signal, serialize};
211          let type_name = #name_lit;
212          let message_result: Result<Vec<u8>, AppError> =
213            serialize(&self)
214            .map_err(|_| AppError::CannotEncodeMessage);
215          let message_bytes = match message_result {
216            Ok(inner) => inner,
217            Err(err) => {
218              debug_print!("{}: \n{}", type_name, err);
219              return;
220            }
221          };
222          let result = send_rust_signal(type_name, message_bytes, binary);
223          if let Err(err) = result {
224            debug_print!("{}: \n{}", type_name, err);
225          }
226        }
227      }
228    }
229  } else {
230    quote! {
231      impl rinf::RustSignal for #name #where_clause {
232        fn send_signal_to_dart(&self) {
233          use rinf::{AppError, debug_print, send_rust_signal, serialize};
234          let type_name = #name_lit;
235          let message_result: Result<Vec<u8>, AppError> =
236            serialize(&self)
237            .map_err(|_| AppError::CannotEncodeMessage);
238          let message_bytes = match message_result {
239            Ok(inner) => inner,
240            Err(err) => {
241              debug_print!("{}: \n{}", type_name, err);
242              return;
243            }
244          };
245          let result = send_rust_signal(type_name, message_bytes, Vec::new());
246          if let Err(err) = result {
247            debug_print!("{}: \n{}", type_name, err);
248          }
249        }
250      }
251    }
252  };
253
254  // Convert the generated code into token stream and return it.
255  TokenStream::from(expanded)
256}
257
258/// Enforces all fields of a struct to have the foreign signal trait.
259/// This assists with type-safe development.
260fn get_struct_where_clause(
261  data_struct: &DataStruct,
262) -> proc_macro2::TokenStream {
263  let field_types: Vec<_> = match &data_struct.fields {
264    // For named structs (struct-like), extract the field types.
265    Fields::Named(all) => all.named.iter().map(|f| &f.ty).collect(),
266    // For unnamed structs (tuple-like), extract the field types.
267    Fields::Unnamed(all) => all.unnamed.iter().map(|f| &f.ty).collect(),
268    // For unit-like structs (without any inner data), do nothing.
269    Fields::Unit => Vec::new(),
270  };
271  quote! {
272    where #(#field_types: rinf::SignalPiece),*
273  }
274}
275
276/// Enforces all fields of an enum variant to have the foreign signal trait.
277/// This assists with type-safe development.
278fn get_enum_where_clause(data_enum: &DataEnum) -> proc_macro2::TokenStream {
279  let variant_types: Vec<_> = data_enum
280    .variants
281    .iter()
282    .flat_map(|variant| {
283      match &variant.fields {
284        // For named variants (struct-like), extract the field types.
285        Fields::Named(all) => all.named.iter().map(|f| &f.ty).collect(),
286        // For unnamed variants (tuple-like), extract the field types.
287        Fields::Unnamed(all) => all.unnamed.iter().map(|f| &f.ty).collect(),
288        // For unit-like variants (without any inner data), do nothing.
289        Fields::Unit => Vec::new(),
290      }
291    })
292    .collect();
293
294  quote! {
295    where #(#variant_types: rinf::SignalPiece),*
296  }
297}
298
299fn get_struct_signal_impl(
300  data_struct: &DataStruct,
301  name: &Ident,
302) -> proc_macro2::TokenStream {
303  match &data_struct.fields {
304    Fields::Named(named_fields) => {
305      let fields = named_fields
306        .named
307        .iter()
308        .filter_map(|field| field.ident.clone());
309      quote! {
310        impl rinf::SignalPiece for #name {
311          fn be_signal_piece(&self) {
312            use rinf::SignalPiece;
313            #(SignalPiece::be_signal_piece(&self.#fields);)*
314          }
315        }
316      }
317    }
318    Fields::Unnamed(unnamed_fields) => {
319      let field_indices: Vec<Index> =
320        (0..unnamed_fields.unnamed.len()).map(Index::from).collect();
321      quote! {
322        impl rinf::SignalPiece for #name {
323          fn be_signal_piece(&self) {
324            use rinf::SignalPiece;
325            #(SignalPiece::be_signal_piece(&self.#field_indices);)*
326          }
327        }
328      }
329    }
330    Fields::Unit => {
331      quote! {
332        impl rinf::SignalPiece for #name {
333          fn be_signal_piece(&self) {
334            // Unit struct has no fields to check
335          }
336        }
337      }
338    }
339  }
340}
341
342fn get_enum_signal_impl(
343  data_enum: &DataEnum,
344  name: &Ident,
345) -> proc_macro2::TokenStream {
346  let variants = data_enum.variants.iter().map(|variant| {
347    let variant_ident = &variant.ident;
348    match &variant.fields {
349      Fields::Named(named_fields) => {
350        let fields: Vec<Ident> = named_fields
351          .named
352          .iter()
353          .filter_map(|field| field.ident.clone())
354          .collect();
355        quote! {
356          Self::#variant_ident { #(#fields),* } => {
357            use rinf::SignalPiece;
358            #(SignalPiece::be_signal_piece(#fields);)*
359          }
360        }
361      }
362      Fields::Unnamed(unnamed_fields) => {
363        let field_indices: Vec<Index> =
364          (0..unnamed_fields.unnamed.len()).map(Index::from).collect();
365        let field_vars: Vec<Ident> = field_indices
366          .iter()
367          .map(|i| {
368            Ident::new(&format!("field_{}", i.index), variant_ident.span())
369          })
370          .collect();
371        quote! {
372          Self::#variant_ident(#(#field_vars),*) => {
373            use rinf::SignalPiece;
374            #(SignalPiece::be_signal_piece(#field_vars);)*
375          }
376        }
377      }
378      Fields::Unit => {
379        quote! {
380          Self::#variant_ident => {}
381        }
382      }
383    }
384  });
385  quote! {
386    impl rinf::SignalPiece for #name {
387      fn be_signal_piece(&self) {
388        match self {
389          #( #variants )*
390        }
391      }
392    }
393  }
394}
395
396fn create_generic_error(ast: DeriveInput) -> TokenStream {
397  Error::new_spanned(ast.generics, "A foreign signal type cannot be generic")
398    .to_compile_error()
399    .into()
400}
401
402fn create_name_error(ast: DeriveInput) -> TokenStream {
403  Error::new_spanned(
404    ast.ident,
405    format!(
406      "The name of a foreign signal cannot start with `{}`",
407      BANNED_LOWER_PREFIX
408    ),
409  )
410  .to_compile_error()
411  .into()
412}