use crate::error::Result;
use crate::prelude::{CsvReadOptions, SessionContext};
use async_trait::async_trait;
use std::sync::Arc;
async fn register_current_csv(
ctx: &SessionContext,
table_name: &str,
infinite: bool,
) -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
let schema = crate::test_util::aggr_test_schema();
ctx.register_csv(
table_name,
&format!("{testdata}/csv/aggregate_test_100.csv"),
CsvReadOptions::new()
.schema(&schema)
.mark_infinite(infinite),
)
.await?;
Ok(())
}
#[derive(Eq, PartialEq, Debug)]
pub enum SourceType {
Unbounded,
Bounded,
}
#[async_trait]
pub trait SqlTestCase {
async fn register_table(&self, ctx: &SessionContext) -> Result<()>;
fn expect_fail(&self) -> bool;
}
pub struct UnaryTestCase {
pub(crate) source_type: SourceType,
pub(crate) expect_fail: bool,
}
#[async_trait]
impl SqlTestCase for UnaryTestCase {
async fn register_table(&self, ctx: &SessionContext) -> Result<()> {
let table_is_infinite = self.source_type == SourceType::Unbounded;
register_current_csv(ctx, "test", table_is_infinite).await?;
Ok(())
}
fn expect_fail(&self) -> bool {
self.expect_fail
}
}
pub struct BinaryTestCase {
pub(crate) source_types: (SourceType, SourceType),
pub(crate) expect_fail: bool,
}
#[async_trait]
impl SqlTestCase for BinaryTestCase {
async fn register_table(&self, ctx: &SessionContext) -> Result<()> {
let left_table_is_infinite = self.source_types.0 == SourceType::Unbounded;
let right_table_is_infinite = self.source_types.1 == SourceType::Unbounded;
register_current_csv(ctx, "left", left_table_is_infinite).await?;
register_current_csv(ctx, "right", right_table_is_infinite).await?;
Ok(())
}
fn expect_fail(&self) -> bool {
self.expect_fail
}
}
pub struct QueryCase {
pub(crate) sql: String,
pub(crate) cases: Vec<Arc<dyn SqlTestCase>>,
pub(crate) error_operator: String,
}
impl QueryCase {
pub(crate) async fn run(&self) -> Result<()> {
for case in &self.cases {
let ctx = SessionContext::new();
case.register_table(&ctx).await?;
let error = if case.expect_fail() {
Some(&self.error_operator)
} else {
None
};
self.run_case(ctx, error).await?;
}
Ok(())
}
async fn run_case(&self, ctx: SessionContext, error: Option<&String>) -> Result<()> {
let dataframe = ctx.sql(self.sql.as_str()).await?;
let plan = dataframe.create_physical_plan().await;
if error.is_some() {
let plan_error = plan.unwrap_err();
let initial = error.unwrap().to_string();
assert!(
plan_error.to_string().contains(initial.as_str()),
"plan_error: {:?} doesn't contain message: {:?}",
plan_error,
initial.as_str()
);
} else {
assert!(plan.is_ok())
}
Ok(())
}
}