1use serde::Serialize;
8
9use crate::config::{Config, ExportConfig, SourceType};
10use crate::error::Result;
11use crate::source;
12use crate::types::{
13 ColumnOverrides, TypeFidelity,
14 policy::{PolicyViolation, TypePolicy},
15 target::{ExportTarget, TargetInput, TargetStatus},
16};
17
18#[derive(Serialize)]
20pub struct TypeReportRow {
21 pub column: String,
22 pub source_type: String,
23 pub rivet_type: String,
24 pub arrow_type: String,
25 pub fidelity: TypeFidelity,
26 #[serde(skip_serializing_if = "Vec::is_empty")]
27 pub warnings: Vec<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub target_type: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub target_status: Option<TargetStatus>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub target_note: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
39 pub autoload_type: Option<String>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub cast_sql: Option<String>,
43}
44
45#[derive(Serialize)]
47pub struct ExportTypeReport {
48 pub export: String,
49 pub columns: Vec<TypeReportRow>,
50 pub violations: Vec<PolicyViolation>,
51 #[serde(skip_serializing_if = "std::ops::Not::not")]
53 pub target_failures: bool,
54 #[serde(skip_serializing_if = "Option::is_none")]
59 pub recovery_sql: Option<String>,
60}
61
62impl ExportTypeReport {
63 pub fn has_fatal(&self) -> bool {
64 self.violations.iter().any(|v| v.fatal)
65 }
66
67 pub fn has_target_fail(&self) -> bool {
68 self.target_failures
69 }
70}
71
72pub fn collect_report(
74 config: &Config,
75 export: &ExportConfig,
76 column_overrides: &ColumnOverrides,
77 policy: &TypePolicy,
78 target: Option<ExportTarget>,
79 config_dir: &std::path::Path,
80 params: Option<&std::collections::HashMap<String, String>>,
81) -> Result<ExportTypeReport> {
82 let url = config.source.resolve_url()?;
83 let tls = config.source.tls.as_ref();
84 let query = export.resolve_query(config_dir, params)?;
88
89 let mut src: Box<dyn source::Source> = match config.source.source_type {
90 SourceType::Postgres => Box::new(source::postgres::PostgresSource::connect_with_tls(
91 &url, tls,
92 )?),
93 SourceType::Mysql => Box::new(source::mysql::MysqlSource::connect_with_tls(&url, tls)?),
94 };
95
96 let mappings = src.type_mappings(&query, column_overrides)?;
97 let violations = policy.validate(&mappings);
98
99 let mut target_failures = false;
100 let rows = mappings
101 .iter()
102 .map(|m| {
103 let (target_type, target_status, target_note, autoload_type, cast_sql) =
104 if let Some(tgt) = target {
105 let spec = tgt.resolve_column(TargetInput::from(m));
106 if spec.status == TargetStatus::Fail {
107 target_failures = true;
108 }
109 let autoload =
112 (spec.autoload_type != spec.target_type).then_some(spec.autoload_type);
113 (
114 Some(spec.target_type),
115 Some(spec.status),
116 spec.note,
117 autoload,
118 spec.cast_sql,
119 )
120 } else {
121 (None, None, None, None, None)
122 };
123 TypeReportRow {
124 column: m.column_name.clone(),
125 source_type: m.source_native_type.clone(),
126 rivet_type: rivet_type_label(&m.rivet_type),
127 arrow_type: m
128 .arrow_type
129 .as_ref()
130 .map(|t| format!("{t:?}"))
131 .unwrap_or_else(|| "-".into()),
132 fidelity: m.fidelity,
133 warnings: m.warnings.clone(),
134 target_type,
135 target_status,
136 target_note,
137 autoload_type,
138 cast_sql,
139 }
140 })
141 .collect();
142
143 let recovery_sql =
147 target.and_then(|t| t.recovery_sql(&t.resolve_table(&mappings), &export.name));
148
149 Ok(ExportTypeReport {
150 export: export.name.clone(),
151 columns: rows,
152 violations,
153 target_failures,
154 recovery_sql,
155 })
156}
157
158pub fn print_table(report: &ExportTypeReport, target: Option<ExportTarget>) {
160 let col_w = col_width(&report.columns, |r| r.column.len());
161 let src_w = col_width(&report.columns, |r| r.source_type.len()).max("Source type".len());
162 let rv_w = col_width(&report.columns, |r| r.rivet_type.len()).max("Rivet type".len());
163 let arr_w = col_width(&report.columns, |r| r.arrow_type.len()).max("Arrow type".len());
164 let fid_w = "logical_string".len();
165
166 println!();
167 if let Some(tgt) = target {
168 println!("Export: {} [target: {}]", report.export, tgt.label());
169 } else {
170 println!("Export: {}", report.export);
171 }
172
173 if target.is_some() {
174 let tgt_w = col_width(&report.columns, |r| {
175 r.target_type.as_deref().unwrap_or("-").len()
176 })
177 .max("Target type".len());
178 let sta_w = "Status".len();
179
180 println!(
181 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {:<fid_w$} {:<tgt_w$} {:<sta_w$}",
182 "Column",
183 "Source type",
184 "Rivet type",
185 "Arrow type",
186 "Fidelity",
187 "Target type",
188 "Status"
189 );
190 println!(
191 " {:-<col_w$} {:-<src_w$} {:-<rv_w$} {:-<arr_w$} {:-<fid_w$} {:-<tgt_w$} {:-<sta_w$}",
192 "", "", "", "", "", "", ""
193 );
194 for row in &report.columns {
195 let status_label = row.target_status.as_ref().map(|s| s.label()).unwrap_or("-");
196 let tgt_type = row.target_type.as_deref().unwrap_or("-");
197 let status_marker = match &row.target_status {
198 Some(TargetStatus::Fail) => " ✗",
199 Some(TargetStatus::Warn) => " ~",
200 _ => "",
201 };
202 println!(
203 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {}{:<rest$} {:<tgt_w$} {}{}",
204 row.column,
205 row.source_type,
206 row.rivet_type,
207 row.arrow_type,
208 row.fidelity.label(),
209 "",
210 tgt_type,
211 status_label,
212 status_marker,
213 rest = fid_w - row.fidelity.label().len(),
214 );
215 if let Some(autoload) = &row.autoload_type {
216 println!(" {:<col_w$} autoload: {}", "", autoload);
217 }
218 if let Some(note) = &row.target_note {
219 println!(" {:<col_w$} note: {}", "", note);
220 }
221 if let Some(cast) = &row.cast_sql {
222 println!(" {:<col_w$} recover: {}", "", cast);
223 }
224 for w in &row.warnings {
225 println!(" {:<col_w$} warning: {}", "", w);
226 }
227 }
228 } else {
229 println!(
230 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {:<fid_w$}",
231 "Column", "Source type", "Rivet type", "Arrow type", "Fidelity"
232 );
233 println!(
234 " {:-<col_w$} {:-<src_w$} {:-<rv_w$} {:-<arr_w$} {:-<fid_w$}",
235 "", "", "", "", ""
236 );
237 for row in &report.columns {
238 println!(
239 " {:<col_w$} {:<src_w$} {:<rv_w$} {:<arr_w$} {}{}",
240 row.column,
241 row.source_type,
242 row.rivet_type,
243 row.arrow_type,
244 row.fidelity.label(),
245 fidelity_marker(row.fidelity),
246 );
247 for w in &row.warnings {
248 println!(" {:<col_w$} warning: {}", "", w);
249 }
250 }
251 }
252
253 if !report.violations.is_empty() {
254 println!();
255 for v in &report.violations {
256 let prefix = if v.fatal { " FAIL" } else { " WARN" };
257 println!("{}: {}", prefix, v.message);
258 }
259 }
260
261 if let Some(sql) = &report.recovery_sql {
262 println!();
263 println!(
264 " {} type recovery — bare autoload degrades JSON/UUID→BYTES, naive",
265 target.map(|t| t.label()).unwrap_or("target")
266 );
267 println!(" timestamp→TIMESTAMP, array→RECORD; load with --autodetect then run:");
268 for line in sql.lines() {
269 println!(" {line}");
270 }
271 }
272}
273
274pub fn print_json(report: &ExportTypeReport) -> Result<()> {
276 let s = serde_json::to_string(report)?;
277 println!("{}", s);
278 Ok(())
279}
280
281fn col_width(rows: &[TypeReportRow], f: impl Fn(&TypeReportRow) -> usize) -> usize {
282 rows.iter().map(f).max().unwrap_or(8).max(8)
283}
284
285fn fidelity_marker(f: TypeFidelity) -> &'static str {
286 match f {
287 TypeFidelity::Lossy | TypeFidelity::Unsupported => " ✗",
288 TypeFidelity::LogicalString => " ~",
289 _ => "",
290 }
291}
292
293fn rivet_type_label(t: &crate::types::RivetType) -> String {
294 use crate::types::RivetType::*;
295 match t {
296 Bool => "bool".into(),
297 Int16 => "int2".into(),
298 Int32 => "int4".into(),
299 Int64 => "int8".into(),
300 UInt64 => "uint8".into(),
301 Float32 => "float4".into(),
302 Float64 => "float8".into(),
303 Decimal { precision, scale } => format!("decimal({precision},{scale})"),
304 Date => "date".into(),
305 Time { .. } => "time".into(),
306 Timestamp {
307 timezone: Some(_), ..
308 } => "timestamp_tz".into(),
309 Timestamp { timezone: None, .. } => "timestamp".into(),
310 String => "text".into(),
311 Text => "text".into(),
312 Binary => "binary".into(),
313 Json => "json".into(),
314 Uuid => "uuid".into(),
315 Enum => "enum".into(),
316 Interval => "interval".into(),
317 List { inner } => format!("list<{}>", rivet_type_label(inner)),
318 Unsupported { native_type, .. } => format!("unsupported({native_type})"),
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::types::{RivetType, TypeFidelity};
326
327 #[test]
330 fn fidelity_marker_lossy_is_cross() {
331 assert_eq!(fidelity_marker(TypeFidelity::Lossy), " ✗");
332 }
333
334 #[test]
335 fn fidelity_marker_unsupported_is_cross() {
336 assert_eq!(fidelity_marker(TypeFidelity::Unsupported), " ✗");
337 }
338
339 #[test]
340 fn fidelity_marker_logical_string_is_tilde() {
341 assert_eq!(fidelity_marker(TypeFidelity::LogicalString), " ~");
342 }
343
344 #[test]
345 fn fidelity_marker_exact_is_empty() {
346 assert_eq!(fidelity_marker(TypeFidelity::Exact), "");
347 }
348
349 #[test]
350 fn fidelity_marker_compatible_is_empty() {
351 assert_eq!(fidelity_marker(TypeFidelity::Compatible), "");
352 }
353
354 #[test]
357 fn label_bool() {
358 assert_eq!(rivet_type_label(&RivetType::Bool), "bool");
359 }
360
361 #[test]
362 fn label_int64() {
363 assert_eq!(rivet_type_label(&RivetType::Int64), "int8");
364 }
365
366 #[test]
367 fn label_float64() {
368 assert_eq!(rivet_type_label(&RivetType::Float64), "float8");
369 }
370
371 #[test]
372 fn label_decimal_with_precision_and_scale() {
373 assert_eq!(
374 rivet_type_label(&RivetType::Decimal {
375 precision: 18,
376 scale: 2
377 }),
378 "decimal(18,2)"
379 );
380 }
381
382 #[test]
383 fn label_text() {
384 assert_eq!(rivet_type_label(&RivetType::Text), "text");
385 }
386
387 #[test]
388 fn label_uuid() {
389 assert_eq!(rivet_type_label(&RivetType::Uuid), "uuid");
390 }
391
392 #[test]
393 fn label_list_of_int64() {
394 let t = RivetType::List {
395 inner: Box::new(RivetType::Int64),
396 };
397 assert_eq!(rivet_type_label(&t), "list<int8>");
398 }
399
400 #[test]
401 fn label_unsupported_native_type() {
402 let t = RivetType::Unsupported {
403 native_type: "tsvector".into(),
404 reason: "not supported".into(),
405 };
406 assert_eq!(rivet_type_label(&t), "unsupported(tsvector)");
407 }
408
409 #[test]
412 fn col_width_empty_returns_minimum_8() {
413 let rows: Vec<TypeReportRow> = vec![];
414 assert_eq!(col_width(&rows, |_r| 0), 8);
415 }
416
417 #[test]
418 fn col_width_short_values_returns_minimum_8() {
419 let row = TypeReportRow {
420 column: "a".into(),
421 source_type: "b".into(),
422 rivet_type: "c".into(),
423 arrow_type: "d".into(),
424 fidelity: TypeFidelity::Exact,
425 warnings: vec![],
426 target_type: None,
427 target_status: None,
428 target_note: None,
429 autoload_type: None,
430 cast_sql: None,
431 };
432 assert_eq!(col_width(&[row], |r| r.column.len()), 8);
433 }
434
435 #[test]
436 fn col_width_long_value_returns_that_length() {
437 let row = TypeReportRow {
438 column: "a_very_long_column_name".into(),
439 source_type: "int8".into(),
440 rivet_type: "int8".into(),
441 arrow_type: "Int64".into(),
442 fidelity: TypeFidelity::Exact,
443 warnings: vec![],
444 target_type: None,
445 target_status: None,
446 target_note: None,
447 autoload_type: None,
448 cast_sql: None,
449 };
450 let w = col_width(&[row], |r| r.column.len());
451 assert_eq!(w, "a_very_long_column_name".len());
452 }
453}