1use serde::Serialize;
8
9use crate::config::{Config, ExportConfig, FormatType, SourceType};
10use crate::error::Result;
11use crate::source;
12use crate::types::{
13 ColumnOverrides, TypeFidelity,
14 policy::{PolicyAction, PolicyViolation, TypePolicy},
15 target::{ExportTarget, TargetInput, TargetStatus},
16};
17
18#[derive(Serialize)]
20pub struct TypeReportRow {
21 pub column: String,
22 pub source_type: String,
23 pub rivet_type: String,
24 pub arrow_type: String,
25 pub fidelity: TypeFidelity,
26 #[serde(skip_serializing_if = "Vec::is_empty")]
27 pub warnings: Vec<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub target_type: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub target_status: Option<TargetStatus>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub target_note: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
39 pub autoload_type: Option<String>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub cast_sql: Option<String>,
43}
44
45#[derive(Serialize)]
47pub struct ExportTypeReport {
48 pub export: String,
49 pub columns: Vec<TypeReportRow>,
50 pub violations: Vec<PolicyViolation>,
51 #[serde(skip_serializing_if = "std::ops::Not::not")]
53 pub target_failures: bool,
54 #[serde(skip_serializing_if = "Option::is_none")]
59 pub recovery_sql: Option<String>,
60}
61
62impl ExportTypeReport {
63 pub fn has_fatal(&self) -> bool {
64 self.violations.iter().any(|v| v.fatal)
65 }
66
67 pub fn has_target_fail(&self) -> bool {
68 self.target_failures
69 }
70}
71
72pub fn collect_report(
74 config: &Config,
75 export: &ExportConfig,
76 column_overrides: &ColumnOverrides,
77 policy: &TypePolicy,
78 target: Option<ExportTarget>,
79 config_dir: &std::path::Path,
80 params: Option<&std::collections::HashMap<String, String>>,
81) -> Result<ExportTypeReport> {
82 let url = config.source.resolve_url()?;
83 let tls = config.source.tls.as_ref();
84 let query = export.resolve_query(config_dir, params)?;
88
89 let mut src: Box<dyn source::Source> = match config.source.source_type {
90 SourceType::Postgres => Box::new(source::postgres::PostgresSource::connect_with_tls(
91 &url, tls,
92 )?),
93 SourceType::Mysql => Box::new(source::mysql::MysqlSource::connect_with_tls(&url, tls)?),
94 SourceType::Mssql => Box::new(source::mssql::MssqlSource::connect_with_tls(&url, tls)?),
95 };
96
97 let mappings = src.type_mappings(&query, column_overrides)?;
98 let mut violations = policy.validate(&mappings);
99
100 if export.format == FormatType::Csv {
107 let fatal = policy.on_unsupported_type == PolicyAction::Fail;
108 for m in &mappings {
109 if let Some(dt) = m.arrow_type.as_ref()
110 && !crate::format::csv::csv_serializable(dt)
111 {
112 violations.push(PolicyViolation {
113 column_name: m.column_name.clone(),
114 fidelity: TypeFidelity::Unsupported,
115 message: format!(
116 "column '{}' (Arrow {dt:?}) cannot be serialized to CSV — \
117 use `format: parquet` or drop it from the query",
118 m.column_name
119 ),
120 fatal,
121 });
122 }
123 }
124 }
125
126 let mut target_failures = false;
127 let rows = mappings
128 .iter()
129 .map(|m| {
130 let (target_type, target_status, target_note, autoload_type, cast_sql) =
131 if let Some(tgt) = target {
132 let spec = tgt.resolve_column(TargetInput::from(m));
133 if spec.status == TargetStatus::Fail {
134 target_failures = true;
135 }
136 let autoload =
139 (spec.autoload_type != spec.target_type).then_some(spec.autoload_type);
140 (
141 Some(spec.target_type),
142 Some(spec.status),
143 spec.note,
144 autoload,
145 spec.cast_sql,
146 )
147 } else {
148 (None, None, None, None, None)
149 };
150 TypeReportRow {
151 column: m.column_name.clone(),
152 source_type: m.source_native_type.clone(),
153 rivet_type: rivet_type_label(&m.rivet_type),
154 arrow_type: m
155 .arrow_type
156 .as_ref()
157 .map(|t| format!("{t:?}"))
158 .unwrap_or_else(|| "-".into()),
159 fidelity: m.fidelity,
160 warnings: m.warnings.clone(),
161 target_type,
162 target_status,
163 target_note,
164 autoload_type,
165 cast_sql,
166 }
167 })
168 .collect();
169
170 let recovery_sql =
174 target.and_then(|t| t.recovery_sql(&t.resolve_table(&mappings), &export.name));
175
176 Ok(ExportTypeReport {
177 export: export.name.clone(),
178 columns: rows,
179 violations,
180 target_failures,
181 recovery_sql,
182 })
183}
184
185pub fn print_table(report: &ExportTypeReport, target: Option<ExportTarget>) {
187 let col_w = col_width(&report.columns, |r| r.column.len());
188 let src_w = col_width(&report.columns, |r| r.source_type.len()).max("Source type".len());
189 let rv_w = col_width(&report.columns, |r| r.rivet_type.len()).max("Rivet type".len());
190 let arr_w = col_width(&report.columns, |r| r.arrow_type.len()).max("Arrow type".len());
191 let fid_w = "logical_string".len();
192
193 println!();
194 if let Some(tgt) = target {
195 println!("Export: {} [target: {}]", report.export, tgt.label());
196 } else {
197 println!("Export: {}", report.export);
198 }
199
200 if target.is_some() {
201 let tgt_w = col_width(&report.columns, |r| {
202 r.target_type.as_deref().unwrap_or("-").len()
203 })
204 .max("Target type".len());
205 let sta_w = "Status".len();
206
207 println!(
208 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {:<fid_w$} {:<tgt_w$} {:<sta_w$}",
209 "Column",
210 "Source type",
211 "Rivet type",
212 "Arrow type",
213 "Fidelity",
214 "Target type",
215 "Status"
216 );
217 println!(
218 " {:-<col_w$} {:-<src_w$} {:-<rv_w$} {:-<arr_w$} {:-<fid_w$} {:-<tgt_w$} {:-<sta_w$}",
219 "", "", "", "", "", "", ""
220 );
221 for row in &report.columns {
222 let status_label = row.target_status.as_ref().map(|s| s.label()).unwrap_or("-");
223 let tgt_type = row.target_type.as_deref().unwrap_or("-");
224 let status_marker = match &row.target_status {
225 Some(TargetStatus::Fail) => " ✗",
226 Some(TargetStatus::Warn) => " ~",
227 _ => "",
228 };
229 println!(
230 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {}{:<rest$} {:<tgt_w$} {}{}",
231 row.column,
232 row.source_type,
233 row.rivet_type,
234 row.arrow_type,
235 row.fidelity.label(),
236 "",
237 tgt_type,
238 status_label,
239 status_marker,
240 rest = fid_w - row.fidelity.label().len(),
241 );
242 if let Some(autoload) = &row.autoload_type {
243 println!(" {:<col_w$} autoload: {}", "", autoload);
244 }
245 if let Some(note) = &row.target_note {
246 println!(" {:<col_w$} note: {}", "", note);
247 }
248 if let Some(cast) = &row.cast_sql {
249 println!(" {:<col_w$} recover: {}", "", cast);
250 }
251 for w in &row.warnings {
252 println!(" {:<col_w$} warning: {}", "", w);
253 }
254 }
255 } else {
256 println!(
257 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {:<fid_w$}",
258 "Column", "Source type", "Rivet type", "Arrow type", "Fidelity"
259 );
260 println!(
261 " {:-<col_w$} {:-<src_w$} {:-<rv_w$} {:-<arr_w$} {:-<fid_w$}",
262 "", "", "", "", ""
263 );
264 for row in &report.columns {
265 println!(
266 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {}{}",
267 row.column,
268 row.source_type,
269 row.rivet_type,
270 row.arrow_type,
271 row.fidelity.label(),
272 fidelity_marker(row.fidelity),
273 );
274 for w in &row.warnings {
275 println!(" {:<col_w$} warning: {}", "", w);
276 }
277 }
278 }
279
280 if !report.violations.is_empty() {
281 println!();
282 for v in &report.violations {
283 let prefix = if v.fatal { " FAIL" } else { " WARN" };
284 println!("{}: {}", prefix, v.message);
285 }
286 }
287
288 if let Some(sql) = &report.recovery_sql {
289 println!();
290 println!(
291 " {} type recovery — bare autoload degrades JSON/UUID→BYTES, naive",
292 target.map(|t| t.label()).unwrap_or("target")
293 );
294 println!(" timestamp→TIMESTAMP, array→RECORD; load with --autodetect then run:");
295 for line in sql.lines() {
296 println!(" {line}");
297 }
298 }
299}
300
301pub fn print_json(report: &ExportTypeReport) -> Result<()> {
303 let s = serde_json::to_string(report)?;
304 println!("{}", s);
305 Ok(())
306}
307
308fn col_width(rows: &[TypeReportRow], f: impl Fn(&TypeReportRow) -> usize) -> usize {
309 rows.iter().map(f).max().unwrap_or(8).max(8)
310}
311
312fn fidelity_marker(f: TypeFidelity) -> &'static str {
313 match f {
314 TypeFidelity::Lossy | TypeFidelity::Unsupported => " ✗",
315 TypeFidelity::LogicalString => " ~",
316 _ => "",
317 }
318}
319
320fn rivet_type_label(t: &crate::types::RivetType) -> String {
321 use crate::types::RivetType::*;
322 match t {
323 Bool => "bool".into(),
324 Int16 => "int2".into(),
325 Int32 => "int4".into(),
326 Int64 => "int8".into(),
327 UInt64 => "uint8".into(),
328 Float32 => "float4".into(),
329 Float64 => "float8".into(),
330 Decimal { precision, scale } => format!("decimal({precision},{scale})"),
331 Date => "date".into(),
332 Time { .. } => "time".into(),
333 Timestamp {
334 timezone: Some(_), ..
335 } => "timestamp_tz".into(),
336 Timestamp { timezone: None, .. } => "timestamp".into(),
337 String => "text".into(),
338 Text => "text".into(),
339 Binary => "binary".into(),
340 Json => "json".into(),
341 Uuid => "uuid".into(),
342 Enum => "enum".into(),
343 Interval => "interval".into(),
344 List { inner } => format!("list<{}>", rivet_type_label(inner)),
345 Unsupported { native_type, .. } => format!("unsupported({native_type})"),
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::types::{RivetType, TypeFidelity};
353
354 #[test]
357 fn fidelity_marker_lossy_is_cross() {
358 assert_eq!(fidelity_marker(TypeFidelity::Lossy), " ✗");
359 }
360
361 #[test]
362 fn fidelity_marker_unsupported_is_cross() {
363 assert_eq!(fidelity_marker(TypeFidelity::Unsupported), " ✗");
364 }
365
366 #[test]
367 fn fidelity_marker_logical_string_is_tilde() {
368 assert_eq!(fidelity_marker(TypeFidelity::LogicalString), " ~");
369 }
370
371 #[test]
372 fn fidelity_marker_exact_is_empty() {
373 assert_eq!(fidelity_marker(TypeFidelity::Exact), "");
374 }
375
376 #[test]
377 fn fidelity_marker_compatible_is_empty() {
378 assert_eq!(fidelity_marker(TypeFidelity::Compatible), "");
379 }
380
381 #[test]
384 fn label_bool() {
385 assert_eq!(rivet_type_label(&RivetType::Bool), "bool");
386 }
387
388 #[test]
389 fn label_int64() {
390 assert_eq!(rivet_type_label(&RivetType::Int64), "int8");
391 }
392
393 #[test]
394 fn label_float64() {
395 assert_eq!(rivet_type_label(&RivetType::Float64), "float8");
396 }
397
398 #[test]
399 fn label_decimal_with_precision_and_scale() {
400 assert_eq!(
401 rivet_type_label(&RivetType::Decimal {
402 precision: 18,
403 scale: 2
404 }),
405 "decimal(18,2)"
406 );
407 }
408
409 #[test]
410 fn label_text() {
411 assert_eq!(rivet_type_label(&RivetType::Text), "text");
412 }
413
414 #[test]
415 fn label_uuid() {
416 assert_eq!(rivet_type_label(&RivetType::Uuid), "uuid");
417 }
418
419 #[test]
420 fn label_list_of_int64() {
421 let t = RivetType::List {
422 inner: Box::new(RivetType::Int64),
423 };
424 assert_eq!(rivet_type_label(&t), "list<int8>");
425 }
426
427 #[test]
428 fn label_unsupported_native_type() {
429 let t = RivetType::Unsupported {
430 native_type: "tsvector".into(),
431 reason: "not supported".into(),
432 };
433 assert_eq!(rivet_type_label(&t), "unsupported(tsvector)");
434 }
435
436 #[test]
439 fn col_width_empty_returns_minimum_8() {
440 let rows: Vec<TypeReportRow> = vec![];
441 assert_eq!(col_width(&rows, |_r| 0), 8);
442 }
443
444 #[test]
445 fn col_width_short_values_returns_minimum_8() {
446 let row = TypeReportRow {
447 column: "a".into(),
448 source_type: "b".into(),
449 rivet_type: "c".into(),
450 arrow_type: "d".into(),
451 fidelity: TypeFidelity::Exact,
452 warnings: vec![],
453 target_type: None,
454 target_status: None,
455 target_note: None,
456 autoload_type: None,
457 cast_sql: None,
458 };
459 assert_eq!(col_width(&[row], |r| r.column.len()), 8);
460 }
461
462 #[test]
463 fn col_width_long_value_returns_that_length() {
464 let row = TypeReportRow {
465 column: "a_very_long_column_name".into(),
466 source_type: "int8".into(),
467 rivet_type: "int8".into(),
468 arrow_type: "Int64".into(),
469 fidelity: TypeFidelity::Exact,
470 warnings: vec![],
471 target_type: None,
472 target_status: None,
473 target_note: None,
474 autoload_type: None,
475 cast_sql: None,
476 };
477 let w = col_width(&[row], |r| r.column.len());
478 assert_eq!(w, "a_very_long_column_name".len());
479 }
480}