1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, anyhow};
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, Table, UserDataMethods};
7use protobuf::descriptor::FileDescriptorProto;
8use protobuf::reflect::{
9 EnumDescriptor, FileDescriptor, MessageDescriptor, RuntimeFieldType, RuntimeType,
10};
11use protobuf::{CodedInputStream, Message, MessageDyn};
12
13use crate::codec::LuaProtoCodec;
14use crate::descriptor::enum_descriptor::LuaEnumDescriptor;
15use crate::descriptor::file_descriptor::LuaFileDescriptor;
16use crate::descriptor::message_descriptor::LuaMessageDescriptor;
17
18#[derive(Default)]
19pub struct LuaProtoc {
20 pub codec: LuaProtoCodec,
21 pub file_descriptors: HashMap<String, LuaFileDescriptor>,
22 pub message_descriptors: HashMap<String, LuaMessageDescriptor>,
23 pub enum_descriptors: HashMap<String, LuaEnumDescriptor>,
24}
25
26impl LuaProtoc {
27 pub fn new(descriptors: Vec<FileDescriptor>) -> Self {
28 let codec = LuaProtoCodec;
29 let mut file_descriptors = HashMap::new();
30 let mut message_descriptors = HashMap::new();
31 let mut enum_descriptors = HashMap::new();
32 for file_descriptor in descriptors {
33 for message_descriptor in file_descriptor.messages() {
34 message_descriptors.insert(
35 message_descriptor.full_name().to_string(),
36 From::from(message_descriptor),
37 );
38 }
39 for enum_descriptor in file_descriptor.enums() {
40 enum_descriptors.insert(
41 enum_descriptor.full_name().to_string(),
42 From::from(enum_descriptor),
43 );
44 }
45 file_descriptors.insert(file_descriptor.name().to_string(), file_descriptor.into());
46 }
47 Self {
48 codec,
49 file_descriptors,
50 message_descriptors,
51 enum_descriptors,
52 }
53 }
54
55 pub fn parse_files(
56 inputs: impl IntoIterator<Item = impl AsRef<Path>>,
57 includes: impl IntoIterator<Item = impl AsRef<Path>>,
58 ) -> anyhow::Result<Self> {
59 let mut parser = protobuf_parse::Parser::new();
60 parser.inputs(inputs).includes(includes);
61
62 #[cfg(feature = "google_protoc")]
63 parser.protoc();
64
65 #[cfg(feature = "vendored_protoc")]
66 parser.protoc_path(
67 &protoc_bin_vendored::protoc_bin_path()
68 .context("unable to find protoc bin vendored")?,
69 );
70
71 let file_protos = parser
72 .parse_and_typecheck()
73 .context("parse proto failed")?
74 .file_descriptors;
75 let file_descriptors: Vec<FileDescriptor> =
76 FileDescriptor::new_dynamic_fds(file_protos, &[])?;
77 let protoc = LuaProtoc::new(file_descriptors);
78 Ok(protoc)
79 }
80
81 pub fn parse_proto(proto: impl AsRef<str>) -> anyhow::Result<Self> {
82 let temp_dir = tempfile::tempdir().context("unable to get tempdir")?;
83 let tempfile = temp_dir.path().join("temp.proto");
84 std::fs::write(&tempfile, proto.as_ref()).context("unable to write data to tempfile")?;
85 LuaProtoc::parse_files([&tempfile], [&temp_dir])
86 }
87
88 pub fn parse_pb(path: impl AsRef<Path>) -> anyhow::Result<Self> {
89 let mut protos = vec![];
90 for entry in walkdir::WalkDir::new(path)
91 .into_iter()
92 .filter_map(|file| file.ok())
93 {
94 let pb_path = entry.path();
95 if pb_path.extension().map(|e| e == "pb").unwrap_or(false) {
96 let mut pb_file = std::fs::File::open(pb_path)
97 .context(format!("failed open {}", pb_path.to_string_lossy()))?;
98 let mut input = CodedInputStream::new(&mut pb_file);
99 let proto = FileDescriptorProto::parse_from(&mut input)?;
100 protos.push(proto);
101 }
102 }
103 let file_descriptors = FileDescriptor::new_dynamic_fds(protos, &[])?;
104 let protoc = LuaProtoc::new(file_descriptors);
105 Ok(protoc)
106 }
107
108 pub fn gen_pb(&self, path: String) -> anyhow::Result<()> {
109 let path = PathBuf::from(path);
110 for file_descriptor in self.file_descriptors.values() {
111 let name = file_descriptor
112 .name()
113 .strip_suffix(".proto")
114 .expect("file descriptor not a proto file");
115 let file_name = format!("{}.pb", name);
116 let file_path = path.join(file_name);
117 std::fs::write(&file_path, file_descriptor.proto().write_to_bytes()?).context(
118 format!("failed write lua to file {}", file_path.to_string_lossy()),
119 )?;
120 }
121 Ok(())
122 }
123
124 pub fn gen_lua(&self, path: String) -> anyhow::Result<()> {
125 let path = PathBuf::from(path);
126 for file_descriptor in self.file_descriptors.values() {
127 let name = file_descriptor
128 .name()
129 .strip_suffix(".proto")
130 .expect("file descriptor not a proto file");
131 let mut message_or_enum = vec![];
132 for message_descriptor in file_descriptor.messages() {
133 let name = message_descriptor.name().to_string();
134 let mut nested_messages_or_enums = vec![];
135 let nested_messages = message_descriptor
136 .nested_messages()
137 .map(|m| (m.name().to_string(), name.clone()));
138 nested_messages_or_enums.extend(nested_messages);
139 let nested_enums = message_descriptor
140 .nested_enums()
141 .map(|e| (e.name().to_string(), name.clone()));
142 nested_messages_or_enums.extend(nested_enums);
143 let messages =
144 self.gen_lua_message(None, nested_messages_or_enums, &message_descriptor);
145 message_or_enum.extend(messages);
146 }
147 for enum_descriptor in file_descriptor.enums() {
148 let enum_table = self.gen_lua_enum(None, &enum_descriptor);
149 message_or_enum.push(enum_table);
150 }
151 let code = message_or_enum.join("\n");
152 let file_name = format!("{}.lua", name);
153 let file_path = path.join(file_name);
154 std::fs::write(&file_path, code).context(format!(
155 "failed write lua to file {}",
156 file_path.to_string_lossy()
157 ))?;
158 }
159 Ok(())
160 }
161
162 fn gen_lua_message(
163 &self,
164 parent: Option<String>,
165 nested_messages_or_enums: Vec<(String, String)>,
166 descriptor: &MessageDescriptor,
167 ) -> Vec<String> {
168 if descriptor.is_map_entry() {
169 return vec![];
170 }
171 let mut messages = vec![];
172 let message_name = descriptor.name().to_string();
173 let message_with_parent = self.decorate_with_parent(&parent, message_name.clone());
174 let class = format!("---@class {}", message_with_parent);
175 for nested_message_descriptor in descriptor.nested_messages() {
176 let name = self.decorate_with_parent(
177 &Some(message_with_parent.clone()),
178 nested_message_descriptor.name().to_string(),
179 );
180 let mut child_nested_messages_or_enums = nested_messages_or_enums.clone();
181 child_nested_messages_or_enums.extend(
182 nested_message_descriptor
183 .nested_messages()
184 .map(|m| (m.name().to_string(), name.clone())),
185 );
186 child_nested_messages_or_enums.extend(
187 nested_message_descriptor
188 .nested_enums()
189 .map(|e| (e.name().to_string(), name.clone())),
190 );
191 let nested_messages = self.gen_lua_message(
192 Some(message_with_parent.clone()),
193 child_nested_messages_or_enums,
194 &nested_message_descriptor,
195 );
196 messages.extend(nested_messages);
197 }
198 for nested_enum_descriptor in descriptor.nested_enums() {
199 let nested_enum =
200 self.gen_lua_enum(Some(message_with_parent.clone()), &nested_enum_descriptor);
201 messages.push(nested_enum);
202 }
203 let mut fields = vec![];
204 for field in descriptor.fields() {
205 let parent = self.decorate_message_type_with_parent(
206 field.runtime_field_type(),
207 &nested_messages_or_enums,
208 );
209 let ty = self.lua_type_of(parent.clone(), field.runtime_field_type());
210 let field = format!("---@field {} {}", field.name(), ty);
211 fields.push(field)
212 }
213 let message_table = if fields.is_empty() {
214 format!("{}\nlocal {} = {{ }}\n", class, message_with_parent)
215 } else {
216 format!(
217 "{}\n{}\nlocal {} = {{ }}\n",
218 class,
219 fields.join("\n"),
220 message_with_parent
221 )
222 };
223 messages.push(message_table);
224 messages
225 }
226
227 fn gen_lua_enum(&self, parent: Option<String>, descriptor: &EnumDescriptor) -> String {
228 let name = descriptor.name();
229 let message = match parent {
230 None => name.to_string(),
231 Some(parent) => {
232 format!("{}_{}", parent, name)
233 }
234 };
235 let class = format!("---@class {}", message);
236 let mut fields = vec![];
237 let mut enum_kv = vec![];
238 for value in descriptor.values() {
239 let field = format!("---@field {} number {}", value.name(), value.value());
240 fields.push(field);
241 enum_kv.push((value.name().to_string(), value.value().to_string()))
242 }
243 let kvs = enum_kv
244 .iter()
245 .map(|(k, v)| format!("{} = {}", k, v))
246 .collect::<Vec<String>>();
247 format!(
248 "{}\n{}\n{} = {{ {} }}\n",
249 class,
250 fields.join("\n"),
251 message,
252 kvs.join(", ")
253 )
254 }
255
256 fn decorate_with_parent(&self, parent: &Option<String>, name: String) -> String {
257 match &parent {
258 None => name.to_string(),
259 Some(parent) => {
260 format!("{}_{}", parent, name)
261 }
262 }
263 }
264
265 fn decorate_message_type_with_parent(
266 &self,
267 runtime_field_type: RuntimeFieldType,
268 nested_messages_or_enums: &[(String, String)],
269 ) -> Option<String> {
270 fn find_message(
271 nested_messages_or_enums: &[(String, String)],
272 name: &str,
273 ) -> Option<String> {
274 nested_messages_or_enums
275 .iter()
276 .rfind(|(n, _)| n == name)
277 .map(|(_, p)| p.clone())
278 }
279
280 fn decorate_message(
281 nested_messages_or_enums: &[(String, String)],
282 rt: RuntimeType,
283 ) -> Option<String> {
284 match rt {
285 RuntimeType::Enum(e) => find_message(nested_messages_or_enums, e.name()),
286 RuntimeType::Message(m) => find_message(nested_messages_or_enums, m.name()),
287 _ => None,
288 }
289 }
290 match runtime_field_type {
291 RuntimeFieldType::Singular(rt) => decorate_message(nested_messages_or_enums, rt),
292 RuntimeFieldType::Repeated(rt) => decorate_message(nested_messages_or_enums, rt),
293 RuntimeFieldType::Map(_, value_rt) => {
294 decorate_message(nested_messages_or_enums, value_rt)
295 }
296 }
297 }
298
299 fn lua_type_of(&self, parent: Option<String>, field_type: RuntimeFieldType) -> String {
300 fn type_of(protoc: &LuaProtoc, parent: Option<String>, rt: RuntimeType) -> String {
301 match rt {
302 RuntimeType::I32
303 | RuntimeType::I64
304 | RuntimeType::U32
305 | RuntimeType::U64
306 | RuntimeType::F32
307 | RuntimeType::F64 => "number".to_string(),
308 RuntimeType::Bool => "boolean".to_string(),
309 RuntimeType::String => "string".to_string(),
310 RuntimeType::VecU8 => "number[]".to_string(),
311 RuntimeType::Enum(e) => {
312 let name = e.name().to_string();
313 protoc.decorate_with_parent(&parent, name)
314 }
315 RuntimeType::Message(m) => {
316 let name = m.name().to_string();
317 protoc.decorate_with_parent(&parent, name)
318 }
319 }
320 }
321
322 match field_type {
323 RuntimeFieldType::Singular(rt) => type_of(self, parent, rt),
324 RuntimeFieldType::Repeated(rt) => {
325 format!("{}[]", type_of(self, parent, rt))
326 }
327 RuntimeFieldType::Map(key_rt, value_rt) => {
328 format!(
329 "table<{},{}>",
330 type_of(self, parent.clone(), key_rt),
331 type_of(self, parent, value_rt)
332 )
333 }
334 }
335 }
336
337 pub fn encode(
338 &self,
339 message_full_name: &str,
340 lua_message: &Table,
341 ) -> anyhow::Result<Box<dyn MessageDyn>> {
342 let descriptor = self
343 .message_descriptors
344 .get(message_full_name)
345 .ok_or(anyhow!("{} not found", message_full_name))?;
346 let message = self.codec.encode_message(lua_message, descriptor)?;
347 Ok(message)
348 }
349
350 pub fn decode(
351 &self,
352 lua: &Lua,
353 message_full_name: String,
354 message_bytes: &[u8],
355 ) -> anyhow::Result<Table> {
356 let descriptor = self
357 .message_descriptors
358 .get(&message_full_name)
359 .ok_or(anyhow!("{} not found", message_full_name))?;
360 let message = descriptor.parse_from_bytes(message_bytes)?;
361 let lua_message = self.codec.decode_message(lua, message.as_ref())?;
362 Ok(lua_message)
363 }
364
365 pub fn list_protos(paths: impl IntoIterator<Item = impl AsRef<Path>>) -> Vec<PathBuf> {
366 let mut protos = Vec::new();
367 for path in paths {
368 for file in walkdir::WalkDir::new(path)
369 .into_iter()
370 .filter_map(|file| file.ok())
371 {
372 let proto_path = file.path();
373 if proto_path
374 .extension()
375 .map(|e| e == "proto")
376 .unwrap_or(false)
377 {
378 protos.push(proto_path.to_path_buf());
379 }
380 }
381 }
382 protos
383 }
384}
385
386impl LuaUserData for LuaProtoc {
387 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
388 methods.add_function(
389 "parse_files",
390 |_, (inputs, includes): (Vec<String>, Vec<String>)| {
391 if inputs.is_empty() {
392 return Err(anyhow!("inputs must not empty").into());
393 }
394 if includes.is_empty() {
395 return Err(anyhow!("includes must not empty").into());
396 }
397 let protoc =
398 LuaProtoc::parse_files(inputs, includes).map_err(|e| anyhow!("{e:?}"))?;
399 Ok(protoc)
400 },
401 );
402
403 methods.add_function("parse_proto", |_, proto: String| {
404 let protoc = LuaProtoc::parse_proto(proto).map_err(|e| anyhow!("{e:?}"))?;
405 Ok(protoc)
406 });
407
408 methods.add_function("parse_pb", |_, dir: String| {
409 let protoc = LuaProtoc::parse_pb(dir).map_err(|e| anyhow!("{e:?}"))?;
410 Ok(protoc)
411 });
412
413 methods.add_method("gen_pb", |_, this, path: String| {
414 this.gen_pb(path).map_err(|e| anyhow!("{e:?}"))?;
415 Ok(())
416 });
417
418 methods.add_method("gen_lua", |_, this, path: String| {
419 this.gen_lua(path).map_err(|e| anyhow!("{e:?}"))?;
420 Ok(())
421 });
422
423 methods.add_method(
424 "encode",
425 |_, protoc, (message_full_name, lua_message): (String, Table)| {
426 let message = protoc
427 .encode(&message_full_name, &lua_message)
428 .map_err(|e| anyhow!("{e:?}"))?;
429 let mut message_bytes = Vec::with_capacity(message.compute_size_dyn() as usize);
430 message
431 .write_to_vec_dyn(&mut message_bytes)
432 .map_err(|e| anyhow!("{e:?}"))?;
433 Ok(message_bytes)
434 },
435 );
436
437 methods.add_method(
438 "decode",
439 |lua, protoc, (message_full_name, message_bytes): (String, Vec<u8>)| {
440 let message = protoc
441 .decode(lua, message_full_name, message_bytes.as_ref())
442 .map_err(|e| anyhow!("{e:?}"))?;
443 Ok(message)
444 },
445 );
446
447 methods.add_function("list_protos", |_, paths: Vec<String>| {
448 let protos = LuaProtoc::list_protos(paths)
449 .iter()
450 .map(|p| p.to_string_lossy().to_string())
451 .collect::<Vec<String>>();
452 Ok(protos)
453 });
454
455 methods.add_method("all_file_descriptors", |_, protoc, ()| {
456 let descriptors: Vec<_> = protoc.file_descriptors.values().map(Clone::clone).collect();
457 Ok(descriptors)
458 });
459
460 methods.add_method("file_descriptor_by_name", |_, protoc, name: String| {
461 let descriptor = protoc.file_descriptors.get(&name).cloned();
462 Ok(descriptor)
463 });
464
465 methods.add_method("all_message_descriptors", |_, protoc, ()| {
466 let descriptors: Vec<_> = protoc
467 .message_descriptors
468 .values()
469 .map(Clone::clone)
470 .collect();
471 Ok(descriptors)
472 });
473
474 methods.add_method("message_descriptor_by_name", |_, protoc, name: String| {
475 let descriptor = protoc.message_descriptors.get(&name).cloned();
476 Ok(descriptor)
477 });
478
479 methods.add_method("all_enum_descriptors", |_, protoc, ()| {
480 let descriptors: Vec<_> = protoc.enum_descriptors.values().map(Clone::clone).collect();
481 Ok(descriptors)
482 });
483
484 methods.add_method("enum_descriptor_by_name", |_, protoc, name: String| {
485 let descriptor = protoc.enum_descriptors.get(&name).cloned();
486 Ok(descriptor)
487 });
488 }
489}