1use std::collections::HashSet;
2use std::ops::Deref;
3
4use anyhow::{Context, anyhow};
5use mlua::prelude::LuaValue;
6use mlua::{Integer, Lua, Number, Table, Value};
7use protobuf::MessageDyn;
8use protobuf::reflect::{
9 MessageDescriptor, ReflectValueBox, ReflectValueRef, RuntimeFieldType, RuntimeType,
10};
11
12#[derive(Copy, Clone, Default)]
13pub struct LuaProtoCodec;
14
15impl LuaProtoCodec {
16 pub fn encode_message(
17 &self,
18 lua_message: &Table,
19 descriptor: &MessageDescriptor,
20 ) -> anyhow::Result<Box<dyn MessageDyn>> {
21 let name = descriptor.name();
22 let mut message = descriptor.new_instance();
23 for pair in lua_message.pairs::<Value, Value>() {
24 let (field_key, field_value) = pair?;
25 let field_key = field_key
26 .as_string()
27 .ok_or(anyhow!("message {} expect a string key", name))?
28 .to_str()?
29 .to_string();
30 let field_descriptor = descriptor.field_by_name(&field_key).ok_or(anyhow!(
31 "field {} not found in message {}",
32 field_key,
33 name
34 ))?;
35 match field_descriptor.runtime_field_type() {
36 RuntimeFieldType::Singular(ty) => {
37 let boxed_value = self.box_value(name, &field_key, &ty, field_value)?;
38 field_descriptor.set_singular_field(message.as_mut(), boxed_value);
39 }
40 RuntimeFieldType::Repeated(ty) => {
41 let mut field_repeated = field_descriptor.mut_repeated(message.as_mut());
42 let table = field_value.as_table().ok_or(anyhow!(
43 "message {} field {} expect a table",
44 name,
45 field_key
46 ))?;
47 for v in table.sequence_values::<Value>() {
48 let v = v?;
49 let boxed_value = self.box_value(name, &field_key, &ty, v)?;
50 field_repeated.push(boxed_value);
51 }
52 }
53 RuntimeFieldType::Map(k_ty, v_ty) => {
54 let mut field_map = field_descriptor.mut_map(message.as_mut());
55 let table = field_value.as_table().ok_or(anyhow!(
56 "message {} field {} expect a table",
57 name,
58 field_key
59 ))?;
60 for pair in table.pairs::<Value, Value>() {
61 let (key, value) = pair?;
62 let key = self.box_value(name, &field_key, &k_ty, key)?;
63 let value = self.box_value(name, &field_key, &v_ty, value)?;
64 field_map.insert(key, value);
65 }
66 }
67 }
68 }
69 Ok(message)
70 }
71
72 pub fn decode_message(&self, lua: &Lua, message: &dyn MessageDyn) -> anyhow::Result<Table> {
73 let lua_message = lua.create_table()?;
74 let descriptor = message.descriptor_dyn();
75 let message_name = descriptor.name();
76 let mut oneof_field = HashSet::new();
77 for oneof_descriptor in descriptor.oneofs() {
78 for field in oneof_descriptor.fields() {
79 oneof_field.insert(field.name().to_string());
80 }
81 }
82 for field in descriptor.fields() {
83 let field_name = field.name();
84 match field.runtime_field_type() {
85 RuntimeFieldType::Singular(_) => {
86 if oneof_field.contains(field_name) {
87 if let Some(value) = field.get_singular(message) {
88 let field_table =
89 self.unbox_value(message_name, field_name, value, lua)?;
90 lua_message.set(field_name, field_table)?;
91 }
92 } else {
93 let value = field.get_singular_field_or_default(message);
94 let field_table = self.unbox_value(message_name, field_name, value, lua)?;
95 lua_message.set(field_name, field_table)?;
96 }
97 }
98 RuntimeFieldType::Repeated(_) => {
99 let field_table = lua.create_table()?;
100 let values = field.get_repeated(message);
101 for value in values {
102 let v = self.unbox_value(message_name, field_name, value, lua)?;
103 field_table.push(v)?;
104 }
105 lua_message.set(field_name, field_table)?;
106 }
107 RuntimeFieldType::Map(_, _) => {
108 let field_table = lua.create_table()?;
109 let maps = field.get_map(message);
110 for (k, v) in maps.into_iter() {
111 let k = self.unbox_value(message_name, field_name, k, lua)?;
112 let v = self.unbox_value(message_name, field_name, v, lua)?;
113 field_table.set(k, v)?;
114 }
115 lua_message.set(field_name, field_table)?;
116 }
117 }
118 }
119 Ok(lua_message)
120 }
121
122 pub fn box_value(
123 &self,
124 name: &str,
125 field: &str,
126 ty: &RuntimeType,
127 value: Value,
128 ) -> anyhow::Result<ReflectValueBox> {
129 fn value_cast_error(message: &str, field: &str, value: &str, ty: &str) -> anyhow::Error {
130 anyhow!(
131 "message {} field {} value {} cannot be cast to {}",
132 message,
133 field,
134 value,
135 ty
136 )
137 }
138 let value_ty = self.fmt_value(&value);
139 let value_box = match ty {
140 RuntimeType::I32 => {
141 let value = value
142 .as_i32()
143 .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
144 ReflectValueBox::I32(value)
145 }
146 RuntimeType::I64 => {
147 let value = value
148 .as_i64()
149 .ok_or(value_cast_error(name, field, value_ty, "i64"))?;
150 ReflectValueBox::I64(value)
151 }
152 RuntimeType::U32 => {
153 let value = value
154 .as_u32()
155 .ok_or(value_cast_error(name, field, value_ty, "u32"))?;
156 ReflectValueBox::U32(value)
157 }
158 RuntimeType::U64 => {
159 let value = value
160 .as_u64()
161 .ok_or(value_cast_error(name, field, value_ty, "u64"))?;
162 ReflectValueBox::U64(value)
163 }
164 RuntimeType::F32 => {
165 let value = value
166 .as_f32()
167 .ok_or(value_cast_error(name, field, value_ty, "f32"))?;
168 ReflectValueBox::F32(value)
169 }
170 RuntimeType::F64 => {
171 let value = value
172 .as_f64()
173 .ok_or(value_cast_error(name, field, value_ty, "f64"))?;
174 ReflectValueBox::F64(value)
175 }
176 RuntimeType::Bool => {
177 let value = value
178 .as_boolean()
179 .ok_or(value_cast_error(name, field, value_ty, "bool"))?;
180 ReflectValueBox::Bool(value)
181 }
182 RuntimeType::String => {
183 let value = value
184 .as_string()
185 .ok_or(value_cast_error(name, field, value_ty, "string"))?
186 .to_str()?
187 .to_string();
188 ReflectValueBox::String(value.to_string())
189 }
190 RuntimeType::VecU8 => {
191 let table = value
192 .as_table()
193 .ok_or(value_cast_error(name, field, value_ty, "table"))?;
194 let len = table.len()?;
195 let mut bytes = Vec::with_capacity(len as usize);
196 for byte in table.sequence_values::<u8>() {
197 let byte = anyhow::Context::context(
198 byte,
199 format!("message {} field {} expect u8 table", name, field),
200 )?;
201 bytes.push(byte);
202 }
203 ReflectValueBox::Bytes(bytes)
204 }
205 RuntimeType::Enum(descriptor) => {
206 let value = value
207 .as_i32()
208 .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
209 descriptor
210 .value_by_number(value)
211 .ok_or(anyhow!("incorrect number of enum {}", descriptor.name()))?;
212 ReflectValueBox::Enum(descriptor.clone(), value)
213 }
214 RuntimeType::Message(descriptor) => {
215 let table = value
216 .as_table()
217 .ok_or(value_cast_error(name, field, value_ty, "i32"))?;
218 let message = self.encode_message(table, descriptor)?;
219 ReflectValueBox::Message(message)
220 }
221 };
222 Ok(value_box)
223 }
224
225 pub fn unbox_value(
226 &self,
227 message_name: &str,
228 field_name: &str,
229 value: ReflectValueRef,
230 lua: &Lua,
231 ) -> anyhow::Result<LuaValue> {
232 let lua_value = match value {
233 ReflectValueRef::U32(u) => Value::Integer(Integer::from(u)),
234 ReflectValueRef::U64(u) => {
235 let u = u32::try_from(u).context(format!(
236 "message {} field {} cannot cast u64 value {} to u32",
237 message_name, field_name, u
238 ))?;
239 Value::Integer(Integer::from(u))
240 }
241 ReflectValueRef::I32(i) => Value::Integer(Integer::from(i)),
242 ReflectValueRef::I64(i) => Value::Integer(Integer::from(i)),
243 ReflectValueRef::F32(f) => Value::Number(Number::from(f)),
244 ReflectValueRef::F64(f) => Value::Number(Number::from(f)),
245 ReflectValueRef::Bool(b) => Value::Boolean(b),
246 ReflectValueRef::String(s) => {
247 let lua_string = lua.create_string(s)?;
248 Value::String(lua_string)
249 }
250 ReflectValueRef::Bytes(bytes) => {
251 let table = lua.create_table()?;
252 for byte in bytes {
253 table.push(*byte)?;
254 }
255 Value::Table(table)
256 }
257 ReflectValueRef::Enum(_, i) => Value::Integer(Integer::from(i)),
258 ReflectValueRef::Message(m) => {
259 let table = self.decode_message(lua, m.deref())?;
260 Value::Table(table)
261 }
262 };
263 Ok(lua_value)
264 }
265
266 fn fmt_value(&self, value: &Value) -> &'static str {
267 match value {
268 Value::Nil => "Nil",
269 Value::Boolean(_) => "Boolean",
270 Value::LightUserData(_) => "LightUserData",
271 Value::Integer(_) => "Integer",
272 Value::Number(_) => "Number",
273 Value::String(_) => "String",
274 Value::Table(_) => "Table",
275 Value::Function(_) => "Function",
276 Value::Thread(_) => "Thread",
277 Value::UserData(_) => "UserData",
278 Value::Error(_) => "Error",
279 #[cfg(any(feature = "luau", doc))]
280 Value::Vector(_) => "Vector",
281 #[cfg(any(feature = "luau", doc))]
282 Value::Buffer(_) => "Buffer",
283 Value::Other(_) => "Other",
284 }
285 }
286}