agility_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, GenericArgument, PathArguments, Type, TypePath, parse_macro_input};
6
7/// Helper function to check if a type is `Signal<'a, T>` and extract the inner type `T`.
8///
9/// This is used by the `#[derive(Lift)]` macro generator to detect struct fields
10/// that are reactive `Signal` types and to obtain their wrapped inner types so
11/// the generated inner struct contains the unwrapped types.
12fn extract_signal_inner_type(ty: &Type) -> Option<&Type> {
13    if let Type::Path(TypePath { path, .. }) = ty {
14        // Get the last segment of the path (e.g., "Signal" from "crate::signal::Signal")
15        let last_segment = path.segments.last()?;
16
17        // Check if it's named "Signal"
18        if last_segment.ident != "Signal" {
19            return None;
20        }
21
22        // Extract the generic arguments
23        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
24            // Signal<'a, T> has two generic arguments: lifetime 'a and type T
25            // We want to extract T (the second argument)
26            let mut iter = args.args.iter();
27            iter.next()?; // Skip the lifetime
28
29            if let Some(GenericArgument::Type(inner_ty)) = iter.next() {
30                return Some(inner_ty);
31            }
32        }
33    }
34    None
35}
36
37/// Helper function to check if a type is `SignalSync<'a, T>` and extract the inner type `T`.
38///
39/// This is used by the `#[derive(LiftSync)]` macro generator to detect struct fields
40/// that are thread-safe reactive `SignalSync` types and to obtain their wrapped inner
41/// types so the generated inner struct contains the unwrapped types.
42fn extract_signal_sync_inner_type(ty: &Type) -> Option<&Type> {
43    if let Type::Path(TypePath { path, .. }) = ty {
44        // Get the last segment of the path (e.g., "SignalSync" from "crate::signal_sync::SignalSync")
45        let last_segment = path.segments.last()?;
46
47        // Check if it's named "SignalSync"
48        if last_segment.ident != "SignalSync" {
49            return None;
50        }
51
52        // Extract the generic arguments
53        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
54            // SignalSync<'a, T> has two generic arguments: lifetime 'a and type T
55            // We want to extract T (the second argument)
56            let mut iter = args.args.iter();
57            iter.next()?; // Skip the lifetime
58
59            if let Some(GenericArgument::Type(inner_ty)) = iter.next() {
60                return Some(inner_ty);
61            }
62        }
63    }
64    None
65}
66
67/// Derive macro to lift a struct into a reactive `Signal`.
68///
69/// Applying `#[derive(Lift)]` to a struct generates an inner unwrapped struct and
70/// a `lift(self)` method that produces a `crate::signal::Signal<'a, _Inner>` where
71/// any fields of type `Signal<'a, T>` are replaced by their inner `T` in the
72/// generated inner struct. The generated `lift` method wires up reactions so that
73/// changes to any signal fields propagate into the resulting lifted `Signal`.
74///
75/// Example:
76/// ```rust
77/// use crate::signal::Signal;
78///
79/// #[derive(Lift)]
80/// struct Example<'a> {
81///     a: Signal<'a, i32>,
82///     b: String,
83/// }
84///
85/// let example = Example { a: Signal::new(1), b: "hi".to_string() };
86/// let lifted = example.lift(); // Signal<'a, _Example>
87/// lifted.with(|inner| println!("a = {}", inner.a));
88/// ```
89#[proc_macro_derive(Lift)]
90pub fn derive_lift(input: TokenStream) -> TokenStream {
91    let input = parse_macro_input!(input as DeriveInput);
92    let name = &input.ident;
93    let generics = &input.generics;
94    let vis = &input.vis;
95
96    // Get the fields
97    let fields = match &input.data {
98        syn::Data::Struct(data) => match &data.fields {
99            Fields::Named(fields) => &fields.named,
100            _ => panic!("Lift only supports structs with named fields"),
101        },
102        _ => panic!("Lift can only be derived for structs"),
103    };
104
105    // Separate signal fields from regular fields by checking the type
106    let mut signal_fields = Vec::new();
107    let mut regular_fields = Vec::new();
108
109    for field in fields {
110        if extract_signal_inner_type(&field.ty).is_some() {
111            signal_fields.push(field);
112        } else {
113            regular_fields.push(field);
114        }
115    }
116
117    // Generate the inner struct name (prefixed with underscore)
118    let inner_name = format_ident!("_{}", name);
119
120    // Generate fields for the inner struct (unwrapped types)
121    let inner_struct_fields = fields.iter().map(|field| {
122        let field_name = &field.ident;
123        let field_vis = &field.vis;
124
125        // If it's a Signal<'a, T>, use T; otherwise use the original type
126        let field_ty = if let Some(inner_ty) = extract_signal_inner_type(&field.ty) {
127            inner_ty
128        } else {
129            &field.ty
130        };
131
132        quote! {
133            #field_vis #field_name: #field_ty
134        }
135    });
136
137    // Generate the reactive setup code for signal fields
138    let reactive_setup = signal_fields.iter().map(|field| {
139        let field_name = &field.ident;
140
141        quote! {
142            {
143                let result_signal_weak = std::rc::Rc::downgrade(&result_signal.0);
144                let source_for_closure = std::rc::Rc::downgrade(&instance.#field_name.0);
145                let react_fn = Box::new(move || {
146                    if let Some(result_sig) = result_signal_weak.upgrade() {
147                        if !*result_sig.explicitly_modified.borrow() {
148                            if let Some(source) = source_for_closure.upgrade() {
149                                // Swap values instead of cloning
150                                std::mem::swap(
151                                    &mut *source.value.borrow_mut(),
152                                    &mut result_sig.value.borrow_mut().#field_name,
153                                );
154                            }
155                        }
156                    }
157                });
158                let cloned_signal = instance.#field_name.clone();
159                cloned_signal.0.react_fns.borrow_mut().push(react_fn);
160                cloned_signal.0.successors.borrow_mut().push(crate::signal::WeakSignalRef::new(&result_signal));
161            }
162        }
163    });
164
165    // Generate the inner struct initialization from main struct (using swap - no clone!)
166    let inner_from_main = signal_fields.iter().map(|field| {
167        let field_name = &field.ident;
168        quote! {
169            #field_name: {
170                let mut temp = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
171                std::mem::swap(&mut *instance.#field_name.0.value.borrow_mut(), &mut temp);
172                temp
173            }
174        }
175    });
176
177    let regular_from_main = regular_fields.iter().map(|field| {
178        let field_name = &field.ident;
179        quote! {
180            #field_name: instance.#field_name.clone()
181        }
182    });
183
184    // Generate restore code for signal fields (swap back after creating result_signal)
185    let restore_values = signal_fields.iter().map(|field| {
186        let field_name = &field.ident;
187        quote! {
188            std::mem::swap(
189                &mut *instance.#field_name.0.value.borrow_mut(),
190                &mut result_signal.0.value.borrow_mut().#field_name,
191            );
192        }
193    });
194
195    // Generate Clone trait bounds for signal fields (using the unwrapped inner type)
196    let signal_clone_bounds = signal_fields.iter().filter_map(|field| {
197        extract_signal_inner_type(&field.ty).map(|inner_ty| {
198            quote! { #inner_ty: Clone }
199        })
200    });
201
202    // Generate Clone trait bounds for regular fields
203    let regular_clone_bounds = regular_fields.iter().map(|field| {
204        let field_ty = &field.ty;
205        quote! { #field_ty: Clone }
206    });
207
208    // Extract generics for impl block
209    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
210
211    // Create a version of generics without lifetimes for the inner struct
212    let type_params = generics.type_params().map(|tp| &tp.ident);
213    let inner_ty_generics = if generics.type_params().count() > 0 {
214        quote! { <#(#type_params),*> }
215    } else {
216        quote! {}
217    };
218
219    let expanded = quote! {
220        // Inner struct (unwrapped types)
221        #[derive(Clone)]
222        #vis struct #inner_name #inner_ty_generics {
223            #(#inner_struct_fields),*
224        }
225
226        impl #impl_generics #name #ty_generics #where_clause {
227            pub fn lift(self) -> crate::signal::Signal<'a, #inner_name #inner_ty_generics>
228            where
229                #(#signal_clone_bounds,)*
230                #(#regular_clone_bounds,)*
231            {
232                let instance = self;
233                let initial_inner = #inner_name {
234                    #(#inner_from_main,)*
235                    #(#regular_from_main),*
236                };
237
238                let result_signal = crate::signal::Signal::new(initial_inner);
239
240                // Restore original values by swapping back
241                #(#restore_values)*
242
243                #(#reactive_setup)*
244
245                result_signal
246            }
247        }
248    };
249
250    TokenStream::from(expanded)
251}
252
253/// Derive macro to lift a struct into a thread-safe reactive `SignalSync`.
254///
255/// Applying `#[derive(LiftSync)]` to a struct generates an inner unwrapped struct and
256/// a `lift(self)` method that produces a `crate::signal_sync::SignalSync<'a, _Inner>` where
257/// any fields of type `SignalSync<'a, T>` are replaced by their inner `T` in the
258/// generated inner struct. The generated `lift` method wires up thread-safe reactions
259/// so that changes to any signal fields propagate into the resulting lifted `SignalSync`.
260///
261/// Example:
262/// ```rust
263/// use crate::signal_sync::SignalSync;
264///
265/// #[derive(LiftSync)]
266/// struct ExampleSync<'a> {
267///     a: SignalSync<'a, i32>,
268///     b: String,
269/// }
270///
271/// let example = ExampleSync { a: SignalSync::new(1), b: "hi".to_string() };
272/// let lifted = example.lift(); // SignalSync<'a, _ExampleSync>
273/// lifted.with(|inner| println!("a = {}", inner.a));
274/// ```
275#[proc_macro_derive(LiftSync)]
276pub fn derive_lift_sync(input: TokenStream) -> TokenStream {
277    let input = parse_macro_input!(input as DeriveInput);
278    let name = &input.ident;
279    let generics = &input.generics;
280    let vis = &input.vis;
281
282    // Get the fields
283    let fields = match &input.data {
284        syn::Data::Struct(data) => match &data.fields {
285            Fields::Named(fields) => &fields.named,
286            _ => panic!("LiftSync only supports structs with named fields"),
287        },
288        _ => panic!("LiftSync can only be derived for structs"),
289    };
290
291    // Separate signal fields from regular fields by checking the type
292    let mut signal_fields = Vec::new();
293    let mut regular_fields = Vec::new();
294
295    for field in fields {
296        if extract_signal_sync_inner_type(&field.ty).is_some() {
297            signal_fields.push(field);
298        } else {
299            regular_fields.push(field);
300        }
301    }
302
303    // Generate the inner struct name (prefixed with underscore)
304    let inner_name = format_ident!("_{}", name);
305
306    // Generate fields for the inner struct (unwrapped types)
307    let inner_struct_fields = fields.iter().map(|field| {
308        let field_name = &field.ident;
309        let field_vis = &field.vis;
310
311        // If it's a SignalSync<'a, T>, use T; otherwise use the original type
312        let field_ty = if let Some(inner_ty) = extract_signal_sync_inner_type(&field.ty) {
313            inner_ty
314        } else {
315            &field.ty
316        };
317
318        quote! {
319            #field_vis #field_name: #field_ty
320        }
321    });
322
323    // Generate the reactive setup code for signal fields (thread-safe version)
324    let reactive_setup = signal_fields.iter().map(|field| {
325        let field_name = &field.ident;
326
327        quote! {
328            {
329                let result_signal_weak = std::sync::Arc::downgrade(&result_signal.0);
330                let source_for_closure = std::sync::Arc::downgrade(&instance.#field_name.0);
331                let react_fn = Box::new(move || {
332                    if let Some(result_sig) = result_signal_weak.upgrade() {
333                        if !result_sig.explicitly_modified.load(std::sync::atomic::Ordering::Acquire) {
334                            if let Some(source) = source_for_closure.upgrade() {
335                                // Swap values instead of cloning
336                                std::mem::swap(
337                                    &mut *source.value.lock().unwrap(),
338                                    &mut result_sig.value.lock().unwrap().#field_name,
339                                );
340                            }
341                        }
342                    }
343                });
344                let cloned_signal = instance.#field_name.clone();
345                cloned_signal.0.react_fns.write().unwrap().push(react_fn);
346                cloned_signal.0.successors.write().unwrap().push(crate::signal_sync::WeakSignalRefSync::new(&result_signal));
347            }
348        }
349    });
350
351    // Generate the inner struct initialization from main struct (using swap - no clone!)
352    let inner_from_main = signal_fields.iter().map(|field| {
353        let field_name = &field.ident;
354        quote! {
355            #field_name: {
356                let mut temp = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
357                std::mem::swap(&mut *instance.#field_name.0.value.lock().unwrap(), &mut temp);
358                temp
359            }
360        }
361    });
362
363    let regular_from_main = regular_fields.iter().map(|field| {
364        let field_name = &field.ident;
365        quote! {
366            #field_name: instance.#field_name.clone()
367        }
368    });
369
370    // Generate restore code for signal fields (swap back after creating result_signal)
371    let restore_values = signal_fields.iter().map(|field| {
372        let field_name = &field.ident;
373        quote! {
374            std::mem::swap(
375                &mut *instance.#field_name.0.value.lock().unwrap(),
376                &mut result_signal.0.value.lock().unwrap().#field_name,
377            );
378        }
379    });
380
381    // Generate Clone + Send + Sync trait bounds for signal fields (using the unwrapped inner type)
382    let signal_clone_bounds = signal_fields.iter().filter_map(|field| {
383        extract_signal_sync_inner_type(&field.ty).map(|inner_ty| {
384            quote! { #inner_ty: Clone + Send + Sync }
385        })
386    });
387
388    // Generate Clone trait bounds for regular fields
389    let regular_clone_bounds = regular_fields.iter().map(|field| {
390        let field_ty = &field.ty;
391        quote! { #field_ty: Clone }
392    });
393
394    // Extract generics for impl block
395    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
396
397    // Create a version of generics without lifetimes for the inner struct
398    let type_params = generics.type_params().map(|tp| &tp.ident);
399    let inner_ty_generics = if generics.type_params().count() > 0 {
400        quote! { <#(#type_params),*> }
401    } else {
402        quote! {}
403    };
404
405    let expanded = quote! {
406        // Inner struct (unwrapped types)
407        #[derive(Clone)]
408        #vis struct #inner_name #inner_ty_generics {
409            #(#inner_struct_fields),*
410        }
411
412        impl #impl_generics #name #ty_generics #where_clause {
413            pub fn lift(self) -> crate::signal_sync::SignalSync<'a, #inner_name #inner_ty_generics>
414            where
415                #(#signal_clone_bounds,)*
416                #(#regular_clone_bounds,)*
417            {
418                let instance = self;
419                let initial_inner = #inner_name {
420                    #(#inner_from_main,)*
421                    #(#regular_from_main),*
422                };
423
424                let result_signal = crate::signal_sync::SignalSync::new(initial_inner);
425
426                // Restore original values by swapping back
427                #(#restore_values)*
428
429                #(#reactive_setup)*
430
431                result_signal
432            }
433        }
434    };
435
436    TokenStream::from(expanded)
437}