1use std::collections::HashMap;
22
23use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
24use serde::Serialize;
25use std::sync::Arc;
26
27use super::{RivetType, SourceColumn, TimeUnit, TypeFidelity};
28
29pub const META_NATIVE_TYPE: &str = "rivet.native_type";
34pub const META_LOGICAL_TYPE: &str = "rivet.logical_type";
38pub const META_FIDELITY: &str = "rivet.fidelity";
42
43#[derive(Debug, Clone, Serialize)]
49pub struct TypeMapping {
50 pub column_name: String,
52 pub source_native_type: String,
55 pub rivet_type: RivetType,
57 #[serde(serialize_with = "serialize_arrow_type_opt")]
64 pub arrow_type: Option<DataType>,
65 pub fidelity: TypeFidelity,
67 pub nullable: bool,
71 pub warnings: Vec<String>,
74}
75
76impl TypeMapping {
77 pub fn from_source(source: &SourceColumn, rivet_type: RivetType) -> Self {
84 let fidelity = derive_fidelity(&rivet_type);
85 let arrow_type = rivet_type_to_arrow(&rivet_type);
86 Self {
87 column_name: source.name.clone(),
88 source_native_type: source.native_type.clone(),
89 rivet_type,
90 arrow_type,
91 fidelity,
92 nullable: source.nullable,
93 warnings: Vec::new(),
94 }
95 }
96
97 #[allow(dead_code)]
99 pub fn with_warning(mut self, msg: impl Into<String>) -> Self {
100 self.warnings.push(msg.into());
101 self
102 }
103}
104
105fn serialize_arrow_type_opt<S: serde::Serializer>(
106 v: &Option<DataType>,
107 s: S,
108) -> std::result::Result<S::Ok, S::Error> {
109 match v {
110 None => s.serialize_none(),
111 Some(dt) => s.serialize_some(&format!("{dt:?}")),
112 }
113}
114
115pub fn rivet_type_to_arrow(t: &RivetType) -> Option<DataType> {
125 match t {
126 RivetType::Bool => Some(DataType::Boolean),
127 RivetType::Int16 => Some(DataType::Int16),
128 RivetType::Int32 => Some(DataType::Int32),
129 RivetType::Int64 => Some(DataType::Int64),
130 RivetType::UInt64 => Some(DataType::UInt64),
131 RivetType::Float32 => Some(DataType::Float32),
132 RivetType::Float64 => Some(DataType::Float64),
133 RivetType::Decimal { precision, scale } => Some(decimal_arrow_type(*precision, *scale)),
134 RivetType::Date => Some(DataType::Date32),
135 RivetType::Time { unit } => Some(DataType::Time64(arrow_unit(*unit))),
136 RivetType::Timestamp { unit, timezone } => Some(DataType::Timestamp(
137 arrow_unit(*unit),
138 timezone.as_deref().map(Into::into),
139 )),
140 RivetType::String
144 | RivetType::Text
145 | RivetType::Json
146 | RivetType::Uuid
147 | RivetType::Enum => Some(DataType::Utf8),
148
149 RivetType::Binary => Some(DataType::Binary),
150
151 RivetType::Interval => Some(DataType::Utf8),
155
156 RivetType::List { inner } => rivet_type_to_arrow(inner)
159 .map(|inner_dt| DataType::List(Arc::new(Field::new("item", inner_dt, true)))),
160
161 RivetType::Unsupported { .. } => None,
162 }
163}
164
165fn decimal_arrow_type(precision: u8, scale: i8) -> DataType {
171 if precision <= 38 {
172 DataType::Decimal128(precision, scale)
173 } else {
174 DataType::Decimal256(precision, scale)
175 }
176}
177
178fn arrow_unit(u: TimeUnit) -> ArrowTimeUnit {
179 match u {
180 TimeUnit::Second => ArrowTimeUnit::Second,
181 TimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
182 TimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
183 TimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
184 }
185}
186
187pub fn derive_fidelity(t: &RivetType) -> TypeFidelity {
193 match t {
194 RivetType::Bool
195 | RivetType::Int16
196 | RivetType::Int32
197 | RivetType::Int64
198 | RivetType::UInt64
199 | RivetType::Float32
200 | RivetType::Float64
201 | RivetType::Decimal { .. }
202 | RivetType::Date
203 | RivetType::Time { .. }
204 | RivetType::Timestamp { .. }
205 | RivetType::String
206 | RivetType::Text
207 | RivetType::Binary => TypeFidelity::Exact,
208
209 RivetType::Uuid => TypeFidelity::Compatible,
212
213 RivetType::Json => TypeFidelity::LogicalString,
216
217 RivetType::Enum => TypeFidelity::Compatible,
220
221 RivetType::Interval => TypeFidelity::Compatible,
224
225 RivetType::List { .. } => TypeFidelity::Compatible,
227
228 RivetType::Unsupported { .. } => TypeFidelity::Unsupported,
229 }
230}
231
232pub fn build_arrow_field(mapping: &TypeMapping) -> Option<Field> {
240 let dt = mapping.arrow_type.clone()?;
241 let mut metadata: HashMap<String, String> = HashMap::new();
242 metadata.insert(META_NATIVE_TYPE.into(), mapping.source_native_type.clone());
243 metadata.insert(META_FIDELITY.into(), mapping.fidelity.label().into());
244 if let Some(logical) = logical_type_label(&mapping.rivet_type) {
245 metadata.insert(META_LOGICAL_TYPE.into(), logical.into());
246 }
247 Some(Field::new(&mapping.column_name, dt, mapping.nullable).with_metadata(metadata))
248}
249
250fn logical_type_label(t: &RivetType) -> Option<&'static str> {
256 match t {
257 RivetType::Json => Some("json"),
258 RivetType::Uuid => Some("uuid"),
259 RivetType::Enum => Some("enum"),
260 RivetType::Interval => Some("interval"),
261 _ => None,
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 fn col(name: &str, native: &str) -> SourceColumn {
270 SourceColumn::simple(name, native, true)
271 }
272
273 #[test]
274 fn integer_types_map_one_to_one() {
275 for (rt, expected) in [
276 (RivetType::Bool, DataType::Boolean),
277 (RivetType::Int16, DataType::Int16),
278 (RivetType::Int32, DataType::Int32),
279 (RivetType::Int64, DataType::Int64),
280 (RivetType::UInt64, DataType::UInt64),
281 (RivetType::Float32, DataType::Float32),
282 (RivetType::Float64, DataType::Float64),
283 ] {
284 assert_eq!(
285 rivet_type_to_arrow(&rt),
286 Some(expected),
287 "rivet type {rt:?}"
288 );
289 assert_eq!(derive_fidelity(&rt), TypeFidelity::Exact);
290 }
291 }
292
293 #[test]
294 fn decimal_p38_uses_decimal128() {
295 for p in [1u8, 18, 38] {
296 let dt = rivet_type_to_arrow(&RivetType::Decimal {
297 precision: p,
298 scale: 2,
299 })
300 .expect("decimal must map to an Arrow type");
301 assert_eq!(dt, DataType::Decimal128(p, 2), "precision={p}");
302 }
303 }
304
305 #[test]
309 fn decimal_above_38_escalates_to_decimal256() {
310 for p in [39u8, 76] {
311 let dt = rivet_type_to_arrow(&RivetType::Decimal {
312 precision: p,
313 scale: 9,
314 })
315 .expect("decimal must map to an Arrow type");
316 assert_eq!(
317 dt,
318 DataType::Decimal256(p, 9),
319 "precision={p} must become Decimal256"
320 );
321 }
322 }
323
324 #[test]
327 fn decimal_supports_negative_scale_for_postgres_numeric() {
328 let dt = rivet_type_to_arrow(&RivetType::Decimal {
329 precision: 5,
330 scale: -2,
331 })
332 .expect("decimal must map to an Arrow type");
333 assert_eq!(dt, DataType::Decimal128(5, -2));
334 }
335
336 #[test]
337 fn timestamp_preserves_timezone_semantics() {
338 let naive = RivetType::Timestamp {
339 unit: TimeUnit::Microsecond,
340 timezone: None,
341 };
342 let utc = RivetType::Timestamp {
343 unit: TimeUnit::Microsecond,
344 timezone: Some("UTC".into()),
345 };
346 assert_eq!(
347 rivet_type_to_arrow(&naive),
348 Some(DataType::Timestamp(ArrowTimeUnit::Microsecond, None))
349 );
350 assert_eq!(
351 rivet_type_to_arrow(&utc),
352 Some(DataType::Timestamp(
353 ArrowTimeUnit::Microsecond,
354 Some("UTC".into())
355 ))
356 );
357 }
358
359 #[test]
360 fn unsupported_returns_no_arrow_type() {
361 let t = RivetType::Unsupported {
362 native_type: "interval".into(),
363 reason: "no mapping yet".into(),
364 };
365 assert_eq!(rivet_type_to_arrow(&t), None);
366 assert_eq!(derive_fidelity(&t), TypeFidelity::Unsupported);
367 }
368
369 #[test]
370 fn json_is_logical_string_with_metadata() {
371 let mapping = TypeMapping::from_source(&col("payload", "jsonb"), RivetType::Json);
372 assert_eq!(mapping.fidelity, TypeFidelity::LogicalString);
373 assert_eq!(mapping.arrow_type, Some(DataType::Utf8));
374
375 let field = build_arrow_field(&mapping).expect("field");
376 assert_eq!(field.data_type(), &DataType::Utf8);
377 assert_eq!(
378 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
379 Some("jsonb")
380 );
381 assert_eq!(
382 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
383 Some("json")
384 );
385 assert_eq!(
386 field.metadata().get(META_FIDELITY).map(String::as_str),
387 Some("logical_string")
388 );
389 }
390
391 #[test]
392 fn uuid_is_compatible_with_logical_metadata() {
393 let mapping = TypeMapping::from_source(&col("id", "uuid"), RivetType::Uuid);
394 assert_eq!(mapping.fidelity, TypeFidelity::Compatible);
395
396 let field = build_arrow_field(&mapping).expect("field");
397 assert_eq!(
398 field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
399 Some("uuid")
400 );
401 assert_eq!(
402 field.metadata().get(META_FIDELITY).map(String::as_str),
403 Some("compatible")
404 );
405 }
406
407 #[test]
408 fn plain_string_has_no_logical_type_metadata() {
409 let mapping = TypeMapping::from_source(&col("name", "text"), RivetType::String);
410 let field = build_arrow_field(&mapping).expect("field");
411 assert!(
412 !field.metadata().contains_key(META_LOGICAL_TYPE),
413 "plain string columns must NOT carry rivet.logical_type so consumers \
414 can distinguish them from json/uuid columns"
415 );
416 assert_eq!(
417 field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
418 Some("text")
419 );
420 assert_eq!(
421 field.metadata().get(META_FIDELITY).map(String::as_str),
422 Some("exact")
423 );
424 }
425
426 #[test]
427 fn binary_stays_binary_not_string() {
428 let mapping = TypeMapping::from_source(&col("payload", "bytea"), RivetType::Binary);
430 let field = build_arrow_field(&mapping).expect("field");
431 assert_eq!(field.data_type(), &DataType::Binary);
432 assert_eq!(mapping.fidelity, TypeFidelity::Exact);
433 }
434
435 #[test]
436 fn unsupported_yields_no_field() {
437 let unsupported = RivetType::Unsupported {
438 native_type: "interval".into(),
439 reason: "no mapping".into(),
440 };
441 let mapping = TypeMapping::from_source(&col("dur", "interval"), unsupported);
442 assert!(
443 build_arrow_field(&mapping).is_none(),
444 "Unsupported must NOT silently produce a Utf8 field — that's exactly the \
445 silent-degradation pattern the roadmap forbids (§5)"
446 );
447 }
448
449 #[test]
450 fn nullable_flag_propagates_from_source_column() {
451 let nullable = SourceColumn::simple("a", "int4", true);
452 let not_nullable = SourceColumn::simple("b", "int4", false);
453 let m_nullable = TypeMapping::from_source(&nullable, RivetType::Int32);
454 let m_required = TypeMapping::from_source(¬_nullable, RivetType::Int32);
455 assert!(build_arrow_field(&m_nullable).expect("f").is_nullable());
456 assert!(!build_arrow_field(&m_required).expect("f").is_nullable());
457 }
458
459 #[test]
460 fn warnings_are_attachable_via_builder() {
461 let mapping = TypeMapping::from_source(&col("x", "int4"), RivetType::Int32)
462 .with_warning("autodetect uncertainty");
463 assert_eq!(mapping.warnings, vec!["autodetect uncertainty".to_string()]);
464 }
465}