constlua 0.1.0

const fn by lua
Documentation
use std::{io::Read, panic, str::FromStr, vec};

use mlua::{Function, Lua, MultiValue, Table, Value};
use proc_macro2::{Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{spanned::Spanned, Type, TypeArray, TypeReference};

macro_rules! type_mismatch {
    ($expected:expr,$value:expr,$found:expr) => {
        &format!(
            "type mismatch:expected {} but value {} is {}",
            $expected, $value, $found
        )
    };
}

#[proc_macro]
pub fn constlua(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let project_dir =
        std::path::PathBuf::from_str(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).unwrap();
    let input = proc_macro2::TokenStream::from(input);
    let mut input = input.into_iter();
    let name: TokenTree = match input.next() {
        Some(TokenTree::Ident(ident)) => ident.into(),
        Some(other) => return abort(other.span(), "expected const value name"),
        None => panic!("expected const value name"),
    };
    match input.next() {
        Some(TokenTree::Punct(punct)) if punct.as_char() == ':' => (),
        Some(other) => return abort(other.span(), "expected :"),
        None => panic!("expected :"),
    };
    let mut ty_tokens = vec![];
    for token in &mut input {
        match token {
            TokenTree::Punct(punct) if punct.as_char() == '=' => {
                break;
            }
            other => {
                ty_tokens.push(other);
            }
        }
    }
    let ty_tokens = TokenStream::from_iter(ty_tokens.into_iter());
    let ty_span = ty_tokens.span();
    let mut ty: syn::Type = match syn::parse2(ty_tokens) {
        Ok(ty) => ty,
        Err(_) => return abort(ty_span, "Syntax Error: expected type"),
    };
    let mut luacode = vec![];
    luacode.push(quote!(function()).into_iter().collect());
    let mut line = vec![];
    let span0 = get_range(&input);
    let mut spans = vec![span0];
    for token in input {
        match &token {
            TokenTree::Punct(punct) if punct.as_char() == ';' => {
                spans.push(get_range(&line.clone().into_iter()));
                luacode.push(line);
                line = vec![];
                continue;
            }
            _ => (),
        }
        line.push(token);
    }
    spans.push(get_range(&line.clone().into_iter()));
    while line.is_empty() {
        line = luacode.pop().expect("no return value");
    }
    match &line[0] {
        TokenTree::Ident(ident) if ident == "return" => (),
        _ => line.insert(0, quote!(return).into_iter().next().unwrap()),
    }
    luacode.push(line);
    luacode.push(quote!(end).into_iter().collect());
    let lua = Lua::new();
    let lua_path = project_dir.join("lua");
    if lua_path.is_dir() {
        let dir = std::fs::read_dir(lua_path).unwrap();
        for entry in dir.flatten() {
            let path = entry.path();
            if let (Some(name), true, Some("lua")) = (
                path.file_stem(),
                path.is_file(),
                path.extension().and_then(|n| n.to_str()),
            ) {
                let mut mod_code = "function()".to_string();
                std::fs::File::open(&path)
                    .unwrap()
                    .read_to_string(&mut mod_code)
                    .unwrap();
                mod_code.push_str("\nend");
                let m: Function = match lua.load(&mod_code).eval() {
                    Ok(m) => m,
                    Err(e) => return abort2_start_end(span0.0, span0.1, &e.to_string()).into(),
                };
                if let Err(e) =
                    lua.load_from_function::<_, Value>(name.to_string_lossy().as_ref(), m)
                {
                    return abort2_start_end(span0.0, span0.1, &e.to_string()).into();
                };
            }
        }
    }
    let luacode: Vec<_> = luacode
        .into_iter()
        .map(|tokens| TokenStream::from_iter(tokens.into_iter()).to_string())
        .collect();

    let luacode = luacode.join("\n");
    let chuck = lua.load(&luacode).set_name("constlua.lua").unwrap();
    let f = match chuck.eval::<Function>() {
        Ok(f) => f,
        Err(e) => return print_err_with_span(e, &spans),
    };
    let value: MultiValue = match f.call::<(), MultiValue>(()) {
        Ok(value) => value,
        Err(e) => return print_err_with_span(e, &spans),
    };
    let value = match value.into_iter().next() {
        Some(value) => value,
        None => Value::Nil,
    };
    let value = match print_value(value, &mut ty) {
        Ok(val) => val,
        Err(val) => return val.into(),
    };
    let output = quote! {
        const #name : #ty = #value;
    };
    output.into()
}

fn print_multi_value(table: Table, ty: &mut Type) -> Result<TokenStream, TokenStream> {
    let table_info = format!("{:?}", table);
    let val: Vec<_> = table.sequence_values().filter_map(|val| val.ok()).collect();

    let sub_ty = match ty {
        Type::Array(arr) => arr.elem.as_mut(),
        Type::Slice(slice) => {
            let lit_int = syn::Lit::Int(syn::LitInt::new(
                val.len().to_string().as_str(),
                slice.bracket_token.span,
            ));
            let len_lit = syn::Expr::Lit(syn::ExprLit {
                attrs: vec![],
                lit: lit_int,
            });
            let bracket_token = slice.bracket_token;
            let semi_token_span = bracket_token.span;
            let elem = slice.elem.clone();
            *ty = Type::Array(TypeArray {
                bracket_token,
                semi_token: syn::Token!(;)(semi_token_span),
                elem,
                len: len_lit,
            });
            match ty {
                Type::Array(a) => &mut a.elem,
                _ => unreachable!(),
            }
        }
        Type::Tuple(tuple) => {
            let val: Result<Vec<TokenStream>, TokenStream> = tuple
                .elems
                .iter_mut()
                .zip(val.into_iter())
                .map(|(ty, val)| print_value(val, ty))
                .collect();
            let val = val?;
            return Ok(quote!(
                ( #( #val ),* )
            ));
        }
        _ => {
            return Err(abort2(
                ty.span(),
                type_mismatch!(ty_name(ty), table_info, "table"),
            ))
        }
    };
    let val: Result<Vec<TokenStream>, TokenStream> = val
        .into_iter()
        .map(|val| print_value(val, sub_ty))
        .collect();
    let val = val?;
    Ok(quote!(
        [ #( #val ),* ]
    ))
}

fn print_value(val: Value, ty: &mut Type) -> Result<TokenStream, TokenStream> {
    let ok = match val {
        Value::Boolean(b) => {
            if ty_name(ty) != "bool" {
                return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), b, "boolean")));
            }
            quote!(#b)
        }
        Value::Integer(i) => match ty_name(ty).as_str() {
            "i64" => quote! {#i},
            "i32" => {
                let i = match i32::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "i16" => {
                let i = match i16::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "i8" => {
                let i = match i8::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "u64" => {
                let i = match u64::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "u32" => {
                let i = match u32::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "u16" => {
                let i = match u16::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "u8" => {
                let i = match u8::try_from(i) {
                    Ok(val) => val,
                    Err(_) => {
                        return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer")))
                    }
                };
                quote! {#i}
            }
            "f64" => {
                let i = i as f64;
                quote! {#i}
            }
            "f32" => {
                let i = i as f32;
                quote! {#i}
            }
            _ => return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), i, "integer"))),
        },
        Value::Number(n) => match ty_name(ty).as_str() {
            "f64" => quote! {#n},
            "f32" => {
                let n = n as f32;
                quote! {#n}
            }
            _ => return Err(abort2(ty.span(), type_mismatch!(ty_name(ty), n, "number"))),
        },
        Value::String(s) => {
            let s = s.to_string_lossy().to_string();
            assert_type_str(ty, &s)?;
            quote!(#s)
        }
        Value::Table(t) => print_multi_value(t, ty)?,
        val => {
            return Err(abort2(
                ty.span(),
                type_mismatch!(ty_name(ty), format!("{:?}", val), val.type_name()),
            ))
        }
    };
    Ok(ok)
}

fn ty_name(ty: &Type) -> String {
    ty.to_token_stream().to_string()
}

fn assert_type_str(ty: &Type, v: &str) -> std::result::Result<(), TokenStream> {
    if let Type::Reference(TypeReference {
        lifetime,
        mutability: None,
        elem,
        ..
    }) = ty
    {
        if (ty_name(elem) == "str")
            && ((lifetime.is_none()) || (lifetime.as_ref().unwrap().ident == "static"))
        {
            return Ok(());
        }
    }
    Err(abort2(ty.span(), type_mismatch!(ty_name(ty), v, "string")))
}

fn abort2(span: Span, msg: &str) -> TokenStream {
    abort2_start_end(span, span, msg)
}

fn abort(span: Span, msg: &str) -> proc_macro::TokenStream {
    abort2(span, msg).into()
}

fn abort2_start_end(start: Span, end: Span, msg: &str) -> TokenStream {
    let mut msg = proc_macro2::Literal::string(msg);
    msg.set_span(start);
    let group = quote_spanned!(start => { #msg } );
    quote_spanned!(end=>compile_error!#group)
}

fn print_err_with_span(error: mlua::Error, spans: &[(Span, Span)]) -> proc_macro::TokenStream {
    use mlua::Error;
    match error {
        Error::SyntaxError { message, .. } | Error::RuntimeError(message) => {
            let s: String = message
                .chars()
                .skip_while(|ch| !ch.is_ascii_digit())
                .take_while(|ch| ch.is_ascii_digit())
                .collect();
            let n: usize = usize::from_str(&s).unwrap();
            abort2_start_end(spans[n].0, spans[n].1, &message)
        }
        other => abort2_start_end(spans[0].0, spans[0].1, &other.to_string()),
    }
    .into()
}

fn get_range<I: Iterator<Item = TokenTree> + Clone>(iter: &I) -> (Span, Span) {
    let mut it: I = iter.to_owned();
    let start = it
        .next()
        .map(|token| token.span())
        .unwrap_or_else(Span::call_site);
    let end = it
        .last()
        .map(|token| token.span())
        .unwrap_or_else(Span::call_site);
    (start, end)
}