use crate::loader::PersistenceConfig;
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use mlua::{Lua, LuaSerdeExt, Result as LuaResult, Table, Value};
use std::fs::{self, File};
use std::io::{Read, Write};
use std::path::{Component, Path, PathBuf};
use thiserror::Error;
const VERSION: &str = "0.1.0";
const DESCRIPTION: &str = "Persistent data storage (JSON/gzip)";
const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
#[derive(Debug, Error)]
pub enum PersistenceError {
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Lua conversion error: {0}")]
LuaConversionError(String),
#[error("Persistence configuration not found")]
ConfigNotFound,
#[error("Lua VM access error: {0}")]
LuaAccessError(String),
#[error("Invalid file format: {0}")]
InvalidFormat(String),
}
impl From<mlua::Error> for PersistenceError {
fn from(e: mlua::Error) -> Self {
PersistenceError::LuaConversionError(e.to_string())
}
}
#[derive(Debug, Clone)]
struct PersistenceState {
file_path: PathBuf,
obfuscate: bool,
debug_mode: bool,
}
pub fn register(lua: &Lua, config: &PersistenceConfig, base_dir: &Path) -> LuaResult<Table> {
let module = lua.create_table()?;
module.set("_VERSION", VERSION)?;
module.set("_DESCRIPTION", DESCRIPTION)?;
let effective_path = config.effective_file_path();
let relative = Path::new(&effective_path);
if relative.is_absolute()
|| relative.components().any(|c| {
matches!(
c,
Component::ParentDir | Component::RootDir | Component::Prefix(_)
)
})
{
return Err(mlua::Error::RuntimeError(format!(
"Invalid persistence file path (directory traversal detected): {}",
effective_path
)));
}
let file_path = base_dir.join(&effective_path);
let state = PersistenceState {
file_path: file_path.clone(),
obfuscate: config.obfuscate,
debug_mode: config.debug_mode,
};
if config.debug_mode {
tracing::debug!(
path = %file_path.display(),
obfuscate = config.obfuscate,
"Persistence module initialized"
);
}
let load_state = state.clone();
module.set(
"load",
lua.create_function(move |lua, ()| load_impl(lua, &load_state))?,
)?;
let save_state = state;
module.set(
"save",
lua.create_function(move |lua, data: Table| save_impl(lua, &save_state, data))?,
)?;
Ok(module)
}
fn load_impl(lua: &Lua, state: &PersistenceState) -> LuaResult<Table> {
match load_from_file(&state.file_path) {
Ok(value) => {
if state.debug_mode {
tracing::debug!(path = %state.file_path.display(), "Loaded persistence data");
}
let lua_value: Value = lua.to_value(&value)?;
match lua_value {
Value::Table(t) => Ok(t),
_ => {
tracing::warn!(
path = %state.file_path.display(),
"Persistence data is not an object, using empty table"
);
lua.create_table()
}
}
}
Err(PersistenceError::IoError(ref e)) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::warn!(path = %state.file_path.display(), "Persistence file not found, using empty table");
lua.create_table()
}
Err(e) => {
tracing::warn!(error = %e, path = %state.file_path.display(), "Failed to load persistence data, using empty table");
lua.create_table()
}
}
}
fn save_impl(
lua: &Lua,
state: &PersistenceState,
data: Table,
) -> LuaResult<(Option<bool>, Option<String>)> {
let lua_value = Value::Table(data);
let json_value: serde_json::Value = match lua.from_value(lua_value) {
Ok(v) => v,
Err(e) => {
let err_msg = format!("Failed to convert Lua value: {}", e);
tracing::warn!(error = %err_msg, "Persistence save conversion error");
return Ok((None, Some(err_msg)));
}
};
match save_to_file(&json_value, &state.file_path, state.obfuscate) {
Ok(()) => {
if state.debug_mode {
tracing::debug!(path = %state.file_path.display(), "Saved persistence data");
}
Ok((Some(true), None))
}
Err(e) => {
let err_msg = format!("Failed to save: {}", e);
tracing::error!(error = %err_msg, path = %state.file_path.display(), "Persistence save error");
Ok((None, Some(err_msg)))
}
}
}
pub fn load_from_file(path: &Path) -> Result<serde_json::Value, PersistenceError> {
let data = fs::read(path)?;
if data.is_empty() {
return Ok(serde_json::Value::Object(serde_json::Map::new()));
}
if data.len() >= 2 && data[0] == GZIP_MAGIC[0] && data[1] == GZIP_MAGIC[1] {
let mut decoder = GzDecoder::new(&data[..]);
let mut json_bytes = Vec::new();
decoder.read_to_end(&mut json_bytes)?;
Ok(serde_json::from_slice(&json_bytes)?)
} else {
Ok(serde_json::from_slice(&data)?)
}
}
pub fn save_to_file(
data: &serde_json::Value,
path: &Path,
obfuscate: bool,
) -> Result<(), PersistenceError> {
if let Some(parent) = path.parent()
&& !parent.exists()
{
fs::create_dir_all(parent)?;
tracing::debug!(path = %parent.display(), "Created persistence directory");
}
let bytes = if obfuscate {
let json_bytes = serde_json::to_vec(data)?;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&json_bytes)?;
encoder.finish()?
} else {
serde_json::to_vec_pretty(data)?
};
let temp_path = path.with_extension("tmp");
let mut file = File::create(&temp_path)?;
file.write_all(&bytes)?;
file.sync_all()?;
drop(file);
if let Err(e) = fs::rename(&temp_path, path) {
let _ = fs::remove_file(&temp_path);
return Err(PersistenceError::IoError(e));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_config(temp_dir: &TempDir, obfuscate: bool) -> (PersistenceConfig, PathBuf) {
let file_name = if obfuscate { "save.dat" } else { "save.json" };
let config = PersistenceConfig {
obfuscate,
file_path: file_name.to_string(),
debug_mode: true,
};
let base_dir = temp_dir.path().to_path_buf();
(config, base_dir)
}
#[test]
fn test_save_load_json() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("save.json");
let data = serde_json::json!({
"player_name": "Alice",
"play_count": 42,
"flags": {
"tutorial_complete": true
}
});
save_to_file(&data, &file_path, false).unwrap();
let loaded = load_from_file(&file_path).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_save_load_obfuscated() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("save.dat");
let data = serde_json::json!({
"player_name": "Bob",
"inventory": ["sword", "shield"]
});
save_to_file(&data, &file_path, true).unwrap();
let raw = fs::read(&file_path).unwrap();
assert!(raw.len() >= 2);
assert_eq!(raw[0], GZIP_MAGIC[0]);
assert_eq!(raw[1], GZIP_MAGIC[1]);
let loaded = load_from_file(&file_path).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_load_nonexistent_returns_error() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("nonexistent.json");
let result = load_from_file(&file_path);
assert!(matches!(result, Err(PersistenceError::IoError(_))));
}
#[test]
fn test_load_corrupted_returns_error() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("corrupted.json");
fs::write(&file_path, "{ invalid json }").unwrap();
let result = load_from_file(&file_path);
assert!(matches!(result, Err(PersistenceError::JsonError(_))));
}
#[test]
fn test_auto_detect_format() {
let temp_dir = TempDir::new().unwrap();
let data = serde_json::json!({"key": "value"});
let json_path = temp_dir.path().join("test.json");
save_to_file(&data, &json_path, false).unwrap();
let gzip_path = temp_dir.path().join("test.dat");
save_to_file(&data, &gzip_path, true).unwrap();
let loaded_json = load_from_file(&json_path).unwrap();
let loaded_gzip = load_from_file(&gzip_path).unwrap();
assert_eq!(loaded_json, data);
assert_eq!(loaded_gzip, data);
}
#[test]
fn test_atomic_write_creates_directory() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir
.path()
.join("subdir")
.join("nested")
.join("save.json");
let data = serde_json::json!({"test": true});
save_to_file(&data, &file_path, false).unwrap();
assert!(file_path.exists());
let loaded = load_from_file(&file_path).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_nested_table_serialization() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("nested.json");
let data = serde_json::json!({
"level1": {
"level2": {
"level3": {
"value": 123,
"array": [1, 2, 3],
"bool": true,
"string": "nested"
}
}
}
});
save_to_file(&data, &file_path, false).unwrap();
let loaded = load_from_file(&file_path).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_empty_file_returns_empty_object() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("empty.json");
fs::write(&file_path, "").unwrap();
let loaded = load_from_file(&file_path).unwrap();
assert_eq!(loaded, serde_json::json!({}));
}
#[test]
fn test_lua_module_load() {
let temp_dir = TempDir::new().unwrap();
let (config, base_dir) = create_test_config(&temp_dir, false);
let lua = Lua::new();
let module = register(&lua, &config, &base_dir).unwrap();
let load_fn: mlua::Function = module.get("load").unwrap();
let result: Table = load_fn.call(()).unwrap();
assert_eq!(result.len().unwrap(), 0);
}
#[test]
fn test_lua_module_save_and_load() {
let temp_dir = TempDir::new().unwrap();
let (config, base_dir) = create_test_config(&temp_dir, false);
let lua = Lua::new();
let module = register(&lua, &config, &base_dir).unwrap();
let data: Table = lua.create_table().unwrap();
data.set("name", "Test").unwrap();
data.set("count", 42).unwrap();
let save_fn: mlua::Function = module.get("save").unwrap();
let (ok, err): (Option<bool>, Option<String>) = save_fn.call(data.clone()).unwrap();
assert_eq!(ok, Some(true));
assert!(err.is_none());
let load_fn: mlua::Function = module.get("load").unwrap();
let result: Table = load_fn.call(()).unwrap();
let name: String = result.get("name").unwrap();
let count: i32 = result.get("count").unwrap();
assert_eq!(name, "Test");
assert_eq!(count, 42);
}
#[test]
fn test_register_rejects_directory_traversal() {
let temp_dir = TempDir::new().unwrap();
let base_dir = temp_dir.path().to_path_buf();
let config = PersistenceConfig {
obfuscate: false,
file_path: "../../etc/malicious.json".to_string(),
debug_mode: false,
};
let lua = Lua::new();
let result = register(&lua, &config, &base_dir);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("directory traversal"),
"Expected traversal error, got: {}",
err_msg
);
}
#[test]
fn test_register_rejects_rooted_path() {
let temp_dir = TempDir::new().unwrap();
let base_dir = temp_dir.path().to_path_buf();
let config = PersistenceConfig {
obfuscate: false,
file_path: "/etc/passwd".to_string(),
debug_mode: false,
};
let lua = Lua::new();
let result = register(&lua, &config, &base_dir);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("directory traversal"),
"Expected traversal error, got: {}",
err_msg
);
}
}