use serde::{Deserialize, Serialize};
use crate::{
error::{ConstraintKind, Error, TypeError},
fragment::Fragment,
value::{
Value,
constraint::{bytes::MaxBytes, precision::Precision, scale::Scale},
dictionary::DictionaryId,
sumtype::SumTypeId,
r#type::Type,
},
};
pub mod bytes;
pub mod precision;
pub mod scale;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct TypeConstraint {
base_type: Type,
constraint: Option<Constraint>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Constraint {
MaxBytes(MaxBytes),
PrecisionScale(Precision, Scale),
Dictionary(DictionaryId, Type),
SumType(SumTypeId),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(C)]
pub struct FFITypeConstraint {
pub base_type: u8,
pub constraint_type: u8,
pub constraint_param1: u32,
pub constraint_param2: u32,
}
impl TypeConstraint {
pub const fn unconstrained(ty: Type) -> Self {
Self {
base_type: ty,
constraint: None,
}
}
pub fn with_constraint(ty: Type, constraint: Constraint) -> Self {
Self {
base_type: ty,
constraint: Some(constraint),
}
}
pub fn dictionary(dictionary_id: DictionaryId, id_type: Type) -> Self {
Self {
base_type: Type::DictionaryId,
constraint: Some(Constraint::Dictionary(dictionary_id, id_type)),
}
}
pub fn sumtype(id: SumTypeId) -> Self {
Self {
base_type: Type::Uint1,
constraint: Some(Constraint::SumType(id)),
}
}
pub fn get_type(&self) -> Type {
self.base_type.clone()
}
pub fn storage_type(&self) -> Type {
match (&self.base_type, &self.constraint) {
(Type::DictionaryId, Some(Constraint::Dictionary(_, id_type))) => id_type.clone(),
_ => self.base_type.clone(),
}
}
pub fn constraint(&self) -> &Option<Constraint> {
&self.constraint
}
pub fn to_ffi(&self) -> FFITypeConstraint {
let base_type = self.base_type.to_u8();
match &self.constraint {
None => FFITypeConstraint {
base_type,
constraint_type: 0,
constraint_param1: 0,
constraint_param2: 0,
},
Some(Constraint::MaxBytes(max)) => FFITypeConstraint {
base_type,
constraint_type: 1,
constraint_param1: max.value(),
constraint_param2: 0,
},
Some(Constraint::PrecisionScale(p, s)) => FFITypeConstraint {
base_type,
constraint_type: 2,
constraint_param1: p.value() as u32,
constraint_param2: s.value() as u32,
},
Some(Constraint::Dictionary(dict_id, id_type)) => FFITypeConstraint {
base_type,
constraint_type: 3,
constraint_param1: dict_id.to_u64() as u32,
constraint_param2: id_type.to_u8() as u32,
},
Some(Constraint::SumType(id)) => FFITypeConstraint {
base_type,
constraint_type: 4,
constraint_param1: id.to_u64() as u32,
constraint_param2: 0,
},
}
}
pub fn from_ffi(ffi: FFITypeConstraint) -> Self {
let ty = Type::from_u8(ffi.base_type);
match ffi.constraint_type {
1 => Self::with_constraint(ty, Constraint::MaxBytes(MaxBytes::new(ffi.constraint_param1))),
2 => Self::with_constraint(
ty,
Constraint::PrecisionScale(
Precision::new(ffi.constraint_param1 as u8),
Scale::new(ffi.constraint_param2 as u8),
),
),
3 => Self::with_constraint(
ty,
Constraint::Dictionary(
DictionaryId::from(ffi.constraint_param1 as u64),
Type::from_u8(ffi.constraint_param2 as u8),
),
),
4 => Self::with_constraint(
ty,
Constraint::SumType(SumTypeId::from(ffi.constraint_param1 as u64)),
),
_ => Self::unconstrained(ty),
}
}
pub fn validate(&self, value: &Value) -> Result<(), Error> {
let value_type = value.get_type();
if value_type != self.base_type && !matches!(value, Value::None { .. }) {
if let Type::Option(inner) = &self.base_type {
if value_type != **inner {
unimplemented!()
}
} else {
unimplemented!()
}
}
if matches!(value, Value::None { .. }) {
if self.base_type.is_option() {
return Ok(());
} else {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::NoneNotAllowed {
column_type: self.base_type.clone(),
},
message: format!(
"Cannot insert none into non-optional column of type {}. Declare the column as Option({}) to allow none values.",
self.base_type, self.base_type
),
fragment: Fragment::None,
}
.into());
}
}
match (&self.base_type, &self.constraint) {
(Type::Utf8, Some(Constraint::MaxBytes(max))) => {
if let Value::Utf8(s) = value {
let byte_len = s.len();
let max_value: usize = (*max).into();
if byte_len > max_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::Utf8MaxBytes {
actual: byte_len,
max: max_value,
},
message: format!(
"UTF8 value exceeds maximum byte length: {} bytes (max: {} bytes)",
byte_len, max_value
),
fragment: Fragment::None,
}
.into());
}
}
}
(Type::Blob, Some(Constraint::MaxBytes(max))) => {
if let Value::Blob(blob) = value {
let byte_len = blob.len();
let max_value: usize = (*max).into();
if byte_len > max_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::BlobMaxBytes {
actual: byte_len,
max: max_value,
},
message: format!(
"BLOB value exceeds maximum byte length: {} bytes (max: {} bytes)",
byte_len, max_value
),
fragment: Fragment::None,
}
.into());
}
}
}
(Type::Int, Some(Constraint::MaxBytes(max))) => {
if let Value::Int(vi) = value {
let str_len = vi.to_string().len();
let byte_len = (str_len * 415 / 1000) + 1; let max_value: usize = (*max).into();
if byte_len > max_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::IntMaxBytes {
actual: byte_len,
max: max_value,
},
message: format!(
"INT value exceeds maximum byte length: {} bytes (max: {} bytes)",
byte_len, max_value
),
fragment: Fragment::None,
}
.into());
}
}
}
(Type::Uint, Some(Constraint::MaxBytes(max))) => {
if let Value::Uint(vu) = value {
let str_len = vu.to_string().len();
let byte_len = (str_len * 415 / 1000) + 1; let max_value: usize = (*max).into();
if byte_len > max_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::UintMaxBytes {
actual: byte_len,
max: max_value,
},
message: format!(
"UINT value exceeds maximum byte length: {} bytes (max: {} bytes)",
byte_len, max_value
),
fragment: Fragment::None,
}
.into());
}
}
}
(Type::Decimal, Some(Constraint::PrecisionScale(precision, scale))) => {
if let Value::Decimal(decimal) = value {
let decimal_str = decimal.to_string();
let decimal_scale: u8 = if let Some(dot_pos) = decimal_str.find('.') {
let after_dot = &decimal_str[dot_pos + 1..];
after_dot.len().min(255) as u8
} else {
0
};
let decimal_precision: u8 =
decimal_str.chars().filter(|c| c.is_ascii_digit()).count().min(255)
as u8;
let scale_value: u8 = (*scale).into();
let precision_value: u8 = (*precision).into();
if decimal_scale > scale_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::DecimalScale {
actual: decimal_scale,
max: scale_value,
},
message: format!(
"DECIMAL value exceeds maximum scale: {} decimal places (max: {} decimal places)",
decimal_scale, scale_value
),
fragment: Fragment::None,
}
.into());
}
if decimal_precision > precision_value {
return Err(TypeError::ConstraintViolation {
kind: ConstraintKind::DecimalPrecision {
actual: decimal_precision,
max: precision_value,
},
message: format!(
"DECIMAL value exceeds maximum precision: {} digits (max: {} digits)",
decimal_precision, precision_value
),
fragment: Fragment::None,
}
.into());
}
}
}
_ => {}
}
Ok(())
}
pub fn is_unconstrained(&self) -> bool {
self.constraint.is_none()
}
#[allow(clippy::inherent_to_string)]
pub fn to_string(&self) -> String {
match &self.constraint {
None => format!("{}", self.base_type),
Some(Constraint::MaxBytes(max)) => {
format!("{}({})", self.base_type, max)
}
Some(Constraint::PrecisionScale(p, s)) => {
format!("{}({},{})", self.base_type, p, s)
}
Some(Constraint::Dictionary(dict_id, id_type)) => {
format!("DictionaryId(dict={}, {})", dict_id, id_type)
}
Some(Constraint::SumType(id)) => {
format!("SumType({})", id)
}
}
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_unconstrained_type() {
let tc = TypeConstraint::unconstrained(Type::Utf8);
assert_eq!(tc.base_type, Type::Utf8);
assert_eq!(tc.constraint, None);
assert!(tc.is_unconstrained());
}
#[test]
fn test_constrained_utf8() {
let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(50)));
assert_eq!(tc.base_type, Type::Utf8);
assert_eq!(tc.constraint, Some(Constraint::MaxBytes(MaxBytes::new(50))));
assert!(!tc.is_unconstrained());
}
#[test]
fn test_constrained_decimal() {
let tc = TypeConstraint::with_constraint(
Type::Decimal,
Constraint::PrecisionScale(Precision::new(10), Scale::new(2)),
);
assert_eq!(tc.base_type, Type::Decimal);
assert_eq!(tc.constraint, Some(Constraint::PrecisionScale(Precision::new(10), Scale::new(2))));
}
#[test]
fn test_validate_utf8_within_limit() {
let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(10)));
let value = Value::Utf8("hello".to_string());
assert!(tc.validate(&value).is_ok());
}
#[test]
fn test_validate_utf8_exceeds_limit() {
let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(5)));
let value = Value::Utf8("hello world".to_string());
assert!(tc.validate(&value).is_err());
}
#[test]
fn test_validate_unconstrained() {
let tc = TypeConstraint::unconstrained(Type::Utf8);
let value = Value::Utf8("any length string is fine here".to_string());
assert!(tc.validate(&value).is_ok());
}
#[test]
fn test_validate_none_rejected_for_non_option() {
let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(5)));
let value = Value::none();
assert!(tc.validate(&value).is_err());
}
#[test]
fn test_validate_none_accepted_for_option() {
let tc = TypeConstraint::unconstrained(Type::Option(Box::new(Type::Utf8)));
let value = Value::none();
assert!(tc.validate(&value).is_ok());
}
#[test]
fn test_to_string() {
let tc1 = TypeConstraint::unconstrained(Type::Utf8);
assert_eq!(tc1.to_string(), "Utf8");
let tc2 = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(50)));
assert_eq!(tc2.to_string(), "Utf8(50)");
let tc3 = TypeConstraint::with_constraint(
Type::Decimal,
Constraint::PrecisionScale(Precision::new(10), Scale::new(2)),
);
assert_eq!(tc3.to_string(), "Decimal(10,2)");
}
}