#[cfg(feature = "parquet")]
mod parquet;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::file_format::csv::CsvFormatFactory;
use crate::datasource::file_format::format_as_file_type;
use crate::datasource::file_format::json::JsonFormatFactory;
use crate::datasource::{provider_as_source, MemTable, TableProvider};
use crate::error::Result;
use crate::execution::context::{SessionState, TaskContext};
use crate::execution::FunctionRegistry;
use crate::logical_expr::utils::find_window_exprs;
use crate::logical_expr::{
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType,
};
use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use arrow_schema::{Schema, SchemaRef};
use datafusion_common::config::{CsvOptions, JsonOptions};
use datafusion_common::{
plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions,
};
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
use async_trait::async_trait;
use datafusion_catalog::Session;
pub struct DataFrameWriteOptions {
overwrite: bool,
single_file_output: bool,
partition_by: Vec<String>,
}
impl DataFrameWriteOptions {
pub fn new() -> Self {
DataFrameWriteOptions {
overwrite: false,
single_file_output: false,
partition_by: vec![],
}
}
pub fn with_overwrite(mut self, overwrite: bool) -> Self {
self.overwrite = overwrite;
self
}
pub fn with_single_file_output(mut self, single_file_output: bool) -> Self {
self.single_file_output = single_file_output;
self
}
pub fn with_partition_by(mut self, partition_by: Vec<String>) -> Self {
self.partition_by = partition_by;
self
}
}
impl Default for DataFrameWriteOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DataFrame {
session_state: Box<SessionState>,
plan: LogicalPlan,
}
impl DataFrame {
pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
Self {
session_state: Box::new(session_state),
plan,
}
}
pub fn parse_sql_expr(&self, sql: &str) -> Result<Expr> {
let df_schema = self.schema();
self.session_state.create_logical_expr(sql, df_schema)
}
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
self.session_state.create_physical_plan(&self.plan).await
}
pub fn select_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = fields
.into_iter()
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}
pub fn select(self, expr_list: Vec<Expr>) -> Result<DataFrame> {
let window_func_exprs = find_window_exprs(&expr_list);
let plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
};
let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
})
}
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields_to_drop = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.filter(|r| r.is_ok())
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, _)| self.plan.schema().qualified_field(idx))
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}
#[deprecated(since = "37.0.0", note = "use unnest_columns instead")]
pub fn unnest_column(self, column: &str) -> Result<DataFrame> {
self.unnest_columns(&[column])
}
#[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")]
pub fn unnest_column_with_options(
self,
column: &str,
options: UnnestOptions,
) -> Result<DataFrame> {
self.unnest_columns_with_options(&[column], options)
}
pub fn unnest_columns(self, columns: &[&str]) -> Result<DataFrame> {
self.unnest_columns_with_options(columns, UnnestOptions::new())
}
pub fn unnest_columns_with_options(
self,
columns: &[&str],
options: UnnestOptions,
) -> Result<DataFrame> {
let columns = columns.iter().map(|c| Column::from(*c)).collect();
let plan = LogicalPlanBuilder::from(self.plan)
.unnest_columns_with_options(columns, options)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn filter(self, predicate: Expr) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.filter(predicate)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn aggregate(
self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn window(self, window_exprs: Vec<Expr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.window(window_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.limit(skip, fetch)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn union(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn union_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union_distinct(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn distinct(self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).distinct()?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn distinct_on(
self,
on_expr: Vec<Expr>,
select_expr: Vec<Expr>,
sort_expr: Option<Vec<Expr>>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.distinct_on(on_expr, select_expr, sort_expr)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub async fn describe(self) -> Result<Self> {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];
let original_schema_fields = self.schema().fields().iter();
let mut describe_schemas = vec![Field::new("describe", DataType::Utf8, false)];
describe_schemas.extend(original_schema_fields.clone().map(|field| {
if field.data_type().is_numeric() {
Field::new(field.name(), DataType::Float64, true)
} else {
Field::new(field.name(), DataType::Utf8, true)
}
}));
let describe_record_batch = vec![
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| count(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| {
sum(case(is_null(col(f.name())))
.when(lit(true), lit(1))
.otherwise(lit(0))
.unwrap())
.alias(f.name())
})
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
})
.map(|f| min(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
})
.map(|f| max(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
];
let mut array_ref_vec: Vec<ArrayRef> = vec![Arc::new(StringArray::from(
supported_describe_functions.clone(),
))];
for field in original_schema_fields {
let mut array_datas = vec![];
for result in describe_record_batch.iter() {
let array_ref = match result {
Ok(df) => {
let batchs = df.clone().collect().await;
match batchs {
Ok(batchs)
if batchs.len() == 1
&& batchs[0]
.column_by_name(field.name())
.is_some() =>
{
let column =
batchs[0].column_by_name(field.name()).unwrap();
if field.data_type().is_numeric() {
cast(column, &DataType::Float64)?
} else {
cast(column, &DataType::Utf8)?
}
}
_ => Arc::new(StringArray::from(vec!["null"])),
}
}
Err(err)
if err.to_string().contains(
"Error during planning: \
Aggregate requires at least one grouping \
or aggregate expression",
) =>
{
Arc::new(StringArray::from(vec!["null"]))
}
Err(other_err) => {
panic!("{other_err}")
}
};
array_datas.push(array_ref);
}
array_ref_vec.push(concat(
array_datas
.iter()
.map(|af| af.as_ref())
.collect::<Vec<_>>()
.as_slice(),
)?);
}
let describe_record_batch =
RecordBatch::try_new(Arc::new(Schema::new(describe_schemas)), array_ref_vec)?;
let provider = MemTable::try_new(
describe_record_batch.schema(),
vec![vec![describe_record_batch]],
)?;
let plan = LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn sort(self, expr: Vec<Expr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn join(
self,
right: DataFrame,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
filter: Option<Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.join(
right.plan,
join_type,
(left_cols.to_vec(), right_cols.to_vec()),
filter,
)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn join_on(
self,
right: DataFrame,
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.join_on(right.plan, join_type, on_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn repartition(self, partitioning_scheme: Partitioning) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.repartition(partitioning_scheme)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub async fn count(self) -> Result<usize> {
let rows = self
.aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])?
.collect()
.await?;
let len = *rows
.first()
.and_then(|r| r.columns().first())
.and_then(|c| c.as_any().downcast_ref::<Int64Array>())
.and_then(|a| a.values().first())
.ok_or(DataFusionError::Internal(
"Unexpected output when collecting for count()".to_string(),
))? as usize;
Ok(len)
}
pub async fn collect(self) -> Result<Vec<RecordBatch>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
collect(plan, task_ctx).await
}
pub async fn show(self) -> Result<()> {
let results = self.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub async fn show_limit(self, num: usize) -> Result<()> {
let results = self.limit(0, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub fn task_ctx(&self) -> TaskContext {
TaskContext::from(self.session_state.as_ref())
}
pub async fn execute_stream(self) -> Result<SendableRecordBatchStream> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
execute_stream(plan, task_ctx)
}
pub async fn collect_partitioned(self) -> Result<Vec<Vec<RecordBatch>>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
collect_partitioned(plan, task_ctx).await
}
pub async fn execute_stream_partitioned(
self,
) -> Result<Vec<SendableRecordBatchStream>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
execute_stream_partitioned(plan, task_ctx)
}
pub fn schema(&self) -> &DFSchema {
self.plan.schema()
}
pub fn logical_plan(&self) -> &LogicalPlan {
&self.plan
}
pub fn into_parts(self) -> (SessionState, LogicalPlan) {
(*self.session_state, self.plan)
}
pub fn into_unoptimized_plan(self) -> LogicalPlan {
self.plan
}
pub fn into_optimized_plan(self) -> Result<LogicalPlan> {
self.session_state.optimize(&self.plan)
}
pub fn into_view(self) -> Arc<dyn TableProvider> {
Arc::new(DataFrameTableProvider { plan: self.plan })
}
pub fn explain(self, verbose: bool, analyze: bool) -> Result<DataFrame> {
if matches!(self.plan, LogicalPlan::Explain(_)) {
return plan_err!("Nested EXPLAINs are not supported");
}
let plan = LogicalPlanBuilder::from(self.plan)
.explain(verbose, analyze)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn registry(&self) -> &dyn FunctionRegistry {
self.session_state.as_ref()
}
pub fn intersect(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::intersect(left_plan, right_plan, true)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub fn except(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::except(left_plan, right_plan, true)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub async fn write_table(
self,
table_name: &str,
write_options: DataFrameWriteOptions,
) -> Result<Vec<RecordBatch>, DataFusionError> {
let arrow_schema = Schema::from(self.schema());
let plan = LogicalPlanBuilder::insert_into(
self.plan,
table_name.to_owned(),
&arrow_schema,
write_options.overwrite,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
}
.collect()
.await
}
pub async fn write_csv(
self,
path: &str,
options: DataFrameWriteOptions,
writer_options: Option<CsvOptions>,
) -> Result<Vec<RecordBatch>, DataFusionError> {
if options.overwrite {
return Err(DataFusionError::NotImplemented(
"Overwrites are not implemented for DataFrame::write_csv.".to_owned(),
));
}
let format = if let Some(csv_opts) = writer_options {
Arc::new(CsvFormatFactory::new_with_options(csv_opts))
} else {
Arc::new(CsvFormatFactory::new())
};
let file_type = format_as_file_type(format);
let plan = LogicalPlanBuilder::copy_to(
self.plan,
path.into(),
file_type,
HashMap::new(),
options.partition_by,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
}
.collect()
.await
}
pub async fn write_json(
self,
path: &str,
options: DataFrameWriteOptions,
writer_options: Option<JsonOptions>,
) -> Result<Vec<RecordBatch>, DataFusionError> {
if options.overwrite {
return Err(DataFusionError::NotImplemented(
"Overwrites are not implemented for DataFrame::write_json.".to_owned(),
));
}
let format = if let Some(json_opts) = writer_options {
Arc::new(JsonFormatFactory::new_with_options(json_opts))
} else {
Arc::new(JsonFormatFactory::new())
};
let file_type = format_as_file_type(format);
let plan = LogicalPlanBuilder::copy_to(
self.plan,
path.into(),
file_type,
Default::default(),
options.partition_by,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
}
.collect()
.await
}
pub fn with_column(self, name: &str, expr: Expr) -> Result<DataFrame> {
let window_func_exprs = find_window_exprs(&[expr.clone()]);
let plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
};
let new_column = expr.alias(name);
let mut col_exists = false;
let mut fields: Vec<Expr> = plan
.schema()
.iter()
.map(|(qualifier, field)| {
if field.name() == name {
col_exists = true;
new_column.clone()
} else {
col(Column::from((qualifier, field)))
}
})
.collect();
if !col_exists {
fields.push(new_column);
}
let project_plan = LogicalPlanBuilder::from(plan).project(fields)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
})
}
pub fn with_column_renamed(
self,
old_name: impl Into<String>,
new_name: &str,
) -> Result<DataFrame> {
let ident_opts = self
.session_state
.config_options()
.sql_parser
.enable_ident_normalization;
let old_column: Column = if ident_opts {
Column::from_qualified_name(old_name)
} else {
Column::from_qualified_name_ignore_case(old_name)
};
let (qualifier_rename, field_rename) =
match self.plan.schema().qualified_field_from_column(&old_column) {
Ok(qualifier_and_field) => qualifier_and_field,
Err(DataFusionError::SchemaError(
SchemaError::FieldNotFound { .. },
_,
)) => return Ok(self),
Err(err) => return Err(err),
};
let projection = self
.plan
.schema()
.iter()
.map(|(qualifier, field)| {
if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename {
col(Column::from((qualifier, field))).alias(new_name)
} else {
col(Column::from((qualifier, field)))
}
})
.collect::<Vec<_>>();
let project_plan = LogicalPlanBuilder::from(self.plan)
.project(projection)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
})
}
pub fn with_param_values(self, query_values: impl Into<ParamValues>) -> Result<Self> {
let plan = self.plan.with_param_values(query_values)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::new_with_state((*self.session_state).clone());
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
}
}
struct DataFrameTableProvider {
plan: LogicalPlan,
}
#[async_trait]
impl TableProvider for DataFrameTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn get_logical_plan(&self) -> Option<&LogicalPlan> {
Some(&self.plan)
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
}
fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
}
fn table_type(&self) -> TableType {
TableType::View
}
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut expr = LogicalPlanBuilder::from(self.plan.clone());
let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
if let Some(filter) = filter {
expr = expr.filter(filter)?
}
if let Some(p) = projection {
expr = expr.select(p.iter().copied())?
}
if let Some(l) = limit {
expr = expr.limit(0, Some(l))?
}
let plan = expr.build()?;
state.create_physical_plan(&plan).await
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use crate::assert_batches_sorted_eq;
use crate::execution::context::SessionConfig;
use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr};
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};
use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) {
let physical_plan = df
.clone()
.create_physical_plan()
.await
.expect("Error creating physical plan");
let actual = get_plan_string(&physical_plan);
assert_eq!(
expected, actual,
"\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
}
pub fn table_with_constraints() -> Arc<dyn TableProvider> {
let dual_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
dual_schema.clone(),
vec![
Arc::new(array::Int32Array::from(vec![1])),
Arc::new(array::StringArray::from(vec!["a"])),
],
)
.unwrap();
let provider = MemTable::try_new(dual_schema, vec![vec![batch]])
.unwrap()
.with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey(
vec![0],
)]));
Arc::new(provider)
}
async fn assert_logical_expr_schema_eq_physical_expr_schema(
df: DataFrame,
) -> Result<()> {
let logical_expr_dfschema = df.schema();
let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned());
let batches = df.collect().await?;
let physical_expr_schema = batches[0].schema();
assert_eq!(logical_expr_schema, physical_expr_schema);
Ok(())
}
#[tokio::test]
async fn test_array_agg_ord_schema() -> Result<()> {
let ctx = SessionContext::new();
let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;
let query = r#"SELECT
array_agg("double_field" ORDER BY "string_field") as "double_field",
array_agg("string_field" ORDER BY "string_field") as "string_field"
FROM test_table"#;
let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}
#[tokio::test]
async fn test_array_agg_schema() -> Result<()> {
let ctx = SessionContext::new();
let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;
let query = r#"SELECT
array_agg("double_field") as "double_field",
array_agg("string_field") as "string_field"
FROM test_table"#;
let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}
#[tokio::test]
async fn test_array_agg_distinct_schema() -> Result<()> {
let ctx = SessionContext::new();
let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(2.0, 'a')
"#;
ctx.sql(create_table_query).await?;
let query = r#"SELECT
array_agg(distinct "double_field") as "double_field",
array_agg(distinct "string_field") as "string_field"
FROM test_table"#;
let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}
#[tokio::test]
async fn select_columns() -> Result<()> {
let t = test_table().await?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_expr() -> Result<()> {
let t = test_table().await?;
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_with_window_exprs() -> Result<()> {
let t = test_table().await?;
let first_row = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
vec![col("aggregate_test_100.c1")],
))
.partition_by(vec![col("aggregate_test_100.c2")])
.build()
.unwrap();
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"select c1, first_value(c1) over (partition by c2) from aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_with_periods() -> Result<()> {
let array: Int32Array = [1, 10].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)])?;
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx.table("t").await?.select_columns(&["f.c1"])?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
["+------+", "| f.c1 |", "+------+", "| 1 |", "| 10 |", "+------+"],
&df_results
);
Ok(())
}
#[tokio::test]
async fn drop_columns() -> Result<()> {
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn drop_columns_with_duplicates() -> Result<()> {
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11", "c2", "c2"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn drop_columns_with_nonexistent_columns() -> Result<()> {
let t = test_table().await?;
let t2 = t.drop_columns(&["canada", "c2", "rocks"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn drop_columns_with_empty_array() -> Result<()> {
let t = test_table().await?;
let t2 = t.drop_columns(&[])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"SELECT c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn drop_with_quotes() -> Result<()> {
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f\"c1", Arc::new(array1) as _),
("f\"c2", Arc::new(array2) as _),
])?;
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
[
"+------+",
"| f\"c2 |",
"+------+",
"| 2 |",
"| 11 |",
"+------+"
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn drop_with_periods() -> Result<()> {
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f.c1", Arc::new(array1) as _),
("f.c2", Arc::new(array2) as _),
])?;
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"],
&df_results
);
Ok(())
}
#[tokio::test]
async fn aggregate() -> Result<()> {
let df = test_table().await?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
];
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
"| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |",
"| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |",
"| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+"],
&df
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_with_pk() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(config);
let df = ctx.read_table(table_with_constraints())?;
let group_expr = vec![col("id")];
let aggr_expr = vec![];
let df = df.aggregate(group_expr, aggr_expr)?;
let df = df.select(vec![col("id"), col("name")])?;
assert_physical_plan(
&df,
vec![
"AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]",
" MemoryExec: partitions=1, partition_sizes=[1]",
],
)
.await;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----+------+",
"| id | name |",
"+----+------+",
"| 1 | a |",
"+----+------+"
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_with_pk2() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(config);
let df = ctx.read_table(table_with_constraints())?;
let group_expr = vec![col("id")];
let aggr_expr = vec![];
let df = df.aggregate(group_expr, aggr_expr)?;
let predicate = col("id").eq(lit(1i32)).and(col("name").eq(lit("a")));
let df = df.filter(predicate)?;
assert_physical_plan(
&df,
vec![
"CoalesceBatchesExec: target_batch_size=8192",
" FilterExec: id@0 = 1 AND name@1 = a",
" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]",
" MemoryExec: partitions=1, partition_sizes=[1]",
],
)
.await;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+------+",
"| id | name |",
"+----+------+",
"| 1 | a |",
"+----+------+",],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_with_pk3() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(config);
let df = ctx.read_table(table_with_constraints())?;
let group_expr = vec![col("id")];
let aggr_expr = vec![];
let df = df.aggregate(group_expr, aggr_expr)?;
let predicate = col("id").eq(lit(1i32));
let df = df.filter(predicate)?;
let df = df.select(vec![col("id"), col("name")])?;
assert_physical_plan(
&df,
vec![
"CoalesceBatchesExec: target_batch_size=8192",
" FilterExec: id@0 = 1",
" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]",
" MemoryExec: partitions=1, partition_sizes=[1]",
],
)
.await;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+------+",
"| id | name |",
"+----+------+",
"| 1 | a |",
"+----+------+",],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_with_pk4() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(config);
let df = ctx.read_table(table_with_constraints())?;
let group_expr = vec![col("id")];
let aggr_expr = vec![];
let df = df.aggregate(group_expr, aggr_expr)?;
let predicate = col("id").eq(lit(1i32));
let df = df.filter(predicate)?;
let df = df.select(vec![col("id")])?;
assert_physical_plan(
&df,
vec![
"CoalesceBatchesExec: target_batch_size=8192",
" FilterExec: id@0 = 1",
" AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]",
" MemoryExec: partitions=1, partition_sizes=[1]",
],
)
.await;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----+",
"| id |",
"+----+",
"| 1 |",
"+----+",],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_alias() -> Result<()> {
let df = test_table().await?;
let df = df
.aggregate(vec![col("c2") + lit(1)], vec![])?
.select(vec![(col("c2") + lit(1)).alias("c2")])?
.aggregate(vec![col("c2").alias("c2")], vec![])?;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----+",
"| c2 |",
"+----+",
"| 2 |",
"| 3 |",
"| 4 |",
"| 5 |",
"| 6 |",
"+----+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_subexpr() -> Result<()> {
let df = test_table().await?;
let group_expr = col("c2") + lit(1);
let aggr_expr = sum(col("c3") + lit(2));
let df = df
.aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])?
.select(vec![
group_expr.alias("c2") + lit(10),
(aggr_expr + lit(20)).alias("sum"),
])?;
let df_results = df.collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----------------+------+",
"| c2 + Int32(10) | sum |",
"+----------------+------+",
"| 12 | 431 |",
"| 13 | 248 |",
"| 14 | 453 |",
"| 15 | 95 |",
"| 16 | -146 |",
"+----------------+------+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_aggregate_name_collision() -> Result<()> {
let df = test_table().await?;
let collided_alias = "aggregate_test_100.c2 + aggregate_test_100.c3";
let group_expr = lit(1).alias(collided_alias);
let df = df
.aggregate(vec![group_expr], vec![])?
.select(vec![
(col("aggregate_test_100.c2") + col("aggregate_test_100.c3")),
])
.expect_err("Expected error");
let expected = "Schema error: No field named aggregate_test_100.c2. \
Valid fields are \"aggregate_test_100.c2 + aggregate_test_100.c3\".";
assert_eq!(df.strip_backtrace(), expected);
Ok(())
}
#[tokio::test]
async fn test_select_over_aggregate_schema() -> Result<()> {
let df = test_table()
.await?
.with_column("c", col("c1"))?
.aggregate(vec![], vec![array_agg(col("c")).alias("c")])?
.select(vec![col("c")])?;
assert_eq!(df.schema().fields().len(), 1);
let field = df.schema().field(0);
assert!(matches!(field.data_type(), DataType::List(_)));
Ok(())
}
#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct()
.unwrap()
.plan
.clone();
let sql_plan = create_plan("select distinct c1 from aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn test_distinct_sort_by() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct()
.unwrap()
.sort(vec![col("c1").sort(true, true)])
.unwrap();
let df_results = plan.clone().collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+",
"| c1 |",
"+----+",
"| a |",
"| b |",
"| c |",
"| d |",
"| e |",
"+----+"],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_distinct_sort_by_unprojected() -> Result<()> {
let t = test_table().await?;
let err = t
.select(vec![col("c1")])
.unwrap()
.distinct()
.unwrap()
.sort(vec![col("c2").sort(true, true)])
.unwrap_err();
assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list");
Ok(())
}
#[tokio::test]
async fn test_distinct_on() -> Result<()> {
let t = test_table().await?;
let plan = t
.distinct_on(vec![col("c1")], vec![col("aggregate_test_100.c1")], None)
.unwrap();
let sql_plan =
create_plan("select distinct on (c1) c1 from aggregate_test_100").await?;
assert_same_plan(&plan.plan.clone(), &sql_plan);
let df_results = plan.clone().collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+",
"| c1 |",
"+----+",
"| a |",
"| b |",
"| c |",
"| d |",
"| e |",
"+----+"],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_distinct_on_sort_by() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct_on(
vec![col("c1")],
vec![col("c1")],
Some(vec![col("c1").sort(true, true)]),
)
.unwrap()
.sort(vec![col("c1").sort(true, true)])
.unwrap();
let df_results = plan.clone().collect().await?;
#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+",
"| c1 |",
"+----+",
"| a |",
"| b |",
"| c |",
"| d |",
"| e |",
"+----+"],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_distinct_on_sort_by_unprojected() -> Result<()> {
let t = test_table().await?;
let err = t
.select(vec![col("c1")])
.unwrap()
.distinct_on(
vec![col("c1")],
vec![col("c1")],
Some(vec![col("c1").sort(true, true)]),
)
.unwrap()
.sort(vec![col("c2").sort(true, true)])
.unwrap_err();
assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list");
Ok(())
}
#[tokio::test]
async fn join() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c3"])?;
let left_rows = left.clone().collect().await?;
let right_rows = right.clone().collect().await?;
let join = left.join(right, JoinType::Inner, &["c1"], &["c1"], None)?;
let join_rows = join.collect().await?;
assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(2008, join_rows.iter().map(|x| x.num_rows()).sum::<usize>());
Ok(())
}
#[tokio::test]
async fn join_on() -> Result<()> {
let left = test_table_with_name("a")
.await?
.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("b")
.await?
.select_columns(&["c1", "c2"])?;
let join = left.join_on(
right,
JoinType::Inner,
[col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))],
)?;
let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2\
\n Projection: a.c1, a.c2\
\n TableScan: a\
\n Projection: b.c1, b.c2\
\n TableScan: b";
assert_eq!(expected_plan, format!("{}", join.logical_plan()));
Ok(())
}
#[tokio::test]
async fn join_on_filter_datatype() -> Result<()> {
let left = test_table_with_name("a").await?.select_columns(&["c1"])?;
let right = test_table_with_name("b").await?.select_columns(&["c1"])?;
let join = left.clone().join_on(
right.clone(),
JoinType::Inner,
Some(Expr::Literal(ScalarValue::Null)),
)?;
let expected_plan = "CrossJoin:\
\n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\
\n TableScan: b projection=[c1]";
assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?));
let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?;
let expected = join.into_optimized_plan().unwrap_err();
assert_eq!(
expected.strip_backtrace(),
"type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8"
);
Ok(())
}
#[tokio::test]
async fn join_ambiguous_filter() -> Result<()> {
let left = test_table_with_name("a")
.await?
.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("b")
.await?
.select_columns(&["c1", "c2"])?;
let join = left
.join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))])
.expect_err("join didn't fail check");
let expected = "Schema error: Ambiguous reference to unqualified field c1";
assert_eq!(join.strip_backtrace(), expected);
Ok(())
}
#[tokio::test]
async fn limit() -> Result<()> {
let t = test_table().await?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(0, Some(10))?;
let plan = t2.plan.clone();
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn df_count() -> Result<()> {
let count = test_table().await?.count().await?;
assert_eq!(100, count);
Ok(())
}
#[tokio::test]
async fn explain() -> Result<()> {
let df = test_table().await?;
let df = df
.select_columns(&["c1", "c2", "c11"])?
.limit(0, Some(10))?
.explain(false, false)?;
let plan = df.plan.clone();
let sql_plan =
create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn registry() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented"));
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
my_fn,
));
let df = ctx.table("aggregate_test_100").await?;
let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
let df = df.select(vec![expr])?;
let sql_plan = ctx.sql("SELECT my_fn(c12) FROM aggregate_test_100").await?;
assert_same_plan(&df.plan, &sql_plan.plan);
Ok(())
}
#[tokio::test]
async fn sendable() {
let df = test_table().await.unwrap();
let task = SpawnedTask::spawn(async move {
df.select_columns(&["c1"])
.expect("should be usable in a task")
});
task.join().await.expect("task completed successfully");
}
#[tokio::test]
async fn intersect() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let d2 = df.clone();
let plan = df.intersect(d2)?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
INTERSECT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn except() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let d2 = df.clone();
let plan = df.except(d2)?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
EXCEPT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
ctx.register_table("test_table", df_impl.clone().into_view())?;
let table = ctx.table("test_table").await?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![sum(col("c12"))];
let df_results = df_impl
.aggregate(group_expr.clone(), aggr_expr.clone())?
.collect()
.await?;
let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
[
"+----+-----------------------------+",
"| c1 | sum(aggregate_test_100.c12) |",
"+----+-----------------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+-----------------------------+"
],
&df_results
);
assert_batches_sorted_eq!(
[
"+----+---------------------+",
"| c1 | sum(test_table.c12) |",
"+----+---------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+---------------------+"
],
table_results
);
Ok(())
}
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{plan1:?}"), format!("{plan2:?}"));
}
async fn create_plan(sql: &str) -> Result<LogicalPlan> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
Ok(ctx.sql(sql).await?.into_unoptimized_plan())
}
#[tokio::test]
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
let df = df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.with_column("sum", col("c2") + col("c3"))?;
let df_results = df.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 3 | -12 | -9 |",
"| a | 3 | -72 | -69 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 14 | 17 |",
"| a | 3 | 17 | 20 |",
"+----+----+-----+-----+"
],
&df_results
);
let df_results_overwrite = df
.clone()
.with_column("c1", col("c2") + col("c3"))?
.collect()
.await?;
assert_batches_sorted_eq!(
[
"+-----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+-----+----+-----+-----+",
"| -69 | 3 | -72 | -69 |",
"| -9 | 3 | -12 | -9 |",
"| 16 | 3 | 13 | 16 |",
"| 16 | 3 | 13 | 16 |",
"| 17 | 3 | 14 | 17 |",
"| 20 | 3 | 17 | 20 |",
"+-----+----+-----+-----+"
],
&df_results_overwrite
);
let df_results_overwrite_self = df
.clone()
.with_column("c2", col("c2") + lit(1))?
.collect()
.await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 4 | -12 | -9 |",
"| a | 4 | -72 | -69 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 14 | 17 |",
"| a | 4 | 17 | 20 |",
"+----+----+-----+-----+"
],
&df_results_overwrite_self
);
Ok(())
}
#[tokio::test]
async fn with_column_join_same_columns() -> Result<()> {
let df = test_table().await?.select_columns(&["c1"])?;
let ctx = SessionContext::new();
let table = df.into_view();
ctx.register_table("t1", table.clone())?;
ctx.register_table("t2", table)?;
let df = ctx
.table("t1")
.await?
.join(
ctx.table("t2").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.sort(vec![
col("t1.c1").sort(true, true),
])?
.limit(0, Some(1))?;
let df_results = df.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+",
"| c1 | c1 |",
"+----+----+",
"| a | a |",
"+----+----+",
],
&df_results
);
let df_with_column = df.clone().with_column("new_column", lit(true))?;
assert_eq!(
"\
Projection: t1.c1, t2.c1, Boolean(true) AS new_column\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST\
\n Inner Join: t1.c1 = t2.c1\
\n TableScan: t1\
\n TableScan: t2",
format!("{}", df_with_column.logical_plan())
);
assert_eq!(
"\
Projection: t1.c1, t2.c1, Boolean(true) AS new_column\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, fetch=1\
\n Inner Join: t1.c1 = t2.c1\
\n SubqueryAlias: t1\
\n TableScan: aggregate_test_100 projection=[c1]\
\n SubqueryAlias: t2\
\n TableScan: aggregate_test_100 projection=[c1]",
format!("{}", df_with_column.clone().into_optimized_plan()?)
);
let df_results = df_with_column.collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+------------+",
"| c1 | c1 | new_column |",
"+----+----+------------+",
"| a | a | true |",
"+----+----+------------+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn with_column_self_join() -> Result<()> {
let df = test_table().await?.select_columns(&["c1"])?;
let ctx = SessionContext::new();
ctx.register_table("t1", df.into_view())?;
let df = ctx
.table("t1")
.await?
.join(
ctx.table("t1").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.sort(vec![
col("t1.c1").sort(true, true),
])?
.limit(0, Some(1))?;
let df_results = df.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+",
"| c1 | c1 |",
"+----+----+",
"| a | a |",
"+----+----+",
],
&df_results
);
let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err();
let expected_err = "Error during planning: Projections require unique expression names \
but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \
Consider aliasing (\"AS\") one of them.";
assert_eq!(actual_err.strip_backtrace(), expected_err);
Ok(())
}
#[tokio::test]
async fn with_column_renamed() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(0, Some(1))?
.sort(vec![
col("c1").sort(true, true),
col("c2").sort(true, true),
col("c3").sort(true, true),
])?
.with_column("sum", col("c2") + col("c3"))?;
let df_sum_renamed = df
.with_column_renamed("sum", "total")?
.with_column_renamed("c1", "one")?
.with_column_renamed("aggregate_test_100.c2", "two")?
.with_column_renamed("c4", "boom")?
.collect()
.await?;
assert_batches_sorted_eq!(
[
"+-----+-----+----+-------+",
"| one | two | c3 | total |",
"+-----+-----+----+-------+",
"| a | 3 | 13 | 16 |",
"+-----+-----+----+-------+"
],
&df_sum_renamed
);
Ok(())
}
#[tokio::test]
async fn with_column_renamed_ambiguous() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let table = df.into_view();
ctx.register_table("t1", table.clone())?;
ctx.register_table("t2", table)?;
let actual_err = ctx
.table("t1")
.await?
.join(
ctx.table("t2").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.with_column_renamed("c2", "AAA")
.unwrap_err();
let expected_err = "Schema error: Ambiguous reference to unqualified field c2";
assert_eq!(actual_err.strip_backtrace(), expected_err);
Ok(())
}
#[tokio::test]
async fn with_column_renamed_join() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let table = df.into_view();
ctx.register_table("t1", table.clone())?;
ctx.register_table("t2", table)?;
let df = ctx
.table("t1")
.await?
.join(
ctx.table("t2").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.sort(vec![
col("t1.c1").sort(true, true),
col("t1.c2").sort(true, true),
col("t1.c3").sort(true, true),
col("t2.c1").sort(true, true),
col("t2.c2").sort(true, true),
col("t2.c3").sort(true, true),
])?
.limit(0, Some(1))?;
let df_results = df.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+----+----+-----+",
"| c1 | c2 | c3 | c1 | c2 | c3 |",
"+----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+----+----+-----+----+----+-----+"
],
&df_results
);
let df_renamed = df.clone().with_column_renamed("t1.c1", "AAA")?;
assert_eq!("\
Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST\
\n Inner Join: t1.c1 = t2.c1\
\n TableScan: t1\
\n TableScan: t2",
format!("{}", df_renamed.logical_plan())
);
assert_eq!("\
Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1\
\n Inner Join: t1.c1 = t2.c1\
\n SubqueryAlias: t1\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]\
\n SubqueryAlias: t2\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]",
format!("{}", df_renamed.clone().into_optimized_plan()?)
);
let df_results = df_renamed.collect().await?;
assert_batches_sorted_eq!(
[
"+-----+----+-----+----+----+-----+",
"| AAA | c2 | c3 | c1 | c2 | c3 |",
"+-----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+-----+----+-----+----+----+-----+"
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn with_column_renamed_case_sensitive() -> Result<()> {
let config =
SessionConfig::from_string_hash_map(std::collections::HashMap::from([(
"datafusion.sql_parser.enable_ident_normalization".to_owned(),
"false".to_owned(),
)]))?;
let ctx = SessionContext::new_with_config(config);
let name = "aggregate_test_100";
register_aggregate_csv(&ctx, name).await?;
let df = ctx.table(name);
let df = df
.await?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(0, Some(1))?
.sort(vec![
col("c1").sort(true, true),
col("c2").sort(true, true),
col("c3").sort(true, true),
])?
.select_columns(&["c1"])?;
let df_renamed = df.clone().with_column_renamed("c1", "CoLuMn1")?;
let res = &df_renamed.clone().collect().await?;
assert_batches_sorted_eq!(
[
"+---------+",
"| CoLuMn1 |",
"+---------+",
"| a |",
"+---------+"
],
res
);
let df_renamed = df_renamed
.with_column_renamed("CoLuMn1", "c1")?
.collect()
.await?;
assert_batches_sorted_eq!(
["+----+", "| c1 |", "+----+", "| a |", "+----+"],
&df_renamed
);
Ok(())
}
#[tokio::test]
async fn cast_expr_test() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
let df_results = df.clone().collect().await?;
df.clone().show().await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+",
"| c2 | c3 | sum |",
"+----+----+-----+",
"| 2 | 1 | 3 |",
"+----+----+-----+"
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn row_writer_resize_test() -> Result<()> {
let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new(
"column_1",
DataType::Utf8,
false,
)]));
let data = RecordBatch::try_new(
schema,
vec![
Arc::new(arrow::array::StringArray::from(vec![
Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"),
Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"),
]))
],
)?;
let ctx = SessionContext::new();
ctx.register_batch("test", data)?;
let sql = r#"
SELECT
count(1)
FROM
test
GROUP BY
column_1"#;
let df = ctx.sql(sql).await?;
df.show_limit(10).await?;
Ok(())
}
#[tokio::test]
async fn with_column_name() -> Result<()> {
let array: Int32Array = [1, 10].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)])?;
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx
.table("t")
.await?
.with_column("f.c2", lit("hello"))?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
[
"+------+-------+",
"| f.c1 | f.c2 |",
"+------+-------+",
"| 1 | hello |",
"| 10 | hello |",
"+------+-------+"
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn test_cache_mismatch() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx
.sql("SELECT CASE WHEN true THEN NULL ELSE 1 END")
.await?;
let cache_df = df.cache().await;
assert!(cache_df.is_ok());
Ok(())
}
#[tokio::test]
async fn cache_test() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
let cached_df = df.clone().cache().await?;
assert_eq!(
"TableScan: ?table? projection=[c2, c3, sum]",
format!("{}", cached_df.clone().into_optimized_plan()?)
);
let df_results = df.collect().await?;
let cached_df_results = cached_df.collect().await?;
assert_batches_sorted_eq!(
[
"+----+----+-----+",
"| c2 | c3 | sum |",
"+----+----+-----+",
"| 2 | 1 | 3 |",
"+----+----+-----+"
],
&cached_df_results
);
assert_eq!(&df_results, &cached_df_results);
Ok(())
}
#[tokio::test]
async fn partition_aware_union() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c3"])?
.with_column_renamed("c2.c1", "c2_c1")?;
let left_rows = left.clone().collect().await?;
let right_rows = right.clone().collect().await?;
let join1 = left.clone().join(
right.clone(),
JoinType::Inner,
&["c1"],
&["c2_c1"],
None,
)?;
let join2 = left.join(right, JoinType::Inner, &["c1"], &["c2_c1"], None)?;
let union = join1.union(join2)?;
let union_rows = union.clone().collect().await?;
assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(4016, union_rows.iter().map(|x| x.num_rows()).sum::<usize>());
let physical_plan = union.create_physical_plan().await?;
let default_partition_count = SessionConfig::new().target_partitions();
assert_eq!(
physical_plan.output_partitioning().partition_count(),
default_partition_count
);
for child in physical_plan.children() {
assert_eq!(
physical_plan.output_partitioning(),
child.output_partitioning()
);
}
Ok(())
}
#[tokio::test]
async fn non_partition_aware_union() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c2"])?
.with_column_renamed("c2.c1", "c2_c1")?
.with_column_renamed("c2.c2", "c2_c2")?;
let left_rows = left.clone().collect().await?;
let right_rows = right.clone().collect().await?;
let join1 = left.clone().join(
right.clone(),
JoinType::Inner,
&["c1", "c2"],
&["c2_c1", "c2_c2"],
None,
)?;
let join2 = left.join(
right,
JoinType::Inner,
&["c2", "c1"],
&["c2_c2", "c2_c1"],
None,
)?;
let union = join1.union(join2)?;
let union_rows = union.clone().collect().await?;
assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(916, union_rows.iter().map(|x| x.num_rows()).sum::<usize>());
let physical_plan = union.create_physical_plan().await?;
let default_partition_count = SessionConfig::new().target_partitions();
assert!(matches!(
physical_plan.output_partitioning(),
Partitioning::UnknownPartitioning(partition_count) if *partition_count == default_partition_count * 2));
Ok(())
}
#[tokio::test]
async fn verify_join_output_partitioning() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c2"])?
.with_column_renamed("c2.c1", "c2_c1")?
.with_column_renamed("c2.c2", "c2_c2")?;
let all_join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::LeftSemi,
JoinType::RightSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
];
let default_partition_count = SessionConfig::new().target_partitions();
for join_type in all_join_types {
let join = left.clone().join(
right.clone(),
join_type,
&["c1", "c2"],
&["c2_c1", "c2_c2"],
None,
)?;
let physical_plan = join.create_physical_plan().await?;
let out_partitioning = physical_plan.output_partitioning();
let join_schema = physical_plan.schema();
match join_type {
JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
let left_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new_with_schema("c1", &join_schema)?),
Arc::new(Column::new_with_schema("c2", &join_schema)?),
];
assert_eq!(
out_partitioning,
&Partitioning::Hash(left_exprs, default_partition_count)
);
}
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => {
let right_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new_with_schema("c2_c1", &join_schema)?),
Arc::new(Column::new_with_schema("c2_c2", &join_schema)?),
];
assert_eq!(
out_partitioning,
&Partitioning::Hash(right_exprs, default_partition_count)
);
}
JoinType::Full => {
assert!(matches!(
out_partitioning,
&Partitioning::UnknownPartitioning(partition_count) if partition_count == default_partition_count));
}
}
}
Ok(())
}
#[tokio::test]
async fn test_except_nested_struct() -> Result<()> {
use arrow::array::StructArray;
let nested_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("lat", DataType::Int32, true),
Field::new("long", DataType::Int32, true),
]));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields.clone()),
true,
),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let updated_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();
let ctx = SessionContext::new();
let before = ctx.read_batch(batch).expect("Failed to make DataFrame");
let after = ctx
.read_batch(updated_batch)
.expect("Failed to make DataFrame");
let diff = before
.except(after)
.expect("Failed to except")
.collect()
.await?;
assert_eq!(diff.len(), 1);
Ok(())
}
#[tokio::test]
async fn nested_explain_should_fail() -> Result<()> {
let ctx = SessionContext::new();
let mut result = ctx.sql("explain select 1").await?.explain(false, false);
assert!(result.is_err());
result = ctx.sql("explain explain select 1").await;
assert!(result.is_err());
Ok(())
}
}