use std::env;
use std::io::{self, BufRead, IsTerminal, Write};
use std::process;
use rilua::{Function, Lua, LuaApiMut, LuaError, StdLib, Val};
#[cfg(unix)]
mod sigint {
use rilua::{clear_interrupted, set_interrupted};
const SIGINT: i32 = 2;
#[expect(unsafe_code, reason = "raw POSIX signal FFI")]
unsafe extern "C" {
fn signal(signum: i32, handler: Option<extern "C" fn(i32)>) -> Option<extern "C" fn(i32)>;
}
extern "C" fn handler(_sig: i32) {
#[expect(unsafe_code, reason = "reset signal handler in signal context")]
unsafe {
signal(SIGINT, None);
}
set_interrupted();
}
pub(crate) fn with_sigint<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
clear_interrupted();
#[expect(unsafe_code, reason = "POSIX signal handling requires unsafe")]
unsafe {
signal(SIGINT, Some(handler));
}
let result = f();
#[expect(unsafe_code, reason = "POSIX signal handling requires unsafe")]
unsafe {
signal(SIGINT, None);
}
result
}
}
#[cfg(windows)]
mod sigint {
use rilua::{clear_interrupted, set_interrupted};
type BOOL = i32;
type DWORD = u32;
const TRUE: BOOL = 1;
const FALSE: BOOL = 0;
const CTRL_C_EVENT: DWORD = 0;
#[expect(unsafe_code, reason = "Win32 console control handler FFI")]
unsafe extern "system" {
fn SetConsoleCtrlHandler(
handler: Option<extern "system" fn(DWORD) -> BOOL>,
add: BOOL,
) -> BOOL;
}
extern "system" fn handler(ctrl_type: DWORD) -> BOOL {
if ctrl_type == CTRL_C_EVENT {
set_interrupted();
TRUE
} else {
FALSE
}
}
pub(crate) fn with_sigint<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
clear_interrupted();
#[expect(unsafe_code, reason = "Win32 SetConsoleCtrlHandler requires unsafe")]
unsafe {
SetConsoleCtrlHandler(Some(handler), TRUE);
}
let result = f();
#[expect(unsafe_code, reason = "Win32 SetConsoleCtrlHandler requires unsafe")]
unsafe {
SetConsoleCtrlHandler(Some(handler), FALSE);
}
result
}
}
#[cfg(not(any(unix, windows)))]
mod sigint {
pub(crate) fn with_sigint<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
f()
}
}
use sigint::with_sigint;
const LUA_VERSION: &str = "Lua 5.1.1 Copyright (C) 1994-2006 Lua.org, PUC-Rio";
fn l_message(progname: Option<&str>, msg: &str) {
if let Some(name) = progname {
eprint!("{name}: ");
}
eprintln!("{msg}");
}
fn report(progname: Option<&str>, err: &LuaError) -> bool {
let msg = err.to_string();
if msg.is_empty() {
return false;
}
l_message(progname, &msg);
true
}
fn is_incomplete(err: &LuaError) -> bool {
if let LuaError::Syntax(e) = err {
e.message.ends_with("'<eof>'")
} else {
false
}
}
struct Flags {
has_i: bool,
has_v: bool,
has_e: bool,
script: usize,
}
fn collect_args(argv: &[String]) -> Result<Flags, ()> {
let mut has_i = false;
let mut has_v = false;
let mut has_e = false;
let mut i = 1;
while i < argv.len() {
let arg = &argv[i];
if !arg.starts_with('-') {
return Ok(Flags {
has_i,
has_v,
has_e,
script: i,
});
}
let bytes = arg.as_bytes();
if bytes.len() < 2 {
return Ok(Flags {
has_i,
has_v,
has_e,
script: i,
});
}
match bytes[1] {
b'-' => {
if bytes.len() != 2 {
return Err(());
}
let script = if i + 1 < argv.len() { i + 1 } else { 0 };
return Ok(Flags {
has_i,
has_v,
has_e,
script,
});
}
b'i' => {
if bytes.len() != 2 {
return Err(());
}
has_i = true;
has_v = true;
}
b'v' => {
if bytes.len() != 2 {
return Err(());
}
has_v = true;
}
b'e' => {
has_e = true;
if bytes.len() == 2 {
i += 1;
if i >= argv.len() {
return Err(());
}
}
}
b'l' => {
if bytes.len() == 2 {
i += 1;
if i >= argv.len() {
return Err(());
}
}
}
_ => return Err(()),
}
i += 1;
}
Ok(Flags {
has_i,
has_v,
has_e,
script: 0,
})
}
fn print_usage(progname: &str) {
eprintln!("usage: {progname} [options] [script [args]].");
eprintln!("Available options are:");
eprintln!(" -e stat execute string 'stat'");
eprintln!(" -l name require library 'name'");
eprintln!(" -i enter interactive mode after executing 'script'");
eprintln!(" -v show version information");
eprintln!(" -- stop handling options");
eprintln!(" - execute stdin and stop handling options");
}
fn handle_lua_init(lua: &mut Lua) -> Result<(), ()> {
let Ok(init) = env::var("LUA_INIT") else {
return Ok(());
};
let result = if let Some(path) = init.strip_prefix('@') {
lua.exec_file(path)
} else {
lua.exec_bytes(init.as_bytes(), "=LUA_INIT")
};
match result {
Ok(()) => Ok(()),
Err(e) => {
report(None, &e);
Err(())
}
}
}
fn run_args(lua: &mut Lua, argv: &[String], script_idx: usize, progname: Option<&str>) -> bool {
let limit = if script_idx > 0 {
script_idx
} else {
argv.len()
};
let mut i = 1;
while i < limit {
let arg = &argv[i];
if !arg.starts_with('-') {
break;
}
let bytes = arg.as_bytes();
if bytes.len() < 2 {
break;
}
match bytes[1] {
b'e' => {
let chunk = if bytes.len() > 2 {
&arg[2..]
} else {
i += 1;
&argv[i]
};
match lua.load_bytes(chunk.as_bytes(), "=(command line)") {
Ok(func) => {
if let Err(e) = lua.call_function_traced(&func, &[]) {
report(progname, &e);
return true;
}
}
Err(e) => {
report(progname, &e);
return true;
}
}
}
b'l' => {
let lib_name = if bytes.len() > 2 {
&arg[2..]
} else {
i += 1;
&argv[i]
};
if do_library(lua, lib_name, progname).is_err() {
return true;
}
}
_ => {}
}
i += 1;
}
false
}
fn do_library(lua: &mut Lua, name: &str, progname: Option<&str>) -> Result<(), ()> {
let Ok(require_fn) = lua.global::<Function>("require") else {
l_message(progname, "require not available");
return Err(());
};
let name_val = lua.create_string(name.as_bytes());
match lua.call_function(&require_fn, &[name_val]) {
Ok(_) => Ok(()),
Err(e) => {
report(progname, &e);
Err(())
}
}
}
fn build_arg_table(lua: &mut Lua, argv: &[String], script_idx: usize) -> Vec<Val> {
let arg_table = lua.create_table();
for (i, a) in argv.iter().enumerate() {
let key = Val::Num((i as f64) - (script_idx as f64));
let value = lua.create_string(a.as_bytes());
let _ = lua.table_raw_set(&arg_table, key, value);
}
let _ = lua.set_global("arg", Val::Table(arg_table.gc_ref()));
let mut script_args = Vec::new();
for a in argv.iter().skip(script_idx + 1) {
script_args.push(lua.create_string(a.as_bytes()));
}
script_args
}
fn handle_script(
lua: &mut Lua,
argv: &[String],
script_idx: usize,
progname: Option<&str>,
) -> bool {
let script_args = build_arg_table(lua, argv, script_idx);
let fname = &argv[script_idx];
let load_result = if fname == "-" && (script_idx == 0 || argv[script_idx - 1] != "--") {
lua.load_file(None)
} else {
lua.load_file(Some(fname))
};
let func = match load_result {
Ok(f) => f,
Err(e) => {
report(progname, &e);
return true;
}
};
with_sigint(|| match lua.call_function_traced(&func, &script_args) {
Ok(_) => false,
Err(e) => {
report(progname, &e);
true
}
})
}
fn dotty(lua: &mut Lua) {
let stdin = io::stdin();
loop {
let prompt = get_prompt(lua, true);
print!("{prompt}");
let _ = io::stdout().flush();
let Some(mut input) = read_line(&stdin) else {
break;
};
if input.starts_with('=') {
input = format!("return {}", &input[1..]);
}
let func = loop {
match lua.load_bytes(input.as_bytes(), "=stdin") {
Ok(f) => break Some(f),
Err(e) => {
if is_incomplete(&e) {
let prompt2 = get_prompt(lua, false);
print!("{prompt2}");
let _ = io::stdout().flush();
match read_line(&stdin) {
Some(line) => {
input.push('\n');
input.push_str(&line);
}
None => break None,
}
} else {
report(None, &e);
break None;
}
}
}
};
if let Some(func) = func {
with_sigint(|| match lua.call_function_traced(&func, &[]) {
Ok(results) => {
if !results.is_empty() {
print_results(lua, &results);
}
}
Err(e) => {
report(None, &e);
}
});
}
}
println!();
}
fn get_prompt(lua: &mut Lua, first_line: bool) -> String {
let global_name = if first_line { "_PROMPT" } else { "_PROMPT2" };
let default = if first_line { "> " } else { ">> " };
match lua.global::<Option<String>>(global_name) {
Ok(Some(s)) => s,
_ => default.to_string(),
}
}
fn read_line(stdin: &io::Stdin) -> Option<String> {
let mut line = String::new();
let result = stdin.lock().read_line(&mut line);
match result {
Ok(0) | Err(_) => None,
Ok(_) => {
if line.ends_with('\n') {
line.pop();
if line.ends_with('\r') {
line.pop();
}
}
Some(line)
}
}
}
fn print_results(lua: &mut Lua, results: &[Val]) {
let Ok(print_fn) = lua.global::<Function>("print") else {
return;
};
if let Err(e) = lua.call_function(&print_fn, results) {
let msg = format!("error calling 'print' ({e})");
l_message(None, &msg);
}
}
fn main() {
let argv: Vec<String> = env::args().collect();
let progname = argv.first().map(String::as_str);
let libs = if env::var("RILUA_TEST_LIB").as_deref() == Ok("1") {
StdLib::ALL | StdLib::TEST
} else {
StdLib::ALL
};
let Ok(mut lua) = Lua::new_with(libs) else {
l_message(progname, "cannot create state");
process::exit(1);
};
if handle_lua_init(&mut lua).is_err() {
process::exit(1);
}
let Ok(flags) = collect_args(&argv) else {
print_usage(progname.unwrap_or("lua"));
process::exit(1);
};
if flags.has_v {
eprintln!("{LUA_VERSION}");
}
if run_args(&mut lua, &argv, flags.script, progname) {
process::exit(1);
}
if flags.script > 0 && handle_script(&mut lua, &argv, flags.script, progname) {
process::exit(1);
}
if flags.has_i {
dotty(&mut lua);
} else if flags.script == 0 && !flags.has_e && !flags.has_v {
if io::stdin().is_terminal() {
eprintln!("{LUA_VERSION}");
dotty(&mut lua);
} else {
match lua.load_file(None) {
Ok(func) => {
if let Err(e) = with_sigint(|| lua.call_function_traced(&func, &[])) {
report(progname, &e);
process::exit(1);
}
}
Err(e) => {
report(progname, &e);
process::exit(1);
}
}
}
}
}