use serde::{Deserialize, Serialize};
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TimeUnit {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
impl TimeUnit {
#[allow(dead_code)]
pub fn label(self) -> &'static str {
match self {
TimeUnit::Second => "second",
TimeUnit::Millisecond => "millisecond",
TimeUnit::Microsecond => "microsecond",
TimeUnit::Nanosecond => "nanosecond",
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum RivetType {
Bool,
Int16,
Int32,
Int64,
UInt64,
Float32,
Float64,
Decimal {
precision: u8,
scale: i8,
},
Date,
Time { unit: TimeUnit },
Timestamp {
unit: TimeUnit,
timezone: Option<String>,
},
String,
Text,
Binary,
Json,
Uuid,
Enum,
Interval,
List { inner: Box<RivetType> },
Unsupported { native_type: String, reason: String },
}
impl RivetType {
#[allow(dead_code)]
pub fn label(&self) -> String {
match self {
RivetType::Bool => "bool".into(),
RivetType::Int16 => "int16".into(),
RivetType::Int32 => "int32".into(),
RivetType::Int64 => "int64".into(),
RivetType::UInt64 => "uint64".into(),
RivetType::Float32 => "float32".into(),
RivetType::Float64 => "float64".into(),
RivetType::Decimal { precision, scale } => format!("decimal({precision},{scale})"),
RivetType::Date => "date".into(),
RivetType::Time { unit } => format!("time({})", unit.label()),
RivetType::Timestamp {
unit,
timezone: None,
} => format!("timestamp({})", unit.label()),
RivetType::Timestamp {
unit,
timezone: Some(tz),
} => format!("timestamp_tz({},{tz})", unit.label()),
RivetType::String => "string".into(),
RivetType::Text => "text".into(),
RivetType::Binary => "binary".into(),
RivetType::Json => "json".into(),
RivetType::Uuid => "uuid".into(),
RivetType::Enum => "enum".into(),
RivetType::Interval => "interval".into(),
RivetType::List { inner } => format!("list<{}>", inner.label()),
RivetType::Unsupported { native_type, .. } => format!("unsupported({native_type})"),
}
}
#[allow(dead_code)]
pub fn is_unsupported(&self) -> bool {
match self {
RivetType::Unsupported { .. } => true,
RivetType::List { inner } => inner.is_unsupported(),
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn label_includes_decimal_precision_and_scale() {
assert_eq!(
RivetType::Decimal {
precision: 18,
scale: 2,
}
.label(),
"decimal(18,2)"
);
}
#[test]
fn label_distinguishes_timestamp_with_and_without_timezone() {
let naive = RivetType::Timestamp {
unit: TimeUnit::Microsecond,
timezone: None,
};
let tz = RivetType::Timestamp {
unit: TimeUnit::Microsecond,
timezone: Some("UTC".into()),
};
assert_eq!(naive.label(), "timestamp(microsecond)");
assert_eq!(tz.label(), "timestamp_tz(microsecond,UTC)");
assert_ne!(naive, tz, "tz=None and tz=Some(\"UTC\") must NOT be equal");
}
#[test]
fn unsupported_carries_actionable_context() {
let t = RivetType::Unsupported {
native_type: "interval".into(),
reason: "Arrow Interval mapping not implemented yet".into(),
};
assert!(t.is_unsupported());
assert_eq!(t.label(), "unsupported(interval)");
}
#[test]
fn json_serialization_uses_kind_tag() {
let t = RivetType::Decimal {
precision: 10,
scale: 3,
};
let json: serde_json::Value =
serde_json::from_str(&serde_json::to_string(&t).expect("serialize")).expect("parse");
assert_eq!(json["kind"], "decimal");
assert_eq!(json["precision"], 10);
assert_eq!(json["scale"], 3);
}
#[test]
fn time_unit_labels_are_stable() {
assert_eq!(TimeUnit::Second.label(), "second");
assert_eq!(TimeUnit::Millisecond.label(), "millisecond");
assert_eq!(TimeUnit::Microsecond.label(), "microsecond");
assert_eq!(TimeUnit::Nanosecond.label(), "nanosecond");
}
}