datafusion_table_providers/util/
constraints.rs1use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
2use datafusion::{
3 common::{Constraint, Constraints},
4 execution::context::SessionContext,
5 functions_aggregate::count::count,
6 logical_expr::{col, lit, utils::COUNT_STAR_EXPANSION},
7};
8use futures::future;
9use snafu::prelude::*;
10
11#[derive(Debug, Snafu)]
12pub enum Error {
13 #[snafu(display("Incoming data violates uniqueness constraint on column(s): {}", unique_cols.join(", ")))]
14 BatchViolatesUniquenessConstraint { unique_cols: Vec<String> },
15
16 #[snafu(display("{source}"))]
17 DataFusion {
18 source: datafusion::error::DataFusionError,
19 },
20}
21
22pub type Result<T, E = Error> = std::result::Result<T, E>;
23
24pub async fn validate_batch_with_constraints(
28 batches: &[RecordBatch],
29 constraints: &Constraints,
30) -> Result<()> {
31 if batches.is_empty() || constraints.is_empty() {
32 return Ok(());
33 }
34
35 let mut futures = Vec::new();
36 for constraint in &**constraints {
37 let fut = validate_batch_with_constraint(batches.to_vec(), constraint.clone());
38 futures.push(fut);
39 }
40
41 future::try_join_all(futures).await?;
42
43 Ok(())
44}
45
46#[tracing::instrument(level = "debug", skip(batches))]
47async fn validate_batch_with_constraint(
48 batches: Vec<RecordBatch>,
49 constraint: Constraint,
50) -> Result<()> {
51 let unique_cols = match constraint {
52 Constraint::PrimaryKey(cols) | Constraint::Unique(cols) => cols,
53 };
54
55 let schema = batches[0].schema();
56 let unique_fields = unique_cols
57 .iter()
58 .map(|col| schema.field(*col))
59 .collect::<Vec<_>>();
60
61 let ctx = SessionContext::new();
62 let df = ctx.read_batches(batches).context(DataFusionSnafu)?;
63
64 let count_name = count(lit(COUNT_STAR_EXPANSION)).schema_name().to_string();
65
66 let num_rows = df
71 .aggregate(
72 unique_fields.iter().map(|f| col(f.name())).collect(),
73 vec![count(lit(COUNT_STAR_EXPANSION))],
74 )
75 .context(DataFusionSnafu)?
76 .filter(col(count_name).gt(lit(1)))
77 .context(DataFusionSnafu)?
78 .count()
79 .await
80 .context(DataFusionSnafu)?;
81
82 if num_rows > 0 {
83 BatchViolatesUniquenessConstraintSnafu {
84 unique_cols: unique_fields
85 .iter()
86 .map(|col| col.name().to_string())
87 .collect::<Vec<_>>(),
88 }
89 .fail()?;
90 }
91
92 Ok(())
93}
94
95#[must_use]
96pub fn get_primary_keys_from_constraints(
97 constraints: &Constraints,
98 schema: &SchemaRef,
99) -> Vec<String> {
100 let mut primary_keys: Vec<String> = Vec::new();
101 for constraint in constraints.clone() {
102 if let Constraint::PrimaryKey(cols) = constraint {
103 cols.iter()
104 .map(|col| schema.field(*col).name())
105 .for_each(|col| {
106 primary_keys.push(col.to_string());
107 });
108 }
109 }
110 primary_keys
111}
112
113#[cfg(test)]
114pub(crate) mod tests {
115 use std::sync::Arc;
116
117 use datafusion::arrow::datatypes::SchemaRef;
118 use datafusion::{
119 common::{Constraint, Constraints},
120 parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
121 };
122
123 #[tokio::test]
124 async fn test_validate_batch_with_constraints() -> Result<(), Box<dyn std::error::Error>> {
125 let parquet_bytes = reqwest::get("https://public-data.spiceai.org/eth.recent_logs.parquet")
126 .await?
127 .bytes()
128 .await?;
129
130 let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(parquet_bytes)?.build()?;
131
132 let records =
133 parquet_reader.collect::<Result<Vec<_>, datafusion::arrow::error::ArrowError>>()?;
134 let schema = records[0].schema();
135
136 let constraints =
137 get_unique_constraints(&["log_index", "transaction_hash"], Arc::clone(&schema));
138
139 let result = super::validate_batch_with_constraints(&records, &constraints).await;
140 assert!(
141 result.is_ok(),
142 "{}",
143 result.expect_err("this returned an error")
144 );
145
146 let invalid_constraints = get_unique_constraints(&["block_number"], Arc::clone(&schema));
147 let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
148 assert!(result.is_err());
149 assert_eq!(
150 result.expect_err("this returned an error").to_string(),
151 "Incoming data violates uniqueness constraint on column(s): block_number"
152 );
153
154 let invalid_constraints =
155 get_unique_constraints(&["block_number", "transaction_hash"], Arc::clone(&schema));
156 let result = super::validate_batch_with_constraints(&records, &invalid_constraints).await;
157 assert!(result.is_err());
158 assert_eq!(
159 result.expect_err("this returned an error").to_string(),
160 "Incoming data violates uniqueness constraint on column(s): block_number, transaction_hash"
161 );
162
163 Ok(())
164 }
165
166 pub(crate) fn get_unique_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
167 let indices = cols
168 .iter()
169 .map(|col| {
170 schema
171 .index_of(col)
172 .unwrap_or_else(|_| panic!("[{col}] not found, validated schema: [{}]", schema))
173 })
174 .collect();
175
176 Constraints::new_unverified(vec![Constraint::Unique(indices)])
177 }
178
179 pub(crate) fn get_pk_constraints(cols: &[&str], schema: SchemaRef) -> Constraints {
180 let indices = cols
181 .iter()
182 .map(|col| {
183 schema
184 .index_of(col)
185 .unwrap_or_else(|_| panic!("[{col}] not found, validated schema: [{}]", schema))
186 })
187 .collect();
188
189 Constraints::new_unverified(vec![Constraint::PrimaryKey(indices)])
190 }
191}