#![doc(html_root_url = "https://docs.rs/prost-build/0.14.3")]
use std::io::Result;
use std::path::Path;
use prost_types::FileDescriptorSet;
mod ast;
pub use crate::ast::{Comments, Method, Service};
mod collections;
pub(crate) use collections::{BytesType, MapType};
mod code_generator;
mod context;
mod extern_paths;
mod ident;
mod message_graph;
mod path;
mod config;
pub use config::{
error_message_protoc_not_found, protoc_from_env, protoc_include_from_env, Config,
};
mod module;
pub use module::Module;
pub trait ServiceGenerator {
fn generate(&mut self, service: Service, buf: &mut String);
fn finalize(&mut self, _buf: &mut String) {}
fn finalize_package(&mut self, _package: &str, _buf: &mut String) {}
}
pub fn compile_protos(protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> Result<()> {
Config::new().compile_protos(protos, includes)
}
pub fn compile_fds(fds: FileDescriptorSet) -> Result<()> {
Config::new().compile_fds(fds)
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::rc::Rc;
use super::*;
macro_rules! assert_eq_fixture_file {
($expected_path:expr, $actual_path:expr) => {{
let actual = std::fs::read_to_string($actual_path).expect("Failed to read actual file");
let actual = actual.replace("\r\n", "\n");
assert_eq_fixture_contents!($expected_path, actual);
}};
}
macro_rules! assert_eq_fixture_contents {
($expected_path:expr, $actual:expr) => {{
let expected =
std::fs::read_to_string($expected_path).expect("Failed to read expected file");
let expected = expected.replace("\r\n", "\n");
if expected != $actual {
std::fs::write($expected_path, &$actual).expect("Failed to write expected file");
}
assert_eq!(expected, $actual);
}};
}
struct ServiceTraitGenerator;
impl ServiceGenerator for ServiceTraitGenerator {
fn generate(&mut self, service: Service, buf: &mut String) {
service.comments.append_with_indent(0, buf);
buf.push_str(&format!("trait {} {{\n", &service.name));
for method in service.methods {
method.comments.append_with_indent(1, buf);
buf.push_str(&format!(
" fn {}(_: {}) -> {};\n",
method.name, method.input_type, method.output_type
));
}
buf.push_str("}\n");
}
fn finalize(&mut self, buf: &mut String) {
buf.push_str("pub mod utils { }\n");
}
}
struct MockServiceGenerator {
state: Rc<RefCell<MockState>>,
}
#[derive(Default)]
struct MockState {
service_names: Vec<String>,
package_names: Vec<String>,
finalized: u32,
}
impl MockServiceGenerator {
fn new(state: Rc<RefCell<MockState>>) -> Self {
Self { state }
}
}
impl ServiceGenerator for MockServiceGenerator {
fn generate(&mut self, service: Service, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.service_names.push(service.name);
}
fn finalize(&mut self, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.finalized += 1;
}
fn finalize_package(&mut self, package: &str, _buf: &mut String) {
let mut state = self.state.borrow_mut();
state.package_names.push(package.to_string());
}
}
#[test]
fn smoke_test() {
let _ = env_logger::try_init();
let tempdir = tempfile::tempdir().unwrap();
Config::new()
.service_generator(Box::new(ServiceTraitGenerator))
.out_dir(tempdir.path())
.compile_protos(&["src/fixtures/smoke_test/smoke_test.proto"], &["src"])
.unwrap();
for entry in std::fs::read_dir(tempdir.path()).unwrap() {
let file = entry.unwrap();
let file_name = file.file_name().into_string().unwrap();
assert_eq!(file_name, "smoke_test.rs");
assert_eq_fixture_file!(
if cfg!(feature = "format") {
"src/fixtures/smoke_test/_expected_smoke_test_formatted.rs"
} else {
"src/fixtures/smoke_test/_expected_smoke_test.rs"
},
file.path()
);
}
}
#[test]
fn finalize_package() {
let _ = env_logger::try_init();
let tempdir = tempfile::tempdir().unwrap();
let state = Rc::new(RefCell::new(MockState::default()));
let generator = MockServiceGenerator::new(Rc::clone(&state));
Config::new()
.service_generator(Box::new(generator))
.include_file("_protos.rs")
.out_dir(tempdir.path())
.compile_protos(
&[
"src/fixtures/helloworld/hello.proto",
"src/fixtures/helloworld/goodbye.proto",
],
&["src/fixtures/helloworld"],
)
.unwrap();
let state = state.borrow();
assert_eq!(&state.service_names, &["Greeting", "Farewell"]);
assert_eq!(&state.package_names, &["helloworld"]);
assert_eq!(state.finalized, 3);
}
#[test]
fn test_generate_message_attributes() {
let _ = env_logger::try_init();
let tempdir = tempfile::tempdir().unwrap();
let mut config = Config::new();
config
.out_dir(tempdir.path())
.message_attribute(".", "#[derive(derive_builder::Builder)]")
.enum_attribute(".", "#[some_enum_attr(u8)]");
let fds = config
.load_fds(
&["src/fixtures/helloworld/hello.proto"],
&["src/fixtures/helloworld"],
)
.unwrap();
for file in &fds.file {
for service in &file.service {
for method in &service.method {
if let Some(input) = &method.input_type {
config.message_attribute(input, "#[derive(custom_proto::Input)]");
}
if let Some(output) = &method.output_type {
config.message_attribute(output, "#[derive(custom_proto::Output)]");
}
}
}
}
config.compile_fds(fds).unwrap();
for entry in std::fs::read_dir(tempdir.path()).unwrap() {
let file = entry.unwrap();
let file_name = file.file_name().into_string().unwrap();
assert_eq_fixture_file!(
format!("src/fixtures/helloworld/_expected_{file_name}"),
file.path()
);
}
}
#[test]
fn test_generate_no_empty_outputs() {
let _ = env_logger::try_init();
let state = Rc::new(RefCell::new(MockState::default()));
let generator = MockServiceGenerator::new(Rc::clone(&state));
let include_file = "_include.rs";
let tempdir = tempfile::tempdir().unwrap();
let previously_empty_proto_path = tempdir.path().join(Path::new("google.protobuf.rs"));
Config::new()
.service_generator(Box::new(generator))
.include_file(include_file)
.out_dir(tempdir.path())
.compile_protos(
&["src/fixtures/imports_empty/imports_empty.proto"],
&["src/fixtures/imports_empty"],
)
.unwrap();
assert!(!std::fs::exists(previously_empty_proto_path).unwrap());
for entry in std::fs::read_dir(tempdir.path()).unwrap() {
let file = entry.unwrap();
let file_name = file.file_name().into_string().unwrap();
if file_name == include_file {
assert_eq_fixture_file!(
"src/fixtures/imports_empty/_expected_include.rs",
file.path()
);
} else if file_name == "com.prost_test.test.v1.rs" {
let content = std::fs::read_to_string(file.path()).unwrap();
assert!(content.contains("struct TestConfig"));
assert!(content.contains("struct GetTestResponse"));
} else {
panic!("Found unexpected file: {}", file_name);
}
}
}
#[test]
fn test_generate_field_attributes() {
let _ = env_logger::try_init();
let tempdir = tempfile::tempdir().unwrap();
Config::new()
.out_dir(tempdir.path())
.boxed("Container.data.foo")
.boxed("Bar.qux")
.compile_protos(
&["src/fixtures/field_attributes/field_attributes.proto"],
&["src/fixtures/field_attributes"],
)
.unwrap();
assert_eq_fixture_file!(
if cfg!(feature = "format") {
"src/fixtures/field_attributes/_expected_field_attributes_formatted.rs"
} else {
"src/fixtures/field_attributes/_expected_field_attributes.rs"
},
tempdir.path().join("field_attributes.rs")
);
}
#[test]
fn deterministic_include_file() {
let _ = env_logger::try_init();
for _ in 1..10 {
let state = Rc::new(RefCell::new(MockState::default()));
let generator = MockServiceGenerator::new(Rc::clone(&state));
let include_file = "_include.rs";
let tempdir = tempfile::tempdir().unwrap();
Config::new()
.service_generator(Box::new(generator))
.include_file(include_file)
.out_dir(tempdir.path())
.compile_protos(
&[
"src/fixtures/alphabet/a.proto",
"src/fixtures/alphabet/b.proto",
"src/fixtures/alphabet/c.proto",
"src/fixtures/alphabet/d.proto",
"src/fixtures/alphabet/e.proto",
"src/fixtures/alphabet/f.proto",
],
&["src/fixtures/alphabet"],
)
.unwrap();
assert_eq_fixture_file!(
"src/fixtures/alphabet/_expected_include.rs",
tempdir.path().join(Path::new(include_file))
);
}
}
#[test]
fn write_includes() {
let modules = [
Module::from_protobuf_package_name("foo.bar.baz"),
Module::from_protobuf_package_name(""),
Module::from_protobuf_package_name("foo.bar"),
Module::from_protobuf_package_name("bar"),
Module::from_protobuf_package_name("foo"),
Module::from_protobuf_package_name("foo.bar.qux"),
Module::from_protobuf_package_name("foo.bar.a.b.c"),
];
let file_names = modules
.iter()
.map(|m| (m.clone(), m.to_file_name_or("_.default")))
.collect();
let mut buf = Vec::new();
Config::new()
.default_package_filename("_.default")
.write_includes(modules.iter().collect(), &mut buf, None, &file_names)
.unwrap();
let actual = String::from_utf8(buf).unwrap();
assert_eq_fixture_contents!("src/fixtures/write_includes/_.includes.rs", actual);
}
#[test]
fn test_generate_deprecated() {
let _ = env_logger::try_init();
let tempdir = tempfile::tempdir().unwrap();
Config::new()
.out_dir(tempdir.path())
.compile_protos(
&["src/fixtures/deprecated/all_deprecated.proto"],
&["src/fixtures/deprecated"],
)
.unwrap();
assert_eq_fixture_file!(
"src/fixtures/deprecated/_all_deprecated.rs",
tempdir.path().join("all_deprecated.rs")
);
}
}