lua_protobuf_rs/
protoc.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, anyhow};
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, Table, UserDataMethods};
7use protobuf::descriptor::FileDescriptorProto;
8use protobuf::reflect::{
9    EnumDescriptor, FileDescriptor, MessageDescriptor, RuntimeFieldType, RuntimeType,
10};
11use protobuf::{CodedInputStream, Message, MessageDyn};
12
13use crate::codec::LuaProtoCodec;
14use crate::descriptor::enum_descriptor::LuaEnumDescriptor;
15use crate::descriptor::file_descriptor::LuaFileDescriptor;
16use crate::descriptor::message_descriptor::LuaMessageDescriptor;
17
18#[derive(Default)]
19pub struct LuaProtoc {
20    pub codec: LuaProtoCodec,
21    pub file_descriptors: HashMap<String, LuaFileDescriptor>,
22    pub message_descriptors: HashMap<String, LuaMessageDescriptor>,
23    pub enum_descriptors: HashMap<String, LuaEnumDescriptor>,
24}
25
26impl LuaProtoc {
27    pub fn new(descriptors: Vec<FileDescriptor>) -> Self {
28        let codec = LuaProtoCodec;
29        let mut file_descriptors = HashMap::new();
30        let mut message_descriptors = HashMap::new();
31        let mut enum_descriptors = HashMap::new();
32        for file_descriptor in descriptors {
33            for message_descriptor in file_descriptor.messages() {
34                message_descriptors.insert(
35                    message_descriptor.full_name().to_string(),
36                    From::from(message_descriptor),
37                );
38            }
39            for enum_descriptor in file_descriptor.enums() {
40                enum_descriptors.insert(
41                    enum_descriptor.full_name().to_string(),
42                    From::from(enum_descriptor),
43                );
44            }
45            file_descriptors.insert(file_descriptor.name().to_string(), file_descriptor.into());
46        }
47        Self {
48            codec,
49            file_descriptors,
50            message_descriptors,
51            enum_descriptors,
52        }
53    }
54
55    pub fn parse_files(
56        inputs: impl IntoIterator<Item = impl AsRef<Path>>,
57        includes: impl IntoIterator<Item = impl AsRef<Path>>,
58    ) -> anyhow::Result<Self> {
59        let mut parser = protobuf_parse::Parser::new();
60        parser.inputs(inputs).includes(includes);
61
62        #[cfg(feature = "google_protoc")]
63        parser.protoc();
64
65        #[cfg(feature = "vendored_protoc")]
66        parser.protoc_path(
67            &protoc_bin_vendored::protoc_bin_path()
68                .context("unable to find protoc bin vendored")?,
69        );
70
71        let file_protos = parser
72            .parse_and_typecheck()
73            .context("parse proto failed")?
74            .file_descriptors;
75        let file_descriptors: Vec<FileDescriptor> =
76            FileDescriptor::new_dynamic_fds(file_protos, &[])?;
77        let protoc = LuaProtoc::new(file_descriptors);
78        Ok(protoc)
79    }
80
81    pub fn parse_proto(proto: impl AsRef<str>) -> anyhow::Result<Self> {
82        let temp_dir = tempfile::tempdir().context("unable to get tempdir")?;
83        let tempfile = temp_dir.path().join("temp.proto");
84        std::fs::write(&tempfile, proto.as_ref()).context("unable to write data to tempfile")?;
85        LuaProtoc::parse_files([&tempfile], [&temp_dir])
86    }
87
88    pub fn parse_pb(path: impl AsRef<Path>) -> anyhow::Result<Self> {
89        let mut protos = vec![];
90        for entry in walkdir::WalkDir::new(path)
91            .into_iter()
92            .filter_map(|file| file.ok())
93        {
94            let pb_path = entry.path();
95            if pb_path.extension().map(|e| e == "pb").unwrap_or(false) {
96                let mut pb_file = std::fs::File::open(pb_path)
97                    .context(format!("failed open {}", pb_path.to_string_lossy()))?;
98                let mut input = CodedInputStream::new(&mut pb_file);
99                let proto = FileDescriptorProto::parse_from(&mut input)?;
100                protos.push(proto);
101            }
102        }
103        let file_descriptors = FileDescriptor::new_dynamic_fds(protos, &[])?;
104        let protoc = LuaProtoc::new(file_descriptors);
105        Ok(protoc)
106    }
107
108    pub fn gen_pb(&self, path: String) -> anyhow::Result<()> {
109        let path = PathBuf::from(path);
110        for file_descriptor in self.file_descriptors.values() {
111            let name = file_descriptor
112                .name()
113                .strip_suffix(".proto")
114                .expect("file descriptor not a proto file");
115            let file_name = format!("{}.pb", name);
116            let file_path = path.join(file_name);
117            std::fs::write(&file_path, file_descriptor.proto().write_to_bytes()?).context(
118                format!("failed write lua to file {}", file_path.to_string_lossy()),
119            )?;
120        }
121        Ok(())
122    }
123
124    pub fn gen_lua(&self, path: String) -> anyhow::Result<()> {
125        let path = PathBuf::from(path);
126        for file_descriptor in self.file_descriptors.values() {
127            let name = file_descriptor
128                .name()
129                .strip_suffix(".proto")
130                .expect("file descriptor not a proto file");
131            let mut message_or_enum = vec![];
132            for message_descriptor in file_descriptor.messages() {
133                let name = message_descriptor.name().to_string();
134                let mut nested_messages_or_enums = vec![];
135                let nested_messages = message_descriptor
136                    .nested_messages()
137                    .map(|m| (m.name().to_string(), name.clone()));
138                nested_messages_or_enums.extend(nested_messages);
139                let nested_enums = message_descriptor
140                    .nested_enums()
141                    .map(|e| (e.name().to_string(), name.clone()));
142                nested_messages_or_enums.extend(nested_enums);
143                let messages =
144                    self.gen_lua_message(None, nested_messages_or_enums, &message_descriptor);
145                message_or_enum.extend(messages);
146            }
147            for enum_descriptor in file_descriptor.enums() {
148                let enum_table = self.gen_lua_enum(None, &enum_descriptor);
149                message_or_enum.push(enum_table);
150            }
151            let code = message_or_enum.join("\n");
152            let file_name = format!("{}.lua", name);
153            let file_path = path.join(file_name);
154            std::fs::write(&file_path, code).context(format!(
155                "failed write lua to file {}",
156                file_path.to_string_lossy()
157            ))?;
158        }
159        Ok(())
160    }
161
162    fn gen_lua_message(
163        &self,
164        parent: Option<String>,
165        nested_messages_or_enums: Vec<(String, String)>,
166        descriptor: &MessageDescriptor,
167    ) -> Vec<String> {
168        if descriptor.is_map_entry() {
169            return vec![];
170        }
171        let mut messages = vec![];
172        let message_name = descriptor.name().to_string();
173        let message_with_parent = self.decorate_with_parent(&parent, message_name.clone());
174        let class = format!("---@class {}", message_with_parent);
175        for nested_message_descriptor in descriptor.nested_messages() {
176            let name = self.decorate_with_parent(
177                &Some(message_with_parent.clone()),
178                nested_message_descriptor.name().to_string(),
179            );
180            let mut child_nested_messages_or_enums = nested_messages_or_enums.clone();
181            child_nested_messages_or_enums.extend(
182                nested_message_descriptor
183                    .nested_messages()
184                    .map(|m| (m.name().to_string(), name.clone())),
185            );
186            child_nested_messages_or_enums.extend(
187                nested_message_descriptor
188                    .nested_enums()
189                    .map(|e| (e.name().to_string(), name.clone())),
190            );
191            let nested_messages = self.gen_lua_message(
192                Some(message_with_parent.clone()),
193                child_nested_messages_or_enums,
194                &nested_message_descriptor,
195            );
196            messages.extend(nested_messages);
197        }
198        for nested_enum_descriptor in descriptor.nested_enums() {
199            let nested_enum =
200                self.gen_lua_enum(Some(message_with_parent.clone()), &nested_enum_descriptor);
201            messages.push(nested_enum);
202        }
203        let mut fields = vec![];
204        for field in descriptor.fields() {
205            let parent = self.decorate_message_type_with_parent(
206                field.runtime_field_type(),
207                &nested_messages_or_enums,
208            );
209            let ty = self.lua_type_of(parent.clone(), field.runtime_field_type());
210            let field = format!("---@field {} {}", field.name(), ty);
211            fields.push(field)
212        }
213        let message_table = if fields.is_empty() {
214            format!("{}\nlocal {} = {{ }}\n", class, message_with_parent)
215        } else {
216            format!(
217                "{}\n{}\nlocal {} = {{ }}\n",
218                class,
219                fields.join("\n"),
220                message_with_parent
221            )
222        };
223        messages.push(message_table);
224        messages
225    }
226
227    fn gen_lua_enum(&self, parent: Option<String>, descriptor: &EnumDescriptor) -> String {
228        let name = descriptor.name();
229        let message = match parent {
230            None => name.to_string(),
231            Some(parent) => {
232                format!("{}_{}", parent, name)
233            }
234        };
235        let class = format!("---@class {}", message);
236        let mut fields = vec![];
237        let mut enum_kv = vec![];
238        for value in descriptor.values() {
239            let field = format!("---@field {} number {}", value.name(), value.value());
240            fields.push(field);
241            enum_kv.push((value.name().to_string(), value.value().to_string()))
242        }
243        let kvs = enum_kv
244            .iter()
245            .map(|(k, v)| format!("{} = {}", k, v))
246            .collect::<Vec<String>>();
247        format!(
248            "{}\n{}\n{} = {{ {} }}\n",
249            class,
250            fields.join("\n"),
251            message,
252            kvs.join(", ")
253        )
254    }
255
256    fn decorate_with_parent(&self, parent: &Option<String>, name: String) -> String {
257        match &parent {
258            None => name.to_string(),
259            Some(parent) => {
260                format!("{}_{}", parent, name)
261            }
262        }
263    }
264
265    fn decorate_message_type_with_parent(
266        &self,
267        runtime_field_type: RuntimeFieldType,
268        nested_messages_or_enums: &[(String, String)],
269    ) -> Option<String> {
270        fn find_message(
271            nested_messages_or_enums: &[(String, String)],
272            name: &str,
273        ) -> Option<String> {
274            nested_messages_or_enums
275                .iter()
276                .rfind(|(n, _)| n == name)
277                .map(|(_, p)| p.clone())
278        }
279
280        fn decorate_message(
281            nested_messages_or_enums: &[(String, String)],
282            rt: RuntimeType,
283        ) -> Option<String> {
284            match rt {
285                RuntimeType::Enum(e) => find_message(nested_messages_or_enums, e.name()),
286                RuntimeType::Message(m) => find_message(nested_messages_or_enums, m.name()),
287                _ => None,
288            }
289        }
290        match runtime_field_type {
291            RuntimeFieldType::Singular(rt) => decorate_message(nested_messages_or_enums, rt),
292            RuntimeFieldType::Repeated(rt) => decorate_message(nested_messages_or_enums, rt),
293            RuntimeFieldType::Map(_, value_rt) => {
294                decorate_message(nested_messages_or_enums, value_rt)
295            }
296        }
297    }
298
299    fn lua_type_of(&self, parent: Option<String>, field_type: RuntimeFieldType) -> String {
300        fn type_of(protoc: &LuaProtoc, parent: Option<String>, rt: RuntimeType) -> String {
301            match rt {
302                RuntimeType::I32
303                | RuntimeType::I64
304                | RuntimeType::U32
305                | RuntimeType::U64
306                | RuntimeType::F32
307                | RuntimeType::F64 => "number".to_string(),
308                RuntimeType::Bool => "boolean".to_string(),
309                RuntimeType::String => "string".to_string(),
310                RuntimeType::VecU8 => "number[]".to_string(),
311                RuntimeType::Enum(e) => {
312                    let name = e.name().to_string();
313                    protoc.decorate_with_parent(&parent, name)
314                }
315                RuntimeType::Message(m) => {
316                    let name = m.name().to_string();
317                    protoc.decorate_with_parent(&parent, name)
318                }
319            }
320        }
321
322        match field_type {
323            RuntimeFieldType::Singular(rt) => type_of(self, parent, rt),
324            RuntimeFieldType::Repeated(rt) => {
325                format!("{}[]", type_of(self, parent, rt))
326            }
327            RuntimeFieldType::Map(key_rt, value_rt) => {
328                format!(
329                    "table<{},{}>",
330                    type_of(self, parent.clone(), key_rt),
331                    type_of(self, parent, value_rt)
332                )
333            }
334        }
335    }
336
337    pub fn encode(
338        &self,
339        message_full_name: &str,
340        lua_message: &Table,
341    ) -> anyhow::Result<Box<dyn MessageDyn>> {
342        let descriptor = self
343            .message_descriptors
344            .get(message_full_name)
345            .ok_or(anyhow!("{} not found", message_full_name))?;
346        let message = self.codec.encode_message(lua_message, descriptor)?;
347        Ok(message)
348    }
349
350    pub fn decode(
351        &self,
352        lua: &Lua,
353        message_full_name: String,
354        message_bytes: &[u8],
355    ) -> anyhow::Result<Table> {
356        let descriptor = self
357            .message_descriptors
358            .get(&message_full_name)
359            .ok_or(anyhow!("{} not found", message_full_name))?;
360        let message = descriptor.parse_from_bytes(message_bytes)?;
361        let lua_message = self.codec.decode_message(lua, message.as_ref())?;
362        Ok(lua_message)
363    }
364
365    pub fn list_protos(paths: impl IntoIterator<Item = impl AsRef<Path>>) -> Vec<PathBuf> {
366        let mut protos = Vec::new();
367        for path in paths {
368            for file in walkdir::WalkDir::new(path)
369                .into_iter()
370                .filter_map(|file| file.ok())
371            {
372                let proto_path = file.path();
373                if proto_path
374                    .extension()
375                    .map(|e| e == "proto")
376                    .unwrap_or(false)
377                {
378                    protos.push(proto_path.to_path_buf());
379                }
380            }
381        }
382        protos
383    }
384}
385
386impl LuaUserData for LuaProtoc {
387    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
388        methods.add_function(
389            "parse_files",
390            |_, (inputs, includes): (Vec<String>, Vec<String>)| {
391                if inputs.is_empty() {
392                    return Err(anyhow!("inputs must not empty").into());
393                }
394                if includes.is_empty() {
395                    return Err(anyhow!("includes must not empty").into());
396                }
397                let protoc =
398                    LuaProtoc::parse_files(inputs, includes).map_err(|e| anyhow!("{e:?}"))?;
399                Ok(protoc)
400            },
401        );
402
403        methods.add_function("parse_proto", |_, proto: String| {
404            let protoc = LuaProtoc::parse_proto(proto).map_err(|e| anyhow!("{e:?}"))?;
405            Ok(protoc)
406        });
407
408        methods.add_function("parse_pb", |_, dir: String| {
409            let protoc = LuaProtoc::parse_pb(dir).map_err(|e| anyhow!("{e:?}"))?;
410            Ok(protoc)
411        });
412
413        methods.add_method("gen_pb", |_, this, path: String| {
414            this.gen_pb(path).map_err(|e| anyhow!("{e:?}"))?;
415            Ok(())
416        });
417
418        methods.add_method("gen_lua", |_, this, path: String| {
419            this.gen_lua(path).map_err(|e| anyhow!("{e:?}"))?;
420            Ok(())
421        });
422
423        methods.add_method(
424            "encode",
425            |_, protoc, (message_full_name, lua_message): (String, Table)| {
426                let message = protoc
427                    .encode(&message_full_name, &lua_message)
428                    .map_err(|e| anyhow!("{e:?}"))?;
429                let mut message_bytes = Vec::with_capacity(message.compute_size_dyn() as usize);
430                message
431                    .write_to_vec_dyn(&mut message_bytes)
432                    .map_err(|e| anyhow!("{e:?}"))?;
433                Ok(message_bytes)
434            },
435        );
436
437        methods.add_method(
438            "decode",
439            |lua, protoc, (message_full_name, message_bytes): (String, Vec<u8>)| {
440                let message = protoc
441                    .decode(lua, message_full_name, message_bytes.as_ref())
442                    .map_err(|e| anyhow!("{e:?}"))?;
443                Ok(message)
444            },
445        );
446
447        methods.add_function("list_protos", |_, paths: Vec<String>| {
448            let protos = LuaProtoc::list_protos(paths)
449                .iter()
450                .map(|p| p.to_string_lossy().to_string())
451                .collect::<Vec<String>>();
452            Ok(protos)
453        });
454
455        methods.add_method("all_file_descriptors", |_, protoc, ()| {
456            let descriptors: Vec<_> = protoc.file_descriptors.values().map(Clone::clone).collect();
457            Ok(descriptors)
458        });
459
460        methods.add_method("file_descriptor_by_name", |_, protoc, name: String| {
461            let descriptor = protoc.file_descriptors.get(&name).cloned();
462            Ok(descriptor)
463        });
464
465        methods.add_method("all_message_descriptors", |_, protoc, ()| {
466            let descriptors: Vec<_> = protoc
467                .message_descriptors
468                .values()
469                .map(Clone::clone)
470                .collect();
471            Ok(descriptors)
472        });
473
474        methods.add_method("message_descriptor_by_name", |_, protoc, name: String| {
475            let descriptor = protoc.message_descriptors.get(&name).cloned();
476            Ok(descriptor)
477        });
478
479        methods.add_method("all_enum_descriptors", |_, protoc, ()| {
480            let descriptors: Vec<_> = protoc.enum_descriptors.values().map(Clone::clone).collect();
481            Ok(descriptors)
482        });
483
484        methods.add_method("enum_descriptor_by_name", |_, protoc, name: String| {
485            let descriptor = protoc.enum_descriptors.get(&name).cloned();
486            Ok(descriptor)
487        });
488    }
489}