Skip to main content

datafusion_quality/
lib.rs

1pub mod error;
2pub mod rules;
3
4use crate::error::ValidationError;
5use datafusion::{common::DFSchema, logical_expr::ExprSchemable, prelude::*};
6use error::DataFusionSnafu;
7use snafu::ResultExt;
8use std::sync::Arc;
9
10/// The main RuleSet struct that holds the context and rules
11#[derive(Clone, Default)]
12pub struct RuleSet {
13    pub(crate) schema_rules: Vec<Arc<dyn SchemaRule>>,
14    pub(crate) column_rules: Vec<(String, Arc<dyn ColumnRule>)>,
15    pub(crate) table_rules: Vec<(String, Arc<dyn TableRule>)>,
16}
17
18impl std::fmt::Debug for RuleSet {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("RuleSet")
21            .field("schema_rules", &self.schema_rules)
22            .field("column_rules", &self.column_rules)
23            .field("table_rules", &self.table_rules)
24            .finish_non_exhaustive()
25    }
26}
27
28/// Trait for schema-level rules
29pub trait SchemaRule: Send + Sync + std::fmt::Debug {
30    /// Validate the schema
31    fn validate_schema(&self, _schema: &DFSchema) -> Result<bool, ValidationError> {
32        unimplemented!("validate_schema must be implemented")
33    }
34
35    /// Validate the schema with access to the RuleSet
36    fn validate_schema_with_ruleset(
37        &self,
38        schema: &DFSchema,
39        _rule_set: &RuleSet,
40    ) -> Result<bool, ValidationError> {
41        self.validate_schema(schema)
42    }
43
44    /// Get the name of the rule
45    fn name(&self) -> &str;
46
47    /// Get the description of the rule
48    fn description(&self) -> &str;
49}
50
51/// Trait for column-level rules
52pub trait ColumnRule: Send + Sync + std::fmt::Debug {
53    /// Apply the rule to a DataFrame, adding a new column
54    fn apply(&self, _df: DataFrame, _column_name: &str) -> Result<DataFrame, ValidationError> {
55        unimplemented!("apply must be implemented")
56    }
57
58    /// Apply the rule to a DataFrame with access to the RuleSet
59    fn apply_with_ruleset(
60        &self,
61        df: DataFrame,
62        column_name: &str,
63        _rule_set: &RuleSet,
64    ) -> Result<DataFrame, ValidationError> {
65        self.apply(df, column_name)
66    }
67
68    /// Get the name of the rule
69    fn name(&self) -> &str;
70
71    /// Get the name of the new column
72    fn new_column_name(&self, column_name: &str) -> String;
73
74    /// Get the description of the rule
75    fn description(&self) -> &str;
76}
77
78/// Trait for table-level aggregate rules
79pub trait TableRule: Send + Sync + std::fmt::Debug {
80    /// Apply the rule to a DataFrame, adding a new column with aggregated results
81    fn apply(&self, _df: DataFrame, _column_name: &str) -> Result<DataFrame, ValidationError> {
82        unimplemented!("apply must be implemented")
83    }
84
85    /// Apply the rule to a DataFrame with access to the RuleSet
86    fn apply_with_ruleset(
87        &self,
88        df: DataFrame,
89        column_name: &str,
90        _rule_set: &RuleSet,
91    ) -> Result<DataFrame, ValidationError> {
92        self.apply(df, column_name)
93    }
94
95    /// Get the name of the rule
96    fn name(&self) -> &str;
97
98    /// Get the name of the new column
99    fn new_column_name(&self, column_name: &str) -> String;
100
101    /// Get the description of the rule
102    fn description(&self) -> &str;
103}
104
105impl RuleSet {
106    /// Create a new RuleSet instance
107    pub fn new() -> Self {
108        Self {
109            schema_rules: Vec::new(),
110            column_rules: Vec::new(),
111            table_rules: Vec::new(),
112        }
113    }
114
115    /// Add a schema rule
116    pub fn with_schema_rule(&mut self, rule: Arc<dyn SchemaRule>) -> &mut Self {
117        self.schema_rules.push(rule);
118        self
119    }
120
121    /// Add a column rule
122    pub fn with_column_rule(
123        &mut self,
124        column_name: impl AsRef<str>,
125        rule: Arc<dyn ColumnRule>,
126    ) -> &mut Self {
127        let column_name = column_name.as_ref().to_string();
128        self.column_rules.push((column_name, rule));
129        self
130    }
131
132    /// Add a table rule
133    pub fn with_table_rule(
134        &mut self,
135        column_name: impl AsRef<str>,
136        table_rule: Arc<dyn TableRule>,
137        check: Option<Arc<dyn ColumnRule>>,
138    ) -> &mut Self {
139        let column_name = column_name.as_ref().to_string();
140        if let Some(check) = check {
141            let column_name = table_rule.new_column_name(&column_name);
142            self.column_rules.push((column_name, check));
143        }
144        self.table_rules.push((column_name, table_rule));
145        self
146    }
147
148    pub async fn apply_table_rules(&self, df: DataFrame) -> Result<DataFrame, ValidationError> {
149        let mut result_df = df;
150        for (column_name, rule) in &self.table_rules {
151            result_df = rule.apply_with_ruleset(result_df, column_name, self)?;
152        }
153        Ok(result_df)
154    }
155
156    /// Apply all rules to a DataFrame
157    pub async fn apply(&self, df: &DataFrame) -> Result<DataFrame, ValidationError> {
158        // First validate schema
159        for rule in &self.schema_rules {
160            if !rule.validate_schema(df.schema())? {
161                return Err(ValidationError::Schema {
162                    message: format!("Schema rule '{}' failed", rule.name()),
163                });
164            }
165        }
166
167        let mut result_df = df.clone();
168        // Apply table calculations
169        result_df = self.apply_table_rules(result_df).await?;
170
171        let mut check_columns = Vec::new();
172
173        // Then apply column rules
174        for (column_name, rule) in &self.column_rules {
175            result_df = rule.apply_with_ruleset(result_df, column_name, self)?;
176            check_columns.push(rule.new_column_name(column_name));
177        }
178
179        let dq_pass_col = check_columns
180            .into_iter()
181            .map(|col_name| {
182                col(col_name)
183                    .cast_to(&arrow::datatypes::DataType::Boolean, result_df.schema())
184                    .map_err(|e| ValidationError::Column {
185                        message: format!("Error casting column to boolean: {}", e),
186                    })
187            })
188            .reduce(|acc, col| Ok(acc?.and(col?)))
189            .unwrap_or(Ok(lit(true)))?;
190
191        result_df = result_df.with_column("dfq_pass", dq_pass_col)?;
192
193        Ok(result_df)
194    }
195
196    pub async fn partition(
197        &self,
198        df: &DataFrame,
199    ) -> Result<(DataFrame, DataFrame), ValidationError> {
200        let dq_df = self.apply(df).await?.cache().await?;
201
202        let pass_expr = col("dfq_pass").eq(lit(true));
203        let pass_df = dq_df.clone().filter(pass_expr.clone())?.select_columns(
204            &df.schema()
205                .fields()
206                .iter()
207                .map(|s| s.name().as_str())
208                .collect::<Vec<&str>>(),
209        )?;
210        let fail_df = dq_df.filter(pass_expr.not())?;
211        Ok((pass_df, fail_df))
212    }
213
214    pub async fn derived_statistics(
215        &self,
216        df: &DataFrame,
217        extra_columns: Option<Vec<&str>>,
218    ) -> Result<DataFrame, ValidationError> {
219        let dq_df = self.apply(df).await?;
220
221        let mut table_rules_names = Vec::new();
222        if let Some(extra_columns) = extra_columns {
223            table_rules_names.extend(extra_columns.iter().map(|s| col(*s)));
224        }
225
226        for (column_name, rule) in &self.table_rules {
227            table_rules_names.push(col(rule.new_column_name(column_name)));
228        }
229
230        dq_df.select(table_rules_names).context(DataFusionSnafu)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::rules::column::*;
238    use crate::rules::table::*;
239    use arrow::record_batch::RecordBatch;
240    use datafusion::arrow::array::{Float64Array, Int32Array, StringArray};
241    use datafusion::arrow::datatypes::{DataType, Field, Schema};
242    use datafusion::assert_batches_eq;
243    use std::sync::Arc;
244
245    async fn create_test_df() -> (SessionContext, DataFrame) {
246        let schema = Schema::new(vec![
247            Field::new("id", DataType::Int32, false),
248            Field::new("name", DataType::Utf8, true),
249            Field::new("age", DataType::Int32, true),
250            Field::new("score", DataType::Float64, true),
251        ]);
252
253        let id_data = Int32Array::from(vec![1, 2, 3, 4, 5]);
254        let name_data = StringArray::from(vec![
255            Some("Alice"),
256            Some("Bob"),
257            None,
258            Some("Charlie"),
259            Some("Dave"),
260        ]);
261        let age_data = Int32Array::from(vec![Some(25), Some(30), Some(15), Some(40), Some(20)]);
262        let score_data = Float64Array::from(vec![
263            Some(85.5),
264            Some(92.0),
265            Some(78.5),
266            Some(95.0),
267            Some(88.5),
268        ]);
269
270        let batch = RecordBatch::try_new(
271            Arc::new(schema),
272            vec![
273                Arc::new(id_data),
274                Arc::new(name_data),
275                Arc::new(age_data),
276                Arc::new(score_data),
277            ],
278        )
279        .unwrap();
280
281        let ctx = SessionContext::new();
282        let df = ctx.read_batch(batch.clone()).unwrap();
283        ctx.register_batch("test_data", batch).unwrap();
284
285        (ctx, df)
286    }
287
288    #[tokio::test]
289    async fn test_partition() {
290        // Create test DataFrame
291        let (_ctx, df) = create_test_df().await;
292
293        // Verify initial test data
294        let expected_data = vec![
295            "+----+---------+-----+-------+",
296            "| id | name    | age | score |",
297            "+----+---------+-----+-------+",
298            "| 1  | Alice   | 25  | 85.5  |",
299            "| 2  | Bob     | 30  | 92.0  |",
300            "| 3  |         | 15  | 78.5  |",
301            "| 4  | Charlie | 40  | 95.0  |",
302            "| 5  | Dave    | 20  | 88.5  |",
303            "+----+---------+-----+-------+",
304        ];
305        assert_batches_eq!(&expected_data, &df.clone().collect().await.unwrap());
306
307        // Create RuleSet
308        let mut rule_set = RuleSet::new();
309
310        // Add column rules
311        rule_set
312            .with_column_rule("name", dfq_not_null())
313            .with_column_rule("score", dfq_in_range(80.0, 100.0));
314
315        // Add table rule with its own column rule
316        rule_set.with_table_rule("name", dfq_null_count(), Some(dfq_in_range(0.0, 10.0)));
317
318        // Apply partition
319        let (pass_df, fail_df) = rule_set.partition(&df).await.unwrap();
320
321        // Expected pass DataFrame (rows where name is not null AND score is between 80 and 100)
322        let expected_pass = vec![
323            "+----+---------+-----+-------+",
324            "| id | name    | age | score |",
325            "+----+---------+-----+-------+",
326            "| 1  | Alice   | 25  | 85.5  |",
327            "| 2  | Bob     | 30  | 92.0  |",
328            "| 4  | Charlie | 40  | 95.0  |",
329            "| 5  | Dave    | 20  | 88.5  |",
330            "+----+---------+-----+-------+",
331        ];
332
333        // Expected fail DataFrame (rows where name is null OR score is not between 80 and 100)
334        let expected_fail = vec![
335            "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
336            "| id | name | age | score | name_null_count | name_not_null | score_in_range | name_null_count_in_range | dfq_pass |",
337            "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
338            "| 3  |      | 15  | 78.5  | 1               | false         | false          | true                     | false    |",
339            "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
340        ];
341
342        // Compare results
343        assert_batches_eq!(&expected_pass, &pass_df.collect().await.unwrap());
344        assert_batches_eq!(&expected_fail, &fail_df.collect().await.unwrap());
345    }
346
347    #[tokio::test]
348    async fn test_derived_statistics() {
349        // Create test DataFrame
350        let (_, df) = create_test_df().await;
351
352        // Create RuleSet with table rules
353        let mut rule_set = RuleSet::new();
354
355        // Add various table rules
356        rule_set
357            .with_table_rule("score", dfq_avg(), None)
358            .with_table_rule("score", dfq_stddev(), None)
359            .with_table_rule("age", dfq_null_count(), None);
360
361        // Get derived statistics
362        let stats_df = rule_set
363            .derived_statistics(&df, Some(vec!["id"]))
364            .await
365            .unwrap();
366
367        let expected = vec![
368            "+----+-----------+-------------------+----------------+",
369            "| id | score_avg | score_stddev      | age_null_count |",
370            "+----+-----------+-------------------+----------------+",
371            "| 1  | 87.9      | 6.358065743604732 | 0              |",
372            "| 2  | 87.9      | 6.358065743604732 | 0              |",
373            "| 3  | 87.9      | 6.358065743604732 | 0              |",
374            "| 4  | 87.9      | 6.358065743604732 | 0              |",
375            "| 5  | 87.9      | 6.358065743604732 | 0              |",
376            "+----+-----------+-------------------+----------------+",
377        ];
378
379        assert_batches_eq!(&expected, &stats_df.collect().await.unwrap());
380    }
381}