datafusion_table_providers/util/
constraints.rs

1use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
2use datafusion::{
3    common::{Constraint, Constraints},
4    execution::context::SessionContext,
5    functions_aggregate::count::count,
6    logical_expr::{col, lit, utils::COUNT_STAR_EXPANSION},
7    prelude::ident,
8};
9use futures::future;
10use snafu::prelude::*;
11
12#[derive(Debug, Snafu)]
13pub enum Error {
14    #[snafu(display("Incoming data violates uniqueness constraint on column(s): {}", unique_cols.join(", ")))]
15    BatchViolatesUniquenessConstraint { unique_cols: Vec<String> },
16
17    #[snafu(display("{source}"))]
18    DataFusion {
19        source: datafusion::error::DataFusionError,
20    },
21}
22
23pub type Result<T, E = Error> = std::result::Result<T, E>;
24
25/// The goal for this function is to determine if all of the data described in `batches` conforms to the constraints described in `constraints`.
26///
27/// It does this by creating a memory table from the record batches and then running a query against the table to validate the constraints.
28pub async fn validate_batch_with_constraints(
29    batches: &[RecordBatch],
30    constraints: &Constraints,
31) -> Result<()> {
32    if batches.is_empty() || constraints.is_empty() {
33        return Ok(());
34    }
35
36    let mut futures = Vec::new();
37    for constraint in &**constraints {
38        let fut = validate_batch_with_constraint(batches.to_vec(), constraint.clone());
39        futures.push(fut);
40    }
41
42    future::try_join_all(futures).await?;
43
44    Ok(())
45}
46
47#[tracing::instrument(level = "debug", skip(batches))]
48async fn validate_batch_with_constraint(
49    batches: Vec<RecordBatch>,
50    constraint: Constraint,
51) -> Result<()> {
52    let unique_cols = match constraint {
53        Constraint::PrimaryKey(cols) | Constraint::Unique(cols) => cols,
54    };
55
56    let schema = batches[0].schema();
57    let unique_fields = unique_cols
58        .iter()
59        .map(|col| schema.field(*col))
60        .collect::<Vec<_>>();
61
62    let ctx = SessionContext::new();
63    let df = ctx.read_batches(batches).context(DataFusionSnafu)?;
64
65    let count_name = count(lit(COUNT_STAR_EXPANSION)).schema_name().to_string();
66
67    // This is equivalent to:
68    // ```sql
69    // SELECT COUNT(1), <unique_field_names> FROM mem_table GROUP BY <unique_field_names> HAVING COUNT(1) > 1
70    // ```
71    let num_rows = df
72        .aggregate(
73            unique_fields.iter().map(|f| ident(f.name())).collect(),
74            vec![count(lit(COUNT_STAR_EXPANSION))],
75        )
76        .context(DataFusionSnafu)?
77        .filter(col(count_name).gt(lit(1)))
78        .context(DataFusionSnafu)?
79        .count()
80        .await
81        .context(DataFusionSnafu)?;
82
83    if num_rows > 0 {
84        BatchViolatesUniquenessConstraintSnafu {
85            unique_cols: unique_fields
86                .iter()
87                .map(|col| col.name().to_string())
88                .collect::<Vec<_>>(),
89        }
90        .fail()?;
91    }
92
93    Ok(())
94}
95
96#[must_use]
97pub fn get_primary_keys_from_constraints(
98    constraints: &Constraints,
99    schema: &SchemaRef,
100) -> Vec<String> {
101    let mut primary_keys: Vec<String> = Vec::new();
102    for constraint in constraints.clone() {
103        if let Constraint::PrimaryKey(cols) = constraint {
104            cols.iter()
105                .map(|col| schema.field(*col).name())
106                .for_each(|col| {
107                    primary_keys.push(col.to_string());
108                });
109        }
110    }
111    primary_keys
112}
113
114#[cfg(test)]
115pub(crate) mod tests {
116    use std::sync::Arc;
117
118    use datafusion::arrow::datatypes::SchemaRef;
119    use datafusion::{
120        common::{Constraint, Constraints},
121        parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
122    };
123
124    #[tokio::test]
125    async fn test_validate_batch_with_constraints() -> Result<(), Box<dyn std::error::Error>> {
126        let parquet_bytes = reqwest::get("https://public-data.spiceai.org/taxi_sample.parquet")
127            .await?
128            .bytes()
129            .await?;
130
131        let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)?.build()?;
132
133        let records =
134            parquet_reader.collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()?;
135        let schema = records[0].schema();
136
137        let constraints = get_unique_constraints(
138            &["VendorID", "tpep_pickup_datetime", "PULocationID"],
139            Arc::clone(&schema),
140        );
141
142        let result = super::validate_batch_with_constraints(&records, &constraints).await;
143        assert!(
144            result.is_ok(),
145            "{}",
146            result.expect_err("this returned an error")
147        );
148
149        let invalid_constraints = get_unique_constraints(&["VendorID"], Arc::clone(&schema));
150        let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
151        assert!(result.is_err());
152        assert_eq!(
153            result.expect_err("this returned an error").to_string(),
154            "Incoming data violates uniqueness constraint on column(s): VendorID"
155        );
156
157        let invalid_constraints =
158            get_unique_constraints(&["VendorID", "tpep_pickup_datetime"], Arc::clone(&schema));
159        let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
160        assert!(result.is_err());
161        assert_eq!(
162            result.expect_err("this returned an error").to_string(),
163            "Incoming data violates uniqueness constraint on column(s): VendorID, tpep_pickup_datetime"
164        );
165
166        Ok(())
167    }
168
169    pub(crate) fn get_unique_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
170        // Convert column names to their corresponding indices
171        let indices: Vec<usize> = cols
172            .iter()
173            .filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
174            .collect();
175
176        Constraints::new_unverified(vec![Constraint::Unique(indices)])
177    }
178
179    pub(crate) fn get_pk_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
180        // Convert column names to their corresponding indices
181        let indices: Vec<usize> = cols
182            .iter()
183            .filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
184            .collect();
185
186        Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
187    }
188}