lua-protobuf-rs 0.3.0

lua protobuf
Documentation
use std::collections::HashSet;
use std::ops::Deref;

use anyhow::{Context, anyhow};
use mlua::prelude::LuaValue;
use mlua::{Integer, Lua, Number, Table, Value};
use protobuf::MessageDyn;
use protobuf::reflect::{
    MessageDescriptor, ReflectValueBox, ReflectValueRef, RuntimeFieldType, RuntimeType,
};

#[derive(Copy, Clone, Default)]
pub struct LuaProtoCodec;

impl LuaProtoCodec {
    pub fn encode_message(
        &self,
        lua_message: &Table,
        descriptor: &MessageDescriptor,
    ) -> anyhow::Result<Box<dyn MessageDyn>> {
        let name = descriptor.name();
        let mut message = descriptor.new_instance();
        for pair in lua_message.pairs::<Value, Value>() {
            let (field_key, field_value) = pair?;
            let field_key = field_key
                .as_string()
                .ok_or(anyhow!("message {} expect a string key", name))?
                .to_str()?
                .to_string();
            let field_descriptor = descriptor.field_by_name(&field_key).ok_or(anyhow!(
                "field {} not found in message {}",
                field_key,
                name
            ))?;
            match field_descriptor.runtime_field_type() {
                RuntimeFieldType::Singular(ty) => {
                    let boxed_value = self.box_value(name, &field_key, &ty, field_value)?;
                    field_descriptor.set_singular_field(message.as_mut(), boxed_value);
                }
                RuntimeFieldType::Repeated(ty) => {
                    let mut field_repeated = field_descriptor.mut_repeated(message.as_mut());
                    let table = field_value.as_table().ok_or(anyhow!(
                        "message {} field {} expect a table",
                        name,
                        field_key
                    ))?;
                    for v in table.sequence_values::<Value>() {
                        let v = v?;
                        let boxed_value = self.box_value(name, &field_key, &ty, v)?;
                        field_repeated.push(boxed_value);
                    }
                }
                RuntimeFieldType::Map(k_ty, v_ty) => {
                    let mut field_map = field_descriptor.mut_map(message.as_mut());
                    let table = field_value.as_table().ok_or(anyhow!(
                        "message {} field {} expect a table",
                        name,
                        field_key
                    ))?;
                    for pair in table.pairs::<Value, Value>() {
                        let (key, value) = pair?;
                        let key = self.box_value(name, &field_key, &k_ty, key)?;
                        let value = self.box_value(name, &field_key, &v_ty, value)?;
                        field_map.insert(key, value);
                    }
                }
            }
        }
        Ok(message)
    }

    pub fn decode_message(&self, lua: &Lua, message: &dyn MessageDyn) -> anyhow::Result<Table> {
        let lua_message = lua.create_table()?;
        let descriptor = message.descriptor_dyn();
        let message_name = descriptor.name();
        let mut oneof_field = HashSet::new();
        for oneof_descriptor in descriptor.oneofs() {
            for field in oneof_descriptor.fields() {
                oneof_field.insert(field.name().to_string());
            }
        }
        for field in descriptor.fields() {
            let field_name = field.name();
            match field.runtime_field_type() {
                RuntimeFieldType::Singular(_) => {
                    if oneof_field.contains(field_name) {
                        if let Some(value) = field.get_singular(message) {
                            let field_table =
                                self.unbox_value(message_name, field_name, value, lua)?;
                            lua_message.set(field_name, field_table)?;
                        }
                    } else {
                        let value = field.get_singular_field_or_default(message);
                        let field_table = self.unbox_value(message_name, field_name, value, lua)?;
                        lua_message.set(field_name, field_table)?;
                    }
                }
                RuntimeFieldType::Repeated(_) => {
                    let field_table = lua.create_table()?;
                    let values = field.get_repeated(message);
                    for value in values {
                        let v = self.unbox_value(message_name, field_name, value, lua)?;
                        field_table.push(v)?;
                    }
                    lua_message.set(field_name, field_table)?;
                }
                RuntimeFieldType::Map(_, _) => {
                    let field_table = lua.create_table()?;
                    let maps = field.get_map(message);
                    for (k, v) in maps.into_iter() {
                        let k = self.unbox_value(message_name, field_name, k, lua)?;
                        let v = self.unbox_value(message_name, field_name, v, lua)?;
                        field_table.set(k, v)?;
                    }
                    lua_message.set(field_name, field_table)?;
                }
            }
        }
        Ok(lua_message)
    }

