1use std::{
16 fmt::{self, Debug, Display},
17 marker::PhantomData,
18 ops::Deref,
19};
20
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct NonEmptyVec<T>(Vec<T>);
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub struct EmptyVecError;
42
43impl Display for EmptyVecError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.write_str("vector is empty; NonEmptyVec requires at least one element")
46 }
47}
48
49impl std::error::Error for EmptyVecError {}
50
51impl<T> NonEmptyVec<T> {
52 pub fn try_new(vec: Vec<T>) -> Result<Self, EmptyVecError> {
58 if vec.is_empty() {
59 Err(EmptyVecError)
60 } else {
61 Ok(Self(vec))
62 }
63 }
64
65 pub fn singleton(first: T) -> Self {
70 Self(vec![first])
71 }
72
73 pub fn first(&self) -> &T {
76 &self.0[0]
78 }
79
80 pub fn last(&self) -> &T {
82 let len = self.0.len();
83 &self.0[len - 1]
84 }
85
86 pub fn into_vec(self) -> Vec<T> {
88 self.0
89 }
90
91 pub fn as_vec(&self) -> &Vec<T> {
93 &self.0
94 }
95
96 pub fn push(&mut self, value: T) {
98 self.0.push(value);
99 }
100
101 pub fn len(&self) -> usize {
103 self.0.len()
104 }
105
106 #[allow(clippy::unused_self)]
109 pub const fn is_empty(&self) -> bool {
110 false
111 }
112}
113
114impl<T> Deref for NonEmptyVec<T> {
115 type Target = [T];
116
117 fn deref(&self) -> &[T] {
118 &self.0
119 }
120}
121
122impl<T> AsRef<[T]> for NonEmptyVec<T> {
123 fn as_ref(&self) -> &[T] {
124 &self.0
125 }
126}
127
128impl<T> TryFrom<Vec<T>> for NonEmptyVec<T> {
129 type Error = EmptyVecError;
130
131 fn try_from(vec: Vec<T>) -> Result<Self, Self::Error> {
132 Self::try_new(vec)
133 }
134}
135
136impl<T> From<NonEmptyVec<T>> for Vec<T> {
137 fn from(nev: NonEmptyVec<T>) -> Self {
138 nev.0
139 }
140}
141
142impl<T> IntoIterator for NonEmptyVec<T> {
143 type Item = T;
144 type IntoIter = std::vec::IntoIter<T>;
145
146 fn into_iter(self) -> Self::IntoIter {
147 self.0.into_iter()
148 }
149}
150
151impl<'a, T> IntoIterator for &'a NonEmptyVec<T> {
152 type Item = &'a T;
153 type IntoIter = std::slice::Iter<'a, T>;
154
155 fn into_iter(self) -> Self::IntoIter {
156 self.0.iter()
157 }
158}
159
160impl<T: Serialize> Serialize for NonEmptyVec<T> {
161 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
162 self.0.serialize(serializer)
163 }
164}
165
166impl<'de, T: Deserialize<'de>> Deserialize<'de> for NonEmptyVec<T> {
167 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
168 let vec = Vec::<T>::deserialize(deserializer)?;
169 Self::try_new(vec).map_err(serde::de::Error::custom)
170 }
171}
172
173#[derive(Debug, Clone)]
194pub struct SqlIdentifier {
195 original: String,
196 normalised: String,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq)]
201pub enum SqlIdentifierError {
202 Empty,
204 InvalidCharacter(char),
206 InvalidWildcardPosition,
208 StartsWithDigit,
210}
211
212impl Display for SqlIdentifierError {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 match self {
215 Self::Empty => f.write_str("SQL identifier is empty"),
216 Self::InvalidCharacter(c) => {
217 write!(f, "SQL identifier contains invalid character {c:?}")
218 }
219 Self::InvalidWildcardPosition => f.write_str(
220 "wildcard '*' may only appear at the start or end of a pattern",
221 ),
222 Self::StartsWithDigit => {
223 f.write_str("SQL identifier must not start with a digit")
224 }
225 }
226 }
227}
228
229impl std::error::Error for SqlIdentifierError {}
230
231impl SqlIdentifier {
232 pub fn try_new(raw: impl Into<String>) -> Result<Self, SqlIdentifierError> {
239 let original = raw.into();
240 if original.is_empty() {
241 return Err(SqlIdentifierError::Empty);
242 }
243 let bytes = original.as_bytes();
244 for (i, &b) in bytes.iter().enumerate() {
245 let is_leading_wildcard = i == 0 && b == b'*';
246 let is_trailing_wildcard = i + 1 == bytes.len() && b == b'*';
247 if b == b'*' && !is_leading_wildcard && !is_trailing_wildcard {
248 return Err(SqlIdentifierError::InvalidWildcardPosition);
249 }
250 let is_alpha = b.is_ascii_alphabetic();
251 let is_digit = b.is_ascii_digit();
252 let is_underscore = b == b'_';
253 if !(is_alpha
254 || is_digit
255 || is_underscore
256 || is_leading_wildcard
257 || is_trailing_wildcard)
258 {
259 return Err(SqlIdentifierError::InvalidCharacter(char::from(b)));
260 }
261 if i == 0 && is_digit {
262 return Err(SqlIdentifierError::StartsWithDigit);
263 }
264 }
265 let normalised = original.to_ascii_lowercase();
266 Ok(Self {
267 original,
268 normalised,
269 })
270 }
271
272 pub fn original(&self) -> &str {
274 &self.original
275 }
276
277 pub fn normalised(&self) -> &str {
282 &self.normalised
283 }
284
285 pub fn is_wildcard(&self) -> bool {
287 self.normalised == "*"
288 }
289
290 pub fn as_prefix_pattern(&self) -> Option<&str> {
293 self.normalised
294 .strip_suffix('*')
295 .filter(|s| !s.is_empty() && !s.contains('*'))
296 }
297
298 pub fn as_suffix_pattern(&self) -> Option<&str> {
300 self.normalised
301 .strip_prefix('*')
302 .filter(|s| !s.is_empty() && !s.contains('*'))
303 }
304
305 pub fn matches(&self, column_name: &str) -> bool {
308 if self.is_wildcard() {
309 return true;
310 }
311 let lhs = column_name.to_ascii_lowercase();
312 if let Some(prefix) = self.as_prefix_pattern() {
313 return lhs.starts_with(prefix);
314 }
315 if let Some(suffix) = self.as_suffix_pattern() {
316 return lhs.ends_with(suffix);
317 }
318 lhs == self.normalised
319 }
320}
321
322impl Display for SqlIdentifier {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 f.write_str(&self.original)
325 }
326}
327
328impl PartialEq for SqlIdentifier {
329 fn eq(&self, other: &Self) -> bool {
330 self.normalised == other.normalised
331 }
332}
333
334impl Eq for SqlIdentifier {}
335
336impl std::hash::Hash for SqlIdentifier {
337 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
338 self.normalised.hash(state);
339 }
340}
341
342impl PartialOrd for SqlIdentifier {
343 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
344 Some(self.cmp(other))
345 }
346}
347
348impl Ord for SqlIdentifier {
349 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
350 self.normalised.cmp(&other.normalised)
351 }
352}
353
354impl Serialize for SqlIdentifier {
355 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
356 self.original.serialize(serializer)
357 }
358}
359
360impl<'de> Deserialize<'de> for SqlIdentifier {
361 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
362 let raw = String::deserialize(deserializer)?;
363 Self::try_new(raw).map_err(serde::de::Error::custom)
364 }
365}
366
367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
380pub struct BoundedSize<const MAX: usize>(usize);
381
382#[derive(Debug, Clone, Copy, PartialEq, Eq)]
384pub struct BoundedSizeError {
385 pub value: u64,
387 pub max: usize,
389}
390
391impl Display for BoundedSizeError {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 write!(
394 f,
395 "size {} exceeds bound {} (decompression bomb or corruption)",
396 self.value, self.max
397 )
398 }
399}
400
401impl std::error::Error for BoundedSizeError {}
402
403impl<const MAX: usize> BoundedSize<MAX> {
404 pub const fn try_new(value: usize) -> Result<Self, BoundedSizeError> {
410 if value > MAX {
411 Err(BoundedSizeError {
412 value: value as u64,
413 max: MAX,
414 })
415 } else {
416 Ok(Self(value))
417 }
418 }
419
420 pub const fn get(self) -> usize {
422 self.0
423 }
424
425 pub const fn max() -> usize {
427 MAX
428 }
429}
430
431impl<const MAX: usize> TryFrom<u32> for BoundedSize<MAX> {
432 type Error = BoundedSizeError;
433
434 fn try_from(value: u32) -> Result<Self, Self::Error> {
435 Self::try_new(value as usize)
436 }
437}
438
439impl<const MAX: usize> TryFrom<u64> for BoundedSize<MAX> {
440 type Error = BoundedSizeError;
441
442 fn try_from(value: u64) -> Result<Self, Self::Error> {
443 if value > usize::MAX as u64 {
444 return Err(BoundedSizeError {
445 value,
446 max: MAX,
447 });
448 }
449 Self::try_new(value as usize)
450 }
451}
452
453impl<const MAX: usize> TryFrom<usize> for BoundedSize<MAX> {
454 type Error = BoundedSizeError;
455
456 fn try_from(value: usize) -> Result<Self, Self::Error> {
457 Self::try_new(value)
458 }
459}
460
461impl<const MAX: usize> From<BoundedSize<MAX>> for usize {
462 fn from(bs: BoundedSize<MAX>) -> Self {
463 bs.0
464 }
465}
466
467impl<const MAX: usize> Serialize for BoundedSize<MAX> {
469 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
470 self.0.serialize(serializer)
471 }
472}
473
474impl<'de, const MAX: usize> Deserialize<'de> for BoundedSize<MAX> {
475 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
476 let raw = usize::deserialize(deserializer)?;
477 Self::try_new(raw).map_err(serde::de::Error::custom)
478 }
479}
480
481const _: () = {
485 fn _phantom<const MAX: usize>() -> PhantomData<[(); MAX]> {
486 PhantomData
487 }
488};
489
490#[repr(u8)]
502#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
503pub enum ClearanceLevel {
504 #[default]
506 Public = 0,
507 Confidential = 1,
509 Secret = 2,
511 TopSecret = 3,
513}
514
515#[derive(Debug, Clone, Copy, PartialEq, Eq)]
517pub struct ClearanceLevelError {
518 pub value: u8,
520}
521
522impl Display for ClearanceLevelError {
523 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524 write!(
525 f,
526 "clearance level {} is out of range (valid: 0..=3)",
527 self.value
528 )
529 }
530}
531
532impl std::error::Error for ClearanceLevelError {}
533
534impl ClearanceLevel {
535 pub const fn as_u8(self) -> u8 {
537 self as u8
538 }
539
540 pub const fn dominates(self, other: Self) -> bool {
543 (self as u8) >= (other as u8)
544 }
545}
546
547impl TryFrom<u8> for ClearanceLevel {
548 type Error = ClearanceLevelError;
549
550 fn try_from(value: u8) -> Result<Self, Self::Error> {
551 match value {
552 0 => Ok(Self::Public),
553 1 => Ok(Self::Confidential),
554 2 => Ok(Self::Secret),
555 3 => Ok(Self::TopSecret),
556 _ => Err(ClearanceLevelError { value }),
557 }
558 }
559}
560
561impl From<ClearanceLevel> for u8 {
562 fn from(level: ClearanceLevel) -> Self {
563 level as u8
564 }
565}
566
567impl Display for ClearanceLevel {
568 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569 match self {
570 Self::Public => f.write_str("public"),
571 Self::Confidential => f.write_str("confidential"),
572 Self::Secret => f.write_str("secret"),
573 Self::TopSecret => f.write_str("top_secret"),
574 }
575 }
576}
577
578impl Serialize for ClearanceLevel {
579 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
580 (*self as u8).serialize(serializer)
581 }
582}
583
584impl<'de> Deserialize<'de> for ClearanceLevel {
585 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
586 let byte = u8::deserialize(deserializer)?;
587 Self::try_from(byte).map_err(serde::de::Error::custom)
588 }
589}
590
591#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
602 fn non_empty_vec_rejects_empty() {
603 assert_eq!(NonEmptyVec::<u8>::try_new(vec![]), Err(EmptyVecError));
604 }
605
606 #[test]
607 fn non_empty_vec_accepts_single_element() {
608 let v = NonEmptyVec::singleton(42u8);
609 assert_eq!(v.len(), 1);
610 assert_eq!(*v.first(), 42);
611 assert_eq!(*v.last(), 42);
612 assert!(!v.is_empty());
613 }
614
615 #[test]
616 fn non_empty_vec_push_preserves_invariant() {
617 let mut v = NonEmptyVec::singleton(1u8);
618 v.push(2);
619 v.push(3);
620 assert_eq!(v.len(), 3);
621 assert_eq!(&*v, &[1, 2, 3]);
622 }
623
624 #[test]
625 fn non_empty_vec_serde_roundtrip() {
626 let v = NonEmptyVec::try_new(vec![1, 2, 3]).expect("non-empty");
627 let json = serde_json::to_string(&v).expect("serialize");
628 assert_eq!(json, "[1,2,3]");
629 let back: NonEmptyVec<i32> = serde_json::from_str(&json).expect("deserialize");
630 assert_eq!(back, v);
631 }
632
633 #[test]
634 fn non_empty_vec_serde_rejects_empty() {
635 let err = serde_json::from_str::<NonEmptyVec<i32>>("[]");
636 assert!(err.is_err(), "deserializing empty should fail");
637 }
638
639 #[test]
642 fn sql_identifier_normalises_case() {
643 let a = SqlIdentifier::try_new("Email").expect("valid");
644 let b = SqlIdentifier::try_new("EMAIL").expect("valid");
645 let c = SqlIdentifier::try_new("email").expect("valid");
646 assert_eq!(a, b);
647 assert_eq!(b, c);
648 assert_eq!(a.original(), "Email");
649 assert_eq!(a.normalised(), "email");
650 }
651
652 #[test]
653 fn sql_identifier_rejects_empty() {
654 assert_eq!(
655 SqlIdentifier::try_new(""),
656 Err(SqlIdentifierError::Empty)
657 );
658 }
659
660 #[test]
661 fn sql_identifier_rejects_leading_digit() {
662 assert_eq!(
663 SqlIdentifier::try_new("1col"),
664 Err(SqlIdentifierError::StartsWithDigit)
665 );
666 }
667
668 #[test]
669 fn sql_identifier_rejects_invalid_char() {
670 match SqlIdentifier::try_new("col-name") {
671 Err(SqlIdentifierError::InvalidCharacter(c)) => assert_eq!(c, '-'),
672 other => panic!("expected InvalidCharacter, got {other:?}"),
673 }
674 }
675
676 #[test]
677 fn sql_identifier_accepts_wildcard_patterns() {
678 SqlIdentifier::try_new("*").expect("bare wildcard");
679 SqlIdentifier::try_new("email_*").expect("prefix pattern");
680 SqlIdentifier::try_new("*_token").expect("suffix pattern");
681 }
682
683 #[test]
684 fn sql_identifier_rejects_middle_wildcard() {
685 assert_eq!(
686 SqlIdentifier::try_new("em*ail"),
687 Err(SqlIdentifierError::InvalidWildcardPosition)
688 );
689 }
690
691 #[test]
692 fn sql_identifier_matches_case_insensitively() {
693 let pat = SqlIdentifier::try_new("Email").expect("valid");
694 assert!(pat.matches("email"));
695 assert!(pat.matches("EMAIL"));
696 assert!(pat.matches("Email"));
697 assert!(!pat.matches("name"));
698 }
699
700 #[test]
701 fn sql_identifier_prefix_suffix_wildcard_match() {
702 let prefix = SqlIdentifier::try_new("user_*").expect("valid");
703 assert!(prefix.matches("user_id"));
704 assert!(prefix.matches("USER_NAME"));
705 assert!(!prefix.matches("id"));
706
707 let suffix = SqlIdentifier::try_new("*_id").expect("valid");
708 assert!(suffix.matches("user_id"));
709 assert!(suffix.matches("ORDER_ID"));
710 assert!(!suffix.matches("user"));
711
712 let wildcard = SqlIdentifier::try_new("*").expect("valid");
713 assert!(wildcard.matches("anything"));
714 }
715
716 #[test]
717 fn sql_identifier_serde_roundtrip() {
718 let id = SqlIdentifier::try_new("User_Email").expect("valid");
719 let json = serde_json::to_string(&id).expect("serialize");
720 assert_eq!(json, "\"User_Email\"");
721 let back: SqlIdentifier = serde_json::from_str(&json).expect("deserialize");
722 assert_eq!(back, id);
723 assert_eq!(back.original(), "User_Email");
724 }
725
726 #[test]
729 fn bounded_size_accepts_within_bound() {
730 let bs: BoundedSize<1024> =
731 BoundedSize::try_new(512).expect("within bound");
732 assert_eq!(bs.get(), 512);
733 assert_eq!(BoundedSize::<1024>::max(), 1024);
734 }
735
736 #[test]
737 fn bounded_size_accepts_exact_max() {
738 let bs: BoundedSize<1024> =
739 BoundedSize::try_new(1024).expect("exact max permitted");
740 assert_eq!(bs.get(), 1024);
741 }
742
743 #[test]
744 fn bounded_size_rejects_over_bound() {
745 let err = BoundedSize::<1024>::try_new(1025).unwrap_err();
746 assert_eq!(err.value, 1025);
747 assert_eq!(err.max, 1024);
748 }
749
750 #[test]
751 fn bounded_size_tryfrom_u32() {
752 let bs: BoundedSize<1024> = 512u32.try_into().expect("within bound");
753 assert_eq!(bs.get(), 512);
754 let err: Result<BoundedSize<1024>, _> = 2048u32.try_into();
755 assert!(err.is_err());
756 }
757
758 #[test]
759 fn bounded_size_tryfrom_u64_overflow_on_32bit_safe() {
760 let bs: BoundedSize<{ usize::MAX }> = 42u64.try_into().expect("within bound");
762 assert_eq!(bs.get(), 42);
763 }
764
765 #[test]
766 fn bounded_size_serde_enforces_on_deserialize() {
767 let bs: BoundedSize<100> = 50usize.try_into().expect("valid");
768 let json = serde_json::to_string(&bs).expect("serialize");
769 assert_eq!(json, "50");
770 let ok: BoundedSize<100> = serde_json::from_str(&json).expect("deserialize");
771 assert_eq!(ok.get(), 50);
772 let err = serde_json::from_str::<BoundedSize<100>>("200");
773 assert!(err.is_err(), "deserialising over-bound should fail");
774 }
775
776 #[test]
779 fn clearance_level_tryfrom_valid() {
780 assert_eq!(ClearanceLevel::try_from(0), Ok(ClearanceLevel::Public));
781 assert_eq!(
782 ClearanceLevel::try_from(1),
783 Ok(ClearanceLevel::Confidential)
784 );
785 assert_eq!(ClearanceLevel::try_from(2), Ok(ClearanceLevel::Secret));
786 assert_eq!(ClearanceLevel::try_from(3), Ok(ClearanceLevel::TopSecret));
787 }
788
789 #[test]
790 fn clearance_level_tryfrom_invalid() {
791 let err = ClearanceLevel::try_from(4).unwrap_err();
792 assert_eq!(err.value, 4);
793 let err = ClearanceLevel::try_from(255).unwrap_err();
794 assert_eq!(err.value, 255);
795 }
796
797 #[test]
798 fn clearance_level_dominates() {
799 assert!(ClearanceLevel::TopSecret.dominates(ClearanceLevel::Public));
800 assert!(ClearanceLevel::Secret.dominates(ClearanceLevel::Confidential));
801 assert!(ClearanceLevel::Public.dominates(ClearanceLevel::Public));
802 assert!(!ClearanceLevel::Public.dominates(ClearanceLevel::Secret));
803 }
804
805 #[test]
806 fn clearance_level_default_is_public() {
807 assert_eq!(ClearanceLevel::default(), ClearanceLevel::Public);
808 }
809
810 #[test]
811 fn clearance_level_serde_roundtrip() {
812 for level in [
813 ClearanceLevel::Public,
814 ClearanceLevel::Confidential,
815 ClearanceLevel::Secret,
816 ClearanceLevel::TopSecret,
817 ] {
818 let json = serde_json::to_string(&level).expect("serialize");
819 let back: ClearanceLevel =
820 serde_json::from_str(&json).expect("deserialize");
821 assert_eq!(back, level);
822 }
823 }
824
825 #[test]
826 fn clearance_level_serde_rejects_out_of_range() {
827 let err = serde_json::from_str::<ClearanceLevel>("7");
828 assert!(err.is_err(), "deserialising 7 should fail");
829 }
830}