#[cfg(test)]
mod tests {
use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::LogicalPlanBuilder;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use datafusion_substrait::logical_plan::producer::to_substrait_plan;
use datafusion_substrait::serializer;
use datafusion::error::Result;
use datafusion::prelude::*;
use insta::assert_snapshot;
use std::fs;
use substrait::proto::plan_rel::RelType;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::{RelCommon, rel};
#[tokio::test]
async fn serialize_to_file() -> Result<()> {
let ctx = create_context().await?;
let path = "tests/serialize_to_file.bin";
let sql = "SELECT a, b FROM data";
serializer::serialize(sql, &ctx, path).await?;
serializer::deserialize(path).await?;
let got = serializer::serialize(sql, &ctx, path)
.await
.unwrap_err()
.to_string();
assert!(
[
"File exists", "os error 80" ]
.iter()
.any(|s| got.contains(s))
);
fs::remove_file(path)?;
Ok(())
}
#[tokio::test]
async fn serialize_simple_select() -> Result<()> {
let ctx = create_context().await?;
let path = "tests/simple_select.bin";
let sql = "SELECT a, b FROM data";
let df_ref = ctx.sql(sql).await?;
let plan_ref = df_ref.into_optimized_plan()?;
serializer::serialize(sql, &ctx, path).await?;
let proto = serializer::deserialize(path).await?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
let plan_str_ref = format!("{plan_ref}");
let plan_str = format!("{plan}");
assert_eq!(plan_str_ref, plan_str);
fs::remove_file(path)?;
Ok(())
}
#[tokio::test]
async fn table_scan_without_projection() -> Result<()> {
let ctx = create_context().await?;
let table = provider_as_source(ctx.table_provider("data").await?);
let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?;
let convert_result = to_substrait_plan(&table_scan, &ctx.state());
assert!(convert_result.is_ok());
Ok(())
}
#[tokio::test]
async fn include_remaps_for_projects() -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql("SELECT b, a + a, a FROM data").await?;
let datafusion_plan = df.into_optimized_plan()?;
assert_snapshot!(
format!("{}", datafusion_plan),
@r"
Projection: data.b, data.a + data.a, data.a
TableScan: data projection=[a, b]
"
,
);
let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
.as_ref()
.clone();
let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};
if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() {
assert_emit(p.common.as_ref(), vec![2, 3, 4]);
if let Some(rel::RelType::Read(r)) =
p.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
2,
select.struct_items.len(),
"Read outputs two columns: a, b"
);
return Ok(());
}
}
panic!("plan did not match expected structure")
}
#[tokio::test]
async fn include_remaps_for_windows() -> Result<()> {
let ctx = create_context().await?;
let df = ctx
.sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;")
.await?;
let datafusion_plan = df.into_optimized_plan()?;
assert_snapshot!(
datafusion_plan,
@r"
Projection: data.b, rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c
WindowAggr: windowExpr=[[rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
TableScan: data projection=[a, b, c]
"
,
);
let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
.as_ref()
.clone();
let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};
if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() {
assert_emit(p1.common.as_ref(), vec![4, 5, 6]);
if let Some(rel::RelType::Project(p2)) =
p1.input.as_ref().unwrap().rel_type.as_ref()
{
assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]);
if let Some(rel::RelType::Read(r)) =
p2.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
3,
select.struct_items.len(),
"Read outputs three columns: a, b, c"
);
return Ok(());
}
}
}
panic!("plan did not match expected structure")
}
fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) {
assert_eq!(
rel_common.unwrap().emit_kind.clone(),
Some(EmitKind::Emit(Emit { output_mapping }))
);
}
async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}