use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum LogicalType {
Any,
Null,
Bool,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
String,
Bytes,
Date,
Time,
Timestamp,
Duration,
ZonedTime,
ZonedDatetime,
List(Box<LogicalType>),
Map {
key: Box<LogicalType>,
value: Box<LogicalType>,
},
Struct(Vec<(String, LogicalType)>),
Node,
Edge,
Path,
Vector(usize),
}
impl LogicalType {
#[must_use]
pub const fn is_numeric(&self) -> bool {
matches!(
self,
LogicalType::Int8
| LogicalType::Int16
| LogicalType::Int32
| LogicalType::Int64
| LogicalType::Float32
| LogicalType::Float64
)
}
#[must_use]
pub const fn is_integer(&self) -> bool {
matches!(
self,
LogicalType::Int8 | LogicalType::Int16 | LogicalType::Int32 | LogicalType::Int64
)
}
#[must_use]
pub const fn is_float(&self) -> bool {
matches!(self, LogicalType::Float32 | LogicalType::Float64)
}
#[must_use]
pub const fn is_temporal(&self) -> bool {
matches!(
self,
LogicalType::Date
| LogicalType::Time
| LogicalType::Timestamp
| LogicalType::Duration
| LogicalType::ZonedTime
| LogicalType::ZonedDatetime
)
}
#[must_use]
pub const fn is_graph_element(&self) -> bool {
matches!(
self,
LogicalType::Node | LogicalType::Edge | LogicalType::Path
)
}
#[must_use]
pub const fn is_nullable(&self) -> bool {
true
}
#[must_use]
pub fn list_element_type(&self) -> Option<&LogicalType> {
match self {
LogicalType::List(elem) => Some(elem),
_ => None,
}
}
#[must_use]
pub const fn is_vector(&self) -> bool {
matches!(self, LogicalType::Vector(_))
}
#[must_use]
pub const fn vector_dimensions(&self) -> Option<usize> {
match self {
LogicalType::Vector(dim) => Some(*dim),
_ => None,
}
}
#[must_use]
pub fn can_coerce_from(&self, other: &LogicalType) -> bool {
if self == other {
return true;
}
if matches!(self, LogicalType::Any) {
return true;
}
if matches!(other, LogicalType::Null) && self.is_nullable() {
return true;
}
match (other, self) {
(LogicalType::Int8, LogicalType::Int16 | LogicalType::Int32 | LogicalType::Int64) => {
true
}
(LogicalType::Int16, LogicalType::Int32 | LogicalType::Int64) => true,
(LogicalType::Int32, LogicalType::Int64) => true,
(LogicalType::Float32, LogicalType::Float64) => true,
(
LogicalType::Int8 | LogicalType::Int16 | LogicalType::Int32,
LogicalType::Float32 | LogicalType::Float64,
) => true,
(LogicalType::Int64, LogicalType::Float64) => true,
(LogicalType::Time, LogicalType::ZonedTime) => true,
(LogicalType::Timestamp, LogicalType::ZonedDatetime) => true,
_ => false,
}
}
#[must_use]
pub fn common_type(&self, other: &LogicalType) -> Option<LogicalType> {
if self == other {
return Some(self.clone());
}
if matches!(self, LogicalType::Any) {
return Some(other.clone());
}
if matches!(other, LogicalType::Any) {
return Some(self.clone());
}
if matches!(self, LogicalType::Null) {
return Some(other.clone());
}
if matches!(other, LogicalType::Null) {
return Some(self.clone());
}
if self.is_numeric() && other.is_numeric() {
if self.is_float() || other.is_float() {
return Some(LogicalType::Float64);
}
return Some(LogicalType::Int64);
}
match (self, other) {
(LogicalType::Time, LogicalType::ZonedTime)
| (LogicalType::ZonedTime, LogicalType::Time) => {
return Some(LogicalType::ZonedTime);
}
(LogicalType::Timestamp, LogicalType::ZonedDatetime)
| (LogicalType::ZonedDatetime, LogicalType::Timestamp) => {
return Some(LogicalType::ZonedDatetime);
}
_ => {}
}
None
}
}
impl fmt::Display for LogicalType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LogicalType::Any => write!(f, "ANY"),
LogicalType::Null => write!(f, "NULL"),
LogicalType::Bool => write!(f, "BOOL"),
LogicalType::Int8 => write!(f, "INT8"),
LogicalType::Int16 => write!(f, "INT16"),
LogicalType::Int32 => write!(f, "INT32"),
LogicalType::Int64 => write!(f, "INT64"),
LogicalType::Float32 => write!(f, "FLOAT32"),
LogicalType::Float64 => write!(f, "FLOAT64"),
LogicalType::String => write!(f, "STRING"),
LogicalType::Bytes => write!(f, "BYTES"),
LogicalType::Date => write!(f, "DATE"),
LogicalType::Time => write!(f, "TIME"),
LogicalType::Timestamp => write!(f, "TIMESTAMP"),
LogicalType::Duration => write!(f, "DURATION"),
LogicalType::ZonedTime => write!(f, "ZONED TIME"),
LogicalType::ZonedDatetime => write!(f, "ZONED DATETIME"),
LogicalType::List(elem) => write!(f, "LIST<{elem}>"),
LogicalType::Map { key, value } => write!(f, "MAP<{key}, {value}>"),
LogicalType::Struct(fields) => {
write!(f, "STRUCT<")?;
for (i, (name, ty)) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{name}: {ty}")?;
}
write!(f, ">")
}
LogicalType::Node => write!(f, "NODE"),
LogicalType::Edge => write!(f, "EDGE"),
LogicalType::Path => write!(f, "PATH"),
LogicalType::Vector(dim) => write!(f, "VECTOR({dim})"),
}
}
}
impl Default for LogicalType {
fn default() -> Self {
LogicalType::Any
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numeric_checks() {
assert!(LogicalType::Int64.is_numeric());
assert!(LogicalType::Float64.is_numeric());
assert!(!LogicalType::String.is_numeric());
assert!(LogicalType::Int64.is_integer());
assert!(!LogicalType::Float64.is_integer());
assert!(LogicalType::Float64.is_float());
assert!(!LogicalType::Int64.is_float());
}
#[test]
fn test_coercion() {
assert!(LogicalType::Int64.can_coerce_from(&LogicalType::Int64));
assert!(LogicalType::Int64.can_coerce_from(&LogicalType::Null));
assert!(LogicalType::String.can_coerce_from(&LogicalType::Null));
assert!(LogicalType::Int64.can_coerce_from(&LogicalType::Int32));
assert!(LogicalType::Int32.can_coerce_from(&LogicalType::Int16));
assert!(!LogicalType::Int32.can_coerce_from(&LogicalType::Int64));
assert!(LogicalType::Float64.can_coerce_from(&LogicalType::Float32));
assert!(LogicalType::Float64.can_coerce_from(&LogicalType::Int64));
assert!(LogicalType::Float32.can_coerce_from(&LogicalType::Int32));
}
#[test]
fn test_common_type() {
assert_eq!(
LogicalType::Int64.common_type(&LogicalType::Int64),
Some(LogicalType::Int64)
);
assert_eq!(
LogicalType::Int32.common_type(&LogicalType::Int64),
Some(LogicalType::Int64)
);
assert_eq!(
LogicalType::Int64.common_type(&LogicalType::Float64),
Some(LogicalType::Float64)
);
assert_eq!(
LogicalType::Null.common_type(&LogicalType::String),
Some(LogicalType::String)
);
assert_eq!(LogicalType::String.common_type(&LogicalType::Int64), None);
}
#[test]
fn test_display() {
assert_eq!(LogicalType::Int64.to_string(), "INT64");
assert_eq!(
LogicalType::List(Box::new(LogicalType::String)).to_string(),
"LIST<STRING>"
);
assert_eq!(
LogicalType::Map {
key: Box::new(LogicalType::String),
value: Box::new(LogicalType::Int64)
}
.to_string(),
"MAP<STRING, INT64>"
);
}
#[test]
fn test_vector_type() {
let v384 = LogicalType::Vector(384);
let v768 = LogicalType::Vector(768);
let v1536 = LogicalType::Vector(1536);
assert!(v384.is_vector());
assert!(v768.is_vector());
assert!(!LogicalType::Float64.is_vector());
assert!(!LogicalType::List(Box::new(LogicalType::Float32)).is_vector());
assert_eq!(v384.vector_dimensions(), Some(384));
assert_eq!(v768.vector_dimensions(), Some(768));
assert_eq!(v1536.vector_dimensions(), Some(1536));
assert_eq!(LogicalType::Float64.vector_dimensions(), None);
assert_eq!(v384.to_string(), "VECTOR(384)");
assert_eq!(v768.to_string(), "VECTOR(768)");
assert_eq!(v1536.to_string(), "VECTOR(1536)");
assert_eq!(LogicalType::Vector(384), LogicalType::Vector(384));
assert_ne!(LogicalType::Vector(384), LogicalType::Vector(768));
assert!(!v384.is_numeric());
assert!(!v384.is_integer());
assert!(!v384.is_float());
}
}