Skip to main content

prost_protovalidate/
validator.rs

1use prost_reflect::ReflectMessage;
2use std::sync::LazyLock;
3
4use crate::config::{NopFilter, ValidationConfig, ValidationOption, ValidatorOption};
5use crate::error::Error;
6
7mod builder;
8mod evaluator;
9mod lookups;
10mod resolve;
11mod rules;
12
13use builder::Builder;
14use evaluator::MessageEvaluator;
15
16/// Thread-safe validator for Protocol Buffer messages.
17///
18/// Validates messages against `buf.validate` rules extracted from proto descriptors.
19/// Evaluators are compiled lazily and cached for reuse.
20pub struct Validator {
21    builder: Builder,
22    config: ValidationConfig,
23}
24
25impl Validator {
26    /// Create a new `Validator` with default options.
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            builder: Builder::new(),
31            config: ValidationConfig::default(),
32        }
33    }
34
35    /// Create a new `Validator` with the given options.
36    #[must_use]
37    pub fn with_options(options: &[ValidatorOption]) -> Self {
38        let mut fail_fast = false;
39        let mut disable_lazy = false;
40        let mut allow_unknown_fields = false;
41        let mut additional_descriptor_sets = Vec::new();
42        let mut message_descriptors = Vec::new();
43        let mut now_fn: std::sync::Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
44            std::sync::Arc::new(|| {
45                let now = std::time::SystemTime::now()
46                    .duration_since(std::time::UNIX_EPOCH)
47                    .unwrap_or_default();
48                prost_types::Timestamp {
49                    seconds: i64::try_from(now.as_secs()).unwrap_or(i64::MAX),
50                    nanos: i32::try_from(now.subsec_nanos()).unwrap_or(0),
51                }
52            });
53
54        for opt in options {
55            match opt {
56                ValidatorOption::FailFast => fail_fast = true,
57                ValidatorOption::DisableLazy => disable_lazy = true,
58                ValidatorOption::AllowUnknownFields => allow_unknown_fields = true,
59                ValidatorOption::NowFn(f) => now_fn = std::sync::Arc::clone(f),
60                ValidatorOption::AdditionalDescriptorSetBytes(bytes) => {
61                    additional_descriptor_sets.push(bytes.clone());
62                }
63                ValidatorOption::MessageDescriptors(descriptors) => {
64                    message_descriptors.extend(descriptors.iter().cloned());
65                }
66            }
67        }
68
69        let builder = Builder::with_config(
70            !disable_lazy,
71            allow_unknown_fields,
72            &additional_descriptor_sets,
73        );
74        for descriptor in &message_descriptors {
75            builder.preload(descriptor);
76        }
77
78        Self {
79            builder,
80            config: ValidationConfig {
81                fail_fast,
82                filter: std::sync::Arc::new(NopFilter),
83                now_fn,
84            },
85        }
86    }
87
88    /// Validate a message against its `buf.validate` rules.
89    ///
90    /// # Errors
91    ///
92    /// Returns an `Error` containing all constraint violations found, or a
93    /// compilation/runtime error if rule evaluation fails.
94    pub fn validate<M: ReflectMessage>(&self, msg: &M) -> Result<(), Error> {
95        self.validate_with(msg, &[])
96    }
97
98    /// Validate a message with per-call validation options.
99    ///
100    /// # Errors
101    ///
102    /// Returns an `Error` containing all constraint violations found, or a
103    /// compilation/runtime error if rule evaluation fails.
104    pub fn validate_with<M: ReflectMessage>(
105        &self,
106        msg: &M,
107        options: &[ValidationOption],
108    ) -> Result<(), Error> {
109        let dynamic = msg.transcode_to_dynamic();
110        let descriptor = dynamic.descriptor();
111        let eval = self.builder.load_or_build(&descriptor);
112        let cfg = effective_config(&self.config, options);
113        eval.evaluate_message(&dynamic, &cfg)
114    }
115}
116
117fn effective_config(base: &ValidationConfig, options: &[ValidationOption]) -> ValidationConfig {
118    let mut cfg = ValidationConfig {
119        fail_fast: base.fail_fast,
120        filter: std::sync::Arc::clone(&base.filter),
121        now_fn: std::sync::Arc::clone(&base.now_fn),
122    };
123
124    for option in options {
125        match option {
126            ValidationOption::FailFast => cfg.fail_fast = true,
127            ValidationOption::Filter(filter) => cfg.filter = std::sync::Arc::clone(filter),
128            ValidationOption::NowFn(now_fn) => cfg.now_fn = std::sync::Arc::clone(now_fn),
129        }
130    }
131
132    cfg
133}
134
135impl Default for Validator {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141static GLOBAL_VALIDATOR: LazyLock<Validator> = LazyLock::new(Validator::new);
142
143/// Validate a message using a global `Validator` instance.
144///
145/// This is a convenience function that uses a shared, lazily-initialized
146/// validator. For lower memory usage, prefer using a single `Validator`
147/// instance rather than creating multiple instances.
148///
149/// # Errors
150///
151/// Returns an `Error` containing all constraint violations found, or a
152/// compilation/runtime error if rule evaluation fails.
153pub fn validate<M: ReflectMessage>(msg: &M) -> Result<(), Error> {
154    GLOBAL_VALIDATOR.validate(msg)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::config::Filter;
161    use prost_reflect::{DynamicMessage, MessageDescriptor, ReflectMessage};
162    use std::sync::Arc;
163    use std::sync::atomic::{AtomicBool, Ordering};
164
165    struct DenyFilter;
166
167    impl Filter for DenyFilter {
168        fn should_validate(
169            &self,
170            _message: &DynamicMessage,
171            _descriptor: &MessageDescriptor,
172        ) -> bool {
173            false
174        }
175    }
176
177    struct RuntimeFilter {
178        seen_required_true: Arc<AtomicBool>,
179    }
180
181    impl Filter for RuntimeFilter {
182        fn should_validate(
183            &self,
184            message: &DynamicMessage,
185            _descriptor: &MessageDescriptor,
186        ) -> bool {
187            let Some(required) = message.descriptor().get_field_by_name("required") else {
188                return true;
189            };
190
191            let required_is_true = message.get_field(&required).as_bool() == Some(true);
192            if required_is_true {
193                self.seen_required_true.store(true, Ordering::Relaxed);
194            }
195            !required_is_true
196        }
197    }
198
199    #[test]
200    fn validation_options_override_call_config_only() {
201        let base = ValidationConfig::default();
202        let now_fn: Arc<dyn Fn() -> prost_types::Timestamp + Send + Sync> =
203            Arc::new(|| prost_types::Timestamp {
204                seconds: 123,
205                nanos: 456,
206            });
207        let options = vec![
208            ValidationOption::FailFast,
209            ValidationOption::Filter(Arc::new(DenyFilter)),
210            ValidationOption::NowFn(Arc::clone(&now_fn)),
211        ];
212
213        let effective = effective_config(&base, &options);
214        let descriptor = prost_protovalidate_types::DESCRIPTOR_POOL
215            .get_message_by_name("buf.validate.FieldRules")
216            .expect("message descriptor exists");
217        let dynamic = prost_reflect::DynamicMessage::new(descriptor.clone());
218
219        assert!(effective.fail_fast);
220        assert_eq!((effective.now_fn)().seconds, 123);
221        assert!(!effective.filter.should_validate(&dynamic, &descriptor));
222
223        assert!(!base.fail_fast);
224    }
225
226    #[test]
227    fn validate_with_keeps_existing_validate_behavior() {
228        let validator = Validator::new();
229        let msg = prost_protovalidate_types::BoolRules::default();
230
231        assert!(validator.validate(&msg).is_ok());
232        assert!(
233            validator
234                .validate_with(&msg, &[ValidationOption::FailFast])
235                .is_ok()
236        );
237    }
238
239    #[test]
240    fn invalid_additional_descriptor_set_surfaces_compilation_error() {
241        let validator =
242            Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(vec![
243                0x01, 0x02, 0x03,
244            ])]);
245        let msg = prost_protovalidate_types::BoolRules::default();
246
247        match validator.validate(&msg) {
248            Ok(()) => panic!("invalid descriptor set bytes must fail validator initialization"),
249            Err(Error::Compilation(err)) => {
250                assert!(
251                    err.cause
252                        .contains("failed to decode additional descriptor set at index 0")
253                );
254            }
255            Err(other) => panic!("unexpected error type: {other}"),
256        }
257    }
258
259    #[test]
260    fn valid_additional_descriptor_set_keeps_validator_operational() {
261        let descriptor_bytes = prost_protovalidate_types::DESCRIPTOR_POOL.encode_to_vec();
262        let validator = Validator::with_options(&[ValidatorOption::AdditionalDescriptorSetBytes(
263            descriptor_bytes,
264        )]);
265        let msg = prost_protovalidate_types::BoolRules::default();
266
267        assert!(validator.validate(&msg).is_ok());
268    }
269
270    #[test]
271    fn message_descriptor_preload_supports_disable_lazy_with_known_messages() {
272        let descriptor = prost_protovalidate_types::BoolRules::default().descriptor();
273        let validator = Validator::with_options(&[
274            ValidatorOption::MessageDescriptors(vec![descriptor]),
275            ValidatorOption::DisableLazy,
276        ]);
277        let msg = prost_protovalidate_types::BoolRules::default();
278
279        // BoolRules has no constraints on itself, so validation should pass
280        assert!(validator.validate(&msg).is_ok());
281    }
282
283    #[test]
284    fn runtime_filter_can_skip_based_on_message_content() {
285        let validator = Validator::new();
286        let descriptor = prost_protovalidate_types::FieldRules::default().descriptor();
287        let mut msg = prost_reflect::DynamicMessage::new(descriptor.clone());
288        let seen_required_true = Arc::new(AtomicBool::new(false));
289        let required = descriptor
290            .get_field_by_name("required")
291            .expect("required field exists");
292        msg.set_field(&required, prost_reflect::Value::Bool(true));
293
294        assert!(
295            validator
296                .validate_with(
297                    &msg,
298                    &[ValidationOption::Filter(Arc::new(RuntimeFilter {
299                        seen_required_true: Arc::clone(&seen_required_true),
300                    }))],
301                )
302                .is_ok()
303        );
304        assert!(seen_required_true.load(Ordering::Relaxed));
305    }
306}