use std::borrow::Cow;
use crate::model::Id;
use crate::util::{parse_date_rfc3339, parse_datetime_rfc3339, parse_time_rfc3339};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum DataType {
Boolean = 1,
Integer = 2,
Float = 3,
Decimal = 4,
Text = 5,
Bytes = 6,
Date = 7,
Time = 8,
Datetime = 9,
Schedule = 10,
Point = 11,
Rect = 12,
Embedding = 13,
}
impl DataType {
pub fn from_u8(v: u8) -> Option<DataType> {
match v {
1 => Some(DataType::Boolean),
2 => Some(DataType::Integer),
3 => Some(DataType::Float),
4 => Some(DataType::Decimal),
5 => Some(DataType::Text),
6 => Some(DataType::Bytes),
7 => Some(DataType::Date),
8 => Some(DataType::Time),
9 => Some(DataType::Datetime),
10 => Some(DataType::Schedule),
11 => Some(DataType::Point),
12 => Some(DataType::Rect),
13 => Some(DataType::Embedding),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum EmbeddingSubType {
Float32 = 0,
Int8 = 1,
Binary = 2,
}
impl EmbeddingSubType {
pub fn from_u8(v: u8) -> Option<EmbeddingSubType> {
match v {
0 => Some(EmbeddingSubType::Float32),
1 => Some(EmbeddingSubType::Int8),
2 => Some(EmbeddingSubType::Binary),
_ => None,
}
}
pub fn bytes_for_dims(self, dims: usize) -> usize {
match self {
EmbeddingSubType::Float32 => dims * 4,
EmbeddingSubType::Int8 => dims,
EmbeddingSubType::Binary => dims.div_ceil(8),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DecimalMantissa<'a> {
I64(i64),
Big(Cow<'a, [u8]>),
}
impl DecimalMantissa<'_> {
pub fn has_trailing_zeros(&self) -> bool {
match self {
DecimalMantissa::I64(v) => *v != 0 && *v % 10 == 0,
DecimalMantissa::Big(bytes) => {
!bytes.is_empty() && bytes[bytes.len() - 1] == 0
}
}
}
pub fn is_zero(&self) -> bool {
match self {
DecimalMantissa::I64(v) => *v == 0,
DecimalMantissa::Big(bytes) => bytes.iter().all(|b| *b == 0),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Value<'a> {
Boolean(bool),
Integer {
value: i64,
unit: Option<Id>,
},
Float {
value: f64,
unit: Option<Id>,
},
Decimal {
exponent: i32,
mantissa: DecimalMantissa<'a>,
unit: Option<Id>,
},
Text {
value: Cow<'a, str>,
language: Option<Id>,
},
Bytes(Cow<'a, [u8]>),
Date(Cow<'a, str>),
Time(Cow<'a, str>),
Datetime(Cow<'a, str>),
Schedule(Cow<'a, str>),
Point {
lat: f64,
lon: f64,
alt: Option<f64>,
},
Rect {
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
},
Embedding {
sub_type: EmbeddingSubType,
dims: usize,
data: Cow<'a, [u8]>,
},
}
impl Value<'_> {
pub fn data_type(&self) -> DataType {
match self {
Value::Boolean(_) => DataType::Boolean,
Value::Integer { .. } => DataType::Integer,
Value::Float { .. } => DataType::Float,
Value::Decimal { .. } => DataType::Decimal,
Value::Text { .. } => DataType::Text,
Value::Bytes(_) => DataType::Bytes,
Value::Date { .. } => DataType::Date,
Value::Time { .. } => DataType::Time,
Value::Datetime { .. } => DataType::Datetime,
Value::Schedule(_) => DataType::Schedule,
Value::Point { .. } => DataType::Point,
Value::Rect { .. } => DataType::Rect,
Value::Embedding { .. } => DataType::Embedding,
}
}
pub fn validate(&self) -> Option<&'static str> {
match self {
Value::Float { value, .. } => {
if value.is_nan() {
return Some("NaN is not allowed in Float");
}
}
Value::Decimal { exponent, mantissa, .. } => {
if mantissa.is_zero() && *exponent != 0 {
return Some("zero DECIMAL must have exponent 0");
}
if !mantissa.is_zero() && mantissa.has_trailing_zeros() {
return Some("DECIMAL mantissa has trailing zeros (not normalized)");
}
}
Value::Point { lat, lon, alt } => {
if *lat < -90.0 || *lat > 90.0 {
return Some("latitude out of range [-90, +90]");
}
if *lon < -180.0 || *lon > 180.0 {
return Some("longitude out of range [-180, +180]");
}
if lat.is_nan() || lon.is_nan() {
return Some("NaN is not allowed in Point coordinates");
}
if let Some(a) = alt {
if a.is_nan() {
return Some("NaN is not allowed in Point altitude");
}
}
}
Value::Rect { min_lat, min_lon, max_lat, max_lon } => {
if *min_lat < -90.0 || *min_lat > 90.0 || *max_lat < -90.0 || *max_lat > 90.0 {
return Some("latitude out of range [-90, +90]");
}
if *min_lon < -180.0 || *min_lon > 180.0 || *max_lon < -180.0 || *max_lon > 180.0 {
return Some("longitude out of range [-180, +180]");
}
if min_lat.is_nan() || min_lon.is_nan() || max_lat.is_nan() || max_lon.is_nan() {
return Some("NaN is not allowed in Rect coordinates");
}
}
Value::Date(s) => {
if parse_date_rfc3339(s).is_err() {
return Some("Invalid RFC 3339 date format");
}
}
Value::Time(s) => {
if parse_time_rfc3339(s).is_err() {
return Some("Invalid RFC 3339 time format");
}
}
Value::Datetime(s) => {
if parse_datetime_rfc3339(s).is_err() {
return Some("Invalid RFC 3339 datetime format");
}
}
Value::Embedding {
sub_type,
dims,
data,
} => {
let expected = sub_type.bytes_for_dims(*dims);
if data.len() != expected {
return Some("embedding data length doesn't match dims");
}
if *sub_type == EmbeddingSubType::Float32 {
for chunk in data.chunks_exact(4) {
let f = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
if f.is_nan() {
return Some("NaN is not allowed in float32 embedding");
}
}
}
}
_ => {}
}
None
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PropertyValue<'a> {
pub property: Id,
pub value: Value<'a>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Property {
pub id: Id,
pub data_type: DataType,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_bytes_for_dims() {
assert_eq!(EmbeddingSubType::Float32.bytes_for_dims(10), 40);
assert_eq!(EmbeddingSubType::Int8.bytes_for_dims(10), 10);
assert_eq!(EmbeddingSubType::Binary.bytes_for_dims(10), 2);
assert_eq!(EmbeddingSubType::Binary.bytes_for_dims(8), 1);
assert_eq!(EmbeddingSubType::Binary.bytes_for_dims(9), 2);
}
#[test]
fn test_value_validation_nan() {
assert!(Value::Float { value: f64::NAN, unit: None }.validate().is_some());
assert!(Value::Float { value: f64::INFINITY, unit: None }.validate().is_none());
assert!(Value::Float { value: -f64::INFINITY, unit: None }.validate().is_none());
assert!(Value::Float { value: 42.0, unit: None }.validate().is_none());
}
#[test]
fn test_value_validation_point() {
assert!(Value::Point { lat: 91.0, lon: 0.0, alt: None }.validate().is_some());
assert!(Value::Point { lat: -91.0, lon: 0.0, alt: None }.validate().is_some());
assert!(Value::Point { lat: 0.0, lon: 181.0, alt: None }.validate().is_some());
assert!(Value::Point { lat: 0.0, lon: -181.0, alt: None }.validate().is_some());
assert!(Value::Point { lat: 90.0, lon: 180.0, alt: None }.validate().is_none());
assert!(Value::Point { lat: -90.0, lon: -180.0, alt: None }.validate().is_none());
assert!(Value::Point { lat: 0.0, lon: 0.0, alt: Some(1000.0) }.validate().is_none());
assert!(Value::Point { lat: 0.0, lon: 0.0, alt: Some(f64::NAN) }.validate().is_some());
}
#[test]
fn test_value_validation_rect() {
assert!(Value::Rect { min_lat: -91.0, min_lon: 0.0, max_lat: 0.0, max_lon: 0.0 }.validate().is_some());
assert!(Value::Rect { min_lat: 0.0, min_lon: 0.0, max_lat: 91.0, max_lon: 0.0 }.validate().is_some());
assert!(Value::Rect { min_lat: 0.0, min_lon: -181.0, max_lat: 0.0, max_lon: 0.0 }.validate().is_some());
assert!(Value::Rect { min_lat: 0.0, min_lon: 0.0, max_lat: 0.0, max_lon: 181.0 }.validate().is_some());
assert!(Value::Rect { min_lat: 24.5, min_lon: -125.0, max_lat: 49.4, max_lon: -66.9 }.validate().is_none());
assert!(Value::Rect { min_lat: f64::NAN, min_lon: 0.0, max_lat: 0.0, max_lon: 0.0 }.validate().is_some());
}
#[test]
fn test_decimal_normalization() {
let zero_bad = Value::Decimal {
exponent: 1,
mantissa: DecimalMantissa::I64(0),
unit: None,
};
assert!(zero_bad.validate().is_some());
let trailing = Value::Decimal {
exponent: 0,
mantissa: DecimalMantissa::I64(1230),
unit: None,
};
assert!(trailing.validate().is_some());
let valid = Value::Decimal {
exponent: -2,
mantissa: DecimalMantissa::I64(1234),
unit: None,
};
assert!(valid.validate().is_none());
}
}