use std::sync::Arc;
use serde::{Deserialize, Deserializer, Serialize};
use smallvec::SmallVec;
use crate::db_string::DbString;
use crate::error::{CoreError, CoreResult};
use crate::extension_type_ids::ExtensionTypeId;
use crate::identity::{BindingTableId, EdgeId, GraphId, NodeId, RecordTypeId};
use crate::json_value::JsonValue;
pub const MAX_VECTOR_DIMENSION: usize = u16::MAX as usize;
#[derive(Clone, Debug, Deserialize, Serialize)]
#[non_exhaustive]
pub enum Value {
Bool(bool),
Int(i64),
Uint(u64),
Int128(#[serde(with = "serde_i128_le")] i128),
Uint128(#[serde(with = "serde_u128_le")] u128),
Float(f64),
Float32(f32),
Decimal(#[serde(with = "serde_decimal_str")] rust_decimal::Decimal),
String(DbString),
Bytes(Arc<[u8]>),
List(Vec<Value>),
Record(Box<Record>),
RecordTyped(Box<RecordTyped>),
Path(Box<Path>),
NodeRef(NodeId),
EdgeRef(EdgeId),
GraphRef(GraphId),
TableRef(BindingTableId),
ZonedDateTime(Box<jiff::Zoned>),
LocalDateTime(jiff::civil::DateTime),
Date(jiff::civil::Date),
ZonedTime(Box<jiff::Zoned>),
LocalTime(jiff::civil::Time),
Duration(Box<jiff::Span>),
Extended {
type_id: ExtensionTypeId,
payload: Arc<[u8]>,
},
Null,
Uuid(uuid::Uuid),
Vector(VectorValue),
Json(JsonValue),
}
const _: () = assert!(core::mem::size_of::<Value>() <= 32);
impl Value {
pub const ALL: &[fn() -> Self] = &[
|| Self::Bool(false),
|| Self::Int(0),
|| Self::Uint(0),
|| Self::Int128(0),
|| Self::Uint128(0),
|| Self::Float(0.0),
|| Self::Float32(0.0),
|| Self::Decimal(rust_decimal::Decimal::ZERO),
|| Self::String(value_variant_string("value.all.string")),
|| Self::Bytes(Arc::from([0_u8])),
|| Self::List(Vec::new()),
|| Self::Record(Box::new(Record::Open(SmallVec::new()))),
|| {
Self::RecordTyped(Box::new(RecordTyped {
type_id: RecordTypeId::new(1),
values: SmallVec::new(),
}))
},
|| {
Self::Path(Box::new(Path {
graph: GraphId::new(1),
start: NodeId::new(1),
segments: SmallVec::new(),
}))
},
|| Self::NodeRef(NodeId::new(1)),
|| Self::EdgeRef(EdgeId::new(1)),
|| Self::GraphRef(GraphId::new(1)),
|| Self::TableRef(BindingTableId::new(1)),
|| Self::ZonedDateTime(Box::new(value_variant_zoned())),
|| Self::LocalDateTime("2024-01-01T00:00:00".parse().unwrap()),
|| Self::Date("2024-01-01".parse().unwrap()),
|| Self::ZonedTime(Box::new(value_variant_zoned())),
|| Self::LocalTime("00:00:00".parse().unwrap()),
|| Self::Duration(Box::new("PT1S".parse().unwrap())),
|| Self::Extended {
type_id: ExtensionTypeId::FIRST_PARTY_MIN,
payload: Arc::from([0_u8]),
},
|| Self::Null,
|| Self::Uuid(uuid::Uuid::nil()),
|| Self::Vector(VectorValue::new(vec![0.0]).expect("fixture vector is valid")),
|| Self::Json(JsonValue::new(serde_json::json!({"fixture": true})).unwrap()),
];
pub const VARIANT_COUNT: usize = Self::ALL.len();
#[must_use]
pub fn variant_name(&self) -> &'static str {
match self {
Self::Bool(_) => "Bool",
Self::Int(_) => "Int",
Self::Uint(_) => "Uint",
Self::Int128(_) => "Int128",
Self::Uint128(_) => "Uint128",
Self::Float(_) => "Float",
Self::Float32(_) => "Float32",
Self::Decimal(_) => "Decimal",
Self::String(_) => "String",
Self::Bytes(_) => "Bytes",
Self::List(_) => "List",
Self::Record(_) => "Record",
Self::RecordTyped(_) => "RecordTyped",
Self::Path(_) => "Path",
Self::NodeRef(_) => "NodeRef",
Self::EdgeRef(_) => "EdgeRef",
Self::GraphRef(_) => "GraphRef",
Self::TableRef(_) => "TableRef",
Self::ZonedDateTime(_) => "ZonedDateTime",
Self::LocalDateTime(_) => "LocalDateTime",
Self::Date(_) => "Date",
Self::ZonedTime(_) => "ZonedTime",
Self::LocalTime(_) => "LocalTime",
Self::Duration(_) => "Duration",
Self::Extended { .. } => "Extended",
Self::Null => "Null",
Self::Uuid(_) => "Uuid",
Self::Vector(_) => "Vector",
Self::Json(_) => "Json",
}
}
}
fn value_variant_string(name: &str) -> DbString {
crate::db_string(name).expect("Value::ALL fixture strings fit DB string cap")
}
fn value_variant_zoned() -> jiff::Zoned {
jiff::Timestamp::new(0, 0)
.expect("Value::ALL timestamp fixture is in range")
.to_zoned(jiff::tz::TimeZone::UTC)
}
impl PartialEq for Value {
fn eq(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Bool(lhs), Self::Bool(rhs)) => lhs == rhs,
(Self::Int(lhs), Self::Int(rhs)) => lhs == rhs,
(Self::Uint(lhs), Self::Uint(rhs)) => lhs == rhs,
(Self::Int128(lhs), Self::Int128(rhs)) => lhs == rhs,
(Self::Uint128(lhs), Self::Uint128(rhs)) => lhs == rhs,
(Self::Float(lhs), Self::Float(rhs)) => lhs == rhs || (lhs.is_nan() && rhs.is_nan()),
(Self::Float32(lhs), Self::Float32(rhs)) => {
lhs == rhs || (lhs.is_nan() && rhs.is_nan())
}
(Self::Decimal(lhs), Self::Decimal(rhs)) => lhs == rhs,
(Self::String(lhs), Self::String(rhs)) => lhs == rhs,
(Self::Bytes(lhs), Self::Bytes(rhs)) => lhs == rhs,
(Self::List(lhs), Self::List(rhs)) => lhs == rhs,
(Self::Record(lhs), Self::Record(rhs)) => lhs == rhs,
(Self::RecordTyped(lhs), Self::RecordTyped(rhs)) => lhs == rhs,
(Self::Path(lhs), Self::Path(rhs)) => lhs == rhs,
(Self::NodeRef(lhs), Self::NodeRef(rhs)) => lhs == rhs,
(Self::EdgeRef(lhs), Self::EdgeRef(rhs)) => lhs == rhs,
(Self::GraphRef(lhs), Self::GraphRef(rhs)) => lhs == rhs,
(Self::TableRef(lhs), Self::TableRef(rhs)) => lhs == rhs,
(Self::ZonedDateTime(lhs), Self::ZonedDateTime(rhs)) => lhs == rhs,
(Self::LocalDateTime(lhs), Self::LocalDateTime(rhs)) => lhs == rhs,
(Self::Date(lhs), Self::Date(rhs)) => lhs == rhs,
(Self::ZonedTime(lhs), Self::ZonedTime(rhs)) => lhs == rhs,
(Self::LocalTime(lhs), Self::LocalTime(rhs)) => lhs == rhs,
(Self::Duration(lhs), Self::Duration(rhs)) => lhs.fieldwise() == rhs.fieldwise(),
(
Self::Extended {
type_id: lhs_type_id,
payload: lhs_payload,
},
Self::Extended {
type_id: rhs_type_id,
payload: rhs_payload,
},
) => lhs_type_id == rhs_type_id && lhs_payload == rhs_payload,
(Self::Null, Self::Null) => true,
(Self::Uuid(lhs), Self::Uuid(rhs)) => lhs == rhs,
(Self::Vector(lhs), Self::Vector(rhs)) => lhs == rhs,
(Self::Json(lhs), Self::Json(rhs)) => lhs == rhs,
_ => false,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize)]
#[serde(transparent)]
pub struct VectorValue {
components: Arc<[f32]>,
}
impl VectorValue {
pub fn new(components: impl Into<Vec<f32>>) -> CoreResult<Self> {
let components = components.into();
if components.is_empty() {
return Err(CoreError::VectorEmpty);
}
if components.len() > MAX_VECTOR_DIMENSION {
return Err(CoreError::VectorTooLarge {
got: components.len(),
max: MAX_VECTOR_DIMENSION,
});
}
for (index, value) in components.iter().copied().enumerate() {
if !value.is_finite() {
return Err(CoreError::VectorComponentNotFinite { index, value });
}
}
Ok(Self {
components: Arc::from(components),
})
}
#[must_use]
pub fn dimension(&self) -> usize {
self.components.len()
}
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.components
}
#[must_use]
pub fn as_arc(&self) -> Arc<[f32]> {
Arc::clone(&self.components)
}
}
impl TryFrom<Vec<f32>> for VectorValue {
type Error = CoreError;
fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<VectorValue> for Vec<f32> {
fn from(value: VectorValue) -> Self {
value.components.as_ref().to_vec()
}
}
impl<'de> Deserialize<'de> for VectorValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Vec::<f32>::deserialize(deserializer)
.and_then(|components| Self::new(components).map_err(serde::de::Error::custom))
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub enum Record {
Open(SmallVec<[(DbString, Value); 4]>),
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct RecordTyped {
pub type_id: RecordTypeId,
pub values: SmallVec<[Option<Value>; 4]>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct Path {
pub graph: GraphId,
pub start: NodeId,
pub segments: SmallVec<[PathSegment; 4]>,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct PathSegment {
pub edge: EdgeId,
pub direction: EdgeDirection,
pub node: NodeId,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum EdgeDirection {
Outgoing,
Incoming,
Undirected,
}
mod serde_i128_le {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(super) fn serialize<S>(value: &i128, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value.to_le_bytes().serialize(serializer)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<i128, D::Error>
where
D: Deserializer<'de>,
{
<[u8; 16]>::deserialize(deserializer).map(i128::from_le_bytes)
}
}
mod serde_u128_le {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(super) fn serialize<S>(value: &u128, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value.to_le_bytes().serialize(serializer)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
where
D: Deserializer<'de>,
{
<[u8; 16]>::deserialize(deserializer).map(u128::from_le_bytes)
}
}
mod serde_decimal_str {
use std::str::FromStr;
use serde::{Deserialize, Deserializer, Serializer};
pub(super) fn serialize<S>(
value: &rust_decimal::Decimal,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&value.to_string())
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
rust_decimal::Decimal::from_str(&value).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests;