1use std::collections::HashMap;
22
23use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
24use arrow_schema::extension::{Json as ArrowJson, Uuid as ArrowUuid};
25use serde::Serialize;
26use std::sync::Arc;
27
28use super::{RivetType, SourceColumn, TimeUnit, TypeFidelity};
29
30pub const META_NATIVE_TYPE: &str = "rivet.native_type";
35pub const META_LOGICAL_TYPE: &str = "rivet.logical_type";
39pub const META_FIDELITY: &str = "rivet.fidelity";
43
44#[derive(Debug, Clone, Serialize)]
50pub struct TypeMapping {
51 pub column_name: String,
53 pub source_native_type: String,
56 pub rivet_type: RivetType,
58 #[serde(serialize_with = "serialize_arrow_type_opt")]
65 pub arrow_type: Option<DataType>,
66 pub fidelity: TypeFidelity,
68 pub nullable: bool,
72 pub warnings: Vec<String>,
75}
76
77impl TypeMapping {
78 pub fn from_source(source: &SourceColumn, rivet_type: RivetType) -> Self {
85 let fidelity = derive_fidelity(&rivet_type);
86 let arrow_type = rivet_type_to_arrow(&rivet_type);
87 Self {
88 column_name: source.name.clone(),
89 source_native_type: source.native_type.clone(),
90 rivet_type,
91 arrow_type,
92 fidelity,
93 nullable: source.nullable,
94 warnings: Vec::new(),
95 }
96 }
97
98 #[allow(dead_code)]
100 pub fn with_warning(mut self, msg: impl Into<String>) -> Self {
101 self.warnings.push(msg.into());
102 self
103 }
104}
105
106fn serialize_arrow_type_opt<S: serde::Serializer>(
107 v: &Option<DataType>,
108 s: S,
109) -> std::result::Result<S::Ok, S::Error> {
110 match v {
111 None => s.serialize_none(),
112 Some(dt) => s.serialize_some(&format!("{dt:?}")),
113 }
114}
115
116pub fn rivet_type_to_arrow(t: &RivetType) -> Option<DataType> {
126 match t {
127 RivetType::Bool => Some(DataType::Boolean),
128 RivetType::Int16 => Some(DataType::Int16),
129 RivetType::Int32 => Some(DataType::Int32),
130 RivetType::Int64 => Some(DataType::Int64),
131 RivetType::UInt64 => Some(DataType::UInt64),
132 RivetType::Float32 => Some(DataType::Float32),
133 RivetType::Float64 => Some(DataType::Float64),
134 RivetType::Decimal { precision, scale } => Some(decimal_arrow_type(*precision, *scale)),
135 RivetType::Date => Some(DataType::Date32),
136 RivetType::Time { unit } => Some(DataType::Time64(arrow_unit(*unit))),
137 RivetType::Timestamp { unit, timezone } => Some(DataType::Timestamp(
138 arrow_unit(*unit),
139 timezone.as_deref().map(Into::into),
140 )),
141 RivetType::Uuid => Some(DataType::FixedSizeBinary(16)),
148
149 RivetType::String | RivetType::Text | RivetType::Json | RivetType::Enum => {
153 Some(DataType::Utf8)
154 }
155
156 RivetType::Binary => Some(DataType::Binary),
157
158 RivetType::Interval => Some(DataType::Utf8),
162
163 RivetType::List { inner } => rivet_type_to_arrow(inner)
166 .map(|inner_dt| DataType::List(Arc::new(Field::new("item", inner_dt, true)))),
167
168 RivetType::Unsupported { .. } => None,
169 }
170}
171
172fn decimal_arrow_type(precision: u8, scale: i8) -> DataType {
178 if precision <= 38 {
179 DataType::Decimal128(precision, scale)
180 } else {
181 DataType::Decimal256(precision, scale)
182 }
183}
184
185fn arrow_unit(u: TimeUnit) -> ArrowTimeUnit {
186 match u {
187 TimeUnit::Second => ArrowTimeUnit::Second,
188 TimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
189 TimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
190 TimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
191 }
192}
193
194pub fn derive_fidelity(t: &RivetType) -> TypeFidelity {
200 match t {
201 RivetType::Bool
202 | RivetType::Int16
203 | RivetType::Int32
204 | RivetType::Int64
205 | RivetType::UInt64
206 | RivetType::Float32
207 | RivetType::Float64
208 | RivetType::Decimal { .. }
209 | RivetType::Date
210 | RivetType::Time { .. }
211 | RivetType::Timestamp { .. }
212 | RivetType::String
213 | RivetType::Text
214 | RivetType::Binary => TypeFidelity::Exact,
215
216 RivetType::Uuid => TypeFidelity::Exact,
222
223 RivetType::Json => TypeFidelity::LogicalString,
226
227 RivetType::Enum => TypeFidelity::Compatible,
230
231 RivetType::Interval => TypeFidelity::Compatible,
234
235 RivetType::List { inner } => match derive_fidelity(inner) {
245 f @ (TypeFidelity::Unsupported | TypeFidelity::Lossy) => f,
246 _ => TypeFidelity::Compatible,
247 },
248
249 RivetType::Unsupported { .. } => TypeFidelity::Unsupported,
250 }
251}
252
253pub fn build_arrow_field(mapping: &TypeMapping) -> Option<Field> {
261 let dt = mapping.arrow_type.clone()?;
262 let mut metadata: HashMap<String, String> = HashMap::new();
263 metadata.insert(META_NATIVE_TYPE.into(), mapping.source_native_type.clone());
264 metadata.insert(META_FIDELITY.into(), mapping.fidelity.label().into());
265 if let Some(logical) = logical_type_label(&mapping.rivet_type) {
266 metadata.insert(META_LOGICAL_TYPE.into(), logical.into());
267 }
268 let mut field = Field::new(&mapping.column_name, dt, mapping.nullable).with_metadata(metadata);
269
270 match mapping.rivet_type {
282 RivetType::Json => {
283 field
284 .try_with_extension_type(ArrowJson::default())
285 .expect("Json extension only valid on Utf8/LargeUtf8 — invariant in mapping");
286 }
287 RivetType::Uuid => {
288 field
289 .try_with_extension_type(ArrowUuid)
290 .expect("Uuid extension only valid on FixedSizeBinary(16) — invariant in mapping");
291 }
292 _ => {}
293 }
294 Some(field)
295}
296
297fn logical_type_label(t: &RivetType) -> Option<&'static str> {
303 match t {
304 RivetType::Json => Some("json"),
305 RivetType::Uuid => Some("uuid"),
306 RivetType::Enum => Some("enum"),
307 RivetType::Interval => Some("interval"),
308 _ => None,
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn list_fidelity_propagates_unsupported_element() {
318 let bad = RivetType::List {
325 inner: Box::new(RivetType::Unsupported {
326 native_type: "numeric".into(),
327 reason: "precision unavailable".into(),
328 }),
329 };
330 assert_eq!(derive_fidelity(&bad), TypeFidelity::Unsupported);
331 assert!(bad.is_unsupported());
332
333 let good = RivetType::List {
335 inner: Box::new(RivetType::Int32),
336 };
337 assert_eq!(derive_fidelity(&good), TypeFidelity::Compatible);
338 assert!(!good.is_unsupported());
339 }
340
341 fn col(name: &str, native: &str) -> SourceColumn {
342 SourceColumn::simple(name, native, true)
343 }
344
345 #[test]
346 fn integer_types_map_one_to_one() {
347 for (rt, expected) in [
348 (RivetType::Bool, DataType::Boolean),
349 (RivetType::Int16, DataType::Int16),
350 (RivetType::Int32, DataType::Int32),
351 (RivetType::Int64, DataType::Int64),
352 (RivetType::UInt64, DataType::UInt64),
353 (RivetType::Float32, DataType::Float32),
354 (RivetType::Float64, DataType::Float64),
355 ] {
356 assert_eq!(
357 rivet_type_to_arrow(&rt),
358 Some(expected),
359 "rivet type {rt:?}"
360 );
361 assert_eq!(derive_fidelity(&rt), TypeFidelity::Exact);
362 }
363 }
364
365 #[test]
366 fn decimal_p38_uses_decimal128() {
367 for p in [1u8, 18, 38] {
368 let dt = rivet_type_to_arrow(&RivetType::Decimal {
369 precision: p,
370 scale: 2,
371 })
372 .expect("decimal must map to an Arrow type");
373 assert_eq!(dt, DataType::Decimal128(p, 2), "precision={p}");
374 }
375 }
376
377 #[test]
381 fn decimal_above_38_escalates_to_decimal256() {
382 for p in [39u8, 76] {
383 let dt = rivet_type_to_arrow(&RivetType::Decimal {
384 precision: p,
385 scale: 9,
386 })
387 .expect("decimal must map to an Arrow type");
388 assert_eq!(
389 dt,
390 DataType::Decimal256(p, 9),
391 "precision={p} must become Decimal256"
392 );
393 }
394 }
395
396 #[test]
399 fn decimal_supports_negative_scale_for_postgres_numeric() {
400 let dt = rivet_type_to_arrow(&RivetType::Decimal {
401 precision: 5,
402 scale: -2,
403 })
404 .expect("decimal must map to an Arrow type");
405 assert_eq!(dt, DataType::Decimal128(5, -2));
406 }
407
408 #[test]
409 fn timestamp_preserves_timezone_semantics() {
410 let naive = RivetType::Timestamp {
411 unit: TimeUnit::Microsecond,
412 timezone: None,
413 };
414 let utc = RivetType::Timestamp {
415 unit: TimeUnit::Microsecond,
416 timezone: Some("UTC".into()),
417 };
418 assert_eq!(
419 rivet_type_to_arrow(&naive),
420 Some(DataType::Timestamp(ArrowTimeUnit::Microsecond, None))
421 );
422 assert_eq!(
423 rivet_type_to_arrow(&utc),
424 Some(DataType::Timestamp(
425 ArrowTimeUnit::Microsecond,
426 Some("UTC".into())
427 ))
428 );
429 }
430
431 #[test]
432 fn unsupported_returns_no_arrow_type() {
433 let t = RivetType::Unsupported {
434 native_type: "interval".into(),
435 reason: "no mapping yet".into(),
436 };
437 assert_eq!(rivet_type_to_arrow(&t), None);
438 assert_eq!(derive_fidelity(&t), TypeFidelity::Unsupported);
439 }
440
441 #[test]
442 fn json_is_logical_string_with_metadata() {
443 let mapping = TypeMapping::from_source(&col("payload", "jsonb"), RivetType::Json);
444 assert_eq!(mapping.fidelity, TypeFidelity::LogicalString);
445 assert_eq!(mapping.arrow_type, Some(DataType::Utf8));
446
447 let field = build_arrow_field(&mapping).expect("field");
448 assert_eq!(field.data_type(), &DataType::Utf8);
449 assert_eq!(
450 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
451 Some("jsonb")
452 );
453 assert_eq!(
454 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
455 Some("json")
456 );
457 assert_eq!(
458 field.metadata().get(META_FIDELITY).map(String::as_str),
459 Some("logical_string")
460 );
461 }
462
463 #[test]
464 fn uuid_is_exact_fixed_size_binary_with_logical_metadata() {
465 let mapping = TypeMapping::from_source(&col("id", "uuid"), RivetType::Uuid);
471 assert_eq!(mapping.fidelity, TypeFidelity::Exact);
472 assert_eq!(mapping.arrow_type, Some(DataType::FixedSizeBinary(16)));
473
474 let field = build_arrow_field(&mapping).expect("field");
475 assert_eq!(
476 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
477 Some("uuid")
478 );
479 assert_eq!(
480 field.metadata().get(META_FIDELITY).map(String::as_str),
481 Some("exact")
482 );
483 assert_eq!(
485 field
486 .metadata()
487 .get("ARROW:extension:name")
488 .map(String::as_str),
489 Some("arrow.uuid")
490 );
491 }
492
493 #[test]
494 fn plain_string_has_no_logical_type_metadata() {
495 let mapping = TypeMapping::from_source(&col("name", "text"), RivetType::String);
496 let field = build_arrow_field(&mapping).expect("field");
497 assert!(
498 !field.metadata().contains_key(META_LOGICAL_TYPE),
499 "plain string columns must NOT carry rivet.logical_type so consumers \
500 can distinguish them from json/uuid columns"
501 );
502 assert_eq!(
503 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
504 Some("text")
505 );
506 assert_eq!(
507 field.metadata().get(META_FIDELITY).map(String::as_str),
508 Some("exact")
509 );
510 }
511
512 #[test]
513 fn binary_stays_binary_not_string() {
514 let mapping = TypeMapping::from_source(&col("payload", "bytea"), RivetType::Binary);
516 let field = build_arrow_field(&mapping).expect("field");
517 assert_eq!(field.data_type(), &DataType::Binary);
518 assert_eq!(mapping.fidelity, TypeFidelity::Exact);
519 }
520
521 #[test]
522 fn unsupported_yields_no_field() {
523 let unsupported = RivetType::Unsupported {
524 native_type: "interval".into(),
525 reason: "no mapping".into(),
526 };
527 let mapping = TypeMapping::from_source(&col("dur", "interval"), unsupported);
528 assert!(
529 build_arrow_field(&mapping).is_none(),
530 "Unsupported must NOT silently produce a Utf8 field — that's exactly the \
531 silent-degradation pattern the roadmap forbids (§5)"
532 );
533 }
534
535 #[test]
536 fn nullable_flag_propagates_from_source_column() {
537 let nullable = SourceColumn::simple("a", "int4", true);
538 let not_nullable = SourceColumn::simple("b", "int4", false);
539 let m_nullable = TypeMapping::from_source(&nullable, RivetType::Int32);
540 let m_required = TypeMapping::from_source(¬_nullable, RivetType::Int32);
541 assert!(build_arrow_field(&m_nullable).expect("f").is_nullable());
542 assert!(!build_arrow_field(&m_required).expect("f").is_nullable());
543 }
544
545 #[test]
546 fn warnings_are_attachable_via_builder() {
547 let mapping = TypeMapping::from_source(&col("x", "int4"), RivetType::Int32)
548 .with_warning("autodetect uncertainty");
549 assert_eq!(mapping.warnings, vec!["autodetect uncertainty".to_string()]);
550 }
551}