use super::validator::{ScriptValidationError, ScriptValidator};
use mlua::Lua;
use std::error::Error;
use std::fmt;
#[derive(Debug, Clone)]
pub enum LuaValidationError {
SyntaxError(String),
MissingReturnStatement(String),
CompilationError(String),
LoadError(String),
}
impl fmt::Display for LuaValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LuaValidationError::SyntaxError(msg) => write!(f, "Syntax error: {msg}"),
LuaValidationError::MissingReturnStatement(msg) => {
write!(f, "Missing return statement: {msg}")
}
LuaValidationError::CompilationError(msg) => write!(f, "Compilation error: {msg}"),
LuaValidationError::LoadError(msg) => write!(f, "Load error: {msg}"),
}
}
}
impl Error for LuaValidationError {}
impl From<LuaValidationError> for ScriptValidationError {
fn from(err: LuaValidationError) -> Self {
match err {
LuaValidationError::SyntaxError(msg) => ScriptValidationError::SyntaxError {
engine: "lua".to_string(),
message: msg,
},
LuaValidationError::MissingReturnStatement(msg) => ScriptValidationError::SyntaxError {
engine: "lua".to_string(),
message: format!("Missing return statement: {msg}"),
},
LuaValidationError::CompilationError(msg) => ScriptValidationError::CompilationError {
engine: "lua".to_string(),
message: msg,
},
LuaValidationError::LoadError(msg) => ScriptValidationError::LoadError {
engine: "lua".to_string(),
message: msg,
},
}
}
}
pub struct LuaValidator {
lua: Lua,
}
impl LuaValidator {
pub fn new() -> Self {
Self { lua: Lua::new() }
}
}
impl Default for LuaValidator {
fn default() -> Self {
Self::new()
}
}
impl ScriptValidator for LuaValidator {
type Error = LuaValidationError;
fn validate(&self, script: &str) -> Result<(), Self::Error> {
match self.lua.load(script).eval::<mlua::Value>() {
Ok(_) => {
Ok(())
}
Err(e) => {
let err_str = e.to_string();
if err_str.contains("syntax error")
|| err_str.contains("unexpected")
|| err_str.contains("'end' expected")
{
Err(LuaValidationError::SyntaxError(err_str))
} else {
Ok(())
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_script() {
let validator = LuaValidator::new();
let script = r#"
local flow_id = request.headers["x-flow-id"]
if flow_id == nil then
return { inject = false }
end
local count = flow_store:increment(flow_id, "count")
if count > 5 then
return {
inject = true,
fault = "error",
status = 503,
body = "Too many requests"
}
end
return { inject = false }
"#;
let result = validator.validate(script);
assert!(result.is_ok(), "Valid script should pass validation");
}
#[test]
fn test_syntax_error() {
let validator = LuaValidator::new();
let script = r#"
local flow_id = request.headers["x-flow-id"
-- Missing closing bracket
return { inject = false }
"#;
let result = validator.validate(script);
assert!(result.is_err());
assert!(matches!(result, Err(LuaValidationError::SyntaxError(_))));
}
#[test]
fn test_complex_valid_script() {
let validator = LuaValidator::new();
let script = r#"
-- Circuit breaker pattern
local flow_id = request.headers["x-flow-id"]
if flow_id == nil then
return { inject = false }
end
local failures = flow_store:increment(flow_id, "failures")
flow_store:set_ttl(flow_id, 300)
if failures > 3 then
return {
inject = true,
fault = "error",
status = 503,
body = "Circuit breaker open"
}
end
return { inject = false }
"#;
let result = validator.validate(script);
assert!(result.is_ok(), "Complex valid script should pass");
}
#[test]
fn test_batch_validation() {
let validator = LuaValidator::new();
let scripts = vec![
("script1", r#"return { inject = false }"#),
("script2", r#"return { inject = true "#), (
"script3",
r#"local x = flow_store:increment("flow1", "key") return { inject = false }"#,
),
];
let results = validator.validate_batch(&scripts);
assert_eq!(results.len(), 3);
assert!(results[0].1.is_ok(), "script1 should be valid");
assert!(results[1].1.is_err(), "script2 should be invalid");
assert!(results[2].1.is_ok(), "script3 should be valid");
}
#[test]
fn test_latency_fault_script() {
let validator = LuaValidator::new();
let script = r#"
if request.path:find("/slow") then
return {
inject = true,
fault = "latency",
duration_ms = 1000
}
end
return { inject = false }
"#;
let result = validator.validate(script);
assert!(result.is_ok(), "Latency fault script should be valid");
}
#[test]
fn test_error_conversion() {
let lua_err = LuaValidationError::SyntaxError("unexpected token".to_string());
let unified_err: ScriptValidationError = lua_err.into();
assert!(matches!(
unified_err,
ScriptValidationError::SyntaxError { engine, .. } if engine == "lua"
));
}
}