proto_types/protovalidate/
containing_rules.rs

1use std::{
2  collections::HashSet,
3  fmt::{Debug, Display, Write},
4  hash::Hash,
5};
6
7use proc_macro2::{Span, TokenStream};
8use quote::{format_ident, quote, ToTokens};
9use syn::{Ident, LitByteStr};
10
11use crate::{
12  protovalidate::{
13    AnyRules, BytesRules, DurationRules, EnumRules, Fixed32Rules, Fixed64Rules, Int32Rules,
14    Int64Rules, SFixed32Rules, SFixed64Rules, SInt32Rules, SInt64Rules, StringRules, UInt32Rules,
15    UInt64Rules,
16  },
17  Duration,
18};
19
20#[derive(Debug, Clone)]
21pub enum ItemList {
22  Slice {
23    error_message: String,
24    tokens: TokenStream,
25  },
26  HashSet {
27    error_message: String,
28    tokens: TokenStream,
29    static_ident: Ident,
30  },
31}
32
33#[derive(Debug, Clone)]
34pub struct ContainingRules {
35  pub in_list_rule: Option<ItemList>,
36  pub not_in_list_rule: Option<ItemList>,
37}
38
39fn format_items_list<T>(items: &[T]) -> String
40where
41  T: Display,
42{
43  items
44    .iter()
45    .map(|i| i.to_string())
46    .collect::<Vec<String>>()
47    .join(", ")
48}
49
50fn format_items_list_wrapped_in_quotes<T>(items: &[T]) -> String
51where
52  T: Display,
53{
54  items
55    .iter()
56    .map(|i| format!("'{}'", i))
57    .collect::<Vec<String>>()
58    .join(", ")
59}
60
61pub(crate) fn get_list_kind<T, Hashable>(
62  rule_name: &str,
63  slice: &[T],
64  set: HashSet<Hashable>,
65  error_prefix: &str,
66  wrap_items_in_quotes: bool,
67  type_tokens: TokenStream,
68  field_full_name: &str,
69) -> Option<ItemList>
70where
71  T: Debug + ToTokens + Display,
72  Hashable: Debug + ToTokens,
73{
74  if slice.is_empty() && set.is_empty() {
75    None
76  } else {
77    let stringified_list = if wrap_items_in_quotes {
78      format_items_list_wrapped_in_quotes(slice)
79    } else {
80      format_items_list(slice)
81    };
82
83    let error_message = format!("{}: [ {} ]", error_prefix, stringified_list);
84
85    if slice.len() >= 16 {
86      let static_ident = Ident::new(
87        &format!("__{}_{}_LIST", field_full_name, rule_name.to_uppercase()),
88        Span::call_site(),
89      );
90
91      Some(ItemList::HashSet {
92        error_message,
93        tokens: hashset_to_tokens(set, type_tokens, &static_ident),
94        static_ident,
95      })
96    } else {
97      Some(ItemList::Slice {
98        error_message,
99        tokens: quote! { [ #(#slice),* ] },
100      })
101    }
102  }
103}
104
105macro_rules! standard_containing_rules {
106  ($struct_target:ident, $type:ty, $wrap_with_quotes:expr $(, $error_prefix:ident)?) => {
107    containing_rules!($struct_target, $type, $type, $wrap_with_quotes $(, $error_prefix)?);
108  };
109}
110
111macro_rules! containing_rules {
112  ($struct_target:ident, $type:ty, $type_tokens:ty, $wrap_with_quotes:expr $(, $error_prefix:ident)?) => {
113    macro_rules! _create_error_message {
114      (type_url, $constant_part:literal) => {  concat!("the type url ", $constant_part)  };
115      (, $constant_part:literal) => { $constant_part };
116    }
117    impl $struct_target {
118      pub fn containing_rules(&self, field_full_name: &str) -> Result<ContainingRules, Vec<$type>> {
119        let in_list_slice = &self.r#in;
120        let not_in_list_slice = &self.not_in;
121
122        let (in_list_hashset, not_in_list_hashset) = get_validated_lists(&self.r#in, &self.not_in)?;
123
124        let in_list_rule = get_list_kind("in", in_list_slice, in_list_hashset, _create_error_message!($($error_prefix)?, "must be one of these values"), $wrap_with_quotes, quote! { $type_tokens }, field_full_name );
125        let not_in_list_rule = get_list_kind("not_in", not_in_list_slice, not_in_list_hashset, _create_error_message!($($error_prefix)?, "cannot be one of these values"), $wrap_with_quotes, quote! { $type_tokens }, field_full_name );
126
127        Ok(ContainingRules {
128          in_list_rule,
129          not_in_list_rule,
130        })
131      }
132    }
133  };
134}
135
136containing_rules!(DurationRules, Duration, ::protocheck::types::Duration, true);
137containing_rules!(AnyRules, String, &'static str, true, type_url);
138standard_containing_rules!(EnumRules, i32, false);
139standard_containing_rules!(StringRules, String, true);
140standard_containing_rules!(Int64Rules, i64, false);
141standard_containing_rules!(Int32Rules, i32, false);
142standard_containing_rules!(SInt64Rules, i64, false);
143standard_containing_rules!(SInt32Rules, i32, false);
144standard_containing_rules!(SFixed64Rules, i64, false);
145standard_containing_rules!(SFixed32Rules, i32, false);
146standard_containing_rules!(UInt64Rules, u64, false);
147standard_containing_rules!(UInt32Rules, u32, false);
148standard_containing_rules!(Fixed64Rules, u64, false);
149standard_containing_rules!(Fixed32Rules, u32, false);
150
151impl BytesRules {
152  pub fn containing_rules(
153    &self,
154    field_full_name: &str,
155  ) -> Result<ContainingRules, Vec<LitByteStr>> {
156    let in_list_slice = &self.r#in;
157    let not_in_list_slice = &self.not_in;
158
159    let in_list_hashset: HashSet<LitByteStr> = in_list_slice
160      .iter()
161      .map(|b| LitByteStr::new(b, Span::call_site()))
162      .collect();
163
164    let not_in_list_hashset: HashSet<LitByteStr> = not_in_list_slice
165      .iter()
166      .map(|b| LitByteStr::new(b, Span::call_site()))
167      .collect();
168
169    let invalid_items: Vec<LitByteStr> = in_list_hashset
170      .intersection(&not_in_list_hashset)
171      .cloned()
172      .collect();
173
174    if !invalid_items.is_empty() {
175      return Err(invalid_items);
176    }
177
178    let in_list = (!in_list_hashset.is_empty()).then(|| {
179      let stringified_list = self
180        .r#in
181        .iter()
182        .map(|b| format_bytes(b))
183        .collect::<Vec<String>>()
184        .join(", ");
185
186      let error_message = format!("must be one of these values: [ {} ]", stringified_list);
187
188      if in_list_slice.len() >= 16 {
189        let static_ident = format_ident!("__{}_IN_LIST", field_full_name);
190        ItemList::HashSet {
191          error_message,
192          tokens: byte_lit_hashset_to_tokens(in_list_hashset, &static_ident),
193          static_ident,
194        }
195      } else {
196        let lit_byte_vec: Vec<&LitByteStr> = in_list_hashset.iter().collect();
197        ItemList::Slice {
198          error_message,
199          tokens: quote! { [ #(#lit_byte_vec),* ] },
200        }
201      }
202    });
203
204    let not_in_list = (!not_in_list_hashset.is_empty()).then(|| {
205      let stringified_list = self
206        .not_in
207        .iter()
208        .map(|b| format_bytes(b))
209        .collect::<Vec<String>>()
210        .join(", ");
211
212      let error_message = format!("cannot be one of these values: [ {} ]", stringified_list);
213
214      if not_in_list_slice.len() >= 16 {
215        let static_ident = format_ident!("__{}_NOT_IN_LIST", field_full_name);
216
217        ItemList::HashSet {
218          error_message,
219          tokens: byte_lit_hashset_to_tokens(not_in_list_hashset, &static_ident),
220          static_ident,
221        }
222      } else {
223        let lit_byte_vec: Vec<&LitByteStr> = not_in_list_hashset.iter().collect();
224        ItemList::Slice {
225          error_message,
226          tokens: quote! { [ #(#lit_byte_vec),* ] },
227        }
228      }
229    });
230
231    Ok(ContainingRules {
232      in_list_rule: in_list,
233      not_in_list_rule: not_in_list,
234    })
235  }
236}
237
238pub(crate) fn format_bytes(bytes: &[u8]) -> String {
239  let mut s = String::with_capacity(bytes.len() * 2);
240  s.push('\'');
241
242  for &byte in bytes.iter() {
243    match byte {
244      b'\n' => s.push_str("\\n"),
245      b'\r' => s.push_str("\\r"),
246      b'\t' => s.push_str("\\t"),
247      b'\\' => s.push_str("\\\\"),
248      b'"' => s.push_str("\\\""),
249
250      32..=126 => s.push(byte as char),
251
252      _ => {
253        write!(s, "\\x{:02x}", byte).unwrap();
254      }
255    }
256  }
257
258  s.push('\'');
259  s
260}
261
262pub(crate) fn get_validated_lists<T>(
263  in_list: &[T],
264  not_in_list: &[T],
265) -> Result<(HashSet<T>, HashSet<T>), Vec<T>>
266where
267  T: Clone + Hash + Eq + Debug + ToTokens + Display,
268{
269  let in_list_hashset: HashSet<T> = in_list.iter().cloned().collect();
270  let not_in_list_hashset: HashSet<T> = not_in_list.iter().cloned().collect();
271
272  let invalid_items: Vec<T> = in_list_hashset
273    .intersection(&not_in_list_hashset)
274    .cloned()
275    .collect();
276
277  if !invalid_items.is_empty() {
278    return Err(invalid_items);
279  }
280
281  Ok((in_list_hashset, not_in_list_hashset))
282}
283
284fn wrap_hashset_tokens(
285  set_ident: Ident,
286  hashset_tokens: TokenStream,
287  type_tokens: TokenStream,
288  static_ident: &Ident,
289) -> TokenStream {
290  quote! {
291    static #static_ident: ::std::sync::LazyLock<std::collections::HashSet<#type_tokens>> = ::std::sync::LazyLock::new(||{
292      let mut #set_ident: ::std::collections::HashSet<#type_tokens> = ::std::collections::HashSet::new();
293      #hashset_tokens
294      #set_ident
295    });
296  }
297}
298
299pub(crate) fn hashset_to_tokens<T>(
300  hashset: HashSet<T>,
301  type_tokens: TokenStream,
302  static_ident: &Ident,
303) -> TokenStream
304where
305  T: ToTokens,
306{
307  let set_ident = Ident::new("set", Span::call_site());
308
309  let mut hashset_tokens = TokenStream::new();
310
311  for item in hashset {
312    hashset_tokens.extend(quote! {
313      #set_ident.insert(#item);
314    });
315  }
316
317  wrap_hashset_tokens(set_ident, hashset_tokens, type_tokens, static_ident)
318}
319
320pub(crate) fn byte_lit_hashset_to_tokens(
321  hashset: HashSet<LitByteStr>,
322  static_ident: &Ident,
323) -> TokenStream {
324  let set_ident = Ident::new("set", Span::call_site());
325  let mut hashset_tokens = TokenStream::new();
326
327  for item in hashset {
328    hashset_tokens.extend(quote! {
329      #set_ident.insert(::bytes::Bytes::from_static(#item));
330    });
331  }
332
333  wrap_hashset_tokens(
334    set_ident,
335    hashset_tokens,
336    quote! { ::bytes::Bytes },
337    static_ident,
338  )
339}