use std::{
fmt::{self, Debug, Display},
marker::PhantomData,
ops::Deref,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NonEmptyVec<T>(Vec<T>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmptyVecError;
impl Display for EmptyVecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("vector is empty; NonEmptyVec requires at least one element")
}
}
impl std::error::Error for EmptyVecError {}
impl<T> NonEmptyVec<T> {
pub fn try_new(vec: Vec<T>) -> Result<Self, EmptyVecError> {
if vec.is_empty() {
Err(EmptyVecError)
} else {
Ok(Self(vec))
}
}
pub fn singleton(first: T) -> Self {
Self(vec![first])
}
pub fn first(&self) -> &T {
&self.0[0]
}
pub fn last(&self) -> &T {
let len = self.0.len();
&self.0[len - 1]
}
pub fn into_vec(self) -> Vec<T> {
self.0
}
pub fn as_vec(&self) -> &Vec<T> {
&self.0
}
pub fn push(&mut self, value: T) {
self.0.push(value);
}
pub fn len(&self) -> usize {
self.0.len()
}
#[allow(clippy::unused_self)]
pub const fn is_empty(&self) -> bool {
false
}
}
impl<T> Deref for NonEmptyVec<T> {
type Target = [T];
fn deref(&self) -> &[T] {
&self.0
}
}
impl<T> AsRef<[T]> for NonEmptyVec<T> {
fn as_ref(&self) -> &[T] {
&self.0
}
}
impl<T> TryFrom<Vec<T>> for NonEmptyVec<T> {
type Error = EmptyVecError;
fn try_from(vec: Vec<T>) -> Result<Self, Self::Error> {
Self::try_new(vec)
}
}
impl<T> From<NonEmptyVec<T>> for Vec<T> {
fn from(nev: NonEmptyVec<T>) -> Self {
nev.0
}
}
impl<T> IntoIterator for NonEmptyVec<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a, T> IntoIterator for &'a NonEmptyVec<T> {
type Item = &'a T;
type IntoIter = std::slice::Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl<T: Serialize> Serialize for NonEmptyVec<T> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.0.serialize(serializer)
}
}
impl<'de, T: Deserialize<'de>> Deserialize<'de> for NonEmptyVec<T> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let vec = Vec::<T>::deserialize(deserializer)?;
Self::try_new(vec).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone)]
pub struct SqlIdentifier {
original: String,
normalised: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SqlIdentifierError {
Empty,
InvalidCharacter(char),
InvalidWildcardPosition,
StartsWithDigit,
}
impl Display for SqlIdentifierError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Empty => f.write_str("SQL identifier is empty"),
Self::InvalidCharacter(c) => {
write!(f, "SQL identifier contains invalid character {c:?}")
}
Self::InvalidWildcardPosition => {
f.write_str("wildcard '*' may only appear at the start or end of a pattern")
}
Self::StartsWithDigit => f.write_str("SQL identifier must not start with a digit"),
}
}
}
impl std::error::Error for SqlIdentifierError {}
impl SqlIdentifier {
pub fn try_new(raw: impl Into<String>) -> Result<Self, SqlIdentifierError> {
let original = raw.into();
if original.is_empty() {
return Err(SqlIdentifierError::Empty);
}
let bytes = original.as_bytes();
for (i, &b) in bytes.iter().enumerate() {
let is_leading_wildcard = i == 0 && b == b'*';
let is_trailing_wildcard = i + 1 == bytes.len() && b == b'*';
if b == b'*' && !is_leading_wildcard && !is_trailing_wildcard {
return Err(SqlIdentifierError::InvalidWildcardPosition);
}
let is_alpha = b.is_ascii_alphabetic();
let is_digit = b.is_ascii_digit();
let is_underscore = b == b'_';
if !(is_alpha
|| is_digit
|| is_underscore
|| is_leading_wildcard
|| is_trailing_wildcard)
{
return Err(SqlIdentifierError::InvalidCharacter(char::from(b)));
}
if i == 0 && is_digit {
return Err(SqlIdentifierError::StartsWithDigit);
}
}
let normalised = original.to_ascii_lowercase();
Ok(Self {
original,
normalised,
})
}
pub fn original(&self) -> &str {
&self.original
}
pub fn normalised(&self) -> &str {
&self.normalised
}
pub fn is_wildcard(&self) -> bool {
self.normalised == "*"
}
pub fn as_prefix_pattern(&self) -> Option<&str> {
self.normalised
.strip_suffix('*')
.filter(|s| !s.is_empty() && !s.contains('*'))
}
pub fn as_suffix_pattern(&self) -> Option<&str> {
self.normalised
.strip_prefix('*')
.filter(|s| !s.is_empty() && !s.contains('*'))
}
pub fn matches(&self, column_name: &str) -> bool {
if self.is_wildcard() {
return true;
}
let lhs = column_name.to_ascii_lowercase();
if let Some(prefix) = self.as_prefix_pattern() {
return lhs.starts_with(prefix);
}
if let Some(suffix) = self.as_suffix_pattern() {
return lhs.ends_with(suffix);
}
lhs == self.normalised
}
}
impl Display for SqlIdentifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.original)
}
}
impl PartialEq for SqlIdentifier {
fn eq(&self, other: &Self) -> bool {
self.normalised == other.normalised
}
}
impl Eq for SqlIdentifier {}
impl std::hash::Hash for SqlIdentifier {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.normalised.hash(state);
}
}
impl PartialOrd for SqlIdentifier {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SqlIdentifier {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.normalised.cmp(&other.normalised)
}
}
impl Serialize for SqlIdentifier {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.original.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SqlIdentifier {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let raw = String::deserialize(deserializer)?;
Self::try_new(raw).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct BoundedSize<const MAX: usize>(usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BoundedSizeError {
pub value: u64,
pub max: usize,
}
impl Display for BoundedSizeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"size {} exceeds bound {} (decompression bomb or corruption)",
self.value, self.max
)
}
}
impl std::error::Error for BoundedSizeError {}
impl<const MAX: usize> BoundedSize<MAX> {
pub const fn try_new(value: usize) -> Result<Self, BoundedSizeError> {
if value > MAX {
Err(BoundedSizeError {
value: value as u64,
max: MAX,
})
} else {
Ok(Self(value))
}
}
pub const fn get(self) -> usize {
self.0
}
pub const fn max() -> usize {
MAX
}
}
impl<const MAX: usize> TryFrom<u32> for BoundedSize<MAX> {
type Error = BoundedSizeError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
Self::try_new(value as usize)
}
}
impl<const MAX: usize> TryFrom<u64> for BoundedSize<MAX> {
type Error = BoundedSizeError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
if value > usize::MAX as u64 {
return Err(BoundedSizeError { value, max: MAX });
}
Self::try_new(value as usize)
}
}
impl<const MAX: usize> TryFrom<usize> for BoundedSize<MAX> {
type Error = BoundedSizeError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}
impl<const MAX: usize> From<BoundedSize<MAX>> for usize {
fn from(bs: BoundedSize<MAX>) -> Self {
bs.0
}
}
impl<const MAX: usize> Serialize for BoundedSize<MAX> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.0.serialize(serializer)
}
}
impl<'de, const MAX: usize> Deserialize<'de> for BoundedSize<MAX> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let raw = usize::deserialize(deserializer)?;
Self::try_new(raw).map_err(serde::de::Error::custom)
}
}
const _: () = {
fn _phantom<const MAX: usize>() -> PhantomData<[(); MAX]> {
PhantomData
}
};
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub enum ClearanceLevel {
#[default]
Public = 0,
Confidential = 1,
Secret = 2,
TopSecret = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ClearanceLevelError {
pub value: u8,
}
impl Display for ClearanceLevelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"clearance level {} is out of range (valid: 0..=3)",
self.value
)
}
}
impl std::error::Error for ClearanceLevelError {}
impl ClearanceLevel {
pub const fn as_u8(self) -> u8 {
self as u8
}
pub const fn dominates(self, other: Self) -> bool {
(self as u8) >= (other as u8)
}
}
impl TryFrom<u8> for ClearanceLevel {
type Error = ClearanceLevelError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Public),
1 => Ok(Self::Confidential),
2 => Ok(Self::Secret),
3 => Ok(Self::TopSecret),
_ => Err(ClearanceLevelError { value }),
}
}
}
impl From<ClearanceLevel> for u8 {
fn from(level: ClearanceLevel) -> Self {
level as u8
}
}
impl Display for ClearanceLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Public => f.write_str("public"),
Self::Confidential => f.write_str("confidential"),
Self::Secret => f.write_str("secret"),
Self::TopSecret => f.write_str("top_secret"),
}
}
}
impl Serialize for ClearanceLevel {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
(*self as u8).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for ClearanceLevel {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let byte = u8::deserialize(deserializer)?;
Self::try_from(byte).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn non_empty_vec_rejects_empty() {
assert_eq!(NonEmptyVec::<u8>::try_new(vec![]), Err(EmptyVecError));
}
#[test]
fn non_empty_vec_accepts_single_element() {
let v = NonEmptyVec::singleton(42u8);
assert_eq!(v.len(), 1);
assert_eq!(*v.first(), 42);
assert_eq!(*v.last(), 42);
assert!(!v.is_empty());
}
#[test]
fn non_empty_vec_push_preserves_invariant() {
let mut v = NonEmptyVec::singleton(1u8);
v.push(2);
v.push(3);
assert_eq!(v.len(), 3);
assert_eq!(&*v, &[1, 2, 3]);
}
#[test]
fn non_empty_vec_serde_roundtrip() {
let v = NonEmptyVec::try_new(vec![1, 2, 3]).expect("non-empty");
let json = serde_json::to_string(&v).expect("serialize");
assert_eq!(json, "[1,2,3]");
let back: NonEmptyVec<i32> = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, v);
}
#[test]
fn non_empty_vec_serde_rejects_empty() {
let err = serde_json::from_str::<NonEmptyVec<i32>>("[]");
assert!(err.is_err(), "deserializing empty should fail");
}
#[test]
fn sql_identifier_normalises_case() {
let a = SqlIdentifier::try_new("Email").expect("valid");
let b = SqlIdentifier::try_new("EMAIL").expect("valid");
let c = SqlIdentifier::try_new("email").expect("valid");
assert_eq!(a, b);
assert_eq!(b, c);
assert_eq!(a.original(), "Email");
assert_eq!(a.normalised(), "email");
}
#[test]
fn sql_identifier_rejects_empty() {
assert_eq!(SqlIdentifier::try_new(""), Err(SqlIdentifierError::Empty));
}
#[test]
fn sql_identifier_rejects_leading_digit() {
assert_eq!(
SqlIdentifier::try_new("1col"),
Err(SqlIdentifierError::StartsWithDigit)
);
}
#[test]
fn sql_identifier_rejects_invalid_char() {
match SqlIdentifier::try_new("col-name") {
Err(SqlIdentifierError::InvalidCharacter(c)) => assert_eq!(c, '-'),
other => panic!("expected InvalidCharacter, got {other:?}"),
}
}
#[test]
fn sql_identifier_accepts_wildcard_patterns() {
SqlIdentifier::try_new("*").expect("bare wildcard");
SqlIdentifier::try_new("email_*").expect("prefix pattern");
SqlIdentifier::try_new("*_token").expect("suffix pattern");
}
#[test]
fn sql_identifier_rejects_middle_wildcard() {
assert_eq!(
SqlIdentifier::try_new("em*ail"),
Err(SqlIdentifierError::InvalidWildcardPosition)
);
}
#[test]
fn sql_identifier_matches_case_insensitively() {
let pat = SqlIdentifier::try_new("Email").expect("valid");
assert!(pat.matches("email"));
assert!(pat.matches("EMAIL"));
assert!(pat.matches("Email"));
assert!(!pat.matches("name"));
}
#[test]
fn sql_identifier_prefix_suffix_wildcard_match() {
let prefix = SqlIdentifier::try_new("user_*").expect("valid");
assert!(prefix.matches("user_id"));
assert!(prefix.matches("USER_NAME"));
assert!(!prefix.matches("id"));
let suffix = SqlIdentifier::try_new("*_id").expect("valid");
assert!(suffix.matches("user_id"));
assert!(suffix.matches("ORDER_ID"));
assert!(!suffix.matches("user"));
let wildcard = SqlIdentifier::try_new("*").expect("valid");
assert!(wildcard.matches("anything"));
}
#[test]
fn sql_identifier_serde_roundtrip() {
let id = SqlIdentifier::try_new("User_Email").expect("valid");
let json = serde_json::to_string(&id).expect("serialize");
assert_eq!(json, "\"User_Email\"");
let back: SqlIdentifier = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, id);
assert_eq!(back.original(), "User_Email");
}
#[test]
fn bounded_size_accepts_within_bound() {
let bs: BoundedSize<1024> = BoundedSize::try_new(512).expect("within bound");
assert_eq!(bs.get(), 512);
assert_eq!(BoundedSize::<1024>::max(), 1024);
}
#[test]
fn bounded_size_accepts_exact_max() {
let bs: BoundedSize<1024> = BoundedSize::try_new(1024).expect("exact max permitted");
assert_eq!(bs.get(), 1024);
}
#[test]
fn bounded_size_rejects_over_bound() {
let err = BoundedSize::<1024>::try_new(1025).unwrap_err();
assert_eq!(err.value, 1025);
assert_eq!(err.max, 1024);
}
#[test]
fn bounded_size_tryfrom_u32() {
let bs: BoundedSize<1024> = 512u32.try_into().expect("within bound");
assert_eq!(bs.get(), 512);
let err: Result<BoundedSize<1024>, _> = 2048u32.try_into();
assert!(err.is_err());
}
#[test]
fn bounded_size_tryfrom_u64_overflow_on_32bit_safe() {
let bs: BoundedSize<{ usize::MAX }> = 42u64.try_into().expect("within bound");
assert_eq!(bs.get(), 42);
}
#[test]
fn bounded_size_serde_enforces_on_deserialize() {
let bs: BoundedSize<100> = 50usize.try_into().expect("valid");
let json = serde_json::to_string(&bs).expect("serialize");
assert_eq!(json, "50");
let ok: BoundedSize<100> = serde_json::from_str(&json).expect("deserialize");
assert_eq!(ok.get(), 50);
let err = serde_json::from_str::<BoundedSize<100>>("200");
assert!(err.is_err(), "deserialising over-bound should fail");
}
#[test]
fn clearance_level_tryfrom_valid() {
assert_eq!(ClearanceLevel::try_from(0), Ok(ClearanceLevel::Public));
assert_eq!(
ClearanceLevel::try_from(1),
Ok(ClearanceLevel::Confidential)
);
assert_eq!(ClearanceLevel::try_from(2), Ok(ClearanceLevel::Secret));
assert_eq!(ClearanceLevel::try_from(3), Ok(ClearanceLevel::TopSecret));
}
#[test]
fn clearance_level_tryfrom_invalid() {
let err = ClearanceLevel::try_from(4).unwrap_err();
assert_eq!(err.value, 4);
let err = ClearanceLevel::try_from(255).unwrap_err();
assert_eq!(err.value, 255);
}
#[test]
fn clearance_level_dominates() {
assert!(ClearanceLevel::TopSecret.dominates(ClearanceLevel::Public));
assert!(ClearanceLevel::Secret.dominates(ClearanceLevel::Confidential));
assert!(ClearanceLevel::Public.dominates(ClearanceLevel::Public));
assert!(!ClearanceLevel::Public.dominates(ClearanceLevel::Secret));
}
#[test]
fn clearance_level_default_is_public() {
assert_eq!(ClearanceLevel::default(), ClearanceLevel::Public);
}
#[test]
fn clearance_level_serde_roundtrip() {
for level in [
ClearanceLevel::Public,
ClearanceLevel::Confidential,
ClearanceLevel::Secret,
ClearanceLevel::TopSecret,
] {
let json = serde_json::to_string(&level).expect("serialize");
let back: ClearanceLevel = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, level);
}
}
#[test]
fn clearance_level_serde_rejects_out_of_range() {
let err = serde_json::from_str::<ClearanceLevel>("7");
assert!(err.is_err(), "deserialising 7 should fail");
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct AggregateMemoryBudget(u64);
pub const AGGREGATE_BUDGET_MIN_BYTES: u64 = 64 * 1024;
pub const AGGREGATE_BUDGET_DEFAULT_BYTES: u64 = 256 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AggregateMemoryBudgetError {
pub observed: u64,
pub minimum: u64,
}
impl Display for AggregateMemoryBudgetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"aggregate memory budget {} bytes is below the minimum {} bytes",
self.observed, self.minimum
)
}
}
impl std::error::Error for AggregateMemoryBudgetError {}
impl AggregateMemoryBudget {
pub const DEFAULT: Self = Self(AGGREGATE_BUDGET_DEFAULT_BYTES);
pub fn try_new(bytes: u64) -> Result<Self, AggregateMemoryBudgetError> {
if bytes < AGGREGATE_BUDGET_MIN_BYTES {
return Err(AggregateMemoryBudgetError {
observed: bytes,
minimum: AGGREGATE_BUDGET_MIN_BYTES,
});
}
let budget = Self(bytes);
debug_assert_eq!(budget.bytes(), bytes, "AggregateMemoryBudget round-trip");
Ok(budget)
}
pub const fn bytes(&self) -> u64 {
self.0
}
}
impl Default for AggregateMemoryBudget {
fn default() -> Self {
Self::DEFAULT
}
}
impl TryFrom<u64> for AggregateMemoryBudget {
type Error = AggregateMemoryBudgetError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}
impl Display for AggregateMemoryBudget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} bytes", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DateField {
Year,
Month,
Day,
Hour,
Minute,
Second,
Millisecond,
Microsecond,
DayOfWeek,
DayOfYear,
Quarter,
Week,
Epoch,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DateFieldParseError(pub String);
impl Display for DateFieldParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"unknown date field '{}' (expected one of: YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, DOW, DOY, QUARTER, WEEK, EPOCH)",
self.0
)
}
}
impl std::error::Error for DateFieldParseError {}
impl DateField {
pub fn parse(s: &str) -> Result<Self, DateFieldParseError> {
match s.to_ascii_uppercase().as_str() {
"YEAR" => Ok(Self::Year),
"MONTH" => Ok(Self::Month),
"DAY" => Ok(Self::Day),
"HOUR" => Ok(Self::Hour),
"MINUTE" => Ok(Self::Minute),
"SECOND" => Ok(Self::Second),
"MILLISECOND" | "MILLISECONDS" => Ok(Self::Millisecond),
"MICROSECOND" | "MICROSECONDS" => Ok(Self::Microsecond),
"DOW" | "DAYOFWEEK" => Ok(Self::DayOfWeek),
"DOY" | "DAYOFYEAR" => Ok(Self::DayOfYear),
"QUARTER" => Ok(Self::Quarter),
"WEEK" => Ok(Self::Week),
"EPOCH" => Ok(Self::Epoch),
_ => Err(DateFieldParseError(s.to_string())),
}
}
pub const fn is_truncatable(&self) -> bool {
matches!(
self,
Self::Year | Self::Month | Self::Day | Self::Hour | Self::Minute | Self::Second
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SubstringRange {
pub start: i64,
pub length: Option<i64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NegativeSubstringLength(pub i64);
impl Display for NegativeSubstringLength {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SUBSTRING length must be non-negative, got {}", self.0)
}
}
impl std::error::Error for NegativeSubstringLength {}
impl SubstringRange {
pub const fn from_start(start: i64) -> Self {
Self {
start,
length: None,
}
}
pub fn try_new(start: i64, length: i64) -> Result<Self, NegativeSubstringLength> {
if length < 0 {
return Err(NegativeSubstringLength(length));
}
Ok(Self {
start,
length: Some(length),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Interval {
months: i32,
days: i32,
nanos: i64,
}
pub const NANOS_PER_DAY: i64 = 86_400_000_000_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IntervalConstructionError {
DayOverflow,
MonthOverflow,
}
impl Display for IntervalConstructionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DayOverflow => f.write_str("interval day component overflow"),
Self::MonthOverflow => f.write_str("interval month component overflow"),
}
}
}
impl std::error::Error for IntervalConstructionError {}
impl Interval {
pub const ZERO: Self = Self {
months: 0,
days: 0,
nanos: 0,
};
pub fn try_from_components(
months: i32,
days: i32,
nanos: i64,
) -> Result<Self, IntervalConstructionError> {
let extra_days_i64 = nanos / NANOS_PER_DAY;
let normalised_nanos = nanos % NANOS_PER_DAY;
let extra_days_i32 =
i32::try_from(extra_days_i64).map_err(|_| IntervalConstructionError::DayOverflow)?;
let new_days = days
.checked_add(extra_days_i32)
.ok_or(IntervalConstructionError::DayOverflow)?;
let interval = Self {
months,
days: new_days,
nanos: normalised_nanos,
};
debug_assert!(
interval.nanos.unsigned_abs() < NANOS_PER_DAY as u64,
"Interval normalisation failed: |nanos|={} >= {}",
interval.nanos.unsigned_abs(),
NANOS_PER_DAY
);
Ok(interval)
}
pub const fn from_months(months: i32) -> Self {
Self {
months,
days: 0,
nanos: 0,
}
}
pub const fn from_days(days: i32) -> Self {
Self {
months: 0,
days,
nanos: 0,
}
}
pub fn from_nanos(nanos: i64) -> Result<Self, IntervalConstructionError> {
Self::try_from_components(0, 0, nanos)
}
pub const fn months(&self) -> i32 {
self.months
}
pub const fn days(&self) -> i32 {
self.days
}
pub const fn nanos(&self) -> i64 {
self.nanos
}
pub const fn is_zero(&self) -> bool {
self.months == 0 && self.days == 0 && self.nanos == 0
}
pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
let months = self.months.checked_add(rhs.months)?;
let total_nanos = self.nanos.checked_add(rhs.nanos)?;
let extra_days = total_nanos / NANOS_PER_DAY;
let nanos = total_nanos % NANOS_PER_DAY;
let days = self
.days
.checked_add(rhs.days)?
.checked_add(i32::try_from(extra_days).ok()?)?;
Some(Self {
months,
days,
nanos,
})
}
pub fn checked_sub(&self, rhs: &Self) -> Option<Self> {
let neg_rhs = rhs.checked_neg()?;
self.checked_add(&neg_rhs)
}
pub fn checked_neg(&self) -> Option<Self> {
Some(Self {
months: self.months.checked_neg()?,
days: self.days.checked_neg()?,
nanos: self.nanos.checked_neg()?,
})
}
}
impl Default for Interval {
fn default() -> Self {
Self::ZERO
}
}
impl Display for Interval {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} months {} days {} ns",
self.months, self.days, self.nanos
)
}
}
#[cfg(kani)]
mod kani_harnesses {
use super::*;
#[kani::proof]
fn interval_add_associative_no_month() {
let a_days: i32 = kani::any();
let a_nanos: i64 = kani::any();
let b_days: i32 = kani::any();
let b_nanos: i64 = kani::any();
let c_days: i32 = kani::any();
let c_nanos: i64 = kani::any();
kani::assume(a_days.abs() < 1_000_000);
kani::assume(b_days.abs() < 1_000_000);
kani::assume(c_days.abs() < 1_000_000);
kani::assume(a_nanos.abs() < NANOS_PER_DAY);
kani::assume(b_nanos.abs() < NANOS_PER_DAY);
kani::assume(c_nanos.abs() < NANOS_PER_DAY);
let a = Interval::try_from_components(0, a_days, a_nanos).unwrap();
let b = Interval::try_from_components(0, b_days, b_nanos).unwrap();
let c = Interval::try_from_components(0, c_days, c_nanos).unwrap();
let lhs = a.checked_add(&b).and_then(|ab| ab.checked_add(&c));
let rhs = b.checked_add(&c).and_then(|bc| a.checked_add(&bc));
match (lhs, rhs) {
(Some(l), Some(r)) => assert_eq!(l, r),
(None, None) => {}
_ => kani::cover!(),
}
}
#[kani::proof]
fn interval_zero_identity() {
let m: i32 = kani::any();
let d: i32 = kani::any();
let n: i64 = kani::any();
kani::assume(n.abs() < NANOS_PER_DAY);
let i = Interval::try_from_components(m, d, n).unwrap();
let z = Interval::ZERO;
assert_eq!(i.checked_add(&z), Some(i));
assert_eq!(z.checked_add(&i), Some(i));
}
}
#[cfg(test)]
mod v07_tests {
use super::*;
#[test]
fn aggregate_budget_default_is_256_mib() {
assert_eq!(
AggregateMemoryBudget::DEFAULT.bytes(),
256 * 1024 * 1024,
"default budget drift breaks consumers tuning against the documented constant"
);
}
#[test]
fn aggregate_budget_try_new_rejects_below_floor() {
let err = AggregateMemoryBudget::try_new(1024).expect_err("1 KiB is below the floor");
assert_eq!(err.observed, 1024);
assert_eq!(err.minimum, AGGREGATE_BUDGET_MIN_BYTES);
}
#[test]
fn aggregate_budget_accepts_floor_exactly() {
let b = AggregateMemoryBudget::try_new(AGGREGATE_BUDGET_MIN_BYTES).expect("at-floor ok");
assert_eq!(b.bytes(), AGGREGATE_BUDGET_MIN_BYTES);
}
#[test]
fn aggregate_budget_round_trips_through_tryfrom() {
let b: AggregateMemoryBudget = (8 * 1024 * 1024_u64).try_into().expect("8 MiB ok");
assert_eq!(b.bytes(), 8 * 1024 * 1024);
}
#[test]
fn datefield_parse_canonical_keywords() {
assert_eq!(DateField::parse("YEAR").unwrap(), DateField::Year);
assert_eq!(DateField::parse("year").unwrap(), DateField::Year);
assert_eq!(DateField::parse("Year").unwrap(), DateField::Year);
assert_eq!(DateField::parse("DOW").unwrap(), DateField::DayOfWeek);
assert_eq!(DateField::parse("DAYOFWEEK").unwrap(), DateField::DayOfWeek);
assert_eq!(DateField::parse("EPOCH").unwrap(), DateField::Epoch);
}
#[test]
fn datefield_parse_rejects_unknown() {
let err = DateField::parse("DECADE").expect_err("DECADE not supported");
assert!(err.to_string().contains("DECADE"));
}
#[test]
fn datefield_truncatable_subset() {
for tf in [
DateField::Year,
DateField::Month,
DateField::Day,
DateField::Hour,
DateField::Minute,
DateField::Second,
] {
assert!(tf.is_truncatable(), "{tf:?} must be truncatable");
}
for nontf in [
DateField::DayOfWeek,
DateField::Quarter,
DateField::Week,
DateField::Epoch,
DateField::Millisecond,
DateField::Microsecond,
] {
assert!(!nontf.is_truncatable(), "{nontf:?} must NOT be truncatable");
}
}
#[test]
fn substring_range_two_arg_form() {
let r = SubstringRange::from_start(3);
assert_eq!(r.start, 3);
assert!(r.length.is_none());
}
#[test]
fn substring_range_three_arg_form_accepts_zero_length() {
let r = SubstringRange::try_new(1, 0).expect("zero length is legal");
assert_eq!(r.start, 1);
assert_eq!(r.length, Some(0));
}
#[test]
fn substring_range_rejects_negative_length() {
let err = SubstringRange::try_new(1, -3).expect_err("negative length rejected");
assert_eq!(err.0, -3);
}
#[test]
fn substring_range_accepts_negative_start() {
let r = SubstringRange::try_new(-2, 5).expect("negative start allowed");
assert_eq!(r.start, -2);
}
#[test]
fn interval_zero_is_zero() {
assert!(Interval::ZERO.is_zero());
assert_eq!(Interval::ZERO.months(), 0);
assert_eq!(Interval::ZERO.days(), 0);
assert_eq!(Interval::ZERO.nanos(), 0);
}
#[test]
fn interval_normalises_nanos_overflow_into_days() {
let i = Interval::try_from_components(0, 0, 2 * NANOS_PER_DAY + 5).unwrap();
assert_eq!(i.days(), 2);
assert_eq!(i.nanos(), 5);
}
#[test]
fn interval_handles_negative_nanos() {
let i = Interval::try_from_components(0, 1, -NANOS_PER_DAY - 1).unwrap();
assert_eq!(i.days(), 0);
assert_eq!(i.nanos(), -1);
}
#[test]
fn interval_round_trip_components() {
let i = Interval::try_from_components(13, 7, 60_000_000_000).unwrap();
assert_eq!(i.months(), 13);
assert_eq!(i.days(), 7);
assert_eq!(i.nanos(), 60_000_000_000);
}
#[test]
fn interval_zero_identity() {
let a = Interval::try_from_components(2, 5, 1_000_000_000).unwrap();
assert_eq!(a.checked_add(&Interval::ZERO), Some(a));
assert_eq!(Interval::ZERO.checked_add(&a), Some(a));
}
#[test]
fn interval_associativity_no_month() {
let a = Interval::from_days(3);
let b = Interval::from_days(7);
let c = Interval::from_days(11);
let lhs = a.checked_add(&b).and_then(|ab| ab.checked_add(&c));
let rhs = b.checked_add(&c).and_then(|bc| a.checked_add(&bc));
assert_eq!(lhs, rhs);
assert_eq!(lhs.unwrap().days(), 21);
}
#[test]
fn interval_negation_is_self_inverse() {
let i = Interval::try_from_components(1, 2, 3).unwrap();
let neg = i.checked_neg().unwrap();
let zero = i.checked_add(&neg).unwrap();
assert!(zero.is_zero());
}
#[test]
fn interval_serde_roundtrip() {
let i = Interval::try_from_components(5, 10, 1234).unwrap();
let json = serde_json::to_string(&i).expect("serialize");
let back: Interval = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, i);
}
}