    pub fn box_value(
        &self,
        name: &str,
        field: &str,
        ty: &RuntimeType,
        value: Value,
    ) -> anyhow::Result<ReflectValueBox> {
        fn value_cast_error(message: &str, field: &str, value: &str, ty: &str) -> anyhow::Error {
            anyhow!(
                "message {} field {} value {} cannot be cast to {}",
                message,
                field,
                value,
                ty
            )
        }
        let value_ty = self.fmt_value(&value);
        let value_box = match ty {
            RuntimeType::I32 => {
                let value = value
                    .as_i32()
                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
                ReflectValueBox::I32(value)
            }
            RuntimeType::I64 => {
                let value = value
                    .as_i64()
                    .ok_or(value_cast_error(name, field, value_ty, "i64"))?;
                ReflectValueBox::I64(value)
            }
            RuntimeType::U32 => {
                let value = value
                    .as_u32()
                    .ok_or(value_cast_error(name, field, value_ty, "u32"))?;
                ReflectValueBox::U32(value)
            }
            RuntimeType::U64 => {
                let value = value
                    .as_u64()
                    .ok_or(value_cast_error(name, field, value_ty, "u64"))?;
                ReflectValueBox::U64(value)
            }
            RuntimeType::F32 => {
                let value = value
                    .as_f32()
                    .ok_or(value_cast_error(name, field, value_ty, "f32"))?;
                ReflectValueBox::F32(value)
            }
            RuntimeType::F64 => {
                let value = value
                    .as_f64()
                    .ok_or(value_cast_error(name, field, value_ty, "f64"))?;
                ReflectValueBox::F64(value)
            }
            RuntimeType::Bool => {
                let value = value
                    .as_boolean()
                    .ok_or(value_cast_error(name, field, value_ty, "bool"))?;
                ReflectValueBox::Bool(value)
            }
            RuntimeType::String => {
                let value = value
                    .as_string()
                    .ok_or(value_cast_error(name, field, value_ty, "string"))?
                    .to_str()?
                    .to_string();
                ReflectValueBox::String(value.to_string())
            }
            RuntimeType::VecU8 => {
                let table = value
                    .as_table()
                    .ok_or(value_cast_error(name, field, value_ty, "table"))?;
                let len = table.len()?;
                let mut bytes = Vec::with_capacity(len as usize);
                for byte in table.sequence_values::<u8>() {
                    let byte = anyhow::Context::context(
                        byte,
                        format!("message {} field {} expect u8 table", name, field),
                    )?;
                    bytes.push(byte);
                }
                ReflectValueBox::Bytes(bytes)
            }
            RuntimeType::Enum(descriptor) => {
                let value = value
                    .as_i32()
                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
                descriptor
                    .value_by_number(value)
                    .ok_or(anyhow!("incorrect number of enum {}", descriptor.name()))?;
                ReflectValueBox::Enum(descriptor.clone(), value)
            }
            RuntimeType::Message(descriptor) => {
                let table = value
                    .as_table()
                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
                let message = self.encode_message(table, descriptor)?;
                ReflectValueBox::Message(message)
            }
        };
        Ok(value_box)
    }

    pub fn unbox_value(
        &self,
        message_name: &str,
        field_name: &str,
        value: ReflectValueRef,
        lua: &Lua,
    ) -> anyhow::Result<LuaValue> {
        let lua_value = match value {
            ReflectValueRef::U32(u) => Value::Integer(Integer::from(u)),
            ReflectValueRef::U64(u) => {
                let u = u32::try_from(u).context(format!(
                    "message {} field {} cannot cast u64 value {} to u32",
                    message_name, field_name, u
                ))?;
                Value::Integer(Integer::from(u))
            }
            ReflectValueRef::I32(i) => Value::Integer(Integer::from(i)),
            ReflectValueRef::I64(i) => Value::Integer(Integer::from(i)),
            ReflectValueRef::F32(f) => Value::Number(Number::from(f)),
            ReflectValueRef::F64(f) => Value::Number(Number::from(f)),
            ReflectValueRef::Bool(b) => Value::Boolean(b),
            ReflectValueRef::String(s) => {
                let lua_string = lua.create_string(s)?;
                Value::String(lua_string)
            }
            ReflectValueRef::Bytes(bytes) => {
                let table = lua.create_table()?;
                for byte in bytes {
                    table.push(*byte)?;
                }
                Value::Table(table)
            }
            ReflectValueRef::Enum(_, i) => Value::Integer(Integer::from(i)),
            ReflectValueRef::Message(m) => {
                let table = self.decode_message(lua, m.deref())?;
                Value::Table(table)
            }
        };
        Ok(lua_value)
    }

    fn fmt_value(&self, value: &Value) -> &'static str {
        match value {
            Value::Nil => "Nil",
            Value::Boolean(_) => "Boolean",
            Value::LightUserData(_) => "LightUserData",
            Value::Integer(_) => "Integer",
            Value::Number(_) => "Number",
            Value::String(_) => "String",
            Value::Table(_) => "Table",
            Value::Function(_) => "Function",
            Value::Thread(_) => "Thread",
            Value::UserData(_) => "UserData",
            Value::Error(_) => "Error",
            #[cfg(any(feature = "luau", doc))]
            Value::Vector(_) => "Vector",
            #[cfg(any(feature = "luau", doc))]
            Value::Buffer(_) => "Buffer",
            Value::Other(_) => "Other",
        }
    }
}