proto_types/protovalidate/
numeric_rules.rs

1use std::{
2  collections::HashSet,
3  fmt::{Debug, Display},
4  hash::Hash,
5};
6
7use quote::{format_ident, quote, ToTokens};
8
9use super::{
10  comparable_rules::ComparableRules, containing_rules::ContainingRules,
11  into_comparable::IntoComparable,
12};
13use crate::protovalidate::{
14  containing_rules::{hashset_to_tokens, ItemList},
15  ConstRule, DoubleRules, Fixed32Rules, Fixed64Rules, FloatRules, Int32Rules, Int64Rules,
16  SFixed32Rules, SFixed64Rules, SInt32Rules, SInt64Rules, UInt32Rules, UInt64Rules,
17};
18
19pub trait NumericRules<HashableType>
20where
21  HashableType: Debug + Copy + ToTokens + Eq + PartialOrd + Hash,
22{
23  type Unit: ToTokens + PartialEq + PartialOrd + Debug + Display;
24  fn constant(&self) -> Option<ConstRule<Self::Unit>>;
25  fn num_containing_rules(&self, field_full_name: &str)
26    -> Result<ContainingRules, Vec<Self::Unit>>;
27  fn finite(&self) -> bool;
28  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str>;
29}
30
31impl NumericRules<u32> for FloatRules {
32  type Unit = f32;
33
34  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
35    self.const_rule()
36  }
37  fn finite(&self) -> bool {
38    self.finite()
39  }
40
41  fn num_containing_rules(
42    &self,
43    field_full_name: &str,
44  ) -> Result<ContainingRules, Vec<Self::Unit>> {
45    let in_list_slice = &self.r#in;
46    let not_in_list_slice = &self.not_in;
47
48    let in_list_hashset: HashSet<u32> = in_list_slice.iter().map(|n| n.to_bits()).collect();
49    let not_in_list_hashset: HashSet<u32> = not_in_list_slice.iter().map(|n| n.to_bits()).collect();
50
51    let invalid_items: Vec<f32> = in_list_hashset
52      .intersection(&not_in_list_hashset)
53      .map(|n| f32::from_bits(*n))
54      .collect();
55
56    if !invalid_items.is_empty() {
57      return Err(invalid_items);
58    }
59
60    let in_list = (!in_list_hashset.is_empty()).then(|| {
61      let stringified_list = self
62        .r#in
63        .iter()
64        .map(|b| b.to_string())
65        .collect::<Vec<String>>()
66        .join(", ");
67
68      let error_message = format!("must be one of these values: [ {} ]", stringified_list);
69
70      if in_list_slice.len() >= 16 {
71        let static_ident = format_ident!("__{}_IN_LIST", field_full_name);
72        ItemList::HashSet {
73          error_message,
74          tokens: hashset_to_tokens(in_list_hashset, quote! { u32 }, &static_ident),
75          static_ident,
76        }
77      } else {
78        ItemList::Slice {
79          error_message,
80          tokens: quote! { [ #(#in_list_slice),* ] },
81        }
82      }
83    });
84
85    let not_in_list = (!not_in_list_hashset.is_empty()).then(|| {
86      let stringified_list = self
87        .r#in
88        .iter()
89        .map(|b| b.to_string())
90        .collect::<Vec<String>>()
91        .join(", ");
92
93      let error_message = format!("cannot be one of these values: [ {} ]", stringified_list);
94
95      if not_in_list_slice.len() >= 16 {
96        let static_ident = format_ident!("__{}_NOT_IN_LIST", field_full_name);
97        ItemList::HashSet {
98          error_message,
99          tokens: hashset_to_tokens(not_in_list_hashset, quote! { u32 }, &static_ident),
100          static_ident,
101        }
102      } else {
103        ItemList::Slice {
104          error_message,
105          tokens: quote! { [ #(#not_in_list_slice),* ] },
106        }
107      }
108    });
109
110    Ok(ContainingRules {
111      in_list_rule: in_list,
112      not_in_list_rule: not_in_list,
113    })
114  }
115
116  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
117    let rules = ComparableRules {
118      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
119      less_than: self.less_than.map(|lt| lt.into_comparable()),
120    };
121    rules.validate()
122  }
123}
124
125impl NumericRules<u64> for DoubleRules {
126  type Unit = f64;
127
128  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
129    self.const_rule()
130  }
131  fn finite(&self) -> bool {
132    self.finite()
133  }
134  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
135    let rules = ComparableRules {
136      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
137      less_than: self.less_than.map(|lt| lt.into_comparable()),
138    };
139    rules.validate()
140  }
141
142  fn num_containing_rules(
143    &self,
144    field_full_name: &str,
145  ) -> Result<ContainingRules, Vec<Self::Unit>> {
146    let in_list_slice = &self.r#in;
147    let not_in_list_slice = &self.not_in;
148
149    let in_list_hashset: HashSet<u64> = in_list_slice.iter().map(|n| n.to_bits()).collect();
150    let not_in_list_hashset: HashSet<u64> = not_in_list_slice.iter().map(|n| n.to_bits()).collect();
151
152    let invalid_items: Vec<f64> = in_list_hashset
153      .intersection(&not_in_list_hashset)
154      .map(|n| f64::from_bits(*n))
155      .collect();
156
157    if !invalid_items.is_empty() {
158      return Err(invalid_items);
159    }
160
161    let in_list = (!in_list_hashset.is_empty()).then(|| {
162      let stringified_list = self
163        .r#in
164        .iter()
165        .map(|b| b.to_string())
166        .collect::<Vec<String>>()
167        .join(", ");
168
169      let error_message = format!("must be one of these values: [ {} ]", stringified_list);
170
171      if in_list_slice.len() >= 16 {
172        let static_ident = format_ident!("__{}_IN_LIST", field_full_name);
173        ItemList::HashSet {
174          error_message,
175          tokens: hashset_to_tokens(in_list_hashset, quote! { u64 }, &static_ident),
176          static_ident,
177        }
178      } else {
179        ItemList::Slice {
180          error_message,
181          tokens: quote! { [ #(#in_list_slice),* ] },
182        }
183      }
184    });
185
186    let not_in_list = (!not_in_list_hashset.is_empty()).then(|| {
187      let stringified_list = self
188        .r#in
189        .iter()
190        .map(|b| b.to_string())
191        .collect::<Vec<String>>()
192        .join(", ");
193
194      let error_message = format!("cannot be one of these values: [ {} ]", stringified_list);
195
196      if not_in_list_slice.len() >= 16 {
197        let static_ident = format_ident!("__{}_NOT_IN_LIST", field_full_name);
198        ItemList::HashSet {
199          error_message,
200          tokens: hashset_to_tokens(not_in_list_hashset, quote! { u64 }, &static_ident),
201          static_ident,
202        }
203      } else {
204        ItemList::Slice {
205          error_message,
206          tokens: quote! { [ #(#not_in_list_slice),* ] },
207        }
208      }
209    });
210
211    Ok(ContainingRules {
212      in_list_rule: in_list,
213      not_in_list_rule: not_in_list,
214    })
215  }
216}
217
218impl NumericRules<i64> for Int64Rules {
219  type Unit = i64;
220
221  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
222    self.const_rule()
223  }
224  fn finite(&self) -> bool {
225    false
226  }
227  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
228    let rules = ComparableRules {
229      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
230      less_than: self.less_than.map(|lt| lt.into_comparable()),
231    };
232    rules.validate()
233  }
234  fn num_containing_rules(
235    &self,
236    field_full_name: &str,
237  ) -> Result<ContainingRules, Vec<Self::Unit>> {
238    self.containing_rules(field_full_name)
239  }
240}
241
242impl NumericRules<i64> for SInt64Rules {
243  type Unit = i64;
244
245  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
246    self.const_rule()
247  }
248  fn finite(&self) -> bool {
249    false
250  }
251  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
252    let rules = ComparableRules {
253      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
254      less_than: self.less_than.map(|lt| lt.into_comparable()),
255    };
256    rules.validate()
257  }
258  fn num_containing_rules(
259    &self,
260    field_full_name: &str,
261  ) -> Result<ContainingRules, Vec<Self::Unit>> {
262    self.containing_rules(field_full_name)
263  }
264}
265
266impl NumericRules<i64> for SFixed64Rules {
267  type Unit = i64;
268
269  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
270    self.const_rule()
271  }
272  fn finite(&self) -> bool {
273    false
274  }
275  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
276    let rules = ComparableRules {
277      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
278      less_than: self.less_than.map(|lt| lt.into_comparable()),
279    };
280    rules.validate()
281  }
282  fn num_containing_rules(
283    &self,
284    field_full_name: &str,
285  ) -> Result<ContainingRules, Vec<Self::Unit>> {
286    self.containing_rules(field_full_name)
287  }
288}
289
290impl NumericRules<i32> for Int32Rules {
291  type Unit = i32;
292
293  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
294    self.const_rule()
295  }
296  fn finite(&self) -> bool {
297    false
298  }
299  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
300    let rules = ComparableRules {
301      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
302      less_than: self.less_than.map(|lt| lt.into_comparable()),
303    };
304    rules.validate()
305  }
306  fn num_containing_rules(
307    &self,
308    field_full_name: &str,
309  ) -> Result<ContainingRules, Vec<Self::Unit>> {
310    self.containing_rules(field_full_name)
311  }
312}
313
314impl NumericRules<i32> for SInt32Rules {
315  type Unit = i32;
316
317  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
318    self.const_rule()
319  }
320  fn finite(&self) -> bool {
321    false
322  }
323  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
324    let rules = ComparableRules {
325      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
326      less_than: self.less_than.map(|lt| lt.into_comparable()),
327    };
328    rules.validate()
329  }
330  fn num_containing_rules(
331    &self,
332    field_full_name: &str,
333  ) -> Result<ContainingRules, Vec<Self::Unit>> {
334    self.containing_rules(field_full_name)
335  }
336}
337
338impl NumericRules<i32> for SFixed32Rules {
339  type Unit = i32;
340
341  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
342    self.const_rule()
343  }
344  fn finite(&self) -> bool {
345    false
346  }
347  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
348    let rules = ComparableRules {
349      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
350      less_than: self.less_than.map(|lt| lt.into_comparable()),
351    };
352    rules.validate()
353  }
354  fn num_containing_rules(
355    &self,
356    field_full_name: &str,
357  ) -> Result<ContainingRules, Vec<Self::Unit>> {
358    self.containing_rules(field_full_name)
359  }
360}
361
362impl NumericRules<u64> for UInt64Rules {
363  type Unit = u64;
364
365  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
366    self.const_rule()
367  }
368  fn finite(&self) -> bool {
369    false
370  }
371  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
372    let rules = ComparableRules {
373      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
374      less_than: self.less_than.map(|lt| lt.into_comparable()),
375    };
376    rules.validate()
377  }
378  fn num_containing_rules(
379    &self,
380    field_full_name: &str,
381  ) -> Result<ContainingRules, Vec<Self::Unit>> {
382    self.containing_rules(field_full_name)
383  }
384}
385
386impl NumericRules<u64> for Fixed64Rules {
387  type Unit = u64;
388
389  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
390    self.const_rule()
391  }
392  fn finite(&self) -> bool {
393    false
394  }
395  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
396    let rules = ComparableRules {
397      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
398      less_than: self.less_than.map(|lt| lt.into_comparable()),
399    };
400    rules.validate()
401  }
402  fn num_containing_rules(
403    &self,
404    field_full_name: &str,
405  ) -> Result<ContainingRules, Vec<Self::Unit>> {
406    self.containing_rules(field_full_name)
407  }
408}
409
410impl NumericRules<u32> for UInt32Rules {
411  type Unit = u32;
412
413  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
414    self.const_rule()
415  }
416  fn finite(&self) -> bool {
417    false
418  }
419  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
420    let rules = ComparableRules {
421      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
422      less_than: self.less_than.map(|lt| lt.into_comparable()),
423    };
424    rules.validate()
425  }
426  fn num_containing_rules(
427    &self,
428    field_full_name: &str,
429  ) -> Result<ContainingRules, Vec<Self::Unit>> {
430    self.containing_rules(field_full_name)
431  }
432}
433
434impl NumericRules<u32> for Fixed32Rules {
435  type Unit = u32;
436
437  fn constant(&self) -> Option<ConstRule<Self::Unit>> {
438    self.const_rule()
439  }
440  fn finite(&self) -> bool {
441    false
442  }
443  fn comparable_rules(&self) -> Result<ComparableRules<Self::Unit>, &'static str> {
444    let rules = ComparableRules {
445      greater_than: self.greater_than.map(|gt| gt.into_comparable()),
446      less_than: self.less_than.map(|lt| lt.into_comparable()),
447    };
448    rules.validate()
449  }
450  fn num_containing_rules(
451    &self,
452    field_full_name: &str,
453  ) -> Result<ContainingRules, Vec<Self::Unit>> {
454    self.containing_rules(field_full_name)
455  }
456}