agility_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, GenericArgument, PathArguments, Type, TypePath, parse_macro_input};
6
7/// Helper function to check if a type is `Signal<'a, T>` and extract the inner type `T`.
8///
9/// This is used by the `#[derive(Lift)]` macro generator to detect struct fields
10/// that are reactive `Signal` types and to obtain their wrapped inner types so
11/// the generated inner struct contains the unwrapped types.
12fn extract_signal_inner_type(ty: &Type) -> Option<&Type> {
13    if let Type::Path(TypePath { path, .. }) = ty {
14        // Get the last segment of the path (e.g., "Signal" from "crate::signal::Signal")
15        let last_segment = path.segments.last()?;
16
17        // Check if it's named "Signal"
18        if last_segment.ident != "Signal" {
19            return None;
20        }
21
22        // Extract the generic arguments
23        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
24            // Signal<'a, T> has two generic arguments: lifetime 'a and type T
25            // We want to extract T (the second argument)
26            let mut iter = args.args.iter();
27            iter.next()?; // Skip the lifetime
28
29            if let Some(GenericArgument::Type(inner_ty)) = iter.next() {
30                return Some(inner_ty);
31            }
32        }
33    }
34    None
35}
36
37/// Helper function to check if a type is `SignalSync<'a, T>` and extract the inner type `T`.
38///
39/// This is used by the `#[derive(LiftSync)]` macro generator to detect struct fields
40/// that are thread-safe reactive `SignalSync` types and to obtain their wrapped inner
41/// types so the generated inner struct contains the unwrapped types.
42fn extract_signal_sync_inner_type(ty: &Type) -> Option<&Type> {
43    if let Type::Path(TypePath { path, .. }) = ty {
44        // Get the last segment of the path (e.g., "SignalSync" from "crate::signal_sync::SignalSync")
45        let last_segment = path.segments.last()?;
46
47        // Check if it's named "SignalSync"
48        if last_segment.ident != "SignalSync" {
49            return None;
50        }
51
52        // Extract the generic arguments
53        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
54            // SignalSync<'a, T> has two generic arguments: lifetime 'a and type T
55            // We want to extract T (the second argument)
56            let mut iter = args.args.iter();
57            iter.next()?; // Skip the lifetime
58
59            if let Some(GenericArgument::Type(inner_ty)) = iter.next() {
60                return Some(inner_ty);
61            }
62        }
63    }
64    None
65}
66
67/// Derive macro to lift a struct into a reactive `Signal`.
68///
69/// Applying `#[derive(Lift)]` to a struct generates an inner unwrapped struct and
70/// a `lift(self)` method that produces a `crate::signal::Signal<'a, _Inner>` where
71/// any fields of type `Signal<'a, T>` are replaced by their inner `T` in the
72/// generated inner struct. The generated `lift` method wires up reactions so that
73/// changes to any signal fields propagate into the resulting lifted `Signal`.
74///
75/// Example:
76/// ```rust
77/// use crate::signal::Signal;
78///
79/// #[derive(Lift)]
80/// struct Example<'a> {
81///     a: Signal<'a, i32>,
82///     b: String,
83/// }
84///
85/// let example = Example { a: Signal::new(1), b: "hi".to_string() };
86/// let lifted = example.lift(); // Signal<'a, _Example>
87/// lifted.with(|inner| println!("a = {}", inner.a));
88/// ```
89#[proc_macro_derive(Lift)]
90pub fn derive_lift(input: TokenStream) -> TokenStream {
91    let input = parse_macro_input!(input as DeriveInput);
92    let name = &input.ident;
93    let generics = &input.generics;
94    let vis = &input.vis;
95
96    // Get the fields
97    let fields = match &input.data {
98        syn::Data::Struct(data) => match &data.fields {
99            Fields::Named(fields) => &fields.named,
100            _ => panic!("Lift only supports structs with named fields"),
101        },
102        _ => panic!("Lift can only be derived for structs"),
103    };
104
105    // Separate signal fields from regular fields by checking the type
106    let mut signal_fields = Vec::new();
107    let mut regular_fields = Vec::new();
108
109    for field in fields {
110        if extract_signal_inner_type(&field.ty).is_some() {
111            signal_fields.push(field);
112        } else {
113            regular_fields.push(field);
114        }
115    }
116
117    // Generate the inner struct name (prefixed with underscore)
118    let inner_name = format_ident!("_{}", name);
119
120    // Generate fields for the inner struct (unwrapped types)
121    let inner_struct_fields = fields.iter().map(|field| {
122        let field_name = &field.ident;
123        let field_vis = &field.vis;
124
125        // If it's a Signal<'a, T>, use T; otherwise use the original type
126        let field_ty = if let Some(inner_ty) = extract_signal_inner_type(&field.ty) {
127            inner_ty
128        } else {
129            &field.ty
130        };
131
132        quote! {
133            #field_vis #field_name: #field_ty
134        }
135    });
136
137    // Generate the reactive setup code for signal fields
138    let reactive_setup = signal_fields.iter().map(|field| {
139        let field_name = &field.ident;
140
141        quote! {
142            {
143                let result_signal_weak = std::rc::Rc::downgrade(&result_signal.0);
144                let source_for_closure = std::rc::Rc::downgrade(&instance.#field_name.0);
145                let react_fn = Box::new(move || {
146                    if let Some(result_sig) = result_signal_weak.upgrade() {
147                        if !*result_sig.explicitly_modified.borrow() {
148                            if let Some(source) = source_for_closure.upgrade() {
149                                result_sig.value.borrow_mut().#field_name = source.value.borrow().clone();
150                            }
151                        }
152                    }
153                });
154                let cloned_signal = instance.#field_name.clone();
155                cloned_signal.0.react_fns.borrow_mut().push(react_fn);
156                cloned_signal.0.successors.borrow_mut().push(crate::signal::WeakSignalRef::new(&result_signal));
157            }
158        }
159    });
160
161    // Generate the inner struct initialization from main struct
162    let inner_from_main = signal_fields.iter().map(|field| {
163        let field_name = &field.ident;
164        quote! {
165            #field_name: instance.#field_name.0.value.borrow().clone()
166        }
167    });
168
169    let regular_from_main = regular_fields.iter().map(|field| {
170        let field_name = &field.ident;
171        quote! {
172            #field_name: instance.#field_name.clone()
173        }
174    });
175
176    // Generate Clone trait bounds for signal fields (using the unwrapped inner type)
177    let signal_clone_bounds = signal_fields.iter().filter_map(|field| {
178        extract_signal_inner_type(&field.ty).map(|inner_ty| {
179            quote! { #inner_ty: Clone }
180        })
181    });
182
183    // Generate Clone trait bounds for regular fields
184    let regular_clone_bounds = regular_fields.iter().map(|field| {
185        let field_ty = &field.ty;
186        quote! { #field_ty: Clone }
187    });
188
189    // Extract generics for impl block
190    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
191
192    // Create a version of generics without lifetimes for the inner struct
193    let type_params = generics.type_params().map(|tp| &tp.ident);
194    let inner_ty_generics = if generics.type_params().count() > 0 {
195        quote! { <#(#type_params),*> }
196    } else {
197        quote! {}
198    };
199
200    let expanded = quote! {
201        // Inner struct (unwrapped types)
202        #[derive(Clone)]
203        #vis struct #inner_name #inner_ty_generics {
204            #(#inner_struct_fields),*
205        }
206
207        impl #impl_generics #name #ty_generics #where_clause {
208            pub fn lift(self) -> crate::signal::Signal<'a, #inner_name #inner_ty_generics>
209            where
210                #(#signal_clone_bounds,)*
211                #(#regular_clone_bounds,)*
212            {
213                let instance = self;
214                let initial_inner = #inner_name {
215                    #(#inner_from_main,)*
216                    #(#regular_from_main),*
217                };
218
219                let result_signal = crate::signal::Signal::new(initial_inner);
220
221                #(#reactive_setup)*
222
223                result_signal
224            }
225        }
226    };
227
228    TokenStream::from(expanded)
229}
230
231/// Derive macro to lift a struct into a thread-safe reactive `SignalSync`.
232///
233/// Applying `#[derive(LiftSync)]` to a struct generates an inner unwrapped struct and
234/// a `lift(self)` method that produces a `crate::signal_sync::SignalSync<'a, _Inner>` where
235/// any fields of type `SignalSync<'a, T>` are replaced by their inner `T` in the
236/// generated inner struct. The generated `lift` method wires up thread-safe reactions
237/// so that changes to any signal fields propagate into the resulting lifted `SignalSync`.
238///
239/// Example:
240/// ```rust
241/// use crate::signal_sync::SignalSync;
242///
243/// #[derive(LiftSync)]
244/// struct ExampleSync<'a> {
245///     a: SignalSync<'a, i32>,
246///     b: String,
247/// }
248///
249/// let example = ExampleSync { a: SignalSync::new(1), b: "hi".to_string() };
250/// let lifted = example.lift(); // SignalSync<'a, _ExampleSync>
251/// lifted.with(|inner| println!("a = {}", inner.a));
252/// ```
253#[proc_macro_derive(LiftSync)]
254pub fn derive_lift_sync(input: TokenStream) -> TokenStream {
255    let input = parse_macro_input!(input as DeriveInput);
256    let name = &input.ident;
257    let generics = &input.generics;
258    let vis = &input.vis;
259
260    // Get the fields
261    let fields = match &input.data {
262        syn::Data::Struct(data) => match &data.fields {
263            Fields::Named(fields) => &fields.named,
264            _ => panic!("LiftSync only supports structs with named fields"),
265        },
266        _ => panic!("LiftSync can only be derived for structs"),
267    };
268
269    // Separate signal fields from regular fields by checking the type
270    let mut signal_fields = Vec::new();
271    let mut regular_fields = Vec::new();
272
273    for field in fields {
274        if extract_signal_sync_inner_type(&field.ty).is_some() {
275            signal_fields.push(field);
276        } else {
277            regular_fields.push(field);
278        }
279    }
280
281    // Generate the inner struct name (prefixed with underscore)
282    let inner_name = format_ident!("_{}", name);
283
284    // Generate fields for the inner struct (unwrapped types)
285    let inner_struct_fields = fields.iter().map(|field| {
286        let field_name = &field.ident;
287        let field_vis = &field.vis;
288
289        // If it's a SignalSync<'a, T>, use T; otherwise use the original type
290        let field_ty = if let Some(inner_ty) = extract_signal_sync_inner_type(&field.ty) {
291            inner_ty
292        } else {
293            &field.ty
294        };
295
296        quote! {
297            #field_vis #field_name: #field_ty
298        }
299    });
300
301    // Generate the reactive setup code for signal fields (thread-safe version)
302    let reactive_setup = signal_fields.iter().map(|field| {
303        let field_name = &field.ident;
304
305        quote! {
306            {
307                let result_signal_weak = std::sync::Arc::downgrade(&result_signal.0);
308                let source_for_closure = std::sync::Arc::downgrade(&instance.#field_name.0);
309                let react_fn = Box::new(move || {
310                    if let Some(result_sig) = result_signal_weak.upgrade() {
311                        if !result_sig.explicitly_modified.load(std::sync::atomic::Ordering::Acquire) {
312                            if let Some(source) = source_for_closure.upgrade() {
313                                result_sig.value.lock().unwrap().#field_name = source.value.lock().unwrap().clone();
314                            }
315                        }
316                    }
317                });
318                let cloned_signal = instance.#field_name.clone();
319                cloned_signal.0.react_fns.write().unwrap().push(react_fn);
320                cloned_signal.0.successors.write().unwrap().push(crate::signal_sync::WeakSignalRefSync::new(&result_signal));
321            }
322        }
323    });
324
325    // Generate the inner struct initialization from main struct
326    let inner_from_main = signal_fields.iter().map(|field| {
327        let field_name = &field.ident;
328        quote! {
329            #field_name: instance.#field_name.0.value.lock().unwrap().clone()
330        }
331    });
332
333    let regular_from_main = regular_fields.iter().map(|field| {
334        let field_name = &field.ident;
335        quote! {
336            #field_name: instance.#field_name.clone()
337        }
338    });
339
340    // Generate Clone + Send + Sync trait bounds for signal fields (using the unwrapped inner type)
341    let signal_clone_bounds = signal_fields.iter().filter_map(|field| {
342        extract_signal_sync_inner_type(&field.ty).map(|inner_ty| {
343            quote! { #inner_ty: Clone + Send + Sync }
344        })
345    });
346
347    // Generate Clone trait bounds for regular fields
348    let regular_clone_bounds = regular_fields.iter().map(|field| {
349        let field_ty = &field.ty;
350        quote! { #field_ty: Clone }
351    });
352
353    // Extract generics for impl block
354    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
355
356    // Create a version of generics without lifetimes for the inner struct
357    let type_params = generics.type_params().map(|tp| &tp.ident);
358    let inner_ty_generics = if generics.type_params().count() > 0 {
359        quote! { <#(#type_params),*> }
360    } else {
361        quote! {}
362    };
363
364    let expanded = quote! {
365        // Inner struct (unwrapped types)
366        #[derive(Clone)]
367        #vis struct #inner_name #inner_ty_generics {
368            #(#inner_struct_fields),*
369        }
370
371        impl #impl_generics #name #ty_generics #where_clause {
372            pub fn lift(self) -> crate::signal_sync::SignalSync<'a, #inner_name #inner_ty_generics>
373            where
374                #(#signal_clone_bounds,)*
375                #(#regular_clone_bounds,)*
376            {
377                let instance = self;
378                let initial_inner = #inner_name {
379                    #(#inner_from_main,)*
380                    #(#regular_from_main),*
381                };
382
383                let result_signal = crate::signal_sync::SignalSync::new(initial_inner);
384
385                #(#reactive_setup)*
386
387                result_signal
388            }
389        }
390    };
391
392    TokenStream::from(expanded)
393}