prost_validate_build/
lib.rs1mod rules;
15
16use crate::rules::IntoFieldAttribute;
17use prost_reflect::prost_types::FileDescriptorProto;
18use prost_reflect::{DescriptorPool, OneofDescriptor};
19use prost_validate_types::{FieldRulesExt, MessageRulesExt, OneofRulesExt};
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::rc::Rc;
23use std::{env, fs, io};
24
25#[derive(Debug, Clone)]
34pub struct Builder {
35 file_descriptor_set_path: PathBuf,
36}
37
38impl Default for Builder {
39 fn default() -> Self {
40 let file_descriptor_set_path = env::var_os("OUT_DIR")
41 .map(PathBuf::from)
42 .unwrap_or_else(|| PathBuf::from("."))
43 .join("file_descriptor_set.bin");
44
45 Self {
46 file_descriptor_set_path,
47 }
48 }
49}
50
51impl Builder {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn file_descriptor_set_path<P>(&mut self, path: P) -> &mut Self
63 where
64 P: Into<PathBuf>,
65 {
66 self.file_descriptor_set_path = path.into();
67 self
68 }
69
70 pub fn configure(
89 &mut self,
90 config: &mut prost_build::Config,
91 protos: &[impl AsRef<Path>],
92 includes: &[impl AsRef<Path>],
93 ) -> io::Result<()> {
94 config
95 .file_descriptor_set_path(&self.file_descriptor_set_path)
96 .compile_protos(protos, includes)?;
97
98 let buf = fs::read(&self.file_descriptor_set_path)?;
99 let descriptor = DescriptorPool::decode(buf.as_ref()).expect("Invalid file descriptor");
100 self.annotate(config, &descriptor);
101 Ok(())
102 }
103
104 pub fn configure_with_file_descriptor_protos(
105 &mut self,
106 config: &mut prost_build::Config,
107 protos: &[FileDescriptorProto],
108 ) -> io::Result<()> {
109 let descriptor = {
110 let mut d = DescriptorPool::new();
111 d.add_file_descriptor_protos(protos.to_owned())
112 .expect("Invalid file descriptor protos");
113 d
114 };
115 self.annotate(config, &descriptor);
116 Ok(())
117 }
118
119 pub fn compile_protos_with_config(
121 &mut self,
122 mut config: prost_build::Config,
123 protos: &[impl AsRef<Path>],
124 includes: &[impl AsRef<Path>],
125 ) -> io::Result<()> {
126 self.configure(&mut config, protos, includes)?;
127
128 config.skip_protoc_run().compile_protos(protos, includes)
129 }
130
131 pub fn compile_protos(
133 &mut self,
134 protos: &[impl AsRef<Path>],
135 includes: &[impl AsRef<Path>],
136 ) -> io::Result<()> {
137 self.compile_protos_with_config(prost_build::Config::new(), protos, includes)
138 }
139
140 pub fn annotate(&self, config: &mut prost_build::Config, descriptor: &DescriptorPool) {
141 for message in descriptor.all_messages() {
142 let full_name = message.full_name();
143 config.type_attribute(full_name, "#[derive(::prost_validate::Validator)]");
144 if message.validation_ignored() || message.validation_disabled() {
145 continue;
146 }
147 let mut oneofs: HashMap<String, Rc<OneofDescriptor>> = HashMap::new();
148 for field in message.fields() {
149 config.field_attribute(
150 field.full_name(),
151 format!("#[validate(name = \"{}\")]", field.full_name()),
152 );
153 let field_rules = match field.validation_rules().unwrap() {
154 Some(r) => r,
155 None => continue,
156 };
157 if oneofs.contains_key(field.full_name()) {
158 continue;
159 }
160 if let Some(ref desc) = field.real_oneof() {
161 config.field_attribute(
162 desc.full_name(),
163 format!("#[validate(name = \"{}\")]", desc.full_name()),
164 );
165 let desc = Rc::new(desc.clone());
166 config
167 .type_attribute(desc.full_name(), "#[derive(::prost_validate::Validator)]");
168 if desc.required() {
169 config.field_attribute(desc.full_name(), "#[validate(required)]");
170 }
171 for field in desc.fields() {
172 let field = field.clone();
173 config.field_attribute(
174 format!("{}.{}", desc.full_name(), field.name()),
175 format!("#[validate(name = \"{}\")]", field.full_name()),
176 );
177 oneofs.insert(field.full_name().to_string(), desc.clone());
178 let field_rules = match field.validation_rules().unwrap() {
179 Some(r) => r,
180 None => continue,
181 };
182 let field_attribute = field_rules.into_field_attribute();
183 if let Some(attribute) = field_attribute {
184 config.field_attribute(
186 format!("{}.{}", desc.full_name(), field.name()),
187 format!("#[validate({})]", attribute),
188 );
189 }
190 }
191 continue;
192 }
193 let field_attribute = field_rules.into_field_attribute();
194 if field.optional() {
195 config.field_attribute(field.full_name(), "#[validate(optional)]");
196 }
197 if let Some(attribute) = field_attribute {
198 config
199 .field_attribute(field.full_name(), format!("#[validate({})]", attribute));
200 }
201 }
202 }
203 }
204}