Skip to main content

genetic_rs_macros/
lib.rs

1extern crate proc_macro;
2
3use darling::util::PathList;
4use darling::FromAttributes;
5use darling::FromMeta;
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use quote::quote_spanned;
10use quote::ToTokens;
11use syn::parse_quote;
12use syn::spanned::Spanned;
13use syn::{parse_macro_input, Data, DeriveInput, Fields};
14
15/// Determines the context handling strategy for a derive macro.
16enum ContextKind {
17    /// No context attribute found; all fields share a single context passed directly.
18    Shared,
19    /// `create_context` or `with_context` was used; context is accessed per field.
20    /// For `create_context`, a new context struct is generated (see [`ContextInfo::ctx_def`]).
21    /// For `with_context`, an existing type is referenced with field access.
22    PerField,
23}
24
25/// Result of resolving context for a struct derive.
26struct ContextInfo {
27    /// The context type token stream (the `type Context = ...` part).
28    ctx_type: TokenStream2,
29    /// Optional token stream defining a new context struct (only for `PerField`).
30    ctx_def: Option<TokenStream2>,
31    /// How to pass the context to each field's method.
32    kind: ContextKind,
33}
34
35#[derive(FromMeta)]
36struct ContextArgs {
37    with_context: Option<syn::Path>,
38    create_context: Option<CreateContext>,
39}
40
41#[derive(FromMeta)]
42struct CreateContext {
43    name: syn::Ident,
44    derive: Option<PathList>,
45}
46
47#[derive(FromAttributes)]
48#[darling(attributes(mitosis))]
49struct MitosisSettings {
50    use_randmut: Option<bool>,
51
52    // darling requires all possible fields to be listed to avoid unknown field errors
53    #[darling(rename = "create_context")]
54    _create_context: Option<CreateContext>,
55
56    #[darling(rename = "with_context")]
57    _with_context: Option<syn::Path>,
58}
59
60/// Resolves the context info from the attribute on an AST node.
61///
62/// `trait_name` is the trait whose `Context` associated type is used when
63/// generating a per-field context struct (i.e., for `create_context`).
64/// `attr_path` is the path of the attribute to look for (e.g. `randmut`).
65///
66/// If the attribute is absent, or if it does not contain `create_context` or
67/// `with_context`, returns `None` and the caller should fall back to inferring
68/// the context type from the first struct field.
69fn resolve_context(
70    ast: &DeriveInput,
71    trait_name: syn::Ident,
72    attr_path: syn::Path,
73    fallback_ctx: TokenStream2,
74) -> ContextInfo {
75    let name = &ast.ident;
76    let vis = ast.vis.to_token_stream();
77
78    let attr = ast.attrs.iter().find(|a| a.path() == &attr_path);
79
80    if let Some(attr) = attr {
81        // Try to parse the attribute as ContextArgs; silently ignore if it
82        // doesn't contain context-related fields (e.g., `use_randmut`).
83        if let Ok(args) = ContextArgs::from_meta(&attr.meta) {
84            if args.create_context.is_some() && args.with_context.is_some() {
85                panic!("cannot have both create_context and with_context");
86            }
87
88            if let Some(create_ctx) = args.create_context {
89                let ident = &create_ctx.name;
90                let doc = quote! {
91                    #[doc = concat!("Autogenerated context struct for [`", stringify!(#name), "`]")]
92                };
93                let derives = create_ctx.derive.map(|paths| {
94                    quote! { #[derive(#(#paths,)*)] }
95                });
96
97                let ctx_def = match &ast.data {
98                    Data::Struct(s) => {
99                        let fields: Vec<TokenStream2> = s
100                            .fields
101                            .iter()
102                            .map(|field| {
103                                let ty = &field.ty;
104                                let ty_span = ty.span();
105                                if let Some(field_name) = &field.ident {
106                                    quote_spanned! {ty_span=>
107                                        #vis #field_name: <#ty as genetic_rs_common::prelude::#trait_name>::Context,
108                                    }
109                                } else {
110                                    quote_spanned! {ty_span=>
111                                        #vis <#ty as genetic_rs_common::prelude::#trait_name>::Context,
112                                    }
113                                }
114                            })
115                            .collect();
116                        let fields_ts: TokenStream2 = fields.into_iter().collect();
117
118                        let is_tuple = matches!(s.fields, Fields::Unnamed(_));
119                        if is_tuple {
120                            quote! { #doc #derives #vis struct #ident (#fields_ts); }
121                        } else {
122                            quote! { #doc #derives #vis struct #ident { #fields_ts } }
123                        }
124                    }
125                    Data::Enum(_) => panic!("enums not supported"),
126                    Data::Union(_) => panic!("unions not supported"),
127                };
128
129                return ContextInfo {
130                    ctx_type: ident.to_token_stream(),
131                    ctx_def: Some(ctx_def),
132                    kind: ContextKind::PerField,
133                };
134            }
135
136            if let Some(with_ctx) = args.with_context {
137                return ContextInfo {
138                    ctx_type: with_ctx.to_token_stream(),
139                    ctx_def: None,
140                    kind: ContextKind::PerField,
141                };
142            }
143        }
144    }
145
146    ContextInfo {
147        ctx_type: fallback_ctx,
148        ctx_def: None,
149        kind: ContextKind::Shared,
150    }
151}
152
153#[proc_macro_derive(RandomlyMutable, attributes(randmut))]
154pub fn randmut_derive(input: TokenStream) -> TokenStream {
155    let ast = parse_macro_input!(input as DeriveInput);
156    let name = &ast.ident;
157
158    let Data::Struct(s) = &ast.data else {
159        panic!("enums and unions not yet supported");
160    };
161
162    // Determine the fallback context type from the first field (if any).
163    let fallback_ctx = s.fields.iter().next().map_or_else(
164        || quote! { () },
165        |f| {
166            let ty = &f.ty;
167            quote! { <#ty as genetic_rs_common::prelude::RandomlyMutable>::Context }
168        },
169    );
170
171    let ctx_info = resolve_context(
172        &ast,
173        parse_quote!(RandomlyMutable),
174        parse_quote!(randmut),
175        fallback_ctx,
176    );
177
178    let ctx_type = &ctx_info.ctx_type;
179    let ctx_def = &ctx_info.ctx_def;
180
181    let inner: TokenStream2 = s
182        .fields
183        .iter()
184        .enumerate()
185        .map(|(i, field)| {
186            let ty = &field.ty;
187            let span = ty.span();
188            let idx = syn::Index::from(i);
189            match (&field.ident, &ctx_info.kind) {
190                (Some(field_name), ContextKind::PerField) => quote_spanned! {span=>
191                    <#ty as genetic_rs_common::prelude::RandomlyMutable>::mutate(&mut self.#field_name, &ctx.#field_name, rate, rng);
192                },
193                (Some(field_name), _) => quote_spanned! {span=>
194                    <#ty as genetic_rs_common::prelude::RandomlyMutable>::mutate(&mut self.#field_name, ctx, rate, rng);
195                },
196                (None, ContextKind::PerField) => quote_spanned! {span=>
197                    <#ty as genetic_rs_common::prelude::RandomlyMutable>::mutate(&mut self.#idx, &ctx.#idx, rate, rng);
198                },
199                (None, _) => quote_spanned! {span=>
200                    <#ty as genetic_rs_common::prelude::RandomlyMutable>::mutate(&mut self.#idx, ctx, rate, rng);
201                },
202            }
203        })
204        .collect();
205
206    quote! {
207        #[automatically_derived]
208        impl genetic_rs_common::prelude::RandomlyMutable for #name {
209            type Context = #ctx_type;
210
211            fn mutate(&mut self, ctx: &Self::Context, rate: f32, rng: &mut impl rand::Rng) {
212                #inner
213            }
214        }
215
216        #ctx_def
217    }
218    .into()
219}
220
221#[proc_macro_derive(Mitosis, attributes(mitosis))]
222pub fn mitosis_derive(input: TokenStream) -> TokenStream {
223    let ast = parse_macro_input!(input as DeriveInput);
224    let name = &ast.ident;
225
226    let mitosis_settings = MitosisSettings::from_attributes(&ast.attrs).unwrap();
227    if mitosis_settings.use_randmut.unwrap_or(false) {
228        return quote! {
229            #[automatically_derived]
230            impl genetic_rs_common::prelude::Mitosis for #name {
231                type Context = <Self as genetic_rs_common::prelude::RandomlyMutable>::Context;
232
233                fn divide(&self, ctx: &Self::Context, rate: f32, rng: &mut impl rand::Rng) -> Self {
234                    let mut child = self.clone();
235                    <Self as genetic_rs_common::prelude::RandomlyMutable>::mutate(&mut child, ctx, rate, rng);
236                    child
237                }
238            }
239        }
240        .into();
241    }
242
243    let Data::Struct(s) = &ast.data else {
244        panic!("enums and unions not yet supported");
245    };
246
247    let fallback_ctx = s.fields.iter().next().map_or_else(
248        || quote! { () },
249        |f| {
250            let ty = &f.ty;
251            quote! { <#ty as genetic_rs_common::prelude::Mitosis>::Context }
252        },
253    );
254
255    let ctx_info = resolve_context(
256        &ast,
257        parse_quote!(Mitosis),
258        parse_quote!(mitosis),
259        fallback_ctx,
260    );
261
262    let ctx_type = &ctx_info.ctx_type;
263    let ctx_def = &ctx_info.ctx_def;
264
265    let is_tuple_struct = matches!(s.fields, Fields::Unnamed(_));
266
267    let inner: TokenStream2 = s
268        .fields
269        .iter()
270        .enumerate()
271        .map(|(i, field)| {
272            let ty = &field.ty;
273            let span = ty.span();
274            let idx = syn::Index::from(i);
275            match (&field.ident, &ctx_info.kind) {
276                (Some(field_name), ContextKind::PerField) => quote_spanned! {span=>
277                    #field_name: <#ty as genetic_rs_common::prelude::Mitosis>::divide(&self.#field_name, &ctx.#field_name, rate, rng),
278                },
279                (Some(field_name), _) => quote_spanned! {span=>
280                    #field_name: <#ty as genetic_rs_common::prelude::Mitosis>::divide(&self.#field_name, ctx, rate, rng),
281                },
282                (None, ContextKind::PerField) => quote_spanned! {span=>
283                    <#ty as genetic_rs_common::prelude::Mitosis>::divide(&self.#idx, &ctx.#idx, rate, rng),
284                },
285                (None, _) => quote_spanned! {span=>
286                    <#ty as genetic_rs_common::prelude::Mitosis>::divide(&self.#idx, ctx, rate, rng),
287                },
288            }
289        })
290        .collect();
291
292    let child = if is_tuple_struct {
293        quote! { Self(#inner) }
294    } else {
295        quote! { Self { #inner } }
296    };
297
298    quote! {
299        #[automatically_derived]
300        impl genetic_rs_common::prelude::Mitosis for #name {
301            type Context = #ctx_type;
302
303            fn divide(&self, ctx: &Self::Context, rate: f32, rng: &mut impl rand::Rng) -> Self {
304                #child
305            }
306        }
307
308        #ctx_def
309    }
310    .into()
311}
312
313#[cfg(feature = "crossover")]
314#[proc_macro_derive(Crossover, attributes(crossover))]
315pub fn crossover_derive(input: TokenStream) -> TokenStream {
316    let ast = parse_macro_input!(input as DeriveInput);
317    let name = &ast.ident;
318
319    let Data::Struct(s) = &ast.data else {
320        panic!("enums and unions not yet supported");
321    };
322
323    let fallback_ctx = s.fields.iter().next().map_or_else(
324        || quote! { () },
325        |f| {
326            let ty = &f.ty;
327            quote! { <#ty as genetic_rs_common::prelude::Crossover>::Context }
328        },
329    );
330
331    let ctx_info = resolve_context(
332        &ast,
333        parse_quote!(Crossover),
334        parse_quote!(crossover),
335        fallback_ctx,
336    );
337
338    let ctx_type = &ctx_info.ctx_type;
339    let ctx_def = &ctx_info.ctx_def;
340
341    let is_tuple_struct = matches!(s.fields, Fields::Unnamed(_));
342
343    let inner: TokenStream2 = s
344        .fields
345        .iter()
346        .enumerate()
347        .map(|(i, field)| {
348            let ty = &field.ty;
349            let span = ty.span();
350            let idx = syn::Index::from(i);
351            match (&field.ident, &ctx_info.kind) {
352                (Some(field_name), ContextKind::PerField) => quote_spanned! {span=>
353                    #field_name: <#ty as genetic_rs_common::prelude::Crossover>::crossover(&self.#field_name, &other.#field_name, &ctx.#field_name, rate, rng),
354                },
355                (Some(field_name), _) => quote_spanned! {span=>
356                    #field_name: <#ty as genetic_rs_common::prelude::Crossover>::crossover(&self.#field_name, &other.#field_name, ctx, rate, rng),
357                },
358                (None, ContextKind::PerField) => quote_spanned! {span=>
359                    <#ty as genetic_rs_common::prelude::Crossover>::crossover(&self.#idx, &other.#idx, &ctx.#idx, rate, rng),
360                },
361                (None, _) => quote_spanned! {span=>
362                    <#ty as genetic_rs_common::prelude::Crossover>::crossover(&self.#idx, &other.#idx, ctx, rate, rng),
363                },
364            }
365        })
366        .collect();
367
368    let child = if is_tuple_struct {
369        quote! { Self(#inner) }
370    } else {
371        quote! { Self { #inner } }
372    };
373
374    quote! {
375        #ctx_def
376
377        #[automatically_derived]
378        impl genetic_rs_common::prelude::Crossover for #name {
379            type Context = #ctx_type;
380
381            fn crossover(&self, other: &Self, ctx: &Self::Context, rate: f32, rng: &mut impl rand::Rng) -> Self {
382                #child
383            }
384        }
385    }
386    .into()
387}
388
389#[cfg(feature = "genrand")]
390#[proc_macro_derive(GenerateRandom)]
391pub fn genrand_derive(input: TokenStream) -> TokenStream {
392    let ast = parse_macro_input!(input as DeriveInput);
393    let name = &ast.ident;
394
395    let Data::Struct(s) = &ast.data else {
396        panic!("enums and unions not yet supported");
397    };
398
399    let is_tuple_struct = matches!(s.fields, Fields::Unnamed(_));
400
401    let inner: TokenStream2 = s
402        .fields
403        .iter()
404        .map(|field| {
405            let ty = &field.ty;
406            let span = ty.span();
407            if let Some(field_name) = &field.ident {
408                quote_spanned! {span=>
409                    #field_name: <#ty as genetic_rs_common::prelude::GenerateRandom>::gen_random(rng),
410                }
411            } else {
412                quote_spanned! {span=>
413                    <#ty as genetic_rs_common::prelude::GenerateRandom>::gen_random(rng),
414                }
415            }
416        })
417        .collect();
418
419    let body = if is_tuple_struct {
420        quote! { Self(#inner) }
421    } else {
422        quote! { Self { #inner } }
423    };
424
425    quote! {
426        #[automatically_derived]
427        impl genetic_rs_common::prelude::GenerateRandom for #name {
428            fn gen_random(rng: &mut impl rand::Rng) -> Self {
429                #body
430            }
431        }
432    }
433    .into()
434}