1#![doc = include_str!("../README.md")]
2
3use std::{
4 borrow::Cow,
5 collections::{BTreeMap, HashSet},
6 fmt, str,
7};
8
9use once_cell::sync::Lazy;
10use prost::Message;
11use prost_build::Module;
12use prost_types::{
13 compiler::{code_generator_response::File, CodeGeneratorRequest},
14 FileDescriptorProto,
15};
16
17use self::generator::{CoreProstGenerator, FileDescriptorSetGenerator};
18
19mod generator;
20
21pub use self::generator::{Error, Generator, GeneratorResultExt, Result};
22
23pub fn execute(raw_request: &[u8]) -> generator::Result {
25 let request = CodeGeneratorRequest::decode(raw_request)?;
26 let params = request.parameter().parse::<Parameters>()?;
27
28 let module_request_set = ModuleRequestSet::new(
29 request.file_to_generate,
30 request.proto_file,
31 raw_request,
32 params.prost.default_package_filename(),
33 params.prost.flat_output_dir,
34 )?;
35 let file_descriptor_set_generator = params
36 .file_descriptor_set
37 .then_some(FileDescriptorSetGenerator);
38
39 let files = CoreProstGenerator::new(params.prost.to_prost_config())
40 .chain(file_descriptor_set_generator)
41 .generate(&module_request_set)?;
42
43 Ok(files)
44}
45
46pub struct ModuleRequestSet {
48 requests: BTreeMap<Module, ModuleRequest>,
49}
50
51impl ModuleRequestSet {
52 pub fn new<I>(
56 input_protos: I,
57 proto_file: Vec<FileDescriptorProto>,
58 raw_request: &[u8],
59 default_package_filename: Option<&str>,
60 flat_output_dir: bool,
61 ) -> std::result::Result<Self, prost::DecodeError>
62 where
63 I: IntoIterator<Item = String>,
64 {
65 let raw_protos = RawProtos::decode(raw_request)?;
66
67 Ok(Self::new_decoded(
68 input_protos,
69 proto_file,
70 raw_protos,
71 default_package_filename.unwrap_or("_"),
72 flat_output_dir,
73 ))
74 }
75
76 fn new_decoded<I>(
77 input_protos: I,
78 proto_file: Vec<FileDescriptorProto>,
79 raw_protos: RawProtos,
80 default_package_filename: &str,
81 flat_output_dir: bool,
82 ) -> Self
83 where
84 I: IntoIterator<Item = String>,
85 {
86 let input_protos: HashSet<_> = input_protos.into_iter().collect();
87
88 let requests = proto_file.into_iter().zip(raw_protos.proto_file).fold(
89 BTreeMap::new(),
90 |mut acc, (proto, raw)| {
91 let module = Module::from_protobuf_package_name(proto.package());
92 let proto_filename = proto.name();
93 let entry = acc.entry(module.clone()).or_insert_with(|| {
94 ModuleRequest::new(proto.package().to_owned(), module, flat_output_dir)
95 });
96
97 if entry.output_filename().is_none() && input_protos.contains(proto_filename) {
98 let filename = match proto.package() {
99 "" => default_package_filename.to_owned(),
100 package => format!("{package}.rs"),
101 };
102 entry.with_output_filename(filename);
103 }
104
105 entry.push_file_descriptor_proto(proto, raw);
106 acc
107 },
108 );
109
110 Self { requests }
111 }
112
113 pub fn requests(&self) -> impl Iterator<Item = (&Module, &ModuleRequest)> {
115 self.requests.iter()
116 }
117
118 pub fn for_module(&self, module: &Module) -> Option<&ModuleRequest> {
120 self.requests.get(module)
121 }
122
123 pub fn modules(&self) -> impl Iterator<Item = &Module> {
124 self.requests.keys()
125 }
126}
127
128pub struct ModuleRequest {
130 proto_package_name: String,
131 module: Module,
132 flat_output_dir: bool,
133 output_filename: Option<String>,
134 files: Vec<FileDescriptorProto>,
135 raw: Vec<Vec<u8>>,
136}
137
138impl ModuleRequest {
139 fn new(proto_package_name: String, module: Module, flat_output_dir: bool) -> Self {
140 Self {
141 proto_package_name,
142 module,
143 flat_output_dir,
144 output_filename: None,
145 files: Vec::new(),
146 raw: Vec::new(),
147 }
148 }
149
150 fn with_output_filename(&mut self, filename: String) {
151 self.output_filename = Some(filename);
152 }
153
154 fn push_file_descriptor_proto(&mut self, encoded: FileDescriptorProto, raw: Vec<u8>) {
155 self.files.push(encoded);
156 self.raw.push(raw);
157 }
158
159 pub fn proto_package_name(&self) -> &str {
161 &self.proto_package_name
162 }
163
164 pub fn output_filename(&self) -> Option<&str> {
166 self.output_filename.as_deref()
167 }
168
169 pub fn output_dir(&self) -> String {
170 if self.flat_output_dir {
171 return String::new();
172 }
173 let mut output_dir = self.module.parts().collect::<Vec<_>>().join("/");
174 if !output_dir.is_empty() {
175 output_dir.push('/');
176 }
177 output_dir
178 }
179
180 pub fn output_filepath(&self) -> Option<String> {
181 self.output_filename().map(|f| {
182 let dir = self.output_dir();
183 format!("{dir}{f}")
184 })
185 }
186
187 pub fn files(&self) -> impl Iterator<Item = &FileDescriptorProto> {
189 self.files.iter()
190 }
191
192 pub fn raw_files(&self) -> impl Iterator<Item = &[u8]> {
194 self.raw.iter().map(|b| b.as_slice())
195 }
196
197 pub(crate) fn write_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
199 self.output_filepath().map(|name| {
200 let mut content = String::with_capacity(8_192);
201 f(&mut content);
202
203 File {
204 name: Some(name),
205 content: Some(content),
206 ..Default::default()
207 }
208 })
209 }
210
211 pub fn append_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
216 self.output_filepath().map(|name| {
217 let mut content = String::new();
218 f(&mut content);
219
220 File {
221 name: Some(name),
222 content: Some(content),
223 insertion_point: Some("module".to_owned()),
224 ..Default::default()
225 }
226 })
227 }
228}
229
230#[derive(Debug, Default)]
234struct Parameters {
235 prost: ProstParameters,
237
238 file_descriptor_set: bool,
240}
241
242#[derive(Debug, Default)]
244struct ProstParameters {
245 btree_map: Vec<String>,
246 bytes: Vec<String>,
247 boxed: Vec<String>,
248 disable_comments: Vec<String>,
249 skip_debug: Vec<String>,
250 default_package_filename: Option<String>,
251 extern_path: Vec<(String, String)>,
252 type_attribute: Vec<(String, String)>,
253 field_attribute: Vec<(String, String)>,
254 enum_attribute: Vec<(String, String)>,
255 message_attribute: Vec<(String, String)>,
256 compile_well_known_types: bool,
257 retain_enum_prefix: bool,
258 enable_type_names: bool,
259 flat_output_dir: bool,
260}
261
262impl ProstParameters {
263 fn to_prost_config(&self) -> prost_build::Config {
265 let mut config = prost_build::Config::new();
266 config.btree_map(self.btree_map.iter());
267 config.bytes(self.bytes.iter());
268 for b in self.boxed.iter() {
269 config.boxed(b);
270 }
271 config.disable_comments(self.disable_comments.iter());
272 config.skip_debug(self.skip_debug.iter());
273
274 if let Some(filename) = self.default_package_filename.as_deref() {
275 config.default_package_filename(filename);
276 }
277
278 for (proto_path, rust_path) in &self.extern_path {
279 config.extern_path(proto_path, rust_path);
280 }
281 for (proto_path, attribute) in &self.type_attribute {
282 config.type_attribute(proto_path, attribute);
283 }
284 for (proto_path, attribute) in &self.field_attribute {
285 config.field_attribute(proto_path, attribute);
286 }
287 for (proto_path, attribute) in &self.enum_attribute {
288 config.enum_attribute(proto_path, attribute);
289 }
290 for (proto_path, attribute) in &self.message_attribute {
291 config.message_attribute(proto_path, attribute);
292 }
293
294 if self.compile_well_known_types {
295 config.compile_well_known_types();
296 }
297 if self.retain_enum_prefix {
298 config.retain_enum_prefix();
299 }
300 if self.enable_type_names {
301 config.enable_type_names();
302 }
303
304 config
305 }
306
307 fn default_package_filename(&self) -> Option<&str> {
308 self.default_package_filename.as_deref()
309 }
310
311 fn try_handle_parameter<'a>(&mut self, param: Param<'a>) -> std::result::Result<(), Param<'a>> {
312 match param {
313 Param::Value {
314 param: "btree_map",
315 value,
316 } => self.btree_map.push(value.to_string()),
317 Param::Value {
318 param: "bytes",
319 value,
320 } => self.bytes.push(value.to_string()),
321 Param::Value {
322 param: "boxed",
323 value,
324 } => self.boxed.push(value.to_string()),
325 Param::Parameter {
326 param: "default_package_filename",
327 }
328 | Param::Value {
329 param: "default_package_filename",
330 ..
331 } => self.default_package_filename = param.value().map(|s| s.into_owned()),
332 Param::Parameter {
333 param: "compile_well_known_types",
334 }
335 | Param::Value {
336 param: "compile_well_known_types",
337 value: "true",
338 } => self.compile_well_known_types = true,
339 Param::Value {
340 param: "compile_well_known_types",
341 value: "false",
342 } => (),
343 Param::Value {
344 param: "disable_comments",
345 value,
346 } => self.disable_comments.push(value.to_string()),
347 Param::Value {
348 param: "skip_debug",
349 value,
350 } => self.skip_debug.push(value.to_string()),
351 Param::Parameter {
352 param: "retain_enum_prefix",
353 }
354 | Param::Value {
355 param: "retain_enum_prefix",
356 value: "true",
357 } => self.retain_enum_prefix = true,
358 Param::Value {
359 param: "retain_enum_prefix",
360 value: "false",
361 } => (),
362 Param::KeyValue {
363 param: "extern_path",
364 key: prefix,
365 value: module,
366 } => self.extern_path.push((prefix.to_string(), module)),
367 Param::KeyValue {
368 param: "type_attribute",
369 key: prefix,
370 value: module,
371 } => self.type_attribute.push((
372 prefix.to_string(),
373 module.replace(r"\,", ",").replace(r"\\", r"\"),
374 )),
375 Param::KeyValue {
376 param: "field_attribute",
377 key: prefix,
378 value: module,
379 } => self.field_attribute.push((
380 prefix.to_string(),
381 module.replace(r"\,", ",").replace(r"\\", r"\"),
382 )),
383 Param::KeyValue {
384 param: "enum_attribute",
385 key: prefix,
386 value: module,
387 } => self.enum_attribute.push((
388 prefix.to_string(),
389 module.replace(r"\,", ",").replace(r"\\", r"\"),
390 )),
391 Param::KeyValue {
392 param: "message_attribute",
393 key: prefix,
394 value: module,
395 } => self.message_attribute.push((
396 prefix.to_string(),
397 module.replace(r"\,", ",").replace(r"\\", r"\"),
398 )),
399 Param::Parameter {
400 param: "enable_type_names",
401 }
402 | Param::Value {
403 param: "enable_type_names",
404 value: "true",
405 } => self.enable_type_names = true,
406 Param::Value {
407 param: "enable_type_names",
408 value: "false",
409 } => (),
410 Param::Parameter {
411 param: "flat_output_dir",
412 }
413 | Param::Value {
414 param: "flat_output_dir",
415 value: "true",
416 } => self.flat_output_dir = true,
417 Param::Value {
418 param: "flat_output_dir",
419 value: "false",
420 } => (),
421 _ => return Err(param),
422 }
423
424 Ok(())
425 }
426}
427
428static PARAMETER: Lazy<regex::Regex> = Lazy::new(|| {
443 regex::Regex::new(
444 r"(?:(?P<param>[^,=]+)(?:=(?P<key>[^,=]+)(?:=(?P<value>(?:[^,\\]|\\,|\\\\)+))?)?)",
445 )
446 .unwrap()
447});
448
449pub struct Params<'a> {
450 params: Vec<Param<'a>>,
451}
452
453impl<'a> IntoIterator for Params<'a> {
454 type IntoIter = <Vec<Param<'a>> as IntoIterator>::IntoIter;
455 type Item = <Vec<Param<'a>> as IntoIterator>::Item;
456
457 fn into_iter(self) -> Self::IntoIter {
458 self.params.into_iter()
459 }
460}
461
462#[derive(Debug, PartialEq, Eq)]
463pub enum Param<'a> {
464 Parameter {
465 param: &'a str,
466 },
467 Value {
468 param: &'a str,
469 value: &'a str,
470 },
471 KeyValue {
472 param: &'a str,
473 key: &'a str,
474 value: String,
475 },
476}
477
478impl<'a> Param<'a> {
479 pub fn value(self) -> Option<Cow<'a, str>> {
480 match self {
481 Self::Parameter { .. } => None,
482 Self::Value { value, .. } => Some(Cow::Borrowed(value)),
483 Self::KeyValue { value, .. } => Some(Cow::Owned(value)),
484 }
485 }
486}
487
488impl From<Param<'_>> for InvalidParameter {
489 fn from(param: Param<'_>) -> Self {
490 let message = match param {
491 Param::Parameter { param } => param.to_owned(),
492 Param::Value { param, value } => format!("{param}={value}"),
493 Param::KeyValue { param, key, value } => {
494 let value = value.replace('\\', r"\\").replace(',', r"\,");
495 format!("{param}={key}={value}")
496 }
497 };
498 InvalidParameter(message)
499 }
500}
501
502impl<'a> Params<'a> {
503 pub fn from_protoc_plugin_opts(s: &'a str) -> std::result::Result<Self, InvalidParameter> {
504 let params = PARAMETER
505 .captures_iter(s)
506 .map(|capture| {
507 let param = capture
508 .get(1)
509 .expect("any captured group will at least have the param name")
510 .as_str()
511 .trim();
512
513 let key = capture.get(2).map(|m| m.as_str());
514 let value = capture.get(3).map(|m| m.as_str());
515
516 match (key, value) {
517 (None, None) => Ok(Param::Parameter { param }),
518 (Some(value), None) => Ok(Param::Value { param, value }),
519 (Some(key), Some(value)) => Ok(Param::KeyValue {
520 param,
521 key,
522 value: value.replace(r"\,", ",").replace(r"\\", r"\"),
523 }),
524 _ => Err(InvalidParameter(
525 capture.get(0).unwrap().as_str().to_string(),
526 )),
527 }
528 })
529 .collect::<std::result::Result<_, _>>()?;
530 Ok(Self { params })
531 }
532}
533
534impl str::FromStr for Parameters {
535 type Err = InvalidParameter;
536 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
537 let mut ret_val = Self::default();
538 for param in Params::from_protoc_plugin_opts(s)? {
539 if let Err(param) = ret_val.prost.try_handle_parameter(param) {
540 match param {
541 Param::Parameter {
542 param: "file_descriptor_set",
543 }
544 | Param::Value {
545 param: "file_descriptor_set",
546 value: "true",
547 } => ret_val.file_descriptor_set = true,
548 Param::Value {
549 param: "file_descriptor_set",
550 value: "false",
551 } => (),
552 _ => return Err(InvalidParameter::from(param)),
553 }
554 }
555 }
556
557 Ok(ret_val)
558 }
559}
560
561#[derive(Debug)]
563pub struct InvalidParameter(String);
564
565impl InvalidParameter {
566 pub fn new(message: String) -> Self {
567 Self(message)
568 }
569}
570
571impl fmt::Display for InvalidParameter {
572 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
573 f.write_str("invalid parameter: ")?;
574 f.write_str(&self.0)
575 }
576}
577
578impl std::error::Error for InvalidParameter {}
579
580#[derive(Clone, PartialEq, ::prost::Message)]
589struct RawProtos {
590 #[prost(bytes = "vec", repeated, tag = "15")]
591 proto_file: Vec<Vec<u8>>,
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 #[test]
599 fn compiler_option_string_with_three_plus_equals_parses_correctly() {
600 const INPUT: &str = r#"flat_output_dir,enable_type_names,compile_well_known_types,disable_comments=.,skip_debug=.,extern_path=.google.protobuf=::pbjson_types,type_attribute=.=#[cfg(all(feature = "test"\, feature = "orange"))]"#;
601
602 let expected: &[Param] = &[
603 Param::Parameter {
604 param: "flat_output_dir",
605 },
606 Param::Parameter {
607 param: "enable_type_names",
608 },
609 Param::Parameter {
610 param: "compile_well_known_types",
611 },
612 Param::Value {
613 param: "disable_comments",
614 value: ".",
615 },
616 Param::Value {
617 param: "skip_debug",
618 value: ".",
619 },
620 Param::KeyValue {
621 param: "extern_path",
622 key: ".google.protobuf",
623 value: "::pbjson_types".into(),
624 },
625 Param::KeyValue {
626 param: "type_attribute",
627 key: ".",
628 value: r#"#[cfg(all(feature = "test", feature = "orange"))]"#.into(),
629 },
630 ];
631
632 let actual = Params::from_protoc_plugin_opts(INPUT).unwrap();
633 assert_eq!(actual.params, expected);
634 }
635}