#![doc = include_str!("../README.md")]
use self::generator::{CoreProstGenerator, FileDescriptorSetGenerator};
use once_cell::sync::Lazy;
use prost::Message;
use prost_build::Module;
use prost_types::compiler::code_generator_response::File;
use prost_types::compiler::CodeGeneratorRequest;
use prost_types::FileDescriptorProto;
use std::collections::{BTreeMap, HashSet};
use std::{fmt, str};
mod generator;
pub use self::generator::{Error, Generator, GeneratorResultExt, Result};
pub fn execute(raw_request: &[u8]) -> generator::Result {
let request = CodeGeneratorRequest::decode(raw_request)?;
let params = request.parameter().parse::<Parameters>()?;
let module_request_set = ModuleRequestSet::new(
request.file_to_generate,
request.proto_file,
raw_request,
params.prost.default_package_filename(),
)?;
let files = CoreProstGenerator::new(params.prost.to_prost_config())
.chain(FileDescriptorSetGenerator)
.generate(&module_request_set)?;
Ok(files)
}
pub struct ModuleRequestSet {
requests: BTreeMap<Module, ModuleRequest>,
}
impl ModuleRequestSet {
pub fn new<I>(
input_protos: I,
proto_file: Vec<FileDescriptorProto>,
raw_request: &[u8],
default_package_filename: Option<&str>,
) -> std::result::Result<Self, prost::DecodeError>
where
I: IntoIterator<Item = String>,
{
let raw_protos = RawProtos::decode(raw_request)?;
Ok(Self::new_decoded(
input_protos,
proto_file,
raw_protos,
default_package_filename.unwrap_or("_"),
))
}
fn new_decoded<I>(
input_protos: I,
proto_file: Vec<FileDescriptorProto>,
raw_protos: RawProtos,
default_package_filename: &str,
) -> Self
where
I: IntoIterator<Item = String>,
{
let input_protos: HashSet<_> = input_protos.into_iter().collect();
let requests = proto_file.into_iter().zip(raw_protos.proto_file).fold(
BTreeMap::new(),
|mut acc, (proto, raw)| {
let module = Module::from_protobuf_package_name(proto.package());
let proto_filename = proto.name();
let entry = acc.entry(module).or_insert_with(|| {
let mut request = ModuleRequest::new(proto.package().to_owned());
if input_protos.contains(proto_filename) {
let filename = match proto.package() {
"" => default_package_filename.to_owned(),
package => format!("{package}.rs"),
};
request.with_output_filename(filename);
}
request
});
entry.push_file_descriptor_proto(proto, raw);
acc
},
);
Self { requests }
}
pub fn requests(&self) -> impl Iterator<Item = (&Module, &ModuleRequest)> {
self.requests.iter()
}
pub fn for_module(&self, module: &Module) -> Option<&ModuleRequest> {
self.requests.get(module)
}
}
pub struct ModuleRequest {
proto_package_name: String,
output_filename: Option<String>,
files: Vec<FileDescriptorProto>,
raw: Vec<Vec<u8>>,
}
impl ModuleRequest {
fn new(proto_package_name: String) -> Self {
Self {
proto_package_name,
output_filename: None,
files: Vec::new(),
raw: Vec::new(),
}
}
fn with_output_filename(&mut self, filename: String) {
self.output_filename = Some(filename);
}
fn push_file_descriptor_proto(&mut self, encoded: FileDescriptorProto, raw: Vec<u8>) {
self.files.push(encoded);
self.raw.push(raw);
}
pub fn proto_package_name(&self) -> &str {
&self.proto_package_name
}
pub fn output_filename(&self) -> Option<&str> {
self.output_filename.as_deref()
}
pub fn files(&self) -> impl Iterator<Item = &FileDescriptorProto> {
self.files.iter()
}
pub fn raw_files(&self) -> impl Iterator<Item = &[u8]> {
self.raw.iter().map(|b| b.as_slice())
}
pub(crate) fn write_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
self.output_filename.as_deref().map(|name| {
let mut content = String::with_capacity(8_192);
f(&mut content);
File {
name: Some(name.to_owned()),
content: Some(content),
..Default::default()
}
})
}
pub fn append_to_file<F: FnOnce(&mut String)>(&self, f: F) -> Option<File> {
self.output_filename.as_deref().map(|name| {
let mut content = String::new();
f(&mut content);
File {
name: Some(name.to_owned()),
content: Some(content),
insertion_point: Some("module".to_owned()),
..Default::default()
}
})
}
}
#[derive(Debug, Default)]
struct Parameters {
prost: ProstParameters,
file_descriptor_set: bool,
}
#[derive(Debug, Default)]
struct ProstParameters {
btree_map: Vec<String>,
bytes: Vec<String>,
disable_comments: Vec<String>,
default_package_filename: Option<String>,
extern_path: Vec<(String, String)>,
type_attribute: Vec<(String, String)>,
field_attribute: Vec<(String, String)>,
compile_well_known_types: bool,
retain_enum_prefix: bool,
}
impl ProstParameters {
fn to_prost_config(&self) -> prost_build::Config {
let mut config = prost_build::Config::new();
config.btree_map(self.btree_map.iter());
config.bytes(self.bytes.iter());
config.disable_comments(self.disable_comments.iter());
if let Some(filename) = self.default_package_filename.as_deref() {
config.default_package_filename(filename);
}
for (proto_path, rust_path) in &self.extern_path {
config.extern_path(proto_path, rust_path);
}
for (proto_path, attribute) in &self.type_attribute {
config.type_attribute(proto_path, attribute);
}
for (proto_path, attribute) in &self.field_attribute {
config.field_attribute(proto_path, attribute);
}
if self.compile_well_known_types {
config.compile_well_known_types();
}
if self.retain_enum_prefix {
config.retain_enum_prefix();
}
config
}
fn default_package_filename(&self) -> Option<&str> {
self.default_package_filename.as_deref()
}
fn try_handle_parameter(
&mut self,
param: &str,
key: Option<&str>,
value: Option<&str>,
) -> std::result::Result<(), ()> {
match (param, key, value) {
("btree_map", Some(value), None) => self.btree_map.push(value.to_string()),
("bytes", Some(value), None) => self.bytes.push(value.to_string()),
("default_package_filename", value, None) => {
self.default_package_filename = value.map(|s| s.to_string())
}
("compile_well_known_types", Some("true") | None, None) => {
self.compile_well_known_types = true
}
("compile_well_known_types", Some("false"), None) => (),
("disable_comments", Some(value), None) => {
self.disable_comments.push(value.to_string())
}
("retain_enum_prefix", Some("true") | None, None) => self.retain_enum_prefix = true,
("retain_enum_prefix", Some("false"), None) => (),
("extern_path", Some(prefix), Some(module)) => self
.extern_path
.push((prefix.to_string(), module.to_string())),
("type_attribute", Some(prefix), Some(module)) => self
.type_attribute
.push((prefix.to_string(), module.replace(r"\,", ","))),
("field_attribute", Some(prefix), Some(module)) => self
.field_attribute
.push((prefix.to_string(), module.replace(r"\,", ","))),
_ => return Err(()),
}
Ok(())
}
}
static PARAMETER: Lazy<regex::Regex> = Lazy::new(|| {
regex::Regex::new(
r"(?:(?P<param>[^,=]+)(?:=(?P<key>[^,=]+)(?:=(?P<value>(?:[^,=\\]|\\,|\\)+))?)?)",
)
.unwrap()
});
impl str::FromStr for Parameters {
type Err = InvalidParameter;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let mut ret_val = Self::default();
for capture in PARAMETER.captures_iter(s) {
let param = capture
.get(1)
.expect("any captured group will at least have the param name")
.as_str()
.trim();
let key = capture.get(2).map(|m| m.as_str());
let value = capture.get(3).map(|m| m.as_str());
if ret_val
.prost
.try_handle_parameter(param, key, value)
.is_err()
{
match (param, key, value) {
("file_descriptor_set", Some("true") | None, None) => {
ret_val.file_descriptor_set = true
}
("file_descriptor_set", Some("false"), None) => (),
_ => {
return Err(InvalidParameter(
capture.get(0).unwrap().as_str().to_string(),
))
}
}
}
}
Ok(ret_val)
}
}
#[derive(Debug)]
struct InvalidParameter(String);
impl fmt::Display for InvalidParameter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("invalid parameter: ")?;
f.write_str(&self.0)
}
}
impl std::error::Error for InvalidParameter {}
#[derive(Clone, PartialEq, ::prost::Message)]
struct RawProtos {
#[prost(bytes = "vec", repeated, tag = "15")]
proto_file: Vec<Vec<u8>>,
}