prost_validate_build/
lib.rs

1//! `prost-validate-build` contains [`Builder`] to configure [`prost_build::Config`]
2//! to derive [`prost_validate::Validator`] for all messages in protocol buffers.
3//!
4//! The simplest way to generate protocol buffer API:
5//!
6//! ```no_run
7//! // build.rs
8//! use prost_validate_build::Builder;
9//!
10//! Builder::new()
11//!     .compile_protos(&["path/to/protobuf.proto"], &["path/to/include"])
12//!     .expect("Failed to compile protos");
13//! ```
14mod rules;
15
16use crate::rules::IntoFieldAttribute;
17use prost_reflect::prost_types::FileDescriptorProto;
18use prost_reflect::{DescriptorPool, OneofDescriptor};
19use prost_validate_types::{FieldRulesExt, MessageRulesExt, OneofRulesExt};
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::rc::Rc;
23use std::{env, fs, io};
24
25/// Configuration builder for prost-validate code generation.
26///
27/// ```no_run
28/// # use prost_validate_build::Builder;
29/// Builder::new()
30///     .compile_protos(&["path/to/protobuf.proto"], &["path/to/include"])
31///     .unwrap();
32/// ```
33#[derive(Debug, Clone)]
34pub struct Builder {
35    file_descriptor_set_path: PathBuf,
36}
37
38impl Default for Builder {
39    fn default() -> Self {
40        let file_descriptor_set_path = env::var_os("OUT_DIR")
41            .map(PathBuf::from)
42            .unwrap_or_else(|| PathBuf::from("."))
43            .join("file_descriptor_set.bin");
44
45        Self {
46            file_descriptor_set_path,
47        }
48    }
49}
50
51impl Builder {
52    /// Create a new builder with default parameters.
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Set the path where the encoded file descriptor set is created.
58    /// By default, it is created at `$OUT_DIR/file_descriptor_set.bin`.
59    ///
60    /// This overrides the path specified by
61    /// [`prost_build::Config::file_descriptor_set_path`].
62    pub fn file_descriptor_set_path<P>(&mut self, path: P) -> &mut Self
63    where
64        P: Into<PathBuf>,
65    {
66        self.file_descriptor_set_path = path.into();
67        self
68    }
69
70    /// Configure `config` to derive [`prost_validate::Validator`] for all messages included in `protos`.
71    /// This method does not generate prost-validate compatible code,
72    /// but `config` may be used later to compile protocol buffers independently of [`Builder`].
73    /// `protos` and `includes` should be the same when [`prost_build::Config::compile_protos`] is called on `config`.
74    ///
75    /// ```ignore
76    /// let mut config = Config::new();
77    ///
78    /// // Customize config here
79    ///
80    /// Builder::new()
81    ///     .configure(&mut config, &["path/to/protobuf.proto"], &["path/to/include"])
82    ///     .expect("Failed to configure for reflection");
83    ///
84    /// // Custom compilation process with `config`
85    /// config.compile_protos(&["path/to/protobuf.proto"], &["path/to/includes"])
86    ///     .expect("Failed to compile protocol buffers");
87    /// ```
88    pub fn configure(
89        &mut self,
90        config: &mut prost_build::Config,
91        protos: &[impl AsRef<Path>],
92        includes: &[impl AsRef<Path>],
93    ) -> io::Result<()> {
94        config
95            .file_descriptor_set_path(&self.file_descriptor_set_path)
96            .compile_protos(protos, includes)?;
97
98        let buf = fs::read(&self.file_descriptor_set_path)?;
99        let descriptor = DescriptorPool::decode(buf.as_ref()).expect("Invalid file descriptor");
100        self.annotate(config, &descriptor);
101        Ok(())
102    }
103
104    pub fn configure_with_file_descriptor_protos(
105        &mut self,
106        config: &mut prost_build::Config,
107        protos: &[FileDescriptorProto],
108    ) -> io::Result<()> {
109        let descriptor = {
110            let mut d = DescriptorPool::new();
111            d.add_file_descriptor_protos(protos.to_owned())
112                .expect("Invalid file descriptor protos");
113            d
114        };
115        self.annotate(config, &descriptor);
116        Ok(())
117    }
118
119    /// Compile protocol buffers into Rust with given [`prost_build::Config`].
120    pub fn compile_protos_with_config(
121        &mut self,
122        mut config: prost_build::Config,
123        protos: &[impl AsRef<Path>],
124        includes: &[impl AsRef<Path>],
125    ) -> io::Result<()> {
126        self.configure(&mut config, protos, includes)?;
127
128        config.skip_protoc_run().compile_protos(protos, includes)
129    }
130
131    /// Compile protocol buffers into Rust.
132    pub fn compile_protos(
133        &mut self,
134        protos: &[impl AsRef<Path>],
135        includes: &[impl AsRef<Path>],
136    ) -> io::Result<()> {
137        self.compile_protos_with_config(prost_build::Config::new(), protos, includes)
138    }
139
140    pub fn annotate(&self, config: &mut prost_build::Config, descriptor: &DescriptorPool) {
141        for message in descriptor.all_messages() {
142            let full_name = message.full_name();
143            config.type_attribute(full_name, "#[derive(::prost_validate::Validator)]");
144            if message.validation_ignored() || message.validation_disabled() {
145                continue;
146            }
147            let mut oneofs: HashMap<String, Rc<OneofDescriptor>> = HashMap::new();
148            for field in message.fields() {
149                config.field_attribute(
150                    field.full_name(),
151                    format!("#[validate(name = \"{}\")]", field.full_name()),
152                );
153                let field_rules = match field.validation_rules().unwrap() {
154                    Some(r) => r,
155                    None => continue,
156                };
157                if oneofs.contains_key(field.full_name()) {
158                    continue;
159                }
160                if let Some(ref desc) = field.real_oneof() {
161                    config.field_attribute(
162                        desc.full_name(),
163                        format!("#[validate(name = \"{}\")]", desc.full_name()),
164                    );
165                    let desc = Rc::new(desc.clone());
166                    config
167                        .type_attribute(desc.full_name(), "#[derive(::prost_validate::Validator)]");
168                    if desc.required() {
169                        config.field_attribute(desc.full_name(), "#[validate(required)]");
170                    }
171                    for field in desc.fields() {
172                        let field = field.clone();
173                        config.field_attribute(
174                            format!("{}.{}", desc.full_name(), field.name()),
175                            format!("#[validate(name = \"{}\")]", field.full_name()),
176                        );
177                        oneofs.insert(field.full_name().to_string(), desc.clone());
178                        let field_rules = match field.validation_rules().unwrap() {
179                            Some(r) => r,
180                            None => continue,
181                        };
182                        let field_attribute = field_rules.into_field_attribute();
183                        if let Some(attribute) = field_attribute {
184                            // this is not very protobuf typical, but it is the way it is implemented in prost-build
185                            config.field_attribute(
186                                format!("{}.{}", desc.full_name(), field.name()),
187                                format!("#[validate({})]", attribute),
188                            );
189                        }
190                    }
191                    continue;
192                }
193                let field_attribute = field_rules.into_field_attribute();
194                if field.optional() {
195                    config.field_attribute(field.full_name(), "#[validate(optional)]");
196                }
197                if let Some(attribute) = field_attribute {
198                    config
199                        .field_attribute(field.full_name(), format!("#[validate({})]", attribute));
200                }
201            }
202        }
203    }
204}