lua-astra 0.47.0

🔥 Blazingly Fast 🔥 runtime environment for Lua
use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
use bytes::Bytes;
use mlua::{ExternalError, UserData};

pub struct AstraWebSocket(pub WebSocket);
impl AstraWebSocket {
    fn value_to_bytes(value: &mlua::Value) -> Result<Bytes, mlua::Error> {
        if let Some(table) = value.as_table() {
            Ok(Bytes::from_iter(
                table.sequence_values::<u8>().filter_map(|x| x.ok()),
            ))
        } else if value.is_string() {
            Ok(Bytes::from(value.to_string()?))
        } else {
            Err(mlua::Error::runtime("type cannot be accepted as bytes"))
        }
    }
}
impl UserData for AstraWebSocket {
    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
        methods.add_async_method_mut("recv", |lua, mut this, ()| async move {
            match this.0.recv().await {
                Some(msg) => match msg {
                    Ok(msg) => {
                        let recv = lua.create_table()?;
                        match msg {
                            Message::Text(utf8_bytes) => {
                                recv.set("type", "text")?;
                                recv.set("value", utf8_bytes.to_string())?;
                            }
                            Message::Binary(bytes) => {
                                recv.set("type", "bytes")?;
                                recv.set("value", bytes.to_vec())?;
                            }
                            Message::Close(close_frame) => match close_frame {
                                Some(frame) => {
                                    recv.set("type", "close")?;
                                    let close_frame = lua.create_table()?;
                                    close_frame.set("code", frame.code)?;
                                    close_frame.set("reason", frame.reason.to_string())?;
                                    recv.set("value", close_frame)?;
                                }
                                None => {
                                    recv.set("type", "Close")?;
                                    recv.set("value", mlua::Value::Nil)?;
                                }
                            },
                            _ => {}
                        };

                        Ok(recv)
                    }
                    Err(e) => Err(mlua::Error::runtime(format!(
                        "failed to receive a frame: {e}",
                    ))),
                },
                None => Err(mlua::Error::runtime("No message received!")),
            }
        });

        methods.add_async_method_mut(
            "send",
            |_, mut this, (message_type, message): (String, mlua::Value)| async move {
                let msg = match message_type.to_lowercase().as_str() {
                    "text" => Ok(Message::Text(Utf8Bytes::from(
                        if let Some(table_message) = message.as_table() {
                            serde_json::to_string(&table_message.clone())
                                .map_err(|e| e.into_lua_err())?
                        } else if let Some(string_message) = message.as_string() {
                            string_message.to_string_lossy()
                        } else {
                            message.to_string()?
                        },
                    ))),
                    "bytes" => Ok(Message::Binary(Self::value_to_bytes(&message)?)),
                    "close" => match message {
                        mlua::Value::Integer(close_code) => Ok(Message::Close(Some(CloseFrame {
                            code: u16::try_from(close_code).unwrap_or(1006),
                            reason: Utf8Bytes::from_static(""),
                        }))),
                        mlua::Value::Table(table) => Ok(Message::Close(Some(CloseFrame {
                            code: table.get::<u16>(1).unwrap_or(1005),
                            reason: Utf8Bytes::from(
                                table.get::<String>(2).unwrap_or("".to_string()),
                            ),
                        }))),
                        _ => Ok(Message::Close(None)),
                    },
                    _ => Err(mlua::Error::runtime("invalid message type")),
                };

                match msg {
                    Ok(msg) => match this.0.send(msg).await {
                        Ok(_) => Ok(()),
                        Err(e) => Err(e.into_lua_err()),
                    },
                    Err(e) => Err(e.into_lua_err()),
                }
            },
        );

        methods.add_async_method_mut("send_text", |_, mut this, message: String| async move {
            match this.0.send(Message::Text(Utf8Bytes::from(message))).await {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut("send_bytes", |_, mut this, bytes: mlua::Value| async move {
            match this
                .0
                .send(Message::Binary(Self::value_to_bytes(&bytes)?))
                .await
            {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut("send_ping", |_, mut this, bytes: mlua::Value| async move {
            match this
                .0
                .send(Message::Ping(Self::value_to_bytes(&bytes)?))
                .await
            {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut("send_pong", |_, mut this, bytes: mlua::Value| async move {
            match this
                .0
                .send(Message::Pong(Self::value_to_bytes(&bytes)?))
                .await
            {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut(
            "send_close",
            |_, mut this, close_frame: Option<mlua::Value>| async move {
                let close_frame: Message = match close_frame {
                    Some(frame) => match frame {
                        mlua::Value::Integer(close_code) => Message::Close(Some(CloseFrame {
                            code: u16::try_from(close_code).unwrap_or(1006),
                            reason: Utf8Bytes::from_static(""),
                        })),
                        mlua::Value::Table(table) => Message::Close(Some(CloseFrame {
                            code: table.get::<u16>(1).unwrap_or(1005),
                            reason: Utf8Bytes::from(
                                table.get::<String>(2).unwrap_or("".to_string()),
                            ),
                        })),
                        _ => Message::Close(None),
                    },
                    None => Message::Close(None),
                };

                match this.0.send(close_frame).await {
                    Ok(_) => Ok(()),
                    Err(e) => Err(e.into_lua_err()),
                }
            },
        );
    }
}