haproxy_api/
async.rs

1use std::future::{self, Future};
2use std::net::TcpListener as StdTcpListener;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::sync::OnceLock;
6use std::task::{Context, Poll};
7
8use dashmap::DashMap;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::net::TcpListener;
11use tokio::runtime;
12use tokio::sync::oneshot::{self, Receiver};
13
14use futures_util::future::Either;
15use mlua::{
16    ExternalResult, FromLuaMulti, Function, IntoLuaMulti, Lua, RegistryKey, Result, Table,
17    UserData, UserDataMethods, Value,
18};
19use rustc_hash::FxBuildHasher;
20
21// Using `u16` will give us max 65536 receivers to store.
22// If for any reason future was not picked up by the notification listener,
23// receiver will be overwritten on the counter reset (and memory released).
24type FutureId = u16;
25
26// Number of open connections to the notification server
27const PER_WORKER_POOL_SIZE: usize = 512;
28
29// Link between future id and the corresponding receiver (used to signal when the future is ready)
30static FUTURE_RX_MAP: OnceLock<DashMap<FutureId, Receiver<()>, FxBuildHasher>> = OnceLock::new();
31
32/// Returns the global tokio runtime.
33pub fn runtime() -> &'static runtime::Runtime {
34    static RUNTIME: OnceLock<runtime::Runtime> = OnceLock::new();
35    RUNTIME.get_or_init(|| {
36        runtime::Builder::new_multi_thread()
37            .enable_all()
38            .build()
39            .expect("failed to create tokio runtime")
40    })
41}
42
43// Find first free port
44fn get_notification_port() -> u16 {
45    static NOTIFICATION_PORT: OnceLock<u16> = OnceLock::new();
46    *NOTIFICATION_PORT.get_or_init(|| {
47        StdTcpListener::bind("127.0.0.1:0")
48            .expect("failed to bind to a local port")
49            .local_addr()
50            .expect("failed to get local address")
51            .port()
52    })
53}
54
55fn get_rx_by_future_id(future_id: FutureId) -> Option<Receiver<()>> {
56    FUTURE_RX_MAP.get()?.remove(&future_id).map(|(_, rx)| rx)
57}
58
59fn set_rx_by_future_id(future_id: FutureId, rx: Receiver<()>) {
60    FUTURE_RX_MAP
61        .get_or_init(|| DashMap::with_capacity_and_hasher(256, FxBuildHasher))
62        .insert(future_id, rx);
63}
64
65// Returns a next future id (and starts the notification task if it's not running yet)
66fn get_future_id() -> FutureId {
67    static WATCHER: OnceLock<()> = OnceLock::new();
68    WATCHER.get_or_init(|| {
69        let port = get_notification_port();
70
71        // Spawn notification task (it responds to subscribe requests and signal when the future is ready)
72        runtime().spawn(async move {
73            let listener = TcpListener::bind(("127.0.0.1", port))
74                .await
75                .unwrap_or_else(|err| panic!("failed to bind to a port {port}: {err}"));
76
77            while let Ok((mut stream, _)) = listener.accept().await {
78                tokio::task::spawn(async move {
79                    let (reader, mut writer) = stream.split();
80                    let reader = BufReader::new(reader);
81                    let mut lines = reader.lines();
82                    // Read future id from the stream and wait for the future to be ready
83                    while let Ok(Some(line)) = lines.next_line().await {
84                        let line = line.trim();
85                        if line == "PING" {
86                            if writer.write_all(b"PONG\n").await.is_err() {
87                                break;
88                            }
89                            continue;
90                        }
91                        if let Ok(future_id) = line.parse::<FutureId>() {
92                            // Wait for the future to be ready before sending the signal
93                            let resp: &[u8] = match get_rx_by_future_id(future_id) {
94                                Some(rx) => {
95                                    _ = rx.await;
96                                    b"READY\n"
97                                }
98                                None => b"ERR\n",
99                            };
100                            if writer.write_all(resp).await.is_err() {
101                                break;
102                            }
103                        }
104                    }
105                });
106            }
107        });
108    });
109
110    // Future id generator
111    static NEXT_ID: AtomicU16 = AtomicU16::new(1);
112    NEXT_ID.fetch_add(1, Ordering::Relaxed)
113}
114
115/// Creates a new async function that can be used in HAProxy configuration.
116///
117/// Tokio runtime is automatically configured to use multiple threads.
118pub fn create_async_function<'lua, A, R, F, FR>(lua: &'lua Lua, func: F) -> Result<Function<'lua>>
119where
120    A: FromLuaMulti<'lua> + 'static,
121    R: IntoLuaMulti<'lua> + Send + 'static,
122    F: Fn(A) -> FR + 'static,
123    FR: Future<Output = Result<R>> + Send + 'static,
124{
125    let port = get_notification_port();
126    let _yield_fixup = YieldFixUp::new(lua, port)?;
127    lua.create_async_function(move |lua, args| {
128        // New future id must be generated on each invocation
129        let future_id = get_future_id();
130
131        // Spawn the future in background
132        let _guard = runtime().enter();
133        let args = match A::from_lua_multi(args, lua) {
134            Ok(args) => args,
135            Err(err) => return Either::Left(future::ready(Err(err))),
136        };
137        let (tx, rx) = oneshot::channel();
138        set_rx_by_future_id(future_id, rx);
139        let fut = func(args);
140        let result = tokio::task::spawn(async move {
141            let result = fut.await;
142            // Signal that the future is ready
143            let _ = tx.send(());
144            result
145        });
146
147        Either::Right(HaproxyFuture {
148            lua,
149            id: future_id,
150            fut: async move { result.await.into_lua_err()? },
151        })
152    })
153}
154
155struct YieldFixUp<'lua>(&'lua Lua, Function<'lua>);
156
157impl<'lua> YieldFixUp<'lua> {
158    fn new(lua: &'lua Lua, port: u16) -> Result<Self> {
159        let connection_pool =
160            match lua.named_registry_value::<Value>("__HAPROXY_CONNECTION_POOL")? {
161                Value::Nil => {
162                    let connection_pool = ObjectPool::new(PER_WORKER_POOL_SIZE)?;
163                    let connection_pool = lua.create_userdata(connection_pool)?;
164                    lua.set_named_registry_value("__HAPROXY_CONNECTION_POOL", &connection_pool)?;
165                    Value::UserData(connection_pool)
166                }
167                connection_pool => connection_pool,
168            };
169
170        let coroutine: Table = lua.globals().get("coroutine")?;
171        let orig_yield: Function = coroutine.get("yield")?;
172        let new_yield: Function = lua
173            .load(
174                r#"
175                local port, connection_pool = ...
176                local msleep = core.msleep
177                return function()
178                    -- It's important to cache the future id before first yielding point
179                    local future_id = __RUST_ACTIVE_FUTURE_ID
180                    local ok, err
181
182                    -- Get new or existing connection from the pool
183                    local sock = connection_pool:get()
184                    if not sock then
185                        sock = core.tcp()
186                        ok, err = sock:connect("127.0.0.1", port)
187                        if err ~= nil then
188                            msleep(1)
189                            return
190                        end
191                    end
192
193                    -- Subscribe to the future updates
194                    ok, err = sock:send(future_id .. "\n")
195                    if err ~= nil then
196                        sock:close()
197                        msleep(1)
198                        return
199                    end
200
201                    -- Wait for the future to be ready
202                    ok, err = sock:receive("*l")
203                    if err ~= nil then
204                        sock:close()
205                        msleep(1)
206                        return
207                    end
208                    if ok ~= "READY" then
209                        msleep(1)
210                    end
211
212                    ok = connection_pool:put(sock)
213                    if not ok then
214                        sock:close()
215                    end
216                end
217            "#,
218            )
219            .call((port, connection_pool))?;
220        coroutine.set("yield", new_yield)?;
221        Ok(YieldFixUp(lua, orig_yield))
222    }
223}
224
225impl<'lua> Drop for YieldFixUp<'lua> {
226    fn drop(&mut self) {
227        if let Err(e) = (|| {
228            let coroutine: Table = self.0.globals().get("coroutine")?;
229            coroutine.set("yield", &self.1)
230        })() {
231            println!("Error in YieldFixUp destructor: {}", e);
232        }
233    }
234}
235
236struct ObjectPool(Vec<RegistryKey>);
237
238impl ObjectPool {
239    fn new(capacity: usize) -> Result<Self> {
240        Ok(ObjectPool(Vec::with_capacity(capacity)))
241    }
242}
243
244impl UserData for ObjectPool {
245    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
246        methods.add_method_mut("get", |_, this, ()| Ok(this.0.pop()));
247
248        methods.add_method_mut("put", |_, this, obj: RegistryKey| {
249            if this.0.len() == PER_WORKER_POOL_SIZE {
250                return Ok(false);
251            }
252            this.0.push(obj);
253            Ok(true)
254        });
255    }
256}
257
258pin_project_lite::pin_project! {
259    struct HaproxyFuture<'lua, F> {
260        lua: &'lua Lua,
261        id: FutureId,
262        #[pin]
263        fut: F,
264    }
265}
266
267impl<F, R> Future for HaproxyFuture<'_, F>
268where
269    F: Future<Output = Result<R>>,
270{
271    type Output = Result<R>;
272
273    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274        let this = self.project();
275        match this.fut.poll(cx) {
276            Poll::Ready(res) => Poll::Ready(res),
277            Poll::Pending => {
278                // Set the active future id so the mlua async helper will be able to wait on it
279                let _ = (this.lua.globals()).raw_set("__RUST_ACTIVE_FUTURE_ID", *this.id);
280                Poll::Pending
281            }
282        }
283    }
284}