use std::sync::Arc;
use mlua::prelude::*;
use rusqlite::{
types::{Value, ValueRef},
Connection,
};
use serde_json::Map;
use tracing::warn;
use crate::host::HostContext;
fn lua_params_to_values(tbl: &LuaTable) -> Result<Vec<Value>, String> {
let len = tbl.raw_len();
let mut result = Vec::with_capacity(len);
for i in 1..=len {
let v: LuaValue = tbl
.raw_get(i)
.map_err(|e| format!("params table access error: {e}"))?;
let sql_val = match v {
LuaValue::Nil => Value::Null,
LuaValue::Boolean(b) => Value::Integer(if b { 1 } else { 0 }),
LuaValue::Integer(n) => Value::Integer(n),
LuaValue::Number(f) => {
if !f.is_finite() {
return Err(format!(
"SQL param #{i} is non-finite ({f}); NaN and ±Inf are not supported"
));
}
Value::Real(f)
}
LuaValue::String(s) => Value::Text(
s.to_str()
.map_err(|e| format!("param string encoding error: {e}"))?
.to_string(),
),
other => return Err(format!("unsupported SQL param type: {}", other.type_name())),
};
result.push(sql_val);
}
Ok(result)
}
fn run_query(
conn: &Connection,
sql: &str,
params: &[Value],
) -> Result<Vec<Map<String, serde_json::Value>>, String> {
let mut stmt = conn.prepare(sql).map_err(|e| format!("sql error: {e}"))?;
let col_names: Vec<String> = stmt.column_names().iter().map(|s| s.to_string()).collect();
let mut rows = stmt
.query(rusqlite::params_from_iter(params.iter()))
.map_err(|e| format!("sql error: {e}"))?;
let mut result = Vec::new();
while let Some(row) = rows.next().map_err(|e| format!("sql error: {e}"))? {
let mut map = serde_json::Map::new();
for (i, name) in col_names.iter().enumerate() {
let val = match row.get_ref(i).map_err(|e| format!("sql error: {e}"))? {
ValueRef::Null => serde_json::Value::Null,
ValueRef::Integer(n) => serde_json::Value::Number(n.into()),
ValueRef::Real(f) => {
serde_json::Number::from_f64(f)
.map(serde_json::Value::Number)
.ok_or_else(|| {
format!(
"non-finite REAL in column '{}' ({f}); \
NaN / ±Inf cannot be represented in JSON/Lua",
col_names[i]
)
})?
}
ValueRef::Text(b) => {
let s = std::str::from_utf8(b)
.map_err(|e| format!("non-UTF-8 TEXT in column '{}': {e}", col_names[i]))?;
serde_json::Value::String(s.to_string())
}
ValueRef::Blob(_) => return Err("blob columns not supported in POC".to_string()),
};
map.insert(name.clone(), val);
}
result.push(map);
}
Ok(result)
}
fn run_exec(conn: &Connection, sql: &str, params: &[Value]) -> Result<(usize, i64), String> {
let affected = conn
.execute(sql, rusqlite::params_from_iter(params.iter()))
.map_err(|e| format!("sql error: {e}"))?;
let last_id = conn.last_insert_rowid();
Ok((affected, last_id))
}
fn rows_to_lua(lua: &Lua, rows: Vec<Map<String, serde_json::Value>>) -> LuaResult<LuaValue> {
let arr = lua.create_table()?;
for (i, row_map) in rows.into_iter().enumerate() {
let row_tbl = lua.create_table()?;
for (col, val) in row_map {
let lua_val = super::json_to_lua(lua, val)?;
row_tbl.set(col.as_str(), lua_val)?;
}
arr.set(i + 1, row_tbl)?;
}
Ok(LuaValue::Table(arr))
}
pub(super) fn lock_conn(
conn: &std::sync::Mutex<rusqlite::Connection>,
) -> std::sync::MutexGuard<'_, rusqlite::Connection> {
conn.lock().unwrap_or_else(|poisoned| {
warn!("sql conn mutex was poisoned; recovering via into_inner");
poisoned.into_inner()
})
}
pub(super) async fn race_timeout<T, F>(
fut: F,
timeout: Option<std::time::Duration>,
interrupt: &rusqlite::InterruptHandle,
op: &'static str,
) -> LuaResult<T>
where
F: std::future::Future<Output = Result<Result<T, String>, tokio::task::JoinError>>,
{
let wait = async {
match timeout {
Some(d) => match tokio::time::timeout(d, fut).await {
Ok(j) => Ok(j),
Err(_) => Err(d),
},
None => Ok(fut.await),
}
};
let wait_result = match super::task::effective_token() {
Some(t) => tokio::select! {
biased;
_ = t.cancelled() => {
interrupt.interrupt();
warn!(op, "cancelled by enclosing task");
return Err(LuaError::external(format!(
"task cancelled during {op}"
)));
}
r = wait => r,
},
None => wait.await,
};
let joined = match wait_result {
Ok(j) => j,
Err(d) => {
interrupt.interrupt();
warn!(op, timeout_ms = d.as_millis() as u64, "operation timeout");
return Err(LuaError::external(format!(
"{op} timeout ({}ms)",
d.as_millis()
)));
}
};
joined
.map_err(|e| {
warn!(op, error = %e, "spawn_blocking join error");
LuaError::external(format!("spawn_blocking: {e}"))
})?
.map_err(|e| {
warn!(op, error = %e, "execution error");
LuaError::external(e)
})
}
pub fn register(lua: &Lua, ctx: &HostContext) -> LuaResult<()> {
let sql_tbl = lua.create_table()?;
sql_tbl.set("null", LuaValue::NULL)?;
{
let ctx_conn = Arc::clone(&ctx.sql_conn);
let ctx_interrupt = Arc::clone(&ctx.sql_interrupt);
sql_tbl.set(
"query",
lua.create_async_function(move |lua, (sql, params): (String, Option<LuaTable>)| {
let conn = Arc::clone(&ctx_conn);
let interrupt = Arc::clone(&ctx_interrupt);
let params_result = params
.map(|t| lua_params_to_values(&t))
.transpose()
.map_err(LuaError::external);
async move {
let params_vec = params_result?.unwrap_or_default();
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
run_query(&guard, &sql, ¶ms_vec)
});
let timeout = super::config::sql_query_timeout();
let rows = race_timeout(fut, timeout, &interrupt, "sql.query").await?;
rows_to_lua(&lua, rows)
}
})?,
)?;
}
{
let ctx_conn = Arc::clone(&ctx.sql_conn);
let ctx_interrupt = Arc::clone(&ctx.sql_interrupt);
sql_tbl.set(
"exec",
lua.create_async_function(move |lua, (sql, params): (String, Option<LuaTable>)| {
let conn = Arc::clone(&ctx_conn);
let interrupt = Arc::clone(&ctx_interrupt);
let params_result = params
.map(|t| lua_params_to_values(&t))
.transpose()
.map_err(LuaError::external);
async move {
let params_vec = params_result?.unwrap_or_default();
let fut = tokio::task::spawn_blocking(move || {
let guard = lock_conn(&conn);
run_exec(&guard, &sql, ¶ms_vec)
});
let timeout = super::config::sql_query_timeout();
let (affected, last_id) =
race_timeout(fut, timeout, &interrupt, "sql.exec").await?;
let result_tbl = lua.create_table()?;
result_tbl.set("affected", affected as i64)?;
result_tbl.set("last_id", last_id)?;
Ok(LuaValue::Table(result_tbl))
}
})?,
)?;
}
let std_ns: LuaTable = lua.globals().get("std")?;
std_ns.set("sql", sql_tbl)?;
lua.load(include_str!("sql_tools.lua"))
.set_name("std.sql.register_tools")
.exec()?;
Ok(())
}