1use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use arrow::array::Array;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9
10#[derive(Debug, Clone, PartialEq)]
12#[allow(dead_code)]
13pub enum DataType {
14 Integer,
16 Float,
18 Boolean,
20 Date,
22 Timestamp,
24 String,
26}
27
28impl DataType {
29 fn pattern(&self) -> &str {
31 match self {
32 DataType::Integer => r"^-?\d+$",
33 DataType::Float => r"^-?\d*\.?\d+([eE][+-]?\d+)?$",
34 DataType::Boolean => r"^(true|false|TRUE|FALSE|True|False|0|1)$",
35 DataType::Date => r"^\d{4}-\d{2}-\d{2}$",
36 DataType::Timestamp => r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}",
37 DataType::String => r".*", }
39 }
40
41 fn name(&self) -> &str {
43 match self {
44 DataType::Integer => "integer",
45 DataType::Float => "float",
46 DataType::Boolean => "boolean",
47 DataType::Date => "date",
48 DataType::Timestamp => "timestamp",
49 DataType::String => "string",
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
69#[allow(dead_code)]
70pub struct DataTypeConstraint {
71 column: String,
72 data_type: DataType,
73 threshold: f64,
74}
75
76#[allow(dead_code)]
77impl DataTypeConstraint {
78 pub fn new(column: impl Into<String>, data_type: DataType, threshold: f64) -> Self {
90 assert!(
91 (0.0..=1.0).contains(&threshold),
92 "Threshold must be between 0.0 and 1.0"
93 );
94 Self {
95 column: column.into(),
96 data_type,
97 threshold,
98 }
99 }
100}
101
102#[async_trait]
103impl Constraint for DataTypeConstraint {
104 #[instrument(skip(self, ctx), fields(column = %self.column, data_type = %self.data_type.name(), threshold = %self.threshold))]
105 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
106 let pattern = self.data_type.pattern();
107
108 let sql = format!(
109 "SELECT
110 COUNT(CASE WHEN {} ~ '{pattern}' THEN 1 END) as matches,
111 COUNT(*) as total
112 FROM data
113 WHERE {} IS NOT NULL",
114 self.column, self.column
115 );
116
117 let df = ctx.sql(&sql).await?;
118 let batches = df.collect().await?;
119
120 if batches.is_empty() {
121 return Ok(ConstraintResult::skipped("No data to validate"));
122 }
123
124 let batch = &batches[0];
125 if batch.num_rows() == 0 {
126 return Ok(ConstraintResult::skipped("No data to validate"));
127 }
128
129 let matches = batch
130 .column(0)
131 .as_any()
132 .downcast_ref::<arrow::array::Int64Array>()
133 .ok_or_else(|| TermError::Internal("Failed to extract match count".to_string()))?
134 .value(0) as f64;
135
136 let total = batch
137 .column(1)
138 .as_any()
139 .downcast_ref::<arrow::array::Int64Array>()
140 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
141 .value(0) as f64;
142
143 if total == 0.0 {
144 return Ok(ConstraintResult::skipped("No non-null data to validate"));
145 }
146
147 let type_ratio = matches / total;
148
149 if type_ratio >= self.threshold {
150 Ok(ConstraintResult::success_with_metric(type_ratio))
151 } else {
152 Ok(ConstraintResult::failure_with_metric(
153 type_ratio,
154 format!(
155 "Data type conformance {type_ratio} is below threshold {}",
156 self.threshold
157 ),
158 ))
159 }
160 }
161
162 fn name(&self) -> &str {
163 "data_type"
164 }
165
166 fn column(&self) -> Option<&str> {
167 Some(&self.column)
168 }
169
170 fn metadata(&self) -> ConstraintMetadata {
171 ConstraintMetadata::for_column(&self.column)
172 .with_description(format!(
173 "Checks that at least {:.1}% of values in '{}' conform to {} type",
174 self.threshold * 100.0,
175 self.column,
176 self.data_type.name()
177 ))
178 .with_custom("data_type", self.data_type.name())
179 .with_custom("threshold", self.threshold.to_string())
180 .with_custom("constraint_type", "data_type")
181 }
182}
183
184#[derive(Debug, Clone)]
200pub struct ContainmentConstraint {
201 column: String,
202 allowed_values: Vec<String>,
203}
204
205impl ContainmentConstraint {
206 pub fn new<I, S>(column: impl Into<String>, allowed_values: I) -> Self
213 where
214 I: IntoIterator<Item = S>,
215 S: Into<String>,
216 {
217 Self {
218 column: column.into(),
219 allowed_values: allowed_values.into_iter().map(Into::into).collect(),
220 }
221 }
222}
223
224#[async_trait]
225impl Constraint for ContainmentConstraint {
226 #[instrument(skip(self, ctx), fields(column = %self.column, allowed_count = %self.allowed_values.len()))]
227 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
228 let values_list = self
230 .allowed_values
231 .iter()
232 .map(|v| format!("'{}'", v.replace('\'', "''"))) .collect::<Vec<_>>()
234 .join(", ");
235
236 let sql = format!(
237 "SELECT
238 COUNT(CASE WHEN {} IN ({values_list}) THEN 1 END) as valid_values,
239 COUNT(*) as total
240 FROM data
241 WHERE {} IS NOT NULL",
242 self.column, self.column
243 );
244
245 let df = ctx.sql(&sql).await?;
246 let batches = df.collect().await?;
247
248 if batches.is_empty() {
249 return Ok(ConstraintResult::skipped("No data to validate"));
250 }
251
252 let batch = &batches[0];
253 if batch.num_rows() == 0 {
254 return Ok(ConstraintResult::skipped("No data to validate"));
255 }
256
257 let valid_values = batch
258 .column(0)
259 .as_any()
260 .downcast_ref::<arrow::array::Int64Array>()
261 .ok_or_else(|| TermError::Internal("Failed to extract valid count".to_string()))?
262 .value(0) as f64;
263
264 let total = batch
265 .column(1)
266 .as_any()
267 .downcast_ref::<arrow::array::Int64Array>()
268 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
269 .value(0) as f64;
270
271 if total == 0.0 {
272 return Ok(ConstraintResult::skipped("No non-null data to validate"));
273 }
274
275 let containment_ratio = valid_values / total;
276
277 if containment_ratio == 1.0 {
278 Ok(ConstraintResult::success_with_metric(containment_ratio))
279 } else {
280 let invalid_count = total - valid_values;
281 Ok(ConstraintResult::failure_with_metric(
282 containment_ratio,
283 format!("{invalid_count} values are not in the allowed set"),
284 ))
285 }
286 }
287
288 fn name(&self) -> &str {
289 "containment"
290 }
291
292 fn column(&self) -> Option<&str> {
293 Some(&self.column)
294 }
295
296 fn metadata(&self) -> ConstraintMetadata {
297 ConstraintMetadata::for_column(&self.column)
298 .with_description(format!(
299 "Checks that all values in '{}' are contained in the allowed set",
300 self.column
301 ))
302 .with_custom(
303 "allowed_values",
304 format!("[{}]", self.allowed_values.join(", ")),
305 )
306 .with_custom("constraint_type", "containment")
307 }
308}
309
310#[derive(Debug, Clone)]
326#[allow(dead_code)]
327pub struct NonNegativeConstraint {
328 column: String,
329}
330
331#[allow(dead_code)]
332impl NonNegativeConstraint {
333 pub fn new(column: impl Into<String>) -> Self {
339 Self {
340 column: column.into(),
341 }
342 }
343}
344
345#[async_trait]
346impl Constraint for NonNegativeConstraint {
347 #[instrument(skip(self, ctx), fields(column = %self.column))]
348 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
349 let sql = format!(
351 "SELECT
352 COUNT(CASE WHEN CAST({} AS DOUBLE) >= 0 THEN 1 END) as non_negative,
353 COUNT(*) as total
354 FROM data
355 WHERE {} IS NOT NULL",
356 self.column, self.column
357 );
358
359 let df = ctx.sql(&sql).await?;
360 let batches = df.collect().await?;
361
362 if batches.is_empty() {
363 return Ok(ConstraintResult::skipped("No data to validate"));
364 }
365
366 let batch = &batches[0];
367 if batch.num_rows() == 0 {
368 return Ok(ConstraintResult::skipped("No data to validate"));
369 }
370
371 let non_negative = batch
372 .column(0)
373 .as_any()
374 .downcast_ref::<arrow::array::Int64Array>()
375 .ok_or_else(|| TermError::Internal("Failed to extract non-negative count".to_string()))?
376 .value(0) as f64;
377
378 let total = batch
379 .column(1)
380 .as_any()
381 .downcast_ref::<arrow::array::Int64Array>()
382 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
383 .value(0) as f64;
384
385 if total == 0.0 {
386 return Ok(ConstraintResult::skipped("No non-null data to validate"));
387 }
388
389 let non_negative_ratio = non_negative / total;
390
391 if non_negative_ratio == 1.0 {
392 Ok(ConstraintResult::success_with_metric(non_negative_ratio))
393 } else {
394 let negative_count = total - non_negative;
395 Ok(ConstraintResult::failure_with_metric(
396 non_negative_ratio,
397 format!("{negative_count} values are negative"),
398 ))
399 }
400 }
401
402 fn name(&self) -> &str {
403 "non_negative"
404 }
405
406 fn column(&self) -> Option<&str> {
407 Some(&self.column)
408 }
409
410 fn metadata(&self) -> ConstraintMetadata {
411 ConstraintMetadata::for_column(&self.column)
412 .with_description(format!(
413 "Checks that all values in '{}' are non-negative",
414 self.column
415 ))
416 .with_custom("constraint_type", "value_range")
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::core::ConstraintStatus;
424 use arrow::array::{Float64Array, StringArray};
425 use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
426 use arrow::record_batch::RecordBatch;
427 use datafusion::datasource::MemTable;
428 use std::sync::Arc;
429
430 async fn create_string_test_context(values: Vec<Option<&str>>) -> SessionContext {
431 let ctx = SessionContext::new();
432
433 let schema = Arc::new(Schema::new(vec![Field::new(
434 "text_col",
435 ArrowDataType::Utf8,
436 true,
437 )]));
438
439 let array = StringArray::from(values);
440 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
441
442 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
443 ctx.register_table("data", Arc::new(provider)).unwrap();
444
445 ctx
446 }
447
448 async fn create_numeric_test_context(values: Vec<Option<f64>>) -> SessionContext {
449 let ctx = SessionContext::new();
450
451 let schema = Arc::new(Schema::new(vec![Field::new(
452 "num_col",
453 ArrowDataType::Float64,
454 true,
455 )]));
456
457 let array = Float64Array::from(values);
458 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
459
460 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
461 ctx.register_table("data", Arc::new(provider)).unwrap();
462
463 ctx
464 }
465
466 #[tokio::test]
467 async fn test_data_type_integer() {
468 let values = vec![Some("123"), Some("456"), Some("not_number"), Some("789")];
469 let ctx = create_string_test_context(values).await;
470
471 let constraint = DataTypeConstraint::new("text_col", DataType::Integer, 0.7);
472
473 let result = constraint.evaluate(&ctx).await.unwrap();
474 assert_eq!(result.status, ConstraintStatus::Success);
475 assert_eq!(result.metric, Some(0.75)); }
477
478 #[tokio::test]
479 async fn test_data_type_float() {
480 let values = vec![Some("123.45"), Some("67.89"), Some("invalid"), Some("100")];
481 let ctx = create_string_test_context(values).await;
482
483 let constraint = DataTypeConstraint::new("text_col", DataType::Float, 0.7);
484
485 let result = constraint.evaluate(&ctx).await.unwrap();
486 assert_eq!(result.status, ConstraintStatus::Success);
487 assert_eq!(result.metric, Some(0.75)); }
489
490 #[tokio::test]
491 async fn test_data_type_boolean() {
492 let values = vec![Some("true"), Some("false"), Some("invalid"), Some("1")];
493 let ctx = create_string_test_context(values).await;
494
495 let constraint = DataTypeConstraint::new("text_col", DataType::Boolean, 0.7);
496
497 let result = constraint.evaluate(&ctx).await.unwrap();
498 assert_eq!(result.status, ConstraintStatus::Success);
499 assert_eq!(result.metric, Some(0.75)); }
501
502 #[tokio::test]
503 async fn test_containment_constraint() {
504 let values = vec![
505 Some("active"),
506 Some("inactive"),
507 Some("pending"),
508 Some("invalid_status"),
509 ];
510 let ctx = create_string_test_context(values).await;
511
512 let constraint = ContainmentConstraint::new(
513 "text_col",
514 vec!["active", "inactive", "pending", "archived"],
515 );
516
517 let result = constraint.evaluate(&ctx).await.unwrap();
518 assert_eq!(result.status, ConstraintStatus::Failure);
519 assert_eq!(result.metric, Some(0.75)); }
521
522 #[tokio::test]
523 async fn test_containment_all_valid() {
524 let values = vec![Some("active"), Some("inactive"), Some("pending")];
525 let ctx = create_string_test_context(values).await;
526
527 let constraint = ContainmentConstraint::new(
528 "text_col",
529 vec!["active", "inactive", "pending", "archived"],
530 );
531
532 let result = constraint.evaluate(&ctx).await.unwrap();
533 assert_eq!(result.status, ConstraintStatus::Success);
534 assert_eq!(result.metric, Some(1.0)); }
536
537 #[tokio::test]
538 async fn test_non_negative_constraint() {
539 let values = vec![Some(1.0), Some(0.0), Some(5.5), Some(100.0)];
540 let ctx = create_numeric_test_context(values).await;
541
542 let constraint = NonNegativeConstraint::new("num_col");
543
544 let result = constraint.evaluate(&ctx).await.unwrap();
545 assert_eq!(result.status, ConstraintStatus::Success);
546 assert_eq!(result.metric, Some(1.0)); }
548
549 #[tokio::test]
550 async fn test_non_negative_with_negative() {
551 let values = vec![Some(1.0), Some(-2.0), Some(5.5), Some(100.0)];
552 let ctx = create_numeric_test_context(values).await;
553
554 let constraint = NonNegativeConstraint::new("num_col");
555
556 let result = constraint.evaluate(&ctx).await.unwrap();
557 assert_eq!(result.status, ConstraintStatus::Failure);
558 assert_eq!(result.metric, Some(0.75)); }
560
561 #[tokio::test]
562 async fn test_with_nulls() {
563 let values = vec![Some("active"), None, Some("inactive"), None];
564 let ctx = create_string_test_context(values).await;
565
566 let constraint = ContainmentConstraint::new("text_col", vec!["active", "inactive"]);
567
568 let result = constraint.evaluate(&ctx).await.unwrap();
569 assert_eq!(result.status, ConstraintStatus::Success);
570 assert_eq!(result.metric, Some(1.0)); }
572
573 #[test]
574 #[should_panic(expected = "Threshold must be between 0.0 and 1.0")]
575 fn test_invalid_threshold() {
576 DataTypeConstraint::new("col", DataType::Integer, 1.5);
577 }
578}