ocaml_derive/
lib.rs

1#![allow(clippy::manual_map)]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use quote::quote;
6
7fn check_func(item_fn: &mut syn::ItemFn) {
8    if item_fn.sig.asyncness.is_some() {
9        panic!("OCaml functions cannot be async");
10    }
11
12    if item_fn.sig.variadic.is_some() {
13        panic!("OCaml functions cannot be variadic");
14    }
15
16    match item_fn.vis {
17        syn::Visibility::Public(_) => (),
18        _ => panic!("OCaml functions must be public"),
19    }
20
21    if !item_fn.sig.generics.params.is_empty() {
22        panic!("OCaml functions may not contain generics")
23    }
24
25    item_fn.sig.abi = Some(syn::Abi {
26        extern_token: syn::token::Extern::default(),
27        name: Some(syn::LitStr::new("C", item_fn.sig.ident.span())),
28    });
29}
30
31#[derive(Debug, PartialEq, Eq)]
32enum Mode {
33    Func,
34    Struct,
35    Enum,
36    Type,
37}
38
39#[proc_macro_attribute]
40pub fn ocaml_sig(attribute: TokenStream, item: TokenStream) -> TokenStream {
41    let (name, mode, n) = if let Ok(item) = syn::parse::<syn::ItemStruct>(item.clone()) {
42        let name = &item.ident;
43        let n_fields = match item.fields {
44            syn::Fields::Named(x) => x.named.iter().count(),
45            syn::Fields::Unit => 0,
46            syn::Fields::Unnamed(x) => x.unnamed.iter().count(),
47        };
48        (name.to_string().to_lowercase(), Mode::Struct, n_fields)
49    } else if let Ok(item) = syn::parse::<syn::ItemEnum>(item.clone()) {
50        let name = &item.ident;
51        let n = item.variants.iter().count();
52        (name.to_string().to_lowercase(), Mode::Enum, n)
53    } else if let Ok(item_fn) = syn::parse::<syn::ItemFn>(item.clone()) {
54        let name = &item_fn.sig.ident;
55        let n_args = item_fn.sig.inputs.iter().count();
56        (name.to_string(), Mode::Func, n_args)
57    } else if let Ok(item) = syn::parse::<syn::ItemType>(item.clone()) {
58        let name = &item.ident;
59        (name.to_string(), Mode::Type, 0)
60    } else {
61        panic!("Invalid use of ocaml::sig macro: {item}")
62    };
63
64    if attribute.is_empty() && mode != Mode::Func {
65        // Ok
66    } else if let Ok(sig) = syn::parse::<syn::LitStr>(attribute) {
67        let s = sig.value();
68        match mode {
69            Mode::Func => {
70                let mut n_args = 0;
71                let mut prev = None;
72                let mut paren_level = 0;
73                let iter = s.chars();
74                for ch in iter {
75                    if ch == '(' {
76                        paren_level += 1;
77                    } else if ch == ')' {
78                        paren_level -= 1;
79                    }
80
81                    if ch == '>' && prev == Some('-') && paren_level == 0 {
82                        n_args += 1;
83                    }
84
85                    prev = Some(ch);
86                }
87
88                if n == 0 && !s.trim().starts_with("unit") {
89                    panic!("{name}: Expected a single unit argument");
90                }
91
92                if n != n_args && (n == 0 && n_args > 1) {
93                    panic!(
94                        "{name}: Signature and function do not have the same number of arguments (expected: {n}, got {n_args})"
95                    );
96                }
97            }
98            Mode::Enum => {
99                if !s.is_empty() {
100                    let mut n_variants = 1;
101                    let mut bracket_level = 0;
102                    let iter = s.chars();
103                    for ch in iter {
104                        if ch == '[' {
105                            bracket_level += 1;
106                        } else if ch == ']' {
107                            bracket_level -= 1;
108                        }
109
110                        if ch == '|' && bracket_level == 0 {
111                            n_variants += 1;
112                        }
113                    }
114                    if s.trim().starts_with('|') {
115                        n_variants -= 1;
116                    }
117                    if n != n_variants {
118                        panic!("{name}: Signature and enum do not have the same number of variants (expected: {n}, got {n_variants})")
119                    }
120                }
121            }
122            Mode::Struct => {
123                if !s.is_empty() {
124                    let n_fields = s.matches(':').count();
125                    if n != n_fields {
126                        panic!("{name}: Signature and struct do not have the same number of fields (expected: {n}, got {n_fields})")
127                    }
128                }
129            }
130            Mode::Type => {}
131        }
132    } else {
133        panic!("OCaml sig accepts a str literal");
134    }
135
136    item
137}
138
139/// `func` is used export Rust functions to OCaml, performing the necessary wrapping/unwrapping
140/// automatically.
141///
142/// - Wraps the function body using `ocaml::body`
143/// - Automatic type conversion for arguments/return value (including Result types)
144/// - Defines a bytecode function automatically for functions that take more than 5 arguments. The
145///   bytecode function for `my_func` would be `my_func_bytecode`
146/// - Allows for an optional ident argument specifying the name of the `gc` handle parameter
147#[proc_macro_attribute]
148pub fn ocaml_func(attribute: TokenStream, item: TokenStream) -> TokenStream {
149    let mut item_fn: syn::ItemFn = syn::parse(item).unwrap();
150    check_func(&mut item_fn);
151
152    let name = &item_fn.sig.ident;
153    let unsafety = &item_fn.sig.unsafety;
154    let constness = &item_fn.sig.constness;
155    let mut gc_name = syn::Ident::new("gc", name.span());
156    let mut use_gc = quote!({let _ = &#gc_name;});
157
158    if let Ok(ident) = syn::parse::<syn::Ident>(attribute) {
159        gc_name = ident;
160        use_gc = quote!();
161    }
162
163    let (returns, rust_return_type) = match &item_fn.sig.output {
164        syn::ReturnType::Default => (false, None),
165        syn::ReturnType::Type(_, t) => (true, Some(t)),
166    };
167
168    let rust_args: Vec<_> = item_fn.sig.inputs.iter().collect();
169
170    let args: Vec<_> = item_fn
171        .sig
172        .inputs
173        .iter()
174        .map(|arg| match arg {
175            syn::FnArg::Receiver(_) => panic!("OCaml functions cannot take a self argument"),
176            syn::FnArg::Typed(t) => match t.pat.as_ref() {
177                syn::Pat::Ident(ident) => Some(ident),
178                _ => None,
179            },
180        })
181        .collect();
182
183    let mut ocaml_args: Vec<_> = args
184        .iter()
185        .map(|t| match t {
186            Some(ident) => {
187                let ident = &ident.ident;
188                quote! { #ident: ocaml::Raw }
189            }
190            None => quote! { _: ocaml::Raw },
191        })
192        .collect();
193
194    let param_names: syn::punctuated::Punctuated<syn::Ident, syn::token::Comma> = args
195        .iter()
196        .filter_map(|arg| match arg {
197            Some(ident) => Some(ident.ident.clone()),
198            None => None,
199        })
200        .collect();
201
202    let convert_params: Vec<_> = args
203        .iter()
204        .filter_map(|arg| match arg {
205            Some(ident) => {
206                let ident = ident.ident.clone();
207                Some(quote! { let #ident = ocaml::FromValue::from_value(unsafe { ocaml::Value::new(#ident).root() }); })
208            }
209            None => None,
210        })
211        .collect();
212
213    if ocaml_args.is_empty() {
214        ocaml_args.push(quote! { _: ocaml::Raw});
215    }
216
217    let body = &item_fn.block;
218
219    let inner = if returns {
220        quote! {
221            #[inline(always)]
222            #constness #unsafety fn inner(#gc_name: &ocaml::Runtime, #(#rust_args),*) -> #rust_return_type {
223                #use_gc
224                #body
225            }
226        }
227    } else {
228        quote! {
229            #[inline(always)]
230            #constness #unsafety fn inner(#gc_name: &ocaml::Runtime, #(#rust_args),*)  {
231                #use_gc
232                #body
233            }
234        }
235    };
236
237    let where_clause = &item_fn.sig.generics.where_clause;
238    let attr: Vec<_> = item_fn.attrs.iter().collect();
239
240    let gen = quote! {
241        #[no_mangle]
242        #(
243            #attr
244        )*
245        pub #constness #unsafety extern "C" fn #name(#(#ocaml_args),*) -> ocaml::Raw #where_clause {
246            #inner
247
248            ocaml::body!(#gc_name: {
249                #(#convert_params);*
250                let res = inner(#gc_name, #param_names);
251                #[allow(unused_unsafe)]
252                let mut gc_ = unsafe { ocaml::Runtime::recover_handle() };
253                unsafe { ocaml::ToValue::to_value(&res, &gc_).raw() }
254            })
255        }
256    };
257
258    if ocaml_args.len() > 5 {
259        let bytecode = {
260            let mut bc = item_fn.clone();
261            bc.attrs.retain(|x| {
262                let s = x
263                    .path()
264                    .segments
265                    .iter()
266                    .map(|x| x.ident.to_string())
267                    .collect::<Vec<_>>()
268                    .join("::");
269                s != "ocaml::sig" && s != "sig"
270            });
271            bc.sig.ident = syn::Ident::new(&format!("{}_bytecode", name), name.span());
272            ocaml_bytecode_func_impl(bc, gc_name, use_gc, Some(name))
273        };
274
275        let r = quote! {
276            #gen
277
278            #bytecode
279        };
280        return r.into();
281    }
282
283    gen.into()
284}
285
286/// `native_func` is used export Rust functions to OCaml, it has much lower overhead than `func`
287/// and expects all arguments and return type to to be `Value`.
288///
289/// - Wraps the function body using `ocaml::body`
290/// - Allows for an optional ident argument specifying the name of the `gc` handle parameter
291#[proc_macro_attribute]
292pub fn ocaml_native_func(attribute: TokenStream, item: TokenStream) -> TokenStream {
293    let mut item_fn: syn::ItemFn = syn::parse(item).unwrap();
294    check_func(&mut item_fn);
295
296    let name = &item_fn.sig.ident;
297    let unsafety = &item_fn.sig.unsafety;
298    let constness = &item_fn.sig.constness;
299
300    let mut gc_name = syn::Ident::new("gc", name.span());
301    let mut use_gc = quote!({let _ = &#gc_name;});
302    if let Ok(ident) = syn::parse::<syn::Ident>(attribute) {
303        gc_name = ident;
304        use_gc = quote!();
305    }
306
307    let where_clause = &item_fn.sig.generics.where_clause;
308    let attr: Vec<_> = item_fn.attrs.iter().collect();
309
310    let rust_args = &item_fn.sig.inputs;
311
312    let args: Vec<_> = item_fn
313        .sig
314        .inputs
315        .iter()
316        .map(|arg| match arg {
317            syn::FnArg::Receiver(_) => panic!("OCaml functions cannot take a self argument"),
318            syn::FnArg::Typed(t) => match t.pat.as_ref() {
319                syn::Pat::Ident(ident) => Some(ident),
320                _ => None,
321            },
322        })
323        .collect();
324
325    let mut ocaml_args: Vec<_> = args
326        .iter()
327        .map(|t| match t {
328            Some(ident) => quote! { #ident: ocaml::Raw },
329            None => quote! { _: ocaml::Raw },
330        })
331        .collect();
332
333    if ocaml_args.is_empty() {
334        ocaml_args.push(quote! { _: ocaml::Raw});
335    }
336
337    let body = &item_fn.block;
338
339    let (_, rust_return_type) = match &item_fn.sig.output {
340        syn::ReturnType::Default => (false, None),
341        syn::ReturnType::Type(_, _t) => (true, Some(quote! {ocaml::Raw})),
342    };
343
344    let gen = quote! {
345        #[no_mangle]
346        #(
347            #attr
348        )*
349        pub #constness #unsafety extern "C" fn #name (#rust_args) -> #rust_return_type #where_clause {
350            let r = ocaml::body!(#gc_name: {
351                #use_gc
352                #body
353            });
354            r.raw()
355        }
356    };
357    gen.into()
358}
359
360/// `bytecode_func` is used export Rust functions to OCaml, performing the necessary wrapping/unwrapping
361/// automatically.
362///
363/// Since this is automatically applied to `func` functions, this is primarily be used when working with
364/// unboxed functions, or `native_func`s directly. `ocaml::body` is not applied since this is
365/// typically used to call the native function, which is wrapped with `ocaml::body` or performs the
366/// equivalent work to register values with the garbage collector
367///
368/// - Automatic type conversion for arguments/return value
369/// - Allows for an optional ident argument specifying the name of the `gc` handle parameter
370#[proc_macro_attribute]
371pub fn ocaml_bytecode_func(attribute: TokenStream, item: TokenStream) -> TokenStream {
372    let item_fn: syn::ItemFn = syn::parse(item).unwrap();
373    let mut gc_name = syn::Ident::new("gc", item_fn.sig.ident.span());
374    let mut use_gc = quote!({let _ = &#gc_name;});
375    if let Ok(ident) = syn::parse::<syn::Ident>(attribute) {
376        gc_name = ident;
377        use_gc = quote!();
378    }
379    ocaml_bytecode_func_impl(item_fn, gc_name, use_gc, None).into()
380}
381
382fn ocaml_bytecode_func_impl(
383    mut item_fn: syn::ItemFn,
384    gc_name: syn::Ident,
385    use_gc: impl quote::ToTokens,
386    original: Option<&proc_macro2::Ident>,
387) -> proc_macro2::TokenStream {
388    check_func(&mut item_fn);
389
390    let name = &item_fn.sig.ident;
391    let unsafety = &item_fn.sig.unsafety;
392    let constness = &item_fn.sig.constness;
393
394    let (returns, rust_return_type) = match &item_fn.sig.output {
395        syn::ReturnType::Default => (false, None),
396        syn::ReturnType::Type(_, t) => (true, Some(t)),
397    };
398
399    let rust_args: Vec<_> = item_fn.sig.inputs.iter().collect();
400
401    let args: Vec<_> = item_fn
402        .sig
403        .inputs
404        .clone()
405        .into_iter()
406        .map(|arg| match arg {
407            syn::FnArg::Receiver(_) => panic!("OCaml functions cannot take a self argument"),
408            syn::FnArg::Typed(mut t) => match t.pat.as_mut() {
409                syn::Pat::Ident(ident) => {
410                    ident.mutability = None;
411                    Some(ident.clone())
412                }
413                _ => None,
414            },
415        })
416        .collect();
417
418    let mut ocaml_args: Vec<_> = args
419        .iter()
420        .map(|t| match t {
421            Some(ident) => {
422                quote! { #ident: ocaml::Raw }
423            }
424            None => quote! { _: ocaml::Raw },
425        })
426        .collect();
427
428    let mut param_names: syn::punctuated::Punctuated<syn::Ident, syn::token::Comma> = args
429        .iter()
430        .filter_map(|arg| match arg {
431            Some(ident) => Some(ident.ident.clone()),
432            None => None,
433        })
434        .collect();
435
436    if ocaml_args.is_empty() {
437        ocaml_args.push(quote! { _unit: ocaml::Raw});
438        param_names.push(syn::Ident::new("__ocaml_unit", name.span()));
439    }
440
441    let body = &item_fn.block;
442
443    let inner = match original {
444        Some(o) => {
445            quote! {
446                #[allow(unused)]
447                let __ocaml_unit = ocaml::Value::unit();
448                let inner = #o;
449            }
450        }
451        None => {
452            if returns {
453                quote! {
454                    #[inline(always)]
455                    #constness #unsafety fn inner(#(#rust_args),*) -> #rust_return_type {
456                        #[allow(unused_variables)]
457                        let #gc_name = unsafe { ocaml::Runtime::recover_handle() };
458                        #use_gc
459                        #body
460                    }
461                }
462            } else {
463                quote! {
464                    #[inline(always)]
465                    #constness #unsafety fn inner(#(#rust_args),*)  {
466                        #[allow(unused_variables)]
467                        let #gc_name = unsafe { ocaml::Runtime::recover_handle() };
468                        #use_gc
469                        #body
470                    }
471                }
472            }
473        }
474    };
475
476    let where_clause = &item_fn.sig.generics.where_clause;
477    let attr: Vec<_> = item_fn.attrs.iter().collect();
478
479    let len = ocaml_args.len();
480
481    if len > 5 {
482        let convert_params: Vec<_> = args
483            .iter()
484            .filter_map(|arg| match arg {
485                Some(ident) => Some(quote! {
486                    #[allow(clippy::not_unsafe_ptr_arg_deref)]
487                    let #ident = ocaml::FromValue::from_value(unsafe {
488                        Value::new(core::ptr::read(__ocaml_argv.add(__ocaml_arg_index as usize))).root()
489                    });
490                    __ocaml_arg_index += 1 ;
491                }),
492                None => None,
493            })
494            .collect();
495        quote! {
496            #[no_mangle]
497            #(
498                #attr
499            )*
500            pub #constness unsafe extern "C" fn #name(__ocaml_argv: *mut ocaml::Raw, __ocaml_argc: i32) -> ocaml::Raw #where_clause {
501                assert!(#len <= __ocaml_argc as usize, "len: {}, argc: {}", #len, __ocaml_argc);
502
503                let #gc_name = unsafe { ocaml::Runtime::recover_handle() };
504
505                #inner
506
507                let mut __ocaml_arg_index = 0;
508                #(#convert_params);*
509                let res = inner(#param_names);
510                ocaml::ToValue::to_value(&res, &#gc_name).raw()
511            }
512        }
513    } else {
514        let convert_params: Vec<_> = args
515            .iter()
516            .filter_map(|arg| match arg {
517                Some(ident) => {
518                    let ident = ident.ident.clone();
519                    Some(quote! { let #ident = ocaml::FromValue::from_value(unsafe { ocaml::Value::new(#ident).root() }); })
520                }
521                None => None,
522            })
523            .collect();
524        quote! {
525            #[no_mangle]
526            #(
527                #attr
528            )*
529            pub #constness #unsafety extern "C" fn #name(#(#ocaml_args),*) -> ocaml::Raw #where_clause {
530                #[allow(unused_variables)]
531                let #gc_name = unsafe { ocaml::Runtime::recover_handle() };
532
533                #inner
534
535                #(#convert_params);*
536                let res = inner(#param_names);
537                ocaml::ToValue::to_value(&res, &#gc_name).raw()
538            }
539        }
540    }
541}
542
543// Derive macros for ToValue/FromValue
544
545fn is_double_array_struct(fields: &syn::Fields) -> bool {
546    fields.iter().all(|field| match &field.ty {
547        syn::Type::Path(p) => {
548            let s = p.path.segments.iter().map(|x| x.ident.to_string()).fold(
549                String::new(),
550                |mut acc, x| {
551                    if !acc.is_empty() {
552                        acc += "::";
553                        acc += &x;
554                        acc
555                    } else {
556                        x
557                    }
558                },
559            );
560            s == "ocaml::Float" || s == "Float" || s == "f64" || s == "f32"
561        }
562        _ => false,
563    })
564}
565
566#[derive(Default)]
567struct Attrs {
568    float_array: bool,
569    unboxed: bool,
570}
571
572// Get struct-level attributes
573fn attrs(attrs: &[syn::Attribute]) -> Attrs {
574    let mut acc = Attrs::default();
575    attrs.iter().for_each(|attr| {
576        if let syn::Meta::Path(p) = &attr.meta {
577            if let Some(ident) = p.get_ident() {
578                if ident == "float_array" {
579                    if acc.unboxed {
580                        panic!("cannot use float_array and unboxed");
581                    }
582                    acc.float_array = true;
583                } else if ident == "unboxed" {
584                    if acc.float_array {
585                        panic!("cannot use float_array and unboxed");
586                    }
587                    acc.unboxed = true;
588                }
589            }
590        }
591    });
592    acc
593}
594
595/// Derive `ocaml::FromValue`
596#[proc_macro_derive(FromValue, attributes(float_array, unboxed))]
597pub fn derive_from_value(item: TokenStream) -> TokenStream {
598    if let Ok(item_struct) = syn::parse::<syn::ItemStruct>(item.clone()) {
599        let attrs = attrs(&item_struct.attrs);
600        let g = item_struct.generics;
601        let name = item_struct.ident;
602
603        // Tuple structs have unnamed fields
604        let tuple_struct = item_struct.fields.is_empty()
605            || item_struct.fields.iter().take(1).all(|x| x.ident.is_none());
606
607        // This is true when all struct fields are `float`s
608        let is_double_array_struct =
609            attrs.float_array || is_double_array_struct(&item_struct.fields);
610
611        if attrs.unboxed && item_struct.fields.len() > 1 {
612            panic!("cannot unbox structs with more than 1 field")
613        }
614
615        let fields =
616            item_struct
617                .fields
618                .iter()
619                .enumerate()
620                .map(|(index, field)| match &field.ident {
621                    Some(name) => {
622                        // Named fields
623                        if is_double_array_struct {
624                            let ty = &field.ty;
625                            quote!(#name: value.double_field(#index) as #ty)
626                        } else if attrs.unboxed {
627                            quote!(#name: ocaml::FromValue::from_value(value))
628                        } else {
629                            quote!(#name: ocaml::FromValue::from_value(value.field(#index)))
630                        }
631                    }
632                    None => {
633                        // Unnamed fields, tuple struct
634                        if is_double_array_struct {
635                            let ty = &field.ty;
636                            quote!(value.double_field(#index) as #ty)
637                        } else if attrs.unboxed {
638                            quote!(ocaml::FromValue::from_value(value))
639                        } else {
640                            quote!(ocaml::FromValue::from_value(value.field(#index)))
641                        }
642                    }
643                });
644
645        let inner = if tuple_struct {
646            quote!(Self(#(#fields),*))
647        } else {
648            quote!(Self{#(#fields),*})
649        };
650
651        let (g_impl, g_ty, g_wh) = g.split_for_impl();
652
653        // Generate FromValue for structs
654        quote! {
655            unsafe impl #g_impl ocaml::FromValue for #name #g_ty #g_wh {
656                fn from_value(value: ocaml::Value) -> Self {
657                    unsafe {
658                        #inner
659                    }
660                }
661            }
662        }
663        .into()
664    } else if let Ok(item_enum) = syn::parse::<syn::ItemEnum>(item) {
665        let g = item_enum.generics;
666        let name = item_enum.ident;
667        let attrs = attrs(&item_enum.attrs);
668        let mut unit_tag = 0u8;
669        let mut non_unit_tag = 0u8;
670        if attrs.unboxed && item_enum.variants.len() > 1 {
671            panic!("cannot unbox enums with more than 1 variant")
672        }
673        let variants =
674            item_enum.variants.iter().map(|variant| {
675                let arity = variant.fields.len();
676                let is_block = arity != 0;
677                let tag_ref = if arity > 0 {
678                    &mut non_unit_tag
679                } else {
680                    &mut unit_tag
681                };
682
683                // Get current tag index
684                let tag = *tag_ref;
685
686                // Increment the tag for next time
687                *tag_ref += 1;
688
689                let v_name = &variant.ident;
690                let n_fields = variant.fields.len();
691
692                // Tuple enums have unnamed fields
693                let tuple_enum = variant.fields.is_empty()
694                    || variant.fields.iter().take(1).all(|x| x.ident.is_none());
695
696                // Handle enums with no fields first
697                if n_fields == 0 {
698                    quote! {
699                        (#is_block, #tag) => {
700                            #name::#v_name
701                        }
702                    }
703                } else {
704                    let fields = variant.fields.iter().enumerate().map(
705                        |(index, field)| match &field.ident {
706                            Some(name) => {
707                                // Struct enum variant
708                                if attrs.unboxed {
709                                    quote!(#name: ocaml::FromValue::from_value(value))
710                                } else {
711                                    quote!(#name: ocaml::FromValue::from_value(value.field(#index)))
712                                }
713                            }
714                            None => {
715                                // Tuple enum variant
716                                if attrs.unboxed {
717                                    quote!(#name: ocaml::FromValue::from_value(value))
718                                } else {
719                                    quote!(ocaml::FromValue::from_value(value.field(#index)))
720                                }
721                            }
722                        },
723                    );
724                    let inner = if tuple_enum {
725                        quote!(#name::#v_name(#(#fields),*))
726                    } else {
727                        quote!(#name::#v_name{#(#fields),*})
728                    };
729
730                    // Generate match case
731                    quote! {
732                        (#is_block, #tag) => {
733                            #inner
734                        }
735                    }
736                }
737            });
738
739        let (g_impl, g_ty, g_wh) = g.split_for_impl();
740
741        // Generate FromValue for enums
742        quote! {
743            unsafe impl #g_impl ocaml::FromValue for #name #g_ty #g_wh {
744                fn from_value(value: ocaml::Value) -> Self {
745                    unsafe {
746                        let is_block = value.is_block();
747                        let tag = if !is_block { value.int_val() as u8 } else { value.tag().0 as u8 };
748                        match (is_block, tag) {
749                            #(#variants),*,
750                            _ => panic!("invalid variant, tag: {}", tag)
751                        }
752                    }
753                }
754           }
755        }
756        .into()
757    } else {
758        panic!("invalid type for FromValue");
759    }
760}
761
762/// Derive `ocaml::ToValue`
763#[proc_macro_derive(ToValue, attributes(float_array, unboxed))]
764pub fn derive_to_value(item: TokenStream) -> TokenStream {
765    if let Ok(item_struct) = syn::parse::<syn::ItemStruct>(item.clone()) {
766        let attrs = attrs(&item_struct.attrs);
767        let g = item_struct.generics;
768        let name = item_struct.ident;
769
770        // Double array structs occur when all fields are `float`s
771        let is_double_array_struct =
772            attrs.float_array || is_double_array_struct(&item_struct.fields);
773        if attrs.unboxed && item_struct.fields.len() > 1 {
774            panic!("cannot unbox structs with more than 1 field")
775        }
776        let fields: Vec<_> = item_struct
777            .fields
778            .iter()
779            .enumerate()
780            .map(|(index, field)| {
781                let index = syn::Index::from(index);
782                match &field.ident {
783                    Some(name) => {
784                        // Named fields
785                        if is_double_array_struct {
786                            quote!(value.store_double_field(#index, self.#name as f64))
787                        } else if attrs.unboxed {
788                            quote!(value = self.#name.to_value(rt))
789                        } else {
790                            quote!(value.store_field(rt, #index, &self.#name))
791                        }
792                    }
793                    None => {
794                        // Tuple struct
795                        if is_double_array_struct {
796                            quote!(value.store_double_field(#index, self.#index as f64))
797                        } else if attrs.unboxed {
798                            quote!(value = self.#index.to_value(rt))
799                        } else {
800                            quote!(value.store_field(rt, #index, &self.#index))
801                        }
802                    }
803                }
804            })
805            .collect();
806
807        let tag = if is_double_array_struct {
808            quote!(ocaml::Tag::DOUBLE_ARRAY)
809        } else {
810            quote!(0.into())
811        };
812        let n = fields.len();
813        let (g_impl, g_ty, g_wh) = g.split_for_impl();
814
815        let value_decl = if attrs.unboxed {
816            // Only allocate a singlue value for unboxed structs
817            quote!(
818                let mut value = ocaml::Value::unit();
819            )
820        } else {
821            quote!(
822                let mut value = ocaml::Value::alloc(#n, #tag);
823            )
824        };
825
826        // Generate ToValue for structs
827        quote! {
828            unsafe impl #g_impl ocaml::ToValue for #name #g_ty #g_wh {
829                fn to_value(&self, rt: &ocaml::Runtime) -> ocaml::Value {
830                    unsafe {
831                        #value_decl
832                        #(#fields);*;
833                        value
834                    }
835                }
836            }
837        }
838        .into()
839    } else if let Ok(item_enum) = syn::parse::<syn::ItemEnum>(item) {
840        let g = item_enum.generics;
841        let name = item_enum.ident;
842        let attrs = attrs(&item_enum.attrs);
843        let mut unit_tag = 0u8;
844        let mut non_unit_tag = 0u8;
845
846        if attrs.unboxed && item_enum.variants.len() != 1 {
847            panic!("cannot unbox enums with more than 1 variant")
848        }
849
850        let variants = item_enum.variants.iter().map(|variant| {
851            let arity = variant.fields.len();
852            let tag_ref = if arity > 0 {
853                &mut non_unit_tag
854            } else {
855                &mut unit_tag
856            };
857
858            // Get current tag and increment for next iteration
859            let tag = *tag_ref;
860            *tag_ref += 1;
861
862            let v_name = &variant.ident;
863
864            let n_fields = variant.fields.len();
865
866            if n_fields == 0 {
867                // A variant with no fields is represented by an int value
868                quote! {
869                    #name::#v_name => {
870                        ocaml::Value::int(#tag as ocaml::Int)
871                    }
872                }
873            } else {
874                // Generate conversion for the fields of each variant
875                let fields: Vec<_> = variant
876                    .fields
877                    .iter()
878                    .enumerate()
879                    .map(|(index, field)| match &field.ident {
880                        Some(name) => {
881                            // Struct-like variant
882                            if attrs.unboxed {
883                                quote!(value = #name.to_value(rt);)
884                            } else {
885                                quote!(value.store_field(rt, #index, &#name))
886                            }
887                        }
888                        None => {
889                            // Tuple-like variant
890                            let x = format!("x{index}");
891                            let x = syn::Ident::new(&x, proc_macro2::Span::call_site());
892                            if attrs.unboxed {
893                                quote!(value = #x.to_value(rt);)
894                            } else {
895                                quote!(value.store_field(rt, #index, &#x))
896                            }
897                        }
898                    })
899                    .collect();
900
901                let n = variant.fields.len();
902                let tuple_enum = variant.fields.is_empty()
903                    || variant.fields.iter().take(1).all(|x| x.ident.is_none());
904
905                // Generate fields
906                let mut v = quote!();
907                for (index, field) in variant.fields.iter().enumerate() {
908                    let xindex = format!("x{index}");
909                    let i = syn::Ident::new(&xindex, proc_macro2::Span::call_site());
910                    let f_name = field.ident.as_ref().unwrap_or(&i);
911                    if index == 0 {
912                        v = quote!(#f_name)
913                    } else {
914                        v = quote!(#v, #f_name);
915                    }
916                }
917
918                let match_fields = if tuple_enum {
919                    quote!(#name::#v_name(#v))
920                } else {
921                    quote!(#name::#v_name{#v})
922                };
923
924                let value_decl = if attrs.unboxed {
925                    quote!(let mut value = ocaml::Value::unit())
926                } else {
927                    quote!(
928                        let mut value = ocaml::Value::alloc(#n, #tag.into());
929                    )
930                };
931                quote!(#match_fields => {
932                    #value_decl
933                    #(#fields);*;
934                    value
935                })
936            }
937        });
938
939        let (g_impl, g_ty, g_wh) = g.split_for_impl();
940
941        // Generate ToValue implementation for enums
942        quote! {
943            unsafe impl #g_impl ocaml::ToValue for #name #g_ty #g_wh {
944                fn to_value(&self, rt: &ocaml::Runtime) -> ocaml::Value {
945                    unsafe {
946                        match self {
947                            #(#variants),*,
948                        }
949                    }
950                }
951           }
952        }
953        .into()
954    } else {
955        panic!("invalid type for ToValue");
956    }
957}