Skip to main content

prost_protovalidate/
validator.rs

1use std::sync::LazyLock;
2
3use prost_reflect::ReflectMessage;
4
5use crate::config::{NopFilter, ValidationConfig, ValidationOption, ValidatorOption};
6use crate::error::Error;
7
8mod builder;
9pub(crate) mod editions;
10mod evaluator;
11mod lookups;
12mod rules;
13
14use builder::Builder;
15use evaluator::MessageEvaluator;
16
17/// Thread-safe validator for Protocol Buffer messages.
18///
19/// Validates messages against `buf.validate` rules extracted from proto descriptors.
20/// Evaluators are compiled lazily and cached for reuse.
21pub struct Validator {
22    builder: Builder,
23    config: ValidationConfig,
24}
25
26impl Validator {
27    /// Create a new `Validator` with default options.
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            builder: Builder::new(),
32            config: ValidationConfig::default(),
33        }
34    }
35
36    /// Create a new `Validator` with the given options.
37    #[must_use]
38    pub fn with_options(options: &[ValidatorOption]) -> Self {
39        let mut fail_fast = false;
40        let mut disable_lazy = false;
41        let mut allow_unknown_fields = false;
42        let mut additional_descriptor_sets = Vec::new();
43        let mut message_descriptors = Vec::new();
44        let mut now_fn = crate::config::default_now_fn();
45
46        for opt in options {
47            match opt {
48                ValidatorOption::FailFast => fail_fast = true,
49                ValidatorOption::DisableLazy => disable_lazy = true,
50                ValidatorOption::AllowUnknownFields => allow_unknown_fields = true,
51                ValidatorOption::NowFn(f) => now_fn = std::sync::Arc::clone(f),
52                ValidatorOption::AdditionalDescriptorSetBytes(bytes) => {
53                    additional_descriptor_sets.push(bytes.clone());
54                }
55                ValidatorOption::MessageDescriptors(descriptors) => {
56                    message_descriptors.extend(descriptors.iter().cloned());
57                }
58            }
59        }
60
61        let builder = Builder::with_config(
62            !disable_lazy,
63            allow_unknown_fields,
64            &additional_descriptor_sets,
65        );
66        for descriptor in &message_descriptors {
67            builder.preload(descriptor);
68        }
69
70        Self {
71            builder,
72            config: ValidationConfig {
73                fail_fast,
74                filter: std::sync::Arc::new(NopFilter),
75                now_fn,
76            },
77        }
78    }
79
80    /// Validate a message against its `buf.validate` rules.
81    ///
82    /// # Errors
83    ///
84    /// Returns an `Error` containing all constraint violations found, or a
85    /// compilation/runtime error if rule evaluation fails.
86    pub fn validate<M: ReflectMessage>(&self, msg: &M) -> Result<(), Error> {
87        self.validate_with(msg, &[])
88    }
89
90    /// Validate a message with per-call validation options.
91    ///
92    /// # Errors
93    ///
94    /// Returns an `Error` containing all constraint violations found, or a
95    /// compilation/runtime error if rule evaluation fails.
96    pub fn validate_with<M: ReflectMessage>(
97        &self,
98        msg: &M,
99        options: &[ValidationOption],
100    ) -> Result<(), Error> {
101        let dynamic = msg.transcode_to_dynamic();
102        let descriptor = dynamic.descriptor();
103        let eval = self.builder.load_or_build(&descriptor);
104        let cfg = effective_config(&self.config, options);
105        eval.evaluate_message(&dynamic, &cfg)
106    }
107}
108
109fn effective_config(base: &ValidationConfig, options: &[ValidationOption]) -> ValidationConfig {
110    let mut cfg = ValidationConfig {
111        fail_fast: base.fail_fast,
112        filter: std::sync::Arc::clone(&base.filter),
113        now_fn: std::sync::Arc::clone(&base.now_fn),
114    };
115
116    for option in options {
117        match option {
118            ValidationOption::FailFast => cfg.fail_fast = true,
119            ValidationOption::Filter(filter) => cfg.filter = std::sync::Arc::clone(filter),
120            ValidationOption::NowFn(now_fn) => cfg.now_fn = std::sync::Arc::clone(now_fn),
121        }
122    }
123
124    cfg
125}
126
127impl Default for Validator {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133static GLOBAL_VALIDATOR: LazyLock<Validator> = LazyLock::new(Validator::new);
134
135/// Validate a message using a global `Validator` instance.
136///
137/// This is a convenience function that uses a shared, lazily-initialized
138/// validator. For lower memory usage, prefer using a single `Validator`
139/// instance rather than creating multiple instances.
140///
141/// # Errors
142///
143/// Returns an `Error` containing all constraint violations found, or a
144/// compilation/runtime error if rule evaluation fails.
145pub fn validate<M: ReflectMessage>(msg: &M) -> Result<(), Error> {
146    GLOBAL_VALIDATOR.validate(msg)
147}
148
149#[cfg(test)]
150mod tests {
151    use std::sync::Arc;
152    use std::sync::atomic::{AtomicBool, Ordering};
153
154    use pretty_assertions::assert_eq;
155    use prost_reflect::{DynamicMessage, MessageDescriptor, ReflectMessage};
156
157    use super::*;
158    use crate::config::Filter;
159
160    struct DenyFilter;
161
162    impl Filter for DenyFilter {
163        fn should_validate(
164            &self,
165            _message: &DynamicMessage,
166            _descriptor: &MessageDescriptor,
167        ) -> bool {
168            false
169        }
170    }
171
172    struct RuntimeFilter {
173        seen_required_true: Arc<AtomicBool>,
174    }
175
176    impl Filter for RuntimeFilter {
177        fn should_validate(
178            &self,
179            message: &DynamicMessage,
180            _descriptor: &MessageDescriptor,
181        ) -> bool {
182            let Some(required) = message.descriptor().get_field_by_name("required") else {
183                return true;
184            };
185
186            let required_is_true = message.get_field(&required).as_bool() == Some(true);
187            if required_is_true {
188                self.seen_required_true.store(true, Ordering::Relaxed);
189            }
190            !required_is_true
191        }
192    }
193
194    #[test]
195    fn validation_options_override_call_config_only() {
196        let base = ValidationConfig::default();
197        let now_fn: Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
198            Arc::new(|| prost_types::Timestamp {
199                seconds: 123,
200                nanos: 456,
201            });
202        let options = vec![
203            ValidationOption::FailFast,
204            ValidationOption::Filter(Arc::new(DenyFilter)),
205            ValidationOption::NowFn(Arc::clone(&now_fn)),
206        ];
207
208        let effective = effective_config(&base, &options);
209        let descriptor = prost_protovalidate_types::DESCRIPTOR_POOL
210            .get_message_by_name("buf.validate.FieldRules")
211            .expect("message descriptor exists");
212        let dynamic = prost_reflect::DynamicMessage::new(descriptor.clone());
213
214        assert!(effective.fail_fast);
215        assert_eq!((effective.now_fn)().seconds, 123);
216        assert!(!effective.filter.should_validate(&dynamic, &descriptor));
217
218        assert!(!base.fail_fast);
219    }
220
221    #[test]
222    fn validate_with_keeps_existing_validate_behavior() {
223        let validator = Validator::new();
224        let msg = prost_protovalidate_types::BoolRules::default();
225
226        assert!(validator.validate(&msg).is_ok());
227        assert!(
228            validator
229                .validate_with(&msg, &[ValidationOption::FailFast])
230                .is_ok()
231        );
232    }
233
234    #[test]
235    fn invalid_additional_descriptor_set_surfaces_compilation_error() {
236        let validator =
237            Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(vec![
238                0x01, 0x02, 0x03,
239            ])]);
240        let msg = prost_protovalidate_types::BoolRules::default();
241
242        match validator.validate(&msg) {
243            Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
244            Err(Error::Compilation(err)) => {
245                assert!(
246                    err.cause
247                        .contains("failed to decode additional descriptor set at index 0")
248                );
249            }
250            Err(other) => panic!("unexpected error type: {other}"),
251        }
252    }
253
254    #[test]
255    fn invalid_additional_descriptor_set_never_panics() {
256        let result = std::panic::catch_unwind(|| {
257            let validator = Validator::with_options(&[
258                ValidatorOption::AdditionalDescriptorSetBytes(vec![0x01, 0x02, 0x03]),
259            ]);
260            let msg = prost_protovalidate_types::BoolRules::default();
261            validator.validate(&msg)
262        });
263
264        let validation_result = result.expect("invalid descriptor sets must not panic");
265        match validation_result {
266            Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
267            Err(Error::Compilation(err)) => {
268                assert!(
269                    err.cause
270                        .contains("failed to decode additional descriptor set at index 0")
271                );
272            }
273            Err(other) => panic!("unexpected error type: {other}"),
274        }
275    }
276
277    #[test]
278    fn valid_additional_descriptor_set_keeps_validator_operational() {
279        let descriptor_bytes = Vec::new();
280        let validator = Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(
281            descriptor_bytes,
282        )]);
283        let msg = prost_protovalidate_types::BoolRules::default();
284
285        assert!(validator.validate(&msg).is_ok());
286    }
287
288    #[test]
289    fn message_descriptor_preload_supports_disable_lazy_with_known_messages() {
290        let descriptor = prost_protovalidate_types::BoolRules::default().descriptor();
291        let validator = Validator::with_options(&[
292            ValidatorOption::MessageDescriptors(vec![descriptor]),
293            ValidatorOption::DisableLazy,
294        ]);
295        let msg = prost_protovalidate_types::BoolRules::default();
296
297        // BoolRules has no constraints on itself, so validation should pass
298        assert!(validator.validate(&msg).is_ok());
299    }
300
301    #[test]
302    fn runtime_filter_can_skip_based_on_message_content() {
303        let validator = Validator::new();
304        let descriptor = prost_protovalidate_types::FieldRules::default().descriptor();
305        let mut msg = prost_reflect::DynamicMessage::new(descriptor.clone());
306        let seen_required_true = Arc::new(AtomicBool::new(false));
307        let required = descriptor
308            .get_field_by_name("required")
309            .expect("required field exists");
310        msg.set_field(&required, prost_reflect::Value::Bool(true));
311
312        assert!(
313            validator
314                .validate_with(
315                    &msg,
316                    &[ValidationOption::Filter(Arc::new(RuntimeFilter {
317                        seen_required_true: Arc::clone(&seen_required_true),
318                    }))],
319                )
320                .is_ok()
321        );
322        assert!(seen_required_true.load(Ordering::Relaxed));
323    }
324}