haproxy_api/
async.rs

1use std::future::{self, Future};
2use std::net::TcpListener as StdTcpListener;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::sync::OnceLock;
6use std::task::{Context, Poll};
7
8use dashmap::DashMap;
9use futures_util::future::Either;
10use mlua::{
11    ExternalResult, FromLuaMulti, Function, IntoLuaMulti, Lua, RegistryKey, Result, Table,
12    UserData, UserDataMethods, Value,
13};
14use rustc_hash::FxBuildHasher;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::net::TcpListener;
17use tokio::runtime;
18use tokio::sync::oneshot::{self, Receiver};
19
20// Using `u16` will give us max 65536 receivers to store.
21// If for any reason future was not picked up by the notification listener,
22// receiver will be overwritten on the counter reset (and memory released).
23type FutureId = u16;
24
25// Number of open connections to the notification server
26const PER_WORKER_POOL_SIZE: usize = 512;
27
28// Link between future id and the corresponding receiver (used to signal when the future is ready)
29static FUTURE_RX_MAP: OnceLock<DashMap<FutureId, Receiver<()>, FxBuildHasher>> = OnceLock::new();
30
31/// Returns the global tokio runtime.
32pub fn runtime() -> &'static runtime::Runtime {
33    static RUNTIME: OnceLock<runtime::Runtime> = OnceLock::new();
34    RUNTIME.get_or_init(|| {
35        runtime::Builder::new_multi_thread()
36            .enable_all()
37            .build()
38            .expect("failed to create tokio runtime")
39    })
40}
41
42// Find first free port
43fn get_notification_port() -> u16 {
44    static NOTIFICATION_PORT: OnceLock<u16> = OnceLock::new();
45    *NOTIFICATION_PORT.get_or_init(|| {
46        StdTcpListener::bind("127.0.0.1:0")
47            .expect("failed to bind to a local port")
48            .local_addr()
49            .expect("failed to get local address")
50            .port()
51    })
52}
53
54fn get_rx_by_future_id(future_id: FutureId) -> Option<Receiver<()>> {
55    FUTURE_RX_MAP.get()?.remove(&future_id).map(|(_, rx)| rx)
56}
57
58fn set_rx_by_future_id(future_id: FutureId, rx: Receiver<()>) {
59    FUTURE_RX_MAP
60        .get_or_init(|| DashMap::with_capacity_and_hasher(256, FxBuildHasher))
61        .insert(future_id, rx);
62}
63
64// Returns a next future id (and starts the notification task if it's not running yet)
65fn get_future_id() -> FutureId {
66    static WATCHER: OnceLock<()> = OnceLock::new();
67    WATCHER.get_or_init(|| {
68        let port = get_notification_port();
69
70        // Spawn notification task (it responds to subscribe requests and signal when the future is ready)
71        runtime().spawn(async move {
72            let listener = TcpListener::bind(("127.0.0.1", port))
73                .await
74                .unwrap_or_else(|err| panic!("failed to bind to a port {port}: {err}"));
75
76            while let Ok((mut stream, _)) = listener.accept().await {
77                tokio::task::spawn(async move {
78                    let (reader, mut writer) = stream.split();
79                    let reader = BufReader::new(reader);
80                    let mut lines = reader.lines();
81                    // Read future id from the stream and wait for the future to be ready
82                    while let Ok(Some(line)) = lines.next_line().await {
83                        let line = line.trim();
84                        if line == "PING" {
85                            if writer.write_all(b"PONG\n").await.is_err() {
86                                break;
87                            }
88                            continue;
89                        }
90                        if let Ok(future_id) = line.parse::<FutureId>() {
91                            // Wait for the future to be ready before sending the signal
92                            let resp: &[u8] = match get_rx_by_future_id(future_id) {
93                                Some(rx) => {
94                                    _ = rx.await;
95                                    b"READY\n"
96                                }
97                                None => b"ERR\n",
98                            };
99                            if writer.write_all(resp).await.is_err() {
100                                break;
101                            }
102                        }
103                    }
104                });
105            }
106        });
107    });
108
109    // Future id generator
110    static NEXT_ID: AtomicU16 = AtomicU16::new(1);
111    NEXT_ID.fetch_add(1, Ordering::Relaxed)
112}
113
114/// Creates a new async function that can be used in HAProxy configuration.
115///
116/// Tokio runtime is automatically configured to use multiple threads.
117pub fn create_async_function<F, A, R, FR>(lua: &Lua, func: F) -> Result<Function>
118where
119    F: Fn(A) -> FR + 'static,
120    A: FromLuaMulti + 'static,
121    R: IntoLuaMulti + Send + 'static,
122    FR: Future<Output = Result<R>> + Send + 'static,
123{
124    let port = get_notification_port();
125    let _yield_fixup = YieldFixUp::new(lua, port)?;
126    lua.create_async_function(move |lua, args| {
127        // New future id must be generated on each invocation
128        let future_id = get_future_id();
129
130        // Spawn the future in background
131        let _guard = runtime().enter();
132        let args = match A::from_lua_multi(args, &lua) {
133            Ok(args) => args,
134            Err(err) => return Either::Left(future::ready(Err(err))),
135        };
136        let (tx, rx) = oneshot::channel();
137        set_rx_by_future_id(future_id, rx);
138        let fut = func(args);
139        let result = tokio::task::spawn(async move {
140            let result = fut.await;
141            // Signal that the future is ready
142            let _ = tx.send(());
143            result
144        });
145
146        Either::Right(HaproxyFuture {
147            lua,
148            id: future_id,
149            fut: async move { result.await.into_lua_err()? },
150        })
151    })
152}
153
154struct YieldFixUp<'lua>(&'lua Lua, Function);
155
156impl<'lua> YieldFixUp<'lua> {
157    fn new(lua: &'lua Lua, port: u16) -> Result<Self> {
158        let connection_pool =
159            match lua.named_registry_value::<Value>("__HAPROXY_CONNECTION_POOL")? {
160                Value::Nil => {
161                    let connection_pool = ObjectPool::new(PER_WORKER_POOL_SIZE)?;
162                    let connection_pool = lua.create_userdata(connection_pool)?;
163                    lua.set_named_registry_value("__HAPROXY_CONNECTION_POOL", &connection_pool)?;
164                    Value::UserData(connection_pool)
165                }
166                connection_pool => connection_pool,
167            };
168
169        let coroutine: Table = lua.globals().get("coroutine")?;
170        let orig_yield: Function = coroutine.get("yield")?;
171        let new_yield: Function = lua
172            .load(
173                r#"
174                local port, connection_pool = ...
175                local msleep = core.msleep
176                return function()
177                    -- It's important to cache the future id before first yielding point
178                    local future_id = __RUST_ACTIVE_FUTURE_ID
179                    local ok, err
180
181                    -- Get new or existing connection from the pool
182                    local sock = connection_pool:get()
183                    if not sock then
184                        sock = core.tcp()
185                        ok, err = sock:connect("127.0.0.1", port)
186                        if err ~= nil then
187                            msleep(1)
188                            return
189                        end
190                    end
191
192                    -- Subscribe to the future updates
193                    ok, err = sock:send(future_id .. "\n")
194                    if err ~= nil then
195                        sock:close()
196                        msleep(1)
197                        return
198                    end
199
200                    -- Wait for the future to be ready
201                    ok, err = sock:receive("*l")
202                    if err ~= nil then
203                        sock:close()
204                        msleep(1)
205                        return
206                    end
207                    if ok ~= "READY" then
208                        msleep(1)
209                    end
210
211                    ok = connection_pool:put(sock)
212                    if not ok then
213                        sock:close()
214                    end
215                end
216            "#,
217            )
218            .call((port, connection_pool))?;
219        coroutine.set("yield", new_yield)?;
220        Ok(YieldFixUp(lua, orig_yield))
221    }
222}
223
224impl<'lua> Drop for YieldFixUp<'lua> {
225    fn drop(&mut self) {
226        if let Err(e) = (|| {
227            let coroutine: Table = self.0.globals().get("coroutine")?;
228            coroutine.set("yield", &self.1)
229        })() {
230            eprintln!("Error in YieldFixUp destructor: {e}");
231        }
232    }
233}
234
235struct ObjectPool(Vec<RegistryKey>);
236
237impl ObjectPool {
238    fn new(capacity: usize) -> Result<Self> {
239        Ok(ObjectPool(Vec::with_capacity(capacity)))
240    }
241}
242
243impl UserData for ObjectPool {
244    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
245        methods.add_method_mut("get", |_, this, ()| Ok(this.0.pop()));
246
247        methods.add_method_mut("put", |_, this, obj: RegistryKey| {
248            if this.0.len() == PER_WORKER_POOL_SIZE {
249                return Ok(false);
250            }
251            this.0.push(obj);
252            Ok(true)
253        });
254    }
255}
256
257pin_project_lite::pin_project! {
258    struct HaproxyFuture<F> {
259        lua: Lua,
260        id: FutureId,
261        #[pin]
262        fut: F,
263    }
264}
265
266impl<F, R> Future for HaproxyFuture<F>
267where
268    F: Future<Output = Result<R>>,
269{
270    type Output = Result<R>;
271
272    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273        let this = self.project();
274        match this.fut.poll(cx) {
275            Poll::Ready(res) => Poll::Ready(res),
276            Poll::Pending => {
277                // Set the active future id so the mlua async helper will be able to wait on it
278                let _ = (this.lua.globals()).raw_set("__RUST_ACTIVE_FUTURE_ID", *this.id);
279                Poll::Pending
280            }
281        }
282    }
283}