datafusion_table_providers/util/
constraints.rs

1use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
2use datafusion::{
3    common::{Constraint, Constraints},
4    execution::context::SessionContext,
5    functions_aggregate::count::count,
6    logical_expr::{col, lit, utils::COUNT_STAR_EXPANSION},
7};
8use futures::future;
9use snafu::prelude::*;
10
11#[derive(Debug, Snafu)]
12pub enum Error {
13    #[snafu(display("Incoming data violates uniqueness constraint on column(s): {}", unique_cols.join(", ")))]
14    BatchViolatesUniquenessConstraint { unique_cols: Vec<String> },
15
16    #[snafu(display("{source}"))]
17    DataFusion {
18        source: datafusion::error::DataFusionError,
19    },
20}
21
22pub type Result<T, E = Error> = std::result::Result<T, E>;
23
24/// The goal for this function is to determine if all of the data described in `batches` conforms to the constraints described in `constraints`.
25///
26/// It does this by creating a memory table from the record batches and then running a query against the table to validate the constraints.
27pub async fn validate_batch_with_constraints(
28    batches: &[RecordBatch],
29    constraints: &Constraints,
30) -> Result<()> {
31    if batches.is_empty() || constraints.is_empty() {
32        return Ok(());
33    }
34
35    let mut futures = Vec::new();
36    for constraint in &**constraints {
37        let fut = validate_batch_with_constraint(batches.to_vec(), constraint.clone());
38        futures.push(fut);
39    }
40
41    future::try_join_all(futures).await?;
42
43    Ok(())
44}
45
46#[tracing::instrument(level = "debug", skip(batches))]
47async fn validate_batch_with_constraint(
48    batches: Vec<RecordBatch>,
49    constraint: Constraint,
50) -> Result<()> {
51    let unique_cols = match constraint {
52        Constraint::PrimaryKey(cols) | Constraint::Unique(cols) => cols,
53    };
54
55    let schema = batches[0].schema();
56    let unique_fields = unique_cols
57        .iter()
58        .map(|col| schema.field(*col))
59        .collect::<Vec<_>>();
60
61    let ctx = SessionContext::new();
62    let df = ctx.read_batches(batches).context(DataFusionSnafu)?;
63
64    let count_name = count(lit(COUNT_STAR_EXPANSION)).schema_name().to_string();
65
66    // This is equivalent to:
67    // ```sql
68    // SELECT COUNT(1), <unique_field_names> FROM mem_table GROUP BY <unique_field_names> HAVING COUNT(1) > 1
69    // ```
70    let num_rows = df
71        .aggregate(
72            unique_fields.iter().map(|f| col(f.name())).collect(),
73            vec![count(lit(COUNT_STAR_EXPANSION))],
74        )
75        .context(DataFusionSnafu)?
76        .filter(col(count_name).gt(lit(1)))
77        .context(DataFusionSnafu)?
78        .count()
79        .await
80        .context(DataFusionSnafu)?;
81
82    if num_rows > 0 {
83        BatchViolatesUniquenessConstraintSnafu {
84            unique_cols: unique_fields
85                .iter()
86                .map(|col| col.name().to_string())
87                .collect::<Vec<_>>(),
88        }
89        .fail()?;
90    }
91
92    Ok(())
93}
94
95#[must_use]
96pub fn get_primary_keys_from_constraints(
97    constraints: &Constraints,
98    schema: &SchemaRef,
99) -> Vec<String> {
100    let mut primary_keys: Vec<String> = Vec::new();
101    for constraint in constraints.clone() {
102        if let Constraint::PrimaryKey(cols) = constraint {
103            cols.iter()
104                .map(|col| schema.field(*col).name())
105                .for_each(|col| {
106                    primary_keys.push(col.to_string());
107                });
108        }
109    }
110    primary_keys
111}
112
113#[cfg(test)]
114pub(crate) mod tests {
115    use std::sync::Arc;
116
117    use datafusion::arrow::datatypes::SchemaRef;
118    use datafusion::{
119        common::{Constraint, Constraints},
120        parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
121    };
122
123    #[tokio::test]
124    async fn test_validate_batch_with_constraints() -> Result<(), Box<dyn std::error::Error>> {
125        let parquet_bytes = reqwest::get("https://public-data.spiceai.org/eth.recent_logs.parquet")
126            .await?
127            .bytes()
128            .await?;
129
130        let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)?.build()?;
131
132        let records =
133            parquet_reader.collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()?;
134        let schema = records[0].schema();
135
136        let constraints =
137            get_unique_constraints(&["log_index", "transaction_hash"], Arc::clone(&schema));
138
139        let result = super::validate_batch_with_constraints(&records, &constraints).await;
140        assert!(
141            result.is_ok(),
142            "{}",
143            result.expect_err("this returned an error")
144        );
145
146        let invalid_constraints = get_unique_constraints(&["block_number"], Arc::clone(&schema));
147        let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
148        assert!(result.is_err());
149        assert_eq!(
150            result.expect_err("this returned an error").to_string(),
151            "Incoming data violates uniqueness constraint on column(s): block_number"
152        );
153
154        let invalid_constraints =
155            get_unique_constraints(&["block_number", "transaction_hash"], Arc::clone(&schema));
156        let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
157        assert!(result.is_err());
158        assert_eq!(
159            result.expect_err("this returned an error").to_string(),
160            "Incoming data violates uniqueness constraint on column(s): block_number, transaction_hash"
161        );
162
163        Ok(())
164    }
165
166    pub(crate) fn get_unique_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
167        let indices = cols
168            .iter()
169            .map(|col| {
170                schema
171                    .index_of(col)
172                    .unwrap_or_else(|_| panic!("[{col}] not found, validated schema: [{}]", schema))
173            })
174            .collect();
175
176        Constraints::new_unverified(vec![Constraint::Unique(indices)])
177    }
178
179    pub(crate) fn get_pk_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
180        let indices = cols
181            .iter()
182            .map(|col| {
183                schema
184                    .index_of(col)
185                    .unwrap_or_else(|_| panic!("[{col}] not found, validated schema: [{}]", schema))
186            })
187            .collect();
188
189        Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
190    }
191}