1use std::{
2 collections::HashSet,
3 fmt::{Debug, Display, Write},
4 hash::Hash,
5};
6
7use proc_macro2::{Span, TokenStream};
8use quote::{format_ident, quote, ToTokens};
9use syn::{Ident, LitByteStr};
10
11use crate::{
12 protovalidate::{
13 AnyRules, BytesRules, DurationRules, EnumRules, Fixed32Rules, Fixed64Rules, Int32Rules,
14 Int64Rules, SFixed32Rules, SFixed64Rules, SInt32Rules, SInt64Rules, StringRules, UInt32Rules,
15 UInt64Rules,
16 },
17 Duration,
18};
19
20#[derive(Debug, Clone)]
21pub enum ItemList {
22 Slice {
23 error_message: String,
24 tokens: TokenStream,
25 },
26 HashSet {
27 error_message: String,
28 tokens: TokenStream,
29 static_ident: Ident,
30 },
31}
32
33#[derive(Debug, Clone)]
34pub struct ContainingRules {
35 pub in_list_rule: Option<ItemList>,
36 pub not_in_list_rule: Option<ItemList>,
37}
38
39fn format_items_list<T>(items: &[T]) -> String
40where
41 T: Display,
42{
43 items
44 .iter()
45 .map(|i| i.to_string())
46 .collect::<Vec<String>>()
47 .join(", ")
48}
49
50fn format_items_list_wrapped_in_quotes<T>(items: &[T]) -> String
51where
52 T: Display,
53{
54 items
55 .iter()
56 .map(|i| format!("'{}'", i))
57 .collect::<Vec<String>>()
58 .join(", ")
59}
60
61pub(crate) fn get_list_kind<T, Hashable>(
62 rule_name: &str,
63 slice: &[T],
64 set: HashSet<Hashable>,
65 error_prefix: &str,
66 wrap_items_in_quotes: bool,
67 type_tokens: TokenStream,
68 field_full_name: &str,
69) -> Option<ItemList>
70where
71 T: Debug + ToTokens + Display,
72 Hashable: Debug + ToTokens,
73{
74 if slice.is_empty() && set.is_empty() {
75 None
76 } else {
77 let stringified_list = if wrap_items_in_quotes {
78 format_items_list_wrapped_in_quotes(slice)
79 } else {
80 format_items_list(slice)
81 };
82
83 let error_message = format!("{}: [ {} ]", error_prefix, stringified_list);
84
85 if slice.len() >= 16 {
86 let static_ident = Ident::new(
87 &format!("__{}_{}_LIST", field_full_name, rule_name.to_uppercase()),
88 Span::call_site(),
89 );
90
91 Some(ItemList::HashSet {
92 error_message,
93 tokens: hashset_to_tokens(set, type_tokens, &static_ident),
94 static_ident,
95 })
96 } else {
97 Some(ItemList::Slice {
98 error_message,
99 tokens: quote! { [ #(#slice),* ] },
100 })
101 }
102 }
103}
104
105macro_rules! standard_containing_rules {
106 ($struct_target:ident, $type:ty, $wrap_with_quotes:expr $(, $error_prefix:ident)?) => {
107 containing_rules!($struct_target, $type, $type, $wrap_with_quotes $(, $error_prefix)?);
108 };
109}
110
111macro_rules! containing_rules {
112 ($struct_target:ident, $type:ty, $type_tokens:ty, $wrap_with_quotes:expr $(, $error_prefix:ident)?) => {
113 macro_rules! _create_error_message {
114 (type_url, $constant_part:literal) => { concat!("the type url ", $constant_part) };
115 (, $constant_part:literal) => { $constant_part };
116 }
117 impl $struct_target {
118 pub fn containing_rules(&self, field_full_name: &str) -> Result<ContainingRules, Vec<$type>> {
119 let in_list_slice = &self.r#in;
120 let not_in_list_slice = &self.not_in;
121
122 let (in_list_hashset, not_in_list_hashset) = get_validated_lists(&self.r#in, &self.not_in)?;
123
124 let in_list_rule = get_list_kind("in", in_list_slice, in_list_hashset, _create_error_message!($($error_prefix)?, "must be one of these values"), $wrap_with_quotes, quote! { $type_tokens }, field_full_name );
125 let not_in_list_rule = get_list_kind("not_in", not_in_list_slice, not_in_list_hashset, _create_error_message!($($error_prefix)?, "cannot be one of these values"), $wrap_with_quotes, quote! { $type_tokens }, field_full_name );
126
127 Ok(ContainingRules {
128 in_list_rule,
129 not_in_list_rule,
130 })
131 }
132 }
133 };
134}
135
136containing_rules!(DurationRules, Duration, ::protocheck::types::Duration, true);
137containing_rules!(AnyRules, String, &'static str, true, type_url);
138standard_containing_rules!(EnumRules, i32, false);
139standard_containing_rules!(StringRules, String, true);
140standard_containing_rules!(Int64Rules, i64, false);
141standard_containing_rules!(Int32Rules, i32, false);
142standard_containing_rules!(SInt64Rules, i64, false);
143standard_containing_rules!(SInt32Rules, i32, false);
144standard_containing_rules!(SFixed64Rules, i64, false);
145standard_containing_rules!(SFixed32Rules, i32, false);
146standard_containing_rules!(UInt64Rules, u64, false);
147standard_containing_rules!(UInt32Rules, u32, false);
148standard_containing_rules!(Fixed64Rules, u64, false);
149standard_containing_rules!(Fixed32Rules, u32, false);
150
151impl BytesRules {
152 pub fn containing_rules(
153 &self,
154 field_full_name: &str,
155 ) -> Result<ContainingRules, Vec<LitByteStr>> {
156 let in_list_slice = &self.r#in;
157 let not_in_list_slice = &self.not_in;
158
159 let in_list_hashset: HashSet<LitByteStr> = in_list_slice
160 .iter()
161 .map(|b| LitByteStr::new(b, Span::call_site()))
162 .collect();
163
164 let not_in_list_hashset: HashSet<LitByteStr> = not_in_list_slice
165 .iter()
166 .map(|b| LitByteStr::new(b, Span::call_site()))
167 .collect();
168
169 let invalid_items: Vec<LitByteStr> = in_list_hashset
170 .intersection(¬_in_list_hashset)
171 .cloned()
172 .collect();
173
174 if !invalid_items.is_empty() {
175 return Err(invalid_items);
176 }
177
178 let in_list = (!in_list_hashset.is_empty()).then(|| {
179 let stringified_list = self
180 .r#in
181 .iter()
182 .map(|b| format_bytes(b))
183 .collect::<Vec<String>>()
184 .join(", ");
185
186 let error_message = format!("must be one of these values: [ {} ]", stringified_list);
187
188 if in_list_slice.len() >= 16 {
189 let static_ident = format_ident!("__{}_IN_LIST", field_full_name);
190 ItemList::HashSet {
191 error_message,
192 tokens: byte_lit_hashset_to_tokens(in_list_hashset, &static_ident),
193 static_ident,
194 }
195 } else {
196 let lit_byte_vec: Vec<&LitByteStr> = in_list_hashset.iter().collect();
197 ItemList::Slice {
198 error_message,
199 tokens: quote! { [ #(#lit_byte_vec),* ] },
200 }
201 }
202 });
203
204 let not_in_list = (!not_in_list_hashset.is_empty()).then(|| {
205 let stringified_list = self
206 .not_in
207 .iter()
208 .map(|b| format_bytes(b))
209 .collect::<Vec<String>>()
210 .join(", ");
211
212 let error_message = format!("cannot be one of these values: [ {} ]", stringified_list);
213
214 if not_in_list_slice.len() >= 16 {
215 let static_ident = format_ident!("__{}_NOT_IN_LIST", field_full_name);
216
217 ItemList::HashSet {
218 error_message,
219 tokens: byte_lit_hashset_to_tokens(not_in_list_hashset, &static_ident),
220 static_ident,
221 }
222 } else {
223 let lit_byte_vec: Vec<&LitByteStr> = not_in_list_hashset.iter().collect();
224 ItemList::Slice {
225 error_message,
226 tokens: quote! { [ #(#lit_byte_vec),* ] },
227 }
228 }
229 });
230
231 Ok(ContainingRules {
232 in_list_rule: in_list,
233 not_in_list_rule: not_in_list,
234 })
235 }
236}
237
238pub(crate) fn format_bytes(bytes: &[u8]) -> String {
239 let mut s = String::with_capacity(bytes.len() * 2);
240 s.push('\'');
241
242 for &byte in bytes.iter() {
243 match byte {
244 b'\n' => s.push_str("\\n"),
245 b'\r' => s.push_str("\\r"),
246 b'\t' => s.push_str("\\t"),
247 b'\\' => s.push_str("\\\\"),
248 b'"' => s.push_str("\\\""),
249
250 32..=126 => s.push(byte as char),
251
252 _ => {
253 write!(s, "\\x{:02x}", byte).unwrap();
254 }
255 }
256 }
257
258 s.push('\'');
259 s
260}
261
262pub(crate) fn get_validated_lists<T>(
263 in_list: &[T],
264 not_in_list: &[T],
265) -> Result<(HashSet<T>, HashSet<T>), Vec<T>>
266where
267 T: Clone + Hash + Eq + Debug + ToTokens + Display,
268{
269 let in_list_hashset: HashSet<T> = in_list.iter().cloned().collect();
270 let not_in_list_hashset: HashSet<T> = not_in_list.iter().cloned().collect();
271
272 let invalid_items: Vec<T> = in_list_hashset
273 .intersection(¬_in_list_hashset)
274 .cloned()
275 .collect();
276
277 if !invalid_items.is_empty() {
278 return Err(invalid_items);
279 }
280
281 Ok((in_list_hashset, not_in_list_hashset))
282}
283
284fn wrap_hashset_tokens(
285 set_ident: Ident,
286 hashset_tokens: TokenStream,
287 type_tokens: TokenStream,
288 static_ident: &Ident,
289) -> TokenStream {
290 quote! {
291 static #static_ident: ::std::sync::LazyLock<std::collections::HashSet<#type_tokens>> = ::std::sync::LazyLock::new(||{
292 let mut #set_ident: ::std::collections::HashSet<#type_tokens> = ::std::collections::HashSet::new();
293 #hashset_tokens
294 #set_ident
295 });
296 }
297}
298
299pub(crate) fn hashset_to_tokens<T>(
300 hashset: HashSet<T>,
301 type_tokens: TokenStream,
302 static_ident: &Ident,
303) -> TokenStream
304where
305 T: ToTokens,
306{
307 let set_ident = Ident::new("set", Span::call_site());
308
309 let mut hashset_tokens = TokenStream::new();
310
311 for item in hashset {
312 hashset_tokens.extend(quote! {
313 #set_ident.insert(#item);
314 });
315 }
316
317 wrap_hashset_tokens(set_ident, hashset_tokens, type_tokens, static_ident)
318}
319
320pub(crate) fn byte_lit_hashset_to_tokens(
321 hashset: HashSet<LitByteStr>,
322 static_ident: &Ident,
323) -> TokenStream {
324 let set_ident = Ident::new("set", Span::call_site());
325 let mut hashset_tokens = TokenStream::new();
326
327 for item in hashset {
328 hashset_tokens.extend(quote! {
329 #set_ident.insert(::bytes::Bytes::from_static(#item));
330 });
331 }
332
333 wrap_hashset_tokens(
334 set_ident,
335 hashset_tokens,
336 quote! { ::bytes::Bytes },
337 static_ident,
338 )
339}