#![doc(html_root_url = "https://docs.rs/prost-build/0.8.0")]
#![allow(clippy::option_as_ref_deref)]
mod ast;
mod code_generator;
mod extern_paths;
mod ident;
mod message_graph;
mod path;
use std::collections::HashMap;
use std::default;
use std::env;
use std::ffi::{OsStr, OsString};
use std::fmt;
use std::fs;
use std::io::{Error, ErrorKind, Result};
use std::path::{Path, PathBuf};
use std::process::Command;
use log::trace;
use prost::Message;
use prost_types::{FileDescriptorProto, FileDescriptorSet};
pub use crate::ast::{Comments, Method, Service};
use crate::code_generator::CodeGenerator;
use crate::extern_paths::ExternPaths;
use crate::ident::to_snake;
use crate::message_graph::MessageGraph;
use crate::path::PathMap;
type Module = Vec<String>;
pub trait ServiceGenerator {
fn generate(&mut self, service: Service, buf: &mut String);
fn finalize(&mut self, _buf: &mut String) {}
fn finalize_package(&mut self, _package: &str, _buf: &mut String) {}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
enum MapType {
HashMap,
BTreeMap,
}
impl Default for MapType {
fn default() -> MapType {
MapType::HashMap
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
enum BytesType {
Vec,
Bytes,
}
impl Default for BytesType {
fn default() -> BytesType {
BytesType::Vec
}
}
pub struct Config {
file_descriptor_set_path: Option<PathBuf>,
service_generator: Option<Box<dyn ServiceGenerator>>,
map_type: PathMap<MapType>,
bytes_type: PathMap<BytesType>,
type_attributes: PathMap<String>,
field_attributes: PathMap<String>,
prost_types: bool,
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
extern_paths: Vec<(String, String)>,
protoc_args: Vec<OsString>,
disable_comments: PathMap<()>,
}
impl Config {
pub fn new() -> Config {
Config::default()
}
pub fn btree_map<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.map_type.clear();
for matcher in paths {
self.map_type
.insert(matcher.as_ref().to_string(), MapType::BTreeMap);
}
self
}
pub fn bytes<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.bytes_type.clear();
for matcher in paths {
self.bytes_type
.insert(matcher.as_ref().to_string(), BytesType::Bytes);
}
self
}
pub fn field_attribute<P, A>(&mut self, path: P, attribute: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
self.field_attributes
.insert(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn type_attribute<P, A>(&mut self, path: P, attribute: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
self.type_attributes
.insert(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
self
}
pub fn compile_well_known_types(&mut self) -> &mut Self {
self.prost_types = false;
self
}
pub fn disable_comments<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.disable_comments.clear();
for matcher in paths {
self.disable_comments
.insert(matcher.as_ref().to_string(), ());
}
self
}
pub fn extern_path<P1, P2>(&mut self, proto_path: P1, rust_path: P2) -> &mut Self
where
P1: Into<String>,
P2: Into<String>,
{
self.extern_paths
.push((proto_path.into(), rust_path.into()));
self
}
pub fn file_descriptor_set_path<P>(&mut self, path: P) -> &mut Self
where
P: Into<PathBuf>,
{
self.file_descriptor_set_path = Some(path.into());
self
}
pub fn retain_enum_prefix(&mut self) -> &mut Self {
self.strip_enum_prefix = false;
self
}
pub fn out_dir<P>(&mut self, path: P) -> &mut Self
where
P: Into<PathBuf>,
{
self.out_dir = Some(path.into());
self
}
pub fn protoc_arg<S>(&mut self, arg: S) -> &mut Self
where
S: AsRef<OsStr>,
{
self.protoc_args.push(arg.as_ref().to_owned());
self
}
pub fn compile_protos(
&mut self,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> Result<()> {
let target: PathBuf = self.out_dir.clone().map(Ok).unwrap_or_else(|| {
env::var_os("OUT_DIR")
.ok_or_else(|| {
Error::new(ErrorKind::Other, "OUT_DIR environment variable is not set")
})
.map(Into::into)
})?;
let tmp;
let file_descriptor_set_path = match self.file_descriptor_set_path.clone() {
Some(file_descriptor_set_path) => file_descriptor_set_path,
None => {
tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
tmp.path().join("prost-descriptor-set")
}
};
let mut cmd = Command::new(protoc());
cmd.arg("--include_imports")
.arg("--include_source_info")
.arg("-o")
.arg(&file_descriptor_set_path);
for include in includes {
cmd.arg("-I").arg(include.as_ref());
}
cmd.arg("-I").arg(protoc_include());
for arg in &self.protoc_args {
cmd.arg(arg);
}
for proto in protos {
cmd.arg(proto.as_ref());
}
let output = cmd.output().map_err(|error| {
Error::new(
error.kind(),
format!("failed to invoke protoc (hint: https://docs.rs/prost-build/#sourcing-protoc): {}", error),
)
})?;
if !output.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
));
}
let buf = fs::read(file_descriptor_set_path)?;
let file_descriptor_set = FileDescriptorSet::decode(&*buf).map_err(|error| {
Error::new(
ErrorKind::InvalidInput,
format!("invalid FileDescriptorSet: {}", error.to_string()),
)
})?;
let modules = self.generate(file_descriptor_set.file)?;
for (module, content) in modules {
let mut filename = module.join(".");
filename.push_str(".rs");
let output_path = target.join(&filename);
let previous_content = fs::read(&output_path);
if previous_content
.map(|previous_content| previous_content == content.as_bytes())
.unwrap_or(false)
{
trace!("unchanged: {:?}", filename);
} else {
trace!("writing: {:?}", filename);
fs::write(output_path, content)?;
}
}
Ok(())
}
fn generate(&mut self, files: Vec<FileDescriptorProto>) -> Result<HashMap<Module, String>> {
let mut modules = HashMap::new();
let mut packages = HashMap::new();
let message_graph = MessageGraph::new(&files)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
for file in files {
let module = self.module(&file);
if !file.service.is_empty() {
packages.insert(module.clone(), file.package().to_string());
}
let mut buf = modules.entry(module).or_insert_with(String::new);
CodeGenerator::generate(self, &message_graph, &extern_paths, file, &mut buf);
}
if let Some(ref mut service_generator) = self.service_generator {
for (module, package) in packages {
let buf = modules.get_mut(&module).unwrap();
service_generator.finalize_package(&package, buf);
}
}
Ok(modules)
}
fn module(&self, file: &FileDescriptorProto) -> Module {
file.package()
.split('.')
.filter(|s| !s.is_empty())
.map(to_snake)
.collect()
}
}
impl default::Default for Config {
fn default() -> Config {
Config {
file_descriptor_set_path: None,
service_generator: None,
map_type: PathMap::default(),
bytes_type: PathMap::default(),
type_attributes: PathMap::default(),
field_attributes: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
extern_paths: Vec::new(),
protoc_args: Vec::new(),
disable_comments: PathMap::default(),
}
}
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("file_descriptor_set_path", &self.file_descriptor_set_path)
.field(
"service_generator",
&self.file_descriptor_set_path.is_some(),
)
.field("map_type", &self.map_type)
.field("bytes_type", &self.bytes_type)
.field("type_attributes", &self.type_attributes)
.field("field_attributes", &self.field_attributes)
.field("prost_types", &self.prost_types)
.field("strip_enum_prefix", &self.strip_enum_prefix)
.field("out_dir", &self.out_dir)
.field("extern_paths", &self.extern_paths)
.field("protoc_args", &self.protoc_args)
.field("disable_comments", &self.disable_comments)
.finish()
}
}
pub fn compile_protos(protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> Result<()> {
Config::new().compile_protos(protos, includes)
}
pub fn protoc() -> PathBuf {
match env::var_os("PROTOC") {
Some(protoc) => PathBuf::from(protoc),
None => PathBuf::from(env!("PROTOC")),
}
}
pub fn protoc_include() -> PathBuf {
match env::var_os("PROTOC_INCLUDE") {
Some(include) => PathBuf::from(include),
None => PathBuf::from(env!("PROTOC_INCLUDE")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
use std::rc::Rc;
struct ServiceTraitGenerator;
impl ServiceGenerator for ServiceTraitGenerator {
fn generate(&mut self, service: Service, buf: &mut String) {
service.comments.append_with_indent(0, buf);
buf.push_str(&format!("trait {} {{\n", &service.name));
for method in service.methods {
method.comments.append_with_indent(1, buf);
buf.push_str(&format!(
" fn {}({}) -> {};\n",
method.name, method.input_type, method.output_type
));
}
buf.push_str("}\n");
}
fn finalize(&mut self, buf: &mut String) {
buf.push_str("pub mod utils { }\n");
}
}
struct MockServiceGenerator {
state: Rc<RefCell<MockState>>,
}
#[derive(Default)]
struct MockState {
service_names: Vec<String>,
package_names: Vec<String>,
finalized: u32,
}
impl MockServiceGenerator {
fn new(state: Rc<RefCell<MockState>>) -> Self {
Self { state }
}
}
impl ServiceGenerator for MockServiceGenerator {
fn generate(&mut self, service: Service, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.service_names.push(service.name);
}
fn finalize(&mut self, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.finalized += 1;
}
fn finalize_package(&mut self, package: &str, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.package_names.push(package.to_string());
}
}
#[test]
fn smoke_test() {
let _ = env_logger::try_init();
Config::new()
.service_generator(Box::new(ServiceTraitGenerator))
.compile_protos(&["src/smoke_test.proto"], &["src"])
.unwrap();
}
#[test]
fn finalize_package() {
let _ = env_logger::try_init();
let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));
Config::new()
.service_generator(Box::new(gen))
.compile_protos(&["src/hello.proto", "src/goodbye.proto"], &["src"])
.unwrap();
let state = state.borrow();
assert_eq!(&state.service_names, &["Greeting", "Farewell"]);
assert_eq!(&state.package_names, &["helloworld"]);
assert_eq!(state.finalized, 3);
}
}