lua_protobuf_rs/
codec.rs

1use std::collections::HashSet;
2use std::ops::Deref;
3
4use anyhow::{anyhow, Context};
5use mlua::prelude::LuaValue;
6use mlua::{Integer, Lua, Number, Table, Value};
7use protobuf::reflect::{MessageDescriptor, ReflectValueBox, ReflectValueRef, RuntimeFieldType, RuntimeType};
8use protobuf::MessageDyn;
9
10#[derive(Copy, Clone, Default)]
11pub struct LuaProtoCodec;
12
13impl LuaProtoCodec {
14    pub fn encode_message(&self, lua_message: &Table, descriptor: &MessageDescriptor) -> anyhow::Result<Box<dyn MessageDyn>> {
15        let name = descriptor.name();
16        let mut message = descriptor.new_instance();
17        for pair in lua_message.pairs::<Value, Value>() {
18            let (field_key, field_value) = pair?;
19            let field_key = field_key.as_str().ok_or(anyhow!("message {} expect a string key",name))?;
20            let field_descriptor = descriptor.field_by_name(&field_key).ok_or(anyhow!("field {} not found in message {}",field_key,name))?;
21            match field_descriptor.runtime_field_type() {
22                RuntimeFieldType::Singular(ty) => {
23                    let boxed_value = self.box_value(name, &field_key, &ty, field_value)?;
24                    field_descriptor.set_singular_field(message.as_mut(), boxed_value);
25                }
26                RuntimeFieldType::Repeated(ty) => {
27                    let mut field_repeated = field_descriptor.mut_repeated(message.as_mut());
28                    let table = field_value.as_table().ok_or(anyhow!("message {} field {} expect a table",name,field_key))?;
29                    for v in table.sequence_values::<Value>() {
30                        let v = v?;
31                        let boxed_value = self.box_value(name, &field_key, &ty, v)?;
32                        field_repeated.push(boxed_value);
33                    }
34                }
35                RuntimeFieldType::Map(k_ty, v_ty) => {
36                    let mut field_map = field_descriptor.mut_map(message.as_mut());
37                    let table = field_value.as_table().ok_or(anyhow!("message {} field {} expect a table",name,field_key))?;
38                    for pair in table.pairs::<Value, Value>() {
39                        let (key, value) = pair?;
40                        let key = self.box_value(name, &field_key, &k_ty, key)?;
41                        let value = self.box_value(name, &field_key, &v_ty, value)?;
42                        field_map.insert(key, value);
43                    }
44                }
45            }
46        }
47        Ok(message)
48    }
49
50    pub fn decode_message(&self, lua: &Lua, message: &dyn MessageDyn) -> anyhow::Result<Table> {
51        let lua_message = lua.create_table()?;
52        let descriptor = message.descriptor_dyn();
53        let message_name = descriptor.name();
54        let mut oneof_field = HashSet::new();
55        for oneof_descriptor in descriptor.oneofs() {
56            for field in oneof_descriptor.fields() {
57                oneof_field.insert(field.name().to_string());
58            }
59        }
60        for field in descriptor.fields() {
61            let field_name = field.name();
62            match field.runtime_field_type() {
63                RuntimeFieldType::Singular(_) => {
64                    if oneof_field.contains(field_name) {
65                        if let Some(value) = field.get_singular(message) {
66                            let field_table = self.unbox_value(message_name, field_name, value, lua)?;
67                            lua_message.set(field_name, field_table)?;
68                        }
69                    } else {
70                        let value = field.get_singular_field_or_default(message);
71                        let field_table = self.unbox_value(message_name, field_name, value, lua)?;
72                        lua_message.set(field_name, field_table)?;
73                    }
74                }
75                RuntimeFieldType::Repeated(_) => {
76                    let field_table = lua.create_table()?;
77                    let values = field.get_repeated(message);
78                    for value in values {
79                        let v = self.unbox_value(message_name, field_name, value, lua)?;
80                        field_table.push(v)?;
81                    }
82                    lua_message.set(field_name, field_table)?;
83                }
84                RuntimeFieldType::Map(_, _) => {
85                    let field_table = lua.create_table()?;
86                    let maps = field.get_map(message);
87                    for (k, v) in maps.into_iter() {
88                        let k = self.unbox_value(message_name, field_name, k, lua)?;
89                        let v = self.unbox_value(message_name, field_name, v, lua)?;
90                        field_table.set(k, v)?;
91                    }
92                    lua_message.set(field_name, field_table)?;
93                }
94            }
95        }
96        Ok(lua_message)
97    }
98
99    pub fn box_value(&self, name: &str, field: &str, ty: &RuntimeType, value: Value) -> anyhow::Result<ReflectValueBox> {
100        fn value_cast_error(message: &str, field: &str, value: &str, ty: &str) -> anyhow::Error {
101            anyhow!("message {} field {} value {} cannot be cast to {}",message,field,value,ty)
102        }
103        let value_ty = self.fmt_value(&value);
104        let value_box = match ty {
105            RuntimeType::I32 => {
106                let value = value.as_i32().ok_or(value_cast_error(name, field, value_ty, "i32"))?;
107                ReflectValueBox::I32(value)
108            }
109            RuntimeType::I64 => {
110                let value = value.as_i64().ok_or(value_cast_error(name, field, value_ty, "i64"))?;
111                ReflectValueBox::I64(value)
112            }
113            RuntimeType::U32 => {
114                let value = value.as_u32().ok_or(value_cast_error(name, field, value_ty, "u32"))?;
115                ReflectValueBox::U32(value)
116            }
117            RuntimeType::U64 => {
118                let value = value.as_u64().ok_or(value_cast_error(name, field, value_ty, "u64"))?;
119                ReflectValueBox::U64(value)
120            }
121            RuntimeType::F32 => {
122                let value = value.as_f32().ok_or(value_cast_error(name, field, value_ty, "f32"))?;
123                ReflectValueBox::F32(value)
124            }
125            RuntimeType::F64 => {
126                let value = value.as_f64().ok_or(value_cast_error(name, field, value_ty, "f64"))?;
127                ReflectValueBox::F64(value)
128            }
129            RuntimeType::Bool => {
130                let value = value.as_boolean().ok_or(value_cast_error(name, field, value_ty, "bool"))?;
131                ReflectValueBox::Bool(value)
132            }
133            RuntimeType::String => {
134                let value = value.as_string_lossy().ok_or(value_cast_error(name, field, value_ty, "string"))?;
135                ReflectValueBox::String(value.to_string())
136            }
137            RuntimeType::VecU8 => {
138                let table = value.as_table().ok_or(value_cast_error(name, field, value_ty, "table"))?;
139                let len = table.len()?;
140                let mut bytes = Vec::with_capacity(len as usize);
141                for byte in table.sequence_values::<u8>() {
142                    let byte = anyhow::Context::context(byte, format!("message {} field {} expect u8 table", name, field))?;
143                    bytes.push(byte);
144                }
145                ReflectValueBox::Bytes(bytes)
146            }
147            RuntimeType::Enum(descriptor) => {
148                let value = value.as_i32().ok_or(value_cast_error(name, field, &value_ty, "i32"))?;
149                descriptor.value_by_number(value).ok_or(anyhow!("incorrect number of enum {}",descriptor.name()))?;
150                ReflectValueBox::Enum(descriptor.clone(), value)
151            }
152            RuntimeType::Message(descriptor) => {
153                let table = value.as_table().ok_or(value_cast_error(name, field, &value_ty, "i32"))?;
154                let message = self.encode_message(table, descriptor)?;
155                ReflectValueBox::Message(message)
156            }
157        };
158        Ok(value_box)
159    }
160
161    pub fn unbox_value(&self, message_name: &str, field_name: &str, value: ReflectValueRef, lua: &Lua) -> anyhow::Result<LuaValue> {
162        let lua_value = match value {
163            ReflectValueRef::U32(u) => { Value::Integer(Integer::from(u)) }
164            ReflectValueRef::U64(u) => {
165                let u = u32::try_from(u).context(format!("message {} field {} cannot cast u64 value {} to u32", message_name, field_name, u))?;
166                Value::Integer(Integer::from(u))
167            }
168            ReflectValueRef::I32(i) => { Value::Integer(Integer::from(i)) }
169            ReflectValueRef::I64(i) => { Value::Integer(Integer::from(i)) }
170            ReflectValueRef::F32(f) => { Value::Number(Number::from(f)) }
171            ReflectValueRef::F64(f) => { Value::Number(Number::from(f)) }
172            ReflectValueRef::Bool(b) => { Value::Boolean(b) }
173            ReflectValueRef::String(s) => {
174                let lua_string = lua.create_string(s)?;
175                Value::String(lua_string)
176            }
177            ReflectValueRef::Bytes(bytes) => {
178                let table = lua.create_table()?;
179                for byte in bytes {
180                    table.push(*byte)?;
181                }
182                Value::Table(table)
183            }
184            ReflectValueRef::Enum(_, i) => {
185                Value::Integer(Integer::from(i))
186            }
187            ReflectValueRef::Message(m) => {
188                let table = self.decode_message(lua, m.deref())?;
189                Value::Table(table)
190            }
191        };
192        Ok(lua_value)
193    }
194
195    fn fmt_value(&self, value: &Value) -> &'static str {
196        match value {
197            Value::Nil => { "Nil" }
198            Value::Boolean(_) => { "Boolean" }
199            Value::LightUserData(_) => { "LightUserData" }
200            Value::Integer(_) => { "Integer" }
201            Value::Number(_) => { "Number" }
202            Value::String(_) => { "String" }
203            Value::Table(_) => { "Table" }
204            Value::Function(_) => { "Function" }
205            Value::Thread(_) => { "Thread" }
206            Value::UserData(_) => { "UserData" }
207            Value::Error(_) => { "Error" }
208            #[cfg(any(feature = "luau", doc))]
209            Value::Vector(_) => { "Vector" }
210            #[cfg(any(feature = "luau", doc))]
211            Value::Buffer(_) => { "Buffer" }
212            Value::Other(_) => { "Other" }
213        }
214    }
215}