Skip to main content

rivet/types/
mapping.rs

1//! `RivetType` → `arrow::DataType` + Arrow metadata.
2//!
3//! See `rivet_roadmap.md` §Epic 14. §5 — Type Mapping Pipeline, §14 —
4//! ("Binary, UUID, JSON" — metadata example). This module is intentionally
5//! the *only* place where `RivetType` becomes an `arrow::DataType`. Source
6//! drivers must not poke at `arrow::DataType` directly any more — they
7//! produce a [`SourceColumn`], call into a vendor-specific
8//! `<vendor>_to_rivet()` (Chunks 2/3), then hand the resulting
9//! [`TypeMapping`] to [`build_arrow_field`] here.
10//!
11//! Why funnel everything through one function:
12//!
13//! - It guarantees the metadata key set ([`META_NATIVE_TYPE`],
14//!   [`META_LOGICAL_TYPE`], [`META_FIDELITY`]) is identical regardless of
15//!   the source database, so downstream consumers (e.g. BigQuery target
16//!   check) can rely on it.
17//! - It keeps Arrow as a *target language*, not a public API — Chunks 4–8
18//!   add policy and overrides without each one needing to know about
19//!   `Field::with_metadata`.
20
21use std::collections::HashMap;
22
23use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
24use serde::Serialize;
25use std::sync::Arc;
26
27use super::{RivetType, SourceColumn, TimeUnit, TypeFidelity};
28
29/// Arrow field-metadata key carrying the native database type name.
30/// Read by the type-report CLI (Chunk 5) and by future BigQuery / Snowflake
31/// target checks so they can produce hints like
32/// "this column came from `numeric(18,2)`".
33pub const META_NATIVE_TYPE: &str = "rivet.native_type";
34/// Arrow field-metadata key carrying the Rivet logical type — used for
35/// types whose physical Arrow representation is `Utf8` but whose semantic
36/// type is recoverable (e.g. `json`, `uuid`).
37pub const META_LOGICAL_TYPE: &str = "rivet.logical_type";
38/// Arrow field-metadata key carrying the [`TypeFidelity`] label.
39/// CI / strict-mode tooling can sniff this to assert that no field in a
40/// produced Parquet schema is `lossy` or `unsupported`.
41pub const META_FIDELITY: &str = "rivet.fidelity";
42
43/// One row of the Type Mapping Pipeline (roadmap §6 `TypeMapping`).
44///
45/// Carries the full provenance from a source-DB column to its eventual
46/// Arrow representation. The struct is what the type-report CLI prints
47/// and what `TypePolicy` validates against.
48#[derive(Debug, Clone, Serialize)]
49pub struct TypeMapping {
50    /// Column name (matches `SourceColumn::name`).
51    pub column_name: String,
52    /// Native source-DB type identifier (`numeric(18,2)`, `timestamptz`,
53    /// `jsonb`, …).
54    pub source_native_type: String,
55    /// The canonical Rivet type produced by the vendor mapper.
56    pub rivet_type: RivetType,
57    /// Resolved Arrow type, or `None` for [`RivetType::Unsupported`] until
58    /// a policy turns it into something exportable.
59    ///
60    /// Kept as `arrow::DataType` (not a stringly-typed name) so the
61    /// pipeline can build an `arrow::Schema` directly from a
62    /// `Vec<TypeMapping>`.
63    #[serde(serialize_with = "serialize_arrow_type_opt")]
64    pub arrow_type: Option<DataType>,
65    /// Fidelity classification — see [`TypeFidelity`].
66    pub fidelity: TypeFidelity,
67    /// True when the source schema declares the column nullable. Threaded
68    /// from `SourceColumn::nullable` so [`build_arrow_field`] doesn't need
69    /// the original column.
70    pub nullable: bool,
71    /// Diagnostic strings emitted by the mapper or the policy. Surfaced by
72    /// the type-report and the strict-mode failure message.
73    pub warnings: Vec<String>,
74}
75
76impl TypeMapping {
77    /// Build a mapping from a [`SourceColumn`] and an already-resolved
78    /// [`RivetType`]. The Arrow type is computed by [`rivet_type_to_arrow`]
79    /// and the fidelity by [`derive_fidelity`].
80    ///
81    /// This is the canonical constructor used by every vendor mapper —
82    /// Chunks 2/3 will call this once per source column.
83    pub fn from_source(source: &SourceColumn, rivet_type: RivetType) -> Self {
84        let fidelity = derive_fidelity(&rivet_type);
85        let arrow_type = rivet_type_to_arrow(&rivet_type);
86        Self {
87            column_name: source.name.clone(),
88            source_native_type: source.native_type.clone(),
89            rivet_type,
90            arrow_type,
91            fidelity,
92            nullable: source.nullable,
93            warnings: Vec::new(),
94        }
95    }
96
97    /// Append a warning visible to the type-report and to logs.
98    #[allow(dead_code)]
99    pub fn with_warning(mut self, msg: impl Into<String>) -> Self {
100        self.warnings.push(msg.into());
101        self
102    }
103}
104
105fn serialize_arrow_type_opt<S: serde::Serializer>(
106    v: &Option<DataType>,
107    s: S,
108) -> std::result::Result<S::Ok, S::Error> {
109    match v {
110        None => s.serialize_none(),
111        Some(dt) => s.serialize_some(&format!("{dt:?}")),
112    }
113}
114
115/// Map [`RivetType`] → [`arrow::DataType`].
116///
117/// This is the *only* place where Arrow types are constructed from Rivet
118/// types. Source drivers must not duplicate this logic; they go through
119/// [`TypeMapping::from_source`] instead.
120///
121/// Returns `None` for [`RivetType::Unsupported`] — the policy layer
122/// (Chunk 4) is responsible for either failing the run or rewriting the
123/// `RivetType` into something supported (e.g. `Unsupported -> String`).
124pub fn rivet_type_to_arrow(t: &RivetType) -> Option<DataType> {
125    match t {
126        RivetType::Bool => Some(DataType::Boolean),
127        RivetType::Int16 => Some(DataType::Int16),
128        RivetType::Int32 => Some(DataType::Int32),
129        RivetType::Int64 => Some(DataType::Int64),
130        RivetType::UInt64 => Some(DataType::UInt64),
131        RivetType::Float32 => Some(DataType::Float32),
132        RivetType::Float64 => Some(DataType::Float64),
133        RivetType::Decimal { precision, scale } => Some(decimal_arrow_type(*precision, *scale)),
134        RivetType::Date => Some(DataType::Date32),
135        RivetType::Time { unit } => Some(DataType::Time64(arrow_unit(*unit))),
136        RivetType::Timestamp { unit, timezone } => Some(DataType::Timestamp(
137            arrow_unit(*unit),
138            timezone.as_deref().map(Into::into),
139        )),
140        // Logical-string types: physical Arrow is Utf8; the metadata
141        // attached by `build_arrow_field` records that the source meant
142        // something more specific (json/uuid/enum).
143        RivetType::String
144        | RivetType::Text
145        | RivetType::Json
146        | RivetType::Uuid
147        | RivetType::Enum => Some(DataType::Utf8),
148
149        RivetType::Binary => Some(DataType::Binary),
150
151        // Interval → Utf8 (ISO 8601 duration string, e.g. "P1Y2M3D").
152        // Arrow's Interval(MonthDayNano) cannot be written to Parquet, so we
153        // serialise to a lossless text representation in the source driver.
154        RivetType::Interval => Some(DataType::Utf8),
155
156        // One-dimensional array: recursively resolve the inner element type.
157        // Returns None if the inner type itself is Unsupported.
158        RivetType::List { inner } => rivet_type_to_arrow(inner)
159            .map(|inner_dt| DataType::List(Arc::new(Field::new("item", inner_dt, true)))),
160
161        RivetType::Unsupported { .. } => None,
162    }
163}
164
165/// Decimal128 vs Decimal256 selection per roadmap §12 ("Exact Decimal Support"):
166/// `Decimal128(p,s)` when `p <= 38`, `Decimal256(p,s)` otherwise.
167///
168/// Negative scale is allowed by PostgreSQL `numeric(p,-s)` and is forwarded
169/// through unchanged — Arrow / Parquet accept it on Decimal128/256.
170fn decimal_arrow_type(precision: u8, scale: i8) -> DataType {
171    if precision <= 38 {
172        DataType::Decimal128(precision, scale)
173    } else {
174        DataType::Decimal256(precision, scale)
175    }
176}
177
178fn arrow_unit(u: TimeUnit) -> ArrowTimeUnit {
179    match u {
180        TimeUnit::Second => ArrowTimeUnit::Second,
181        TimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
182        TimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
183        TimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
184    }
185}
186
187/// Compute the [`TypeFidelity`] for a freshly-resolved [`RivetType`].
188///
189/// The output of every vendor mapper goes through this so the fidelity
190/// label is computed in *exactly one place* and the type-report stays
191/// consistent across PostgreSQL / MySQL / future drivers.
192pub fn derive_fidelity(t: &RivetType) -> TypeFidelity {
193    match t {
194        RivetType::Bool
195        | RivetType::Int16
196        | RivetType::Int32
197        | RivetType::Int64
198        | RivetType::UInt64
199        | RivetType::Float32
200        | RivetType::Float64
201        | RivetType::Decimal { .. }
202        | RivetType::Date
203        | RivetType::Time { .. }
204        | RivetType::Timestamp { .. }
205        | RivetType::String
206        | RivetType::Text
207        | RivetType::Binary => TypeFidelity::Exact,
208
209        // UUID round-trips losslessly as text but the physical type is not
210        // the canonical FixedSizeBinary(16) — call it `compatible`.
211        RivetType::Uuid => TypeFidelity::Compatible,
212
213        // JSON is preserved byte-for-byte but its native semantics
214        // (object/array tree) are not — call it `logical_string`.
215        RivetType::Json => TypeFidelity::LogicalString,
216
217        // Enum labels are text — value preserved, but native enum semantics
218        // (ordered labels, constraint) are not enforced in Arrow.
219        RivetType::Enum => TypeFidelity::Compatible,
220
221        // Interval: Arrow IntervalMonthDayNano preserves all three components
222        // exactly; downstream tools may interpret it differently.
223        RivetType::Interval => TypeFidelity::Compatible,
224
225        // List: Arrow List preserves element values; 1-D only currently.
226        RivetType::List { .. } => TypeFidelity::Compatible,
227
228        RivetType::Unsupported { .. } => TypeFidelity::Unsupported,
229    }
230}
231
232/// Build an `arrow::Field` from a [`TypeMapping`], attaching the standard
233/// metadata keys ([`META_NATIVE_TYPE`], [`META_LOGICAL_TYPE`],
234/// [`META_FIDELITY`]).
235///
236/// Returns `None` if the mapping has no resolved Arrow type (i.e. the
237/// `RivetType` is `Unsupported` and no policy has rewritten it). Callers
238/// must surface this as a type-policy decision, not as a panic.
239pub fn build_arrow_field(mapping: &TypeMapping) -> Option<Field> {
240    let dt = mapping.arrow_type.clone()?;
241    let mut metadata: HashMap<String, String> = HashMap::new();
242    metadata.insert(META_NATIVE_TYPE.into(), mapping.source_native_type.clone());
243    metadata.insert(META_FIDELITY.into(), mapping.fidelity.label().into());
244    if let Some(logical) = logical_type_label(&mapping.rivet_type) {
245        metadata.insert(META_LOGICAL_TYPE.into(), logical.into());
246    }
247    Some(Field::new(&mapping.column_name, dt, mapping.nullable).with_metadata(metadata))
248}
249
250/// Return the `rivet.logical_type` value for types whose physical Arrow
251/// representation is a generic container (Utf8, Binary) but whose source
252/// semantic is more specific (`json`, `uuid`). `None` when the physical
253/// type already encodes the semantic (e.g. `Decimal128(18,2)` is already
254/// "decimal" in Arrow / Parquet).
255fn logical_type_label(t: &RivetType) -> Option<&'static str> {
256    match t {
257        RivetType::Json => Some("json"),
258        RivetType::Uuid => Some("uuid"),
259        RivetType::Enum => Some("enum"),
260        RivetType::Interval => Some("interval"),
261        _ => None,
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    fn col(name: &str, native: &str) -> SourceColumn {
270        SourceColumn::simple(name, native, true)
271    }
272
273    #[test]
274    fn integer_types_map_one_to_one() {
275        for (rt, expected) in [
276            (RivetType::Bool, DataType::Boolean),
277            (RivetType::Int16, DataType::Int16),
278            (RivetType::Int32, DataType::Int32),
279            (RivetType::Int64, DataType::Int64),
280            (RivetType::UInt64, DataType::UInt64),
281            (RivetType::Float32, DataType::Float32),
282            (RivetType::Float64, DataType::Float64),
283        ] {
284            assert_eq!(
285                rivet_type_to_arrow(&rt),
286                Some(expected),
287                "rivet type {rt:?}"
288            );
289            assert_eq!(derive_fidelity(&rt), TypeFidelity::Exact);
290        }
291    }
292
293    #[test]
294    fn decimal_p38_uses_decimal128() {
295        for p in [1u8, 18, 38] {
296            let dt = rivet_type_to_arrow(&RivetType::Decimal {
297                precision: p,
298                scale: 2,
299            })
300            .expect("decimal must map to an Arrow type");
301            assert_eq!(dt, DataType::Decimal128(p, 2), "precision={p}");
302        }
303    }
304
305    /// Roadmap §12: precision >38 must escalate to Decimal256, not silently
306    /// truncate or fall back to Float64. This is the single most important
307    /// invariant of the whole Type Safety Foundation.
308    #[test]
309    fn decimal_above_38_escalates_to_decimal256() {
310        for p in [39u8, 76] {
311            let dt = rivet_type_to_arrow(&RivetType::Decimal {
312                precision: p,
313                scale: 9,
314            })
315            .expect("decimal must map to an Arrow type");
316            assert_eq!(
317                dt,
318                DataType::Decimal256(p, 9),
319                "precision={p} must become Decimal256"
320            );
321        }
322    }
323
324    /// Roadmap §12: PostgreSQL `numeric(5,-2)` rounds to hundreds; the type
325    /// system must round-trip the negative scale.
326    #[test]
327    fn decimal_supports_negative_scale_for_postgres_numeric() {
328        let dt = rivet_type_to_arrow(&RivetType::Decimal {
329            precision: 5,
330            scale: -2,
331        })
332        .expect("decimal must map to an Arrow type");
333        assert_eq!(dt, DataType::Decimal128(5, -2));
334    }
335
336    #[test]
337    fn timestamp_preserves_timezone_semantics() {
338        let naive = RivetType::Timestamp {
339            unit: TimeUnit::Microsecond,
340            timezone: None,
341        };
342        let utc = RivetType::Timestamp {
343            unit: TimeUnit::Microsecond,
344            timezone: Some("UTC".into()),
345        };
346        assert_eq!(
347            rivet_type_to_arrow(&naive),
348            Some(DataType::Timestamp(ArrowTimeUnit::Microsecond, None))
349        );
350        assert_eq!(
351            rivet_type_to_arrow(&utc),
352            Some(DataType::Timestamp(
353                ArrowTimeUnit::Microsecond,
354                Some("UTC".into())
355            ))
356        );
357    }
358
359    #[test]
360    fn unsupported_returns_no_arrow_type() {
361        let t = RivetType::Unsupported {
362            native_type: "interval".into(),
363            reason: "no mapping yet".into(),
364        };
365        assert_eq!(rivet_type_to_arrow(&t), None);
366        assert_eq!(derive_fidelity(&t), TypeFidelity::Unsupported);
367    }
368
369    #[test]
370    fn json_is_logical_string_with_metadata() {
371        let mapping = TypeMapping::from_source(&col("payload", "jsonb"), RivetType::Json);
372        assert_eq!(mapping.fidelity, TypeFidelity::LogicalString);
373        assert_eq!(mapping.arrow_type, Some(DataType::Utf8));
374
375        let field = build_arrow_field(&mapping).expect("field");
376        assert_eq!(field.data_type(), &DataType::Utf8);
377        assert_eq!(
378            field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
379            Some("jsonb")
380        );
381        assert_eq!(
382            field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
383            Some("json")
384        );
385        assert_eq!(
386            field.metadata().get(META_FIDELITY).map(String::as_str),
387            Some("logical_string")
388        );
389    }
390
391    #[test]
392    fn uuid_is_compatible_with_logical_metadata() {
393        let mapping = TypeMapping::from_source(&col("id", "uuid"), RivetType::Uuid);
394        assert_eq!(mapping.fidelity, TypeFidelity::Compatible);
395
396        let field = build_arrow_field(&mapping).expect("field");
397        assert_eq!(
398            field.metadata().get(META_LOGICAL_TYPE).map(String::as_str),
399            Some("uuid")
400        );
401        assert_eq!(
402            field.metadata().get(META_FIDELITY).map(String::as_str),
403            Some("compatible")
404        );
405    }
406
407    #[test]
408    fn plain_string_has_no_logical_type_metadata() {
409        let mapping = TypeMapping::from_source(&col("name", "text"), RivetType::String);
410        let field = build_arrow_field(&mapping).expect("field");
411        assert!(
412            !field.metadata().contains_key(META_LOGICAL_TYPE),
413            "plain string columns must NOT carry rivet.logical_type so consumers \
414             can distinguish them from json/uuid columns"
415        );
416        assert_eq!(
417            field.metadata().get(META_NATIVE_TYPE).map(String::as_str),
418            Some("text")
419        );
420        assert_eq!(
421            field.metadata().get(META_FIDELITY).map(String::as_str),
422            Some("exact")
423        );
424    }
425
426    #[test]
427    fn binary_stays_binary_not_string() {
428        // Roadmap §14: binary columns must never be silently exported as Utf8.
429        let mapping = TypeMapping::from_source(&col("payload", "bytea"), RivetType::Binary);
430        let field = build_arrow_field(&mapping).expect("field");
431        assert_eq!(field.data_type(), &DataType::Binary);
432        assert_eq!(mapping.fidelity, TypeFidelity::Exact);
433    }
434
435    #[test]
436    fn unsupported_yields_no_field() {
437        let unsupported = RivetType::Unsupported {
438            native_type: "interval".into(),
439            reason: "no mapping".into(),
440        };
441        let mapping = TypeMapping::from_source(&col("dur", "interval"), unsupported);
442        assert!(
443            build_arrow_field(&mapping).is_none(),
444            "Unsupported must NOT silently produce a Utf8 field — that's exactly the \
445             silent-degradation pattern the roadmap forbids (§5)"
446        );
447    }
448
449    #[test]
450    fn nullable_flag_propagates_from_source_column() {
451        let nullable = SourceColumn::simple("a", "int4", true);
452        let not_nullable = SourceColumn::simple("b", "int4", false);
453        let m_nullable = TypeMapping::from_source(&nullable, RivetType::Int32);
454        let m_required = TypeMapping::from_source(&not_nullable, RivetType::Int32);
455        assert!(build_arrow_field(&m_nullable).expect("f").is_nullable());
456        assert!(!build_arrow_field(&m_required).expect("f").is_nullable());
457    }
458
459    #[test]
460    fn warnings_are_attachable_via_builder() {
461        let mapping = TypeMapping::from_source(&col("x", "int4"), RivetType::Int32)
462            .with_warning("autodetect uncertainty");
463        assert_eq!(mapping.warnings, vec!["autodetect uncertainty".to_string()]);
464    }
465}