prost_protovalidate/
error.rs1use std::fmt;
2
3use crate::violation::Violation;
4
5#[derive(Debug, thiserror::Error)]
7#[non_exhaustive]
8pub enum Error {
9 #[error(transparent)]
11 Validation(#[from] ValidationError),
12
13 #[error(transparent)]
15 Compilation(#[from] CompilationError),
16
17 #[error(transparent)]
19 Runtime(#[from] RuntimeError),
20}
21
22#[derive(Debug)]
24pub struct ValidationError {
25 pub violations: Vec<Violation>,
27}
28
29impl fmt::Display for ValidationError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self.violations.len() {
32 0 => Ok(()),
33 1 => write!(f, "validation error: {}", self.violations[0]),
34 _ => {
35 write!(f, "validation errors:")?;
36 for v in &self.violations {
37 write!(f, "\n - {v}")?;
38 }
39 Ok(())
40 }
41 }
42 }
43}
44
45impl std::error::Error for ValidationError {}
46
47impl ValidationError {
48 pub(crate) fn new(violations: Vec<Violation>) -> Self {
49 Self { violations }
50 }
51
52 pub(crate) fn single(violation: Violation) -> Self {
53 Self {
54 violations: vec![violation],
55 }
56 }
57
58 #[must_use]
60 pub fn to_proto(&self) -> prost_protovalidate_types::Violations {
61 prost_protovalidate_types::Violations {
62 violations: self.violations.iter().map(|v| v.proto.clone()).collect(),
63 }
64 }
65}
66
67#[derive(Debug, thiserror::Error)]
69#[error("compilation error: {cause}")]
70pub struct CompilationError {
71 pub cause: String,
73}
74
75#[derive(Debug, thiserror::Error)]
77#[error("runtime error: {cause}")]
78pub struct RuntimeError {
79 pub cause: String,
81}
82
83pub(crate) fn merge_violations(
88 acc: Option<Error>,
89 new_err: Result<(), Error>,
90 fail_fast: bool,
91) -> (bool, Option<Error>) {
92 let new_err = match new_err {
93 Ok(()) => return (true, acc),
94 Err(e) => e,
95 };
96
97 match new_err {
98 Error::Compilation(_) | Error::Runtime(_) => (false, Some(new_err)),
99 Error::Validation(new_val) => {
100 if fail_fast {
101 return (false, Some(Error::Validation(new_val)));
102 }
103 match acc {
104 Some(Error::Validation(mut existing)) => {
105 existing.violations.extend(new_val.violations);
106 (true, Some(Error::Validation(existing)))
107 }
108 _ => (true, Some(Error::Validation(new_val))),
109 }
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::{Error, ValidationError, merge_violations};
117 use crate::violation::Violation;
118
119 fn validation_error(rule_id: &str) -> Error {
120 Error::Validation(ValidationError::single(Violation::new("", rule_id, "")))
121 }
122
123 #[test]
124 fn validation_error_display_matches_single_and_multiple_formats() {
125 let single = ValidationError::new(vec![Violation::new("one.two", "bar", "foo")]);
126 assert_eq!(single.to_string(), "validation error: one.two: foo");
127
128 let multiple = ValidationError::new(vec![
129 Violation::new("one.two", "bar", "foo"),
130 Violation::new("one.three", "bar", ""),
131 ]);
132 assert_eq!(
133 multiple.to_string(),
134 "validation errors:\n - one.two: foo\n - one.three: [bar]"
135 );
136 }
137
138 #[test]
139 fn merge_violations_handles_non_validation_and_validation_paths() {
140 let (cont, acc) = merge_violations(None, Ok(()), true);
141 assert!(cont);
142 assert!(acc.is_none());
143
144 let runtime = Error::Runtime(super::RuntimeError {
145 cause: "runtime failure".to_string(),
146 });
147 let (cont, acc) = merge_violations(None, Err(runtime), false);
148 assert!(!cont);
149 assert!(matches!(acc, Some(Error::Runtime(_))));
150
151 let (cont, acc) = merge_violations(None, Err(validation_error("foo")), true);
152 assert!(!cont);
153 let Some(Error::Validation(err)) = acc else {
154 panic!("expected validation error");
155 };
156 assert_eq!(err.violations.len(), 1);
157 assert_eq!(err.violations[0].rule_id, "foo");
158
159 let base = Some(validation_error("foo"));
160 let (cont, acc) = merge_violations(base, Err(validation_error("bar")), false);
161 assert!(cont);
162 let Some(Error::Validation(err)) = acc else {
163 panic!("expected merged validation error");
164 };
165 assert_eq!(err.violations.len(), 2);
166 assert_eq!(err.violations[0].rule_id, "foo");
167 assert_eq!(err.violations[1].rule_id, "bar");
168 }
169}