1use std::collections::HashSet;
2use std::ops::Deref;
3
4use anyhow::{anyhow, Context};
5use mlua::prelude::LuaValue;
6use mlua::{Integer, Lua, Number, Table, Value};
7use protobuf::reflect::{MessageDescriptor, ReflectValueBox, ReflectValueRef, RuntimeFieldType, RuntimeType};
8use protobuf::MessageDyn;
9
10#[derive(Copy, Clone, Default)]
11pub struct LuaProtoCodec;
12
13impl LuaProtoCodec {
14 pub fn encode_message(&self, lua_message: &Table, descriptor: &MessageDescriptor) -> anyhow::Result<Box<dyn MessageDyn>> {
15 let name = descriptor.name();
16 let mut message = descriptor.new_instance();
17 for pair in lua_message.pairs::<Value, Value>() {
18 let (field_key, field_value) = pair?;
19 let field_key = field_key.as_str().ok_or(anyhow!("message {} expect a string key",name))?;
20 let field_descriptor = descriptor.field_by_name(&field_key).ok_or(anyhow!("field {} not found in message {}",field_key,name))?;
21 match field_descriptor.runtime_field_type() {
22 RuntimeFieldType::Singular(ty) => {
23 let boxed_value = self.box_value(name, &field_key, &ty, field_value)?;
24 field_descriptor.set_singular_field(message.as_mut(), boxed_value);
25 }
26 RuntimeFieldType::Repeated(ty) => {
27 let mut field_repeated = field_descriptor.mut_repeated(message.as_mut());
28 let table = field_value.as_table().ok_or(anyhow!("message {} field {} expect a table",name,field_key))?;
29 for v in table.sequence_values::<Value>() {
30 let v = v?;
31 let boxed_value = self.box_value(name, &field_key, &ty, v)?;
32 field_repeated.push(boxed_value);
33 }
34 }
35 RuntimeFieldType::Map(k_ty, v_ty) => {
36 let mut field_map = field_descriptor.mut_map(message.as_mut());
37 let table = field_value.as_table().ok_or(anyhow!("message {} field {} expect a table",name,field_key))?;
38 for pair in table.pairs::<Value, Value>() {
39 let (key, value) = pair?;
40 let key = self.box_value(name, &field_key, &k_ty, key)?;
41 let value = self.box_value(name, &field_key, &v_ty, value)?;
42 field_map.insert(key, value);
43 }
44 }
45 }
46 }
47 Ok(message)
48 }
49
50 pub fn decode_message(&self, lua: &Lua, message: &dyn MessageDyn) -> anyhow::Result<Table> {
51 let lua_message = lua.create_table()?;
52 let descriptor = message.descriptor_dyn();
53 let message_name = descriptor.name();
54 let mut oneof_field = HashSet::new();
55 for oneof_descriptor in descriptor.oneofs() {
56 for field in oneof_descriptor.fields() {
57 oneof_field.insert(field.name().to_string());
58 }
59 }
60 for field in descriptor.fields() {
61 let field_name = field.name();
62 match field.runtime_field_type() {
63 RuntimeFieldType::Singular(_) => {
64 if oneof_field.contains(field_name) {
65 if let Some(value) = field.get_singular(message) {
66 let field_table = self.unbox_value(message_name, field_name, value, lua)?;
67 lua_message.set(field_name, field_table)?;
68 }
69 } else {
70 let value = field.get_singular_field_or_default(message);
71 let field_table = self.unbox_value(message_name, field_name, value, lua)?;
72 lua_message.set(field_name, field_table)?;
73 }
74 }
75 RuntimeFieldType::Repeated(_) => {
76 let field_table = lua.create_table()?;
77 let values = field.get_repeated(message);
78 for value in values {
79 let v = self.unbox_value(message_name, field_name, value, lua)?;
80 field_table.push(v)?;
81 }
82 lua_message.set(field_name, field_table)?;
83 }
84 RuntimeFieldType::Map(_, _) => {
85 let field_table = lua.create_table()?;
86 let maps = field.get_map(message);
87 for (k, v) in maps.into_iter() {
88 let k = self.unbox_value(message_name, field_name, k, lua)?;
89 let v = self.unbox_value(message_name, field_name, v, lua)?;
90 field_table.set(k, v)?;
91 }
92 lua_message.set(field_name, field_table)?;
93 }
94 }
95 }
96 Ok(lua_message)
97 }
98
99 pub fn box_value(&self, name: &str, field: &str, ty: &RuntimeType, value: Value) -> anyhow::Result<ReflectValueBox> {
100 fn value_cast_error(message: &str, field: &str, value: &str, ty: &str) -> anyhow::Error {
101 anyhow!("message {} field {} value {} cannot be cast to {}",message,field,value,ty)
102 }
103 let value_ty = self.fmt_value(&value);
104 let value_box = match ty {
105 RuntimeType::I32 => {
106 let value = value.as_i32().ok_or(value_cast_error(name, field, value_ty, "i32"))?;
107 ReflectValueBox::I32(value)
108 }
109 RuntimeType::I64 => {
110 let value = value.as_i64().ok_or(value_cast_error(name, field, value_ty, "i64"))?;
111 ReflectValueBox::I64(value)
112 }
113 RuntimeType::U32 => {
114 let value = value.as_u32().ok_or(value_cast_error(name, field, value_ty, "u32"))?;
115 ReflectValueBox::U32(value)
116 }
117 RuntimeType::U64 => {
118 let value = value.as_u64().ok_or(value_cast_error(name, field, value_ty, "u64"))?;
119 ReflectValueBox::U64(value)
120 }
121 RuntimeType::F32 => {
122 let value = value.as_f32().ok_or(value_cast_error(name, field, value_ty, "f32"))?;
123 ReflectValueBox::F32(value)
124 }
125 RuntimeType::F64 => {
126 let value = value.as_f64().ok_or(value_cast_error(name, field, value_ty, "f64"))?;
127 ReflectValueBox::F64(value)
128 }
129 RuntimeType::Bool => {
130 let value = value.as_boolean().ok_or(value_cast_error(name, field, value_ty, "bool"))?;
131 ReflectValueBox::Bool(value)
132 }
133 RuntimeType::String => {
134 let value = value.as_string_lossy().ok_or(value_cast_error(name, field, value_ty, "string"))?;
135 ReflectValueBox::String(value.to_string())
136 }
137 RuntimeType::VecU8 => {
138 let table = value.as_table().ok_or(value_cast_error(name, field, value_ty, "table"))?;
139 let len = table.len()?;
140 let mut bytes = Vec::with_capacity(len as usize);
141 for byte in table.sequence_values::<u8>() {
142 let byte = anyhow::Context::context(byte, format!("message {} field {} expect u8 table", name, field))?;
143 bytes.push(byte);
144 }
145 ReflectValueBox::Bytes(bytes)
146 }
147 RuntimeType::Enum(descriptor) => {
148 let value = value.as_i32().ok_or(value_cast_error(name, field, &value_ty, "i32"))?;
149 descriptor.value_by_number(value).ok_or(anyhow!("incorrect number of enum {}",descriptor.name()))?;
150 ReflectValueBox::Enum(descriptor.clone(), value)
151 }
152 RuntimeType::Message(descriptor) => {
153 let table = value.as_table().ok_or(value_cast_error(name, field, &value_ty, "i32"))?;
154 let message = self.encode_message(table, descriptor)?;
155 ReflectValueBox::Message(message)
156 }
157 };
158 Ok(value_box)
159 }
160
161 pub fn unbox_value(&self, message_name: &str, field_name: &str, value: ReflectValueRef, lua: &Lua) -> anyhow::Result<LuaValue> {
162 let lua_value = match value {
163 ReflectValueRef::U32(u) => { Value::Integer(Integer::from(u)) }
164 ReflectValueRef::U64(u) => {
165 let u = u32::try_from(u).context(format!("message {} field {} cannot cast u64 value {} to u32", message_name, field_name, u))?;
166 Value::Integer(Integer::from(u))
167 }
168 ReflectValueRef::I32(i) => { Value::Integer(Integer::from(i)) }
169 ReflectValueRef::I64(i) => { Value::Integer(Integer::from(i)) }
170 ReflectValueRef::F32(f) => { Value::Number(Number::from(f)) }
171 ReflectValueRef::F64(f) => { Value::Number(Number::from(f)) }
172 ReflectValueRef::Bool(b) => { Value::Boolean(b) }
173 ReflectValueRef::String(s) => {
174 let lua_string = lua.create_string(s)?;
175 Value::String(lua_string)
176 }
177 ReflectValueRef::Bytes(bytes) => {
178 let table = lua.create_table()?;
179 for byte in bytes {
180 table.push(*byte)?;
181 }
182 Value::Table(table)
183 }
184 ReflectValueRef::Enum(_, i) => {
185 Value::Integer(Integer::from(i))
186 }
187 ReflectValueRef::Message(m) => {
188 let table = self.decode_message(lua, m.deref())?;
189 Value::Table(table)
190 }
191 };
192 Ok(lua_value)
193 }
194
195 fn fmt_value(&self, value: &Value) -> &'static str {
196 match value {
197 Value::Nil => { "Nil" }
198 Value::Boolean(_) => { "Boolean" }
199 Value::LightUserData(_) => { "LightUserData" }
200 Value::Integer(_) => { "Integer" }
201 Value::Number(_) => { "Number" }
202 Value::String(_) => { "String" }
203 Value::Table(_) => { "Table" }
204 Value::Function(_) => { "Function" }
205 Value::Thread(_) => { "Thread" }
206 Value::UserData(_) => { "UserData" }
207 Value::Error(_) => { "Error" }
208 #[cfg(any(feature = "luau", doc))]
209 Value::Vector(_) => { "Vector" }
210 #[cfg(any(feature = "luau", doc))]
211 Value::Buffer(_) => { "Buffer" }
212 Value::Other(_) => { "Other" }
213 }
214 }
215}