assay-lua 0.11.2

General-purpose enhanced Lua runtime. Batteries-included scripting, automation, and web services.
Documentation
use futures_util::{SinkExt, StreamExt};
use mlua::{Lua, UserData, Value};
use std::rc::Rc;
use tokio_tungstenite::MaybeTlsStream;

type WsSink = Rc<
    tokio::sync::Mutex<
        futures_util::stream::SplitSink<
            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
            tokio_tungstenite::tungstenite::Message,
        >,
    >,
>;
type WsStream = Rc<
    tokio::sync::Mutex<
        futures_util::stream::SplitStream<
            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
        >,
    >,
>;

struct WsConn {
    sink: WsSink,
    stream: WsStream,
}
impl UserData for WsConn {}

fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
    let ud = match val {
        Value::UserData(ud) => ud,
        _ => {
            return Err(mlua::Error::runtime(format!(
                "{fn_name}: first argument must be a ws connection"
            )));
        }
    };
    let ws = ud.borrow::<WsConn>().map_err(|_| {
        mlua::Error::runtime(format!("{fn_name}: first argument must be a ws connection"))
    })?;
    Ok((ws.sink.clone(), ws.stream.clone()))
}

pub fn register_ws(lua: &Lua) -> mlua::Result<()> {
    let ws_table = lua.create_table()?;

    let connect_fn = lua.create_async_function(|lua, url: String| async move {
        let (stream, _response) = tokio_tungstenite::connect_async(&url)
            .await
            .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
        let (sink, read) = stream.split();
        lua.create_any_userdata(WsConn {
            sink: Rc::new(tokio::sync::Mutex::new(sink)),
            stream: Rc::new(tokio::sync::Mutex::new(read)),
        })
    })?;
    ws_table.set("connect", connect_fn)?;

    let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
        let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
        sink.lock()
            .await
            .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
            .await
            .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
        Ok(())
    })?;
    ws_table.set("send", send_fn)?;

    let recv_fn = lua.create_async_function(|_, conn: Value| async move {
        let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
        loop {
            let msg = stream
                .lock()
                .await
                .next()
                .await
                .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
                .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
            match msg {
                tokio_tungstenite::tungstenite::Message::Text(t) => {
                    return Ok(t.to_string());
                }
                tokio_tungstenite::tungstenite::Message::Binary(b) => {
                    return String::from_utf8(b.into())
                        .map_err(|e| mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}")));
                }
                tokio_tungstenite::tungstenite::Message::Close(_) => {
                    return Err(mlua::Error::runtime("ws.recv: connection closed"));
                }
                _ => continue,
            }
        }
    })?;
    ws_table.set("recv", recv_fn)?;

    let close_fn = lua.create_async_function(|_, conn: Value| async move {
        let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
        sink.lock()
            .await
            .close()
            .await
            .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
        Ok(())
    })?;
    ws_table.set("close", close_fn)?;

    lua.globals().set("ws", ws_table)?;
    Ok(())
}