use std::sync::Arc;
use mlua::prelude::*;
use rusqlite::{Connection, InterruptHandle, OptionalExtension};
use crate::sql::{
json_to_lua_preserving_null, lock_conn, lua_to_json_preserving_null, race_timeout,
sql_query_timeout, SqlConfig,
};
fn validate_ns(ns: &str) -> Result<(), String> {
if ns.is_empty() {
return Err(format!("Invalid namespace: '{ns}'"));
}
if ns.contains('/') || ns.contains('\\') || ns.contains('\0') || ns.contains("..") {
return Err(format!("Invalid namespace: '{ns}'"));
}
Ok(())
}
fn init_schema(conn: &Connection) -> Result<(), String> {
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS __kv (\n ns TEXT NOT NULL,\n key TEXT NOT NULL,\n value TEXT NOT NULL,\n PRIMARY KEY (ns, key)\n ) WITHOUT ROWID;",
)
.map_err(|e| format!("kv schema init: {e}"))
}
pub fn register(
lua: &Lua,
conn: Arc<std::sync::Mutex<Connection>>,
interrupt: Arc<InterruptHandle>,
) -> LuaResult<()> {
register_with(lua, conn, interrupt, SqlConfig::default())
}
pub fn register_with(
lua: &Lua,
conn: Arc<std::sync::Mutex<Connection>>,
interrupt: Arc<InterruptHandle>,
cfg: SqlConfig,
) -> LuaResult<()> {
lua.set_app_data::<SqlConfig>(cfg);
{
let guard = lock_conn(&conn);
init_schema(&guard).map_err(LuaError::external)?;
}
let kv_tbl = lua.create_table()?;
{
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
kv_tbl.set(
"get",
lua.create_async_function(move |lua, (ns, key): (String, String)| {
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
let ns_check = validate_ns(&ns).map_err(LuaError::external);
async move {
ns_check?;
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
guard
.query_row(
"SELECT value FROM __kv WHERE ns = ?1 AND key = ?2",
rusqlite::params![ns, key],
|row| row.get::<_, String>(0),
)
.optional()
.map_err(|e| format!("kv.get sql error: {e}"))
});
let timeout = sql_query_timeout(&lua);
let row = race_timeout(fut, timeout, &interrupt, "kv.get").await?;
match row {
None => Ok(LuaValue::Nil),
Some(s) => {
let v: serde_json::Value = serde_json::from_str(&s).map_err(|e| {
LuaError::external(format!("kv.get json parse: {e}"))
})?;
json_to_lua_preserving_null(&lua, v)
}
}
}
})?,
)?;
}
{
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
kv_tbl.set(
"set",
lua.create_async_function(move |lua, (ns, key, value): (String, String, LuaValue)| {
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
let ns_check = validate_ns(&ns).map_err(LuaError::external);
let json_result = lua_to_json_preserving_null(value).and_then(|v| {
serde_json::to_string(&v)
.map_err(|e| LuaError::external(format!("kv.set serialize: {e}")))
});
async move {
ns_check?;
let json_str = json_result?;
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
guard
.execute(
"INSERT INTO __kv (ns, key, value) VALUES (?1, ?2, ?3) \
ON CONFLICT(ns, key) DO UPDATE SET value = excluded.value",
rusqlite::params![ns, key, json_str],
)
.map(|_| ())
.map_err(|e| format!("kv.set sql error: {e}"))
});
let timeout = sql_query_timeout(&lua);
race_timeout(fut, timeout, &interrupt, "kv.set").await
}
})?,
)?;
}
{
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
kv_tbl.set(
"delete",
lua.create_async_function(move |lua, (ns, key): (String, String)| {
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
let ns_check = validate_ns(&ns).map_err(LuaError::external);
async move {
ns_check?;
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
guard
.execute(
"DELETE FROM __kv WHERE ns = ?1 AND key = ?2",
rusqlite::params![ns, key],
)
.map(|n| n > 0)
.map_err(|e| format!("kv.delete sql error: {e}"))
});
let timeout = sql_query_timeout(&lua);
race_timeout(fut, timeout, &interrupt, "kv.delete").await
}
})?,
)?;
}
{
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
kv_tbl.set(
"list",
lua.create_async_function(move |lua, (ns, prefix): (String, Option<String>)| {
let conn = Arc::clone(&conn);
let interrupt = Arc::clone(&interrupt);
let ns_check = validate_ns(&ns).map_err(LuaError::external);
async move {
ns_check?;
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
let mut stmt = guard
.prepare("SELECT key FROM __kv WHERE ns = ?1 ORDER BY key")
.map_err(|e| format!("kv.list prepare: {e}"))?;
let keys: Vec<String> = stmt
.query_map(rusqlite::params![ns], |row| row.get::<_, String>(0))
.map_err(|e| format!("kv.list query: {e}"))?
.collect::<Result<_, _>>()
.map_err(|e| format!("kv.list row: {e}"))?;
Ok::<_, String>(keys)
});
let timeout = sql_query_timeout(&lua);
let keys = race_timeout(fut, timeout, &interrupt, "kv.list").await?;
let tbl = lua.create_table()?;
let mut idx = 1usize;
for k in keys {
let include = prefix.as_deref().map_or(true, |p| k.starts_with(p));
if include {
tbl.set(idx, k.as_str())?;
idx += 1;
}
}
Ok(LuaValue::Table(tbl))
}
})?,
)?;
}
let std_ns: LuaTable = lua.globals().get("std")?;
std_ns.set("kv", kv_tbl)?;
Ok(())
}