use convert_case::{Case, Casing};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, GenericArgument, Lifetime, Token, Type};
use crate::{
command_enum_parsing::{Field, FunctionDeclaration, NamespacePath},
generate::{get_input_type_of_bare_fn_field, get_return_type_of_bare_fn_field},
DataCommandEnum,
};
pub fn generate_rust_wrapper_functions(
namespace: Option<&NamespacePath>,
input: &DataCommandEnum,
) -> TokenStream2 {
generate_rust_wrapper_functions_rec(namespace, &input.fields)
}
pub fn generate_rust_wrapper_functions_rec(
namespace: Option<&NamespacePath>,
input: &Punctuated<Field, Token![,]>,
) -> TokenStream2 {
let wrapped_functions: TokenStream2 = input
.iter()
.map(|field| match field {
Field::Function(fun_field) => {
wrap_lua_function(namespace.unwrap_or(&Default::default()), fun_field)
}
Field::Namespace(nasp) => {
let mut passed_namespace = namespace.unwrap_or(&Default::default()).clone();
nasp.path
.clone()
.into_iter()
.for_each(|val| passed_namespace.push(val));
generate_rust_wrapper_functions_rec(Some(&passed_namespace), &nasp.fields)
}
})
.collect();
quote! {
#wrapped_functions
}
}
fn wrap_lua_function(namespace: &NamespacePath, field: &FunctionDeclaration) -> TokenStream2 {
let input_type = get_input_type_of_bare_fn_field(field);
let return_type = get_return_type_of_bare_fn_field(field);
let function_name = &field.name;
let function_body = get_function_body(&namespace, field, input_type.is_some(), &return_type);
let lifetime_args =
get_and_add_lifetimes_form_inputs_and_outputs(input_type.clone(), return_type);
let input_type = input_type
.unwrap_or(syn::parse(quote! {()}.into()).expect("This is static, it always works"));
quote! {
async fn #function_name <#lifetime_args>(
lua: &mlua::Lua,
input: #input_type
) -> Result<mlua::Value, mlua::Error> {
#function_body
}
}
}
fn get_and_add_lifetimes_form_inputs_and_outputs<'a>(
input_type: Option<syn::Type>,
return_type: Option<syn::Type>,
) -> Punctuated<Lifetime, Comma> {
fn get_lifetime_args_from_type<'a>(return_type: syn::Type) -> Option<Vec<Lifetime>> {
match return_type {
syn::Type::Path(path) => {
let args_to_final_path_segment = &path
.path
.segments
.last()
.expect("The path should have a last segment")
.arguments;
match args_to_final_path_segment {
syn::PathArguments::None =>
{
None
}
syn::PathArguments::AngleBracketed(angle) => {
let lifetime_args: Vec<_> = angle
.args
.iter()
.filter_map(|arg| {
if let GenericArgument::Lifetime(lifetime) = arg {
Some(lifetime.to_owned())
} else {
None
}
})
.collect();
return Some(lifetime_args);
}
syn::PathArguments::Parenthesized(_) => todo!("Parenthesized Life time"),
}
}
syn::Type::Tuple(_) => {
dbg!("Ignoring tuple lifetime!");
None
}
non_path => todo!("Non path lifetime: {:#?}", non_path),
}
}
let mut output: Punctuated<Lifetime, Comma> = Punctuated::new();
if let Some(input_type) = input_type {
let lifetime_args = get_lifetime_args_from_type(input_type).unwrap_or(vec![]);
lifetime_args.into_iter().for_each(|arg| output.push(arg));
}
if let Some(return_type) = return_type {
let lifetime_args = get_lifetime_args_from_type(return_type).unwrap_or(vec![]);
lifetime_args.into_iter().for_each(|arg| output.push(arg));
}
output
}
fn get_function_body(
namespace: &NamespacePath,
field: &FunctionDeclaration,
has_input: bool,
output_type: &Option<Type>,
) -> TokenStream2 {
let command_name = field
.name
.to_string()
.from_case(Case::Snake)
.to_case(Case::Pascal);
let command_ident = {
if has_input {
format!("{}(", command_name)
} else {
command_name.clone()
}
};
let command_namespace: String = {
namespace
.iter()
.map(|path| {
let path_enum_name: String = path
.to_string()
.from_case(Case::Snake)
.to_case(Case::Pascal);
path_enum_name.clone() + "(" + &path_enum_name + "::"
})
.collect::<Vec<String>>()
.join("")
};
let send_output: TokenStream2 = {
let finishing_brackets = {
if has_input {
let mut output = "input.clone()".to_owned();
output.push_str(&(0..namespace.len()).map(|_| ')').collect::<String>());
output
} else {
(0..namespace.len()).map(|_| ')').collect::<String>()
}
};
("Event::CommandEvent( Command::".to_owned()
+ &command_namespace
+ &command_ident
+ &finishing_brackets
+ {if has_input {")"} else {""}}
+ ",Some(callback_tx))")
.parse()
.expect("This code should be valid")
};
let function_return = if let Some(_) = output_type {
quote! {
return Ok(output.into_lua(lua).expect("This conversion should always work"));
}
} else {
quote! {
return Ok(mlua::Value::Nil);
}
};
let does_function_expect_output = if output_type.is_some() {
quote! {
return Err(mlua::Error::ExternalError(std::sync::Arc::new(
err
)));
}
} else {
quote! {
return Ok(mlua::Value::Nil);
}
};
quote! {
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel::<CommandTransferValue>();
let tx: mlua::AppDataRef<tokio::sync::mpsc::Sender<Event>> =
lua.app_data_ref().expect("This should exist, it was set before");
(*tx)
.send(#send_output)
.await
.expect("This should work, as the receiver is not dropped");
cli_log::info!("Sent CommandEvent: `{}`", #command_name);
match callback_rx.await {
Ok(output) => {
cli_log::info!(
"Lua function: `{}` returned output to lua: `{}`", #command_name, &output
);
#function_return
},
Err(err) => {
#does_function_expect_output
}
};
}
}