protoc_gen_prost/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    borrow::Cow,
5    collections::{BTreeMap, HashSet},
6    fmt, str,
7};
8
9use once_cell::sync::Lazy;
10use prost::Message;
11use prost_build::Module;
12use prost_types::{
13    compiler::{code_generator_response::File, CodeGeneratorRequest},
14    FileDescriptorProto,
15};
16
17use self::generator::{CoreProstGenerator, FileDescriptorSetGenerator};
18
19mod generator;
20
21pub use self::generator::{Error, Generator, GeneratorResultExt, Result};
22
23/// Execute the core _Prost!_ generator from an encoded [`CodeGeneratorRequest`]
24pub fn execute(raw_request: &[u8]) -> generator::Result {
25    let request = CodeGeneratorRequest::decode(raw_request)?;
26    let params = request.parameter().parse::<Parameters>()?;
27
28    let module_request_set = ModuleRequestSet::new(
29        request.file_to_generate,
30        request.proto_file,
31        raw_request,
32        params.prost.default_package_filename(),
33        params.prost.flat_output_dir,
34    )?;
35    let file_descriptor_set_generator = params
36        .file_descriptor_set
37        .then_some(FileDescriptorSetGenerator);
38
39    let files = CoreProstGenerator::new(params.prost.to_prost_config())
40        .chain(file_descriptor_set_generator)
41        .generate(&module_request_set)?;
42
43    Ok(files)
44}
45
46/// A set of requests to generate code for a series of modules
47pub struct ModuleRequestSet {
48    requests: BTreeMap<Module, ModuleRequest>,
49}
50
51impl ModuleRequestSet {
52    /// Construct a new module request set from an encoded [`CodeGeneratorRequest`]
53    ///
54    /// [`CodeGeneratorRequest`]: prost_types::compiler::CodeGeneratorRequest
55    pub fn new<I>(
56        input_protos: I,
57        proto_file: Vec<FileDescriptorProto>,
58        raw_request: &[u8],
59        default_package_filename: Option<&str>,
60        flat_output_dir: bool,
61    ) -> std::result::Result<Self, prost::DecodeError>
62    where
63        I: IntoIterator<Item = String>,
64    {
65        let raw_protos = RawProtos::decode(raw_request)?;
66
67        Ok(Self::new_decoded(
68            input_protos,
69            proto_file,
70            raw_protos,
71            default_package_filename.unwrap_or("_"),
72            flat_output_dir,
73        ))
74    }
75
76    fn new_decoded<I>(
77        input_protos: I,
78        proto_file: Vec<FileDescriptorProto>,
79        raw_protos: RawProtos,
80        default_package_filename: &str,
81        flat_output_dir: bool,
82    ) -> Self
83    where
84        I: IntoIterator<Item = String>,
85    {
86        let input_protos: HashSet<_> = input_protos.into_iter().collect();
87
88        let requests = proto_file.into_iter().zip(raw_protos.proto_file).fold(
89            BTreeMap::new(),
90            |mut acc, (proto, raw)| {
91                let module = Module::from_protobuf_package_name(proto.package());
92                let proto_filename = proto.name();
93                let entry = acc.entry(module.clone()).or_insert_with(|| {
94                    ModuleRequest::new(proto.package().to_owned(), module, flat_output_dir)
95                });
96
97                if entry.output_filename().is_none() && input_protos.contains(proto_filename) {
98                    let filename = match proto.package() {
99                        "" => default_package_filename.to_owned(),
100                        package => format!("{package}.rs"),
101                    };
102                    entry.with_output_filename(filename);
103                }
104
105                entry.push_file_descriptor_proto(proto, raw);
106                acc
107            },
108        );
109
110        Self { requests }
111    }
112
113    /// An ordered iterator of all requests
114    pub fn requests(&self) -> impl Iterator<Item = (&Module, &ModuleRequest)> {
115        self.requests.iter()
116    }
117
118    /// Retrieve the request for the given module
119    pub fn for_module(&self, module: &Module) -> Option<&ModuleRequest> {
120        self.requests.get(module)
121    }
122
123    pub fn modules(&self) -> impl Iterator<Item = &Module> {
124        self.requests.keys()
125    }
126}
127
128/// A code generation request for a specific module
129pub struct ModuleRequest {
130    proto_package_name: String,
131    module: Module,
132    flat_output_dir: bool,
133    output_filename: Option<String>,
134    files: Vec<FileDescriptorProto>,
135    raw: Vec<Vec<u8>>,
136}
137
138impl ModuleRequest {
139    fn new(proto_package_name: String, module: Module, flat_output_dir: bool) -> Self {
140        Self {
141            proto_package_name,
142            module,
143            flat_output_dir,
144            output_filename: None,
145            files: Vec::new(),
146            raw: Vec::new(),
147        }
148    }
149
150    fn with_output_filename(&mut self, filename: String) {
151        self.output_filename = Some(filename);
152    }
153
154    fn push_file_descriptor_proto(&mut self, encoded: FileDescriptorProto, raw: Vec<u8>) {
155        self.files.push(encoded);
156        self.raw.push(raw);
157    }
158
159    /// The protobuf package name for this module
160    pub fn proto_package_name(&self) -> &str {
161        &self.proto_package_name
162    }
163
164    /// The output filename for this module
165    pub fn output_filename(&self) -> Option<&str> {
166        self.output_filename.as_deref()
167    }
168
169    pub fn output_dir(&self) -> String {
170        if self.flat_output_dir {
171            return String::new();
172        }
173        let mut output_dir = self.module.parts().collect::<Vec<_>>().join("/");
174        if !output_dir.is_empty() {
175            output_dir.push('/');
176        }
177        output_dir
178    }
179
180    pub fn output_filepath(&self) -> Option<String> {
181        self.output_filename().map(|f| {
182            let dir = self.output_dir();
183            format!("{dir}{f}")
184        })
185    }
186
187    /// An iterator of the file descriptors
188    pub fn files(&self) -> impl Iterator<Item = &FileDescriptorProto> {
189        self.files.iter()
190    }
191
192    /// An iterator of the encoded [`FileDescriptorProto`]s from [`files()`][Self::files()]
193    pub fn raw_files(&self) -> impl Iterator<Item = &[u8]> {
194        self.raw.iter().map(|b| b.as_slice())
195    }
196
197    /// Creates a code generation file from the output
198    pub(crate) fn write_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
199        self.output_filepath().map(|name| {
200            let mut content = String::with_capacity(8_192);
201            f(&mut content);
202
203            File {
204                name: Some(name),
205                content: Some(content),
206                ..Default::default()
207            }
208        })
209    }
210
211    /// Appends generated code to the end of the main file for this module
212    ///
213    /// This is generally a good way to add includes referencing the output
214    /// of other plugins or to directly append to the main file.
215    pub fn append_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
216        self.output_filepath().map(|name| {
217            let mut content = String::new();
218            f(&mut content);
219
220            File {
221                name: Some(name),
222                content: Some(content),
223                insertion_point: Some("module".to_owned()),
224                ..Default::default()
225            }
226        })
227    }
228}
229
230/// Parameters use to configure [`Generator`]s built into `protoc-gen-prost`
231///
232/// [`Generator`]: crate::Generator
233#[derive(Debug, Default)]
234struct Parameters {
235    /// Prost parameters, used to generate [`prost_build::Config`]
236    prost: ProstParameters,
237
238    /// Whether a file descriptor set has been requested in each module
239    file_descriptor_set: bool,
240}
241
242/// Parameters used to configure the underlying Prost generator
243#[derive(Debug, Default)]
244struct ProstParameters {
245    btree_map: Vec<String>,
246    bytes: Vec<String>,
247    boxed: Vec<String>,
248    disable_comments: Vec<String>,
249    skip_debug: Vec<String>,
250    default_package_filename: Option<String>,
251    extern_path: Vec<(String, String)>,
252    type_attribute: Vec<(String, String)>,
253    field_attribute: Vec<(String, String)>,
254    enum_attribute: Vec<(String, String)>,
255    message_attribute: Vec<(String, String)>,
256    compile_well_known_types: bool,
257    retain_enum_prefix: bool,
258    enable_type_names: bool,
259    flat_output_dir: bool,
260}
261
262impl ProstParameters {
263    /// Builds a [`prost_build::Config`] from the parameters
264    fn to_prost_config(&self) -> prost_build::Config {
265        let mut config = prost_build::Config::new();
266        config.btree_map(self.btree_map.iter());
267        config.bytes(self.bytes.iter());
268        for b in self.boxed.iter() {
269            config.boxed(b);
270        }
271        config.disable_comments(self.disable_comments.iter());
272        config.skip_debug(self.skip_debug.iter());
273
274        if let Some(filename) = self.default_package_filename.as_deref() {
275            config.default_package_filename(filename);
276        }
277
278        for (proto_path, rust_path) in &self.extern_path {
279            config.extern_path(proto_path, rust_path);
280        }
281        for (proto_path, attribute) in &self.type_attribute {
282            config.type_attribute(proto_path, attribute);
283        }
284        for (proto_path, attribute) in &self.field_attribute {
285            config.field_attribute(proto_path, attribute);
286        }
287        for (proto_path, attribute) in &self.enum_attribute {
288            config.enum_attribute(proto_path, attribute);
289        }
290        for (proto_path, attribute) in &self.message_attribute {
291            config.message_attribute(proto_path, attribute);
292        }
293
294        if self.compile_well_known_types {
295            config.compile_well_known_types();
296        }
297        if self.retain_enum_prefix {
298            config.retain_enum_prefix();
299        }
300        if self.enable_type_names {
301            config.enable_type_names();
302        }
303
304        config
305    }
306
307    fn default_package_filename(&self) -> Option<&str> {
308        self.default_package_filename.as_deref()
309    }
310
311    fn try_handle_parameter<'a>(&mut self, param: Param<'a>) -> std::result::Result<(), Param<'a>> {
312        match param {
313            Param::Value {
314                param: "btree_map",
315                value,
316            } => self.btree_map.push(value.to_string()),
317            Param::Value {
318                param: "bytes",
319                value,
320            } => self.bytes.push(value.to_string()),
321            Param::Value {
322                param: "boxed",
323                value,
324            } => self.boxed.push(value.to_string()),
325            Param::Parameter {
326                param: "default_package_filename",
327            }
328            | Param::Value {
329                param: "default_package_filename",
330                ..
331            } => self.default_package_filename = param.value().map(|s| s.into_owned()),
332            Param::Parameter {
333                param: "compile_well_known_types",
334            }
335            | Param::Value {
336                param: "compile_well_known_types",
337                value: "true",
338            } => self.compile_well_known_types = true,
339            Param::Value {
340                param: "compile_well_known_types",
341                value: "false",
342            } => (),
343            Param::Value {
344                param: "disable_comments",
345                value,
346            } => self.disable_comments.push(value.to_string()),
347            Param::Value {
348                param: "skip_debug",
349                value,
350            } => self.skip_debug.push(value.to_string()),
351            Param::Parameter {
352                param: "retain_enum_prefix",
353            }
354            | Param::Value {
355                param: "retain_enum_prefix",
356                value: "true",
357            } => self.retain_enum_prefix = true,
358            Param::Value {
359                param: "retain_enum_prefix",
360                value: "false",
361            } => (),
362            Param::KeyValue {
363                param: "extern_path",
364                key: prefix,
365                value: module,
366            } => self.extern_path.push((prefix.to_string(), module)),
367            Param::KeyValue {
368                param: "type_attribute",
369                key: prefix,
370                value: module,
371            } => self.type_attribute.push((
372                prefix.to_string(),
373                module.replace(r"\,", ",").replace(r"\\", r"\"),
374            )),
375            Param::KeyValue {
376                param: "field_attribute",
377                key: prefix,
378                value: module,
379            } => self.field_attribute.push((
380                prefix.to_string(),
381                module.replace(r"\,", ",").replace(r"\\", r"\"),
382            )),
383            Param::KeyValue {
384                param: "enum_attribute",
385                key: prefix,
386                value: module,
387            } => self.enum_attribute.push((
388                prefix.to_string(),
389                module.replace(r"\,", ",").replace(r"\\", r"\"),
390            )),
391            Param::KeyValue {
392                param: "message_attribute",
393                key: prefix,
394                value: module,
395            } => self.message_attribute.push((
396                prefix.to_string(),
397                module.replace(r"\,", ",").replace(r"\\", r"\"),
398            )),
399            Param::Parameter {
400                param: "enable_type_names",
401            }
402            | Param::Value {
403                param: "enable_type_names",
404                value: "true",
405            } => self.enable_type_names = true,
406            Param::Value {
407                param: "enable_type_names",
408                value: "false",
409            } => (),
410            Param::Parameter {
411                param: "flat_output_dir",
412            }
413            | Param::Value {
414                param: "flat_output_dir",
415                value: "true",
416            } => self.flat_output_dir = true,
417            Param::Value {
418                param: "flat_output_dir",
419                value: "false",
420            } => (),
421            _ => return Err(param),
422        }
423
424        Ok(())
425    }
426}
427
428/// Standard parameter regular expression
429///
430/// Supports the following forms:
431///
432/// ```text
433/// parameter
434/// parameter=key
435/// parameter=key=value
436/// ```
437///
438/// * `parameter` is terminated on the first `=` or `,`
439/// * If `parameter` is terminated with `=`, then `key` follows, terminated by the first `=` or `,`.
440/// * If `key` is terminated with `=`, then `value` follows. It is terminated only by `,`. However,
441///   if that `,` is prefixed by `\` but not `\\`, then it will not terminate.
442static PARAMETER: Lazy<regex::Regex> = Lazy::new(|| {
443    regex::Regex::new(
444        r"(?:(?P<param>[^,=]+)(?:=(?P<key>[^,=]+)(?:=(?P<value>(?:[^,\\]|\\,|\\\\)+))?)?)",
445    )
446    .unwrap()
447});
448
449pub struct Params<'a> {
450    params: Vec<Param<'a>>,
451}
452
453impl<'a> IntoIterator for Params<'a> {
454    type IntoIter = <Vec<Param<'a>> as IntoIterator>::IntoIter;
455    type Item = <Vec<Param<'a>> as IntoIterator>::Item;
456
457    fn into_iter(self) -> Self::IntoIter {
458        self.params.into_iter()
459    }
460}
461
462#[derive(Debug, PartialEq, Eq)]
463pub enum Param<'a> {
464    Parameter {
465        param: &'a str,
466    },
467    Value {
468        param: &'a str,
469        value: &'a str,
470    },
471    KeyValue {
472        param: &'a str,
473        key: &'a str,
474        value: String,
475    },
476}
477
478impl<'a> Param<'a> {
479    pub fn value(self) -> Option<Cow<'a, str>> {
480        match self {
481            Self::Parameter { .. } => None,
482            Self::Value { value, .. } => Some(Cow::Borrowed(value)),
483            Self::KeyValue { value, .. } => Some(Cow::Owned(value)),
484        }
485    }
486}
487
488impl From<Param<'_>> for InvalidParameter {
489    fn from(param: Param<'_>) -> Self {
490        let message = match param {
491            Param::Parameter { param } => param.to_owned(),
492            Param::Value { param, value } => format!("{param}={value}"),
493            Param::KeyValue { param, key, value } => {
494                let value = value.replace('\\', r"\\").replace(',', r"\,");
495                format!("{param}={key}={value}")
496            }
497        };
498        InvalidParameter(message)
499    }
500}
501
502impl<'a> Params<'a> {
503    pub fn from_protoc_plugin_opts(s: &'a str) -> std::result::Result<Self, InvalidParameter> {
504        let params = PARAMETER
505            .captures_iter(s)
506            .map(|capture| {
507                let param = capture
508                    .get(1)
509                    .expect("any captured group will at least have the param name")
510                    .as_str()
511                    .trim();
512
513                let key = capture.get(2).map(|m| m.as_str());
514                let value = capture.get(3).map(|m| m.as_str());
515
516                match (key, value) {
517                    (None, None) => Ok(Param::Parameter { param }),
518                    (Some(value), None) => Ok(Param::Value { param, value }),
519                    (Some(key), Some(value)) => Ok(Param::KeyValue {
520                        param,
521                        key,
522                        value: value.replace(r"\,", ",").replace(r"\\", r"\"),
523                    }),
524                    _ => Err(InvalidParameter(
525                        capture.get(0).unwrap().as_str().to_string(),
526                    )),
527                }
528            })
529            .collect::<std::result::Result<_, _>>()?;
530        Ok(Self { params })
531    }
532}
533
534impl str::FromStr for Parameters {
535    type Err = InvalidParameter;
536    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
537        let mut ret_val = Self::default();
538        for param in Params::from_protoc_plugin_opts(s)? {
539            if let Err(param) = ret_val.prost.try_handle_parameter(param) {
540                match param {
541                    Param::Parameter {
542                        param: "file_descriptor_set",
543                    }
544                    | Param::Value {
545                        param: "file_descriptor_set",
546                        value: "true",
547                    } => ret_val.file_descriptor_set = true,
548                    Param::Value {
549                        param: "file_descriptor_set",
550                        value: "false",
551                    } => (),
552                    _ => return Err(InvalidParameter::from(param)),
553                }
554            }
555        }
556
557        Ok(ret_val)
558    }
559}
560
561/// An invalid parameter
562#[derive(Debug)]
563pub struct InvalidParameter(String);
564
565impl InvalidParameter {
566    pub fn new(message: String) -> Self {
567        Self(message)
568    }
569}
570
571impl fmt::Display for InvalidParameter {
572    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
573        f.write_str("invalid parameter: ")?;
574        f.write_str(&self.0)
575    }
576}
577
578impl std::error::Error for InvalidParameter {}
579
580/// A wire-compatible reader of a [`CodeGeneratorRequest`]
581///
582/// This type treats the proto files contained in the request as raw byte
583/// arrays so that we can round-trip those bytes into the generated files
584/// as an encoded [`FileDescriptorSet`].
585///
586/// [`CodeGeneratorRequest`]: prost_types::compiler::CodeGeneratorRequest
587/// [`FileDescriptorSet`]: prost_types::FileDescriptorSet
588#[derive(Clone, PartialEq, ::prost::Message)]
589struct RawProtos {
590    #[prost(bytes = "vec", repeated, tag = "15")]
591    proto_file: Vec<Vec<u8>>,
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    #[test]
599    fn compiler_option_string_with_three_plus_equals_parses_correctly() {
600        const INPUT: &str = r#"flat_output_dir,enable_type_names,compile_well_known_types,disable_comments=.,skip_debug=.,extern_path=.google.protobuf=::pbjson_types,type_attribute=.=#[cfg(all(feature = "test"\, feature = "orange"))]"#;
601
602        let expected: &[Param] = &[
603            Param::Parameter {
604                param: "flat_output_dir",
605            },
606            Param::Parameter {
607                param: "enable_type_names",
608            },
609            Param::Parameter {
610                param: "compile_well_known_types",
611            },
612            Param::Value {
613                param: "disable_comments",
614                value: ".",
615            },
616            Param::Value {
617                param: "skip_debug",
618                value: ".",
619            },
620            Param::KeyValue {
621                param: "extern_path",
622                key: ".google.protobuf",
623                value: "::pbjson_types".into(),
624            },
625            Param::KeyValue {
626                param: "type_attribute",
627                key: ".",
628                value: r#"#[cfg(all(feature = "test", feature = "orange"))]"#.into(),
629            },
630        ];
631
632        let actual = Params::from_protoc_plugin_opts(INPUT).unwrap();
633        assert_eq!(actual.params, expected);
634    }
635}