use std::{io::Read, panic, str::FromStr, vec};
use mlua::{Function, Lua, MultiValue, Table, Value};
use proc_macro2::{Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{spanned::Spanned, Type, TypeArray, TypeReference};
macro_rules! type_mismatch {
($expected:expr,$value:expr,$found:expr) => {
&format!(
"type mismatch:expected {} but value {} is {}",
$expected, $value, $found
)
};
}
#[proc_macro]
pub fn constlua(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let project_dir =
std::path::PathBuf::from_str(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).unwrap();
let input = proc_macro2::TokenStream::from(input);
let mut input = input.into_iter();
let name: TokenTree = match input.next() {
Some(TokenTree::Ident(ident)) => ident.into(),
Some(other) => return abort(other.span(), "expected const value name"),
None => panic!("expected const value name"),
};
match input.next() {
Some(TokenTree::Punct(punct)) if punct.as_char() == ':' => (),
Some(other) => return abort(other.span(), "expected :"),
None => panic!("expected :"),
};
let mut ty_tokens = vec![];
for token in &mut input {
match token {
TokenTree::Punct(punct) if punct.as_char() == '=' => {
break;
}
other => {
ty_tokens.push(other);
}
}
}
let ty_tokens = TokenStream::from_iter(ty_tokens.into_iter());
let ty_span = ty_tokens.span();
let mut ty: syn::Type = match syn::parse2(ty_tokens) {
Ok(ty) => ty,
Err(_) => return abort(ty_span, "Syntax Error: expected type"),
};
let mut luacode = vec![];
luacode.push(quote!(function()).into_iter().collect());
let mut line = vec![];
let span0 = get_range(&input);
let mut spans = vec![span0];
for token in input {
match &token {
TokenTree::Punct(punct) if punct.as_char() == ';' => {
spans.push(get_range(&line.clone().into_iter()));
luacode.push(line);
line = vec![];
continue;
}
_ => (),
}
line.push(token);
}
spans.push(get_range(&line.clone().into_iter()));
while line.is_empty() {
line = luacode.pop().expect("no return value");
}
match &line[0] {
TokenTree::Ident(ident) if ident == "return" => (),
_ => line.insert(0, quote!(return).into_iter().next().unwrap()),
}
luacode.push(line);
luacode.push(quote!(end).into_iter().collect());
let lua = Lua::new();
let lua_path = project_dir.join("lua");
if lua_path.is_dir() {
let dir = std::fs::read_dir(lua_path).unwrap();
for entry in dir.flatten() {
let path = entry.path();
if let (Some(name), true, Some("lua")) = (
path.file_stem(),
path.is_file(),
path.extension().and_then(|n| n.to_str()),
) {
let mut mod_code = "function()".to_string();
std::fs::File::open(&path)
.unwrap()
.read_to_string(&mut mod_code)
.unwrap();
mod_code.push_str("\nend");
let m: Function = match lua.load(&mod_code).eval() {
Ok(m) => m,
Err(e) => return abort2_start_end(span0.0, span0.1, &e.to_string()).into(),
};
if let Err(e) =
lua.load_from_function::<_, Value>(name.to_string_lossy().as_ref(), m)
{
return abort2_start_end(span0.0, span0.1, &e.to_string()).into();
};
}
}
}
let luacode: Vec<_> = luacode
.into_iter()
.map(|tokens| TokenStream::from_iter(tokens.into_iter()).to_string())
.collect();
let luacode = luacode.join("\n");
let chuck = lua.load(&luacode).set_name("constlua.lua").unwrap();
let f = match chuck.eval::<Function>() {
Ok(f) => f,
Err(e) => return print_err_with_span(e, &spans),
};
let value: MultiValue = match f.call::<(), MultiValue>(()) {
Ok(value) => value,
Err(e) => return print_err_with_span(e, &spans),
};
let value = match value.into_iter().next() {
Some(value) => value,
None => Value::Nil,
};
let value = match print_value(value, &mut ty) {
Ok(val) => val,
Err(val) => return val.into(),
};
let output = quote! {
const #name : #ty = #value;
};
output.into()
}
fn print_multi_value(table: Table, ty: &mut Type) -> Result<TokenStream, TokenStream> {
let table_info = format!("{:?}", table);
let val: Vec<_> = table.sequence_values().filter_map(|val| val.ok()).collect();
let sub_ty = match ty {
Type::Array(arr) => arr.elem.as_mut(),
Type::Slice(slice) => {
let lit_int = syn::Lit::Int(syn::LitInt::new(
val.len().to_string().as_str(),
slice.bracket_token.span,
));
let len_lit = syn::Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: lit_int,
});
let bracket_token = slice.bracket_token;
let semi_token_span = bracket_token.span;
let elem = slice.elem.clone();
*ty = Type::Array(TypeArray {
bracket_token,
semi_token: syn::Token!(;)(semi_token_span),
elem,
len: len_lit,
});
match ty {
Type::Array(a) => &mut a.elem,
_ => unreachable!(),
}
}
Type::Tuple(tuple) => {
let val: Result<Vec<TokenStream>, TokenStream> = tuple
.elems
.iter_mut()
.zip(val.into_iter())
.map(|(ty, val)| print_value(val, ty))
.collect();
let val = val?;
return Ok(quote!(
( #( #val ),* )
));
}
_ => {
return Err(abort2(
ty.span(),
type_mismatch!(ty_name(ty), table_info, "table"),
))
}
};
let val: Result<Vec<TokenStream>, TokenStream> = val
.into_iter()
.map(|val| print_value(val, sub_ty))
.collect();
let val = val?;
Ok(quote!(
[ #( #val ),* ]
))
}
fn print_value(val: Value, ty: &mut Type) -> Result<TokenStream, TokenStream> {
let ok = match val {
Value::Boolean(b) => {
if ty_name(ty) != "bool" {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), b, "boolean")));
}
quote!(#b)
}
Value::Integer(i) => match ty_name(ty).as_str() {
"i64" => quote! {#i},
"i32" => {
let i = match i32::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"i16" => {
let i = match i16::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"i8" => {
let i = match i8::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"u64" => {
let i = match u64::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"u32" => {
let i = match u32::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"u16" => {
let i = match u16::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"u8" => {
let i = match u8::try_from(i) {
Ok(val) => val,
Err(_) => {
return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
}
};
quote! {#i}
}
"f64" => {
let i = i as f64;
quote! {#i}
}
"f32" => {
let i = i as f32;
quote! {#i}
}
_ => return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer"))),
},
Value::Number(n) => match ty_name(ty).as_str() {
"f64" => quote! {#n},
"f32" => {
let n = n as f32;
quote! {#n}
}
_ => return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), n, "number"))),
},
Value::String(s) => {
let s = s.to_string_lossy().to_string();
assert_type_str(ty, &s)?;
quote!(#s)
}
Value::Table(t) => print_multi_value(t, ty)?,
val => {
return Err(abort2(
ty.span(),
type_mismatch!(ty_name(ty), format!("{:?}", val), val.type_name()),
))
}
};
Ok(ok)
}
fn ty_name(ty: &Type) -> String {
ty.to_token_stream().to_string()
}
fn assert_type_str(ty: &Type, v: &str) -> std::result::Result<(), TokenStream> {
if let Type::Reference(TypeReference {
lifetime,
mutability: None,
elem,
..
}) = ty
{
if (ty_name(elem) == "str")
&& ((lifetime.is_none()) || (lifetime.as_ref().unwrap().ident == "static"))
{
return Ok(());
}
}
Err(abort2(ty.span(), type_mismatch!(ty_name(ty), v, "string")))
}
fn abort2(span: Span, msg: &str) -> TokenStream {
abort2_start_end(span, span, msg)
}
fn abort(span: Span, msg: &str) -> proc_macro::TokenStream {
abort2(span, msg).into()
}
fn abort2_start_end(start: Span, end: Span, msg: &str) -> TokenStream {
let mut msg = proc_macro2::Literal::string(msg);
msg.set_span(start);
let group = quote_spanned!(start => { #msg } );
quote_spanned!(end=>compile_error!#group)
}
fn print_err_with_span(error: mlua::Error, spans: &[(Span, Span)]) -> proc_macro::TokenStream {
use mlua::Error;
match error {
Error::SyntaxError { message, .. } | Error::RuntimeError(message) => {
let s: String = message
.chars()
.skip_while(|ch| !ch.is_ascii_digit())
.take_while(|ch| ch.is_ascii_digit())
.collect();
let n: usize = usize::from_str(&s).unwrap();
abort2_start_end(spans[n].0, spans[n].1, &message)
}
other => abort2_start_end(spans[0].0, spans[0].1, &other.to_string()),
}
.into()
}
fn get_range<I: Iterator<Item = TokenTree> + Clone>(iter: &I) -> (Span, Span) {
let mut it: I = iter.to_owned();
let start = it
.next()
.map(|token| token.span())
.unwrap_or_else(Span::call_site);
let end = it
.last()
.map(|token| token.span())
.unwrap_or_else(Span::call_site);
(start, end)
}