prost_protovalidate/
validator.rs1use std::sync::LazyLock;
2
3use prost_reflect::ReflectMessage;
4
5use crate::config::{NopFilter, ValidationConfig, ValidationOption, ValidatorOption};
6use crate::error::Error;
7
8mod builder;
9pub(crate) mod editions;
10mod evaluator;
11mod lookups;
12mod rules;
13
14use builder::Builder;
15use evaluator::MessageEvaluator;
16
17pub struct Validator {
22 builder: Builder,
23 config: ValidationConfig,
24}
25
26impl Validator {
27 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 builder: Builder::new(),
32 config: ValidationConfig::default(),
33 }
34 }
35
36 #[must_use]
38 pub fn with_options(options: &[ValidatorOption]) -> Self {
39 let mut fail_fast = false;
40 let mut disable_lazy = false;
41 let mut allow_unknown_fields = false;
42 let mut additional_descriptor_sets = Vec::new();
43 let mut message_descriptors = Vec::new();
44 let mut now_fn = crate::config::default_now_fn();
45
46 for opt in options {
47 match opt {
48 ValidatorOption::FailFast => fail_fast = true,
49 ValidatorOption::DisableLazy => disable_lazy = true,
50 ValidatorOption::AllowUnknownFields => allow_unknown_fields = true,
51 ValidatorOption::NowFn(f) => now_fn = std::sync::Arc::clone(f),
52 ValidatorOption::AdditionalDescriptorSetBytes(bytes) => {
53 additional_descriptor_sets.push(bytes.clone());
54 }
55 ValidatorOption::MessageDescriptors(descriptors) => {
56 message_descriptors.extend(descriptors.iter().cloned());
57 }
58 }
59 }
60
61 let builder = Builder::with_config(
62 !disable_lazy,
63 allow_unknown_fields,
64 &additional_descriptor_sets,
65 );
66 for descriptor in &message_descriptors {
67 builder.preload(descriptor);
68 }
69
70 Self {
71 builder,
72 config: ValidationConfig {
73 fail_fast,
74 filter: std::sync::Arc::new(NopFilter),
75 now_fn,
76 },
77 }
78 }
79
80 pub fn validate<M: ReflectMessage>(&self, msg: &M) -> Result<(), Error> {
87 self.validate_with(msg, &[])
88 }
89
90 pub fn validate_with<M: ReflectMessage>(
97 &self,
98 msg: &M,
99 options: &[ValidationOption],
100 ) -> Result<(), Error> {
101 let dynamic = msg.transcode_to_dynamic();
102 let descriptor = dynamic.descriptor();
103 let eval = self.builder.load_or_build(&descriptor);
104 let cfg = effective_config(&self.config, options);
105 eval.evaluate_message(&dynamic, &cfg)
106 }
107}
108
109fn effective_config(base: &ValidationConfig, options: &[ValidationOption]) -> ValidationConfig {
110 let mut cfg = ValidationConfig {
111 fail_fast: base.fail_fast,
112 filter: std::sync::Arc::clone(&base.filter),
113 now_fn: std::sync::Arc::clone(&base.now_fn),
114 };
115
116 for option in options {
117 match option {
118 ValidationOption::FailFast => cfg.fail_fast = true,
119 ValidationOption::Filter(filter) => cfg.filter = std::sync::Arc::clone(filter),
120 ValidationOption::NowFn(now_fn) => cfg.now_fn = std::sync::Arc::clone(now_fn),
121 }
122 }
123
124 cfg
125}
126
127impl Default for Validator {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133static GLOBAL_VALIDATOR: LazyLock<Validator> = LazyLock::new(Validator::new);
134
135pub fn validate<M: ReflectMessage>(msg: &M) -> Result<(), Error> {
146 GLOBAL_VALIDATOR.validate(msg)
147}
148
149#[cfg(test)]
150mod tests {
151 use std::sync::Arc;
152 use std::sync::atomic::{AtomicBool, Ordering};
153
154 use pretty_assertions::assert_eq;
155 use prost_reflect::{DynamicMessage, MessageDescriptor, ReflectMessage};
156
157 use super::*;
158 use crate::config::Filter;
159
160 struct DenyFilter;
161
162 impl Filter for DenyFilter {
163 fn should_validate(
164 &self,
165 _message: &DynamicMessage,
166 _descriptor: &MessageDescriptor,
167 ) -> bool {
168 false
169 }
170 }
171
172 struct RuntimeFilter {
173 seen_required_true: Arc<AtomicBool>,
174 }
175
176 impl Filter for RuntimeFilter {
177 fn should_validate(
178 &self,
179 message: &DynamicMessage,
180 _descriptor: &MessageDescriptor,
181 ) -> bool {
182 let Some(required) = message.descriptor().get_field_by_name("required") else {
183 return true;
184 };
185
186 let required_is_true = message.get_field(&required).as_bool() == Some(true);
187 if required_is_true {
188 self.seen_required_true.store(true, Ordering::Relaxed);
189 }
190 !required_is_true
191 }
192 }
193
194 #[test]
195 fn validation_options_override_call_config_only() {
196 let base = ValidationConfig::default();
197 let now_fn: Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
198 Arc::new(|| prost_types::Timestamp {
199 seconds: 123,
200 nanos: 456,
201 });
202 let options = vec![
203 ValidationOption::FailFast,
204 ValidationOption::Filter(Arc::new(DenyFilter)),
205 ValidationOption::NowFn(Arc::clone(&now_fn)),
206 ];
207
208 let effective = effective_config(&base, &options);
209 let descriptor = prost_protovalidate_types::DESCRIPTOR_POOL
210 .get_message_by_name("buf.validate.FieldRules")
211 .expect("message descriptor exists");
212 let dynamic = prost_reflect::DynamicMessage::new(descriptor.clone());
213
214 assert!(effective.fail_fast);
215 assert_eq!((effective.now_fn)().seconds, 123);
216 assert!(!effective.filter.should_validate(&dynamic, &descriptor));
217
218 assert!(!base.fail_fast);
219 }
220
221 #[test]
222 fn validate_with_keeps_existing_validate_behavior() {
223 let validator = Validator::new();
224 let msg = prost_protovalidate_types::BoolRules::default();
225
226 assert!(validator.validate(&msg).is_ok());
227 assert!(
228 validator
229 .validate_with(&msg, &[ValidationOption::FailFast])
230 .is_ok()
231 );
232 }
233
234 #[test]
235 fn invalid_additional_descriptor_set_surfaces_compilation_error() {
236 let validator =
237 Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(vec![
238 0x01, 0x02, 0x03,
239 ])]);
240 let msg = prost_protovalidate_types::BoolRules::default();
241
242 match validator.validate(&msg) {
243 Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
244 Err(Error::Compilation(err)) => {
245 assert!(
246 err.cause
247 .contains("failed to decode additional descriptor set at index 0")
248 );
249 }
250 Err(other) => panic!("unexpected error type: {other}"),
251 }
252 }
253
254 #[test]
255 fn invalid_additional_descriptor_set_never_panics() {
256 let result = std::panic::catch_unwind(|| {
257 let validator = Validator::with_options(&[
258 ValidatorOption::AdditionalDescriptorSetBytes(vec![0x01, 0x02, 0x03]),
259 ]);
260 let msg = prost_protovalidate_types::BoolRules::default();
261 validator.validate(&msg)
262 });
263
264 let validation_result = result.expect("invalid descriptor sets must not panic");
265 match validation_result {
266 Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
267 Err(Error::Compilation(err)) => {
268 assert!(
269 err.cause
270 .contains("failed to decode additional descriptor set at index 0")
271 );
272 }
273 Err(other) => panic!("unexpected error type: {other}"),
274 }
275 }
276
277 #[test]
278 fn valid_additional_descriptor_set_keeps_validator_operational() {
279 let descriptor_bytes = Vec::new();
280 let validator = Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(
281 descriptor_bytes,
282 )]);
283 let msg = prost_protovalidate_types::BoolRules::default();
284
285 assert!(validator.validate(&msg).is_ok());
286 }
287
288 #[test]
289 fn message_descriptor_preload_supports_disable_lazy_with_known_messages() {
290 let descriptor = prost_protovalidate_types::BoolRules::default().descriptor();
291 let validator = Validator::with_options(&[
292 ValidatorOption::MessageDescriptors(vec![descriptor]),
293 ValidatorOption::DisableLazy,
294 ]);
295 let msg = prost_protovalidate_types::BoolRules::default();
296
297 assert!(validator.validate(&msg).is_ok());
299 }
300
301 #[test]
302 fn runtime_filter_can_skip_based_on_message_content() {
303 let validator = Validator::new();
304 let descriptor = prost_protovalidate_types::FieldRules::default().descriptor();
305 let mut msg = prost_reflect::DynamicMessage::new(descriptor.clone());
306 let seen_required_true = Arc::new(AtomicBool::new(false));
307 let required = descriptor
308 .get_field_by_name("required")
309 .expect("required field exists");
310 msg.set_field(&required, prost_reflect::Value::Bool(true));
311
312 assert!(
313 validator
314 .validate_with(
315 &msg,
316 &[ValidationOption::Filter(Arc::new(RuntimeFilter {
317 seen_required_true: Arc::clone(&seen_required_true),
318 }))],
319 )
320 .is_ok()
321 );
322 assert!(seen_required_true.load(Ordering::Relaxed));
323 }
324}