Skip to main content

scheme_rs_macros/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro2::{Literal, Span};
3use quote::{format_ident, quote};
4use syn::{
5    Attribute, DataEnum, DataStruct, DeriveInput, Error, Expr, ExprClosure, Fields, FnArg,
6    GenericParam, Generics, Ident, ItemFn, LitBool, LitStr, Member, Meta, Pat, PatIdent, PatType,
7    Result, Token, Type, TypePath, TypeReference, Visibility, braced, bracketed, parenthesized,
8    parse::{Parse, ParseStream},
9    parse_macro_input, parse_quote,
10    punctuated::Punctuated,
11    spanned::Spanned,
12};
13
14/// The `bridge` proc macro allows one to register Scheme procedures written in
15/// Rust.
16///
17/// Rust functions registered with `bridge` must have the following form:
18/// `async? fn(arg: &T, ... (rest_args: &[Value])?) -> Result<Vec<Value>, Exception>`
19///
20/// The types of the arguments can be any `T` for which `T: TryFrom<&Value>`, or
21/// they can be a `&Value`. Scheme-rs will throw an excpetion if 
22///
23/// Bridge functions can be async if the `async` feature flag is enabled.
24///
25/// The `bridge` proc macro takes two arguments: `def` which specifies the
26/// scheme procedure name and `lib` which specifies the library to register the
27/// procedure to. At time of variable resolution, if the library has no scheme
28/// code associated with it, all bridge functions registered to that library
29/// will be assumed to be public. More control can be given by associating
30/// scheme code with the library, in that case bridge functions will need to be
31/// made public by putting them in the `export` spec.
32#[proc_macro_attribute]
33pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream {
34    let mut name: Option<LitStr> = None;
35    let mut lib: Option<LitStr> = None;
36    let bridge_attr_parser = syn::meta::parser(|meta| {
37        if meta.path.is_ident("name") {
38            name = Some(meta.value()?.parse()?);
39            Ok(())
40        } else if meta.path.is_ident("lib") {
41            lib = Some(meta.value()?.parse()?);
42            Ok(())
43        } else {
44            Err(meta.error("unsupported bridge property"))
45        }
46    });
47
48    parse_macro_input!(args with bridge_attr_parser);
49
50    let name = name.unwrap().value();
51    let lib = lib.unwrap().value();
52    let bridge = parse_macro_input!(item as ItemFn);
53
54    let impl_name = bridge.sig.ident.clone();
55    let wrapper_name = impl_name.to_string();
56    let wrapper_name = Ident::new(&wrapper_name, Span::call_site());
57
58    let (rest_args, is_variadic) = if let Some(last_arg) = bridge.sig.inputs.last()
59        && is_slice(&last_arg)
60    {
61        (quote!(rest_args), true)
62    } else {
63        (quote!(), false)
64    };
65
66    let num_args = if is_variadic {
67        bridge.sig.inputs.len().saturating_sub(1)
68    } else {
69        bridge.sig.inputs.len()
70    };
71
72    let arg_names: Vec<_> = bridge
73        .sig
74        .inputs
75        .iter()
76        .enumerate()
77        .map(|(i, arg)| {
78            if let FnArg::Typed(PatType { pat, .. }) = arg {
79                if let Pat::Ident(PatIdent { ident, .. }) = pat.as_ref() {
80                    return ident.to_string();
81                }
82            }
83            format!("arg{i}")
84        })
85        .collect();
86
87    let arg_indices: Vec<_> = (0..num_args).collect();
88
89    let visibility = bridge.vis.clone();
90
91    if bridge.sig.asyncness.is_some() {
92        quote! {
93            #visibility fn #wrapper_name<'a>(
94                runtime: &'a ::scheme_rs::runtime::Runtime,
95                _env: &'a [::scheme_rs::value::Value],
96                args: &'a [::scheme_rs::value::Value],
97                rest_args: &'a [::scheme_rs::value::Value],
98                dyn_state: &'a mut ::scheme_rs::proc::DynamicState,
99                k: ::scheme_rs::value::Value,
100            ) -> futures::future::BoxFuture<'a, scheme_rs::proc::Application> {
101                #bridge
102
103                Box::pin(
104                    async move {
105                        let result = #impl_name(
106                            #(
107                                match (&args[#arg_indices]).try_into() {
108                                    Ok(ok) => ok,
109                                    Err(err) => {
110                                        return ::scheme_rs::exceptions::raise(
111                                            runtime.clone(),
112                                            err.into(),
113                                            dyn_state,
114                                        )
115                                    }
116                                },
117                            )*
118                            #rest_args
119                        ).await;
120                        // If the function returned an error, we want to raise
121                        // it.
122                        let result = match result {
123                            Err(err) => return ::scheme_rs::exceptions::raise(
124                                runtime.clone(),
125                                err.into(),
126                                dyn_state,
127                            ),
128                            Ok(result) => result,
129                        };
130                        let k = unsafe { k.try_into().unwrap_unchecked() };
131                        ::scheme_rs::proc::Application::new(k, result)
132                    }
133                )
134            }
135
136            inventory::submit! {
137                ::scheme_rs::registry::BridgeFn::new(
138                    #name,
139                    #lib,
140                    #num_args,
141                    #is_variadic,
142                    ::scheme_rs::registry::Bridge::Async(#wrapper_name),
143                    ::scheme_rs::registry::BridgeFnDebugInfo::new(
144                        ::std::file!(),
145                        ::std::line!(),
146                        ::std::column!(),
147                        0,
148                        &[ #( #arg_names, )* ],
149                    )
150                )
151            }
152        }
153    } else {
154        quote! {
155            #visibility fn #wrapper_name(
156                runtime: &::scheme_rs::runtime::Runtime,
157                _env: &[::scheme_rs::value::Value],
158                args: &[::scheme_rs::value::Value],
159                rest_args: &[::scheme_rs::value::Value],
160                dyn_state: &mut ::scheme_rs::proc::DynamicState,
161                k: ::scheme_rs::value::Value,
162            ) -> scheme_rs::proc::Application {
163                #bridge
164
165                let result = #impl_name(
166                    #(
167                        match (&args[#arg_indices]).try_into() {
168                            Ok(ok) => ok,
169                            Err(err) => {
170                                return ::scheme_rs::exceptions::raise(
171                                    runtime.clone(),
172                                    err.into(),
173                                    dyn_state,
174                                )
175                            }
176                        },
177                    )*
178                    #rest_args
179                );
180
181                // If the function returned an error, we want to raise
182                // it.
183                let result = match result {
184                    Err(err) => return ::scheme_rs::exceptions::raise(
185                        runtime.clone(),
186                        err.into(),
187                        dyn_state,
188                    ),
189                    Ok(result) => result,
190                };
191
192                let k = unsafe { k.try_into().unwrap_unchecked() };
193                ::scheme_rs::proc::Application::new(k, result)
194            }
195
196            inventory::submit! {
197                ::scheme_rs::registry::BridgeFn::new(
198                    #name,
199                    #lib,
200                    #num_args,
201                    #is_variadic,
202                    ::scheme_rs::registry::Bridge::Sync(#wrapper_name),
203                    ::scheme_rs::registry::BridgeFnDebugInfo::new(
204                        ::std::file!(),
205                        ::std::line!(),
206                        ::std::column!(),
207                        0,
208                        &[ #( #arg_names, )* ],
209                    )
210                )
211            }
212        }
213    }
214    .into()
215}
216
217/// The `cps_bridge` proc macro allows one to register Scheme procedureds written
218/// in Rust in a
219/// [continuation-passing style](https://en.wikipedia.org/wiki/Continuation-passing_style).
220///
221/// The main benefit of this is to allow for Rust functions that call scheme
222/// procedures in a tail-context. Essentially every scheme function, including
223/// those that are not possible to express in base scheme, are expressible in
224/// Rust because of this.
225///
226/// Functions registered with `cps_bridge` must take the following arguments:
227///  - `runtime: &Runtime`: The runtime to which the procedure is registered.
228///  - `env: &[Value]`: Environmental variables supplied to the procedure via
229///    `Procedure::new`.
230///  - `args: &[Value]`: The arguments to the procedure.
231///  - `rest_args: &[Value]`: Any variadic arguments provided to the procedure.
232///  - `dyn_state: &mut DynamicState`: The dynamic state of the program.
233///  - `k: Value`: The current continuation.
234///
235/// The `cps_bridge` proc macro takes two arguments: `def` which specifies the
236/// scheme procedure name and arguments and `lib` which specifies the library
237/// to register the procedure to.
238///
239/// `cps_bridge` functions can be async if the `async` feature flag is enabled.
240///
241/// **Note:** `cps_bridge` functions _must_ be public to be registered to a
242/// library! If a `cbs_bridge` function is declared as private, it does not take
243/// a `def` or `lib` argument.
244///
245/// # Example: Scheme `apply` procedure written in Rust:
246///
247/// ```rust,ignore
248/// #[cps_bridge(def = "apply arg1 . args", lib = "(rnrs base builtins (6))")]
249/// pub fn apply(
250///     _runtime: &Runtime,
251///     _env: &[Value],
252///     args: &[Value],
253///     rest_args: &[Value],
254///     _dyn_state: &mut DynamicState,
255///     k: Value,
256/// ) -> Result<Application, Exception> {
257///     if rest_args.is_empty() {
258///         return Err(Exception::wrong_num_of_args(2, args.len()));
259///     }
260///     let op: Procedure = args[0].clone().try_into()?;
261///     let (last, args) = rest_args.split_last().unwrap();
262///     let mut args = args.to_vec();
263///     list_to_vec(last, &mut args);
264///     args.push(k);
265///     Ok(Application::new(op.clone(), args))
266/// }
267/// ```
268
269#[proc_macro_attribute]
270pub fn cps_bridge(args: TokenStream, item: TokenStream) -> TokenStream {
271    let mut def: Option<LitStr> = None;
272    let mut lib: Option<LitStr> = None;
273    let bridge_attr_parser = syn::meta::parser(|meta| {
274        if meta.path.is_ident("def") {
275            def = Some(meta.value()?.parse()?);
276            Ok(())
277        } else if meta.path.is_ident("lib") {
278            lib = Some(meta.value()?.parse()?);
279            Ok(())
280        } else {
281            Err(meta.error("unsupported bridge property"))
282        }
283    });
284
285    parse_macro_input!(args with bridge_attr_parser);
286
287    let mut bridge = parse_macro_input!(item as ItemFn);
288    let wrapper_name = Ident::new(&bridge.sig.ident.to_string(), Span::call_site());
289    bridge.sig.ident = Ident::new("inner", Span::call_site());
290    let impl_name = bridge.sig.ident.clone();
291
292    let (vis, inventory) = if matches!(bridge.vis, Visibility::Public(_)) {
293        let vis = std::mem::replace(&mut bridge.vis, Visibility::Inherited);
294        let lib = lib.unwrap().value();
295        let def = def.unwrap().value();
296        let mut is_variadic = false;
297
298        let mut def = def
299            .split(" ")
300            .filter_map(|x| {
301                if x.is_empty() {
302                    None
303                } else if x == "." {
304                    is_variadic = true;
305                    None
306                } else {
307                    Some(x)
308                }
309            })
310            .collect::<Vec<_>>();
311
312        let name = def.remove(0);
313
314        let num_args = def.len() - is_variadic as usize;
315
316        let bridge_ptr = if bridge.sig.asyncness.is_some() {
317            quote!(::scheme_rs::registry::Bridge::Async(#wrapper_name))
318        } else {
319            quote!(::scheme_rs::registry::Bridge::Sync(#wrapper_name))
320        };
321
322        let inventory = quote! {
323            inventory::submit! {
324                ::scheme_rs::registry::BridgeFn::new(
325                    #name,
326                    #lib,
327                    #num_args,
328                    #is_variadic,
329                    #bridge_ptr,
330                    ::scheme_rs::registry::BridgeFnDebugInfo::new(
331                        ::std::file!(),
332                        ::std::line!(),
333                        ::std::column!(),
334                        0,
335                        &[ #( #def, )* ],
336                    )
337                )
338            }
339        };
340        (vis, inventory)
341    } else {
342        if let Some(def) = def {
343            return Error::new(
344                def.span(),
345                "name attribute is not supported for private functions",
346            )
347            .into_compile_error()
348            .into();
349        }
350        if let Some(lib) = lib {
351            return Error::new(
352                lib.span(),
353                "lib attribute is not supported for private functions",
354            )
355            .into_compile_error()
356            .into();
357        }
358        let vis = std::mem::replace(&mut bridge.vis, Visibility::Inherited);
359        (vis, quote!())
360    };
361
362    if bridge.sig.asyncness.is_some() {
363        quote! {
364            #vis fn #wrapper_name<'a>(
365                runtime: &'a ::scheme_rs::runtime::Runtime,
366                env: &'a [::scheme_rs::value::Value],
367                args: &'a [::scheme_rs::value::Value],
368                rest_args: &'a [::scheme_rs::value::Value],
369                dyn_state: &'a mut ::scheme_rs::proc::DynamicState,
370                k: ::scheme_rs::value::Value,
371            ) -> futures::future::BoxFuture<'a, scheme_rs::proc::Application> {
372                #bridge
373
374                Box::pin(async move {
375                    match #impl_name(
376                        runtime,
377                        env,
378                        args,
379                        rest_args,
380                        dyn_state,
381                        k
382                    ).await {
383                        Ok(app) => app,
384                        Err(err) => ::scheme_rs::exceptions::raise(
385                            runtime.clone(),
386                            err.into(),
387                            dyn_state
388                        ),
389                    }
390                })
391            }
392
393            #inventory
394        }
395    } else {
396        quote! {
397            #vis fn #wrapper_name(
398                runtime: &::scheme_rs::runtime::Runtime,
399                env: &[::scheme_rs::value::Value],
400                args: &[::scheme_rs::value::Value],
401                rest_args: &[::scheme_rs::value::Value],
402                dyn_state: &mut ::scheme_rs::proc::DynamicState,
403                k: ::scheme_rs::value::Value,
404            ) -> scheme_rs::proc::Application {
405                #bridge
406
407                match #impl_name(
408                    runtime,
409                    env,
410                    args,
411                    rest_args,
412                    dyn_state,
413                    k
414                ) {
415                    Ok(app) => app,
416                    Err(err) => ::scheme_rs::exceptions::raise(
417                        runtime.clone(),
418                        err.into(),
419                        dyn_state
420                    )
421                }
422            }
423
424            #inventory
425        }
426    }
427    .into()
428}
429
430fn is_slice(arg: &FnArg) -> bool {
431    matches!(arg, FnArg::Typed(PatType { ty, ..}) if matches!(ty.as_ref(), Type::Reference(TypeReference { elem, .. }) if matches!(elem.as_ref(), Type::Slice(_))))
432}
433
434/// Derive the `Trace` trait for a type.
435///
436/// `Trace` assumes that all fields of the type implement `Trace` or are a `Gc`
437/// type. Occasionally you want to skip the tracing of a field, perhaps because
438/// the type cannot implement `Trace`. The `#[trace(skip)]` attribute specifies
439/// that the collector should ignore that field when tracing the type.
440///
441/// Skipping a field is always safe, but can cause memory leaks if the field
442/// being skipped contains a `Gc`.
443///
444/// `Trace` will also automatically add `Trace` bounds to generic parameters.
445/// To avoid this behavior, use the `#[trace(skip_bounds)]` attribute.
446///
447/// ```rust,ignore
448/// #[derive(Trace)]
449/// #[trace(skip_bounds)]
450/// struct CustomType<T> {
451///     #[trace(skip)]
452///     inner: TypeYouDoNotOwn<T>
453/// }
454/// ```
455#[proc_macro_derive(Trace, attributes(trace))]
456pub fn derive_trace(input: TokenStream) -> TokenStream {
457    let DeriveInput {
458        attrs,
459        ident,
460        data,
461        generics,
462        ..
463    } = parse_macro_input!(input);
464
465    let tokens = match data {
466        syn::Data::Struct(data_struct) => derive_trace_struct(&attrs, ident, data_struct, generics),
467        syn::Data::Enum(data_enum) => derive_trace_enum(&attrs, ident, data_enum, generics),
468        syn::Data::Union(union) => Err(Error::new(
469            union.union_token.span(),
470            "unions are not supported by Trace",
471        )),
472    };
473
474    tokens.unwrap_or_else(syn::Error::into_compile_error).into()
475}
476
477fn derive_trace_struct(
478    attrs: &[Attribute],
479    name: Ident,
480    record: DataStruct,
481    generics: Generics,
482) -> syn::Result<proc_macro2::TokenStream> {
483    let fields = match record.fields {
484        Fields::Named(fields) => fields.named,
485        Fields::Unnamed(fields) => fields.unnamed,
486        _ => {
487            return Ok(quote! {
488                unsafe impl ::scheme_rs::gc::Trace for #name {
489                    unsafe fn visit_children(&self, visitor: &mut dyn FnMut(::scheme_rs::gc::OpaqueGcPtr)) {}
490
491                    unsafe fn finalize(&mut self) {
492                        unsafe {
493                            ::std::ptr::drop_in_place(self as *mut Self)
494                        }
495                    }
496                }
497            });
498        }
499    };
500
501    let Generics {
502        mut params,
503        where_clause,
504        ..
505    } = generics;
506
507    let mut unbound_params = Punctuated::<GenericParam, Token![,]>::new();
508
509    // TODO: Factor this out
510    if !skip_bounds(attrs)? {
511        for param in params.iter_mut() {
512            match param {
513                GenericParam::Type(ty) => {
514                    ty.bounds.push(syn::TypeParamBound::Verbatim(
515                        quote! { ::scheme_rs::gc::Trace },
516                    ));
517                    unbound_params.push(GenericParam::Type(syn::TypeParam::from(ty.ident.clone())));
518                }
519                param => unbound_params.push(param.clone()),
520            }
521        }
522    }
523
524    let (field_visits, field_drops): (Vec<_>, Vec<_>) = fields
525        .iter()
526        .enumerate()
527        .map(|(i, f)| {
528            let ident = f.ident.clone().map_or_else(
529                || {
530                    Member::Unnamed(syn::Index {
531                        index: i as u32,
532                        span: Span::call_site(),
533                    })
534                },
535                Member::Named,
536            );
537            let ty = &f.ty;
538            let is_gc = is_gc(ty);
539            let skip_field = skip_field(&f.attrs)?;
540            let visit = if skip_field {
541                quote! {}
542            } else if is_gc {
543                quote! {
544                    visitor(self.#ident.as_opaque());
545                }
546            } else {
547                quote! {
548                    <#ty as ::scheme_rs::gc::Trace>::visit_children(&self. #ident, visitor);
549                }
550            };
551            let finalize = if skip_field {
552                quote! {
553                    core::ptr::drop_in_place(&mut self. #ident as *mut #ty);
554                }
555            } else if is_gc {
556                quote! {}
557            } else {
558                quote! { <#ty as ::scheme_rs::gc::Trace>::finalize(&mut self. #ident); }
559            };
560            Ok((visit, finalize))
561        })
562        .collect::<syn::Result<Vec<_>>>()?
563        .into_iter()
564        .unzip();
565
566    Ok(quote! {
567        #[automatically_derived]
568        unsafe impl<#params> ::scheme_rs::gc::Trace for #name <#unbound_params>
569        #where_clause
570        {
571            unsafe fn visit_children(&self, visitor: &mut dyn FnMut(::scheme_rs::gc::OpaqueGcPtr)) {
572                #(
573                    #field_visits
574                )*
575            }
576
577            unsafe fn finalize(&mut self) {
578                #(
579                    #field_drops
580                )*
581            }
582        }
583    })
584}
585
586fn derive_trace_enum(
587    attrs: &[Attribute],
588    name: Ident,
589    data_enum: DataEnum,
590    generics: Generics,
591) -> syn::Result<proc_macro2::TokenStream> {
592    let (visit_match_clauses, finalize_match_clauses): (Vec<_>, Vec<_>) = data_enum
593        .variants
594        .into_iter()
595        .flat_map(|variant| {
596            let fields: Vec<_> = match variant.fields {
597                Fields::Named(ref named) => named
598                    .named
599                    .iter()
600                    .map(|field| {
601                        (
602                            field.attrs.clone(),
603                            field.ty.clone(),
604                            field.ident.as_ref().unwrap().clone(),
605                        )
606                    })
607                    .collect(),
608                Fields::Unnamed(ref unnamed) => unnamed
609                    .unnamed
610                    .iter()
611                    .enumerate()
612                    .map(|(i, field)| {
613                        let ident = Ident::new(&format!("t{i}"), Span::call_site());
614                        (field.attrs.clone(), field.ty.clone(), ident)
615                    })
616                    .collect(),
617                _ => return None,
618            };
619            Some((variant, fields))
620        })
621        .map(|(variant, fields)| {
622            let visits = fields
623                .iter()
624                .map(|(attrs, ty, accessor)| {
625                    let skip_field = skip_field(&attrs)?;
626
627                    let visit = if skip_field {
628                        quote! {
629                            let _ = #accessor;
630                        }
631                    } else if is_gc(ty) {
632                        quote! {
633                            visitor(#accessor.as_opaque());
634                        }
635                    } else {
636                        quote! {
637                            <#ty as ::scheme_rs::gc::Trace>::visit_children(#accessor, visitor);
638                        }
639                    };
640                    Ok(visit)
641                })
642                .collect::<syn::Result<Vec<_>>>()?;
643            let drops: Vec<_> = fields
644                .iter()
645                .map(|(attrs, ty, accessor)| {
646                    let skip_field = skip_field(&attrs).unwrap();
647
648                    if skip_field {
649                        quote! {
650                            core::ptr::drop_in_place(#accessor as *mut #ty);
651                        }
652                    } else if is_gc(ty) {
653                        quote! {}
654                    } else {
655                        quote! {
656                            <#ty as ::scheme_rs::gc::Trace>::finalize(#accessor);
657                        }
658                    }
659                })
660                .collect();
661            let field_name = fields.iter().map(|(_, _, field)| field);
662            let fields_destructured = match variant.fields {
663                Fields::Named(..) => quote! { { #( #field_name, )* .. } },
664                _ => quote! { ( #( #field_name ),* ) },
665            };
666            let field_name = fields.iter().map(|(_, _, field)| field);
667            let fields_destructured_mut = match variant.fields {
668                Fields::Named(..) => quote! { { #( #field_name, )* .. } },
669                _ => quote! { ( #( #field_name ),* ) },
670            };
671            let variant_name = variant.ident;
672            Ok((
673                quote! {
674                    Self::#variant_name #fields_destructured => {
675                        #(
676                            #visits
677                        )*
678                    }
679                },
680                quote! {
681                    Self::#variant_name #fields_destructured_mut => {
682                        #(
683                            #drops
684                        )*
685                    }
686                },
687            ))
688        })
689        .collect::<syn::Result<Vec<_>>>()?
690        .into_iter()
691        .unzip();
692
693    let Generics {
694        mut params,
695        where_clause,
696        ..
697    } = generics;
698
699    let mut unbound_params = Punctuated::<GenericParam, Token![,]>::new();
700
701    if !skip_bounds(attrs)? {
702        for param in params.iter_mut() {
703            match param {
704                GenericParam::Type(ty) => {
705                    ty.bounds.push(syn::TypeParamBound::Verbatim(
706                        quote! { ::scheme_rs::gc::Trace },
707                    ));
708                    unbound_params.push(GenericParam::Type(syn::TypeParam::from(ty.ident.clone())));
709                }
710                param => unbound_params.push(param.clone()),
711            }
712        }
713    }
714
715    Ok(quote! {
716        #[automatically_derived]
717        unsafe impl<#params> ::scheme_rs::gc::Trace for #name <#unbound_params>
718        #where_clause
719        {
720            unsafe fn visit_children(&self, visitor: &mut dyn FnMut(::scheme_rs::gc::OpaqueGcPtr)) {
721                match self {
722                    #( #visit_match_clauses, )*
723                    _ => (),
724                }
725            }
726
727            unsafe fn finalize(&mut self) {
728                match self {
729                    #( #finalize_match_clauses, )*
730                    _ => (),
731                }
732            }
733        }
734    })
735}
736
737fn skip_field(attrs: &[Attribute]) -> syn::Result<bool> {
738    let mut skip_field = false;
739
740    for attr in attrs.iter() {
741        if attr.path().is_ident("trace") {
742            let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
743            for meta in nested.into_iter() {
744                if meta.path().is_ident("skip") {
745                    skip_field = true;
746                } else if meta.path().is_ident("skip_bounds") {
747                    return Err(Error::new(
748                        meta.path().span(),
749                        "skip_bounds attribute is unsupported in this position",
750                    ));
751                } else {
752                    return Err(Error::new(meta.path().span(), "unrecognized attribute"));
753                }
754            }
755        }
756    }
757
758    Ok(skip_field)
759}
760
761fn skip_bounds(attrs: &[Attribute]) -> syn::Result<bool> {
762    let mut skip_bounds = false;
763
764    for attr in attrs.iter() {
765        if attr.path().is_ident("trace") {
766            let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
767            for meta in nested.into_iter() {
768                if meta.path().is_ident("skip_bounds") {
769                    skip_bounds = true;
770                } else if meta.path().is_ident("skip") {
771                    return Err(Error::new(
772                        meta.path().span(),
773                        "skip attribute is unsupported in this position",
774                    ));
775                } else {
776                    return Err(Error::new(meta.path().span(), "unrecognized attribute"));
777                }
778            }
779        }
780    }
781
782    Ok(skip_bounds)
783}
784
785fn is_gc(arg: &Type) -> bool {
786    if let Type::Path(path) = arg {
787        return path
788            .path
789            .segments
790            .last()
791            .map(|p| p.ident.to_string())
792            .as_deref()
793            == Some("Gc");
794    }
795    false
796}
797
798fn is_primitive(path: &TypePath, ty: &'static str) -> bool {
799    path.path
800        .segments
801        .last()
802        .map(|p| p.ident.to_string())
803        .as_deref()
804        == Some(ty)
805}
806
807fn rust_type_to_cranelift_type(ty: &Type) -> Option<Ident> {
808    match ty {
809        Type::Path(path) if is_primitive(path, "bool") => Some(format_ident!("I8")),
810        Type::Path(path) if is_primitive(path, "i32") => Some(format_ident!("I32")),
811        Type::Path(path) if is_primitive(path, "u32") => Some(format_ident!("I32")),
812        Type::Path(_) => Some(format_ident!("I64")),
813        Type::Ptr(_) => Some(format_ident!("I64")),
814        Type::Tuple(_) => None,
815        _ => unreachable!(),
816    }
817}
818
819#[proc_macro_attribute]
820pub fn runtime_fn(_args: TokenStream, item: TokenStream) -> TokenStream {
821    let runtime_fn = parse_macro_input!(item as ItemFn);
822
823    let name_ident = runtime_fn.sig.ident.clone();
824    let name_lit = Literal::string(&runtime_fn.sig.ident.to_string());
825    let ret = if let Some(ret_type) = match runtime_fn.sig.output {
826        syn::ReturnType::Default => None,
827        syn::ReturnType::Type(_, ref ty) => Some(rust_type_to_cranelift_type(&ty)),
828    }
829    .flatten()
830    {
831        quote! { sig.returns.push(AbiParam::new(types::#ret_type)); }
832    } else {
833        quote! {}
834    };
835    let args: Vec<_> = runtime_fn
836        .sig
837        .inputs
838        .iter()
839        .filter_map(|arg| {
840            let FnArg::Typed(pat) = arg else {
841                unreachable!();
842            };
843            rust_type_to_cranelift_type(&pat.ty)
844        })
845        .collect();
846
847    quote! {
848        #[allow(unused)]
849        inventory::submit!(crate::runtime::RuntimeFn::new(
850            |runtime_fns, module| {
851                use cranelift::prelude::*;
852                use cranelift_module::{Module, Linkage};
853                let mut sig = module.make_signature();
854                #(
855                    sig.params.push(AbiParam::new(types::#args));
856                )*
857                #ret
858                let func = module.declare_function(#name_lit, Linkage::Import, &sig).unwrap();
859                runtime_fns.#name_ident(func);
860            },
861            |jit_builder| {
862                jit_builder.symbol(#name_lit, #name_ident as *const u8);
863            }
864        ));
865
866
867        #runtime_fn
868    }
869    .into()
870}
871
872enum RtdField {
873    Immutable(LitStr),
874    Mutable(LitStr),
875}
876
877impl Parse for RtdField {
878    fn parse(input: ParseStream) -> Result<Self> {
879        if input.peek(LitStr) {
880            Ok(Self::Immutable(input.parse()?))
881        } else {
882            let mutability: Ident = input.parse()?;
883            let constructor = if mutability == "immutable" {
884                RtdField::Immutable
885            } else if mutability == "mutable" {
886                RtdField::Mutable
887            } else {
888                return Err(Error::new(
889                    mutability.span(),
890                    format!("invalid mutability '{mutability}'"),
891                ));
892            };
893            let content;
894            parenthesized!(content in input);
895            let name: LitStr = content.parse()?;
896            Ok((constructor)(name))
897        }
898    }
899}
900
901impl RtdField {
902    fn into_token_stream(self) -> proc_macro2::TokenStream {
903        match self {
904            Self::Immutable(name) => quote! {
905                ::scheme_rs::records::Field::Immutable(::scheme_rs::symbols::Symbol::intern(#name))
906            },
907            Self::Mutable(name) => quote! {
908                ::scheme_rs::records::Field::Mutable(::scheme_rs::symbols::Symbol::intern(#name))
909            },
910        }
911    }
912}
913
914struct Rtd {
915    name: LitStr,
916    parent: Option<Expr>,
917    opaque: Option<Expr>,
918    sealed: Option<LitBool>,
919    uid: Option<LitStr>,
920    constructor: Option<ExprClosure>,
921    fields: Option<Vec<RtdField>>,
922    lib: Option<LitStr>,
923}
924
925impl Parse for Rtd {
926    fn parse(input: ParseStream) -> Result<Self> {
927        let mut name = None;
928        let mut parent = None;
929        let mut opaque = None;
930        let mut sealed = None;
931        let mut fields = None;
932        let mut uid = None;
933        let mut lib = None;
934        let mut constructor: Option<ExprClosure> = None;
935        while !input.is_empty() {
936            let keyword: Ident = input.parse()?;
937            if keyword == "name" {
938                if name.is_some() {
939                    return Err(Error::new(keyword.span(), "duplicate definition of name"));
940                }
941                let _: Token![:] = input.parse()?;
942                name = Some(input.parse()?);
943            } else if keyword == "parent" {
944                if parent.is_some() {
945                    return Err(Error::new(keyword.span(), "duplicate definition of parent"));
946                }
947                let _: Token![:] = input.parse()?;
948                parent = Some(input.parse()?);
949            } else if keyword == "constructor" {
950                if constructor.is_some() {
951                    return Err(Error::new(
952                        keyword.span(),
953                        "duplicate definition of constructor",
954                    ));
955                }
956                let _: Token![:] = input.parse()?;
957                constructor = Some(input.parse()?);
958            } else if keyword == "opaque" {
959                if opaque.is_some() {
960                    return Err(Error::new(keyword.span(), "duplicate definition of opaque"));
961                }
962                let _: Token![:] = input.parse()?;
963                opaque = Some(input.parse()?);
964            } else if keyword == "sealed" {
965                if sealed.is_some() {
966                    return Err(Error::new(keyword.span(), "duplicate definition of sealed"));
967                }
968                let _: Token![:] = input.parse()?;
969                sealed = Some(input.parse()?);
970            } else if keyword == "uid" {
971                if uid.is_some() {
972                    return Err(Error::new(keyword.span(), "duplicate definition of uid"));
973                }
974                let _: Token![:] = input.parse()?;
975                uid = Some(input.parse()?);
976            } else if keyword == "lib" {
977                if lib.is_some() {
978                    return Err(Error::new(keyword.span(), "duplicate definition of lib"));
979                }
980                let _: Token![:] = input.parse()?;
981                lib = Some(input.parse()?);
982            } else if keyword == "fields" {
983                if fields.is_some() {
984                    return Err(Error::new(keyword.span(), "duplicate definition of fields"));
985                }
986                let _: Token![:] = input.parse()?;
987                let content;
988                bracketed!(content in input);
989                let punctuated_fields = content.parse_terminated(RtdField::parse, Token![,])?;
990                fields = Some(punctuated_fields.into_iter().collect());
991            } else {
992                return Err(Error::new(keyword.span(), "unknown field name"));
993            }
994
995            if !input.is_empty() {
996                let _: Token![,] = input.parse()?;
997            }
998        }
999
1000        let Some(name) = name else {
1001            return Err(Error::new(input.span(), "name field is required"));
1002        };
1003
1004        if !sealed.as_ref().map_or(false, LitBool::value) && constructor.is_none() {
1005            return Err(Error::new(
1006                input.span(),
1007                "unsealed records must have a constructor defined",
1008            ));
1009        }
1010
1011        Ok(Rtd {
1012            name,
1013            parent,
1014            opaque,
1015            sealed,
1016            uid,
1017            constructor,
1018            fields,
1019            lib,
1020        })
1021    }
1022}
1023
1024/// Convenience macro for declaring RecordTypeDescriptors
1025#[proc_macro]
1026pub fn rtd(tokens: TokenStream) -> TokenStream {
1027    let Rtd {
1028        name,
1029        parent,
1030        opaque,
1031        sealed,
1032        uid,
1033        constructor,
1034        fields,
1035        lib,
1036    } = parse_macro_input!(tokens as Rtd);
1037
1038    let fields = fields
1039        .into_iter()
1040        .flatten()
1041        .map(RtdField::into_token_stream)
1042        .collect::<Vec<_>>();
1043    let inherits = match parent {
1044        Some(parent) => quote!({
1045            let parent = <#parent as ::scheme_rs::records::SchemeCompatible>::rtd();
1046            let mut inherits = parent.inherits.clone();
1047            inherits.insert(::by_address::ByAddress(parent));
1048            inherits
1049        }),
1050        None => quote!(Default::default()),
1051    };
1052    let rust_parent_constructor = match constructor {
1053        Some(constructor) => {
1054            let num_inputs = constructor.inputs.len();
1055            let inputs = 0..num_inputs;
1056            let types = inputs.clone().map(|_| quote!(::scheme_rs::value::Value));
1057            quote!(Some({
1058                ::scheme_rs::records::RustParentConstructor::new(|vals| {
1059                    if vals.len() != #num_inputs {
1060                        return Err(::scheme_rs::exceptions::Exception::wrong_num_of_args(#num_inputs, vals.len()));
1061                    }
1062                    let constructor: fn(#(#types,)*) -> Result<_, ::scheme_rs::exceptions::Exception> = #constructor;
1063                    Ok(::scheme_rs::records::into_scheme_compatible(::scheme_rs::gc::Gc::new((constructor)(#(vals[#inputs].clone(),)*)?)))
1064                })
1065            }))
1066        }
1067        None => quote!(None),
1068    };
1069    let opaque = opaque.unwrap_or_else(|| parse_quote!(false));
1070    let sealed = sealed.unwrap_or_else(|| parse_quote!(false));
1071    let uid = match uid {
1072        Some(uid) => quote!(Some(::scheme_rs::symbols::Symbol::intern(#uid))),
1073        None => quote!(None),
1074    };
1075
1076    let bridge = lib.map(|lib| {
1077        let name = format!("{}-rtd", name.value());
1078        quote! {
1079            #[::scheme_rs_macros::bridge(name = #name, lib = #lib)]
1080            pub fn rtd() -> Result<Vec<::scheme_rs::value::Value>, ::scheme_rs::exceptions::Exception> {
1081                Ok(vec![::scheme_rs::value::Value::from(RTD.clone())])
1082            }
1083        }
1084    });
1085
1086    quote! {
1087        {
1088            static RTD: std::sync::LazyLock<std::sync::Arc<::scheme_rs::records::RecordTypeDescriptor>> =
1089                std::sync::LazyLock::new(|| {
1090                    std::sync::Arc::new(::scheme_rs::records::RecordTypeDescriptor {
1091                        name: ::scheme_rs::symbols::Symbol::intern(#name),
1092                        inherits: #inherits,
1093                        opaque: #opaque,
1094                        sealed: #sealed,
1095                        uid: #uid,
1096                        field_index_offset: 0,
1097                        fields: vec![ #( #fields, )* ],
1098                        rust_parent_constructor: #rust_parent_constructor,
1099                        rust_type: true
1100                    })
1101                });
1102            #bridge
1103            RTD.clone()
1104        }
1105    }.into()
1106}
1107
1108struct DctField {
1109    name: Ident,
1110    ty: Type,
1111}
1112
1113impl Parse for DctField {
1114    fn parse(input: ParseStream) -> Result<Self> {
1115        let name: Ident = input.parse()?;
1116        let _: Token![:] = input.parse()?;
1117        let ty: Type = input.parse()?;
1118        Ok(Self { name, ty })
1119    }
1120}
1121
1122struct DefineConditionType {
1123    scheme_name: LitStr,
1124    rust_name: Ident,
1125    lib: Option<LitStr>,
1126    parent: Type,
1127    constructor: Option<ExprClosure>,
1128    fields: Option<Vec<DctField>>,
1129    dbg: Option<ExprClosure>,
1130}
1131
1132impl Parse for DefineConditionType {
1133    fn parse(input: ParseStream) -> Result<Self> {
1134        let mut scheme_name = None;
1135        let mut rust_name = None;
1136        let mut parent = None;
1137        let mut constructor = None;
1138        let mut fields = None;
1139        let mut dbg = None;
1140        let mut lib = None;
1141
1142        while !input.is_empty() {
1143            let keyword: Ident = input.parse()?;
1144            if keyword == "scheme_name" {
1145                if scheme_name.is_some() {
1146                    return Err(Error::new(
1147                        keyword.span(),
1148                        "duplicate definition of scheme_name",
1149                    ));
1150                }
1151                let _: Token![:] = input.parse()?;
1152                scheme_name = Some(input.parse()?);
1153            } else if keyword == "lib" {
1154                if lib.is_some() {
1155                    return Err(Error::new(
1156                        keyword.span(),
1157                        "duplicate definition of lib",
1158                    ));
1159                }
1160                let _: Token![:] = input.parse()?;
1161                lib = Some(input.parse()?);
1162            } else if keyword == "rust_name" {
1163                if rust_name.is_some() {
1164                    return Err(Error::new(
1165                        keyword.span(),
1166                        "duplicate definition of rust_name",
1167                    ));
1168                }
1169                let _: Token![:] = input.parse()?;
1170                rust_name = Some(input.parse()?);
1171            } else if keyword == "parent" {
1172                if parent.is_some() {
1173                    return Err(Error::new(keyword.span(), "duplicate definition of parent"));
1174                }
1175                let _: Token![:] = input.parse()?;
1176                parent = Some(input.parse()?);
1177            } else if keyword == "constructor" {
1178                if constructor.is_some() {
1179                    return Err(Error::new(
1180                        keyword.span(),
1181                        "duplicate definition of constructor",
1182                    ));
1183                }
1184                let _: Token![:] = input.parse()?;
1185                constructor = Some(input.parse()?);
1186            } else if keyword == "debug" {
1187                if dbg.is_some() {
1188                    return Err(Error::new(keyword.span(), "duplicate definition of debug"));
1189                }
1190                let _: Token![:] = input.parse()?;
1191                dbg = Some(input.parse()?);
1192            } else if keyword == "fields" {
1193                if fields.is_some() {
1194                    return Err(Error::new(keyword.span(), "duplicate definition of fields"));
1195                }
1196                let _: Token![:] = input.parse()?;
1197                let content;
1198                braced!(content in input);
1199                let punctuated_fields = content.parse_terminated(DctField::parse, Token![,])?;
1200                fields = Some(punctuated_fields.into_iter().collect());
1201            } else {
1202                return Err(Error::new(keyword.span(), "unknown field name"));
1203            }
1204
1205            if !input.is_empty() {
1206                let _: Token![,] = input.parse()?;
1207            }
1208        }
1209
1210        let Some(scheme_name) = scheme_name else {
1211            return Err(Error::new(input.span(), "scheme_name field is required"));
1212        };
1213
1214        let Some(rust_name) = rust_name else {
1215            return Err(Error::new(input.span(), "rust_name field is required"));
1216        };
1217
1218        let Some(parent) = parent else {
1219            return Err(Error::new(input.span(), "parent field is required"));
1220        };
1221
1222        Ok(DefineConditionType {
1223            scheme_name,
1224            rust_name,
1225            parent,
1226            constructor,
1227            fields,
1228            dbg,
1229            lib,
1230        })
1231    }
1232}
1233
1234#[proc_macro]
1235pub fn define_condition_type(tokens: TokenStream) -> TokenStream {
1236    let DefineConditionType {
1237        scheme_name,
1238        rust_name,
1239        parent,
1240        constructor,
1241        fields,
1242        dbg,
1243        lib,
1244    } = parse_macro_input!(tokens as DefineConditionType);
1245
1246    let (field_names, field_tys): (Vec<_>, Vec<_>) = fields
1247        .into_iter()
1248        .flatten()
1249        .map(|field| (field.name, field.ty))
1250        .unzip();
1251
1252    let field_name_strs = field_names
1253        .clone()
1254        .into_iter()
1255        .map(|field_name| LitStr::new(&field_name.to_string(), field_name.span()));
1256
1257    let field_idxs = 0..field_names.len();
1258
1259    let lib = lib.map(|lib| quote!(lib: #lib,));
1260
1261    let constructor =  constructor.map_or_else(
1262        || quote! {
1263            constructor: || Ok(#rust_name::default()),
1264        },
1265        |constructor| {
1266            quote!(
1267                constructor: #constructor,
1268            )
1269        });
1270
1271    let dbg = dbg.map(|dbg| {
1272        quote!(
1273            let dbg: fn(&Self, &mut std::fmt::Formatter<'_>) -> std::fmt::Result = #dbg;
1274            (dbg)(self, f)?;
1275        )
1276    });
1277
1278    quote! {
1279        #[derive(Clone, ::scheme_rs::gc::Trace)]
1280        pub struct #rust_name {
1281            pub parent: ::scheme_rs::gc::Gc<#parent>,
1282            #( pub #field_names: #field_tys, )*
1283
1284        }
1285
1286        impl ::scheme_rs::records::SchemeCompatible for #rust_name {
1287            fn rtd() -> std::sync::Arc<::scheme_rs::records::RecordTypeDescriptor> {
1288                ::scheme_rs::records::rtd!(
1289                    name: #scheme_name,
1290                    parent: #parent,
1291                    #lib
1292                    fields: [#(#field_name_strs,)*],
1293                    #constructor
1294                )
1295            }
1296
1297            fn extract_embedded_record(
1298                &self,
1299                rtd: &std::sync::Arc<::scheme_rs::records::RecordTypeDescriptor>
1300            ) -> Option<::scheme_rs::gc::Gc<dyn ::scheme_rs::records::SchemeCompatible>> {
1301                #parent::rtd()
1302                    .is_subtype_of(rtd)
1303                    .then(|| ::scheme_rs::records::into_scheme_compatible(self.parent.clone()))
1304            }
1305
1306            fn get_field(&self, k: usize) -> Result<::scheme_rs::value::Value, ::scheme_rs::exceptions::Exception> {
1307                match k {
1308                    #(#field_idxs => Ok(::scheme_rs::value::Value::from(self.#field_names.clone())),)*
1309                    _ => Err(Exception::error(format!("invalid record field: {k}"))),
1310                }
1311            }
1312        }
1313
1314        impl std::fmt::Debug for #rust_name {
1315            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1316                #dbg
1317                Ok(())
1318            }
1319        }
1320    }
1321    .into()
1322}
1323
1324// Internal use only:
1325
1326#[proc_macro_attribute]
1327pub fn maybe_async(_args: TokenStream, item: TokenStream) -> TokenStream {
1328    let func = parse_macro_input!(item as ItemFn);
1329    let mut async_func = func.clone();
1330    async_func.sig.asyncness = Some(Token![async](Span::call_site()));
1331    quote! {
1332        #[cfg(not(feature = "async"))]
1333        #func
1334
1335        #[cfg(feature = "async")]
1336        #async_func
1337    }
1338    .into()
1339}
1340
1341#[proc_macro]
1342pub fn maybe_await(tokens: TokenStream) -> TokenStream {
1343    let tokens = proc_macro2::TokenStream::from(tokens);
1344    quote! {
1345        {
1346            #[cfg(not(feature = "async"))]
1347            let result = #tokens ;
1348
1349            #[cfg(feature = "async")]
1350            let result = #tokens .await;
1351
1352            result
1353        }
1354    }
1355    .into()
1356}