Skip to main content

rivet_cli/
quality.rs

1// Functions in this module are called from pipeline::sink, pipeline::mod, and integration tests
2// via the library crate. The binary re-declares this module but does not call all functions
3// directly, producing dead_code warnings only for the bin target.
4#![allow(dead_code)]
5
6use std::collections::{HashMap, HashSet};
7
8use arrow::array::Array;
9use arrow::record_batch::RecordBatch;
10
11use crate::config::QualityConfig;
12
13#[derive(Debug, Clone)]
14pub struct QualityIssue {
15    pub severity: Severity,
16    pub message: String,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum Severity {
21    Warn,
22    Fail,
23}
24
25pub fn check_row_count(actual: usize, config: &QualityConfig) -> Vec<QualityIssue> {
26    let mut issues = Vec::new();
27    if let Some(min) = config.row_count_min
28        && actual < min
29    {
30        issues.push(QualityIssue {
31            severity: Severity::Fail,
32            message: format!("row_count {} below minimum {}", actual, min),
33        });
34    }
35    if let Some(max) = config.row_count_max
36        && actual > max
37    {
38        issues.push(QualityIssue {
39            severity: Severity::Fail,
40            message: format!("row_count {} exceeds maximum {}", actual, max),
41        });
42    }
43    issues
44}
45
46pub fn check_null_ratios(
47    batches: &[RecordBatch],
48    thresholds: &HashMap<String, f64>,
49) -> Vec<QualityIssue> {
50    if thresholds.is_empty() {
51        return Vec::new();
52    }
53
54    let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
55    if total_rows == 0 {
56        return Vec::new();
57    }
58
59    let mut null_counts: HashMap<String, usize> = HashMap::new();
60    for batch in batches {
61        let schema = batch.schema();
62        for (i, field) in schema.fields().iter().enumerate() {
63            if thresholds.contains_key(field.name().as_str()) {
64                let col = batch.column(i);
65                *null_counts.entry(field.name().clone()).or_default() += col.null_count();
66            }
67        }
68    }
69
70    let mut issues = Vec::new();
71    for (col_name, max_ratio) in thresholds {
72        let nulls = null_counts.get(col_name.as_str()).copied().unwrap_or(0);
73        let ratio = nulls as f64 / total_rows as f64;
74        if ratio > *max_ratio {
75            issues.push(QualityIssue {
76                severity: Severity::Fail,
77                message: format!(
78                    "column '{}': null ratio {:.4} exceeds threshold {:.4}",
79                    col_name, ratio, max_ratio
80                ),
81            });
82        }
83    }
84    issues
85}
86
87pub fn check_uniqueness(batches: &[RecordBatch], columns: &[String]) -> Vec<QualityIssue> {
88    if columns.is_empty() {
89        return Vec::new();
90    }
91
92    let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
93    if total_rows == 0 {
94        return Vec::new();
95    }
96
97    let mut issues = Vec::new();
98
99    for col_name in columns {
100        let mut seen = HashSet::new();
101        let mut duplicates = 0usize;
102
103        for batch in batches {
104            if let Ok(idx) = batch.schema().index_of(col_name) {
105                let col = batch.column(idx);
106                let string_col = arrow::util::display::ArrayFormatter::try_new(
107                    col.as_ref(),
108                    &arrow::util::display::FormatOptions::default(),
109                );
110                if let Ok(formatter) = string_col {
111                    for row in 0..col.len() {
112                        let val = formatter.value(row).to_string();
113                        if !seen.insert(val) {
114                            duplicates += 1;
115                        }
116                    }
117                }
118            }
119        }
120
121        if duplicates > 0 {
122            issues.push(QualityIssue {
123                severity: Severity::Fail,
124                message: format!(
125                    "column '{}': {} duplicate values out of {} rows",
126                    col_name, duplicates, total_rows
127                ),
128            });
129        }
130    }
131
132    issues
133}
134
135/// Run all configured quality checks. Returns issues found.
136pub fn run_checks(
137    config: &QualityConfig,
138    batches: &[RecordBatch],
139    total_rows: usize,
140) -> Vec<QualityIssue> {
141    let mut all = Vec::new();
142    all.extend(check_row_count(total_rows, config));
143    all.extend(check_null_ratios(batches, &config.null_ratio_max));
144    all.extend(check_uniqueness(batches, &config.unique_columns));
145    all
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use arrow::array::{Int64Array, StringArray};
152    use arrow::datatypes::{DataType, Field, Schema};
153    use arrow::record_batch::RecordBatch;
154    use std::sync::Arc;
155
156    fn make_batch(ids: &[Option<i64>], names: &[Option<&str>]) -> RecordBatch {
157        let schema = Arc::new(Schema::new(vec![
158            Field::new("id", DataType::Int64, true),
159            Field::new("name", DataType::Utf8, true),
160        ]));
161        let id_arr = Int64Array::from(ids.to_vec());
162        let name_arr = StringArray::from(names.to_vec());
163        RecordBatch::try_new(schema, vec![Arc::new(id_arr), Arc::new(name_arr)]).unwrap()
164    }
165
166    #[test]
167    fn row_count_within_bounds() {
168        let cfg = QualityConfig {
169            row_count_min: Some(5),
170            row_count_max: Some(100),
171            null_ratio_max: HashMap::new(),
172            unique_columns: vec![],
173        };
174        assert!(check_row_count(50, &cfg).is_empty());
175    }
176
177    #[test]
178    fn row_count_below_min() {
179        let cfg = QualityConfig {
180            row_count_min: Some(100),
181            row_count_max: None,
182            null_ratio_max: HashMap::new(),
183            unique_columns: vec![],
184        };
185        let issues = check_row_count(50, &cfg);
186        assert_eq!(issues.len(), 1);
187        assert_eq!(issues[0].severity, Severity::Fail);
188        assert!(issues[0].message.contains("below minimum"));
189    }
190
191    #[test]
192    fn row_count_above_max() {
193        let cfg = QualityConfig {
194            row_count_min: None,
195            row_count_max: Some(10),
196            null_ratio_max: HashMap::new(),
197            unique_columns: vec![],
198        };
199        let issues = check_row_count(50, &cfg);
200        assert_eq!(issues.len(), 1);
201        assert!(issues[0].message.contains("exceeds maximum"));
202    }
203
204    #[test]
205    fn null_ratio_passes() {
206        let batch = make_batch(
207            &[Some(1), Some(2), Some(3)],
208            &[Some("a"), Some("b"), Some("c")],
209        );
210        let mut thresholds = HashMap::new();
211        thresholds.insert("name".into(), 0.5);
212        assert!(check_null_ratios(&[batch], &thresholds).is_empty());
213    }
214
215    #[test]
216    fn null_ratio_fails() {
217        let batch = make_batch(&[Some(1), Some(2), Some(3)], &[None, None, Some("c")]);
218        let mut thresholds = HashMap::new();
219        thresholds.insert("name".into(), 0.5);
220        let issues = check_null_ratios(&[batch], &thresholds);
221        assert_eq!(issues.len(), 1);
222        assert!(issues[0].message.contains("null ratio"));
223    }
224
225    #[test]
226    fn uniqueness_passes() {
227        let batch = make_batch(
228            &[Some(1), Some(2), Some(3)],
229            &[Some("a"), Some("b"), Some("c")],
230        );
231        let issues = check_uniqueness(&[batch], &["id".into()]);
232        assert!(issues.is_empty());
233    }
234
235    #[test]
236    fn uniqueness_fails() {
237        let batch = make_batch(
238            &[Some(1), Some(2), Some(1)],
239            &[Some("a"), Some("b"), Some("c")],
240        );
241        let issues = check_uniqueness(&[batch], &["id".into()]);
242        assert_eq!(issues.len(), 1);
243        assert!(issues[0].message.contains("duplicate"));
244    }
245
246    // ─── regression: multi-batch aggregation ─────────────────
247
248    #[test]
249    fn null_ratio_multi_batch_aggregates() {
250        let b1 = make_batch(&[Some(1), Some(2)], &[None, Some("b")]);
251        let b2 = make_batch(&[Some(3), Some(4)], &[None, None]);
252        let mut thresholds = HashMap::new();
253        thresholds.insert("name".into(), 0.5);
254        let issues = check_null_ratios(&[b1, b2], &thresholds);
255        assert_eq!(issues.len(), 1, "3/4 nulls > 0.5 threshold");
256    }
257
258    #[test]
259    fn null_ratio_multi_batch_passes_when_sparse() {
260        let b1 = make_batch(&[Some(1), Some(2)], &[Some("a"), Some("b")]);
261        let b2 = make_batch(&[Some(3)], &[None]);
262        let mut thresholds = HashMap::new();
263        thresholds.insert("name".into(), 0.5);
264        let issues = check_null_ratios(&[b1, b2], &thresholds);
265        assert!(issues.is_empty(), "1/3 nulls < 0.5 threshold");
266    }
267
268    #[test]
269    fn uniqueness_multi_batch_detects_cross_batch_dupes() {
270        let b1 = make_batch(&[Some(1), Some(2)], &[Some("a"), Some("b")]);
271        let b2 = make_batch(&[Some(2), Some(3)], &[Some("c"), Some("d")]);
272        let issues = check_uniqueness(&[b1, b2], &["id".into()]);
273        assert_eq!(issues.len(), 1, "id=2 duplicated across batches");
274    }
275
276    #[test]
277    fn uniqueness_empty_batches() {
278        let schema = Arc::new(Schema::new(vec![
279            Field::new("id", DataType::Int64, true),
280            Field::new("name", DataType::Utf8, true),
281        ]));
282        let empty = RecordBatch::try_new(
283            schema,
284            vec![
285                Arc::new(Int64Array::from(Vec::<Option<i64>>::new())),
286                Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
287            ],
288        )
289        .unwrap();
290        let issues = check_uniqueness(&[empty], &["id".into()]);
291        assert!(issues.is_empty(), "empty batch → no duplicates");
292    }
293
294    #[test]
295    fn null_ratio_empty_batches() {
296        let schema = Arc::new(Schema::new(vec![
297            Field::new("id", DataType::Int64, true),
298            Field::new("name", DataType::Utf8, true),
299        ]));
300        let empty = RecordBatch::try_new(
301            schema,
302            vec![
303                Arc::new(Int64Array::from(Vec::<Option<i64>>::new())),
304                Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
305            ],
306        )
307        .unwrap();
308        let mut thresholds = HashMap::new();
309        thresholds.insert("name".into(), 0.0);
310        let issues = check_null_ratios(&[empty], &thresholds);
311        assert!(issues.is_empty(), "0 rows → skip");
312    }
313
314    // ─── regression: run_checks integration ──────────────────
315
316    #[test]
317    fn run_checks_combines_all_results() {
318        let batch = make_batch(&[Some(1), Some(1), Some(1)], &[None, None, Some("c")]);
319        let cfg = QualityConfig {
320            row_count_min: Some(100),
321            row_count_max: None,
322            null_ratio_max: {
323                let mut m = HashMap::new();
324                m.insert("name".into(), 0.1);
325                m
326            },
327            unique_columns: vec!["id".into()],
328        };
329        let issues = run_checks(&cfg, &[batch], 3);
330        assert!(
331            issues.len() >= 3,
332            "row_count + null_ratio + uniqueness, got: {}",
333            issues.len()
334        );
335    }
336
337    #[test]
338    fn run_checks_no_issues_when_clean() {
339        let batch = make_batch(
340            &[Some(1), Some(2), Some(3)],
341            &[Some("a"), Some("b"), Some("c")],
342        );
343        let cfg = QualityConfig {
344            row_count_min: Some(1),
345            row_count_max: Some(10),
346            null_ratio_max: {
347                let mut m = HashMap::new();
348                m.insert("name".into(), 0.5);
349                m
350            },
351            unique_columns: vec!["id".into()],
352        };
353        let issues = run_checks(&cfg, &[batch], 3);
354        assert!(issues.is_empty(), "all clean: {:?}", issues);
355    }
356
357    #[test]
358    fn row_count_exact_boundary() {
359        let cfg = QualityConfig {
360            row_count_min: Some(5),
361            row_count_max: Some(5),
362            null_ratio_max: HashMap::new(),
363            unique_columns: vec![],
364        };
365        assert!(check_row_count(5, &cfg).is_empty(), "exactly on boundary");
366        assert!(!check_row_count(4, &cfg).is_empty(), "one below min");
367        assert!(!check_row_count(6, &cfg).is_empty(), "one above max");
368    }
369
370    #[test]
371    fn null_ratio_exact_threshold_passes() {
372        let batch = make_batch(&[Some(1), Some(2)], &[None, Some("b")]);
373        let mut thresholds = HashMap::new();
374        thresholds.insert("name".into(), 0.5);
375        let issues = check_null_ratios(&[batch], &thresholds);
376        assert!(issues.is_empty(), "0.5 == 0.5, not >, so should pass");
377    }
378}