1use std::future::{self, Future};
2use std::net::TcpListener as StdTcpListener;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::sync::OnceLock;
6use std::task::{Context, Poll};
7
8use dashmap::DashMap;
9use futures_util::future::Either;
10use mlua::{
11 ExternalResult, FromLuaMulti, Function, IntoLuaMulti, Lua, RegistryKey, Result, Table,
12 UserData, UserDataMethods, Value,
13};
14use rustc_hash::FxBuildHasher;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::net::TcpListener;
17use tokio::runtime;
18use tokio::sync::oneshot::{self, Receiver};
19
20type FutureId = u16;
24
25const PER_WORKER_POOL_SIZE: usize = 512;
27
28static FUTURE_RX_MAP: OnceLock<DashMap<FutureId, Receiver<()>, FxBuildHasher>> = OnceLock::new();
30
31pub fn runtime() -> &'static runtime::Runtime {
33 static RUNTIME: OnceLock<runtime::Runtime> = OnceLock::new();
34 RUNTIME.get_or_init(|| {
35 runtime::Builder::new_multi_thread()
36 .enable_all()
37 .build()
38 .expect("failed to create tokio runtime")
39 })
40}
41
42fn get_notification_port() -> u16 {
44 static NOTIFICATION_PORT: OnceLock<u16> = OnceLock::new();
45 *NOTIFICATION_PORT.get_or_init(|| {
46 StdTcpListener::bind("127.0.0.1:0")
47 .expect("failed to bind to a local port")
48 .local_addr()
49 .expect("failed to get local address")
50 .port()
51 })
52}
53
54fn get_rx_by_future_id(future_id: FutureId) -> Option<Receiver<()>> {
55 FUTURE_RX_MAP.get()?.remove(&future_id).map(|(_, rx)| rx)
56}
57
58fn set_rx_by_future_id(future_id: FutureId, rx: Receiver<()>) {
59 FUTURE_RX_MAP
60 .get_or_init(|| DashMap::with_capacity_and_hasher(256, FxBuildHasher))
61 .insert(future_id, rx);
62}
63
64fn get_future_id() -> FutureId {
66 static WATCHER: OnceLock<()> = OnceLock::new();
67 WATCHER.get_or_init(|| {
68 let port = get_notification_port();
69
70 runtime().spawn(async move {
72 let listener = TcpListener::bind(("127.0.0.1", port))
73 .await
74 .unwrap_or_else(|err| panic!("failed to bind to a port {port}: {err}"));
75
76 while let Ok((mut stream, _)) = listener.accept().await {
77 tokio::task::spawn(async move {
78 let (reader, mut writer) = stream.split();
79 let reader = BufReader::new(reader);
80 let mut lines = reader.lines();
81 while let Ok(Some(line)) = lines.next_line().await {
83 let line = line.trim();
84 if line == "PING" {
85 if writer.write_all(b"PONG\n").await.is_err() {
86 break;
87 }
88 continue;
89 }
90 if let Ok(future_id) = line.parse::<FutureId>() {
91 let resp: &[u8] = match get_rx_by_future_id(future_id) {
93 Some(rx) => {
94 _ = rx.await;
95 b"READY\n"
96 }
97 None => b"ERR\n",
98 };
99 if writer.write_all(resp).await.is_err() {
100 break;
101 }
102 }
103 }
104 });
105 }
106 });
107 });
108
109 static NEXT_ID: AtomicU16 = AtomicU16::new(1);
111 NEXT_ID.fetch_add(1, Ordering::Relaxed)
112}
113
114pub fn create_async_function<F, A, R, FR>(lua: &Lua, func: F) -> Result<Function>
118where
119 F: Fn(A) -> FR + 'static,
120 A: FromLuaMulti + 'static,
121 R: IntoLuaMulti + Send + 'static,
122 FR: Future<Output = Result<R>> + Send + 'static,
123{
124 let port = get_notification_port();
125 let _yield_fixup = YieldFixUp::new(lua, port)?;
126 lua.create_async_function(move |lua, args| {
127 let future_id = get_future_id();
129
130 let _guard = runtime().enter();
132 let args = match A::from_lua_multi(args, &lua) {
133 Ok(args) => args,
134 Err(err) => return Either::Left(future::ready(Err(err))),
135 };
136 let (tx, rx) = oneshot::channel();
137 set_rx_by_future_id(future_id, rx);
138 let fut = func(args);
139 let result = tokio::task::spawn(async move {
140 let result = fut.await;
141 let _ = tx.send(());
143 result
144 });
145
146 Either::Right(HaproxyFuture {
147 lua,
148 id: future_id,
149 fut: async move { result.await.into_lua_err()? },
150 })
151 })
152}
153
154struct YieldFixUp<'lua>(&'lua Lua, Function);
155
156impl<'lua> YieldFixUp<'lua> {
157 fn new(lua: &'lua Lua, port: u16) -> Result<Self> {
158 let connection_pool =
159 match lua.named_registry_value::<Value>("__HAPROXY_CONNECTION_POOL")? {
160 Value::Nil => {
161 let connection_pool = ObjectPool::new(PER_WORKER_POOL_SIZE)?;
162 let connection_pool = lua.create_userdata(connection_pool)?;
163 lua.set_named_registry_value("__HAPROXY_CONNECTION_POOL", &connection_pool)?;
164 Value::UserData(connection_pool)
165 }
166 connection_pool => connection_pool,
167 };
168
169 let coroutine: Table = lua.globals().get("coroutine")?;
170 let orig_yield: Function = coroutine.get("yield")?;
171 let new_yield: Function = lua
172 .load(
173 r#"
174 local port, connection_pool = ...
175 local msleep = core.msleep
176 return function()
177 -- It's important to cache the future id before first yielding point
178 local future_id = __RUST_ACTIVE_FUTURE_ID
179 local ok, err
180
181 -- Get new or existing connection from the pool
182 local sock = connection_pool:get()
183 if not sock then
184 sock = core.tcp()
185 ok, err = sock:connect("127.0.0.1", port)
186 if err ~= nil then
187 msleep(1)
188 return
189 end
190 end
191
192 -- Subscribe to the future updates
193 ok, err = sock:send(future_id .. "\n")
194 if err ~= nil then
195 sock:close()
196 msleep(1)
197 return
198 end
199
200 -- Wait for the future to be ready
201 ok, err = sock:receive("*l")
202 if err ~= nil then
203 sock:close()
204 msleep(1)
205 return
206 end
207 if ok ~= "READY" then
208 msleep(1)
209 end
210
211 ok = connection_pool:put(sock)
212 if not ok then
213 sock:close()
214 end
215 end
216 "#,
217 )
218 .call((port, connection_pool))?;
219 coroutine.set("yield", new_yield)?;
220 Ok(YieldFixUp(lua, orig_yield))
221 }
222}
223
224impl<'lua> Drop for YieldFixUp<'lua> {
225 fn drop(&mut self) {
226 if let Err(e) = (|| {
227 let coroutine: Table = self.0.globals().get("coroutine")?;
228 coroutine.set("yield", &self.1)
229 })() {
230 eprintln!("Error in YieldFixUp destructor: {e}");
231 }
232 }
233}
234
235struct ObjectPool(Vec<RegistryKey>);
236
237impl ObjectPool {
238 fn new(capacity: usize) -> Result<Self> {
239 Ok(ObjectPool(Vec::with_capacity(capacity)))
240 }
241}
242
243impl UserData for ObjectPool {
244 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
245 methods.add_method_mut("get", |_, this, ()| Ok(this.0.pop()));
246
247 methods.add_method_mut("put", |_, this, obj: RegistryKey| {
248 if this.0.len() == PER_WORKER_POOL_SIZE {
249 return Ok(false);
250 }
251 this.0.push(obj);
252 Ok(true)
253 });
254 }
255}
256
257pin_project_lite::pin_project! {
258 struct HaproxyFuture<F> {
259 lua: Lua,
260 id: FutureId,
261 #[pin]
262 fut: F,
263 }
264}
265
266impl<F, R> Future for HaproxyFuture<F>
267where
268 F: Future<Output = Result<R>>,
269{
270 type Output = Result<R>;
271
272 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273 let this = self.project();
274 match this.fut.poll(cx) {
275 Poll::Ready(res) => Poll::Ready(res),
276 Poll::Pending => {
277 let _ = (this.lua.globals()).raw_set("__RUST_ACTIVE_FUTURE_ID", *this.id);
279 Poll::Pending
280 }
281 }
282 }
283}