1use std::collections::HashMap;
22
23use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
24use arrow_schema::extension::{Json as ArrowJson, Uuid as ArrowUuid};
25use serde::Serialize;
26use std::sync::Arc;
27
28use super::{RivetType, SourceColumn, TimeUnit, TypeFidelity};
29
30pub const META_NATIVE_TYPE: &str = "rivet.native_type";
35pub const META_LOGICAL_TYPE: &str = "rivet.logical_type";
39pub const META_FIDELITY: &str = "rivet.fidelity";
43
44#[derive(Debug, Clone, Serialize)]
50pub struct TypeMapping {
51 pub column_name: String,
53 pub source_native_type: String,
56 pub rivet_type: RivetType,
58 #[serde(serialize_with = "serialize_arrow_type_opt")]
65 pub arrow_type: Option<DataType>,
66 pub fidelity: TypeFidelity,
68 pub nullable: bool,
72 pub warnings: Vec<String>,
75}
76
77impl TypeMapping {
78 pub fn from_source(source: &SourceColumn, rivet_type: RivetType) -> Self {
85 let fidelity = derive_fidelity(&rivet_type);
86 let arrow_type = rivet_type_to_arrow(&rivet_type);
87 Self {
88 column_name: source.name.clone(),
89 source_native_type: source.native_type.clone(),
90 rivet_type,
91 arrow_type,
92 fidelity,
93 nullable: source.nullable,
94 warnings: Vec::new(),
95 }
96 }
97
98 #[allow(dead_code)]
100 pub fn with_warning(mut self, msg: impl Into<String>) -> Self {
101 self.warnings.push(msg.into());
102 self
103 }
104}
105
106fn serialize_arrow_type_opt<S: serde::Serializer>(
107 v: &Option<DataType>,
108 s: S,
109) -> std::result::Result<S::Ok, S::Error> {
110 match v {
111 None => s.serialize_none(),
112 Some(dt) => s.serialize_some(&format!("{dt:?}")),
113 }
114}
115
116pub fn rivet_type_to_arrow(t: &RivetType) -> Option<DataType> {
126 match t {
127 RivetType::Bool => Some(DataType::Boolean),
128 RivetType::Int16 => Some(DataType::Int16),
129 RivetType::Int32 => Some(DataType::Int32),
130 RivetType::Int64 => Some(DataType::Int64),
131 RivetType::UInt64 => Some(DataType::UInt64),
132 RivetType::Float32 => Some(DataType::Float32),
133 RivetType::Float64 => Some(DataType::Float64),
134 RivetType::Decimal { precision, scale } => Some(decimal_arrow_type(*precision, *scale)),
135 RivetType::Date => Some(DataType::Date32),
136 RivetType::Time { unit } => Some(DataType::Time64(arrow_unit(*unit))),
137 RivetType::Timestamp { unit, timezone } => Some(DataType::Timestamp(
138 arrow_unit(*unit),
139 timezone.as_deref().map(Into::into),
140 )),
141 RivetType::Uuid => Some(DataType::FixedSizeBinary(16)),
148
149 RivetType::String | RivetType::Text | RivetType::Json | RivetType::Enum => {
153 Some(DataType::Utf8)
154 }
155
156 RivetType::Binary => Some(DataType::Binary),
157
158 RivetType::Interval => Some(DataType::Utf8),
162
163 RivetType::List { inner } => rivet_type_to_arrow(inner)
166 .map(|inner_dt| DataType::List(Arc::new(Field::new("item", inner_dt, true)))),
167
168 RivetType::Unsupported { .. } => None,
169 }
170}
171
172fn decimal_arrow_type(precision: u8, scale: i8) -> DataType {
178 if precision <= 38 {
179 DataType::Decimal128(precision, scale)
180 } else {
181 DataType::Decimal256(precision, scale)
182 }
183}
184
185fn arrow_unit(u: TimeUnit) -> ArrowTimeUnit {
186 match u {
187 TimeUnit::Second => ArrowTimeUnit::Second,
188 TimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
189 TimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
190 TimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
191 }
192}
193
194pub fn derive_fidelity(t: &RivetType) -> TypeFidelity {
200 match t {
201 RivetType::Bool
202 | RivetType::Int16
203 | RivetType::Int32
204 | RivetType::Int64
205 | RivetType::UInt64
206 | RivetType::Float32
207 | RivetType::Float64
208 | RivetType::Decimal { .. }
209 | RivetType::Date
210 | RivetType::Time { .. }
211 | RivetType::Timestamp { .. }
212 | RivetType::String
213 | RivetType::Text
214 | RivetType::Binary => TypeFidelity::Exact,
215
216 RivetType::Uuid => TypeFidelity::Exact,
222
223 RivetType::Json => TypeFidelity::LogicalString,
226
227 RivetType::Enum => TypeFidelity::Compatible,
230
231 RivetType::Interval => TypeFidelity::Compatible,
234
235 RivetType::List { .. } => TypeFidelity::Compatible,
237
238 RivetType::Unsupported { .. } => TypeFidelity::Unsupported,
239 }
240}
241
242pub fn build_arrow_field(mapping: &TypeMapping) -> Option<Field> {
250 let dt = mapping.arrow_type.clone()?;
251 let mut metadata: HashMap<String, String> = HashMap::new();
252 metadata.insert(META_NATIVE_TYPE.into(), mapping.source_native_type.clone());
253 metadata.insert(META_FIDELITY.into(), mapping.fidelity.label().into());
254 if let Some(logical) = logical_type_label(&mapping.rivet_type) {
255 metadata.insert(META_LOGICAL_TYPE.into(), logical.into());
256 }
257 let mut field = Field::new(&mapping.column_name, dt, mapping.nullable).with_metadata(metadata);
258
259 match mapping.rivet_type {
271 RivetType::Json => {
272 field
273 .try_with_extension_type(ArrowJson::default())
274 .expect("Json extension only valid on Utf8/LargeUtf8 — invariant in mapping");
275 }
276 RivetType::Uuid => {
277 field
278 .try_with_extension_type(ArrowUuid)
279 .expect("Uuid extension only valid on FixedSizeBinary(16) — invariant in mapping");
280 }
281 _ => {}
282 }
283 Some(field)
284}
285
286fn logical_type_label(t: &RivetType) -> Option<&'static str> {
292 match t {
293 RivetType::Json => Some("json"),
294 RivetType::Uuid => Some("uuid"),
295 RivetType::Enum => Some("enum"),
296 RivetType::Interval => Some("interval"),
297 _ => None,
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 fn col(name: &str, native: &str) -> SourceColumn {
306 SourceColumn::simple(name, native, true)
307 }
308
309 #[test]
310 fn integer_types_map_one_to_one() {
311 for (rt, expected) in [
312 (RivetType::Bool, DataType::Boolean),
313 (RivetType::Int16, DataType::Int16),
314 (RivetType::Int32, DataType::Int32),
315 (RivetType::Int64, DataType::Int64),
316 (RivetType::UInt64, DataType::UInt64),
317 (RivetType::Float32, DataType::Float32),
318 (RivetType::Float64, DataType::Float64),
319 ] {
320 assert_eq!(
321 rivet_type_to_arrow(&rt),
322 Some(expected),
323 "rivet type {rt:?}"
324 );
325 assert_eq!(derive_fidelity(&rt), TypeFidelity::Exact);
326 }
327 }
328
329 #[test]
330 fn decimal_p38_uses_decimal128() {
331 for p in [1u8, 18, 38] {
332 let dt = rivet_type_to_arrow(&RivetType::Decimal {
333 precision: p,
334 scale: 2,
335 })
336 .expect("decimal must map to an Arrow type");
337 assert_eq!(dt, DataType::Decimal128(p, 2), "precision={p}");
338 }
339 }
340
341 #[test]
345 fn decimal_above_38_escalates_to_decimal256() {
346 for p in [39u8, 76] {
347 let dt = rivet_type_to_arrow(&RivetType::Decimal {
348 precision: p,
349 scale: 9,
350 })
351 .expect("decimal must map to an Arrow type");
352 assert_eq!(
353 dt,
354 DataType::Decimal256(p, 9),
355 "precision={p} must become Decimal256"
356 );
357 }
358 }
359
360 #[test]
363 fn decimal_supports_negative_scale_for_postgres_numeric() {
364 let dt = rivet_type_to_arrow(&RivetType::Decimal {
365 precision: 5,
366 scale: -2,
367 })
368 .expect("decimal must map to an Arrow type");
369 assert_eq!(dt, DataType::Decimal128(5, -2));
370 }
371
372 #[test]
373 fn timestamp_preserves_timezone_semantics() {
374 let naive = RivetType::Timestamp {
375 unit: TimeUnit::Microsecond,
376 timezone: None,
377 };
378 let utc = RivetType::Timestamp {
379 unit: TimeUnit::Microsecond,
380 timezone: Some("UTC".into()),
381 };
382 assert_eq!(
383 rivet_type_to_arrow(&naive),
384 Some(DataType::Timestamp(ArrowTimeUnit::Microsecond, None))
385 );
386 assert_eq!(
387 rivet_type_to_arrow(&utc),
388 Some(DataType::Timestamp(
389 ArrowTimeUnit::Microsecond,
390 Some("UTC".into())
391 ))
392 );
393 }
394
395 #[test]
396 fn unsupported_returns_no_arrow_type() {
397 let t = RivetType::Unsupported {
398 native_type: "interval".into(),
399 reason: "no mapping yet".into(),
400 };
401 assert_eq!(rivet_type_to_arrow(&t), None);
402 assert_eq!(derive_fidelity(&t), TypeFidelity::Unsupported);
403 }
404
405 #[test]
406 fn json_is_logical_string_with_metadata() {
407 let mapping = TypeMapping::from_source(&col("payload", "jsonb"), RivetType::Json);
408 assert_eq!(mapping.fidelity, TypeFidelity::LogicalString);
409 assert_eq!(mapping.arrow_type, Some(DataType::Utf8));
410
411 let field = build_arrow_field(&mapping).expect("field");
412 assert_eq!(field.data_type(), &DataType::Utf8);
413 assert_eq!(
414 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
415 Some("jsonb")
416 );
417 assert_eq!(
418 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
419 Some("json")
420 );
421 assert_eq!(
422 field.metadata().get(META_FIDELITY).map(String::as_str),
423 Some("logical_string")
424 );
425 }
426
427 #[test]
428 fn uuid_is_exact_fixed_size_binary_with_logical_metadata() {
429 let mapping = TypeMapping::from_source(&col("id", "uuid"), RivetType::Uuid);
435 assert_eq!(mapping.fidelity, TypeFidelity::Exact);
436 assert_eq!(mapping.arrow_type, Some(DataType::FixedSizeBinary(16)));
437
438 let field = build_arrow_field(&mapping).expect("field");
439 assert_eq!(
440 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
441 Some("uuid")
442 );
443 assert_eq!(
444 field.metadata().get(META_FIDELITY).map(String::as_str),
445 Some("exact")
446 );
447 assert_eq!(
449 field
450 .metadata()
451 .get("ARROW:extension:name")
452 .map(String::as_str),
453 Some("arrow.uuid")
454 );
455 }
456
457 #[test]
458 fn plain_string_has_no_logical_type_metadata() {
459 let mapping = TypeMapping::from_source(&col("name", "text"), RivetType::String);
460 let field = build_arrow_field(&mapping).expect("field");
461 assert!(
462 !field.metadata().contains_key(META_LOGICAL_TYPE),
463 "plain string columns must NOT carry rivet.logical_type so consumers \
464 can distinguish them from json/uuid columns"
465 );
466 assert_eq!(
467 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
468 Some("text")
469 );
470 assert_eq!(
471 field.metadata().get(META_FIDELITY).map(String::as_str),
472 Some("exact")
473 );
474 }
475
476 #[test]
477 fn binary_stays_binary_not_string() {
478 let mapping = TypeMapping::from_source(&col("payload", "bytea"), RivetType::Binary);
480 let field = build_arrow_field(&mapping).expect("field");
481 assert_eq!(field.data_type(), &DataType::Binary);
482 assert_eq!(mapping.fidelity, TypeFidelity::Exact);
483 }
484
485 #[test]
486 fn unsupported_yields_no_field() {
487 let unsupported = RivetType::Unsupported {
488 native_type: "interval".into(),
489 reason: "no mapping".into(),
490 };
491 let mapping = TypeMapping::from_source(&col("dur", "interval"), unsupported);
492 assert!(
493 build_arrow_field(&mapping).is_none(),
494 "Unsupported must NOT silently produce a Utf8 field — that's exactly the \
495 silent-degradation pattern the roadmap forbids (§5)"
496 );
497 }
498
499 #[test]
500 fn nullable_flag_propagates_from_source_column() {
501 let nullable = SourceColumn::simple("a", "int4", true);
502 let not_nullable = SourceColumn::simple("b", "int4", false);
503 let m_nullable = TypeMapping::from_source(&nullable, RivetType::Int32);
504 let m_required = TypeMapping::from_source(¬_nullable, RivetType::Int32);
505 assert!(build_arrow_field(&m_nullable).expect("f").is_nullable());
506 assert!(!build_arrow_field(&m_required).expect("f").is_nullable());
507 }
508
509 #[test]
510 fn warnings_are_attachable_via_builder() {
511 let mapping = TypeMapping::from_source(&col("x", "int4"), RivetType::Int32)
512 .with_warning("autodetect uncertainty");
513 assert_eq!(mapping.warnings, vec!["autodetect uncertainty".to_string()]);
514 }
515}