Skip to main content

rexlang_proc_macro/
lib.rs

1#![cfg_attr(not(test), deny(clippy::unwrap_used, clippy::expect_used))]
2
3use proc_macro::TokenStream;
4
5use proc_macro2::{Span, TokenStream as TokenStream2};
6use quote::{format_ident, quote};
7use std::collections::HashMap;
8use syn::{
9    Attribute, Data, DeriveInput, Error, Fields, GenericArgument, Generics, Ident, LitStr,
10    PathArguments, Type, parse_quote, spanned::Spanned,
11};
12
13#[proc_macro_derive(Rex, attributes(rex, serde))]
14pub fn derive_rex(input: TokenStream) -> TokenStream {
15    let ast: DeriveInput = match syn::parse(input) {
16        Ok(ast) => ast,
17        Err(e) => return e.to_compile_error().into(),
18    };
19    match expand(&ast) {
20        Ok(ts) => ts.into(),
21        Err(e) => e.to_compile_error().into(),
22    }
23}
24
25struct DeriveOptions {
26    name: String,
27}
28
29fn expand(ast: &DeriveInput) -> Result<TokenStream2, Error> {
30    if ast.generics.lifetimes().next().is_some() || ast.generics.const_params().next().is_some() {
31        return Err(Error::new(
32            ast.generics.span(),
33            "`#[derive(Rex)]` only supports type parameters (no lifetimes or const generics)",
34        ));
35    }
36
37    let opts = DeriveOptions {
38        name: rex_name_from_attrs(&ast.attrs)?.unwrap_or_else(|| ast.ident.to_string()),
39    };
40
41    let rust_ident = &ast.ident;
42    let type_name = opts.name;
43    let type_param_idents: Vec<Ident> = ast
44        .generics
45        .type_params()
46        .map(|p| p.ident.clone())
47        .collect();
48    let type_param_count = type_param_idents.len();
49
50    let mut rex_type_generics = ast.generics.clone();
51    add_bound_to_type_params(&mut rex_type_generics, parse_quote!(::rexlang::RexType));
52    let (rex_type_impl_generics, rex_type_ty_generics, rex_type_where_clause) =
53        rex_type_generics.split_for_impl();
54    let rex_type_params = type_param_idents.iter().map(|ident| {
55        quote! { <#ident as ::rexlang::RexType>::rex_type() }
56    });
57    let rex_type_impl = quote! {
58        impl #rex_type_impl_generics ::rexlang::RexType for #rust_ident #rex_type_ty_generics #rex_type_where_clause {
59            fn rex_type() -> ::rexlang::Type {
60                let mut ty = ::rexlang::Type::con(#type_name, #type_param_count);
61                #( ty = ::rexlang::Type::app(ty, #rex_type_params); )*
62                ty
63            }
64        }
65    };
66    let adt_decl_fn = adt_decl_fn(ast, &type_name, &type_param_idents)?;
67    let mut rex_adt_generics = ast.generics.clone();
68    add_bound_to_type_params(&mut rex_adt_generics, parse_quote!(::rexlang::RexType));
69    let (rex_adt_impl_generics, rex_adt_ty_generics, rex_adt_where_clause) =
70        rex_adt_generics.split_for_impl();
71    let rex_adt_impl = quote! {
72        impl #rex_adt_impl_generics ::rexlang::RexAdt for #rust_ident #rex_adt_ty_generics #rex_adt_where_clause {
73            fn rex_adt_decl<State: Clone + Send + Sync + 'static>(
74                engine: &mut ::rexlang::Engine<State>,
75            ) -> Result<::rexlang::AdtDecl, ::rexlang::EngineError> {
76                #adt_decl_fn
77            }
78        }
79    };
80    let inject_fn = quote! {
81        impl #rex_adt_impl_generics #rust_ident #rex_adt_ty_generics #rex_adt_where_clause {
82            pub fn inject_rex<State: Clone + Send + Sync + 'static>(
83                engine: &mut ::rexlang::Engine<State>,
84            ) -> Result<(), ::rexlang::EngineError> {
85                <Self as ::rexlang::RexAdt>::inject_rex(engine)
86            }
87
88            pub fn rex_adt_decl<State: Clone + Send + Sync + 'static>(
89                engine: &mut ::rexlang::Engine<State>,
90            ) -> Result<::rexlang::AdtDecl, ::rexlang::EngineError> {
91                <Self as ::rexlang::RexAdt>::rex_adt_decl(engine)
92            }
93
94            pub fn inject_rex_with_default<State: Clone + Send + Sync + 'static>(
95                engine: &mut ::rexlang::Engine<State>,
96            ) -> Result<(), ::rexlang::EngineError>
97            where
98                Self: ::rexlang::RexDefault<State>,
99            {
100                <Self as ::rexlang::RexAdt>::inject_rex(engine)?;
101                engine.inject_rex_default_instance::<Self>()
102            }
103
104            pub fn inject_rex_with_constructor<State, Sig, H>(
105                engine: &mut ::rexlang::Engine<State>,
106                constructor: H,
107            ) -> Result<(), ::rexlang::EngineError>
108            where
109                State: Clone + Send + Sync + 'static,
110                H: ::rexlang::Handler<State, Sig>,
111            {
112                <Self as ::rexlang::RexAdt>::inject_rex(engine)?;
113                engine.export(#type_name, constructor)
114            }
115        }
116    };
117
118    let into_value_impl = into_value_impl(ast, &type_name)?;
119    let from_value_impl = from_value_impl(ast, &type_name)?;
120
121    Ok(quote! {
122        #rex_type_impl
123        #rex_adt_impl
124        #inject_fn
125        #into_value_impl
126        #from_value_impl
127    })
128}
129
130fn rex_name_from_attrs(attrs: &[Attribute]) -> Result<Option<String>, Error> {
131    for attr in attrs {
132        if !attr.path().is_ident("rex") {
133            continue;
134        }
135        let mut name: Option<String> = None;
136        attr.parse_nested_meta(|meta| {
137            if meta.path.is_ident("name") {
138                let value = meta.value()?;
139                let lit: LitStr = value.parse()?;
140                name = Some(lit.value());
141            }
142            Ok(())
143        })?;
144        return Ok(name);
145    }
146    Ok(None)
147}
148
149fn serde_rename_from_attrs(attrs: &[Attribute]) -> Result<Option<String>, Error> {
150    for attr in attrs {
151        if !attr.path().is_ident("serde") {
152            continue;
153        }
154        let mut rename: Option<String> = None;
155        attr.parse_nested_meta(|meta| {
156            if meta.path.is_ident("rename") {
157                let value = meta.value()?;
158                let lit: LitStr = value.parse()?;
159                rename = Some(lit.value());
160            } else if meta.path.is_ident("alias") {
161                // Consume and ignore aliases so serde meta parsing doesn't fail.
162                let value = meta.value()?;
163                let _lit: LitStr = value.parse()?;
164            } else if meta.path.is_ident("default") {
165                // Consume and ignore defaults (function path as string literal).
166                let value = meta.value()?;
167                let _lit: LitStr = value.parse()?;
168            }
169            Ok(())
170        })?;
171        if rename.is_some() {
172            return Ok(rename);
173        }
174    }
175    Ok(None)
176}
177
178fn adt_decl_fn(
179    ast: &DeriveInput,
180    type_name: &str,
181    type_params: &[Ident],
182) -> Result<TokenStream2, Error> {
183    let param_names: Vec<LitStr> = type_params
184        .iter()
185        .map(|p| LitStr::new(&p.to_string(), Span::call_site()))
186        .collect();
187    let param_count = param_names.len();
188    let adt_decl = if param_names.is_empty() {
189        quote! {
190            let head = ::rexlang::Type::con(#type_name, 0);
191            let mut adt = engine.adt_decl_from_type_with_params(&head, &[])?;
192        }
193    } else {
194        quote! {
195            let head = ::rexlang::Type::con(#type_name, #param_count);
196            let mut adt = engine.adt_decl_from_type_with_params(&head, &[#(#param_names,)*])?;
197        }
198    };
199
200    let mut param_bindings = Vec::new();
201    let mut param_map: HashMap<String, TokenStream2> = HashMap::new();
202    for p in type_params {
203        let p_name = p.to_string();
204        let p_lit = LitStr::new(&p_name, Span::call_site());
205        let p_ident = format_ident!("__rex_param_{p_name}", span = Span::call_site());
206        param_bindings.push(quote! {
207            let #p_ident = adt
208                .param_type(&::rexlang::intern(#p_lit))
209                .ok_or_else(|| ::rexlang::EngineError::UnknownType(::rexlang::intern(#type_name)))?;
210        });
211        param_map.insert(p_name, quote!(#p_ident.clone()));
212    }
213
214    match &ast.data {
215        Data::Struct(data) => match &data.fields {
216            Fields::Named(fields) => {
217                let ctor = type_name;
218                let mut field_inits = Vec::new();
219                for field in &fields.named {
220                    let field_ident = field
221                        .ident
222                        .as_ref()
223                        .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
224                    let mut field_name = field_ident.to_string();
225                    if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
226                        field_name = rename;
227                    }
228                    let field_ty = rex_type_expr(&field.ty, &param_map)?;
229                    field_inits.push(quote! {
230                        ( ::rexlang::intern(#field_name), #field_ty )
231                    });
232                }
233                Ok(quote! {{
234                    #adt_decl
235                    #(#param_bindings)*
236                    let record = ::rexlang::Type::record(::std::vec![#(#field_inits,)*]);
237                    adt.add_variant(::rexlang::intern(#ctor), ::std::vec![record]);
238                    Ok(adt)
239                }})
240            }
241            Fields::Unnamed(fields) => {
242                let ctor = type_name;
243                let mut args = Vec::new();
244                for field in &fields.unnamed {
245                    let ty = rex_type_expr(&field.ty, &param_map)?;
246                    args.push(ty);
247                }
248                Ok(quote! {{
249                    #adt_decl
250                    #(#param_bindings)*
251                    adt.add_variant(::rexlang::intern(#ctor), ::std::vec![#(#args,)*]);
252                    Ok(adt)
253                }})
254            }
255            Fields::Unit => Ok(quote! {{
256                #adt_decl
257                #(#param_bindings)*
258                adt.add_variant(::rexlang::intern(#type_name), ::std::vec![]);
259                Ok(adt)
260            }}),
261        },
262        Data::Enum(data) => {
263            let mut variants = Vec::new();
264            for variant in &data.variants {
265                let mut variant_name = variant.ident.to_string();
266                if let Some(rename) = serde_rename_from_attrs(&variant.attrs)? {
267                    variant_name = rename;
268                }
269                let args = match &variant.fields {
270                    Fields::Unit => Vec::new(),
271                    Fields::Unnamed(fields) => {
272                        let mut out = Vec::new();
273                        for field in &fields.unnamed {
274                            out.push(rex_type_expr(&field.ty, &param_map)?);
275                        }
276                        out
277                    }
278                    Fields::Named(fields) => {
279                        let mut field_inits = Vec::new();
280                        for field in &fields.named {
281                            let field_ident = field
282                                .ident
283                                .as_ref()
284                                .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
285                            let mut field_name = field_ident.to_string();
286                            if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
287                                field_name = rename;
288                            }
289                            let field_ty = rex_type_expr(&field.ty, &param_map)?;
290                            field_inits.push(quote! {
291                                ( ::rexlang::intern(#field_name), #field_ty )
292                            });
293                        }
294                        let record = quote! {
295                            ::rexlang::Type::record(::std::vec![#(#field_inits,)*])
296                        };
297                        vec![record]
298                    }
299                };
300                variants.push(quote! {
301                    adt.add_variant(::rexlang::intern(#variant_name), ::std::vec![#(#args,)*]);
302                });
303            }
304            Ok(quote! {{
305                #adt_decl
306                #(#param_bindings)*
307                #(#variants)*
308                Ok(adt)
309            }})
310        }
311        Data::Union(_) => Err(Error::new(
312            ast.span(),
313            "`#[derive(Rex)]` only supports structs and enums",
314        )),
315    }
316}
317
318fn rex_type_expr(
319    ty: &Type,
320    adt_params: &HashMap<String, TokenStream2>,
321) -> Result<TokenStream2, Error> {
322    match ty {
323        Type::Tuple(tuple) => {
324            let elems = tuple
325                .elems
326                .iter()
327                .map(|t| rex_type_expr(t, adt_params))
328                .collect::<Result<Vec<_>, _>>()?;
329            Ok(quote! { ::rexlang::Type::tuple(::std::vec![#(#elems,)*]) })
330        }
331        Type::Path(type_path) => {
332            if type_path.qself.is_none() && type_path.path.segments.len() == 1 {
333                let seg = type_path
334                    .path
335                    .segments
336                    .last()
337                    .ok_or_else(|| Error::new(type_path.span(), "unsupported type path"))?;
338                let ident = seg.ident.to_string();
339                if let Some(param_ty) = adt_params.get(&ident) {
340                    return Ok(param_ty.clone());
341                }
342            }
343
344            let seg = type_path
345                .path
346                .segments
347                .last()
348                .ok_or_else(|| Error::new(type_path.span(), "unsupported type path"))?;
349            let ident = seg.ident.to_string();
350            let args = match &seg.arguments {
351                PathArguments::AngleBracketed(args) => args
352                    .args
353                    .iter()
354                    .filter_map(|a| match a {
355                        GenericArgument::Type(t) => Some(t),
356                        _ => None,
357                    })
358                    .collect::<Vec<_>>(),
359                _ => Vec::new(),
360            };
361
362            match ident.as_str() {
363                "Vec" => {
364                    let [inner] = args.as_slice() else {
365                        return Err(Error::new(seg.span(), "expected `Vec<T>`"));
366                    };
367                    let inner = rex_type_expr(inner, adt_params)?;
368                    Ok(quote! {
369                        ::rexlang::Type::app(
370                            ::rexlang::Type::builtin(::rexlang::BuiltinTypeId::List),
371                            #inner
372                        )
373                    })
374                }
375                "HashMap" | "BTreeMap" => {
376                    let [k, v] = args.as_slice() else {
377                        return Err(Error::new(seg.span(), "expected `HashMap<K, V>`"));
378                    };
379                    if !is_string_type(k) {
380                        return Err(Error::new(
381                            k.span(),
382                            "only `HashMap<String, V>` is supported for Rex dictionaries",
383                        ));
384                    }
385                    let v = rex_type_expr(v, adt_params)?;
386                    Ok(quote! {
387                        ::rexlang::Type::app(
388                            ::rexlang::Type::builtin(::rexlang::BuiltinTypeId::Dict),
389                            #v
390                        )
391                    })
392                }
393                "Option" => {
394                    let [inner] = args.as_slice() else {
395                        return Err(Error::new(seg.span(), "expected `Option<T>`"));
396                    };
397                    let inner = rex_type_expr(inner, adt_params)?;
398                    Ok(quote! {
399                        ::rexlang::Type::app(
400                            ::rexlang::Type::builtin(::rexlang::BuiltinTypeId::Option),
401                            #inner
402                        )
403                    })
404                }
405                "Result" => {
406                    let [ok, err] = args.as_slice() else {
407                        return Err(Error::new(seg.span(), "expected `Result<T, E>`"));
408                    };
409                    let ok = rex_type_expr(ok, adt_params)?;
410                    let err = rex_type_expr(err, adt_params)?;
411                    Ok(quote! {
412                        ::rexlang::Type::app(
413                            ::rexlang::Type::app(
414                                ::rexlang::Type::builtin(::rexlang::BuiltinTypeId::Result),
415                                #err
416                            ),
417                            #ok
418                        )
419                    })
420                }
421                _ => Ok(quote! { <#type_path as ::rexlang::RexType>::rex_type() }),
422            }
423        }
424        other => Err(Error::new(
425            other.span(),
426            "unsupported field type for Rex mapping",
427        )),
428    }
429}
430
431fn into_value_expr(expr: TokenStream2, ty: &Type) -> Result<TokenStream2, Error> {
432    match ty {
433        Type::Tuple(tuple) => {
434            let vars: Vec<Ident> = (0..tuple.elems.len())
435                .map(|i| format_ident!("__rex_t{i}", span = Span::call_site()))
436                .collect();
437            let encs = vars
438                .iter()
439                .zip(tuple.elems.iter())
440                .map(|(v, t)| into_value_expr(quote!(#v), t))
441                .collect::<Result<Vec<_>, _>>()?;
442            Ok(quote! {{
443                let (#(#vars,)*) = #expr;
444                heap.alloc_tuple(::std::vec![#(#encs,)*])?
445            }})
446        }
447        Type::Path(type_path) => {
448            let seg = type_path
449                .path
450                .segments
451                .last()
452                .ok_or_else(|| Error::new(type_path.span(), "unsupported type path"))?;
453            let ident = seg.ident.to_string();
454            let args = match &seg.arguments {
455                PathArguments::AngleBracketed(args) => args
456                    .args
457                    .iter()
458                    .filter_map(|a| match a {
459                        GenericArgument::Type(t) => Some(t),
460                        _ => None,
461                    })
462                    .collect::<Vec<_>>(),
463                _ => Vec::new(),
464            };
465
466            match ident.as_str() {
467                "Vec" => {
468                    let [inner] = args.as_slice() else {
469                        return Err(Error::new(seg.span(), "expected `Vec<T>`"));
470                    };
471                    let inner_encode = into_value_expr(quote!(item), inner)?;
472                    Ok(quote! {{
473                        let mut out =
474                            heap.alloc_adt(::rexlang::intern("Empty"), ::std::vec::Vec::new())?;
475                        for item in #expr.into_iter().rev() {
476                            out = heap
477                                .alloc_adt(
478                                    ::rexlang::intern("Cons"),
479                                    ::std::vec![#inner_encode, out],
480                                )?;
481                        }
482                        out
483                    }})
484                }
485                "HashMap" | "BTreeMap" => {
486                    let [k, v] = args.as_slice() else {
487                        return Err(Error::new(seg.span(), "expected `HashMap<K, V>`"));
488                    };
489                    if !is_string_type(k) {
490                        return Err(Error::new(
491                            k.span(),
492                            "only `HashMap<String, V>` is supported for Rex dictionaries",
493                        ));
494                    }
495                    let v_encode = into_value_expr(quote!(v), v)?;
496                    Ok(quote! {{
497                        let mut out = ::std::collections::BTreeMap::new();
498                        for (k, v) in #expr {
499                            out.insert(::rexlang::intern(&k), #v_encode);
500                        }
501                        heap.alloc_dict(out)?
502                    }})
503                }
504                "Option" => {
505                    let [inner] = args.as_slice() else {
506                        return Err(Error::new(seg.span(), "expected `Option<T>`"));
507                    };
508                    let inner_encode = into_value_expr(quote!(v), inner)?;
509                    Ok(quote! {{
510                        match #expr {
511                            Some(v) => heap
512                                .alloc_adt(::rexlang::intern("Some"), ::std::vec![#inner_encode])?,
513                            None => heap
514                                .alloc_adt(::rexlang::intern("None"), ::std::vec::Vec::new())?,
515                        }
516                    }})
517                }
518                "Result" => {
519                    let [ok_ty, err_ty] = args.as_slice() else {
520                        return Err(Error::new(seg.span(), "expected `Result<T, E>`"));
521                    };
522                    let ok_encode = into_value_expr(quote!(v), ok_ty)?;
523                    let err_encode = into_value_expr(quote!(e), err_ty)?;
524                    Ok(quote! {{
525                        match #expr {
526                            Ok(v) => heap
527                                .alloc_adt(::rexlang::intern("Ok"), ::std::vec![#ok_encode])?,
528                            Err(e) => heap
529                                .alloc_adt(::rexlang::intern("Err"), ::std::vec![#err_encode])?,
530                        }
531                    }})
532                }
533                _ => Ok(quote! { ::rexlang::IntoPointer::into_pointer(#expr, heap)? }),
534            }
535        }
536        other => Err(Error::new(
537            other.span(),
538            "unsupported field type for Rex encoding",
539        )),
540    }
541}
542
543fn from_value_expr(
544    value_expr: TokenStream2,
545    ty: &Type,
546    name_expr: TokenStream2,
547) -> Result<TokenStream2, Error> {
548    match ty {
549        Type::Tuple(tuple) => {
550            let elem_tys = tuple.elems.iter().collect::<Vec<_>>();
551            let indices: Vec<usize> = (0..elem_tys.len()).collect();
552            let decs = elem_tys
553                .iter()
554                .zip(indices.iter())
555                .map(|(t, i)| {
556                    from_value_expr(
557                        quote!(&heap.get(&items[#i])?.as_ref().clone()),
558                        t,
559                        name_expr.clone(),
560                    )
561                })
562                .collect::<Result<Vec<_>, _>>()?;
563            let len = elem_tys.len();
564            Ok(quote! {{
565                match #value_expr {
566                    ::rexlang::Value::Tuple(items) if items.len() == #len => {
567                        Ok((#(#decs?,)*))
568                    }
569                    other => Err(::rexlang::EngineError::NativeType { expected: "tuple".into(),
570                        got: ::rexlang::value_debug(heap, &other)
571                            .unwrap_or_else(|err| format!("<display error: {err}>")),
572                    }),
573                }
574            }})
575        }
576        Type::Path(type_path) => {
577            let seg = type_path
578                .path
579                .segments
580                .last()
581                .ok_or_else(|| Error::new(type_path.span(), "unsupported type path"))?;
582            let ident = seg.ident.to_string();
583            let args = match &seg.arguments {
584                PathArguments::AngleBracketed(args) => args
585                    .args
586                    .iter()
587                    .filter_map(|a| match a {
588                        GenericArgument::Type(t) => Some(t),
589                        _ => None,
590                    })
591                    .collect::<Vec<_>>(),
592                _ => Vec::new(),
593            };
594
595            match ident.as_str() {
596                "Vec" => {
597                    let [inner] = args.as_slice() else {
598                        return Err(Error::new(seg.span(), "expected `Vec<T>`"));
599                    };
600                    let inner_decode = from_value_expr(
601                        quote!(&heap.get(&args[0])?.as_ref().clone()),
602                        inner,
603                        name_expr.clone(),
604                    )?;
605                    Ok(quote! {{
606                        let mut out = ::std::vec::Vec::new();
607                        let mut cur = (#value_expr).clone();
608                        loop {
609                            match &cur {
610                                ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "Empty" && args.is_empty() => {
611                                    break Ok(out);
612                                }
613                                ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "Cons" && args.len() == 2 => {
614                                    let v = #inner_decode?;
615                                    out.push(v);
616                                    cur = heap.get(&args[1])?.as_ref().clone();
617                                }
618                                other => {
619                                    break Err(::rexlang::EngineError::NativeType { expected: "list".into(),
620                                        got: ::rexlang::value_debug(heap, &other)
621                                            .unwrap_or_else(|err| format!("<display error: {err}>")),
622                                    });
623                                }
624                            }
625                        }
626                    }})
627                }
628                "HashMap" | "BTreeMap" => {
629                    let [k, v] = args.as_slice() else {
630                        return Err(Error::new(seg.span(), "expected `HashMap<K, V>`"));
631                    };
632                    if !is_string_type(k) {
633                        return Err(Error::new(
634                            k.span(),
635                            "only `HashMap<String, V>` is supported for Rex dictionaries",
636                        ));
637                    }
638                    let v_decode = from_value_expr(
639                        quote!(&heap.get(&v)?.as_ref().clone()),
640                        v,
641                        name_expr.clone(),
642                    )?;
643                    Ok(quote! {{
644                        match #value_expr {
645                            ::rexlang::Value::Dict(map) => {
646                                let mut out = ::std::collections::HashMap::new();
647                                for (k, v) in map {
648                                    let decoded = #v_decode?;
649                                    out.insert(k.as_ref().to_string(), decoded);
650                                }
651                                Ok(out)
652                            }
653                            other => Err(::rexlang::EngineError::NativeType { expected: "dict".into(),
654                                got: ::rexlang::value_debug(heap, &other)
655                                    .unwrap_or_else(|err| format!("<display error: {err}>")),
656                            }),
657                        }
658                    }})
659                }
660                "Option" => {
661                    let [inner] = args.as_slice() else {
662                        return Err(Error::new(seg.span(), "expected `Option<T>`"));
663                    };
664                    let inner_decode = from_value_expr(
665                        quote!(&heap.get(&args[0])?.as_ref().clone()),
666                        inner,
667                        name_expr.clone(),
668                    )?;
669                    Ok(quote! {{
670                        match #value_expr {
671                            ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "None" && args.is_empty() => Ok(None),
672                            ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "Some" && args.len() == 1 => Ok(Some(#inner_decode?)),
673                            other => Err(::rexlang::EngineError::NativeType { expected: "option".into(),
674                                got: ::rexlang::value_debug(heap, &other)
675                                    .unwrap_or_else(|err| format!("<display error: {err}>")),
676                            }),
677                        }
678                    }})
679                }
680                "Result" => {
681                    let [ok_ty, err_ty] = args.as_slice() else {
682                        return Err(Error::new(seg.span(), "expected `Result<T, E>`"));
683                    };
684                    let ok_decode = from_value_expr(
685                        quote!(&heap.get(&args[0])?.as_ref().clone()),
686                        ok_ty,
687                        name_expr.clone(),
688                    )?;
689                    let err_decode = from_value_expr(
690                        quote!(&heap.get(&args[0])?.as_ref().clone()),
691                        err_ty,
692                        name_expr.clone(),
693                    )?;
694                    Ok(quote! {{
695                        match #value_expr {
696                            ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "Ok" && args.len() == 1 => Ok(Ok(#ok_decode?)),
697                            ::rexlang::Value::Adt(tag, args) if tag.as_ref() == "Err" && args.len() == 1 => Ok(Err(#err_decode?)),
698                            other => Err(::rexlang::EngineError::NativeType { expected: "result".into(),
699                                got: ::rexlang::value_debug(heap, &other)
700                                    .unwrap_or_else(|err| format!("<display error: {err}>")),
701                            }),
702                        }
703                    }})
704                }
705                _ => Ok(quote! {{
706                    let __rex_value: ::rexlang::Value = (#value_expr).clone();
707                    let __rex_ptr = heap.alloc_value(__rex_value)?;
708                    <#type_path as ::rexlang::FromPointer>::from_pointer(heap, &__rex_ptr)
709                }}),
710            }
711        }
712        other => Err(Error::new(
713            other.span(),
714            "unsupported field type for Rex decoding",
715        )),
716    }
717}
718
719fn is_string_type(ty: &Type) -> bool {
720    match ty {
721        Type::Path(p) => p
722            .path
723            .segments
724            .last()
725            .map(|s| s.ident == "String")
726            .unwrap_or(false),
727        _ => false,
728    }
729}
730
731fn add_bound_to_type_params(generics: &mut Generics, bound: syn::TypeParamBound) {
732    for param in generics.type_params_mut() {
733        param.bounds.push(bound.clone());
734    }
735}
736
737fn into_value_impl(ast: &DeriveInput, type_name: &str) -> Result<TokenStream2, Error> {
738    let rust_ident = &ast.ident;
739    let ctor = type_name;
740
741    let body = match &ast.data {
742        Data::Struct(data) => match &data.fields {
743            Fields::Named(fields) => {
744                let mut inserts = Vec::new();
745                for field in &fields.named {
746                    let ident = field
747                        .ident
748                        .as_ref()
749                        .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
750                    let mut name = ident.to_string();
751                    if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
752                        name = rename;
753                    }
754                    let enc = into_value_expr(quote!(self.#ident), &field.ty)?;
755                    inserts.push(quote! {
756                        map.insert(::rexlang::intern(#name), #enc);
757                    });
758                }
759                quote! {{
760                    let mut map = ::std::collections::BTreeMap::new();
761                    #(#inserts)*
762                    let dict = heap.alloc_dict(map)?;
763                    heap.alloc_adt(::rexlang::intern(#ctor), ::std::vec![dict])?
764                }}
765            }
766            Fields::Unnamed(fields) => {
767                let mut args = Vec::new();
768                let mut bindings = Vec::new();
769                for (idx, field) in fields.unnamed.iter().enumerate() {
770                    let v = format_ident!("__rex_f{idx}", span = Span::call_site());
771                    bindings.push(v.clone());
772                    args.push(into_value_expr(quote!(#v), &field.ty)?);
773                }
774                quote! {{
775                    let Self(#(#bindings,)*) = self;
776                    heap.alloc_adt(::rexlang::intern(#ctor), ::std::vec![#(#args,)*])?
777                }}
778            }
779            Fields::Unit => quote! {
780                heap.alloc_adt(::rexlang::intern(#ctor), ::std::vec::Vec::new())?
781            },
782        },
783        Data::Enum(data) => {
784            let mut arms = Vec::new();
785            for variant in &data.variants {
786                let variant_ident = &variant.ident;
787                let mut variant_name = variant_ident.to_string();
788                if let Some(rename) = serde_rename_from_attrs(&variant.attrs)? {
789                    variant_name = rename;
790                }
791                let arm = match &variant.fields {
792                    Fields::Unit => quote! {
793                        Self::#variant_ident => heap
794                            .alloc_adt(::rexlang::intern(#variant_name), ::std::vec::Vec::new())?
795                    },
796                    Fields::Unnamed(fields) => {
797                        let vars: Vec<Ident> = (0..fields.unnamed.len())
798                            .map(|i| format_ident!("__rex_v{i}", span = Span::call_site()))
799                            .collect();
800                        let encs = vars
801                            .iter()
802                            .zip(fields.unnamed.iter())
803                            .map(|(v, f)| into_value_expr(quote!(#v), &f.ty))
804                            .collect::<Result<Vec<_>, _>>()?;
805                        quote! {
806                            Self::#variant_ident(#(#vars,)*) => heap
807                                .alloc_adt(::rexlang::intern(#variant_name), ::std::vec![#(#encs,)*])?
808                        }
809                    }
810                    Fields::Named(fields) => {
811                        let mut vars = Vec::new();
812                        let mut inserts = Vec::new();
813                        for field in &fields.named {
814                            let ident = field
815                                .ident
816                                .as_ref()
817                                .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
818                            vars.push(ident.clone());
819                            let mut name = ident.to_string();
820                            if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
821                                name = rename;
822                            }
823                            let enc = into_value_expr(quote!(#ident), &field.ty)?;
824                            inserts.push(quote! {
825                                map.insert(::rexlang::intern(#name), #enc);
826                            });
827                        }
828                        quote! {
829                            Self::#variant_ident { #(#vars,)* } => {
830                                let mut map = ::std::collections::BTreeMap::new();
831                                #(#inserts)*
832                                let dict = heap.alloc_dict(map)?;
833                                heap.alloc_adt(::rexlang::intern(#variant_name), ::std::vec![dict])?
834                            }
835                        }
836                    }
837                };
838                arms.push(arm);
839            }
840            quote! {{
841                match self {
842                    #(#arms,)*
843                }
844            }}
845        }
846        Data::Union(_) => {
847            return Err(Error::new(
848                ast.span(),
849                "`#[derive(Rex)]` only supports structs and enums",
850            ));
851        }
852    };
853
854    let mut generics = ast.generics.clone();
855    add_bound_to_type_params(&mut generics, parse_quote!(::rexlang::IntoPointer));
856    let (impl_generics, _, where_clause) = generics.split_for_impl();
857    let (_, ty_generics, _) = generics.split_for_impl();
858
859    Ok(quote! {
860        impl #impl_generics ::rexlang::IntoPointer for #rust_ident #ty_generics #where_clause {
861            fn into_pointer(
862                self,
863                heap: &::rexlang::Heap,
864            ) -> ::std::result::Result<::rexlang::Pointer, ::rexlang::EngineError> {
865                Ok(#body)
866            }
867        }
868    })
869}
870
871fn from_value_impl(ast: &DeriveInput, type_name: &str) -> Result<TokenStream2, Error> {
872    let rust_ident = &ast.ident;
873    let name_expr = quote!(name);
874
875    let body = match &ast.data {
876        Data::Struct(data) => match &data.fields {
877            Fields::Named(fields) => {
878                let mut field_decodes = Vec::new();
879                let mut field_idents = Vec::new();
880                for field in &fields.named {
881                    let ident = field
882                        .ident
883                        .as_ref()
884                        .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
885                    field_idents.push(ident.clone());
886                    let mut name = ident.to_string();
887                    if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
888                        name = rename;
889                    }
890                    let key = quote!(::rexlang::intern(#name));
891                    let decode = from_value_expr(
892                        quote!(&heap.get(&v)?.as_ref().clone()),
893                        &field.ty,
894                        name_expr.clone(),
895                    )?;
896                    field_decodes.push(quote! {
897                        let v = map.get(&#key).ok_or_else(|| ::rexlang::EngineError::NativeType { expected: format!("missing field `{}`", #name),
898                            got: "dict".into(),
899                        })?;
900                        let #ident = #decode?;
901                    });
902                }
903                Ok(quote! {{
904                    match value {
905                        ::rexlang::Value::Adt(tag, args)
906                            if (tag.as_ref() == #type_name
907                                || tag.as_ref().rsplit('.').next() == Some(#type_name))
908                                && args.len() == 1 =>
909                        {
910                            match heap.get(&args[0])?.as_ref().clone() {
911                                ::rexlang::Value::Dict(map) => {
912                                    #(#field_decodes)*
913                                    Ok(Self { #(#field_idents,)* })
914                                }
915                                other => Err(::rexlang::EngineError::NativeType { expected: "dict".into(),
916                                    got: ::rexlang::value_debug(heap, &other)
917                                        .unwrap_or_else(|err| format!("<display error: {err}>")),
918                                }),
919                            }
920                        }
921                        other => Err(::rexlang::EngineError::NativeType { expected: #type_name.into(),
922                            got: ::rexlang::value_debug(heap, &other)
923                                .unwrap_or_else(|err| format!("<display error: {err}>")),
924                        }),
925                    }
926                }})
927            }
928            Fields::Unnamed(fields) => {
929                let mut decs = Vec::new();
930                for (idx, field) in fields.unnamed.iter().enumerate() {
931                    let decode = from_value_expr(
932                        quote!(&heap.get(&args[#idx])?.as_ref().clone()),
933                        &field.ty,
934                        name_expr.clone(),
935                    )?;
936                    decs.push(quote!(#decode?));
937                }
938                let len = fields.unnamed.len();
939                Ok(quote! {{
940                    match value {
941                        ::rexlang::Value::Adt(tag, args)
942                            if (tag.as_ref() == #type_name
943                                || tag.as_ref().rsplit('.').next() == Some(#type_name))
944                                && args.len() == #len =>
945                        {
946                            Ok(Self(#(#decs,)*))
947                        }
948                        other => Err(::rexlang::EngineError::NativeType { expected: #type_name.into(),
949                            got: ::rexlang::value_debug(heap, &other)
950                                .unwrap_or_else(|err| format!("<display error: {err}>")),
951                        }),
952                    }
953                }})
954            }
955            Fields::Unit => Ok(quote! {{
956                match value {
957                    ::rexlang::Value::Adt(tag, args)
958                        if (tag.as_ref() == #type_name
959                            || tag.as_ref().rsplit('.').next() == Some(#type_name))
960                            && args.is_empty() =>
961                    {
962                        Ok(Self)
963                    }
964                    other => Err(::rexlang::EngineError::NativeType { expected: #type_name.into(),
965                        got: ::rexlang::value_debug(heap, &other)
966                            .unwrap_or_else(|err| format!("<display error: {err}>")),
967                    }),
968                }
969            }}),
970        },
971        Data::Enum(data) => {
972            let mut arms = Vec::new();
973            for variant in &data.variants {
974                let variant_ident = &variant.ident;
975                let mut variant_name = variant_ident.to_string();
976                if let Some(rename) = serde_rename_from_attrs(&variant.attrs)? {
977                    variant_name = rename;
978                }
979                let arm = match &variant.fields {
980                    Fields::Unit => quote! {
981                        ::rexlang::Value::Adt(tag, args)
982                            if (tag.as_ref() == #variant_name
983                                || tag.as_ref().rsplit('.').next() == Some(#variant_name))
984                                && args.is_empty() =>
985                        {
986                            Ok(Self::#variant_ident)
987                        }
988                    },
989                    Fields::Unnamed(fields) => {
990                        let len = fields.unnamed.len();
991                        let vals = fields
992                            .unnamed
993                            .iter()
994                            .enumerate()
995                            .map(|(i, f)| {
996                                from_value_expr(
997                                    quote!(&heap.get(&args[#i])?.as_ref().clone()),
998                                    &f.ty,
999                                    name_expr.clone(),
1000                                )
1001                            })
1002                            .collect::<Result<Vec<_>, _>>()?
1003                            .into_iter()
1004                            .map(|d| quote!(#d?))
1005                            .collect::<Vec<_>>();
1006                        quote! {
1007                            ::rexlang::Value::Adt(tag, args)
1008                                if (tag.as_ref() == #variant_name
1009                                    || tag.as_ref().rsplit('.').next() == Some(#variant_name))
1010                                    && args.len() == #len =>
1011                            {
1012                                Ok(Self::#variant_ident(#(#vals,)*))
1013                            }
1014                        }
1015                    }
1016                    Fields::Named(fields) => {
1017                        let mut field_decodes = Vec::new();
1018                        let mut fields_init = Vec::new();
1019                        for field in &fields.named {
1020                            let ident = field
1021                                .ident
1022                                .as_ref()
1023                                .ok_or_else(|| Error::new(field.span(), "expected named field"))?;
1024                            fields_init.push(ident.clone());
1025                            let mut name = ident.to_string();
1026                            if let Some(rename) = serde_rename_from_attrs(&field.attrs)? {
1027                                name = rename;
1028                            }
1029                            let key = quote!(::rexlang::intern(#name));
1030                            let decode = from_value_expr(
1031                                quote!(&heap.get(&v)?.as_ref().clone()),
1032                                &field.ty,
1033                                name_expr.clone(),
1034                            )?;
1035                            field_decodes.push(quote! {
1036                                let v = map.get(&#key).ok_or_else(|| ::rexlang::EngineError::NativeType { expected: format!("missing field `{}`", #name),
1037                                    got: "dict".into(),
1038                                })?;
1039                                let #ident = #decode?;
1040                            });
1041                        }
1042                        quote! {
1043                            ::rexlang::Value::Adt(tag, args)
1044                                if (tag.as_ref() == #variant_name
1045                                    || tag.as_ref().rsplit('.').next() == Some(#variant_name))
1046                                    && args.len() == 1 =>
1047                            {
1048                                match heap.get(&args[0])?.as_ref().clone() {
1049                                    ::rexlang::Value::Dict(map) => {
1050                                        #(#field_decodes)*
1051                                        Ok(Self::#variant_ident { #(#fields_init,)* })
1052                                    }
1053                                    other => Err(::rexlang::EngineError::NativeType { expected: "dict".into(),
1054                                        got: ::rexlang::value_debug(heap, &other)
1055                                            .unwrap_or_else(|err| format!("<display error: {err}>")),
1056                                    }),
1057                                }
1058                            }
1059                        }
1060                    }
1061                };
1062                arms.push(arm);
1063            }
1064
1065            Ok(quote! {{
1066                match value {
1067                    #(#arms,)*
1068                    other => Err(::rexlang::EngineError::NativeType { expected: #type_name.into(),
1069                        got: ::rexlang::value_debug(heap, &other)
1070                            .unwrap_or_else(|err| format!("<display error: {err}>")),
1071                    }),
1072                }
1073            }})
1074        }
1075        Data::Union(_) => Err(Error::new(
1076            ast.span(),
1077            "`#[derive(Rex)]` only supports structs and enums",
1078        )),
1079    }?;
1080
1081    let mut generics = ast.generics.clone();
1082    add_bound_to_type_params(&mut generics, parse_quote!(::rexlang::FromPointer));
1083    let (impl_generics, _, where_clause) = generics.split_for_impl();
1084    let (_, ty_generics, _) = generics.split_for_impl();
1085
1086    Ok(quote! {
1087        impl #impl_generics ::rexlang::FromPointer for #rust_ident #ty_generics #where_clause {
1088            fn from_pointer(
1089                heap: &::rexlang::Heap,
1090                pointer: &::rexlang::Pointer,
1091            ) -> Result<Self, ::rexlang::EngineError> {
1092                let value = heap.get(&pointer)?.as_ref().clone();
1093                #body
1094            }
1095        }
1096    })
1097}