1use std::{
16 fmt::{self, Debug, Display},
17 marker::PhantomData,
18 ops::Deref,
19};
20
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct NonEmptyVec<T>(Vec<T>);
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub struct EmptyVecError;
42
43impl Display for EmptyVecError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.write_str("vector is empty; NonEmptyVec requires at least one element")
46 }
47}
48
49impl std::error::Error for EmptyVecError {}
50
51impl<T> NonEmptyVec<T> {
52 pub fn try_new(vec: Vec<T>) -> Result<Self, EmptyVecError> {
58 if vec.is_empty() {
59 Err(EmptyVecError)
60 } else {
61 Ok(Self(vec))
62 }
63 }
64
65 pub fn singleton(first: T) -> Self {
70 Self(vec![first])
71 }
72
73 pub fn first(&self) -> &T {
76 &self.0[0]
78 }
79
80 pub fn last(&self) -> &T {
82 let len = self.0.len();
83 &self.0[len - 1]
84 }
85
86 pub fn into_vec(self) -> Vec<T> {
88 self.0
89 }
90
91 pub fn as_vec(&self) -> &Vec<T> {
93 &self.0
94 }
95
96 pub fn push(&mut self, value: T) {
98 self.0.push(value);
99 }
100
101 pub fn len(&self) -> usize {
103 self.0.len()
104 }
105
106 #[allow(clippy::unused_self)]
109 pub const fn is_empty(&self) -> bool {
110 false
111 }
112}
113
114impl<T> Deref for NonEmptyVec<T> {
115 type Target = [T];
116
117 fn deref(&self) -> &[T] {
118 &self.0
119 }
120}
121
122impl<T> AsRef<[T]> for NonEmptyVec<T> {
123 fn as_ref(&self) -> &[T] {
124 &self.0
125 }
126}
127
128impl<T> TryFrom<Vec<T>> for NonEmptyVec<T> {
129 type Error = EmptyVecError;
130
131 fn try_from(vec: Vec<T>) -> Result<Self, Self::Error> {
132 Self::try_new(vec)
133 }
134}
135
136impl<T> From<NonEmptyVec<T>> for Vec<T> {
137 fn from(nev: NonEmptyVec<T>) -> Self {
138 nev.0
139 }
140}
141
142impl<T> IntoIterator for NonEmptyVec<T> {
143 type Item = T;
144 type IntoIter = std::vec::IntoIter<T>;
145
146 fn into_iter(self) -> Self::IntoIter {
147 self.0.into_iter()
148 }
149}
150
151impl<'a, T> IntoIterator for &'a NonEmptyVec<T> {
152 type Item = &'a T;
153 type IntoIter = std::slice::Iter<'a, T>;
154
155 fn into_iter(self) -> Self::IntoIter {
156 self.0.iter()
157 }
158}
159
160impl<T: Serialize> Serialize for NonEmptyVec<T> {
161 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
162 self.0.serialize(serializer)
163 }
164}
165
166impl<'de, T: Deserialize<'de>> Deserialize<'de> for NonEmptyVec<T> {
167 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
168 let vec = Vec::<T>::deserialize(deserializer)?;
169 Self::try_new(vec).map_err(serde::de::Error::custom)
170 }
171}
172
173#[derive(Debug, Clone)]
194pub struct SqlIdentifier {
195 original: String,
196 normalised: String,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq)]
201pub enum SqlIdentifierError {
202 Empty,
204 InvalidCharacter(char),
206 InvalidWildcardPosition,
208 StartsWithDigit,
210}
211
212impl Display for SqlIdentifierError {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 match self {
215 Self::Empty => f.write_str("SQL identifier is empty"),
216 Self::InvalidCharacter(c) => {
217 write!(f, "SQL identifier contains invalid character {c:?}")
218 }
219 Self::InvalidWildcardPosition => {
220 f.write_str("wildcard '*' may only appear at the start or end of a pattern")
221 }
222 Self::StartsWithDigit => f.write_str("SQL identifier must not start with a digit"),
223 }
224 }
225}
226
227impl std::error::Error for SqlIdentifierError {}
228
229impl SqlIdentifier {
230 pub fn try_new(raw: impl Into<String>) -> Result<Self, SqlIdentifierError> {
237 let original = raw.into();
238 if original.is_empty() {
239 return Err(SqlIdentifierError::Empty);
240 }
241 let bytes = original.as_bytes();
242 for (i, &b) in bytes.iter().enumerate() {
243 let is_leading_wildcard = i == 0 && b == b'*';
244 let is_trailing_wildcard = i + 1 == bytes.len() && b == b'*';
245 if b == b'*' && !is_leading_wildcard && !is_trailing_wildcard {
246 return Err(SqlIdentifierError::InvalidWildcardPosition);
247 }
248 let is_alpha = b.is_ascii_alphabetic();
249 let is_digit = b.is_ascii_digit();
250 let is_underscore = b == b'_';
251 if !(is_alpha
252 || is_digit
253 || is_underscore
254 || is_leading_wildcard
255 || is_trailing_wildcard)
256 {
257 return Err(SqlIdentifierError::InvalidCharacter(char::from(b)));
258 }
259 if i == 0 && is_digit {
260 return Err(SqlIdentifierError::StartsWithDigit);
261 }
262 }
263 let normalised = original.to_ascii_lowercase();
264 Ok(Self {
265 original,
266 normalised,
267 })
268 }
269
270 pub fn original(&self) -> &str {
272 &self.original
273 }
274
275 pub fn normalised(&self) -> &str {
280 &self.normalised
281 }
282
283 pub fn is_wildcard(&self) -> bool {
285 self.normalised == "*"
286 }
287
288 pub fn as_prefix_pattern(&self) -> Option<&str> {
291 self.normalised
292 .strip_suffix('*')
293 .filter(|s| !s.is_empty() && !s.contains('*'))
294 }
295
296 pub fn as_suffix_pattern(&self) -> Option<&str> {
298 self.normalised
299 .strip_prefix('*')
300 .filter(|s| !s.is_empty() && !s.contains('*'))
301 }
302
303 pub fn matches(&self, column_name: &str) -> bool {
306 if self.is_wildcard() {
307 return true;
308 }
309 let lhs = column_name.to_ascii_lowercase();
310 if let Some(prefix) = self.as_prefix_pattern() {
311 return lhs.starts_with(prefix);
312 }
313 if let Some(suffix) = self.as_suffix_pattern() {
314 return lhs.ends_with(suffix);
315 }
316 lhs == self.normalised
317 }
318}
319
320impl Display for SqlIdentifier {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 f.write_str(&self.original)
323 }
324}
325
326impl PartialEq for SqlIdentifier {
327 fn eq(&self, other: &Self) -> bool {
328 self.normalised == other.normalised
329 }
330}
331
332impl Eq for SqlIdentifier {}
333
334impl std::hash::Hash for SqlIdentifier {
335 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
336 self.normalised.hash(state);
337 }
338}
339
340impl PartialOrd for SqlIdentifier {
341 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
342 Some(self.cmp(other))
343 }
344}
345
346impl Ord for SqlIdentifier {
347 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
348 self.normalised.cmp(&other.normalised)
349 }
350}
351
352impl Serialize for SqlIdentifier {
353 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
354 self.original.serialize(serializer)
355 }
356}
357
358impl<'de> Deserialize<'de> for SqlIdentifier {
359 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
360 let raw = String::deserialize(deserializer)?;
361 Self::try_new(raw).map_err(serde::de::Error::custom)
362 }
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
378pub struct BoundedSize<const MAX: usize>(usize);
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub struct BoundedSizeError {
383 pub value: u64,
385 pub max: usize,
387}
388
389impl Display for BoundedSizeError {
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 write!(
392 f,
393 "size {} exceeds bound {} (decompression bomb or corruption)",
394 self.value, self.max
395 )
396 }
397}
398
399impl std::error::Error for BoundedSizeError {}
400
401impl<const MAX: usize> BoundedSize<MAX> {
402 pub const fn try_new(value: usize) -> Result<Self, BoundedSizeError> {
408 if value > MAX {
409 Err(BoundedSizeError {
410 value: value as u64,
411 max: MAX,
412 })
413 } else {
414 Ok(Self(value))
415 }
416 }
417
418 pub const fn get(self) -> usize {
420 self.0
421 }
422
423 pub const fn max() -> usize {
425 MAX
426 }
427}
428
429impl<const MAX: usize> TryFrom<u32> for BoundedSize<MAX> {
430 type Error = BoundedSizeError;
431
432 fn try_from(value: u32) -> Result<Self, Self::Error> {
433 Self::try_new(value as usize)
434 }
435}
436
437impl<const MAX: usize> TryFrom<u64> for BoundedSize<MAX> {
438 type Error = BoundedSizeError;
439
440 fn try_from(value: u64) -> Result<Self, Self::Error> {
441 if value > usize::MAX as u64 {
442 return Err(BoundedSizeError { value, max: MAX });
443 }
444 Self::try_new(value as usize)
445 }
446}
447
448impl<const MAX: usize> TryFrom<usize> for BoundedSize<MAX> {
449 type Error = BoundedSizeError;
450
451 fn try_from(value: usize) -> Result<Self, Self::Error> {
452 Self::try_new(value)
453 }
454}
455
456impl<const MAX: usize> From<BoundedSize<MAX>> for usize {
457 fn from(bs: BoundedSize<MAX>) -> Self {
458 bs.0
459 }
460}
461
462impl<const MAX: usize> Serialize for BoundedSize<MAX> {
464 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
465 self.0.serialize(serializer)
466 }
467}
468
469impl<'de, const MAX: usize> Deserialize<'de> for BoundedSize<MAX> {
470 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
471 let raw = usize::deserialize(deserializer)?;
472 Self::try_new(raw).map_err(serde::de::Error::custom)
473 }
474}
475
476const _: () = {
480 fn _phantom<const MAX: usize>() -> PhantomData<[(); MAX]> {
481 PhantomData
482 }
483};
484
485#[repr(u8)]
497#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
498pub enum ClearanceLevel {
499 #[default]
501 Public = 0,
502 Confidential = 1,
504 Secret = 2,
506 TopSecret = 3,
508}
509
510#[derive(Debug, Clone, Copy, PartialEq, Eq)]
512pub struct ClearanceLevelError {
513 pub value: u8,
515}
516
517impl Display for ClearanceLevelError {
518 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519 write!(
520 f,
521 "clearance level {} is out of range (valid: 0..=3)",
522 self.value
523 )
524 }
525}
526
527impl std::error::Error for ClearanceLevelError {}
528
529impl ClearanceLevel {
530 pub const fn as_u8(self) -> u8 {
532 self as u8
533 }
534
535 pub const fn dominates(self, other: Self) -> bool {
538 (self as u8) >= (other as u8)
539 }
540}
541
542impl TryFrom<u8> for ClearanceLevel {
543 type Error = ClearanceLevelError;
544
545 fn try_from(value: u8) -> Result<Self, Self::Error> {
546 match value {
547 0 => Ok(Self::Public),
548 1 => Ok(Self::Confidential),
549 2 => Ok(Self::Secret),
550 3 => Ok(Self::TopSecret),
551 _ => Err(ClearanceLevelError { value }),
552 }
553 }
554}
555
556impl From<ClearanceLevel> for u8 {
557 fn from(level: ClearanceLevel) -> Self {
558 level as u8
559 }
560}
561
562impl Display for ClearanceLevel {
563 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
564 match self {
565 Self::Public => f.write_str("public"),
566 Self::Confidential => f.write_str("confidential"),
567 Self::Secret => f.write_str("secret"),
568 Self::TopSecret => f.write_str("top_secret"),
569 }
570 }
571}
572
573impl Serialize for ClearanceLevel {
574 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
575 (*self as u8).serialize(serializer)
576 }
577}
578
579impl<'de> Deserialize<'de> for ClearanceLevel {
580 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
581 let byte = u8::deserialize(deserializer)?;
582 Self::try_from(byte).map_err(serde::de::Error::custom)
583 }
584}
585
586#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
597 fn non_empty_vec_rejects_empty() {
598 assert_eq!(NonEmptyVec::<u8>::try_new(vec![]), Err(EmptyVecError));
599 }
600
601 #[test]
602 fn non_empty_vec_accepts_single_element() {
603 let v = NonEmptyVec::singleton(42u8);
604 assert_eq!(v.len(), 1);
605 assert_eq!(*v.first(), 42);
606 assert_eq!(*v.last(), 42);
607 assert!(!v.is_empty());
608 }
609
610 #[test]
611 fn non_empty_vec_push_preserves_invariant() {
612 let mut v = NonEmptyVec::singleton(1u8);
613 v.push(2);
614 v.push(3);
615 assert_eq!(v.len(), 3);
616 assert_eq!(&*v, &[1, 2, 3]);
617 }
618
619 #[test]
620 fn non_empty_vec_serde_roundtrip() {
621 let v = NonEmptyVec::try_new(vec![1, 2, 3]).expect("non-empty");
622 let json = serde_json::to_string(&v).expect("serialize");
623 assert_eq!(json, "[1,2,3]");
624 let back: NonEmptyVec<i32> = serde_json::from_str(&json).expect("deserialize");
625 assert_eq!(back, v);
626 }
627
628 #[test]
629 fn non_empty_vec_serde_rejects_empty() {
630 let err = serde_json::from_str::<NonEmptyVec<i32>>("[]");
631 assert!(err.is_err(), "deserializing empty should fail");
632 }
633
634 #[test]
637 fn sql_identifier_normalises_case() {
638 let a = SqlIdentifier::try_new("Email").expect("valid");
639 let b = SqlIdentifier::try_new("EMAIL").expect("valid");
640 let c = SqlIdentifier::try_new("email").expect("valid");
641 assert_eq!(a, b);
642 assert_eq!(b, c);
643 assert_eq!(a.original(), "Email");
644 assert_eq!(a.normalised(), "email");
645 }
646
647 #[test]
648 fn sql_identifier_rejects_empty() {
649 assert_eq!(SqlIdentifier::try_new(""), Err(SqlIdentifierError::Empty));
650 }
651
652 #[test]
653 fn sql_identifier_rejects_leading_digit() {
654 assert_eq!(
655 SqlIdentifier::try_new("1col"),
656 Err(SqlIdentifierError::StartsWithDigit)
657 );
658 }
659
660 #[test]
661 fn sql_identifier_rejects_invalid_char() {
662 match SqlIdentifier::try_new("col-name") {
663 Err(SqlIdentifierError::InvalidCharacter(c)) => assert_eq!(c, '-'),
664 other => panic!("expected InvalidCharacter, got {other:?}"),
665 }
666 }
667
668 #[test]
669 fn sql_identifier_accepts_wildcard_patterns() {
670 SqlIdentifier::try_new("*").expect("bare wildcard");
671 SqlIdentifier::try_new("email_*").expect("prefix pattern");
672 SqlIdentifier::try_new("*_token").expect("suffix pattern");
673 }
674
675 #[test]
676 fn sql_identifier_rejects_middle_wildcard() {
677 assert_eq!(
678 SqlIdentifier::try_new("em*ail"),
679 Err(SqlIdentifierError::InvalidWildcardPosition)
680 );
681 }
682
683 #[test]
684 fn sql_identifier_matches_case_insensitively() {
685 let pat = SqlIdentifier::try_new("Email").expect("valid");
686 assert!(pat.matches("email"));
687 assert!(pat.matches("EMAIL"));
688 assert!(pat.matches("Email"));
689 assert!(!pat.matches("name"));
690 }
691
692 #[test]
693 fn sql_identifier_prefix_suffix_wildcard_match() {
694 let prefix = SqlIdentifier::try_new("user_*").expect("valid");
695 assert!(prefix.matches("user_id"));
696 assert!(prefix.matches("USER_NAME"));
697 assert!(!prefix.matches("id"));
698
699 let suffix = SqlIdentifier::try_new("*_id").expect("valid");
700 assert!(suffix.matches("user_id"));
701 assert!(suffix.matches("ORDER_ID"));
702 assert!(!suffix.matches("user"));
703
704 let wildcard = SqlIdentifier::try_new("*").expect("valid");
705 assert!(wildcard.matches("anything"));
706 }
707
708 #[test]
709 fn sql_identifier_serde_roundtrip() {
710 let id = SqlIdentifier::try_new("User_Email").expect("valid");
711 let json = serde_json::to_string(&id).expect("serialize");
712 assert_eq!(json, "\"User_Email\"");
713 let back: SqlIdentifier = serde_json::from_str(&json).expect("deserialize");
714 assert_eq!(back, id);
715 assert_eq!(back.original(), "User_Email");
716 }
717
718 #[test]
721 fn bounded_size_accepts_within_bound() {
722 let bs: BoundedSize<1024> = BoundedSize::try_new(512).expect("within bound");
723 assert_eq!(bs.get(), 512);
724 assert_eq!(BoundedSize::<1024>::max(), 1024);
725 }
726
727 #[test]
728 fn bounded_size_accepts_exact_max() {
729 let bs: BoundedSize<1024> = BoundedSize::try_new(1024).expect("exact max permitted");
730 assert_eq!(bs.get(), 1024);
731 }
732
733 #[test]
734 fn bounded_size_rejects_over_bound() {
735 let err = BoundedSize::<1024>::try_new(1025).unwrap_err();
736 assert_eq!(err.value, 1025);
737 assert_eq!(err.max, 1024);
738 }
739
740 #[test]
741 fn bounded_size_tryfrom_u32() {
742 let bs: BoundedSize<1024> = 512u32.try_into().expect("within bound");
743 assert_eq!(bs.get(), 512);
744 let err: Result<BoundedSize<1024>, _> = 2048u32.try_into();
745 assert!(err.is_err());
746 }
747
748 #[test]
749 fn bounded_size_tryfrom_u64_overflow_on_32bit_safe() {
750 let bs: BoundedSize<{ usize::MAX }> = 42u64.try_into().expect("within bound");
752 assert_eq!(bs.get(), 42);
753 }
754
755 #[test]
756 fn bounded_size_serde_enforces_on_deserialize() {
757 let bs: BoundedSize<100> = 50usize.try_into().expect("valid");
758 let json = serde_json::to_string(&bs).expect("serialize");
759 assert_eq!(json, "50");
760 let ok: BoundedSize<100> = serde_json::from_str(&json).expect("deserialize");
761 assert_eq!(ok.get(), 50);
762 let err = serde_json::from_str::<BoundedSize<100>>("200");
763 assert!(err.is_err(), "deserialising over-bound should fail");
764 }
765
766 #[test]
769 fn clearance_level_tryfrom_valid() {
770 assert_eq!(ClearanceLevel::try_from(0), Ok(ClearanceLevel::Public));
771 assert_eq!(
772 ClearanceLevel::try_from(1),
773 Ok(ClearanceLevel::Confidential)
774 );
775 assert_eq!(ClearanceLevel::try_from(2), Ok(ClearanceLevel::Secret));
776 assert_eq!(ClearanceLevel::try_from(3), Ok(ClearanceLevel::TopSecret));
777 }
778
779 #[test]
780 fn clearance_level_tryfrom_invalid() {
781 let err = ClearanceLevel::try_from(4).unwrap_err();
782 assert_eq!(err.value, 4);
783 let err = ClearanceLevel::try_from(255).unwrap_err();
784 assert_eq!(err.value, 255);
785 }
786
787 #[test]
788 fn clearance_level_dominates() {
789 assert!(ClearanceLevel::TopSecret.dominates(ClearanceLevel::Public));
790 assert!(ClearanceLevel::Secret.dominates(ClearanceLevel::Confidential));
791 assert!(ClearanceLevel::Public.dominates(ClearanceLevel::Public));
792 assert!(!ClearanceLevel::Public.dominates(ClearanceLevel::Secret));
793 }
794
795 #[test]
796 fn clearance_level_default_is_public() {
797 assert_eq!(ClearanceLevel::default(), ClearanceLevel::Public);
798 }
799
800 #[test]
801 fn clearance_level_serde_roundtrip() {
802 for level in [
803 ClearanceLevel::Public,
804 ClearanceLevel::Confidential,
805 ClearanceLevel::Secret,
806 ClearanceLevel::TopSecret,
807 ] {
808 let json = serde_json::to_string(&level).expect("serialize");
809 let back: ClearanceLevel = serde_json::from_str(&json).expect("deserialize");
810 assert_eq!(back, level);
811 }
812 }
813
814 #[test]
815 fn clearance_level_serde_rejects_out_of_range() {
816 let err = serde_json::from_str::<ClearanceLevel>("7");
817 assert!(err.is_err(), "deserialising 7 should fail");
818 }
819}