1use arrow::datatypes::DataType;
32use serde::Serialize;
33
34use super::{RivetType, TypeFidelity, TypeMapping};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ExportTarget {
40 DuckDb,
43 BigQuery,
46 Snowflake,
49}
50
51impl ExportTarget {
52 pub fn parse(s: &str) -> Option<Self> {
53 match s.to_lowercase().as_str() {
54 "bigquery" | "bq" => Some(Self::BigQuery),
55 "duckdb" | "duck" => Some(Self::DuckDb),
56 "snowflake" | "sf" => Some(Self::Snowflake),
57 _ => None,
58 }
59 }
60
61 pub fn label(self) -> &'static str {
62 match self {
63 Self::BigQuery => "bigquery",
64 Self::DuckDb => "duckdb",
65 Self::Snowflake => "snowflake",
66 }
67 }
68
69 pub fn resolve_column(self, input: TargetInput<'_>) -> TargetColumnSpec {
71 let mut spec = match self {
72 ExportTarget::BigQuery => bigquery::resolve(&input),
73 ExportTarget::DuckDb => duckdb::resolve(&input),
74 ExportTarget::Snowflake => snowflake::resolve(&input),
75 };
76 if input.fidelity.is_unsafe_for_strict_mode() && spec.status == TargetStatus::Ok {
80 spec.status = TargetStatus::Warn;
81 }
82 spec
83 }
84
85 pub fn resolve_table(self, mappings: &[TypeMapping]) -> Vec<TargetColumnSpec> {
90 mappings
91 .iter()
92 .map(|m| self.resolve_column(TargetInput::from(m)))
93 .collect()
94 }
95
96 pub fn recovery_sql(self, specs: &[TargetColumnSpec], table: &str) -> Option<String> {
104 match self {
105 ExportTarget::BigQuery => Some(bigquery_recovery_sql(specs, table)),
106 ExportTarget::Snowflake => Some(snowflake_recovery_sql(specs, table)),
107 ExportTarget::DuckDb => None,
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
114#[serde(rename_all = "snake_case")]
115pub enum TargetStatus {
116 Ok,
117 Warn,
118 Fail,
119}
120
121impl TargetStatus {
122 pub fn label(&self) -> &'static str {
123 match self {
124 Self::Ok => "ok",
125 Self::Warn => "warn",
126 Self::Fail => "fail",
127 }
128 }
129}
130
131#[derive(Debug, Clone, Copy)]
136pub struct TargetInput<'a> {
137 pub column_name: &'a str,
138 pub rivet_type: &'a RivetType,
139 #[allow(dead_code)]
142 pub arrow_type: Option<&'a DataType>,
143 pub fidelity: TypeFidelity,
144}
145
146impl<'a> From<&'a TypeMapping> for TargetInput<'a> {
147 fn from(m: &'a TypeMapping) -> Self {
148 TargetInput {
149 column_name: &m.column_name,
150 rivet_type: &m.rivet_type,
151 arrow_type: m.arrow_type.as_ref(),
152 fidelity: m.fidelity,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize)]
161pub struct TargetColumnSpec {
162 pub column_name: String,
164 pub target_type: String,
166 pub autoload_type: String,
169 pub status: TargetStatus,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub note: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
175 pub cast_sql: Option<String>,
176}
177
178struct Resolved {
181 target_type: String,
182 autoload_type: String,
183 status: TargetStatus,
184 note: Option<String>,
185 cast: Option<String>,
187}
188
189impl Resolved {
190 fn ok(t: impl Into<String>) -> Self {
191 let t = t.into();
192 Self {
193 autoload_type: t.clone(),
194 target_type: t,
195 status: TargetStatus::Ok,
196 note: None,
197 cast: None,
198 }
199 }
200 fn diverge(
203 native: impl Into<String>,
204 autoload: impl Into<String>,
205 note: impl Into<String>,
206 cast: Option<&str>,
207 ) -> Self {
208 Self {
209 target_type: native.into(),
210 autoload_type: autoload.into(),
211 status: TargetStatus::Warn,
212 note: Some(note.into()),
213 cast: cast.map(str::to_string),
214 }
215 }
216 fn warn(t: impl Into<String>, note: impl Into<String>) -> Self {
217 let t = t.into();
218 Self {
219 autoload_type: t.clone(),
220 target_type: t,
221 status: TargetStatus::Warn,
222 note: Some(note.into()),
223 cast: None,
224 }
225 }
226 fn fail(note: impl Into<String>) -> Self {
227 Self {
228 target_type: "-".into(),
229 autoload_type: "-".into(),
230 status: TargetStatus::Fail,
231 note: Some(note.into()),
232 cast: None,
233 }
234 }
235 fn into_spec(self, input: &TargetInput<'_>) -> TargetColumnSpec {
236 TargetColumnSpec {
237 column_name: input.column_name.to_string(),
238 target_type: self.target_type,
239 autoload_type: self.autoload_type,
240 status: self.status,
241 note: self.note,
242 cast_sql: self.cast.map(|t| t.replace("{col}", input.column_name)),
243 }
244 }
245}
246
247fn unsupported_reason(t: &RivetType) -> String {
248 match t {
249 RivetType::Unsupported { reason, .. } => reason.clone(),
250 _ => "no target mapping".into(),
251 }
252}
253
254fn recovery_projection(specs: &[TargetColumnSpec], passthrough: impl Fn(&str) -> String) -> String {
269 specs
270 .iter()
271 .map(|s| match &s.cast_sql {
272 Some(cast) => format!(" {cast} AS {name}", name = s.column_name),
273 None => passthrough(&s.column_name),
274 })
275 .collect::<Vec<_>>()
276 .join(",\n")
277}
278
279fn bigquery_recovery_sql(specs: &[TargetColumnSpec], table: &str) -> String {
280 let cols = recovery_projection(specs, |name| format!(" {name}"));
281 format!(
282 "-- 1) bq load --autodetect --parquet_enable_list_inference \
283 --source_format=PARQUET {table}__staging <parquet>\n\
284 -- 2) recover native types:\n\
285 CREATE OR REPLACE TABLE `{table}` AS\n\
286 SELECT\n{cols}\n\
287 FROM `{table}__staging`;"
288 )
289}
290
291fn snowflake_recovery_sql(specs: &[TargetColumnSpec], table: &str) -> String {
297 let cols = recovery_projection(specs, |name| format!(" \"{name}\" AS {name}"));
298 format!(
299 "-- 1) ALTER SESSION SET TIMEZONE='UTC';\n\
300 -- 2) CREATE OR REPLACE FILE FORMAT rivet_pq TYPE=PARQUET BINARY_AS_TEXT=FALSE;\n\
301 -- 3) PUT file://<parquet> @<stage> AUTO_COMPRESS=FALSE;\n\
302 -- 4) CREATE OR REPLACE TABLE {table}__staging USING TEMPLATE (SELECT ARRAY_AGG(\n\
303 -- OBJECT_CONSTRUCT(*)) FROM TABLE(INFER_SCHEMA(LOCATION=>'@<stage>', FILE_FORMAT=>'rivet_pq')));\n\
304 -- COPY INTO {table}__staging FROM @<stage> FILE_FORMAT=(FORMAT_NAME='rivet_pq') MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE;\n\
305 -- 5) recover native types:\n\
306 CREATE OR REPLACE TABLE {table} AS\n\
307 SELECT\n{cols}\n\
308 FROM {table}__staging;"
309 )
310}
311
312mod bigquery {
315 use super::*;
316
317 const NUMERIC_MAX_P: u8 = 29;
319 const NUMERIC_MAX_S: i8 = 9;
320 const BIGNUMERIC_MAX_P: u8 = 76;
322 const BIGNUMERIC_MAX_S: i8 = 38;
323
324 pub(super) fn resolve(input: &TargetInput<'_>) -> TargetColumnSpec {
325 native(input.rivet_type).into_spec(input)
326 }
327
328 fn native(t: &RivetType) -> Resolved {
329 match t {
330 RivetType::Bool => Resolved::ok("BOOL"),
331 RivetType::Int16 | RivetType::Int32 | RivetType::Int64 => Resolved::ok("INT64"),
332 RivetType::UInt64 => Resolved::diverge(
337 "NUMERIC",
338 "INT64",
339 "UINT64 > INT64_MAX overflows the INT64 autoload and cannot be recovered after \
340 load — map the column to decimal(20,0) with a source column override",
341 None,
342 ),
343 RivetType::Float32 | RivetType::Float64 => Resolved::ok("FLOAT64"),
344 RivetType::Decimal { precision, scale } => decimal(*precision, *scale),
345 RivetType::Date => Resolved::ok("DATE"),
346 RivetType::Time { .. } => Resolved::ok("TIME"),
347 RivetType::Timestamp {
349 timezone: Some(_), ..
350 } => Resolved::ok("TIMESTAMP"),
351 RivetType::Timestamp { timezone: None, .. } => Resolved::diverge(
355 "DATETIME",
356 "TIMESTAMP",
357 "naive timestamp autoloads as TIMESTAMP (an instant); recover wall-clock with \
358 DATETIME(col) after load",
359 Some("DATETIME({col})"),
360 ),
361 RivetType::String | RivetType::Text | RivetType::Enum => Resolved::ok("STRING"),
362 RivetType::Binary => Resolved::ok("BYTES"),
363 RivetType::Json => Resolved::diverge(
366 "JSON",
367 "BYTES",
368 "Parquet JSON logical type autoloads as BYTES in BigQuery; recover native JSON \
369 with PARSE_JSON(SAFE_CONVERT_BYTES_TO_STRING(col)) after load",
370 Some("PARSE_JSON(SAFE_CONVERT_BYTES_TO_STRING({col}))"),
371 ),
372 RivetType::Uuid => Resolved::diverge(
376 "STRING",
377 "BYTES",
378 "UUID autoloads as 16-byte BYTES in BigQuery; recover hex text with TO_HEX(col) \
379 after load (or keep BYTES)",
380 Some("TO_HEX({col})"),
381 ),
382 RivetType::Interval => Resolved::ok("STRING"),
383 RivetType::List { inner } => list(inner),
384 RivetType::Unsupported { .. } => Resolved::fail(unsupported_reason(t)),
385 }
386 }
387
388 fn decimal(p: u8, s: i8) -> Resolved {
389 if s < 0 {
390 return Resolved::fail(format!(
391 "BigQuery has no negative scale; decimal({p},{s}) needs a STRING/INT64 cast"
392 ));
393 }
394 let native = if p <= NUMERIC_MAX_P && s <= NUMERIC_MAX_S {
395 "NUMERIC"
396 } else if p <= BIGNUMERIC_MAX_P && s <= BIGNUMERIC_MAX_S {
397 "BIGNUMERIC"
398 } else {
399 return Resolved::fail(format!(
400 "decimal({p},{s}) exceeds BigQuery BIGNUMERIC limits (max 76,38)"
401 ));
402 };
403 Resolved::ok(native)
404 }
405
406 fn list(inner: &RivetType) -> Resolved {
407 let inner_r = native(inner);
408 if inner_r.status == TargetStatus::Fail {
409 return Resolved::fail(format!(
410 "REPEATED of unsupported element: {}",
411 inner_r.target_type
412 ));
413 }
414 Resolved::diverge(
421 format!("REPEATED {}", inner_r.target_type),
422 format!("REPEATED RECORD{{item {}}}", inner_r.autoload_type),
423 "arrays load as REPEATED RECORD{item}; load the staging table with \
424 --parquet_enable_list_inference, then flatten with UNNEST after load",
425 Some("ARRAY(SELECT el.item FROM UNNEST({col}) AS el)"),
426 )
427 }
428}
429
430mod duckdb {
433 use super::*;
434
435 pub(super) fn resolve(input: &TargetInput<'_>) -> TargetColumnSpec {
436 native(input.rivet_type).into_spec(input)
437 }
438
439 fn native(t: &RivetType) -> Resolved {
442 match t {
443 RivetType::Bool => Resolved::ok("BOOLEAN"),
444 RivetType::Int16 => Resolved::ok("SMALLINT"),
445 RivetType::Int32 => Resolved::ok("INTEGER"),
446 RivetType::Int64 => Resolved::ok("BIGINT"),
447 RivetType::UInt64 => Resolved::ok("UBIGINT"),
448 RivetType::Float32 => Resolved::ok("FLOAT"),
449 RivetType::Float64 => Resolved::ok("DOUBLE"),
450 RivetType::Decimal { precision, scale } => {
451 if *scale < 0 {
452 Resolved::warn(
453 "DECIMAL",
454 format!(
455 "DuckDB has no negative scale; decimal({precision},{scale}) loads via cast"
456 ),
457 )
458 } else if *precision <= 38 {
459 Resolved::ok(format!("DECIMAL({precision},{scale})"))
460 } else {
461 Resolved::warn(
463 "DECIMAL(38,*)",
464 format!("decimal({precision},{scale}) exceeds DuckDB DECIMAL(38); widens"),
465 )
466 }
467 }
468 RivetType::Date => Resolved::ok("DATE"),
469 RivetType::Time { .. } => Resolved::ok("TIME"),
470 RivetType::Timestamp {
471 timezone: Some(_), ..
472 } => Resolved::ok("TIMESTAMPTZ"),
473 RivetType::Timestamp { timezone: None, .. } => Resolved::ok("TIMESTAMP"),
474 RivetType::String | RivetType::Text | RivetType::Enum => Resolved::ok("VARCHAR"),
475 RivetType::Binary => Resolved::ok("BLOB"),
476 RivetType::Json => Resolved::ok("JSON"),
477 RivetType::Uuid => Resolved::ok("UUID"),
478 RivetType::Interval => Resolved::ok("INTERVAL"),
479 RivetType::List { inner } => {
480 let inner_r = native(inner);
481 if inner_r.status == TargetStatus::Fail {
482 Resolved::fail(format!(
483 "LIST of unsupported element: {}",
484 inner_r.target_type
485 ))
486 } else {
487 Resolved::ok(format!("{}[]", inner_r.target_type))
488 }
489 }
490 RivetType::Unsupported { .. } => Resolved::fail(unsupported_reason(t)),
491 }
492 }
493}
494
495mod snowflake {
498 use super::*;
499
500 pub(super) fn resolve(input: &TargetInput<'_>) -> TargetColumnSpec {
501 native(input.rivet_type).into_spec(input)
502 }
503
504 fn native(t: &RivetType) -> Resolved {
509 match t {
510 RivetType::Bool => Resolved::ok("BOOLEAN"),
511 RivetType::Int16 | RivetType::Int32 | RivetType::Int64 => Resolved::ok("NUMBER(38,0)"),
512 RivetType::UInt64 => Resolved::diverge(
514 "NUMBER(20,0)",
515 "NUMBER(38,0)",
516 "UINT64 > INT64_MAX overflows the Parquet read; map to decimal(20,0) at source",
517 None,
518 ),
519 RivetType::Float32 | RivetType::Float64 => Resolved::ok("FLOAT"),
520 RivetType::Decimal { precision, scale } => {
521 if *scale < 0 {
522 Resolved::warn(
523 "NUMBER",
524 format!(
525 "Snowflake NUMBER has no negative scale; decimal({precision},{scale}) loads via cast"
526 ),
527 )
528 } else {
529 Resolved::ok(format!("NUMBER({precision},{scale})"))
530 }
531 }
532 RivetType::Date => Resolved::ok("DATE"),
533 RivetType::Time { .. } => Resolved::diverge(
535 "TIME",
536 "NUMBER(38,0)",
537 "TIME autoloads as NUMBER (µs of day); recover with TIME_FROM_PARTS after load",
538 Some(r#"TIME_FROM_PARTS(0,0,FLOOR("{col}"/1000000),MOD("{col}",1000000)*1000)"#),
539 ),
540 RivetType::Timestamp {
542 timezone: Some(_), ..
543 } => Resolved::diverge(
544 "TIMESTAMP_TZ",
545 "TIMESTAMP_NTZ",
546 "tz timestamp autoloads as TIMESTAMP_NTZ — ALTER SESSION SET TIMEZONE='UTC' before COPY so the instant matches",
547 None,
548 ),
549 RivetType::Timestamp { timezone: None, .. } => Resolved::diverge(
551 "TIMESTAMP_NTZ",
552 "NUMBER(38,0)",
553 "naive timestamp autoloads as NUMBER (µs since epoch); recover with TO_TIMESTAMP_NTZ after load",
554 Some(r#"TO_TIMESTAMP_NTZ("{col}", 6)"#),
555 ),
556 RivetType::String | RivetType::Text | RivetType::Enum => Resolved::ok("TEXT"),
557 RivetType::Binary => Resolved::warn(
559 "BINARY",
560 "set BINARY_AS_TEXT=FALSE in the Parquet FILE FORMAT or non-UTF8 bytes fail to load",
561 ),
562 RivetType::Json => Resolved::diverge(
564 "VARIANT",
565 "TEXT",
566 "JSON autoloads as TEXT; recover native VARIANT with PARSE_JSON after load",
567 Some(r#"PARSE_JSON("{col}")"#),
568 ),
569 RivetType::Uuid => Resolved::diverge(
571 "TEXT",
572 "BINARY",
573 "UUID autoloads as 16-byte BINARY; recover canonical text with HEX_ENCODE + REGEXP after load",
574 Some(
575 r#"REGEXP_REPLACE(LOWER(HEX_ENCODE("{col}")),'^(.{8})(.{4})(.{4})(.{4})(.{12})$','\\1-\\2-\\3-\\4-\\5')"#,
576 ),
577 ),
578 RivetType::Interval => Resolved::ok("TEXT"),
579 RivetType::List { inner } => {
584 let inner_r = native(inner);
585 if inner_r.status == TargetStatus::Fail {
586 Resolved::fail(format!(
587 "ARRAY of unsupported element: {}",
588 inner_r.target_type
589 ))
590 } else {
591 Resolved::diverge(
592 "ARRAY",
593 "VARIANT",
594 "list autoloads as VARIANT (the JSON array); recover native ARRAY with ::ARRAY after load",
595 Some(r#""{col}"::ARRAY"#),
596 )
597 }
598 }
599 RivetType::Unsupported { .. } => Resolved::fail(unsupported_reason(t)),
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 fn input<'a>(rt: &'a RivetType) -> TargetInput<'a> {
609 TargetInput {
610 column_name: "c",
611 rivet_type: rt,
612 arrow_type: None,
613 fidelity: TypeFidelity::Exact,
614 }
615 }
616
617 fn bq(rt: &RivetType) -> TargetColumnSpec {
618 ExportTarget::BigQuery.resolve_column(input(rt))
619 }
620 fn duck(rt: &RivetType) -> TargetColumnSpec {
621 ExportTarget::DuckDb.resolve_column(input(rt))
622 }
623 fn sf(rt: &RivetType) -> TargetColumnSpec {
624 ExportTarget::Snowflake.resolve_column(input(rt))
625 }
626
627 #[test]
630 fn bq_uuid_resolves_not_fails() {
631 let s = bq(&RivetType::Uuid);
635 assert_eq!(s.target_type, "STRING");
636 assert_eq!(s.autoload_type, "BYTES");
637 assert_eq!(s.status, TargetStatus::Warn);
638 assert!(s.cast_sql.unwrap().contains("c"));
639 }
640
641 #[test]
642 fn bq_json_native_is_json_autoload_is_bytes() {
643 let s = bq(&RivetType::Json);
644 assert_eq!(s.target_type, "JSON");
645 assert_eq!(s.autoload_type, "BYTES");
646 assert_eq!(s.status, TargetStatus::Warn);
647 assert!(s.cast_sql.unwrap().starts_with("PARSE_JSON"));
648 }
649
650 #[test]
651 fn bq_naive_timestamp_is_datetime_native_timestamp_autoload() {
652 let naive = RivetType::Timestamp {
653 unit: super::super::TimeUnit::Microsecond,
654 timezone: None,
655 };
656 let s = bq(&naive);
657 assert_eq!(s.target_type, "DATETIME");
658 assert_eq!(s.autoload_type, "TIMESTAMP");
659 assert_eq!(s.status, TargetStatus::Warn);
660 }
661
662 #[test]
663 fn bq_tz_timestamp_is_timestamp_ok() {
664 let tz = RivetType::Timestamp {
665 unit: super::super::TimeUnit::Microsecond,
666 timezone: Some("UTC".into()),
667 };
668 let s = bq(&tz);
669 assert_eq!(s.target_type, "TIMESTAMP");
670 assert_eq!(s.autoload_type, "TIMESTAMP");
671 assert_eq!(s.status, TargetStatus::Ok);
672 }
673
674 #[test]
675 fn bq_decimal_within_numeric_is_numeric() {
676 let s = bq(&RivetType::Decimal {
677 precision: 18,
678 scale: 2,
679 });
680 assert_eq!(s.target_type, "NUMERIC");
681 assert_eq!(s.status, TargetStatus::Ok);
682 }
683
684 #[test]
685 fn bq_decimal_escalates_to_bignumeric() {
686 let s = bq(&RivetType::Decimal {
687 precision: 38,
688 scale: 9,
689 });
690 assert_eq!(s.target_type, "BIGNUMERIC");
691 assert_eq!(s.status, TargetStatus::Ok);
692 }
693
694 #[test]
695 fn bq_decimal_negative_scale_fails() {
696 let s = bq(&RivetType::Decimal {
697 precision: 5,
698 scale: -2,
699 });
700 assert_eq!(s.status, TargetStatus::Fail);
701 }
702
703 #[test]
704 fn bq_uint64_recommends_numeric_warns_overflow() {
705 let s = bq(&RivetType::UInt64);
706 assert_eq!(s.target_type, "NUMERIC");
707 assert_eq!(s.autoload_type, "INT64");
708 assert_eq!(s.status, TargetStatus::Warn);
709 }
710
711 #[test]
712 fn bq_list_is_repeated_native_record_autoload() {
713 let t = RivetType::List {
714 inner: Box::new(RivetType::String),
715 };
716 let s = bq(&t);
717 assert_eq!(s.target_type, "REPEATED STRING");
718 assert!(s.autoload_type.contains("REPEATED RECORD"));
719 assert_eq!(s.status, TargetStatus::Warn);
720 }
721
722 #[test]
723 fn bq_unsupported_is_fail_row_not_panic() {
724 let t = RivetType::Unsupported {
725 native_type: "geometry".into(),
726 reason: "no mapping".into(),
727 };
728 let s = bq(&t);
729 assert_eq!(s.status, TargetStatus::Fail);
730 assert_eq!(s.target_type, "-");
731 }
732
733 #[test]
734 fn bq_standard_scalars_ok() {
735 for (rt, native) in [
736 (RivetType::Bool, "BOOL"),
737 (RivetType::Int64, "INT64"),
738 (RivetType::Float64, "FLOAT64"),
739 (RivetType::Date, "DATE"),
740 (RivetType::String, "STRING"),
741 (RivetType::Binary, "BYTES"),
742 (RivetType::Enum, "STRING"),
743 ] {
744 let s = bq(&rt);
745 assert_eq!(s.target_type, native, "{rt:?}");
746 assert_eq!(s.autoload_type, native, "{rt:?}");
747 assert_eq!(s.status, TargetStatus::Ok, "{rt:?}");
748 }
749 }
750
751 #[test]
754 fn duckdb_reads_everything_natively() {
755 let naive = RivetType::Timestamp {
756 unit: super::super::TimeUnit::Microsecond,
757 timezone: None,
758 };
759 for rt in [
760 RivetType::Json,
761 RivetType::Uuid,
762 RivetType::UInt64,
763 naive,
764 RivetType::List {
765 inner: Box::new(RivetType::Int64),
766 },
767 ] {
768 let s = duck(&rt);
769 assert_eq!(
770 s.target_type, s.autoload_type,
771 "DuckDB autoload must equal native for {rt:?}"
772 );
773 assert_ne!(s.status, TargetStatus::Fail, "{rt:?}");
774 }
775 }
776
777 #[test]
778 fn duckdb_native_type_names() {
779 assert_eq!(duck(&RivetType::Json).target_type, "JSON");
780 assert_eq!(duck(&RivetType::Uuid).target_type, "UUID");
781 assert_eq!(duck(&RivetType::UInt64).target_type, "UBIGINT");
782 assert_eq!(
783 duck(&RivetType::Decimal {
784 precision: 18,
785 scale: 2
786 })
787 .target_type,
788 "DECIMAL(18,2)"
789 );
790 assert_eq!(
791 duck(&RivetType::List {
792 inner: Box::new(RivetType::Int64)
793 })
794 .target_type,
795 "BIGINT[]"
796 );
797 }
798
799 #[test]
800 fn parse_accepts_aliases() {
801 assert_eq!(ExportTarget::parse("bq"), Some(ExportTarget::BigQuery));
802 assert_eq!(
803 ExportTarget::parse("BigQuery"),
804 Some(ExportTarget::BigQuery)
805 );
806 assert_eq!(ExportTarget::parse("duckdb"), Some(ExportTarget::DuckDb));
807 assert_eq!(ExportTarget::parse("nope"), None);
808 }
809
810 #[test]
811 fn resolve_table_preserves_order_and_names() {
812 use super::super::SourceColumn;
813 let mappings = vec![
814 TypeMapping::from_source(&SourceColumn::simple("a", "int8", true), RivetType::Int64),
815 TypeMapping::from_source(&SourceColumn::simple("b", "jsonb", true), RivetType::Json),
816 ];
817 let specs = ExportTarget::BigQuery.resolve_table(&mappings);
818 assert_eq!(specs.len(), 2);
819 assert_eq!(specs[0].column_name, "a");
820 assert_eq!(specs[1].column_name, "b");
821 assert_eq!(specs[1].target_type, "JSON");
822 }
823
824 #[test]
831 fn cast_sql_is_none_when_post_load_recovery_is_impossible() {
832 let u = bq(&RivetType::UInt64);
837 assert!(
838 u.cast_sql.is_none(),
839 "overflowed UINT64 has no lossless post-load recovery"
840 );
841 let note = u.note.unwrap().to_lowercase();
842 assert!(
843 note.contains("override"),
844 "UINT64 note must point to the source-side override, got: {note}"
845 );
846 }
847
848 #[test]
849 fn cast_sql_present_only_when_lossless_post_load() {
850 assert!(
853 bq(&RivetType::Json)
854 .cast_sql
855 .unwrap()
856 .contains("PARSE_JSON")
857 );
858 assert!(bq(&RivetType::Uuid).cast_sql.unwrap().contains("TO_HEX"));
859 let naive = RivetType::Timestamp {
860 unit: super::super::TimeUnit::Microsecond,
861 timezone: None,
862 };
863 assert!(bq(&naive).cast_sql.unwrap().contains("DATETIME"));
864 }
865
866 #[test]
867 fn every_divergence_offers_a_recovery_path() {
868 let naive = RivetType::Timestamp {
873 unit: super::super::TimeUnit::Microsecond,
874 timezone: None,
875 };
876 let cases = [
877 RivetType::Json,
878 RivetType::Uuid,
879 RivetType::UInt64,
880 naive,
881 RivetType::List {
882 inner: Box::new(RivetType::String),
883 },
884 ];
885 for rt in cases {
886 let s = bq(&rt);
887 assert_ne!(s.autoload_type, s.target_type, "case must diverge: {rt:?}");
888 let has_cast = s.cast_sql.is_some();
889 let note = s.note.as_deref().unwrap_or("").to_lowercase();
890 let describes_recovery = note.contains("after load") || note.contains("override");
891 assert!(
892 has_cast || describes_recovery,
893 "divergent {rt:?} must offer a recovery (cast_sql or a recovery note)"
894 );
895 }
896 }
897
898 #[test]
901 fn bq_decimal_limit_boundaries() {
902 assert_eq!(
904 bq(&RivetType::Decimal {
905 precision: 76,
906 scale: 38
907 })
908 .status,
909 TargetStatus::Ok
910 );
911 assert_eq!(
913 bq(&RivetType::Decimal {
914 precision: 77,
915 scale: 38
916 })
917 .status,
918 TargetStatus::Fail
919 );
920 assert_eq!(
922 bq(&RivetType::Decimal {
923 precision: 76,
924 scale: 39
925 })
926 .status,
927 TargetStatus::Fail
928 );
929 assert_eq!(
931 bq(&RivetType::Decimal {
932 precision: 30,
933 scale: 0
934 })
935 .target_type,
936 "BIGNUMERIC"
937 );
938 }
939
940 #[test]
941 fn duckdb_decimal_over_38_warns_not_silently_clamps() {
942 let s = duck(&RivetType::Decimal {
943 precision: 40,
944 scale: 2,
945 });
946 assert_eq!(s.status, TargetStatus::Warn);
947 }
948
949 #[test]
952 fn bq_recovery_sql_casts_native_types() {
953 use super::super::{SourceColumn, TimeUnit};
954 let naive = RivetType::Timestamp {
955 unit: TimeUnit::Microsecond,
956 timezone: None,
957 };
958 let mappings = vec![
959 TypeMapping::from_source(&SourceColumn::simple("id", "int8", true), RivetType::Int64),
960 TypeMapping::from_source(
961 &SourceColumn::simple("attrs", "jsonb", true),
962 RivetType::Json,
963 ),
964 TypeMapping::from_source(&SourceColumn::simple("uid", "uuid", true), RivetType::Uuid),
965 TypeMapping::from_source(
966 &SourceColumn::simple("created_at", "timestamp", true),
967 naive,
968 ),
969 TypeMapping::from_source(
970 &SourceColumn::simple("tags", "_text", true),
971 RivetType::List {
972 inner: Box::new(RivetType::String),
973 },
974 ),
975 ];
976 let specs = ExportTarget::BigQuery.resolve_table(&mappings);
977 let sql = ExportTarget::BigQuery
978 .recovery_sql(&specs, "payments")
979 .expect("BigQuery has a recovery SQL");
980 assert!(sql.contains("PARSE_JSON(SAFE_CONVERT_BYTES_TO_STRING(attrs)) AS attrs"));
983 assert!(sql.contains("TO_HEX(uid) AS uid"));
984 assert!(sql.contains("DATETIME(created_at) AS created_at"));
985 assert!(sql.contains("ARRAY(SELECT el.item FROM UNNEST(tags) AS el) AS tags"));
988 assert!(sql.contains("--parquet_enable_list_inference"));
989 assert!(sql.contains("SELECT\n id"));
991 assert!(sql.contains("CREATE OR REPLACE TABLE `payments`"));
993 assert!(sql.contains("FROM `payments__staging`"));
994 }
995
996 #[test]
997 fn duckdb_needs_no_recovery() {
998 let mappings = vec![TypeMapping::from_source(
999 &super::super::SourceColumn::simple("attrs", "json", true),
1000 RivetType::Json,
1001 )];
1002 let specs = ExportTarget::DuckDb.resolve_table(&mappings);
1003 assert!(
1004 ExportTarget::DuckDb.recovery_sql(&specs, "t").is_none(),
1005 "DuckDB autoloads every logical type natively — no recovery needed"
1006 );
1007 }
1008
1009 #[test]
1010 fn recovery_sql_projects_every_column_once_and_only_casts_divergent() {
1011 use super::super::{SourceColumn, TimeUnit};
1012 let naive = RivetType::Timestamp {
1013 unit: TimeUnit::Microsecond,
1014 timezone: None,
1015 };
1016 let cols: [(&str, RivetType); 6] = [
1017 ("id", RivetType::Int64), (
1019 "amount",
1020 RivetType::Decimal {
1021 precision: 18,
1022 scale: 2,
1023 },
1024 ), ("attrs", RivetType::Json), ("uid", RivetType::Uuid), ("created_at", naive), (
1029 "tags",
1030 RivetType::List {
1031 inner: Box::new(RivetType::String),
1032 },
1033 ), ];
1035 let mappings: Vec<_> = cols
1036 .iter()
1037 .cloned()
1038 .map(|(n, rt)| TypeMapping::from_source(&SourceColumn::simple(n, "x", true), rt))
1039 .collect();
1040 let specs = ExportTarget::BigQuery.resolve_table(&mappings);
1041 let sql = ExportTarget::BigQuery.recovery_sql(&specs, "t").unwrap();
1042
1043 let body = sql
1046 .split("SELECT\n")
1047 .nth(1)
1048 .and_then(|s| s.split("\nFROM").next())
1049 .expect("recovery SQL has a SELECT … FROM body");
1050 assert_eq!(
1051 body.split(",\n").count(),
1052 cols.len(),
1053 "one projection per column, got:\n{body}"
1054 );
1055 for (name, _) in &cols {
1056 assert!(body.contains(name), "column {name} missing:\n{body}");
1057 }
1058 assert!(body.contains(" id,") && !body.contains("AS id"));
1061 assert!(body.contains(" amount,") && !body.contains("AS amount"));
1062 assert!(body.contains("PARSE_JSON(SAFE_CONVERT_BYTES_TO_STRING(attrs)) AS attrs"));
1063 assert!(body.contains("TO_HEX(uid) AS uid"));
1064 assert!(body.contains("DATETIME(created_at) AS created_at"));
1065 assert!(body.contains("UNNEST(tags) AS el) AS tags"));
1066 }
1067
1068 #[test]
1071 fn snowflake_autoload_degradations_and_native_casts() {
1072 let j = sf(&RivetType::Json);
1074 assert_eq!(j.target_type, "VARIANT");
1075 assert_eq!(j.autoload_type, "TEXT");
1076 assert!(j.cast_sql.unwrap().starts_with("PARSE_JSON"));
1077 let u = sf(&RivetType::Uuid);
1079 assert_eq!(u.target_type, "TEXT");
1080 assert_eq!(u.autoload_type, "BINARY");
1081 assert!(u.cast_sql.unwrap().contains("HEX_ENCODE"));
1082 let naive = RivetType::Timestamp {
1084 unit: super::super::TimeUnit::Microsecond,
1085 timezone: None,
1086 };
1087 let t = sf(&naive);
1088 assert_eq!(t.target_type, "TIMESTAMP_NTZ");
1089 assert_eq!(t.autoload_type, "NUMBER(38,0)");
1090 assert!(t.cast_sql.unwrap().contains("TO_TIMESTAMP_NTZ"));
1091 let tm = sf(&RivetType::Time {
1093 unit: super::super::TimeUnit::Microsecond,
1094 });
1095 assert_eq!(tm.target_type, "TIME");
1096 assert!(tm.cast_sql.unwrap().contains("TIME_FROM_PARTS"));
1097 let d = sf(&RivetType::Decimal {
1099 precision: 18,
1100 scale: 2,
1101 });
1102 assert_eq!(d.target_type, "NUMBER(18,2)");
1103 assert!(d.cast_sql.is_none());
1104 let l = sf(&RivetType::List {
1106 inner: Box::new(RivetType::Int64),
1107 });
1108 assert_eq!(l.target_type, "ARRAY");
1109 assert_eq!(l.autoload_type, "VARIANT");
1110 assert!(l.cast_sql.unwrap().ends_with("::ARRAY"));
1111 }
1112
1113 #[test]
1114 fn snowflake_recovery_sql_quotes_columns_and_casts() {
1115 use super::super::{SourceColumn, TimeUnit};
1116 let naive = RivetType::Timestamp {
1117 unit: TimeUnit::Microsecond,
1118 timezone: None,
1119 };
1120 let mappings = vec![
1121 TypeMapping::from_source(&SourceColumn::simple("id", "int8", true), RivetType::Int64),
1122 TypeMapping::from_source(
1123 &SourceColumn::simple("attrs", "jsonb", true),
1124 RivetType::Json,
1125 ),
1126 TypeMapping::from_source(&SourceColumn::simple("uid", "uuid", true), RivetType::Uuid),
1127 TypeMapping::from_source(
1128 &SourceColumn::simple("created_at", "timestamp", true),
1129 naive,
1130 ),
1131 ];
1132 let specs = ExportTarget::Snowflake.resolve_table(&mappings);
1133 let sql = ExportTarget::Snowflake.recovery_sql(&specs, "t").unwrap();
1134 assert!(sql.contains("\"id\" AS id"));
1136 assert!(sql.contains("PARSE_JSON(\"attrs\") AS attrs"));
1137 assert!(sql.contains("HEX_ENCODE(\"uid\")"));
1138 assert!(sql.contains("TO_TIMESTAMP_NTZ(\"created_at\", 6) AS created_at"));
1139 assert!(sql.contains("BINARY_AS_TEXT=FALSE"));
1141 assert!(sql.contains("MATCH_BY_COLUMN_NAME"));
1142 assert!(sql.contains("FROM t__staging"));
1143 }
1144
1145 #[test]
1146 fn parse_accepts_snowflake() {
1147 assert_eq!(
1148 ExportTarget::parse("snowflake"),
1149 Some(ExportTarget::Snowflake)
1150 );
1151 assert_eq!(ExportTarget::parse("sf"), Some(ExportTarget::Snowflake));
1152 }
1153}