use crate::runtime::{self, output, sandbox};
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub enum ReplError {
Lua(mlua::Error),
LockPoisoned,
}
impl std::fmt::Display for ReplError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReplError::Lua(e) => write!(f, "Lua error: {}", e),
ReplError::LockPoisoned => write!(f, "Runtime lock poisoned"),
}
}
}
impl std::error::Error for ReplError {}
impl From<mlua::Error> for ReplError {
fn from(err: mlua::Error) -> Self {
ReplError::Lua(err)
}
}
pub struct Repl {
runtime: Mutex<mlua::Lua>,
}
pub struct EvalOutcome {
pub result: Result<Vec<String>, String>,
pub output: Vec<String>,
}
impl Repl {
pub fn new() -> Result<Self, mlua::Error> {
Self::new_with(runtime::default()?)
}
pub fn new_with(runtime: mlua::Lua) -> Result<Self, mlua::Error> {
let runtime = Mutex::new(runtime);
Ok(Self { runtime })
}
pub fn new_with_policy<P: sandbox::policy::Policy + 'static>(
policy: Arc<P>,
) -> Result<Self, mlua::Error> {
let runtime = mlua::Lua::new();
sandbox::apply_with_policy(&runtime, policy, None)?;
Self::new_with(runtime)
}
pub fn eval(&self, code: &str) -> Result<EvalOutcome, mlua::Error> {
let runtime = self.runtime.lock().unwrap();
let (eval_result, output) = output::with_output_capture(&runtime, |runtime| {
runtime.load(code).eval::<mlua::MultiValue>()
})?;
let result = match eval_result {
Ok(values) => Ok(values
.iter()
.map(|v| format!("{:#?}", v))
.collect::<Vec<_>>()),
Err(e) => Err(Self::format_lua_error(&e)),
};
Ok(EvalOutcome { result, output })
}
pub fn with_runtime<F, R>(&self, f: F) -> Result<R, ReplError>
where
F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>,
{
let runtime = self.runtime.lock().map_err(|_| ReplError::LockPoisoned)?;
f(&runtime).map_err(ReplError::from)
}
fn format_lua_error(error: &mlua::Error) -> String {
match error {
mlua::Error::RuntimeError(msg) => format!("RuntimeError: {}", msg),
mlua::Error::SyntaxError { message, .. } => format!("SyntaxError: {}", message),
mlua::Error::MemoryError(msg) => format!("MemoryError: {}", msg),
mlua::Error::CallbackError { traceback, cause } => {
format!("CallbackError: {}\nTraceback:\n{}", cause, traceback)
}
_ => format!("{}", error),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_repl() -> Repl {
Repl::new().expect("Failed to create REPL")
}
fn assert_error_contains(result: &Result<Vec<String>, String>, expected: &str) {
assert!(result.is_err(), "Expected error but got success");
let error = result.as_ref().unwrap_err();
assert!(
error.contains(expected),
"Expected error to contain '{}', but got: {}",
expected,
error
);
}
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn repl_is_send_sync() {
assert_send_sync::<Repl>();
}
#[test]
fn test_new_creates_repl_successfully() {
let result = Repl::new();
assert!(result.is_ok(), "Failed to create REPL: {:?}", result.err());
}
#[test]
fn test_new_with_custom_runtime() {
let lua = mlua::Lua::new();
lua.globals().set("test_var", 42).unwrap();
let repl = Repl::new_with(lua).unwrap();
let eval = repl.eval("return test_var").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "42");
}
#[test]
fn test_new_applies_sandboxing() {
let repl = create_repl();
let eval = repl.eval("return io.open('test.txt', 'r')").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].to_lowercase().contains("nil"));
}
#[test]
fn test_eval_simple_expression() {
let repl = create_repl();
let eval = repl.eval("1 + 1").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "2");
assert!(eval.output.is_empty());
}
#[test]
fn test_eval_string_expression() {
let repl = create_repl();
let eval = repl.eval(r#"return "hello""#).unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].contains("hello"));
}
#[test]
fn test_eval_multiple_return_values() {
let repl = create_repl();
let eval = repl.eval("return 1, 2, 3").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], "1");
assert_eq!(result[1], "2");
assert_eq!(result[2], "3");
}
#[test]
fn test_eval_nil_value() {
let repl = create_repl();
let eval = repl.eval("return nil").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].to_lowercase().contains("nil"));
}
#[test]
fn test_eval_boolean_values() {
let repl = create_repl();
let eval_true = repl.eval("return true").unwrap();
let eval_false = repl.eval("return false").unwrap();
assert!(eval_true.result.is_ok());
let result_true = eval_true.result.unwrap();
assert!(result_true[0].contains("true"));
assert!(eval_false.result.is_ok());
let result_false = eval_false.result.unwrap();
assert!(result_false[0].contains("false"));
}
#[test]
fn test_eval_table_expression() {
let repl = create_repl();
let eval = repl.eval("return {x=1, y=2}").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(!result.is_empty());
}
#[test]
fn test_eval_function_return() {
let repl = create_repl();
let eval = repl.eval(r#"return string.upper("hello")"#).unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].contains("HELLO"));
}
#[test]
fn test_eval_empty_code() {
let repl = create_repl();
let eval = repl.eval("").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result.is_empty());
assert!(eval.output.is_empty());
}
#[test]
fn test_eval_assignment_no_return() {
let repl = create_repl();
let eval = repl.eval("x = 42").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result.is_empty());
}
#[test]
fn test_eval_captures_print_output() {
let repl = create_repl();
let eval = repl.eval(r#"print("test")"#).unwrap();
assert_eq!(eval.output, vec!["test\n"]);
assert!(eval.result.is_ok());
assert!(eval.result.unwrap().is_empty());
}
#[test]
fn test_eval_captures_multiple_prints() {
let repl = create_repl();
let eval = repl
.eval(
r#"
print("line1")
print("line2")
print("line3")
"#,
)
.unwrap();
assert_eq!(eval.output, vec!["line1\n", "line2\n", "line3\n"]);
}
#[test]
fn test_eval_captures_print_with_multiple_args() {
let repl = create_repl();
let eval = repl.eval(r#"print("a", "b", "c")"#).unwrap();
assert_eq!(eval.output, vec!["a\tb\tc\n"]);
}
#[test]
fn test_eval_print_and_return_separate() {
let repl = create_repl();
let eval = repl
.eval(
r#"
print("output")
return 42
"#,
)
.unwrap();
assert_eq!(eval.output, vec!["output\n"]);
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "42");
}
#[test]
fn test_eval_print_various_types() {
let repl = create_repl();
let eval = repl.eval(r#"print(42, nil, true, false)"#).unwrap();
assert_eq!(eval.output, vec!["42\tnil\ttrue\tfalse\n"]);
}
#[test]
fn test_eval_output_not_accumulated() {
let repl = create_repl();
let eval1 = repl.eval(r#"print("first")"#).unwrap();
assert_eq!(eval1.output, vec!["first\n"]);
let eval2 = repl.eval(r#"print("second")"#).unwrap();
assert_eq!(eval2.output, vec!["second\n"]);
}
#[test]
fn test_eval_syntax_error() {
let repl = create_repl();
let eval = repl.eval("function end").unwrap();
assert_error_contains(&eval.result, "SyntaxError:");
}
#[test]
fn test_eval_runtime_error() {
let repl = create_repl();
let eval = repl.eval(r#"error("test error")"#).unwrap();
assert_error_contains(&eval.result, "RuntimeError:");
assert_error_contains(&eval.result, "test error");
}
#[test]
fn test_eval_undefined_variable_error() {
let repl = create_repl();
let eval = repl.eval("undefined_var()").unwrap();
assert_error_contains(&eval.result, "RuntimeError:");
}
#[test]
fn test_eval_type_error() {
let repl = create_repl();
let eval = repl.eval(r#"return "string" + 1"#).unwrap();
assert!(eval.result.is_err());
}
#[test]
fn test_eval_callback_error() {
let lua = mlua::Lua::new();
let error_fn = lua
.create_function(|_, ()| -> mlua::Result<()> {
Err(mlua::Error::RuntimeError("callback failed".to_string()))
})
.unwrap();
lua.globals().set("error_fn", error_fn).unwrap();
let repl = Repl::new_with(lua).unwrap();
let eval = repl.eval("error_fn()").unwrap();
assert_error_contains(&eval.result, "CallbackError:");
assert_error_contains(&eval.result, "callback failed");
}
#[test]
fn test_eval_blocked_function_error() {
let repl = create_repl();
let eval = repl.eval(r#"return io.open("file.txt")"#).unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].to_lowercase().contains("nil"));
}
#[test]
fn test_eval_error_preserves_output() {
let repl = create_repl();
let eval = repl
.eval(
r#"
print("before error")
error("test error")
"#,
)
.unwrap();
assert_eq!(eval.output, vec!["before error\n"]);
assert_error_contains(&eval.result, "RuntimeError:");
}
#[test]
fn test_eval_state_persists_between_calls() {
let repl = create_repl();
let eval1 = repl.eval("x = 42").unwrap();
assert!(eval1.result.is_ok());
let eval2 = repl.eval("return x").unwrap();
assert!(eval2.result.is_ok());
assert_eq!(eval2.result.unwrap()[0], "42");
}
#[test]
fn test_eval_function_definition_persists() {
let repl = create_repl();
let eval1 = repl.eval("function double(n) return n * 2 end").unwrap();
assert!(eval1.result.is_ok());
let eval2 = repl.eval("return double(21)").unwrap();
assert!(eval2.result.is_ok());
assert_eq!(eval2.result.unwrap()[0], "42");
}
#[test]
fn test_eval_global_table_persists() {
let repl = create_repl();
let eval1 = repl.eval("my_table = {x = 10}").unwrap();
assert!(eval1.result.is_ok());
let eval2 = repl.eval("return my_table.x").unwrap();
assert!(eval2.result.is_ok());
assert_eq!(eval2.result.unwrap()[0], "10");
}
#[test]
fn test_eval_table_modification_persists() {
let repl = create_repl();
repl.eval("my_table = {x = 10}").unwrap();
repl.eval("my_table.x = 20").unwrap();
let eval = repl.eval("return my_table.x").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "20");
}
#[test]
fn test_integration_with_safe_os_functions() {
let repl = create_repl();
let eval = repl.eval("return os.time()").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(!result.is_empty());
assert!(result[0].parse::<i64>().is_ok());
}
#[test]
fn test_integration_math_functions() {
let repl = create_repl();
let eval = repl.eval("return math.sqrt(16)").unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap()[0].clone();
assert!(result == "4" || result == "4.0");
}
#[test]
fn test_integration_string_functions() {
let repl = create_repl();
let eval = repl.eval(r#"return string.upper("test")"#).unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].contains("TEST"));
}
#[test]
fn test_integration_table_functions() {
let repl = create_repl();
let eval = repl
.eval(r#"return table.concat({"a", "b", "c"}, ",")"#)
.unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].contains("a,b,c"));
}
#[test]
fn test_with_runtime_set_global_variable() {
let repl = create_repl();
let result = repl.with_runtime(|lua| {
lua.globals().set("custom_var", 42)?;
Ok(())
});
assert!(result.is_ok());
let eval = repl.eval("return custom_var").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "42");
}
#[test]
fn test_with_runtime_register_rust_function() {
let repl = create_repl();
let result = repl.with_runtime(|lua| {
let greet = lua.create_function(|_, name: String| Ok(format!("Hello, {}!", name)))?;
lua.globals().set("greet", greet)?;
Ok(())
});
assert!(result.is_ok());
let eval = repl.eval(r#"return greet("World")"#).unwrap();
assert!(eval.result.is_ok());
let result = eval.result.unwrap();
assert!(result[0].contains("Hello, World!"));
}
#[test]
fn test_with_runtime_closure_captures_state() {
let repl = create_repl();
let multiplier = 10;
let result = repl.with_runtime(|lua| {
let func = lua.create_function(move |_, x: i32| Ok(x * multiplier))?;
lua.globals().set("multiply", func)?;
Ok(())
});
assert!(result.is_ok());
let eval = repl.eval("return multiply(5)").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "50");
}
#[test]
fn test_with_runtime_extract_value_from_lua() {
let repl = create_repl();
repl.eval("x = 42").unwrap();
let value: i32 = repl.with_runtime(|lua| lua.globals().get("x")).unwrap();
assert_eq!(value, 42);
}
#[test]
fn test_with_runtime_extract_string_from_lua() {
let repl = create_repl();
repl.eval(r#"name = "Alice""#).unwrap();
let value: String = repl.with_runtime(|lua| lua.globals().get("name")).unwrap();
assert_eq!(value, "Alice");
}
#[test]
fn test_with_runtime_returns_custom_type() {
let repl = create_repl();
repl.eval("a = 10; b = 20").unwrap();
let sum: i32 = repl
.with_runtime(|lua| {
let a: i32 = lua.globals().get("a")?;
let b: i32 = lua.globals().get("b")?;
Ok(a + b)
})
.unwrap();
assert_eq!(sum, 30);
}
#[test]
fn test_with_runtime_error_propagation() {
let repl = create_repl();
let result: Result<(), ReplError> = repl.with_runtime(|lua| {
let _val: i32 = lua.globals().get("nonexistent")?;
Ok(())
});
assert!(result.is_err());
match result {
Err(ReplError::Lua(_)) => {}
_ => panic!("Expected Lua error"),
}
}
#[test]
fn test_with_runtime_multiple_operations() {
let repl = create_repl();
repl.with_runtime(|lua| {
lua.globals().set("a", 1)?;
lua.globals().set("b", 2)?;
lua.globals().set("c", 3)?;
Ok(())
})
.unwrap();
let eval = repl.eval("return a + b + c").unwrap();
assert!(eval.result.is_ok());
assert_eq!(eval.result.unwrap()[0], "6");
}
#[test]
fn test_with_runtime_state_persists_after_call() {
let repl = create_repl();
repl.with_runtime(|lua| {
lua.globals().set("persistent", 99)?;
Ok(())
})
.unwrap();
let value: i32 = repl
.with_runtime(|lua| lua.globals().get("persistent"))
.unwrap();
assert_eq!(value, 99);
}
}