lua_protobuf_rs/
codec.rs

1use std::collections::HashSet;
2use std::ops::Deref;
3
4use anyhow::{Context, anyhow};
5use mlua::prelude::LuaValue;
6use mlua::{Integer, Lua, Number, Table, Value};
7use protobuf::MessageDyn;
8use protobuf::reflect::{
9    MessageDescriptor, ReflectValueBox, ReflectValueRef, RuntimeFieldType, RuntimeType,
10};
11
12#[derive(Copy, Clone, Default)]
13pub struct LuaProtoCodec;
14
15impl LuaProtoCodec {
16    pub fn encode_message(
17        &self,
18        lua_message: &Table,
19        descriptor: &MessageDescriptor,
20    ) -> anyhow::Result<Box<dyn MessageDyn>> {
21        let name = descriptor.name();
22        let mut message = descriptor.new_instance();
23        for pair in lua_message.pairs::<Value, Value>() {
24            let (field_key, field_value) = pair?;
25            let field_key = field_key
26                .as_string()
27                .ok_or(anyhow!("message {} expect a string key", name))?
28                .to_str()?
29                .to_string();
30            let field_descriptor = descriptor.field_by_name(&field_key).ok_or(anyhow!(
31                "field {} not found in message {}",
32                field_key,
33                name
34            ))?;
35            match field_descriptor.runtime_field_type() {
36                RuntimeFieldType::Singular(ty) => {
37                    let boxed_value = self.box_value(name, &field_key, &ty, field_value)?;
38                    field_descriptor.set_singular_field(message.as_mut(), boxed_value);
39                }
40                RuntimeFieldType::Repeated(ty) => {
41                    let mut field_repeated = field_descriptor.mut_repeated(message.as_mut());
42                    let table = field_value.as_table().ok_or(anyhow!(
43                        "message {} field {} expect a table",
44                        name,
45                        field_key
46                    ))?;
47                    for v in table.sequence_values::<Value>() {
48                        let v = v?;
49                        let boxed_value = self.box_value(name, &field_key, &ty, v)?;
50                        field_repeated.push(boxed_value);
51                    }
52                }
53                RuntimeFieldType::Map(k_ty, v_ty) => {
54                    let mut field_map = field_descriptor.mut_map(message.as_mut());
55                    let table = field_value.as_table().ok_or(anyhow!(
56                        "message {} field {} expect a table",
57                        name,
58                        field_key
59                    ))?;
60                    for pair in table.pairs::<Value, Value>() {
61                        let (key, value) = pair?;
62                        let key = self.box_value(name, &field_key, &k_ty, key)?;
63                        let value = self.box_value(name, &field_key, &v_ty, value)?;
64                        field_map.insert(key, value);
65                    }
66                }
67            }
68        }
69        Ok(message)
70    }
71
72    pub fn decode_message(&self, lua: &Lua, message: &dyn MessageDyn) -> anyhow::Result<Table> {
73        let lua_message = lua.create_table()?;
74        let descriptor = message.descriptor_dyn();
75        let message_name = descriptor.name();
76        let mut oneof_field = HashSet::new();
77        for oneof_descriptor in descriptor.oneofs() {
78            for field in oneof_descriptor.fields() {
79                oneof_field.insert(field.name().to_string());
80            }
81        }
82        for field in descriptor.fields() {
83            let field_name = field.name();
84            match field.runtime_field_type() {
85                RuntimeFieldType::Singular(_) => {
86                    if oneof_field.contains(field_name) {
87                        if let Some(value) = field.get_singular(message) {
88                            let field_table =
89                                self.unbox_value(message_name, field_name, value, lua)?;
90                            lua_message.set(field_name, field_table)?;
91                        }
92                    } else {
93                        let value = field.get_singular_field_or_default(message);
94                        let field_table = self.unbox_value(message_name, field_name, value, lua)?;
95                        lua_message.set(field_name, field_table)?;
96                    }
97                }
98                RuntimeFieldType::Repeated(_) => {
99                    let field_table = lua.create_table()?;
100                    let values = field.get_repeated(message);
101                    for value in values {
102                        let v = self.unbox_value(message_name, field_name, value, lua)?;
103                        field_table.push(v)?;
104                    }
105                    lua_message.set(field_name, field_table)?;
106                }
107                RuntimeFieldType::Map(_, _) => {
108                    let field_table = lua.create_table()?;
109                    let maps = field.get_map(message);
110                    for (k, v) in maps.into_iter() {
111                        let k = self.unbox_value(message_name, field_name, k, lua)?;
112                        let v = self.unbox_value(message_name, field_name, v, lua)?;
113                        field_table.set(k, v)?;
114                    }
115                    lua_message.set(field_name, field_table)?;
116                }
117            }
118        }
119        Ok(lua_message)
120    }
121
122    pub fn box_value(
123        &self,
124        name: &str,
125        field: &str,
126        ty: &RuntimeType,
127        value: Value,
128    ) -> anyhow::Result<ReflectValueBox> {
129        fn value_cast_error(message: &str, field: &str, value: &str, ty: &str) -> anyhow::Error {
130            anyhow!(
131                "message {} field {} value {} cannot be cast to {}",
132                message,
133                field,
134                value,
135                ty
136            )
137        }
138        let value_ty = self.fmt_value(&value);
139        let value_box = match ty {
140            RuntimeType::I32 => {
141                let value = value
142                    .as_i32()
143                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
144                ReflectValueBox::I32(value)
145            }
146            RuntimeType::I64 => {
147                let value = value
148                    .as_i64()
149                    .ok_or(value_cast_error(name, field, value_ty, "i64"))?;
150                ReflectValueBox::I64(value)
151            }
152            RuntimeType::U32 => {
153                let value = value
154                    .as_u32()
155                    .ok_or(value_cast_error(name, field, value_ty, "u32"))?;
156                ReflectValueBox::U32(value)
157            }
158            RuntimeType::U64 => {
159                let value = value
160                    .as_u64()
161                    .ok_or(value_cast_error(name, field, value_ty, "u64"))?;
162                ReflectValueBox::U64(value)
163            }
164            RuntimeType::F32 => {
165                let value = value
166                    .as_f32()
167                    .ok_or(value_cast_error(name, field, value_ty, "f32"))?;
168                ReflectValueBox::F32(value)
169            }
170            RuntimeType::F64 => {
171                let value = value
172                    .as_f64()
173                    .ok_or(value_cast_error(name, field, value_ty, "f64"))?;
174                ReflectValueBox::F64(value)
175            }
176            RuntimeType::Bool => {
177                let value = value
178                    .as_boolean()
179                    .ok_or(value_cast_error(name, field, value_ty, "bool"))?;
180                ReflectValueBox::Bool(value)
181            }
182            RuntimeType::String => {
183                let value = value
184                    .as_string()
185                    .ok_or(value_cast_error(name, field, value_ty, "string"))?
186                    .to_str()?
187                    .to_string();
188                ReflectValueBox::String(value.to_string())
189            }
190            RuntimeType::VecU8 => {
191                let table = value
192                    .as_table()
193                    .ok_or(value_cast_error(name, field, value_ty, "table"))?;
194                let len = table.len()?;
195                let mut bytes = Vec::with_capacity(len as usize);
196                for byte in table.sequence_values::<u8>() {
197                    let byte = anyhow::Context::context(
198                        byte,
199                        format!("message {} field {} expect u8 table", name, field),
200                    )?;
201                    bytes.push(byte);
202                }
203                ReflectValueBox::Bytes(bytes)
204            }
205            RuntimeType::Enum(descriptor) => {
206                let value = value
207                    .as_i32()
208                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
209                descriptor
210                    .value_by_number(value)
211                    .ok_or(anyhow!("incorrect number of enum {}", descriptor.name()))?;
212                ReflectValueBox::Enum(descriptor.clone(), value)
213            }
214            RuntimeType::Message(descriptor) => {
215                let table = value
216                    .as_table()
217                    .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
218                let message = self.encode_message(table, descriptor)?;
219                ReflectValueBox::Message(message)
220            }
221        };
222        Ok(value_box)
223    }
224
225    pub fn unbox_value(
226        &self,
227        message_name: &str,
228        field_name: &str,
229        value: ReflectValueRef,
230        lua: &Lua,
231    ) -> anyhow::Result<LuaValue> {
232        let lua_value = match value {
233            ReflectValueRef::U32(u) => Value::Integer(Integer::from(u)),
234            ReflectValueRef::U64(u) => {
235                let u = u32::try_from(u).context(format!(
236                    "message {} field {} cannot cast u64 value {} to u32",
237                    message_name, field_name, u
238                ))?;
239                Value::Integer(Integer::from(u))
240            }
241            ReflectValueRef::I32(i) => Value::Integer(Integer::from(i)),
242            ReflectValueRef::I64(i) => Value::Integer(Integer::from(i)),
243            ReflectValueRef::F32(f) => Value::Number(Number::from(f)),
244            ReflectValueRef::F64(f) => Value::Number(Number::from(f)),
245            ReflectValueRef::Bool(b) => Value::Boolean(b),
246            ReflectValueRef::String(s) => {
247                let lua_string = lua.create_string(s)?;
248                Value::String(lua_string)
249            }
250            ReflectValueRef::Bytes(bytes) => {
251                let table = lua.create_table()?;
252                for byte in bytes {
253                    table.push(*byte)?;
254                }
255                Value::Table(table)
256            }
257            ReflectValueRef::Enum(_, i) => Value::Integer(Integer::from(i)),
258            ReflectValueRef::Message(m) => {
259                let table = self.decode_message(lua, m.deref())?;
260                Value::Table(table)
261            }
262        };
263        Ok(lua_value)
264    }
265
266    fn fmt_value(&self, value: &Value) -> &'static str {
267        match value {
268            Value::Nil => "Nil",
269            Value::Boolean(_) => "Boolean",
270            Value::LightUserData(_) => "LightUserData",
271            Value::Integer(_) => "Integer",
272            Value::Number(_) => "Number",
273            Value::String(_) => "String",
274            Value::Table(_) => "Table",
275            Value::Function(_) => "Function",
276            Value::Thread(_) => "Thread",
277            Value::UserData(_) => "UserData",
278            Value::Error(_) => "Error",
279            #[cfg(any(feature = "luau", doc))]
280            Value::Vector(_) => "Vector",
281            #[cfg(any(feature = "luau", doc))]
282            Value::Buffer(_) => "Buffer",
283            Value::Other(_) => "Other",
284        }
285    }
286}