Skip to main content

rivet/preflight/
type_report.rs

1//! `rivet check --type-report` — tabular and JSON output.
2//!
3//! Roadmap §9 ("Type Fidelity Report") and §16 ("BigQuery Compatibility Layer").
4//! Renders a `Vec<TypeMapping>` plus any `PolicyViolation`s as either a
5//! fixed-width terminal table or newline-delimited JSON.
6
7use serde::Serialize;
8
9use crate::config::{Config, ExportConfig, SourceType};
10use crate::error::Result;
11use crate::source;
12use crate::types::{
13    ColumnOverrides, TypeFidelity,
14    policy::{PolicyViolation, TypePolicy},
15    target::{ExportTarget, TargetCompat, TargetStatus, check_target_compat},
16};
17
18/// One row in the type report (and the JSON output — roadmap §9).
19#[derive(Serialize)]
20pub struct TypeReportRow {
21    pub column: String,
22    pub source_type: String,
23    pub rivet_type: String,
24    pub arrow_type: String,
25    pub fidelity: TypeFidelity,
26    #[serde(skip_serializing_if = "Vec::is_empty")]
27    pub warnings: Vec<String>,
28    /// Present when `--target` is set.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub target_type: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub target_status: Option<TargetStatus>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub target_note: Option<String>,
35}
36
37/// One export's type-report data.
38#[derive(Serialize)]
39pub struct ExportTypeReport {
40    pub export: String,
41    pub columns: Vec<TypeReportRow>,
42    pub violations: Vec<PolicyViolation>,
43    /// True when any column failed target-compatibility.
44    #[serde(skip_serializing_if = "std::ops::Not::not")]
45    pub target_failures: bool,
46}
47
48impl ExportTypeReport {
49    pub fn has_fatal(&self) -> bool {
50        self.violations.iter().any(|v| v.fatal)
51    }
52
53    pub fn has_target_fail(&self) -> bool {
54        self.target_failures
55    }
56}
57
58/// Collect type mappings for one export from a live connection.
59pub fn collect_report(
60    config: &Config,
61    export: &ExportConfig,
62    column_overrides: &ColumnOverrides,
63    policy: &TypePolicy,
64    target: Option<ExportTarget>,
65) -> Result<ExportTypeReport> {
66    let url = config.source.resolve_url()?;
67    let tls = config.source.tls.as_ref();
68    let query = export
69        .query
70        .as_deref()
71        .or(export.query_file.as_deref())
72        .unwrap_or("");
73
74    let mut src: Box<dyn source::Source> = match config.source.source_type {
75        SourceType::Postgres => Box::new(source::postgres::PostgresSource::connect_with_tls(
76            &url, tls,
77        )?),
78        SourceType::Mysql => Box::new(source::mysql::MysqlSource::connect_with_tls(&url, tls)?),
79    };
80
81    let mappings = src.type_mappings(query, column_overrides)?;
82    let violations = policy.validate(&mappings);
83
84    let mut target_failures = false;
85    let rows = mappings
86        .iter()
87        .map(|m| {
88            let (target_type, target_status, target_note) = if let Some(tgt) = target {
89                let compat: TargetCompat = check_target_compat(m.arrow_type.as_ref(), tgt);
90                if compat.status == TargetStatus::Fail {
91                    target_failures = true;
92                }
93                (Some(compat.target_type), Some(compat.status), compat.note)
94            } else {
95                (None, None, None)
96            };
97            TypeReportRow {
98                column: m.column_name.clone(),
99                source_type: m.source_native_type.clone(),
100                rivet_type: rivet_type_label(&m.rivet_type),
101                arrow_type: m
102                    .arrow_type
103                    .as_ref()
104                    .map(|t| format!("{t:?}"))
105                    .unwrap_or_else(|| "-".into()),
106                fidelity: m.fidelity,
107                warnings: m.warnings.clone(),
108                target_type,
109                target_status,
110                target_note,
111            }
112        })
113        .collect();
114
115    Ok(ExportTypeReport {
116        export: export.name.clone(),
117        columns: rows,
118        violations,
119        target_failures,
120    })
121}
122
123/// Print the report as a human-readable table to stdout.
124pub fn print_table(report: &ExportTypeReport, target: Option<ExportTarget>) {
125    let col_w = col_width(&report.columns, |r| r.column.len());
126    let src_w = col_width(&report.columns, |r| r.source_type.len()).max("Source type".len());
127    let rv_w = col_width(&report.columns, |r| r.rivet_type.len()).max("Rivet type".len());
128    let arr_w = col_width(&report.columns, |r| r.arrow_type.len()).max("Arrow type".len());
129    let fid_w = "logical_string".len();
130
131    println!();
132    if let Some(tgt) = target {
133        println!("Export: {}  [target: {}]", report.export, tgt.label());
134    } else {
135        println!("Export: {}", report.export);
136    }
137
138    if target.is_some() {
139        let tgt_w = col_width(&report.columns, |r| {
140            r.target_type.as_deref().unwrap_or("-").len()
141        })
142        .max("Target type".len());
143        let sta_w = "Status".len();
144
145        println!(
146            "  {:<col_w$}  {:<src_w$}  {:<rv_w$}  {:<arr_w$}  {:<fid_w$}  {:<tgt_w$}  {:<sta_w$}",
147            "Column",
148            "Source type",
149            "Rivet type",
150            "Arrow type",
151            "Fidelity",
152            "Target type",
153            "Status"
154        );
155        println!(
156            "  {:-<col_w$}  {:-<src_w$}  {:-<rv_w$}  {:-<arr_w$}  {:-<fid_w$}  {:-<tgt_w$}  {:-<sta_w$}",
157            "", "", "", "", "", "", ""
158        );
159        for row in &report.columns {
160            let status_label = row.target_status.as_ref().map(|s| s.label()).unwrap_or("-");
161            let tgt_type = row.target_type.as_deref().unwrap_or("-");
162            let status_marker = match &row.target_status {
163                Some(TargetStatus::Fail) => " ✗",
164                Some(TargetStatus::Warn) => " ~",
165                _ => "",
166            };
167            println!(
168                "  {:<col_w$}  {:<src_w$}  {:<rv_w$}  {:<arr_w$}  {}{:<rest$}  {:<tgt_w$}  {}{}",
169                row.column,
170                row.source_type,
171                row.rivet_type,
172                row.arrow_type,
173                row.fidelity.label(),
174                "",
175                tgt_type,
176                status_label,
177                status_marker,
178                rest = fid_w - row.fidelity.label().len(),
179            );
180            if let Some(note) = &row.target_note {
181                println!("  {:<col_w$}    note: {}", "", note);
182            }
183            for w in &row.warnings {
184                println!("  {:<col_w$}    warning: {}", "", w);
185            }
186        }
187    } else {
188        println!(
189            "  {:<col_w$}  {:<src_w$}  {:<rv_w$}  {:<arr_w$}  {:<fid_w$}",
190            "Column", "Source type", "Rivet type", "Arrow type", "Fidelity"
191        );
192        println!(
193            "  {:-<col_w$}  {:-<src_w$}  {:-<rv_w$}  {:-<arr_w$}  {:-<fid_w$}",
194            "", "", "", "", ""
195        );
196        for row in &report.columns {
197            println!(
198                "  {:<col_w$}  {:<src_w$}  {:<rv_w$}  {:<arr_w$}  {}{}",
199                row.column,
200                row.source_type,
201                row.rivet_type,
202                row.arrow_type,
203                row.fidelity.label(),
204                fidelity_marker(row.fidelity),
205            );
206            for w in &row.warnings {
207                println!("  {:<col_w$}    warning: {}", "", w);
208            }
209        }
210    }
211
212    if !report.violations.is_empty() {
213        println!();
214        for v in &report.violations {
215            let prefix = if v.fatal { "  FAIL" } else { "  WARN" };
216            println!("{}: {}", prefix, v.message);
217        }
218    }
219}
220
221/// Emit newline-delimited JSON (one object per export).
222pub fn print_json(report: &ExportTypeReport) -> Result<()> {
223    let s = serde_json::to_string(report)?;
224    println!("{}", s);
225    Ok(())
226}
227
228fn col_width(rows: &[TypeReportRow], f: impl Fn(&TypeReportRow) -> usize) -> usize {
229    rows.iter().map(f).max().unwrap_or(8).max(8)
230}
231
232fn fidelity_marker(f: TypeFidelity) -> &'static str {
233    match f {
234        TypeFidelity::Lossy | TypeFidelity::Unsupported => " ✗",
235        TypeFidelity::LogicalString => " ~",
236        _ => "",
237    }
238}
239
240fn rivet_type_label(t: &crate::types::RivetType) -> String {
241    use crate::types::RivetType::*;
242    match t {
243        Bool => "bool".into(),
244        Int16 => "int2".into(),
245        Int32 => "int4".into(),
246        Int64 => "int8".into(),
247        UInt64 => "uint8".into(),
248        Float32 => "float4".into(),
249        Float64 => "float8".into(),
250        Decimal { precision, scale } => format!("decimal({precision},{scale})"),
251        Date => "date".into(),
252        Time { .. } => "time".into(),
253        Timestamp {
254            timezone: Some(_), ..
255        } => "timestamp_tz".into(),
256        Timestamp { timezone: None, .. } => "timestamp".into(),
257        String => "text".into(),
258        Text => "text".into(),
259        Binary => "binary".into(),
260        Json => "json".into(),
261        Uuid => "uuid".into(),
262        Enum => "enum".into(),
263        Interval => "interval".into(),
264        List { inner } => format!("list<{}>", rivet_type_label(inner)),
265        Unsupported { native_type, .. } => format!("unsupported({native_type})"),
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::types::{RivetType, TypeFidelity};
273
274    // ── fidelity_marker ──────────────────────────────────────────────────────
275
276    #[test]
277    fn fidelity_marker_lossy_is_cross() {
278        assert_eq!(fidelity_marker(TypeFidelity::Lossy), " ✗");
279    }
280
281    #[test]
282    fn fidelity_marker_unsupported_is_cross() {
283        assert_eq!(fidelity_marker(TypeFidelity::Unsupported), " ✗");
284    }
285
286    #[test]
287    fn fidelity_marker_logical_string_is_tilde() {
288        assert_eq!(fidelity_marker(TypeFidelity::LogicalString), " ~");
289    }
290
291    #[test]
292    fn fidelity_marker_exact_is_empty() {
293        assert_eq!(fidelity_marker(TypeFidelity::Exact), "");
294    }
295
296    #[test]
297    fn fidelity_marker_compatible_is_empty() {
298        assert_eq!(fidelity_marker(TypeFidelity::Compatible), "");
299    }
300
301    // ── rivet_type_label ─────────────────────────────────────────────────────
302
303    #[test]
304    fn label_bool() {
305        assert_eq!(rivet_type_label(&RivetType::Bool), "bool");
306    }
307
308    #[test]
309    fn label_int64() {
310        assert_eq!(rivet_type_label(&RivetType::Int64), "int8");
311    }
312
313    #[test]
314    fn label_float64() {
315        assert_eq!(rivet_type_label(&RivetType::Float64), "float8");
316    }
317
318    #[test]
319    fn label_decimal_with_precision_and_scale() {
320        assert_eq!(
321            rivet_type_label(&RivetType::Decimal {
322                precision: 18,
323                scale: 2
324            }),
325            "decimal(18,2)"
326        );
327    }
328
329    #[test]
330    fn label_text() {
331        assert_eq!(rivet_type_label(&RivetType::Text), "text");
332    }
333
334    #[test]
335    fn label_uuid() {
336        assert_eq!(rivet_type_label(&RivetType::Uuid), "uuid");
337    }
338
339    #[test]
340    fn label_list_of_int64() {
341        let t = RivetType::List {
342            inner: Box::new(RivetType::Int64),
343        };
344        assert_eq!(rivet_type_label(&t), "list<int8>");
345    }
346
347    #[test]
348    fn label_unsupported_native_type() {
349        let t = RivetType::Unsupported {
350            native_type: "tsvector".into(),
351            reason: "not supported".into(),
352        };
353        assert_eq!(rivet_type_label(&t), "unsupported(tsvector)");
354    }
355
356    // ── col_width ────────────────────────────────────────────────────────────
357
358    #[test]
359    fn col_width_empty_returns_minimum_8() {
360        let rows: Vec<TypeReportRow> = vec![];
361        assert_eq!(col_width(&rows, |_r| 0), 8);
362    }
363
364    #[test]
365    fn col_width_short_values_returns_minimum_8() {
366        let row = TypeReportRow {
367            column: "a".into(),
368            source_type: "b".into(),
369            rivet_type: "c".into(),
370            arrow_type: "d".into(),
371            fidelity: TypeFidelity::Exact,
372            warnings: vec![],
373            target_type: None,
374            target_status: None,
375            target_note: None,
376        };
377        assert_eq!(col_width(&[row], |r| r.column.len()), 8);
378    }
379
380    #[test]
381    fn col_width_long_value_returns_that_length() {
382        let row = TypeReportRow {
383            column: "a_very_long_column_name".into(),
384            source_type: "int8".into(),
385            rivet_type: "int8".into(),
386            arrow_type: "Int64".into(),
387            fidelity: TypeFidelity::Exact,
388            warnings: vec![],
389            target_type: None,
390            target_status: None,
391            target_note: None,
392        };
393        let w = col_width(&[row], |r| r.column.len());
394        assert_eq!(w, "a_very_long_column_name".len());
395    }
396}