use std::collections::HashMap;
use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
use arrow_schema::extension::{Json as ArrowJson, Uuid as ArrowUuid};
use serde::Serialize;
use std::sync::Arc;
use super::{RivetType, SourceColumn, TimeUnit, TypeFidelity};
pub const META_NATIVE_TYPE: &str = "rivet.native_type";
pub const META_LOGICAL_TYPE: &str = "rivet.logical_type";
pub const META_FIDELITY: &str = "rivet.fidelity";
#[derive(Debug, Clone, Serialize)]
pub struct TypeMapping {
pub column_name: String,
pub source_native_type: String,
pub rivet_type: RivetType,
#[serde(serialize_with = "serialize_arrow_type_opt")]
pub arrow_type: Option<DataType>,
pub fidelity: TypeFidelity,
pub nullable: bool,
pub warnings: Vec<String>,
}
impl TypeMapping {
pub fn from_source(source: &SourceColumn, rivet_type: RivetType) -> Self {
let fidelity = derive_fidelity(&rivet_type);
let arrow_type = rivet_type_to_arrow(&rivet_type);
Self {
column_name: source.name.clone(),
source_native_type: source.native_type.clone(),
rivet_type,
arrow_type,
fidelity,
nullable: source.nullable,
warnings: Vec::new(),
}
}
#[allow(dead_code)]
pub fn with_warning(mut self, msg: impl Into<String>) -> Self {
self.warnings.push(msg.into());
self
}
}
fn serialize_arrow_type_opt<S: serde::Serializer>(
v: &Option<DataType>,
s: S,
) -> std::result::Result<S::Ok, S::Error> {
match v {
None => s.serialize_none(),
Some(dt) => s.serialize_some(&format!("{dt:?}")),
}
}
pub fn rivet_type_to_arrow(t: &RivetType) -> Option<DataType> {
match t {
RivetType::Bool => Some(DataType::Boolean),
RivetType::Int16 => Some(DataType::Int16),
RivetType::Int32 => Some(DataType::Int32),
RivetType::Int64 => Some(DataType::Int64),
RivetType::UInt64 => Some(DataType::UInt64),
RivetType::Float32 => Some(DataType::Float32),
RivetType::Float64 => Some(DataType::Float64),
RivetType::Decimal { precision, scale } => Some(decimal_arrow_type(*precision, *scale)),
RivetType::Date => Some(DataType::Date32),
RivetType::Time { unit } => Some(DataType::Time64(arrow_unit(*unit))),
RivetType::Timestamp { unit, timezone } => Some(DataType::Timestamp(
arrow_unit(*unit),
timezone.as_deref().map(Into::into),
)),
RivetType::Uuid => Some(DataType::FixedSizeBinary(16)),
RivetType::String | RivetType::Text | RivetType::Json | RivetType::Enum => {
Some(DataType::Utf8)
}
RivetType::Binary => Some(DataType::Binary),
RivetType::Interval => Some(DataType::Utf8),
RivetType::List { inner } => rivet_type_to_arrow(inner)
.map(|inner_dt| DataType::List(Arc::new(Field::new("item", inner_dt, true)))),
RivetType::Unsupported { .. } => None,
}
}
fn decimal_arrow_type(precision: u8, scale: i8) -> DataType {
if precision <= 38 {
DataType::Decimal128(precision, scale)
} else {
DataType::Decimal256(precision, scale)
}
}
fn arrow_unit(u: TimeUnit) -> ArrowTimeUnit {
match u {
TimeUnit::Second => ArrowTimeUnit::Second,
TimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
TimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
TimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
}
}
pub fn derive_fidelity(t: &RivetType) -> TypeFidelity {
match t {
RivetType::Bool
| RivetType::Int16
| RivetType::Int32
| RivetType::Int64
| RivetType::UInt64
| RivetType::Float32
| RivetType::Float64
| RivetType::Decimal { .. }
| RivetType::Date
| RivetType::Time { .. }
| RivetType::Timestamp { .. }
| RivetType::String
| RivetType::Text
| RivetType::Binary => TypeFidelity::Exact,
RivetType::Uuid => TypeFidelity::Exact,
RivetType::Json => TypeFidelity::LogicalString,
RivetType::Enum => TypeFidelity::Compatible,
RivetType::Interval => TypeFidelity::Compatible,
RivetType::List { inner } => match derive_fidelity(inner) {
f @ (TypeFidelity::Unsupported | TypeFidelity::Lossy) => f,
_ => TypeFidelity::Compatible,
},
RivetType::Unsupported { .. } => TypeFidelity::Unsupported,
}
}
pub fn build_arrow_field(mapping: &TypeMapping) -> Option<Field> {
let dt = mapping.arrow_type.clone()?;
let mut metadata: HashMap<String, String> = HashMap::new();
metadata.insert(META_NATIVE_TYPE.into(), mapping.source_native_type.clone());
metadata.insert(META_FIDELITY.into(), mapping.fidelity.label().into());
if let Some(logical) = logical_type_label(&mapping.rivet_type) {
metadata.insert(META_LOGICAL_TYPE.into(), logical.into());
}
let mut field = Field::new(&mapping.column_name, dt, mapping.nullable).with_metadata(metadata);
match mapping.rivet_type {
RivetType::Json => {
field
.try_with_extension_type(ArrowJson::default())
.expect("Json extension only valid on Utf8/LargeUtf8 — invariant in mapping");
}
RivetType::Uuid => {
field
.try_with_extension_type(ArrowUuid)
.expect("Uuid extension only valid on FixedSizeBinary(16) — invariant in mapping");
}
_ => {}
}
Some(field)
}
fn logical_type_label(t: &RivetType) -> Option<&'static str> {
match t {
RivetType::Json => Some("json"),
RivetType::Uuid => Some("uuid"),
RivetType::Enum => Some("enum"),
RivetType::Interval => Some("interval"),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn list_fidelity_propagates_unsupported_element() {
let bad = RivetType::List {
inner: Box::new(RivetType::Unsupported {
native_type: "numeric".into(),
reason: "precision unavailable".into(),
}),
};
assert_eq!(derive_fidelity(&bad), TypeFidelity::Unsupported);
assert!(bad.is_unsupported());
let good = RivetType::List {
inner: Box::new(RivetType::Int32),
};
assert_eq!(derive_fidelity(&good), TypeFidelity::Compatible);
assert!(!good.is_unsupported());
}
fn col(name: &str, native: &str) -> SourceColumn {
SourceColumn::simple(name, native, true)
}
#[test]
fn integer_types_map_one_to_one() {
for (rt, expected) in [
(RivetType::Bool, DataType::Boolean),
(RivetType::Int16, DataType::Int16),
(RivetType::Int32, DataType::Int32),
(RivetType::Int64, DataType::Int64),
(RivetType::UInt64, DataType::UInt64),
(RivetType::Float32, DataType::Float32),
(RivetType::Float64, DataType::Float64),
] {
assert_eq!(
rivet_type_to_arrow(&rt),
Some(expected),
"rivet type {rt:?}"
);
assert_eq!(derive_fidelity(&rt), TypeFidelity::Exact);
}
}
#[test]
fn decimal_p38_uses_decimal128() {
for p in [1u8, 18, 38] {
let dt = rivet_type_to_arrow(&RivetType::Decimal {
precision: p,
scale: 2,
})
.expect("decimal must map to an Arrow type");
assert_eq!(dt, DataType::Decimal128(p, 2), "precision={p}");
}
}
#[test]
fn decimal_above_38_escalates_to_decimal256() {
for p in [39u8, 76] {
let dt = rivet_type_to_arrow(&RivetType::Decimal {
precision: p,
scale: 9,
})
.expect("decimal must map to an Arrow type");
assert_eq!(
dt,
DataType::Decimal256(p, 9),
"precision={p} must become Decimal256"
);
}
}
#[test]
fn decimal_supports_negative_scale_for_postgres_numeric() {
let dt = rivet_type_to_arrow(&RivetType::Decimal {
precision: 5,
scale: -2,
})
.expect("decimal must map to an Arrow type");
assert_eq!(dt, DataType::Decimal128(5, -2));
}
#[test]
fn timestamp_preserves_timezone_semantics() {
let naive = RivetType::Timestamp {
unit: TimeUnit::Microsecond,
timezone: None,
};
let utc = RivetType::Timestamp {
unit: TimeUnit::Microsecond,
timezone: Some("UTC".into()),
};
assert_eq!(
rivet_type_to_arrow(&naive),
Some(DataType::Timestamp(ArrowTimeUnit::Microsecond, None))
);
assert_eq!(
rivet_type_to_arrow(&utc),
Some(DataType::Timestamp(
ArrowTimeUnit::Microsecond,
Some("UTC".into())
))
);
}
#[test]
fn unsupported_returns_no_arrow_type() {
let t = RivetType::Unsupported {
native_type: "interval".into(),
reason: "no mapping yet".into(),
};
assert_eq!(rivet_type_to_arrow(&t), None);
assert_eq!(derive_fidelity(&t), TypeFidelity::Unsupported);
}
#[test]
fn json_is_logical_string_with_metadata() {
let mapping = TypeMapping::from_source(&col("payload", "jsonb"), RivetType::Json);
assert_eq!(mapping.fidelity, TypeFidelity::LogicalString);
assert_eq!(mapping.arrow_type, Some(DataType::Utf8));
let field = build_arrow_field(&mapping).expect("field");
assert_eq!(field.data_type(), &DataType::Utf8);
assert_eq!(
field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
Some("jsonb")
);
assert_eq!(
field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
Some("json")
);
assert_eq!(
field.metadata().get(META_FIDELITY).map(String::as_str),
Some("logical_string")
);
}
#[test]
fn uuid_is_exact_fixed_size_binary_with_logical_metadata() {
let mapping = TypeMapping::from_source(&col("id", "uuid"), RivetType::Uuid);
assert_eq!(mapping.fidelity, TypeFidelity::Exact);
assert_eq!(mapping.arrow_type, Some(DataType::FixedSizeBinary(16)));
let field = build_arrow_field(&mapping).expect("field");
assert_eq!(
field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
Some("uuid")
);
assert_eq!(
field.metadata().get(META_FIDELITY).map(String::as_str),
Some("exact")
);
assert_eq!(
field
.metadata()
.get("ARROW:extension:name")
.map(String::as_str),
Some("arrow.uuid")
);
}
#[test]
fn plain_string_has_no_logical_type_metadata() {
let mapping = TypeMapping::from_source(&col("name", "text"), RivetType::String);
let field = build_arrow_field(&mapping).expect("field");
assert!(
!field.metadata().contains_key(META_LOGICAL_TYPE),
"plain string columns must NOT carry rivet.logical_type so consumers \
can distinguish them from json/uuid columns"
);
assert_eq!(
field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
Some("text")
);
assert_eq!(
field.metadata().get(META_FIDELITY).map(String::as_str),
Some("exact")
);
}
#[test]
fn binary_stays_binary_not_string() {
let mapping = TypeMapping::from_source(&col("payload", "bytea"), RivetType::Binary);
let field = build_arrow_field(&mapping).expect("field");
assert_eq!(field.data_type(), &DataType::Binary);
assert_eq!(mapping.fidelity, TypeFidelity::Exact);
}
#[test]
fn unsupported_yields_no_field() {
let unsupported = RivetType::Unsupported {
native_type: "interval".into(),
reason: "no mapping".into(),
};
let mapping = TypeMapping::from_source(&col("dur", "interval"), unsupported);
assert!(
build_arrow_field(&mapping).is_none(),
"Unsupported must NOT silently produce a Utf8 field — that's exactly the \
silent-degradation pattern the roadmap forbids (§5)"
);
}
#[test]
fn nullable_flag_propagates_from_source_column() {
let nullable = SourceColumn::simple("a", "int4", true);
let not_nullable = SourceColumn::simple("b", "int4", false);
let m_nullable = TypeMapping::from_source(&nullable, RivetType::Int32);
let m_required = TypeMapping::from_source(¬_nullable, RivetType::Int32);
assert!(build_arrow_field(&m_nullable).expect("f").is_nullable());
assert!(!build_arrow_field(&m_required).expect("f").is_nullable());
}
#[test]
fn warnings_are_attachable_via_builder() {
let mapping = TypeMapping::from_source(&col("x", "int4"), RivetType::Int32)
.with_warning("autodetect uncertainty");
assert_eq!(mapping.warnings, vec!["autodetect uncertainty".to_string()]);
}
}