use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
use crate::core::current_validation_context;
#[derive(Debug, Clone)]
pub struct ComplianceAnalyzer {
name: String,
predicate: String,
}
impl ComplianceAnalyzer {
pub fn new(name: impl Into<String>, predicate: impl Into<String>) -> Self {
Self {
name: name.into(),
predicate: predicate.into(),
}
}
pub fn check_name(&self) -> &str {
&self.name
}
pub fn predicate(&self) -> &str {
&self.predicate
}
fn validate_predicate(&self) -> AnalyzerResult<()> {
let lower = self.predicate.to_lowercase();
let dangerous_keywords = [
"drop", "delete", "insert", "update", "create", "alter", "grant", "revoke", "exec",
"execute", "union", "select", "--", "/*", "*/",
];
for keyword in &dangerous_keywords {
if lower.contains(keyword) {
return Err(AnalyzerError::invalid_config(format!(
"Predicate contains forbidden keyword: {keyword}"
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceState {
pub compliant_count: u64,
pub total_count: u64,
}
impl ComplianceState {
pub fn compliance_fraction(&self) -> f64 {
if self.total_count == 0 {
1.0 } else {
self.compliant_count as f64 / self.total_count as f64
}
}
}
impl AnalyzerState for ComplianceState {
fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
let compliant_count = states.iter().map(|s| s.compliant_count).sum();
let total_count = states.iter().map(|s| s.total_count).sum();
Ok(ComplianceState {
compliant_count,
total_count,
})
}
fn is_empty(&self) -> bool {
self.total_count == 0
}
}
#[async_trait]
impl Analyzer for ComplianceAnalyzer {
type State = ComplianceState;
type Metric = MetricValue;
#[instrument(skip(ctx), fields(analyzer = "compliance", name = %self.name))]
async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
self.validate_predicate()?;
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!(
"SELECT
COUNT(CASE WHEN ({}) THEN 1 END) as compliant_count,
COUNT(*) as total_count
FROM {table_name}",
self.predicate
);
let df = ctx.sql(&sql).await.map_err(|e| {
AnalyzerError::invalid_config(format!("Invalid predicate '{}': {e}", self.predicate))
})?;
let batches = df.collect().await?;
let (compliant_count, total_count) = if let Some(batch) = batches.first() {
if batch.num_rows() > 0 {
let compliant_array = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
AnalyzerError::invalid_data("Expected Int64 array for compliant count")
})?;
let compliant = compliant_array.value(0) as u64;
let total_array = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
AnalyzerError::invalid_data("Expected Int64 array for total count")
})?;
let total = total_array.value(0) as u64;
(compliant, total)
} else {
(0, 0)
}
} else {
(0, 0)
};
Ok(ComplianceState {
compliant_count,
total_count,
})
}
fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
Ok(MetricValue::Double(state.compliance_fraction()))
}
fn name(&self) -> &str {
"compliance"
}
fn description(&self) -> &str {
"Evaluates compliance with custom SQL expressions"
}
}