prost_validate_build/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
//! `prost-validate-build` contains [`Builder`] to configure [`prost_build::Config`]
//! to derive [`prost_validate::Validator`] for all messages in protocol buffers.
//!
//! The simplest way to generate protocol buffer API:
//!
//! ```no_run
//! // build.rs
//! use prost_validate_build::Builder;
//!
//! Builder::new()
//! .compile_protos(&["path/to/protobuf.proto"], &["path/to/include"])
//! .expect("Failed to compile protos");
//! ```
mod rules;
use crate::rules::IntoFieldAttribute;
use prost_reflect::prost_types::FileDescriptorProto;
use prost_reflect::{DescriptorPool, OneofDescriptor};
use prost_validate_types::{FieldRulesExt, MessageRulesExt, OneofRulesExt};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::{env, fs, io};
/// Configuration builder for prost-validate code generation.
///
/// ```no_run
/// # use prost_validate_build::Builder;
/// Builder::new()
/// .compile_protos(&["path/to/protobuf.proto"], &["path/to/include"])
/// .unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct Builder {
file_descriptor_set_path: PathBuf,
}
impl Default for Builder {
fn default() -> Self {
let file_descriptor_set_path = env::var_os("OUT_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("."))
.join("file_descriptor_set.bin");
Self {
file_descriptor_set_path,
}
}
}
impl Builder {
/// Create a new builder with default parameters.
pub fn new() -> Self {
Self::default()
}
/// Set the path where the encoded file descriptor set is created.
/// By default, it is created at `$OUT_DIR/file_descriptor_set.bin`.
///
/// This overrides the path specified by
/// [`prost_build::Config::file_descriptor_set_path`].
pub fn file_descriptor_set_path<P>(&mut self, path: P) -> &mut Self
where
P: Into<PathBuf>,
{
self.file_descriptor_set_path = path.into();
self
}
/// Configure `config` to derive [`prost_validate::Validator`] for all messages included in `protos`.
/// This method does not generate prost-validate compatible code,
/// but `config` may be used later to compile protocol buffers independently of [`Builder`].
/// `protos` and `includes` should be the same when [`prost_build::Config::compile_protos`] is called on `config`.
///
/// ```ignore
/// let mut config = Config::new();
///
/// // Customize config here
///
/// Builder::new()
/// .configure(&mut config, &["path/to/protobuf.proto"], &["path/to/include"])
/// .expect("Failed to configure for reflection");
///
/// // Custom compilation process with `config`
/// config.compile_protos(&["path/to/protobuf.proto"], &["path/to/includes"])
/// .expect("Failed to compile protocol buffers");
/// ```
pub fn configure(
&mut self,
config: &mut prost_build::Config,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
config
.file_descriptor_set_path(&self.file_descriptor_set_path)
.compile_protos(protos, includes)?;
let buf = fs::read(&self.file_descriptor_set_path)?;
let descriptor = DescriptorPool::decode(buf.as_ref()).expect("Invalid file descriptor");
self.annotate(config, &descriptor);
Ok(())
}
pub fn configure_with_file_descriptor_protos(
&mut self,
config: &mut prost_build::Config,
protos: &[FileDescriptorProto],
) -> io::Result<()> {
let descriptor = {
let mut d = DescriptorPool::new();
d.add_file_descriptor_protos(protos.to_owned())
.expect("Invalid file descriptor protos");
d
};
self.annotate(config, &descriptor);
Ok(())
}
/// Compile protocol buffers into Rust with given [`prost_build::Config`].
pub fn compile_protos_with_config(
&mut self,
mut config: prost_build::Config,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
self.configure(&mut config, protos, includes)?;
config.skip_protoc_run().compile_protos(protos, includes)
}
/// Compile protocol buffers into Rust.
pub fn compile_protos(
&mut self,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
self.compile_protos_with_config(prost_build::Config::new(), protos, includes)
}
pub fn annotate(&self, config: &mut prost_build::Config, descriptor: &DescriptorPool) {
for message in descriptor.all_messages() {
let full_name = message.full_name();
config.type_attribute(full_name, "#[derive(::prost_validate::Validator)]");
if message.validation_ignored() || message.validation_disabled() {
continue;
}
let mut oneofs: HashMap<String, Rc<OneofDescriptor>> = HashMap::new();
for field in message.fields() {
config.field_attribute(
field.full_name(),
format!("#[validate(name = \"{}\")]", field.full_name()),
);
let field_rules = match field.validation_rules().unwrap() {
Some(r) => r,
None => continue,
};
if oneofs.contains_key(field.full_name()) {
continue;
}
if let Some(ref desc) = field.containing_oneof() {
config.field_attribute(
desc.full_name(),
format!("#[validate(name = \"{}\")]", desc.full_name()),
);
let desc = Rc::new(desc.clone());
config
.type_attribute(desc.full_name(), "#[derive(::prost_validate::Validator)]");
if desc.required() {
config.field_attribute(desc.full_name(), "#[validate(required)]");
}
for field in desc.fields() {
let field = field.clone();
config.field_attribute(
format!("{}.{}", desc.full_name(), field.name()),
format!("#[validate(name = \"{}\")]", field.full_name()),
);
oneofs.insert(field.full_name().to_string(), desc.clone());
let field_rules = match field.validation_rules().unwrap() {
Some(r) => r,
None => continue,
};
let field_attribute = field_rules.into_field_attribute();
if let Some(attribute) = field_attribute {
// this is not very protobuf typical, but it is the way it is implemented in prost-build
config.field_attribute(
format!("{}.{}", desc.full_name(), field.name()),
format!("#[validate({})]", attribute),
);
}
}
continue;
}
let field_attribute = field_rules.into_field_attribute();
if let Some(attribute) = field_attribute {
config
.field_attribute(field.full_name(), format!("#[validate({})]", attribute));
}
}
}
}
}