use std::future::{self, Future};
use std::net::TcpListener as StdTcpListener;
use std::pin::Pin;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::OnceLock;
use std::task::{Context, Poll};
use dashmap::DashMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::runtime;
use tokio::sync::oneshot::{self, Receiver};
use futures_util::future::Either;
use mlua::{
ExternalResult, FromLuaMulti, Function, IntoLuaMulti, Lua, RegistryKey, Result, Table,
UserData, UserDataMethods, Value,
};
use rustc_hash::FxBuildHasher;
type FutureId = u16;
const PER_WORKER_POOL_SIZE: usize = 512;
static FUTURE_RX_MAP: OnceLock<DashMap<FutureId, Receiver<()>, FxBuildHasher>> = OnceLock::new();
pub fn runtime() -> &'static runtime::Runtime {
static RUNTIME: OnceLock<runtime::Runtime> = OnceLock::new();
RUNTIME.get_or_init(|| {
runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to create tokio runtime")
})
}
fn get_notification_port() -> u16 {
static NOTIFICATION_PORT: OnceLock<u16> = OnceLock::new();
*NOTIFICATION_PORT.get_or_init(|| {
StdTcpListener::bind("127.0.0.1:0")
.expect("failed to bind to a local port")
.local_addr()
.expect("failed to get local address")
.port()
})
}
fn get_rx_by_future_id(future_id: FutureId) -> Option<Receiver<()>> {
FUTURE_RX_MAP.get()?.remove(&future_id).map(|(_, rx)| rx)
}
fn set_rx_by_future_id(future_id: FutureId, rx: Receiver<()>) {
FUTURE_RX_MAP
.get_or_init(|| DashMap::with_capacity_and_hasher(256, FxBuildHasher))
.insert(future_id, rx);
}
fn get_future_id() -> FutureId {
static WATCHER: OnceLock<()> = OnceLock::new();
WATCHER.get_or_init(|| {
let port = get_notification_port();
runtime().spawn(async move {
let listener = TcpListener::bind(("127.0.0.1", port))
.await
.unwrap_or_else(|err| panic!("failed to bind to a port {port}: {err}"));
while let Ok((mut stream, _)) = listener.accept().await {
tokio::task::spawn(async move {
let (reader, mut writer) = stream.split();
let reader = BufReader::new(reader);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let line = line.trim();
if line == "PING" {
if writer.write_all(b"PONG\n").await.is_err() {
break;
}
continue;
}
if let Ok(future_id) = line.parse::<FutureId>() {
let resp: &[u8] = match get_rx_by_future_id(future_id) {
Some(rx) => {
_ = rx.await;
b"READY\n"
}
None => b"ERR\n",
};
if writer.write_all(resp).await.is_err() {
break;
}
}
}
});
}
});
});
static NEXT_ID: AtomicU16 = AtomicU16::new(1);
NEXT_ID.fetch_add(1, Ordering::Relaxed)
}
pub fn create_async_function<'lua, A, R, F, FR>(lua: &'lua Lua, func: F) -> Result<Function<'lua>>
where
A: FromLuaMulti<'lua> + 'static,
R: IntoLuaMulti<'lua> + Send + 'static,
F: Fn(A) -> FR + 'static,
FR: Future<Output = Result<R>> + Send + 'static,
{
let port = get_notification_port();
let _yield_fixup = YieldFixUp::new(lua, port)?;
lua.create_async_function(move |lua, args| {
let future_id = get_future_id();
let _guard = runtime().enter();
let args = match A::from_lua_multi(args, lua) {
Ok(args) => args,
Err(err) => return Either::Left(future::ready(Err(err))),
};
let (tx, rx) = oneshot::channel();
set_rx_by_future_id(future_id, rx);
let fut = func(args);
let result = tokio::task::spawn(async move {
let result = fut.await;
let _ = tx.send(());
result
});
Either::Right(HaproxyFuture {
lua,
id: future_id,
fut: async move { result.await.into_lua_err()? },
})
})
}
struct YieldFixUp<'lua>(&'lua Lua, Function<'lua>);
impl<'lua> YieldFixUp<'lua> {
fn new(lua: &'lua Lua, port: u16) -> Result<Self> {
let connection_pool =
match lua.named_registry_value::<Value>("__HAPROXY_CONNECTION_POOL")? {
Value::Nil => {
let connection_pool = ObjectPool::new(PER_WORKER_POOL_SIZE)?;
let connection_pool = lua.create_userdata(connection_pool)?;
lua.set_named_registry_value("__HAPROXY_CONNECTION_POOL", &connection_pool)?;
Value::UserData(connection_pool)
}
connection_pool => connection_pool,
};
let coroutine: Table = lua.globals().get("coroutine")?;
let orig_yield: Function = coroutine.get("yield")?;
let new_yield: Function = lua
.load(
r#"
local port, connection_pool = ...
local msleep = core.msleep
return function()
-- It's important to cache the future id before first yielding point
local future_id = __RUST_ACTIVE_FUTURE_ID
local ok, err
-- Get new or existing connection from the pool
local sock = connection_pool:get()
if not sock then
sock = core.tcp()
ok, err = sock:connect("127.0.0.1", port)
if err ~= nil then
msleep(1)
return
end
end
-- Subscribe to the future updates
ok, err = sock:send(future_id .. "\n")
if err ~= nil then
sock:close()
msleep(1)
return
end
-- Wait for the future to be ready
ok, err = sock:receive("*l")
if err ~= nil then
sock:close()
msleep(1)
return
end
if ok ~= "READY" then
msleep(1)
end
ok = connection_pool:put(sock)
if not ok then
sock:close()
end
end
"#,
)
.call((port, connection_pool))?;
coroutine.set("yield", new_yield)?;
Ok(YieldFixUp(lua, orig_yield))
}
}
impl<'lua> Drop for YieldFixUp<'lua> {
fn drop(&mut self) {
if let Err(e) = (|| {
let coroutine: Table = self.0.globals().get("coroutine")?;
coroutine.set("yield", &self.1)
})() {
println!("Error in YieldFixUp destructor: {}", e);
}
}
}
struct ObjectPool(Vec<RegistryKey>);
impl ObjectPool {
fn new(capacity: usize) -> Result<Self> {
Ok(ObjectPool(Vec::with_capacity(capacity)))
}
}
impl UserData for ObjectPool {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method_mut("get", |_, this, ()| Ok(this.0.pop()));
methods.add_method_mut("put", |_, this, obj: RegistryKey| {
if this.0.len() == PER_WORKER_POOL_SIZE {
return Ok(false);
}
this.0.push(obj);
Ok(true)
});
}
}
pin_project_lite::pin_project! {
struct HaproxyFuture<'lua, F> {
lua: &'lua Lua,
id: FutureId,
#[pin]
fut: F,
}
}
impl<F, R> Future for HaproxyFuture<'_, F>
where
F: Future<Output = Result<R>>,
{
type Output = Result<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(cx) {
Poll::Ready(res) => Poll::Ready(res),
Poll::Pending => {
let _ = (this.lua.globals()).raw_set("__RUST_ACTIVE_FUTURE_ID", *this.id);
Poll::Pending
}
}
}
}