lua-astra 0.47.0

🔥 Blazingly Fast 🔥 runtime environment for Lua
use bytes::Bytes;
use futures::{SinkExt, TryStreamExt};
use mlua::{ExternalError, UserData};
use reqwest_websocket::{CloseCode, Message, WebSocket};

pub struct AstraWebSocket(pub WebSocket);
impl AstraWebSocket {
    fn value_to_bytes(value: &mlua::Value) -> Result<Bytes, mlua::Error> {
        if let Some(table) = value.as_table() {
            Ok(Bytes::from_iter(
                table.sequence_values::<u8>().filter_map(|x| x.ok()),
            ))
        } else if value.is_string() {
            Ok(Bytes::from(value.to_string()?))
        } else {
            Err(mlua::Error::runtime("type cannot be accepted as bytes"))
        }
    }
}
impl UserData for AstraWebSocket {
    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
        methods.add_async_method_mut("recv", |lua, mut this, ()| async move {
            match this.0.try_next().await {
                Ok(msg) => match msg {
                    Some(msg) => {
                        let recv = lua.create_table()?;
                        match msg {
                            Message::Text(utf8_bytes) => {
                                recv.set("type", "text")?;
                                recv.set("value", utf8_bytes.to_string())?;
                            }
                            Message::Binary(bytes) => {
                                recv.set("type", "bytes")?;
                                recv.set("value", bytes.to_vec())?;
                            }
                            Message::Close { code, reason } => {
                                recv.set("type", "close")?;
                                let close_frame = lua.create_table()?;
                                close_frame.set("code", code.to_string())?;
                                close_frame.set("reason", reason)?;
                                recv.set("value", close_frame)?;
                            }
                            _ => {}
                        };

                        Ok(recv)
                    }
                    None => Err(mlua::Error::runtime("failed to receive a frame")),
                },
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut(
            "send",
            |_, mut this, (message_type, message): (String, mlua::Value)| async move {
                let msg = match message_type.to_lowercase().as_str() {
                    "text" => Ok(Message::Text(
                        if let Some(table_message) = message.as_table() {
                            serde_json::to_string(&table_message.clone())
                                .map_err(|e| e.into_lua_err())?
                        } else if let Some(string_message) = message.as_string() {
                            string_message.to_string_lossy()
                        } else {
                            message.to_string()?
                        },
                    )),
                    "bytes" => Ok(Message::Binary(Self::value_to_bytes(&message)?)),
                    "close" => match message {
                        mlua::Value::Integer(close_code) => Ok(Message::Close {
                            code: CloseCode::from(u16::try_from(close_code).unwrap_or(1006)),
                            reason: String::new(),
                        }),
                        mlua::Value::Table(table) => Ok(Message::Close {
                            code: CloseCode::from(table.get::<u16>(1).unwrap_or(1005)),
                            reason: table.get::<String>(2).unwrap_or("".to_string()),
                        }),
                        _ => Ok(Message::Close {
                            code: CloseCode::Normal,
                            reason: String::new(),
                        }),
                    },
                    _ => Err(mlua::Error::runtime("invalid message type")),
                };

                match msg {
                    Ok(msg) => match this.0.send(msg).await {
                        Ok(_) => Ok(()),
                        Err(e) => Err(e.into_lua_err()),
                    },
                    Err(e) => Err(e.into_lua_err()),
                }
            },
        );

        methods.add_async_method_mut("send_text", |_, mut this, message: String| async move {
            match this.0.send(Message::Text(message)).await {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });

        methods.add_async_method_mut("send_bytes", |_, mut this, bytes: mlua::Value| async move {
            match this
                .0
                .send(Message::Binary(Self::value_to_bytes(&bytes)?))
                .await
            {
                Ok(_) => Ok(()),
                Err(e) => Err(e.into_lua_err()),
            }
        });
        methods.add_async_method_mut(
            "send_close",
            |_, mut this, close_frame: Option<mlua::Value>| async move {
                let close_frame: Message = match close_frame {
                    Some(frame) => match frame {
                        mlua::Value::Integer(close_code) => Message::Close {
                            code: CloseCode::from(u16::try_from(close_code).unwrap_or(1006)),
                            reason: String::new(),
                        },
                        mlua::Value::Table(table) => Message::Close {
                            code: CloseCode::from(table.get::<u16>(1).unwrap_or(1005)),
                            reason: table.get::<String>(2).unwrap_or("".to_string()),
                        },
                        _ => Message::Close {
                            code: CloseCode::Normal,
                            reason: String::new(),
                        },
                    },
                    None => Message::Close {
                        code: CloseCode::Normal,
                        reason: String::new(),
                    },
                };

                match this.0.send(close_frame).await {
                    Ok(_) => Ok(()),
                    Err(e) => Err(e.into_lua_err()),
                }
            },
        );
    }
}