datafusion_table_providers/util/
constraints.rs1use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
2use datafusion::{
3 common::{Constraint, Constraints},
4 execution::context::SessionContext,
5 functions_aggregate::count::count,
6 logical_expr::{col, lit, utils::COUNT_STAR_EXPANSION},
7 prelude::ident,
8};
9use futures::future;
10use snafu::prelude::*;
11
12#[derive(Debug, Snafu)]
13pub enum Error {
14 #[snafu(display("Incoming data violates uniqueness constraint on column(s): {}", unique_cols.join(", ")))]
15 BatchViolatesUniquenessConstraint { unique_cols: Vec<String> },
16
17 #[snafu(display("{source}"))]
18 DataFusion {
19 source: datafusion::error::DataFusionError,
20 },
21}
22
23pub type Result<T, E = Error> = std::result::Result<T, E>;
24
25pub async fn validate_batch_with_constraints(
29 batches: &[RecordBatch],
30 constraints: &Constraints,
31) -> Result<()> {
32 if batches.is_empty() || constraints.is_empty() {
33 return Ok(());
34 }
35
36 let mut futures = Vec::new();
37 for constraint in &**constraints {
38 let fut = validate_batch_with_constraint(batches.to_vec(), constraint.clone());
39 futures.push(fut);
40 }
41
42 future::try_join_all(futures).await?;
43
44 Ok(())
45}
46
47#[tracing::instrument(level = "debug", skip(batches))]
48async fn validate_batch_with_constraint(
49 batches: Vec<RecordBatch>,
50 constraint: Constraint,
51) -> Result<()> {
52 let unique_cols = match constraint {
53 Constraint::PrimaryKey(cols) | Constraint::Unique(cols) => cols,
54 };
55
56 let schema = batches[0].schema();
57 let unique_fields = unique_cols
58 .iter()
59 .map(|col| schema.field(*col))
60 .collect::<Vec<_>>();
61
62 let ctx = SessionContext::new();
63 let df = ctx.read_batches(batches).context(DataFusionSnafu)?;
64
65 let count_name = count(lit(COUNT_STAR_EXPANSION)).schema_name().to_string();
66
67 let num_rows = df
72 .aggregate(
73 unique_fields.iter().map(|f| ident(f.name())).collect(),
74 vec![count(lit(COUNT_STAR_EXPANSION))],
75 )
76 .context(DataFusionSnafu)?
77 .filter(col(count_name).gt(lit(1)))
78 .context(DataFusionSnafu)?
79 .count()
80 .await
81 .context(DataFusionSnafu)?;
82
83 if num_rows > 0 {
84 BatchViolatesUniquenessConstraintSnafu {
85 unique_cols: unique_fields
86 .iter()
87 .map(|col| col.name().to_string())
88 .collect::<Vec<_>>(),
89 }
90 .fail()?;
91 }
92
93 Ok(())
94}
95
96#[must_use]
97pub fn get_primary_keys_from_constraints(
98 constraints: &Constraints,
99 schema: &SchemaRef,
100) -> Vec<String> {
101 let mut primary_keys: Vec<String> = Vec::new();
102 for constraint in constraints.clone() {
103 if let Constraint::PrimaryKey(cols) = constraint {
104 cols.iter()
105 .map(|col| schema.field(*col).name())
106 .for_each(|col| {
107 primary_keys.push(col.to_string());
108 });
109 }
110 }
111 primary_keys
112}
113
114#[cfg(test)]
115pub(crate) mod tests {
116 use std::sync::Arc;
117
118 use datafusion::arrow::datatypes::SchemaRef;
119 use datafusion::{
120 common::{Constraint, Constraints},
121 parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
122 };
123
124 #[tokio::test]
125 async fn test_validate_batch_with_constraints() -> Result<(), Box<dyn std::error::Error>> {
126 let parquet_bytes = reqwest::get("https://public-data.spiceai.org/taxi_sample.parquet")
127 .await?
128 .bytes()
129 .await?;
130
131 let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)?.build()?;
132
133 let records =
134 parquet_reader.collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()?;
135 let schema = records[0].schema();
136
137 let constraints = get_unique_constraints(
138 &["VendorID", "tpep_pickup_datetime", "PULocationID"],
139 Arc::clone(&schema),
140 );
141
142 let result = super::validate_batch_with_constraints(&records, &constraints).await;
143 assert!(
144 result.is_ok(),
145 "{}",
146 result.expect_err("this returned an error")
147 );
148
149 let invalid_constraints = get_unique_constraints(&["VendorID"], Arc::clone(&schema));
150 let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
151 assert!(result.is_err());
152 assert_eq!(
153 result.expect_err("this returned an error").to_string(),
154 "Incoming data violates uniqueness constraint on column(s): VendorID"
155 );
156
157 let invalid_constraints =
158 get_unique_constraints(&["VendorID", "tpep_pickup_datetime"], Arc::clone(&schema));
159 let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
160 assert!(result.is_err());
161 assert_eq!(
162 result.expect_err("this returned an error").to_string(),
163 "Incoming data violates uniqueness constraint on column(s): VendorID, tpep_pickup_datetime"
164 );
165
166 Ok(())
167 }
168
169 pub(crate) fn get_unique_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
170 let indices: Vec<usize> = cols
172 .iter()
173 .filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
174 .collect();
175
176 Constraints::new_unverified(vec![Constraint::Unique(indices)])
177 }
178
179 pub(crate) fn get_pk_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
180 let indices: Vec<usize> = cols
182 .iter()
183 .filter_map(|&col_name| schema.column_with_name(col_name).map(|(index, _)| index))
184 .collect();
185
186 Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
187 }
188}