use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{DeriveInput, parse_macro_input};
#[proc_macro_derive(SwampExport, attributes(swamp))]
pub fn derive_swamp_export(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let fields = match input.data {
syn::Data::Struct(ref data) => &data.fields,
_ => panic!("SwampExport can only be derived for structs"),
};
let from_field_extractions = fields.iter().enumerate().map(|(index, f)| {
let field_name = &f.ident.as_ref().unwrap();
let field_type = &f.ty;
quote! {
let #field_name = <#field_type>::from_swamp_value(&values[#index])?;
}
});
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let expanded = quote! {
impl SwampExport for #name {
fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
let fields = vec![
#((stringify!(#field_names), <#field_types>::get_resolved_type(registry))),*
];
registry.register_derived_struct(stringify!(#name), fields)
}
fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
let mut values = Vec::new();
#(values.push(self.#field_names.to_swamp_value(registry));)*
let resolved_type = Self::get_resolved_type(registry);
match &resolved_type {
ResolvedType::Struct(struct_type) => {
Value::Struct(struct_type.clone(), values, resolved_type)
},
_ => unreachable!("get_resolved_type returned non-struct type")
}
}
fn from_swamp_value(value: &Value) -> Result<Self, String> {
match value {
Value::Struct(struct_type_ref, values, _) => {
if struct_type_ref.borrow().name.text != stringify!(#name) {
return Err(format!(
"Expected {} struct, got {}",
stringify!(#name),
struct_type_ref.borrow().name.text
));
}
#(#from_field_extractions)*
Ok(Self {
#(#field_names),*
})
}
_ => Err(format!("Expected {} struct", stringify!(#name)))
}
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn swamp_fn(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as syn::ItemFn);
let fn_name = &input_fn.sig.ident;
let module_name = format_ident!("swamp_{}", fn_name.to_string().to_lowercase());
let context_type = match &input_fn.sig.inputs[0] {
syn::FnArg::Typed(pat_type) => &*pat_type.ty,
_ => panic!("First parameter must be the context type"),
};
let context_inner_type = match context_type {
syn::Type::Reference(type_ref) => &*type_ref.elem,
_ => panic!("Context parameter must be a mutable reference"),
};
let return_type = match &input_fn.sig.output {
syn::ReturnType::Default => quote!(<()>::get_resolved_type(registry)),
syn::ReturnType::Type(_, ty) => quote!(<#ty>::get_resolved_type(registry)),
};
let args = input_fn
.sig
.inputs
.iter()
.skip(1)
.map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
let pat = &pat_type.pat;
let ty = &pat_type.ty;
(pat, ty)
} else {
panic!("self parameters not supported yet")
}
})
.collect::<Vec<_>>();
let arg_count = args.len();
let arg_indices = 0..arg_count;
let (patterns, types): (Vec<_>, Vec<_>) = args.iter().copied().unzip();
let expanded = quote! {
#input_fn
mod #module_name {
use super::*;
use swamp_script_core_extra::prelude::*;
pub struct Function {
pub name: &'static str,
pub function_id: ExternalFunctionId,
}
impl Function {
pub fn new(function_id: ExternalFunctionId) -> Self {
Self {
name: stringify!(#fn_name),
function_id,
}
}
pub fn handler<'a>(
&'a self,
registry: &'a TypeRegistry,
) -> Box<dyn FnMut(&[Value], &mut #context_inner_type) -> Result<Value, ValueError> + 'a> {
Box::new(move |args: &[Value], ctx: &mut #context_inner_type| {
if args.len() != #arg_count {
return Err(ValueError::WrongNumberOfArguments {
expected: #arg_count,
got: args.len(),
});
}
#(
let #patterns = <#types>::from_swamp_value(&args[#arg_indices])
.map_err(|e| ValueError::TypeError(e))?;
)*
let result = super::#fn_name(ctx, #(#patterns),*);
Ok(result.to_swamp_value(registry))
})
}
pub fn get_definition(&self, registry: &TypeRegistry) -> ResolvedExternalFunctionDefinition {
ResolvedExternalFunctionDefinition {
name: LocalIdentifier::from_str(self.name),
signature: ResolvedFunctionSignature {
parameters: vec![
#(ResolvedParameter {
name: stringify!(#patterns).to_string(),
resolved_type: <#types>::get_resolved_type(registry),
ast_parameter: Parameter::default(),
is_mutable: false,
},)*
],
return_type: #return_type,
},
id: self.function_id,
}
}
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(SwampExportEnum, attributes(swamp))]
pub fn derive_swamp_export_enum(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let expanded = match input.data {
syn::Data::Enum(ref data) => {
let variant_matches = data.variants.iter().enumerate().map(|(variant_index, variant)| {
let variant_name = &variant.ident;
match &variant.fields {
syn::Fields::Unit => {
quote! {
#name::#variant_name => {
let variant_type = ResolvedEnumVariantType {
owner: enum_type.clone(),
data: ResolvedEnumVariantContainerType::Nothing,
name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
number: #variant_index as TypeNumber,
};
Value::EnumVariantSimple(Rc::new(variant_type))
}
}
}
syn::Fields::Named(fields) => {
let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
let field_type_conversions = field_types.iter().map(|ty| {
match quote!(#ty).to_string().as_str() {
"f32" => quote! { registry.get_float_type() },
"i32" => quote! { registry.get_int_type() },
"bool" => quote! { registry.get_bool_type() },
"String" => quote! { registry.get_string_type() },
ty => quote! { panic!("Unsupported type: {}", #ty) },
}
});
let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
match quote!(#ty).to_string().as_str() {
"f32" => quote! { Value::Float(Fp::from(*#name)) },
"i32" => quote! { Value::Int(*#name) },
"bool" => quote! { Value::Bool(*#name) },
"String" => quote! { Value::String(#name.clone()) },
ty => quote! { panic!("Unsupported type: {}", #ty) },
}
});
quote! {
#name::#variant_name { #(ref #field_names),* } => {
let mut fields = SeqMap::new();
#(
fields.insert(
IdentifierName(stringify!(#field_names).to_string()),
#field_type_conversions
);
)*
let common = CommonEnumVariantType {
number: #variant_index as TypeNumber,
module_path: ModulePath::new(),
variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
enum_ref: enum_type.clone(),
};
let variant_struct = Rc::new(ResolvedEnumVariantStructType {
common,
fields,
ast_struct: AnonymousStruct::default(),
});
let values = vec![
#(#field_value_conversions),*
];
Value::EnumVariantStruct(variant_struct, values)
}
}
}
syn::Fields::Unnamed(fields) => {
let field_types: Vec<_> = fields.unnamed.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = (0..field_types.len())
.map(|i| format_ident!("field_{}", i))
.collect::<Vec<_>>();
let field_type_conversions = field_types.iter().map(|ty| {
match quote!(#ty).to_string().as_str() {
"f32" => quote! { registry.get_float_type() },
"i32" => quote! { registry.get_int_type() },
"bool" => quote! { registry.get_bool_type() },
"String" => quote! { registry.get_string_type() },
ty => quote! { panic!("Unsupported type: {}", #ty) },
}
});
let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
match quote!(#ty).to_string().as_str() {
"f32" => quote! { Value::Float(Fp::from(*#name)) },
"i32" => quote! { Value::Int(*#name) },
"bool" => quote! { Value::Bool(*#name) },
"String" => quote! { Value::String(#name.clone()) },
ty => quote! { panic!("Unsupported type: {}", #ty) },
}
});
quote! {
#name::#variant_name(#(ref #field_names),*) => {
let fields_in_order = vec![
#(#field_type_conversions),*
];
let common = CommonEnumVariantType {
number: #variant_index as TypeNumber,
module_path: ModulePath::new(),
variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
enum_ref: enum_type.clone(),
};
let variant_tuple = Rc::new(ResolvedEnumVariantTupleType {
common,
fields_in_order,
});
let values = vec![
#(#field_value_conversions),*
];
Value::EnumVariantTuple(variant_tuple, values)
}
}
}
}
});
quote! {
impl SwampExport for #name {
fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
let enum_type = Rc::new(ResolvedEnumType {
name: LocalTypeIdentifier::from_str(stringify!(#name)),
number: registry.allocate_type_number(),
module_path: ModulePath(vec![]),
});
ResolvedType::Enum(enum_type)
}
fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
let enum_type = match Self::get_resolved_type(registry) {
ResolvedType::Enum(t) => t,
_ => unreachable!(),
};
match self {
#(#variant_matches),*
}
}
fn from_swamp_value(value: &Value) -> Result<Self, String> {
match value {
Value::EnumVariantSimple(_) |
Value::EnumVariantTuple(_, _) |
Value::EnumVariantStruct(_, _) => {
todo!("Implement from_swamp_value for enums") }
_ => Err(format!("Expected enum variant, got {:?}", value))
}
}
}
}
}
_ => panic!("SwampExportEnum can only be derived for enums"),
};
TokenStream::from(expanded)
}