1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::{anyhow, Context};
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, Table, UserDataMethods};
7use protobuf::descriptor::FileDescriptorProto;
8use protobuf::reflect::{EnumDescriptor, FileDescriptor, MessageDescriptor, RuntimeFieldType, RuntimeType};
9use protobuf::{CodedInputStream, Message, MessageDyn};
10
11use crate::codec::LuaProtoCodec;
12use crate::descriptor::enum_descriptor::LuaEnumDescriptor;
13use crate::descriptor::file_descriptor::LuaFileDescriptor;
14use crate::descriptor::message_descriptor::LuaMessageDescriptor;
15
16#[derive(Default)]
17pub struct LuaProtoc {
18 pub codec: LuaProtoCodec,
19 pub file_descriptors: HashMap<String, LuaFileDescriptor>,
20 pub message_descriptors: HashMap<String, LuaMessageDescriptor>,
21 pub enum_descriptors: HashMap<String, LuaEnumDescriptor>,
22}
23
24impl LuaProtoc {
25 pub fn new(descriptors: Vec<FileDescriptor>) -> Self {
26 let codec = LuaProtoCodec::default();
27 let mut file_descriptors = HashMap::new();
28 let mut message_descriptors = HashMap::new();
29 let mut enum_descriptors = HashMap::new();
30 for file_descriptor in descriptors {
31 for message_descriptor in file_descriptor.messages() {
32 message_descriptors.insert(message_descriptor.full_name().to_string(), From::from(message_descriptor));
33 }
34 for enum_descriptor in file_descriptor.enums() {
35 enum_descriptors.insert(enum_descriptor.full_name().to_string(), From::from(enum_descriptor));
36 }
37 file_descriptors.insert(file_descriptor.name().to_string(), file_descriptor.into());
38 };
39 Self {
40 codec,
41 file_descriptors,
42 message_descriptors,
43 enum_descriptors,
44 }
45 }
46
47 pub fn parse_files(inputs: impl IntoIterator<Item=impl AsRef<Path>>, includes: impl IntoIterator<Item=impl AsRef<Path>>) -> anyhow::Result<Self> {
48 let mut parser = protobuf_parse::Parser::new();
49 parser.inputs(inputs).includes(includes);
50
51 #[cfg(feature = "google_protoc")]
52 parser.protoc();
53
54 #[cfg(feature = "vendored_protoc")]
55 parser.protoc_path(&protoc_bin_vendored::protoc_bin_path().context("unable to find protoc bin vendored")?);
56
57 let file_protos = parser.parse_and_typecheck().context("parse proto failed")?.file_descriptors;
58 let file_descriptors: Vec<FileDescriptor> = FileDescriptor::new_dynamic_fds(file_protos, &[])?;
59 let protoc = LuaProtoc::new(file_descriptors);
60 Ok(protoc)
61 }
62
63 pub fn parse_proto(proto: impl AsRef<str>) -> anyhow::Result<Self> {
64 let temp_dir = tempfile::tempdir().context("unable to get tempdir")?;
65 let tempfile = temp_dir.path().join("temp.proto");
66 std::fs::write(&tempfile, proto.as_ref()).context("unable to write data to tempfile")?;
67 LuaProtoc::parse_files([&tempfile], [&temp_dir])
68 }
69
70 pub fn parse_pb(path: impl AsRef<Path>) -> anyhow::Result<Self> {
71 let mut protos = vec![];
72 for entry in walkdir::WalkDir::new(path).into_iter().filter_map(|file| file.ok()) {
73 let pb_path = entry.path();
74 if pb_path.extension().and_then(|e| Some(e == "pb")).unwrap_or(false) {
75 let mut pb_file = std::fs::File::open(pb_path).context(format!("failed open {}", pb_path.to_string_lossy()))?;
76 let mut input = CodedInputStream::new(&mut pb_file);
77 let proto = FileDescriptorProto::parse_from(&mut input)?;
78 protos.push(proto);
79 }
80 }
81 let file_descriptors = FileDescriptor::new_dynamic_fds(protos, &[])?;
82 let protoc = LuaProtoc::new(file_descriptors);
83 Ok(protoc)
84 }
85
86 pub fn gen_pb(&self, path: String) -> anyhow::Result<()> {
87 let path = PathBuf::from(path);
88 for (_, file_descriptor) in &self.file_descriptors {
89 let name = file_descriptor.name().strip_suffix(".proto").expect("file descriptor not a proto file");
90 let file_name = format!("{}.pb", name);
91 let file_path = path.join(file_name);
92 std::fs::write(&file_path, file_descriptor.proto().write_to_bytes()?).context(format!("failed write lua to file {}", file_path.to_string_lossy()))?;
93 }
94 Ok(())
95 }
96
97 pub fn gen_lua(&self, path: String) -> anyhow::Result<()> {
98 let path = PathBuf::from(path);
99 for (_, file_descriptor) in &self.file_descriptors {
100 let name = file_descriptor.name().strip_suffix(".proto").expect("file descriptor not a proto file");
101 let mut message_or_enum = vec![];
102 for message_descriptor in file_descriptor.messages() {
103 let name = message_descriptor.name().to_string();
104 let mut nested_messages_or_enums = vec![];
105 let nested_messages = message_descriptor.nested_messages().map(|m| (m.name().to_string(), name.clone()));
106 nested_messages_or_enums.extend(nested_messages);
107 let nested_enums = message_descriptor.nested_enums().map(|e| (e.name().to_string(), name.clone()));
108 nested_messages_or_enums.extend(nested_enums);
109 let messages = self.gen_lua_message(None, nested_messages_or_enums, &message_descriptor);
110 message_or_enum.extend(messages);
111 }
112 for enum_descriptor in file_descriptor.enums() {
113 let enum_table = self.gen_lua_enum(None, &enum_descriptor);
114 message_or_enum.push(enum_table);
115 }
116 let code = message_or_enum.join("\n");
117 let file_name = format!("{}.lua", name);
118 let file_path = path.join(file_name);
119 std::fs::write(&file_path, code).context(format!("failed write lua to file {}", file_path.to_string_lossy()))?;
120 }
121 Ok(())
122 }
123
124 fn gen_lua_message(&self, parent: Option<String>, nested_messages_or_enums: Vec<(String, String)>, descriptor: &MessageDescriptor) -> Vec<String> {
125 if descriptor.is_map_entry() {
126 return vec![];
127 }
128 let mut messages = vec![];
129 let message_name = descriptor.name().to_string();
130 let message_with_parent = self.decorate_with_parent(&parent, message_name.clone());
131 let class = format!("---@class {}", message_with_parent);
132 for nested_message_descriptor in descriptor.nested_messages() {
133 let name = self.decorate_with_parent(&Some(message_with_parent.clone()), nested_message_descriptor.name().to_string());
134 let mut child_nested_messages_or_enums = nested_messages_or_enums.clone();
135 child_nested_messages_or_enums.extend(nested_message_descriptor.nested_messages().map(|m| (m.name().to_string(), name.clone())));
136 child_nested_messages_or_enums.extend(nested_message_descriptor.nested_enums().map(|e| (e.name().to_string(), name.clone())));
137 let nested_messages = self.gen_lua_message(Some(message_with_parent.clone()), child_nested_messages_or_enums, &nested_message_descriptor);
138 messages.extend(nested_messages);
139 }
140 for nested_enum_descriptor in descriptor.nested_enums() {
141 let nested_enum = self.gen_lua_enum(Some(message_with_parent.clone()), &nested_enum_descriptor);
142 messages.push(nested_enum);
143 }
144 let mut fields = vec![];
145 for field in descriptor.fields() {
146 let parent = self.decorate_message_type_with_parent(field.runtime_field_type(), &nested_messages_or_enums);
147 let ty = self.lua_type_of(parent.clone(), field.runtime_field_type());
148 let field = format!("---@field {} {}", field.name(), ty);
149 fields.push(field)
150 }
151 let message_table = if fields.is_empty() {
152 format!("{}\nlocal {} = {{ }}\n", class, message_with_parent)
153 } else {
154 format!("{}\n{}\nlocal {} = {{ }}\n", class, fields.join("\n"), message_with_parent)
155 };
156 messages.push(message_table);
157 messages
158 }
159
160 fn gen_lua_enum(&self, parent: Option<String>, descriptor: &EnumDescriptor) -> String {
161 let name = descriptor.name();
162 let message = match parent {
163 None => {
164 format!("{}", name)
165 }
166 Some(parent) => {
167 format!("{}_{}", parent, name)
168 }
169 };
170 let class = format!("---@class {}", message);
171 let mut fields = vec![];
172 let mut enum_kv = vec![];
173 for value in descriptor.values() {
174 let field = format!("---@field {} number {}", value.name(), value.value());
175 fields.push(field);
176 enum_kv.push((value.name().to_string(), value.value().to_string()))
177 }
178 let kvs = enum_kv.iter().map(|(k, v)| format!("{} = {}", k, v)).collect::<Vec<String>>();
179 format!("{}\n{}\n{} = {{ {} }}\n", class, fields.join("\n"), message, kvs.join(", "))
180 }
181
182 fn decorate_with_parent(&self, parent: &Option<String>, name: String) -> String {
183 let message = match &parent {
184 None => {
185 format!("{}", name)
186 }
187 Some(parent) => {
188 format!("{}_{}", parent, name)
189 }
190 };
191 message
192 }
193
194 fn decorate_message_type_with_parent(&self, runtime_field_type: RuntimeFieldType, nested_messages_or_enums: &Vec<(String, String)>) -> Option<String> {
195 fn find_message(nested_messages_or_enums: &Vec<(String, String)>, name: &str) -> Option<String> {
196 match nested_messages_or_enums.iter().rfind(|(n, _)| n == name) {
197 None => {
198 None
199 }
200 Some((_, p)) => {
201 Some(p.clone())
202 }
203 }
204 }
205
206 fn decorate_message(nested_messages_or_enums: &Vec<(String, String)>, rt: RuntimeType) -> Option<String> {
207 match rt {
208 RuntimeType::Enum(e) => {
209 find_message(nested_messages_or_enums, e.name())
210 }
211 RuntimeType::Message(m) => {
212 find_message(nested_messages_or_enums, m.name())
213 }
214 _ => None
215 }
216 }
217 match runtime_field_type {
218 RuntimeFieldType::Singular(rt) => {
219 decorate_message(nested_messages_or_enums, rt)
220 }
221 RuntimeFieldType::Repeated(rt) => {
222 decorate_message(nested_messages_or_enums, rt)
223 }
224 RuntimeFieldType::Map(_, value_rt) => {
225 decorate_message(nested_messages_or_enums, value_rt)
226 }
227 }
228 }
229
230 fn lua_type_of(&self, parent: Option<String>, field_type: RuntimeFieldType) -> String {
231 fn type_of(protoc: &LuaProtoc, parent: Option<String>, rt: RuntimeType) -> String {
232 let ty = match rt {
233 RuntimeType::I32 |
234 RuntimeType::I64 |
235 RuntimeType::U32 |
236 RuntimeType::U64 |
237 RuntimeType::F32 |
238 RuntimeType::F64 => "number".to_string(),
239 RuntimeType::Bool => "boolean".to_string(),
240 RuntimeType::String => "string".to_string(),
241 RuntimeType::VecU8 => "number[]".to_string(),
242 RuntimeType::Enum(e) => {
243 let name = e.name().to_string();
244 protoc.decorate_with_parent(&parent, name)
245 }
246 RuntimeType::Message(m) => {
247 let name = m.name().to_string();
248 protoc.decorate_with_parent(&parent, name)
249 }
250 };
251 ty
252 }
253 let ty = match field_type {
254 RuntimeFieldType::Singular(rt) => {
255 type_of(self, parent, rt)
256 }
257 RuntimeFieldType::Repeated(rt) => {
258 format!("{}[]", type_of(self, parent, rt))
259 }
260 RuntimeFieldType::Map(key_rt, value_rt) => {
261 format!("table<{},{}>", type_of(self, parent.clone(), key_rt), type_of(self, parent, value_rt))
262 }
263 };
264 ty
265 }
266
267
268 pub fn encode(&self, message_full_name: &str, lua_message: &Table) -> anyhow::Result<Box<dyn MessageDyn>> {
269 let descriptor = self.message_descriptors.get(message_full_name).ok_or(anyhow!("{} not found",message_full_name))?;
270 let message = self.codec.encode_message(lua_message, descriptor)?;
271 Ok(message)
272 }
273
274 pub fn decode(&self, lua: &Lua, message_full_name: String, message_bytes: &[u8]) -> anyhow::Result<Table> {
275 let descriptor = self.message_descriptors.get(&message_full_name).ok_or(anyhow!("{} not found",message_full_name))?;
276 let message = descriptor.parse_from_bytes(message_bytes)?;
277 let lua_message = self.codec.decode_message(lua, message.as_ref())?;
278 Ok(lua_message)
279 }
280
281 pub fn list_protos(paths: impl IntoIterator<Item=impl AsRef<Path>>) -> Vec<PathBuf> {
282 let mut protos = Vec::new();
283 for path in paths {
284 for file in walkdir::WalkDir::new(path).into_iter().filter_map(|file| file.ok()) {
285 let proto_path = file.path();
286 if proto_path.extension().and_then(|e| Some(e == "proto")).unwrap_or(false) {
287 protos.push(proto_path.to_path_buf());
288 }
289 }
290 }
291 protos
292 }
293}
294
295impl LuaUserData for LuaProtoc {
296 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
297 methods.add_function("parse_files", |_, (inputs, includes): (Vec<String>, Vec<String>)| {
298 if inputs.is_empty() {
299 return Err(anyhow!("inputs must not empty").into());
300 }
301 if includes.is_empty() {
302 return Err(anyhow!("includes must not empty").into());
303 }
304 let protoc = LuaProtoc::parse_files(inputs, includes).map_err(|e| anyhow!("{e:?}"))?;
305 Ok(protoc)
306 });
307
308 methods.add_function("parse_proto", |_, proto: String| {
309 let protoc = LuaProtoc::parse_proto(proto).map_err(|e| anyhow!("{e:?}"))?;
310 Ok(protoc)
311 });
312
313 methods.add_function("parse_pb", |_, dir: String| {
314 let protoc = LuaProtoc::parse_pb(dir).map_err(|e| anyhow!("{e:?}"))?;
315 Ok(protoc)
316 });
317
318 methods.add_method("gen_pb", |_, this, path: String| {
319 this.gen_pb(path).map_err(|e| anyhow!("{e:?}"))?;
320 Ok(())
321 });
322
323 methods.add_method("gen_lua", |_, this, path: String| {
324 this.gen_lua(path).map_err(|e| anyhow!("{e:?}"))?;
325 Ok(())
326 });
327
328 methods.add_method("encode", |_, protoc, (message_full_name, lua_message): (String, Table)| {
329 let message = protoc.encode(&message_full_name, &lua_message).map_err(|e| anyhow!("{e:?}"))?;
330 let mut message_bytes = Vec::with_capacity(message.compute_size_dyn() as usize);
331 message.write_to_vec_dyn(&mut message_bytes).map_err(|e| anyhow!("{e:?}"))?;
332 Ok(message_bytes)
333 });
334
335 methods.add_method("decode", |lua, protoc, (message_full_name, message_bytes): (String, Vec<u8>)| {
336 let message = protoc.decode(lua, message_full_name, message_bytes.as_ref()).map_err(|e| anyhow!("{e:?}"))?;
337 Ok(message)
338 });
339
340 methods.add_function("list_protos", |_, paths: Vec<String>| {
341 let protos = LuaProtoc::list_protos(paths).iter().map(|p| { p.to_string_lossy().to_string() }).collect::<Vec<String>>();
342 Ok(protos)
343 });
344
345 methods.add_method("all_file_descriptors", |_, protoc, ()| {
346 let descriptors: Vec<_> = protoc.file_descriptors.values().map(Clone::clone).collect();
347 Ok(descriptors)
348 });
349
350 methods.add_method("file_descriptor_by_name", |_, protoc, name: String| {
351 let descriptor = protoc.file_descriptors.get(&name).map(Clone::clone);
352 Ok(descriptor)
353 });
354
355 methods.add_method("all_message_descriptors", |_, protoc, ()| {
356 let descriptors: Vec<_> = protoc.message_descriptors.values().map(Clone::clone).collect();
357 Ok(descriptors)
358 });
359
360 methods.add_method("message_descriptor_by_name", |_, protoc, name: String| {
361 let descriptor = protoc.message_descriptors.get(&name).map(Clone::clone);
362 Ok(descriptor)
363 });
364
365 methods.add_method("all_enum_descriptors", |_, protoc, ()| {
366 let descriptors: Vec<_> = protoc.enum_descriptors.values().map(Clone::clone).collect();
367 Ok(descriptors)
368 });
369
370 methods.add_method("enum_descriptor_by_name", |_, protoc, name: String| {
371 let descriptor = protoc.enum_descriptors.get(&name).map(Clone::clone);
372 Ok(descriptor)
373 });
374 }
375}