use crate::constraints::Assertion;
use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
use crate::prelude::*;
use async_trait::async_trait;
use datafusion::prelude::*;
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct ColumnCountConstraint {
assertion: Assertion,
}
impl ColumnCountConstraint {
pub fn new(assertion: Assertion) -> Self {
Self { assertion }
}
}
#[async_trait]
impl Constraint for ColumnCountConstraint {
#[instrument(skip(self, ctx), fields(assertion = ?self.assertion))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let df = ctx.table("data").await.map_err(|e| {
TermError::constraint_evaluation(
self.name(),
format!("Failed to access table 'data': {e}"),
)
})?;
let column_count = df.schema().fields().len() as f64;
let assertion_result = self.assertion.evaluate(column_count);
let status = if assertion_result {
ConstraintStatus::Success
} else {
ConstraintStatus::Failure
};
let message = if status == ConstraintStatus::Failure {
Some(format!(
"Column count {column_count} does not satisfy assertion {}",
self.assertion.description()
))
} else {
None
};
Ok(ConstraintResult {
status,
metric: Some(column_count),
message,
})
}
fn name(&self) -> &str {
"column_count"
}
fn column(&self) -> Option<&str> {
None }
fn metadata(&self) -> ConstraintMetadata {
ConstraintMetadata::default()
.with_description(format!(
"Checks that the dataset has {} columns",
self.assertion.description()
))
.with_custom("assertion", self.assertion.description())
.with_custom("constraint_type", "schema")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ConstraintStatus;
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context_with_columns(num_columns: usize) -> SessionContext {
let ctx = SessionContext::new();
let fields: Vec<Field> = (0..num_columns)
.map(|i| Field::new(format!("col_{i}"), DataType::Int64, true))
.collect();
let schema = Arc::new(Schema::new(fields));
let arrays: Vec<Arc<dyn arrow::array::Array>> = (0..num_columns)
.map(|_| {
Arc::new(Int64Array::from(vec![Some(1), Some(2), Some(3)]))
as Arc<dyn arrow::array::Array>
})
.collect();
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[tokio::test]
async fn test_column_count_equals() {
let ctx = create_test_context_with_columns(5).await;
let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(5.0));
let constraint = ColumnCountConstraint::new(Assertion::Equals(10.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert_eq!(result.metric, Some(5.0));
assert!(result.message.is_some());
}
#[tokio::test]
async fn test_column_count_greater_than() {
let ctx = create_test_context_with_columns(8).await;
let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(5.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(10.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
}
#[tokio::test]
async fn test_column_count_less_than() {
let ctx = create_test_context_with_columns(3).await;
let constraint = ColumnCountConstraint::new(Assertion::LessThan(5.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
let constraint = ColumnCountConstraint::new(Assertion::LessThan(2.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
}
#[tokio::test]
async fn test_column_count_between() {
let ctx = create_test_context_with_columns(7).await;
let constraint = ColumnCountConstraint::new(Assertion::Between(5.0, 10.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 15.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
}
#[tokio::test]
async fn test_single_column_dataset() {
let ctx = create_test_context_with_columns(1).await;
let constraint = ColumnCountConstraint::new(Assertion::Equals(1.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(1.0));
}
#[tokio::test]
async fn test_large_column_count() {
let ctx = create_test_context_with_columns(100).await;
let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(100.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(100.0));
}
#[tokio::test]
async fn test_missing_table() {
let ctx = SessionContext::new();
let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_metadata() {
let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 20.0));
let metadata = constraint.metadata();
assert!(metadata
.description
.unwrap_or_default()
.contains("between 10 and 20"));
assert_eq!(
metadata.custom.get("constraint_type"),
Some(&"schema".to_string())
);
}
}