use crate::constants::NVIM_POLL_TIMEOUT;
use mlua::{chunk, prelude::*};
use once_cell::sync::OnceCell;
use oorandom::Rand32;
use std::{
sync::Mutex,
time::{SystemTime, SystemTimeError, UNIX_EPOCH},
};
pub fn make_utils_tbl(lua: &Lua) -> LuaResult<LuaTable> {
let tbl = lua.create_table()?;
tbl.set(
"nvim_wrap_async",
lua.create_function(|lua, (async_fn, millis): (LuaFunction, Option<u64>)| {
nvim_wrap_async(lua, async_fn, millis.unwrap_or(NVIM_POLL_TIMEOUT))
})?,
)?;
tbl.set(
"wrap_async",
lua.create_function(|lua, (async_fn, schedule_fn)| wrap_async(lua, async_fn, schedule_fn))?,
)?;
tbl.set("rand_u32", lua.create_function(|_, ()| rand_u32())?)?;
Ok(tbl)
}
pub fn nvim_wrap_async<'a>(
lua: &'a Lua,
async_fn: LuaFunction<'a>,
millis: u64,
) -> LuaResult<LuaFunction<'a>> {
let schedule_fn = lua
.load(chunk! {
function(cb)
return vim.defer_fn(cb, $millis)
end
})
.eval()?;
wrap_async(lua, async_fn, schedule_fn)
}
pub fn wrap_async<'a>(
lua: &'a Lua,
async_fn: LuaFunction<'a>,
schedule_fn: LuaFunction<'a>,
) -> LuaResult<LuaFunction<'a>> {
let pending = pending(lua)?;
lua.load(chunk! {
return function(...)
local args = {...}
local cb = table.remove(args)
assert(type(cb) == "function", "Invalid type for cb")
local schedule = function(...) return $schedule_fn(...) end
local thread = coroutine.create(function(...) return $async_fn(...) end)
local status, res = coroutine.resume(thread, unpack(args))
local inner_fn
inner_fn = function()
if not status then
cb(false, res)
elseif res == $pending then
status, res = coroutine.resume(thread)
schedule(inner_fn)
else
cb(true, res)
end
end
schedule(inner_fn)
end
})
.eval()
}
pub(super) fn pending(lua: &Lua) -> LuaResult<LuaValue> {
let pending = lua.create_async_function(|_, ()| async move {
tokio::task::yield_now().await;
Ok(())
})?;
lua.load(chunk! {
(coroutine.wrap($pending))()
})
.eval()
}
pub fn rand_u32() -> LuaResult<u32> {
static RAND: OnceCell<Mutex<Rand32>> = OnceCell::new();
Ok(RAND
.get_or_try_init::<_, SystemTimeError>(|| {
Ok(Mutex::new(Rand32::new(
SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
)))
})
.to_lua_err()?
.lock()
.map_err(|x| x.to_string())
.to_lua_err()?
.rand_u32())
}