1use heck::{ToShoutySnakeCase, ToSnakeCase};
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{
5 Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Ident, Index,
6 parse_macro_input,
7};
8
9static BANNED_LOWER_PREFIX: &str = "rinf";
10
11#[proc_macro_derive(SignalPiece)]
16pub fn derive_signal_piece(input: TokenStream) -> TokenStream {
17 let ast = parse_macro_input!(input as DeriveInput);
19 let name = &ast.ident;
20 let name_lit = name.to_string();
21
22 if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
24 return create_name_error(ast);
25 }
26
27 if ast.generics.params.iter().count() != 0 {
29 return create_generic_error(ast);
30 }
31
32 let expanded = match &ast.data {
34 Data::Struct(data_struct) => get_struct_signal_impl(data_struct, name),
35 Data::Enum(data_enum) => get_enum_signal_impl(data_enum, name),
36 _ => return TokenStream::new(),
37 };
38
39 TokenStream::from(expanded)
41}
42
43#[proc_macro_derive(DartSignal)]
47pub fn derive_dart_signal(input: TokenStream) -> TokenStream {
48 derive_dart_signal_real(input, false)
49}
50
51#[proc_macro_derive(DartSignalBinary)]
55pub fn derive_dart_signal_binary(input: TokenStream) -> TokenStream {
56 derive_dart_signal_real(input, true)
57}
58
59fn derive_dart_signal_real(
60 input: TokenStream,
61 include_binary: bool,
62) -> TokenStream {
63 let ast = parse_macro_input!(input as DeriveInput);
65 let name = &ast.ident;
66 let name_lit = name.to_string();
67 let snake_name = name_lit.to_snake_case();
68 let upper_snake_name = name_lit.to_shouty_snake_case();
69
70 if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
72 return create_name_error(ast);
73 }
74
75 if ast.generics.params.iter().count() != 0 {
77 return create_generic_error(ast);
78 }
79
80 let where_clause = match &ast.data {
82 Data::Struct(data_struct) => get_struct_where_clause(data_struct),
83 Data::Enum(data_enum) => get_enum_where_clause(data_enum),
84 _ => return TokenStream::new(),
85 };
86
87 let channel_type_ident = Ident::new(&format!("{}Channel", name), name.span());
89 let channel_const_ident =
90 Ident::new(&format!("{}_CHANNEL", upper_snake_name), name.span());
91 let extern_fn_name = &format!("rinf_send_dart_signal_{}", snake_name);
92 let extern_fn_ident = Ident::new(extern_fn_name, name.span());
93
94 let signal_trait = if include_binary {
96 quote! { rinf::DartSignalBinary }
97 } else {
98 quote! { rinf::DartSignal }
99 };
100 let expanded = quote! {
101 impl #signal_trait for #name #where_clause {
102 fn get_dart_signal_receiver(
103 ) -> rinf::SignalReceiver<rinf::DartSignalPack<Self>> {
104 #channel_const_ident.1.clone()
105 }
106 }
107
108 impl #name #where_clause {
109 fn send_dart_signal(message_bytes: &[u8], binary: &[u8]) {
110 use rinf::{AppError, DartSignalPack, debug_print, deserialize};
111 let message_result: Result<#name, AppError> =
112 deserialize(message_bytes)
113 .map_err(|_| AppError::CannotDecodeMessage);
114 let message = match message_result {
115 Ok(inner) => inner,
116 Err(err) => {
117 let type_name = #name_lit;
118 debug_print!("{}: \n{}", type_name, err);
119 return;
120 }
121 };
122 let dart_signal = DartSignalPack {
123 message,
124 binary: binary.to_vec(),
125 };
126 #channel_const_ident.0.send(dart_signal);
127 }
128 }
129
130 type #channel_type_ident = std::sync::LazyLock<(
131 rinf::SignalSender<rinf::DartSignalPack<#name>>,
132 rinf::SignalReceiver<rinf::DartSignalPack<#name>>,
133 )>;
134
135 static #channel_const_ident: #channel_type_ident =
136 std::sync::LazyLock::new(rinf::signal_channel);
137
138 #[cfg(not(target_family = "wasm"))]
139 #[unsafe(no_mangle)]
140 unsafe extern "C" fn #extern_fn_ident(
141 message_pointer: *const u8,
142 message_size: usize,
143 binary_pointer: *const u8,
144 binary_size: usize,
145 ) {
146 use std::slice::from_raw_parts;
147 let message_bytes = from_raw_parts(message_pointer, message_size);
148 let binary = from_raw_parts(binary_pointer, binary_size);
149 #name::send_dart_signal(message_bytes, binary);
150 }
151
152 #[cfg(target_family = "wasm")]
153 #[wasm_bindgen::prelude::wasm_bindgen]
154 pub fn #extern_fn_ident(message_bytes: &[u8], binary: &[u8]) {
155 #name::send_dart_signal(message_bytes, binary);
156 }
157 };
158
159 TokenStream::from(expanded)
161}
162
163#[proc_macro_derive(RustSignal)]
167pub fn derive_rust_signal(input: TokenStream) -> TokenStream {
168 derive_rust_signal_real(input, false)
169}
170
171#[proc_macro_derive(RustSignalBinary)]
175pub fn derive_rust_signal_binary(input: TokenStream) -> TokenStream {
176 derive_rust_signal_real(input, true)
177}
178
179fn derive_rust_signal_real(
180 input: TokenStream,
181 include_binary: bool,
182) -> TokenStream {
183 let ast = parse_macro_input!(input as DeriveInput);
185 let name = &ast.ident;
186 let name_lit = name.to_string();
187
188 if name_lit.to_lowercase().starts_with(BANNED_LOWER_PREFIX) {
190 return create_name_error(ast);
191 }
192
193 if ast.generics.params.iter().count() != 0 {
195 return create_generic_error(ast);
196 }
197
198 let where_clause = match &ast.data {
200 Data::Struct(data_struct) => get_struct_where_clause(data_struct),
201 Data::Enum(data_enum) => get_enum_where_clause(data_enum),
202 _ => return TokenStream::new(),
203 };
204
205 let expanded = if include_binary {
207 quote! {
208 impl rinf::RustSignalBinary for #name #where_clause {
209 fn send_signal_to_dart(&self, binary: Vec<u8>) {
210 use rinf::{AppError, debug_print, send_rust_signal, serialize};
211 let type_name = #name_lit;
212 let message_result: Result<Vec<u8>, AppError> =
213 serialize(&self)
214 .map_err(|_| AppError::CannotEncodeMessage);
215 let message_bytes = match message_result {
216 Ok(inner) => inner,
217 Err(err) => {
218 debug_print!("{}: \n{}", type_name, err);
219 return;
220 }
221 };
222 let result = send_rust_signal(type_name, message_bytes, binary);
223 if let Err(err) = result {
224 debug_print!("{}: \n{}", type_name, err);
225 }
226 }
227 }
228 }
229 } else {
230 quote! {
231 impl rinf::RustSignal for #name #where_clause {
232 fn send_signal_to_dart(&self) {
233 use rinf::{AppError, debug_print, send_rust_signal, serialize};
234 let type_name = #name_lit;
235 let message_result: Result<Vec<u8>, AppError> =
236 serialize(&self)
237 .map_err(|_| AppError::CannotEncodeMessage);
238 let message_bytes = match message_result {
239 Ok(inner) => inner,
240 Err(err) => {
241 debug_print!("{}: \n{}", type_name, err);
242 return;
243 }
244 };
245 let result = send_rust_signal(type_name, message_bytes, Vec::new());
246 if let Err(err) = result {
247 debug_print!("{}: \n{}", type_name, err);
248 }
249 }
250 }
251 }
252 };
253
254 TokenStream::from(expanded)
256}
257
258fn get_struct_where_clause(
261 data_struct: &DataStruct,
262) -> proc_macro2::TokenStream {
263 let field_types: Vec<_> = match &data_struct.fields {
264 Fields::Named(all) => all.named.iter().map(|f| &f.ty).collect(),
266 Fields::Unnamed(all) => all.unnamed.iter().map(|f| &f.ty).collect(),
268 Fields::Unit => Vec::new(),
270 };
271 quote! {
272 where #(#field_types: rinf::SignalPiece),*
273 }
274}
275
276fn get_enum_where_clause(data_enum: &DataEnum) -> proc_macro2::TokenStream {
279 let variant_types: Vec<_> = data_enum
280 .variants
281 .iter()
282 .flat_map(|variant| {
283 match &variant.fields {
284 Fields::Named(all) => all.named.iter().map(|f| &f.ty).collect(),
286 Fields::Unnamed(all) => all.unnamed.iter().map(|f| &f.ty).collect(),
288 Fields::Unit => Vec::new(),
290 }
291 })
292 .collect();
293
294 quote! {
295 where #(#variant_types: rinf::SignalPiece),*
296 }
297}
298
299fn get_struct_signal_impl(
300 data_struct: &DataStruct,
301 name: &Ident,
302) -> proc_macro2::TokenStream {
303 match &data_struct.fields {
304 Fields::Named(named_fields) => {
305 let fields = named_fields
306 .named
307 .iter()
308 .filter_map(|field| field.ident.clone());
309 quote! {
310 impl rinf::SignalPiece for #name {
311 fn be_signal_piece(&self) {
312 use rinf::SignalPiece;
313 #(SignalPiece::be_signal_piece(&self.#fields);)*
314 }
315 }
316 }
317 }
318 Fields::Unnamed(unnamed_fields) => {
319 let field_indices: Vec<Index> =
320 (0..unnamed_fields.unnamed.len()).map(Index::from).collect();
321 quote! {
322 impl rinf::SignalPiece for #name {
323 fn be_signal_piece(&self) {
324 use rinf::SignalPiece;
325 #(SignalPiece::be_signal_piece(&self.#field_indices);)*
326 }
327 }
328 }
329 }
330 Fields::Unit => {
331 quote! {
332 impl rinf::SignalPiece for #name {
333 fn be_signal_piece(&self) {
334 }
336 }
337 }
338 }
339 }
340}
341
342fn get_enum_signal_impl(
343 data_enum: &DataEnum,
344 name: &Ident,
345) -> proc_macro2::TokenStream {
346 let variants = data_enum.variants.iter().map(|variant| {
347 let variant_ident = &variant.ident;
348 match &variant.fields {
349 Fields::Named(named_fields) => {
350 let fields: Vec<Ident> = named_fields
351 .named
352 .iter()
353 .filter_map(|field| field.ident.clone())
354 .collect();
355 quote! {
356 Self::#variant_ident { #(#fields),* } => {
357 use rinf::SignalPiece;
358 #(SignalPiece::be_signal_piece(#fields);)*
359 }
360 }
361 }
362 Fields::Unnamed(unnamed_fields) => {
363 let field_indices: Vec<Index> =
364 (0..unnamed_fields.unnamed.len()).map(Index::from).collect();
365 let field_vars: Vec<Ident> = field_indices
366 .iter()
367 .map(|i| {
368 Ident::new(&format!("field_{}", i.index), variant_ident.span())
369 })
370 .collect();
371 quote! {
372 Self::#variant_ident(#(#field_vars),*) => {
373 use rinf::SignalPiece;
374 #(SignalPiece::be_signal_piece(#field_vars);)*
375 }
376 }
377 }
378 Fields::Unit => {
379 quote! {
380 Self::#variant_ident => {}
381 }
382 }
383 }
384 });
385 quote! {
386 impl rinf::SignalPiece for #name {
387 fn be_signal_piece(&self) {
388 match self {
389 #( #variants )*
390 }
391 }
392 }
393 }
394}
395
396fn create_generic_error(ast: DeriveInput) -> TokenStream {
397 Error::new_spanned(ast.generics, "A foreign signal type cannot be generic")
398 .to_compile_error()
399 .into()
400}
401
402fn create_name_error(ast: DeriveInput) -> TokenStream {
403 Error::new_spanned(
404 ast.ident,
405 format!(
406 "The name of a foreign signal cannot start with `{}`",
407 BANNED_LOWER_PREFIX
408 ),
409 )
410 .to_compile_error()
411 .into()
412}