1use std::future::{self, Future};
2use std::net::TcpListener as StdTcpListener;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::sync::OnceLock;
6use std::task::{Context, Poll};
7
8use dashmap::DashMap;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::net::TcpListener;
11use tokio::runtime;
12use tokio::sync::oneshot::{self, Receiver};
13
14use futures_util::future::Either;
15use mlua::{
16 ExternalResult, FromLuaMulti, Function, IntoLuaMulti, Lua, RegistryKey, Result, Table,
17 UserData, UserDataMethods, Value,
18};
19use rustc_hash::FxBuildHasher;
20
21type FutureId = u16;
25
26const PER_WORKER_POOL_SIZE: usize = 512;
28
29static FUTURE_RX_MAP: OnceLock<DashMap<FutureId, Receiver<()>, FxBuildHasher>> = OnceLock::new();
31
32pub fn runtime() -> &'static runtime::Runtime {
34 static RUNTIME: OnceLock<runtime::Runtime> = OnceLock::new();
35 RUNTIME.get_or_init(|| {
36 runtime::Builder::new_multi_thread()
37 .enable_all()
38 .build()
39 .expect("failed to create tokio runtime")
40 })
41}
42
43fn get_notification_port() -> u16 {
45 static NOTIFICATION_PORT: OnceLock<u16> = OnceLock::new();
46 *NOTIFICATION_PORT.get_or_init(|| {
47 StdTcpListener::bind("127.0.0.1:0")
48 .expect("failed to bind to a local port")
49 .local_addr()
50 .expect("failed to get local address")
51 .port()
52 })
53}
54
55fn get_rx_by_future_id(future_id: FutureId) -> Option<Receiver<()>> {
56 FUTURE_RX_MAP.get()?.remove(&future_id).map(|(_, rx)| rx)
57}
58
59fn set_rx_by_future_id(future_id: FutureId, rx: Receiver<()>) {
60 FUTURE_RX_MAP
61 .get_or_init(|| DashMap::with_capacity_and_hasher(256, FxBuildHasher))
62 .insert(future_id, rx);
63}
64
65fn get_future_id() -> FutureId {
67 static WATCHER: OnceLock<()> = OnceLock::new();
68 WATCHER.get_or_init(|| {
69 let port = get_notification_port();
70
71 runtime().spawn(async move {
73 let listener = TcpListener::bind(("127.0.0.1", port))
74 .await
75 .unwrap_or_else(|err| panic!("failed to bind to a port {port}: {err}"));
76
77 while let Ok((mut stream, _)) = listener.accept().await {
78 tokio::task::spawn(async move {
79 let (reader, mut writer) = stream.split();
80 let reader = BufReader::new(reader);
81 let mut lines = reader.lines();
82 while let Ok(Some(line)) = lines.next_line().await {
84 let line = line.trim();
85 if line == "PING" {
86 if writer.write_all(b"PONG\n").await.is_err() {
87 break;
88 }
89 continue;
90 }
91 if let Ok(future_id) = line.parse::<FutureId>() {
92 let resp: &[u8] = match get_rx_by_future_id(future_id) {
94 Some(rx) => {
95 _ = rx.await;
96 b"READY\n"
97 }
98 None => b"ERR\n",
99 };
100 if writer.write_all(resp).await.is_err() {
101 break;
102 }
103 }
104 }
105 });
106 }
107 });
108 });
109
110 static NEXT_ID: AtomicU16 = AtomicU16::new(1);
112 NEXT_ID.fetch_add(1, Ordering::Relaxed)
113}
114
115pub fn create_async_function<'lua, A, R, F, FR>(lua: &'lua Lua, func: F) -> Result<Function<'lua>>
119where
120 A: FromLuaMulti<'lua> + 'static,
121 R: IntoLuaMulti<'lua> + Send + 'static,
122 F: Fn(A) -> FR + 'static,
123 FR: Future<Output = Result<R>> + Send + 'static,
124{
125 let port = get_notification_port();
126 let _yield_fixup = YieldFixUp::new(lua, port)?;
127 lua.create_async_function(move |lua, args| {
128 let future_id = get_future_id();
130
131 let _guard = runtime().enter();
133 let args = match A::from_lua_multi(args, lua) {
134 Ok(args) => args,
135 Err(err) => return Either::Left(future::ready(Err(err))),
136 };
137 let (tx, rx) = oneshot::channel();
138 set_rx_by_future_id(future_id, rx);
139 let fut = func(args);
140 let result = tokio::task::spawn(async move {
141 let result = fut.await;
142 let _ = tx.send(());
144 result
145 });
146
147 Either::Right(HaproxyFuture {
148 lua,
149 id: future_id,
150 fut: async move { result.await.into_lua_err()? },
151 })
152 })
153}
154
155struct YieldFixUp<'lua>(&'lua Lua, Function<'lua>);
156
157impl<'lua> YieldFixUp<'lua> {
158 fn new(lua: &'lua Lua, port: u16) -> Result<Self> {
159 let connection_pool =
160 match lua.named_registry_value::<Value>("__HAPROXY_CONNECTION_POOL")? {
161 Value::Nil => {
162 let connection_pool = ObjectPool::new(PER_WORKER_POOL_SIZE)?;
163 let connection_pool = lua.create_userdata(connection_pool)?;
164 lua.set_named_registry_value("__HAPROXY_CONNECTION_POOL", &connection_pool)?;
165 Value::UserData(connection_pool)
166 }
167 connection_pool => connection_pool,
168 };
169
170 let coroutine: Table = lua.globals().get("coroutine")?;
171 let orig_yield: Function = coroutine.get("yield")?;
172 let new_yield: Function = lua
173 .load(
174 r#"
175 local port, connection_pool = ...
176 local msleep = core.msleep
177 return function()
178 -- It's important to cache the future id before first yielding point
179 local future_id = __RUST_ACTIVE_FUTURE_ID
180 local ok, err
181
182 -- Get new or existing connection from the pool
183 local sock = connection_pool:get()
184 if not sock then
185 sock = core.tcp()
186 ok, err = sock:connect("127.0.0.1", port)
187 if err ~= nil then
188 msleep(1)
189 return
190 end
191 end
192
193 -- Subscribe to the future updates
194 ok, err = sock:send(future_id .. "\n")
195 if err ~= nil then
196 sock:close()
197 msleep(1)
198 return
199 end
200
201 -- Wait for the future to be ready
202 ok, err = sock:receive("*l")
203 if err ~= nil then
204 sock:close()
205 msleep(1)
206 return
207 end
208 if ok ~= "READY" then
209 msleep(1)
210 end
211
212 ok = connection_pool:put(sock)
213 if not ok then
214 sock:close()
215 end
216 end
217 "#,
218 )
219 .call((port, connection_pool))?;
220 coroutine.set("yield", new_yield)?;
221 Ok(YieldFixUp(lua, orig_yield))
222 }
223}
224
225impl<'lua> Drop for YieldFixUp<'lua> {
226 fn drop(&mut self) {
227 if let Err(e) = (|| {
228 let coroutine: Table = self.0.globals().get("coroutine")?;
229 coroutine.set("yield", &self.1)
230 })() {
231 println!("Error in YieldFixUp destructor: {}", e);
232 }
233 }
234}
235
236struct ObjectPool(Vec<RegistryKey>);
237
238impl ObjectPool {
239 fn new(capacity: usize) -> Result<Self> {
240 Ok(ObjectPool(Vec::with_capacity(capacity)))
241 }
242}
243
244impl UserData for ObjectPool {
245 fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
246 methods.add_method_mut("get", |_, this, ()| Ok(this.0.pop()));
247
248 methods.add_method_mut("put", |_, this, obj: RegistryKey| {
249 if this.0.len() == PER_WORKER_POOL_SIZE {
250 return Ok(false);
251 }
252 this.0.push(obj);
253 Ok(true)
254 });
255 }
256}
257
258pin_project_lite::pin_project! {
259 struct HaproxyFuture<'lua, F> {
260 lua: &'lua Lua,
261 id: FutureId,
262 #[pin]
263 fut: F,
264 }
265}
266
267impl<F, R> Future for HaproxyFuture<'_, F>
268where
269 F: Future<Output = Result<R>>,
270{
271 type Output = Result<R>;
272
273 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274 let this = self.project();
275 match this.fut.poll(cx) {
276 Poll::Ready(res) => Poll::Ready(res),
277 Poll::Pending => {
278 let _ = (this.lua.globals()).raw_set("__RUST_ACTIVE_FUTURE_ID", *this.id);
280 Poll::Pending
281 }
282 }
283 }
284}