lua_protobuf_rs/
protoc.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::{anyhow, Context};
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, Table, UserDataMethods};
7use protobuf::descriptor::FileDescriptorProto;
8use protobuf::reflect::{EnumDescriptor, FileDescriptor, MessageDescriptor, RuntimeFieldType, RuntimeType};
9use protobuf::{CodedInputStream, Message, MessageDyn};
10
11use crate::codec::LuaProtoCodec;
12use crate::descriptor::enum_descriptor::LuaEnumDescriptor;
13use crate::descriptor::file_descriptor::LuaFileDescriptor;
14use crate::descriptor::message_descriptor::LuaMessageDescriptor;
15
16#[derive(Default)]
17pub struct LuaProtoc {
18    pub codec: LuaProtoCodec,
19    pub file_descriptors: HashMap<String, LuaFileDescriptor>,
20    pub message_descriptors: HashMap<String, LuaMessageDescriptor>,
21    pub enum_descriptors: HashMap<String, LuaEnumDescriptor>,
22}
23
24impl LuaProtoc {
25    pub fn new(descriptors: Vec<FileDescriptor>) -> Self {
26        let codec = LuaProtoCodec::default();
27        let mut file_descriptors = HashMap::new();
28        let mut message_descriptors = HashMap::new();
29        let mut enum_descriptors = HashMap::new();
30        for file_descriptor in descriptors {
31            for message_descriptor in file_descriptor.messages() {
32                message_descriptors.insert(message_descriptor.full_name().to_string(), From::from(message_descriptor));
33            }
34            for enum_descriptor in file_descriptor.enums() {
35                enum_descriptors.insert(enum_descriptor.full_name().to_string(), From::from(enum_descriptor));
36            }
37            file_descriptors.insert(file_descriptor.name().to_string(), file_descriptor.into());
38        };
39        Self {
40            codec,
41            file_descriptors,
42            message_descriptors,
43            enum_descriptors,
44        }
45    }
46
47    pub fn parse_files(inputs: impl IntoIterator<Item=impl AsRef<Path>>, includes: impl IntoIterator<Item=impl AsRef<Path>>) -> anyhow::Result<Self> {
48        let mut parser = protobuf_parse::Parser::new();
49        parser.inputs(inputs).includes(includes);
50
51        #[cfg(feature = "google_protoc")]
52        parser.protoc();
53
54        #[cfg(feature = "vendored_protoc")]
55        parser.protoc_path(&protoc_bin_vendored::protoc_bin_path().context("unable to find protoc bin vendored")?);
56
57        let file_protos = parser.parse_and_typecheck().context("parse proto failed")?.file_descriptors;
58        let file_descriptors: Vec<FileDescriptor> = FileDescriptor::new_dynamic_fds(file_protos, &[])?;
59        let protoc = LuaProtoc::new(file_descriptors);
60        Ok(protoc)
61    }
62
63    pub fn parse_proto(proto: impl AsRef<str>) -> anyhow::Result<Self> {
64        let temp_dir = tempfile::tempdir().context("unable to get tempdir")?;
65        let tempfile = temp_dir.path().join("temp.proto");
66        std::fs::write(&tempfile, proto.as_ref()).context("unable to write data to tempfile")?;
67        LuaProtoc::parse_files([&tempfile], [&temp_dir])
68    }
69
70    pub fn parse_pb(path: impl AsRef<Path>) -> anyhow::Result<Self> {
71        let mut protos = vec![];
72        for entry in walkdir::WalkDir::new(path).into_iter().filter_map(|file| file.ok()) {
73            let pb_path = entry.path();
74            if pb_path.extension().and_then(|e| Some(e == "pb")).unwrap_or(false) {
75                let mut pb_file = std::fs::File::open(pb_path).context(format!("failed open {}", pb_path.to_string_lossy()))?;
76                let mut input = CodedInputStream::new(&mut pb_file);
77                let proto = FileDescriptorProto::parse_from(&mut input)?;
78                protos.push(proto);
79            }
80        }
81        let file_descriptors = FileDescriptor::new_dynamic_fds(protos, &[])?;
82        let protoc = LuaProtoc::new(file_descriptors);
83        Ok(protoc)
84    }
85
86    pub fn gen_pb(&self, path: String) -> anyhow::Result<()> {
87        let path = PathBuf::from(path);
88        for (_, file_descriptor) in &self.file_descriptors {
89            let name = file_descriptor.name().strip_suffix(".proto").expect("file descriptor not a proto file");
90            let file_name = format!("{}.pb", name);
91            let file_path = path.join(file_name);
92            std::fs::write(&file_path, file_descriptor.proto().write_to_bytes()?).context(format!("failed write lua to file {}", file_path.to_string_lossy()))?;
93        }
94        Ok(())
95    }
96
97    pub fn gen_lua(&self, path: String) -> anyhow::Result<()> {
98        let path = PathBuf::from(path);
99        for (_, file_descriptor) in &self.file_descriptors {
100            let name = file_descriptor.name().strip_suffix(".proto").expect("file descriptor not a proto file");
101            let mut message_or_enum = vec![];
102            for message_descriptor in file_descriptor.messages() {
103                let name = message_descriptor.name().to_string();
104                let mut nested_messages_or_enums = vec![];
105                let nested_messages = message_descriptor.nested_messages().map(|m| (m.name().to_string(), name.clone()));
106                nested_messages_or_enums.extend(nested_messages);
107                let nested_enums = message_descriptor.nested_enums().map(|e| (e.name().to_string(), name.clone()));
108                nested_messages_or_enums.extend(nested_enums);
109                let messages = self.gen_lua_message(None, nested_messages_or_enums, &message_descriptor);
110                message_or_enum.extend(messages);
111            }
112            for enum_descriptor in file_descriptor.enums() {
113                let enum_table = self.gen_lua_enum(None, &enum_descriptor);
114                message_or_enum.push(enum_table);
115            }
116            let code = message_or_enum.join("\n");
117            let file_name = format!("{}.lua", name);
118            let file_path = path.join(file_name);
119            std::fs::write(&file_path, code).context(format!("failed write lua to file {}", file_path.to_string_lossy()))?;
120        }
121        Ok(())
122    }
123
124    fn gen_lua_message(&self, parent: Option<String>, nested_messages_or_enums: Vec<(String, String)>, descriptor: &MessageDescriptor) -> Vec<String> {
125        if descriptor.is_map_entry() {
126            return vec![];
127        }
128        let mut messages = vec![];
129        let message_name = descriptor.name().to_string();
130        let message_with_parent = self.decorate_with_parent(&parent, message_name.clone());
131        let class = format!("---@class {}", message_with_parent);
132        for nested_message_descriptor in descriptor.nested_messages() {
133            let name = self.decorate_with_parent(&Some(message_with_parent.clone()), nested_message_descriptor.name().to_string());
134            let mut child_nested_messages_or_enums = nested_messages_or_enums.clone();
135            child_nested_messages_or_enums.extend(nested_message_descriptor.nested_messages().map(|m| (m.name().to_string(), name.clone())));
136            child_nested_messages_or_enums.extend(nested_message_descriptor.nested_enums().map(|e| (e.name().to_string(), name.clone())));
137            let nested_messages = self.gen_lua_message(Some(message_with_parent.clone()), child_nested_messages_or_enums, &nested_message_descriptor);
138            messages.extend(nested_messages);
139        }
140        for nested_enum_descriptor in descriptor.nested_enums() {
141            let nested_enum = self.gen_lua_enum(Some(message_with_parent.clone()), &nested_enum_descriptor);
142            messages.push(nested_enum);
143        }
144        let mut fields = vec![];
145        for field in descriptor.fields() {
146            let parent = self.decorate_message_type_with_parent(field.runtime_field_type(), &nested_messages_or_enums);
147            let ty = self.lua_type_of(parent.clone(), field.runtime_field_type());
148            let field = format!("---@field {} {}", field.name(), ty);
149            fields.push(field)
150        }
151        let message_table = if fields.is_empty() {
152            format!("{}\nlocal {} = {{ }}\n", class, message_with_parent)
153        } else {
154            format!("{}\n{}\nlocal {} = {{ }}\n", class, fields.join("\n"), message_with_parent)
155        };
156        messages.push(message_table);
157        messages
158    }
159
160    fn gen_lua_enum(&self, parent: Option<String>, descriptor: &EnumDescriptor) -> String {
161        let name = descriptor.name();
162        let message = match parent {
163            None => {
164                format!("{}", name)
165            }
166            Some(parent) => {
167                format!("{}_{}", parent, name)
168            }
169        };
170        let class = format!("---@class {}", message);
171        let mut fields = vec![];
172        let mut enum_kv = vec![];
173        for value in descriptor.values() {
174            let field = format!("---@field {} number {}", value.name(), value.value());
175            fields.push(field);
176            enum_kv.push((value.name().to_string(), value.value().to_string()))
177        }
178        let kvs = enum_kv.iter().map(|(k, v)| format!("{} = {}", k, v)).collect::<Vec<String>>();
179        format!("{}\n{}\n{} = {{ {} }}\n", class, fields.join("\n"), message, kvs.join(", "))
180    }
181
182    fn decorate_with_parent(&self, parent: &Option<String>, name: String) -> String {
183        let message = match &parent {
184            None => {
185                format!("{}", name)
186            }
187            Some(parent) => {
188                format!("{}_{}", parent, name)
189            }
190        };
191        message
192    }
193
194    fn decorate_message_type_with_parent(&self, runtime_field_type: RuntimeFieldType, nested_messages_or_enums: &Vec<(String, String)>) -> Option<String> {
195        fn find_message(nested_messages_or_enums: &Vec<(String, String)>, name: &str) -> Option<String> {
196            match nested_messages_or_enums.iter().rfind(|(n, _)| n == name) {
197                None => {
198                    None
199                }
200                Some((_, p)) => {
201                    Some(p.clone())
202                }
203            }
204        }
205
206        fn decorate_message(nested_messages_or_enums: &Vec<(String, String)>, rt: RuntimeType) -> Option<String> {
207            match rt {
208                RuntimeType::Enum(e) => {
209                    find_message(nested_messages_or_enums, e.name())
210                }
211                RuntimeType::Message(m) => {
212                    find_message(nested_messages_or_enums, m.name())
213                }
214                _ => None
215            }
216        }
217        match runtime_field_type {
218            RuntimeFieldType::Singular(rt) => {
219                decorate_message(nested_messages_or_enums, rt)
220            }
221            RuntimeFieldType::Repeated(rt) => {
222                decorate_message(nested_messages_or_enums, rt)
223            }
224            RuntimeFieldType::Map(_, value_rt) => {
225                decorate_message(nested_messages_or_enums, value_rt)
226            }
227        }
228    }
229
230    fn lua_type_of(&self, parent: Option<String>, field_type: RuntimeFieldType) -> String {
231        fn type_of(protoc: &LuaProtoc, parent: Option<String>, rt: RuntimeType) -> String {
232            let ty = match rt {
233                RuntimeType::I32 |
234                RuntimeType::I64 |
235                RuntimeType::U32 |
236                RuntimeType::U64 |
237                RuntimeType::F32 |
238                RuntimeType::F64 => "number".to_string(),
239                RuntimeType::Bool => "boolean".to_string(),
240                RuntimeType::String => "string".to_string(),
241                RuntimeType::VecU8 => "number[]".to_string(),
242                RuntimeType::Enum(e) => {
243                    let name = e.name().to_string();
244                    protoc.decorate_with_parent(&parent, name)
245                }
246                RuntimeType::Message(m) => {
247                    let name = m.name().to_string();
248                    protoc.decorate_with_parent(&parent, name)
249                }
250            };
251            ty
252        }
253        let ty = match field_type {
254            RuntimeFieldType::Singular(rt) => {
255                type_of(self, parent, rt)
256            }
257            RuntimeFieldType::Repeated(rt) => {
258                format!("{}[]", type_of(self, parent, rt))
259            }
260            RuntimeFieldType::Map(key_rt, value_rt) => {
261                format!("table<{},{}>", type_of(self, parent.clone(), key_rt), type_of(self, parent, value_rt))
262            }
263        };
264        ty
265    }
266
267
268    pub fn encode(&self, message_full_name: &str, lua_message: &Table) -> anyhow::Result<Box<dyn MessageDyn>> {
269        let descriptor = self.message_descriptors.get(message_full_name).ok_or(anyhow!("{} not found",message_full_name))?;
270        let message = self.codec.encode_message(lua_message, descriptor)?;
271        Ok(message)
272    }
273
274    pub fn decode(&self, lua: &Lua, message_full_name: String, message_bytes: &[u8]) -> anyhow::Result<Table> {
275        let descriptor = self.message_descriptors.get(&message_full_name).ok_or(anyhow!("{} not found",message_full_name))?;
276        let message = descriptor.parse_from_bytes(message_bytes)?;
277        let lua_message = self.codec.decode_message(lua, message.as_ref())?;
278        Ok(lua_message)
279    }
280
281    pub fn list_protos(paths: impl IntoIterator<Item=impl AsRef<Path>>) -> Vec<PathBuf> {
282        let mut protos = Vec::new();
283        for path in paths {
284            for file in walkdir::WalkDir::new(path).into_iter().filter_map(|file| file.ok()) {
285                let proto_path = file.path();
286                if proto_path.extension().and_then(|e| Some(e == "proto")).unwrap_or(false) {
287                    protos.push(proto_path.to_path_buf());
288                }
289            }
290        }
291        protos
292    }
293}
294
295impl LuaUserData for LuaProtoc {
296    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
297        methods.add_function("parse_files", |_, (inputs, includes): (Vec<String>, Vec<String>)| {
298            if inputs.is_empty() {
299                return Err(anyhow!("inputs must not empty").into());
300            }
301            if includes.is_empty() {
302                return Err(anyhow!("includes must not empty").into());
303            }
304            let protoc = LuaProtoc::parse_files(inputs, includes).map_err(|e| anyhow!("{e:?}"))?;
305            Ok(protoc)
306        });
307
308        methods.add_function("parse_proto", |_, proto: String| {
309            let protoc = LuaProtoc::parse_proto(proto).map_err(|e| anyhow!("{e:?}"))?;
310            Ok(protoc)
311        });
312
313        methods.add_function("parse_pb", |_, dir: String| {
314            let protoc = LuaProtoc::parse_pb(dir).map_err(|e| anyhow!("{e:?}"))?;
315            Ok(protoc)
316        });
317
318        methods.add_method("gen_pb", |_, this, path: String| {
319            this.gen_pb(path).map_err(|e| anyhow!("{e:?}"))?;
320            Ok(())
321        });
322
323        methods.add_method("gen_lua", |_, this, path: String| {
324            this.gen_lua(path).map_err(|e| anyhow!("{e:?}"))?;
325            Ok(())
326        });
327
328        methods.add_method("encode", |_, protoc, (message_full_name, lua_message): (String, Table)| {
329            let message = protoc.encode(&message_full_name, &lua_message).map_err(|e| anyhow!("{e:?}"))?;
330            let mut message_bytes = Vec::with_capacity(message.compute_size_dyn() as usize);
331            message.write_to_vec_dyn(&mut message_bytes).map_err(|e| anyhow!("{e:?}"))?;
332            Ok(message_bytes)
333        });
334
335        methods.add_method("decode", |lua, protoc, (message_full_name, message_bytes): (String, Vec<u8>)| {
336            let message = protoc.decode(lua, message_full_name, message_bytes.as_ref()).map_err(|e| anyhow!("{e:?}"))?;
337            Ok(message)
338        });
339
340        methods.add_function("list_protos", |_, paths: Vec<String>| {
341            let protos = LuaProtoc::list_protos(paths).iter().map(|p| { p.to_string_lossy().to_string() }).collect::<Vec<String>>();
342            Ok(protos)
343        });
344
345        methods.add_method("all_file_descriptors", |_, protoc, ()| {
346            let descriptors: Vec<_> = protoc.file_descriptors.values().map(Clone::clone).collect();
347            Ok(descriptors)
348        });
349
350        methods.add_method("file_descriptor_by_name", |_, protoc, name: String| {
351            let descriptor = protoc.file_descriptors.get(&name).map(Clone::clone);
352            Ok(descriptor)
353        });
354
355        methods.add_method("all_message_descriptors", |_, protoc, ()| {
356            let descriptors: Vec<_> = protoc.message_descriptors.values().map(Clone::clone).collect();
357            Ok(descriptors)
358        });
359
360        methods.add_method("message_descriptor_by_name", |_, protoc, name: String| {
361            let descriptor = protoc.message_descriptors.get(&name).map(Clone::clone);
362            Ok(descriptor)
363        });
364
365        methods.add_method("all_enum_descriptors", |_, protoc, ()| {
366            let descriptors: Vec<_> = protoc.enum_descriptors.values().map(Clone::clone).collect();
367            Ok(descriptors)
368        });
369
370        methods.add_method("enum_descriptor_by_name", |_, protoc, name: String| {
371            let descriptor = protoc.enum_descriptors.get(&name).map(Clone::clone);
372            Ok(descriptor)
373        });
374    }
375}