1pub mod error;
2pub mod rules;
3
4use crate::error::ValidationError;
5use datafusion::{common::DFSchema, logical_expr::ExprSchemable, prelude::*};
6use error::DataFusionSnafu;
7use snafu::ResultExt;
8use std::sync::Arc;
9
10#[derive(Clone, Default)]
12pub struct RuleSet {
13 pub(crate) schema_rules: Vec<Arc<dyn SchemaRule>>,
14 pub(crate) column_rules: Vec<(String, Arc<dyn ColumnRule>)>,
15 pub(crate) table_rules: Vec<(String, Arc<dyn TableRule>)>,
16}
17
18impl std::fmt::Debug for RuleSet {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("RuleSet")
21 .field("schema_rules", &self.schema_rules)
22 .field("column_rules", &self.column_rules)
23 .field("table_rules", &self.table_rules)
24 .finish_non_exhaustive()
25 }
26}
27
28pub trait SchemaRule: Send + Sync + std::fmt::Debug {
30 fn validate_schema(&self, _schema: &DFSchema) -> Result<bool, ValidationError> {
32 unimplemented!("validate_schema must be implemented")
33 }
34
35 fn validate_schema_with_ruleset(
37 &self,
38 schema: &DFSchema,
39 _rule_set: &RuleSet,
40 ) -> Result<bool, ValidationError> {
41 self.validate_schema(schema)
42 }
43
44 fn name(&self) -> &str;
46
47 fn description(&self) -> &str;
49}
50
51pub trait ColumnRule: Send + Sync + std::fmt::Debug {
53 fn apply(&self, _df: DataFrame, _column_name: &str) -> Result<DataFrame, ValidationError> {
55 unimplemented!("apply must be implemented")
56 }
57
58 fn apply_with_ruleset(
60 &self,
61 df: DataFrame,
62 column_name: &str,
63 _rule_set: &RuleSet,
64 ) -> Result<DataFrame, ValidationError> {
65 self.apply(df, column_name)
66 }
67
68 fn name(&self) -> &str;
70
71 fn new_column_name(&self, column_name: &str) -> String;
73
74 fn description(&self) -> &str;
76}
77
78pub trait TableRule: Send + Sync + std::fmt::Debug {
80 fn apply(&self, _df: DataFrame, _column_name: &str) -> Result<DataFrame, ValidationError> {
82 unimplemented!("apply must be implemented")
83 }
84
85 fn apply_with_ruleset(
87 &self,
88 df: DataFrame,
89 column_name: &str,
90 _rule_set: &RuleSet,
91 ) -> Result<DataFrame, ValidationError> {
92 self.apply(df, column_name)
93 }
94
95 fn name(&self) -> &str;
97
98 fn new_column_name(&self, column_name: &str) -> String;
100
101 fn description(&self) -> &str;
103}
104
105impl RuleSet {
106 pub fn new() -> Self {
108 Self {
109 schema_rules: Vec::new(),
110 column_rules: Vec::new(),
111 table_rules: Vec::new(),
112 }
113 }
114
115 pub fn with_schema_rule(&mut self, rule: Arc<dyn SchemaRule>) -> &mut Self {
117 self.schema_rules.push(rule);
118 self
119 }
120
121 pub fn with_column_rule(
123 &mut self,
124 column_name: impl AsRef<str>,
125 rule: Arc<dyn ColumnRule>,
126 ) -> &mut Self {
127 let column_name = column_name.as_ref().to_string();
128 self.column_rules.push((column_name, rule));
129 self
130 }
131
132 pub fn with_table_rule(
134 &mut self,
135 column_name: impl AsRef<str>,
136 table_rule: Arc<dyn TableRule>,
137 check: Option<Arc<dyn ColumnRule>>,
138 ) -> &mut Self {
139 let column_name = column_name.as_ref().to_string();
140 if let Some(check) = check {
141 let column_name = table_rule.new_column_name(&column_name);
142 self.column_rules.push((column_name, check));
143 }
144 self.table_rules.push((column_name, table_rule));
145 self
146 }
147
148 pub async fn apply_table_rules(&self, df: DataFrame) -> Result<DataFrame, ValidationError> {
149 let mut result_df = df;
150 for (column_name, rule) in &self.table_rules {
151 result_df = rule.apply_with_ruleset(result_df, column_name, self)?;
152 }
153 Ok(result_df)
154 }
155
156 pub async fn apply(&self, df: &DataFrame) -> Result<DataFrame, ValidationError> {
158 for rule in &self.schema_rules {
160 if !rule.validate_schema(df.schema())? {
161 return Err(ValidationError::Schema {
162 message: format!("Schema rule '{}' failed", rule.name()),
163 });
164 }
165 }
166
167 let mut result_df = df.clone();
168 result_df = self.apply_table_rules(result_df).await?;
170
171 let mut check_columns = Vec::new();
172
173 for (column_name, rule) in &self.column_rules {
175 result_df = rule.apply_with_ruleset(result_df, column_name, self)?;
176 check_columns.push(rule.new_column_name(column_name));
177 }
178
179 let dq_pass_col = check_columns
180 .into_iter()
181 .map(|col_name| {
182 col(col_name)
183 .cast_to(&arrow::datatypes::DataType::Boolean, result_df.schema())
184 .map_err(|e| ValidationError::Column {
185 message: format!("Error casting column to boolean: {}", e),
186 })
187 })
188 .reduce(|acc, col| Ok(acc?.and(col?)))
189 .unwrap_or(Ok(lit(true)))?;
190
191 result_df = result_df.with_column("dfq_pass", dq_pass_col)?;
192
193 Ok(result_df)
194 }
195
196 pub async fn partition(
197 &self,
198 df: &DataFrame,
199 ) -> Result<(DataFrame, DataFrame), ValidationError> {
200 let dq_df = self.apply(df).await?.cache().await?;
201
202 let pass_expr = col("dfq_pass").eq(lit(true));
203 let pass_df = dq_df.clone().filter(pass_expr.clone())?.select_columns(
204 &df.schema()
205 .fields()
206 .iter()
207 .map(|s| s.name().as_str())
208 .collect::<Vec<&str>>(),
209 )?;
210 let fail_df = dq_df.filter(pass_expr.not())?;
211 Ok((pass_df, fail_df))
212 }
213
214 pub async fn derived_statistics(
215 &self,
216 df: &DataFrame,
217 extra_columns: Option<Vec<&str>>,
218 ) -> Result<DataFrame, ValidationError> {
219 let dq_df = self.apply(df).await?;
220
221 let mut table_rules_names = Vec::new();
222 if let Some(extra_columns) = extra_columns {
223 table_rules_names.extend(extra_columns.iter().map(|s| col(*s)));
224 }
225
226 for (column_name, rule) in &self.table_rules {
227 table_rules_names.push(col(rule.new_column_name(column_name)));
228 }
229
230 dq_df.select(table_rules_names).context(DataFusionSnafu)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::rules::column::*;
238 use crate::rules::table::*;
239 use arrow::record_batch::RecordBatch;
240 use datafusion::arrow::array::{Float64Array, Int32Array, StringArray};
241 use datafusion::arrow::datatypes::{DataType, Field, Schema};
242 use datafusion::assert_batches_eq;
243 use std::sync::Arc;
244
245 async fn create_test_df() -> (SessionContext, DataFrame) {
246 let schema = Schema::new(vec![
247 Field::new("id", DataType::Int32, false),
248 Field::new("name", DataType::Utf8, true),
249 Field::new("age", DataType::Int32, true),
250 Field::new("score", DataType::Float64, true),
251 ]);
252
253 let id_data = Int32Array::from(vec![1, 2, 3, 4, 5]);
254 let name_data = StringArray::from(vec![
255 Some("Alice"),
256 Some("Bob"),
257 None,
258 Some("Charlie"),
259 Some("Dave"),
260 ]);
261 let age_data = Int32Array::from(vec![Some(25), Some(30), Some(15), Some(40), Some(20)]);
262 let score_data = Float64Array::from(vec![
263 Some(85.5),
264 Some(92.0),
265 Some(78.5),
266 Some(95.0),
267 Some(88.5),
268 ]);
269
270 let batch = RecordBatch::try_new(
271 Arc::new(schema),
272 vec![
273 Arc::new(id_data),
274 Arc::new(name_data),
275 Arc::new(age_data),
276 Arc::new(score_data),
277 ],
278 )
279 .unwrap();
280
281 let ctx = SessionContext::new();
282 let df = ctx.read_batch(batch.clone()).unwrap();
283 ctx.register_batch("test_data", batch).unwrap();
284
285 (ctx, df)
286 }
287
288 #[tokio::test]
289 async fn test_partition() {
290 let (_ctx, df) = create_test_df().await;
292
293 let expected_data = vec![
295 "+----+---------+-----+-------+",
296 "| id | name | age | score |",
297 "+----+---------+-----+-------+",
298 "| 1 | Alice | 25 | 85.5 |",
299 "| 2 | Bob | 30 | 92.0 |",
300 "| 3 | | 15 | 78.5 |",
301 "| 4 | Charlie | 40 | 95.0 |",
302 "| 5 | Dave | 20 | 88.5 |",
303 "+----+---------+-----+-------+",
304 ];
305 assert_batches_eq!(&expected_data, &df.clone().collect().await.unwrap());
306
307 let mut rule_set = RuleSet::new();
309
310 rule_set
312 .with_column_rule("name", dfq_not_null())
313 .with_column_rule("score", dfq_in_range(80.0, 100.0));
314
315 rule_set.with_table_rule("name", dfq_null_count(), Some(dfq_in_range(0.0, 10.0)));
317
318 let (pass_df, fail_df) = rule_set.partition(&df).await.unwrap();
320
321 let expected_pass = vec![
323 "+----+---------+-----+-------+",
324 "| id | name | age | score |",
325 "+----+---------+-----+-------+",
326 "| 1 | Alice | 25 | 85.5 |",
327 "| 2 | Bob | 30 | 92.0 |",
328 "| 4 | Charlie | 40 | 95.0 |",
329 "| 5 | Dave | 20 | 88.5 |",
330 "+----+---------+-----+-------+",
331 ];
332
333 let expected_fail = vec![
335 "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
336 "| id | name | age | score | name_null_count | name_not_null | score_in_range | name_null_count_in_range | dfq_pass |",
337 "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
338 "| 3 | | 15 | 78.5 | 1 | false | false | true | false |",
339 "+----+------+-----+-------+-----------------+---------------+----------------+--------------------------+----------+",
340 ];
341
342 assert_batches_eq!(&expected_pass, &pass_df.collect().await.unwrap());
344 assert_batches_eq!(&expected_fail, &fail_df.collect().await.unwrap());
345 }
346
347 #[tokio::test]
348 async fn test_derived_statistics() {
349 let (_, df) = create_test_df().await;
351
352 let mut rule_set = RuleSet::new();
354
355 rule_set
357 .with_table_rule("score", dfq_avg(), None)
358 .with_table_rule("score", dfq_stddev(), None)
359 .with_table_rule("age", dfq_null_count(), None);
360
361 let stats_df = rule_set
363 .derived_statistics(&df, Some(vec!["id"]))
364 .await
365 .unwrap();
366
367 let expected = vec![
368 "+----+-----------+-------------------+----------------+",
369 "| id | score_avg | score_stddev | age_null_count |",
370 "+----+-----------+-------------------+----------------+",
371 "| 1 | 87.9 | 6.358065743604732 | 0 |",
372 "| 2 | 87.9 | 6.358065743604732 | 0 |",
373 "| 3 | 87.9 | 6.358065743604732 | 0 |",
374 "| 4 | 87.9 | 6.358065743604732 | 0 |",
375 "| 5 | 87.9 | 6.358065743604732 | 0 |",
376 "+----+-----------+-------------------+----------------+",
377 ];
378
379 assert_batches_eq!(&expected, &stats_df.collect().await.unwrap());
380 }
381}