Skip to main content

tl_data/
quality.rs

1// ThinkingLanguage — Data Quality Operations
2// Licensed under MIT OR Apache-2.0
3//
4// DataFrame-level clean, validate, and profile operations.
5
6use datafusion::arrow::array::*;
7use datafusion::arrow::datatypes::{DataType, Field, Schema};
8use datafusion::functions_aggregate::expr_fn::{
9    avg, count, max as agg_max, min as agg_min, stddev,
10};
11use datafusion::prelude::*;
12use std::sync::Arc;
13
14use crate::engine::DataEngine;
15
16impl DataEngine {
17    /// Fill null values in a column using a strategy.
18    /// strategy: "value" (use fill_value), "mean", "zero"
19    pub fn fill_null(
20        &self,
21        df: DataFrame,
22        column: &str,
23        strategy: &str,
24        fill_value: Option<f64>,
25    ) -> Result<DataFrame, String> {
26        let fill_expr = match strategy {
27            "value" => {
28                let val =
29                    fill_value.ok_or("fill_null with 'value' strategy requires a fill_value")?;
30                coalesce(vec![col(column), lit(val)]).alias(column)
31            }
32            "zero" => coalesce(vec![col(column), lit(0.0)]).alias(column),
33            "mean" => {
34                // Compute mean first
35                let mean_df = df
36                    .clone()
37                    .aggregate(vec![], vec![avg(col(column)).alias("__mean")])
38                    .map_err(|e| format!("fill_null mean aggregate error: {e}"))?;
39                let batches = self.collect(mean_df)?;
40                let mean_val = if !batches.is_empty() && batches[0].num_rows() > 0 {
41                    let col_arr = batches[0].column(0);
42                    if let Some(f64_arr) = col_arr.as_any().downcast_ref::<Float64Array>() {
43                        if f64_arr.is_null(0) {
44                            0.0
45                        } else {
46                            f64_arr.value(0)
47                        }
48                    } else {
49                        0.0
50                    }
51                } else {
52                    0.0
53                };
54                coalesce(vec![col(column), lit(mean_val)]).alias(column)
55            }
56            "median" => {
57                // Approximate: use SQL median via sorted approach
58                // For simplicity, compute via mean (median requires more complex logic)
59                let mean_df = df
60                    .clone()
61                    .aggregate(vec![], vec![avg(col(column)).alias("__mean")])
62                    .map_err(|e| format!("fill_null median aggregate error: {e}"))?;
63                let batches = self.collect(mean_df)?;
64                let mean_val = if !batches.is_empty() && batches[0].num_rows() > 0 {
65                    let col_arr = batches[0].column(0);
66                    if let Some(f64_arr) = col_arr.as_any().downcast_ref::<Float64Array>() {
67                        if f64_arr.is_null(0) {
68                            0.0
69                        } else {
70                            f64_arr.value(0)
71                        }
72                    } else {
73                        0.0
74                    }
75                } else {
76                    0.0
77                };
78                coalesce(vec![col(column), lit(mean_val)]).alias(column)
79            }
80            other => return Err(format!("Unknown fill_null strategy: {other}")),
81        };
82
83        // Build select list: replace target column, keep others
84        let schema = df.schema().clone();
85        let mut select_exprs = Vec::new();
86        for field in schema.fields() {
87            if field.name() == column {
88                select_exprs.push(fill_expr.clone());
89            } else {
90                select_exprs.push(col(field.name()));
91            }
92        }
93
94        df.select(select_exprs)
95            .map_err(|e| format!("fill_null select error: {e}"))
96    }
97
98    /// Drop rows where a column is null.
99    pub fn drop_null(&self, df: DataFrame, column: &str) -> Result<DataFrame, String> {
100        df.filter(col(column).is_not_null())
101            .map_err(|e| format!("drop_null error: {e}"))
102    }
103
104    /// Remove duplicate rows based on specified columns.
105    pub fn dedup(&self, df: DataFrame, columns: &[String]) -> Result<DataFrame, String> {
106        if columns.is_empty() {
107            return df.distinct().map_err(|e| format!("dedup error: {e}"));
108        }
109        // Use distinct on specific columns by registering as table + SQL
110        let table_name = "__dedup_tmp";
111        self.ctx
112            .register_table(table_name, df.into_view())
113            .map_err(|e| format!("dedup register error: {e}"))?;
114
115        let cols_str = columns.join(", ");
116        let result = self.sql(&format!(
117            "SELECT DISTINCT ON ({cols_str}) * FROM {table_name}"
118        ));
119
120        // Fallback to regular DISTINCT if DISTINCT ON is not supported
121        match result {
122            Ok(r) => Ok(r),
123            Err(_) => {
124                // Use GROUP BY approach
125                let all_cols = self.sql(&format!("SELECT * FROM {table_name} GROUP BY {cols_str}"));
126                match all_cols {
127                    Ok(r) => Ok(r),
128                    Err(_) => {
129                        // Final fallback: just use DISTINCT
130                        self.sql(&format!("SELECT DISTINCT * FROM {table_name}"))
131                    }
132                }
133            }
134        }
135    }
136
137    /// Clamp values in a column to [min_val, max_val].
138    pub fn clamp(
139        &self,
140        df: DataFrame,
141        column: &str,
142        min_val: f64,
143        max_val: f64,
144    ) -> Result<DataFrame, String> {
145        let clamp_expr = when(col(column).lt(lit(min_val)), lit(min_val))
146            .when(col(column).gt(lit(max_val)), lit(max_val))
147            .otherwise(col(column))
148            .map_err(|e| format!("clamp expr error: {e}"))?
149            .alias(column);
150
151        let schema = df.schema().clone();
152        let mut select_exprs = Vec::new();
153        for field in schema.fields() {
154            if field.name() == column {
155                select_exprs.push(clamp_expr.clone());
156            } else {
157                select_exprs.push(col(field.name()));
158            }
159        }
160
161        df.select(select_exprs)
162            .map_err(|e| format!("clamp select error: {e}"))
163    }
164
165    /// Generate a statistical profile of all numeric columns.
166    /// Returns a table with: column_name, count, null_count, null_rate, min, max, mean, stddev
167    pub fn data_profile(&self, df: DataFrame) -> Result<DataFrame, String> {
168        let schema = df.schema().clone();
169        let mut col_names = Vec::new();
170        let mut counts = Vec::new();
171        let mut null_counts = Vec::new();
172        let mut null_rates = Vec::new();
173        let mut mins = Vec::new();
174        let mut maxs = Vec::new();
175        let mut means = Vec::new();
176        let mut stddevs = Vec::new();
177
178        for field in schema.fields() {
179            let name = field.name();
180            let is_numeric = matches!(
181                field.data_type(),
182                DataType::Int8
183                    | DataType::Int16
184                    | DataType::Int32
185                    | DataType::Int64
186                    | DataType::UInt8
187                    | DataType::UInt16
188                    | DataType::UInt32
189                    | DataType::UInt64
190                    | DataType::Float32
191                    | DataType::Float64
192            );
193
194            // Build aggregation query for this column
195            let mut agg_exprs = vec![count(col(name)).alias("__count")];
196            if is_numeric {
197                agg_exprs.push(agg_min(col(name)).alias("__min"));
198                agg_exprs.push(agg_max(col(name)).alias("__max"));
199                agg_exprs.push(avg(col(name)).alias("__mean"));
200                agg_exprs.push(stddev(col(name)).alias("__stddev"));
201            }
202
203            let agg_df = df
204                .clone()
205                .aggregate(vec![], agg_exprs)
206                .map_err(|e| format!("data_profile aggregate error for {name}: {e}"))?;
207            let batches = self.collect(agg_df)?;
208
209            if batches.is_empty() || batches[0].num_rows() == 0 {
210                continue;
211            }
212            let batch = &batches[0];
213
214            let non_null_cnt = Self::extract_i64_or_u64(batch.column(0));
215            // Get total row count to compute null count
216            let total = self.row_count(df.clone())?;
217            let null_cnt = total - non_null_cnt;
218            let nr = if total > 0 {
219                null_cnt as f64 / total as f64
220            } else {
221                0.0
222            };
223
224            col_names.push(name.clone());
225            counts.push(non_null_cnt);
226            null_counts.push(null_cnt);
227            null_rates.push(nr);
228
229            if is_numeric && batch.num_columns() >= 5 {
230                mins.push(Self::extract_f64(batch.column(1)));
231                maxs.push(Self::extract_f64(batch.column(2)));
232                means.push(Self::extract_f64(batch.column(3)));
233                stddevs.push(Self::extract_f64(batch.column(4)));
234            } else {
235                mins.push(f64::NAN);
236                maxs.push(f64::NAN);
237                means.push(f64::NAN);
238                stddevs.push(f64::NAN);
239            }
240        }
241
242        let result_schema = Arc::new(Schema::new(vec![
243            Field::new("column_name", DataType::Utf8, false),
244            Field::new("count", DataType::Int64, false),
245            Field::new("null_count", DataType::Int64, false),
246            Field::new("null_rate", DataType::Float64, false),
247            Field::new("min", DataType::Float64, true),
248            Field::new("max", DataType::Float64, true),
249            Field::new("mean", DataType::Float64, true),
250            Field::new("stddev", DataType::Float64, true),
251        ]));
252
253        let batch = RecordBatch::try_new(
254            result_schema,
255            vec![
256                Arc::new(StringArray::from(col_names)),
257                Arc::new(Int64Array::from(counts)),
258                Arc::new(Int64Array::from(null_counts)),
259                Arc::new(Float64Array::from(null_rates)),
260                Arc::new(Float64Array::from(mins)),
261                Arc::new(Float64Array::from(maxs)),
262                Arc::new(Float64Array::from(means)),
263                Arc::new(Float64Array::from(stddevs)),
264            ],
265        )
266        .map_err(|e| format!("data_profile batch error: {e}"))?;
267
268        self.register_batch("__data_profile", batch)?;
269        self.rt
270            .block_on(self.ctx.table("__data_profile"))
271            .map_err(|e| format!("data_profile table error: {e}"))
272    }
273
274    /// Get the row count of a DataFrame.
275    pub fn row_count(&self, df: DataFrame) -> Result<i64, String> {
276        let cnt = self
277            .rt
278            .block_on(df.count())
279            .map_err(|e| format!("row_count error: {e}"))?;
280        Ok(cnt as i64)
281    }
282
283    /// Get the null rate of a column (0.0 to 1.0).
284    pub fn null_rate(&self, df: DataFrame, column: &str) -> Result<f64, String> {
285        let total = self
286            .rt
287            .block_on(df.clone().count())
288            .map_err(|e| format!("null_rate count error: {e}"))? as i64;
289        if total == 0 {
290            return Ok(0.0);
291        }
292        let non_null_df = df
293            .aggregate(vec![], vec![count(col(column)).alias("__non_null")])
294            .map_err(|e| format!("null_rate aggregate error: {e}"))?;
295        let batches = self.collect(non_null_df)?;
296        if batches.is_empty() || batches[0].num_rows() == 0 {
297            return Ok(0.0);
298        }
299        let non_null = Self::extract_i64_or_u64(batches[0].column(0));
300        Ok((total - non_null) as f64 / total as f64)
301    }
302
303    /// Check if a column's non-null values are all unique.
304    pub fn is_unique(&self, df: DataFrame, column: &str) -> Result<bool, String> {
305        let table_name = "__unique_check_tmp";
306        self.ctx
307            .register_table(table_name, df.into_view())
308            .map_err(|e| format!("is_unique register error: {e}"))?;
309
310        let result = self.sql(&format!(
311            "SELECT COUNT(DISTINCT \"{column}\") = COUNT(\"{column}\") AS is_uniq FROM {table_name} WHERE \"{column}\" IS NOT NULL"
312        ))?;
313
314        let batches = self.collect(result)?;
315        if batches.is_empty() || batches[0].num_rows() == 0 {
316            return Ok(true);
317        }
318        let col_arr = batches[0].column(0);
319        if let Some(bool_arr) = col_arr.as_any().downcast_ref::<BooleanArray>() {
320            Ok(!bool_arr.is_null(0) && bool_arr.value(0))
321        } else {
322            Ok(false)
323        }
324    }
325
326    // Helper: extract i64 from first row of an array (handles Int64 and UInt64)
327    fn extract_i64_or_u64(arr: &dyn Array) -> i64 {
328        if let Some(a) = arr.as_any().downcast_ref::<Int64Array>() {
329            if a.is_null(0) { 0 } else { a.value(0) }
330        } else if let Some(a) = arr.as_any().downcast_ref::<UInt64Array>() {
331            if a.is_null(0) { 0 } else { a.value(0) as i64 }
332        } else {
333            0
334        }
335    }
336
337    // Helper: extract f64 from first row of an array
338    fn extract_f64(arr: &dyn Array) -> f64 {
339        if let Some(a) = arr.as_any().downcast_ref::<Float64Array>() {
340            if a.is_null(0) { f64::NAN } else { a.value(0) }
341        } else if let Some(a) = arr.as_any().downcast_ref::<Int64Array>() {
342            if a.is_null(0) {
343                f64::NAN
344            } else {
345                a.value(0) as f64
346            }
347        } else if let Some(a) = arr.as_any().downcast_ref::<Int32Array>() {
348            if a.is_null(0) {
349                f64::NAN
350            } else {
351                a.value(0) as f64
352            }
353        } else {
354            f64::NAN
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use datafusion::arrow::array::{Float64Array, Int64Array, StringArray};
363    use datafusion::arrow::datatypes::{DataType, Field, Schema};
364
365    fn make_test_engine_with_data() -> DataEngine {
366        let engine = DataEngine::new();
367        let schema = Arc::new(Schema::new(vec![
368            Field::new("id", DataType::Int64, false),
369            Field::new("name", DataType::Utf8, true),
370            Field::new("age", DataType::Float64, true),
371        ]));
372        let batch = RecordBatch::try_new(
373            schema,
374            vec![
375                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
376                Arc::new(StringArray::from(vec![
377                    Some("Alice"),
378                    Some("Bob"),
379                    None,
380                    Some("Diana"),
381                    Some("Eve"),
382                ])),
383                Arc::new(Float64Array::from(vec![
384                    Some(30.0),
385                    Some(25.0),
386                    None,
387                    Some(35.0),
388                    Some(28.0),
389                ])),
390            ],
391        )
392        .unwrap();
393        engine.register_batch("test_data", batch).unwrap();
394        engine
395    }
396
397    #[test]
398    fn test_fill_null_value() {
399        let engine = make_test_engine_with_data();
400        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
401        let result = engine.fill_null(df, "age", "value", Some(0.0)).unwrap();
402        let batches = engine.collect(result).unwrap();
403        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
404        assert_eq!(total_rows, 5);
405        // Check that null was filled
406        let age_col = batches[0].column_by_name("age").unwrap();
407        let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
408        assert!(!f64_arr.is_null(2)); // was null, now filled
409        assert_eq!(f64_arr.value(2), 0.0);
410    }
411
412    #[test]
413    fn test_fill_null_mean() {
414        let engine = make_test_engine_with_data();
415        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
416        let result = engine.fill_null(df, "age", "mean", None).unwrap();
417        let batches = engine.collect(result).unwrap();
418        let age_col = batches[0].column_by_name("age").unwrap();
419        let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
420        assert!(!f64_arr.is_null(2));
421        // Mean of [30, 25, 35, 28] = 29.5
422        assert!((f64_arr.value(2) - 29.5).abs() < 0.01);
423    }
424
425    #[test]
426    fn test_drop_null() {
427        let engine = make_test_engine_with_data();
428        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
429        let result = engine.drop_null(df, "name").unwrap();
430        let batches = engine.collect(result).unwrap();
431        let total: usize = batches.iter().map(|b| b.num_rows()).sum();
432        assert_eq!(total, 4); // one null row removed
433    }
434
435    #[test]
436    fn test_dedup() {
437        let engine = DataEngine::new();
438        let schema = Arc::new(Schema::new(vec![
439            Field::new("id", DataType::Int64, false),
440            Field::new("val", DataType::Utf8, false),
441        ]));
442        let batch = RecordBatch::try_new(
443            schema,
444            vec![
445                Arc::new(Int64Array::from(vec![1, 2, 2, 3])),
446                Arc::new(StringArray::from(vec!["a", "b", "b", "c"])),
447            ],
448        )
449        .unwrap();
450        engine.register_batch("dup_data", batch).unwrap();
451        let df = engine.rt.block_on(engine.ctx.table("dup_data")).unwrap();
452        let result = engine.dedup(df, &[]).unwrap();
453        let batches = engine.collect(result).unwrap();
454        let total: usize = batches.iter().map(|b| b.num_rows()).sum();
455        assert_eq!(total, 3); // one duplicate removed
456    }
457
458    #[test]
459    fn test_clamp() {
460        let engine = make_test_engine_with_data();
461        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
462        let result = engine.clamp(df, "age", 26.0, 32.0).unwrap();
463        let batches = engine.collect(result).unwrap();
464        let age_col = batches[0].column_by_name("age").unwrap();
465        let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
466        // 30 -> 30 (in range), 25 -> 26 (clamped up), null stays null, 35 -> 32 (clamped down), 28 -> 28
467        assert_eq!(f64_arr.value(0), 30.0);
468        assert_eq!(f64_arr.value(1), 26.0);
469        assert_eq!(f64_arr.value(3), 32.0);
470        assert_eq!(f64_arr.value(4), 28.0);
471    }
472
473    #[test]
474    fn test_data_profile() {
475        let engine = make_test_engine_with_data();
476        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
477        let result = engine.data_profile(df).unwrap();
478        let batches = engine.collect(result).unwrap();
479        assert!(!batches.is_empty());
480        // Should have rows for id, name, age
481        let total: usize = batches.iter().map(|b| b.num_rows()).sum();
482        assert!(total >= 2); // at least id and age (numeric columns get full stats)
483    }
484
485    #[test]
486    fn test_row_count() {
487        let engine = make_test_engine_with_data();
488        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
489        let count = engine.row_count(df).unwrap();
490        assert_eq!(count, 5);
491    }
492
493    #[test]
494    fn test_null_rate_and_is_unique() {
495        let engine = make_test_engine_with_data();
496        let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
497        let rate = engine.null_rate(df, "name").unwrap();
498        assert!((rate - 0.2).abs() < 0.01); // 1 null out of 5
499
500        let df2 = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
501        let unique = engine.is_unique(df2, "id").unwrap();
502        assert!(unique);
503    }
504}