prost_protovalidate/
validator.rs1use prost_reflect::ReflectMessage;
2use std::sync::LazyLock;
3
4use crate::config::{NopFilter, ValidationConfig, ValidationOption, ValidatorOption};
5use crate::error::Error;
6
7mod builder;
8mod evaluator;
9mod lookups;
10mod resolve;
11mod rules;
12
13use builder::Builder;
14use evaluator::MessageEvaluator;
15
16pub struct Validator {
21 builder: Builder,
22 config: ValidationConfig,
23}
24
25impl Validator {
26 #[must_use]
28 pub fn new() -> Self {
29 Self {
30 builder: Builder::new(),
31 config: ValidationConfig::default(),
32 }
33 }
34
35 #[must_use]
37 pub fn with_options(options: &[ValidatorOption]) -> Self {
38 let mut fail_fast = false;
39 let mut disable_lazy = false;
40 let mut allow_unknown_fields = false;
41 let mut additional_descriptor_sets = Vec::new();
42 let mut message_descriptors = Vec::new();
43 let mut now_fn: std::sync::Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
44 std::sync::Arc::new(|| {
45 let now = std::time::SystemTime::now()
46 .duration_since(std::time::UNIX_EPOCH)
47 .unwrap_or_default();
48 prost_types::Timestamp {
49 seconds: i64::try_from(now.as_secs()).unwrap_or(i64::MAX),
50 nanos: i32::try_from(now.subsec_nanos()).unwrap_or(0),
51 }
52 });
53
54 for opt in options {
55 match opt {
56 ValidatorOption::FailFast => fail_fast = true,
57 ValidatorOption::DisableLazy => disable_lazy = true,
58 ValidatorOption::AllowUnknownFields => allow_unknown_fields = true,
59 ValidatorOption::NowFn(f) => now_fn = std::sync::Arc::clone(f),
60 ValidatorOption::AdditionalDescriptorSetBytes(bytes) => {
61 additional_descriptor_sets.push(bytes.clone());
62 }
63 ValidatorOption::MessageDescriptors(descriptors) => {
64 message_descriptors.extend(descriptors.iter().cloned());
65 }
66 }
67 }
68
69 let builder = Builder::with_config(
70 !disable_lazy,
71 allow_unknown_fields,
72 &additional_descriptor_sets,
73 );
74 for descriptor in &message_descriptors {
75 builder.preload(descriptor);
76 }
77
78 Self {
79 builder,
80 config: ValidationConfig {
81 fail_fast,
82 filter: std::sync::Arc::new(NopFilter),
83 now_fn,
84 },
85 }
86 }
87
88 pub fn validate<M: ReflectMessage>(&self, msg: &M) -> Result<(), Error> {
95 self.validate_with(msg, &[])
96 }
97
98 pub fn validate_with<M: ReflectMessage>(
105 &self,
106 msg: &M,
107 options: &[ValidationOption],
108 ) -> Result<(), Error> {
109 let dynamic = msg.transcode_to_dynamic();
110 let descriptor = dynamic.descriptor();
111 let eval = self.builder.load_or_build(&descriptor);
112 let cfg = effective_config(&self.config, options);
113 eval.evaluate_message(&dynamic, &cfg)
114 }
115}
116
117fn effective_config(base: &ValidationConfig, options: &[ValidationOption]) -> ValidationConfig {
118 let mut cfg = ValidationConfig {
119 fail_fast: base.fail_fast,
120 filter: std::sync::Arc::clone(&base.filter),
121 now_fn: std::sync::Arc::clone(&base.now_fn),
122 };
123
124 for option in options {
125 match option {
126 ValidationOption::FailFast => cfg.fail_fast = true,
127 ValidationOption::Filter(filter) => cfg.filter = std::sync::Arc::clone(filter),
128 ValidationOption::NowFn(now_fn) => cfg.now_fn = std::sync::Arc::clone(now_fn),
129 }
130 }
131
132 cfg
133}
134
135impl Default for Validator {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141static GLOBAL_VALIDATOR: LazyLock<Validator> = LazyLock::new(Validator::new);
142
143pub fn validate<M: ReflectMessage>(msg: &M) -> Result<(), Error> {
154 GLOBAL_VALIDATOR.validate(msg)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::config::Filter;
161 use prost_reflect::{DynamicMessage, MessageDescriptor, ReflectMessage};
162 use std::sync::Arc;
163 use std::sync::atomic::{AtomicBool, Ordering};
164
165 struct DenyFilter;
166
167 impl Filter for DenyFilter {
168 fn should_validate(
169 &self,
170 _message: &DynamicMessage,
171 _descriptor: &MessageDescriptor,
172 ) -> bool {
173 false
174 }
175 }
176
177 struct RuntimeFilter {
178 seen_required_true: Arc<AtomicBool>,
179 }
180
181 impl Filter for RuntimeFilter {
182 fn should_validate(
183 &self,
184 message: &DynamicMessage,
185 _descriptor: &MessageDescriptor,
186 ) -> bool {
187 let Some(required) = message.descriptor().get_field_by_name("required") else {
188 return true;
189 };
190
191 let required_is_true = message.get_field(&required).as_bool() == Some(true);
192 if required_is_true {
193 self.seen_required_true.store(true, Ordering::Relaxed);
194 }
195 !required_is_true
196 }
197 }
198
199 #[test]
200 fn validation_options_override_call_config_only() {
201 let base = ValidationConfig::default();
202 let now_fn: Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
203 Arc::new(|| prost_types::Timestamp {
204 seconds: 123,
205 nanos: 456,
206 });
207 let options = vec![
208 ValidationOption::FailFast,
209 ValidationOption::Filter(Arc::new(DenyFilter)),
210 ValidationOption::NowFn(Arc::clone(&now_fn)),
211 ];
212
213 let effective = effective_config(&base, &options);
214 let descriptor = prost_protovalidate_types::DESCRIPTOR_POOL
215 .get_message_by_name("buf.validate.FieldRules")
216 .expect("message descriptor exists");
217 let dynamic = prost_reflect::DynamicMessage::new(descriptor.clone());
218
219 assert!(effective.fail_fast);
220 assert_eq!((effective.now_fn)().seconds, 123);
221 assert!(!effective.filter.should_validate(&dynamic, &descriptor));
222
223 assert!(!base.fail_fast);
224 }
225
226 #[test]
227 fn validate_with_keeps_existing_validate_behavior() {
228 let validator = Validator::new();
229 let msg = prost_protovalidate_types::BoolRules::default();
230
231 assert!(validator.validate(&msg).is_ok());
232 assert!(
233 validator
234 .validate_with(&msg, &[ValidationOption::FailFast])
235 .is_ok()
236 );
237 }
238
239 #[test]
240 fn invalid_additional_descriptor_set_surfaces_compilation_error() {
241 let validator =
242 Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(vec![
243 0x01, 0x02, 0x03,
244 ])]);
245 let msg = prost_protovalidate_types::BoolRules::default();
246
247 match validator.validate(&msg) {
248 Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
249 Err(Error::Compilation(err)) => {
250 assert!(
251 err.cause
252 .contains("failed to decode additional descriptor set at index 0")
253 );
254 }
255 Err(other) => panic!("unexpected error type: {other}"),
256 }
257 }
258
259 #[test]
260 fn valid_additional_descriptor_set_keeps_validator_operational() {
261 let descriptor_bytes = prost_protovalidate_types::DESCRIPTOR_POOL.encode_to_vec();
262 let validator = Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(
263 descriptor_bytes,
264 )]);
265 let msg = prost_protovalidate_types::BoolRules::default();
266
267 assert!(validator.validate(&msg).is_ok());
268 }
269
270 #[test]
271 fn message_descriptor_preload_supports_disable_lazy_with_known_messages() {
272 let descriptor = prost_protovalidate_types::BoolRules::default().descriptor();
273 let validator = Validator::with_options(&[
274 ValidatorOption::MessageDescriptors(vec![descriptor]),
275 ValidatorOption::DisableLazy,
276 ]);
277 let msg = prost_protovalidate_types::BoolRules::default();
278
279 assert!(validator.validate(&msg).is_ok());
281 }
282
283 #[test]
284 fn runtime_filter_can_skip_based_on_message_content() {
285 let validator = Validator::new();
286 let descriptor = prost_protovalidate_types::FieldRules::default().descriptor();
287 let mut msg = prost_reflect::DynamicMessage::new(descriptor.clone());
288 let seen_required_true = Arc::new(AtomicBool::new(false));
289 let required = descriptor
290 .get_field_by_name("required")
291 .expect("required field exists");
292 msg.set_field(&required, prost_reflect::Value::Bool(true));
293
294 assert!(
295 validator
296 .validate_with(
297 &msg,
298 &[ValidationOption::Filter(Arc::new(RuntimeFilter {
299 seen_required_true: Arc::clone(&seen_required_true),
300 }))],
301 )
302 .is_ok()
303 );
304 assert!(seen_required_true.load(Ordering::Relaxed));
305 }
306}