use std::sync::Arc;
use apply_schema::{apply_schema, apply_schema_to};
use evaluate_expression::{evaluate_expression, evaluate_predicate, extract_column};
use itertools::Itertools;
use tracing::debug;
use super::arrow_conversion::{TryFromKernel as _, TryIntoArrow as _};
use crate::arrow::array::{self, ArrayBuilder, ArrayRef, RecordBatch, StructArray};
use crate::arrow::datatypes::{
DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
};
use crate::engine::arrow_data::{extract_record_batch, ArrowEngineData};
use crate::error::{DeltaResult, Error};
use crate::expressions::{ArrayData, Expression, ExpressionRef, PredicateRef, Scalar};
use crate::schema::{DataType, PrimitiveType, SchemaRef};
use crate::utils::require;
use crate::{EngineData, EvaluationHandler, ExpressionEvaluator, PredicateEvaluator};
mod apply_schema;
pub mod evaluate_expression;
pub mod opaque;
#[cfg(test)]
mod tests;
impl Scalar {
pub fn to_array(&self, num_rows: usize) -> DeltaResult<ArrayRef> {
let data_type = ArrowDataType::try_from_kernel(&self.data_type())?;
let mut builder = array::make_builder(&data_type, num_rows);
self.append_to(&mut builder, num_rows)?;
Ok(builder.finish())
}
fn append_to(&self, builder: &mut dyn ArrayBuilder, num_rows: usize) -> DeltaResult<()> {
use Scalar::*;
macro_rules! builder_as {
($t:ty) => {{
builder.as_any_mut().downcast_mut::<$t>().ok_or_else(|| {
Error::invalid_expression(format!("Invalid builder for {}", self.data_type()))
})?
}};
}
macro_rules! append_val_n_as {
($t:ty, $val:expr) => {{
let builder = builder_as!($t);
builder.append_value_n($val, num_rows);
}};
}
macro_rules! append_val_as {
($t:ty, $val:expr) => {{
let builder = builder_as!($t);
for _ in 0..num_rows {
builder.append_value($val);
}
}};
}
match self {
Integer(val) => append_val_n_as!(array::Int32Builder, *val),
Long(val) => append_val_n_as!(array::Int64Builder, *val),
Short(val) => append_val_n_as!(array::Int16Builder, *val),
Byte(val) => append_val_n_as!(array::Int8Builder, *val),
Float(val) => append_val_n_as!(array::Float32Builder, *val),
Double(val) => append_val_n_as!(array::Float64Builder, *val),
String(val) => append_val_as!(array::StringBuilder, val),
Boolean(val) => builder_as!(array::BooleanBuilder).append_n(num_rows, *val),
Timestamp(val) | TimestampNtz(val) => {
append_val_n_as!(array::TimestampMicrosecondBuilder, *val)
}
Date(val) => append_val_n_as!(array::Date32Builder, *val),
Binary(val) => append_val_as!(array::BinaryBuilder, val),
Decimal(val) => append_val_n_as!(array::Decimal128Builder, val.bits()),
Struct(data) => {
let builder = builder_as!(array::StructBuilder);
require!(
builder.num_fields() == data.fields().len(),
Error::generic("Struct builder has wrong number of fields")
);
let field_builders = builder.field_builders_mut().iter_mut();
for (builder, value) in field_builders.zip(data.values()) {
value.append_to(builder, num_rows)?;
}
for _ in 0..num_rows {
builder.append(true);
}
}
Array(data) => {
let builder = builder_as!(array::ListBuilder<Box<dyn ArrayBuilder>>);
for _ in 0..num_rows {
for value in data.array_elements() {
value.append_to(builder.values(), 1)?;
}
builder.append(true);
}
}
Map(data) => {
let builder =
builder_as!(array::MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>);
for _ in 0..num_rows {
for (key, val) in data.pairs() {
key.append_to(builder.keys(), 1)?;
val.append_to(builder.values(), 1)?;
}
builder.append(true)?;
}
}
Null(data_type) => Self::append_null(builder, data_type, num_rows)?,
}
Ok(())
}
fn append_null(
builder: &mut dyn ArrayBuilder,
data_type: &DataType,
num_rows: usize,
) -> DeltaResult<()> {
macro_rules! builder_as {
($t:ty) => {{
builder.as_any_mut().downcast_mut::<$t>().ok_or_else(|| {
Error::invalid_expression(format!("Invalid builder for {data_type}"))
})?
}};
}
macro_rules! append_nulls_as {
($t:ty) => {{
let builder = builder_as!($t);
builder.append_nulls(num_rows);
}};
}
match *data_type {
DataType::INTEGER => append_nulls_as!(array::Int32Builder),
DataType::LONG => append_nulls_as!(array::Int64Builder),
DataType::SHORT => append_nulls_as!(array::Int16Builder),
DataType::BYTE => append_nulls_as!(array::Int8Builder),
DataType::FLOAT => append_nulls_as!(array::Float32Builder),
DataType::DOUBLE => append_nulls_as!(array::Float64Builder),
DataType::STRING => append_nulls_as!(array::StringBuilder),
DataType::BOOLEAN => append_nulls_as!(array::BooleanBuilder),
DataType::TIMESTAMP | DataType::TIMESTAMP_NTZ => {
append_nulls_as!(array::TimestampMicrosecondBuilder)
}
DataType::DATE => append_nulls_as!(array::Date32Builder),
DataType::BINARY => append_nulls_as!(array::BinaryBuilder),
DataType::Primitive(PrimitiveType::Decimal(_)) => {
append_nulls_as!(array::Decimal128Builder)
}
DataType::Struct(ref stype) => {
let builder = builder_as!(array::StructBuilder);
require!(
builder.num_fields() == stype.num_fields(),
Error::generic("Struct builder has wrong number of fields")
);
let field_builders = builder.field_builders_mut().iter_mut();
for (builder, field) in field_builders.zip(stype.fields()) {
Self::append_null(builder, &field.data_type, num_rows)?;
}
builder.append_nulls(num_rows);
}
DataType::Array(_) => append_nulls_as!(array::ListBuilder<Box<dyn ArrayBuilder>>),
DataType::Map(_) => {
let builder =
builder_as!(array::MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>);
for _ in 0..num_rows {
builder.append(false)?;
}
}
DataType::Variant(_) => {
return Err(Error::unsupported(
"Variant is not supported as scalar yet.",
));
}
}
Ok(())
}
}
impl ArrayData {
pub fn to_arrow(&self) -> DeltaResult<ArrayRef> {
let arrow_data_type = ArrowDataType::try_from_kernel(self.array_type().element_type())?;
let elements = self.array_elements();
let mut builder = array::make_builder(&arrow_data_type, elements.len());
for element in elements {
element.append_to(&mut builder, 1)?;
}
Ok(builder.finish())
}
}
#[derive(Debug)]
pub struct ArrowEvaluationHandler;
impl EvaluationHandler for ArrowEvaluationHandler {
fn new_expression_evaluator(
&self,
schema: SchemaRef,
expression: ExpressionRef,
output_type: DataType,
) -> DeltaResult<Arc<dyn ExpressionEvaluator>> {
Ok(Arc::new(DefaultExpressionEvaluator {
_input_schema: schema,
expression,
output_type,
}))
}
fn new_predicate_evaluator(
&self,
schema: SchemaRef,
predicate: PredicateRef,
) -> DeltaResult<Arc<dyn PredicateEvaluator>> {
Ok(Arc::new(DefaultPredicateEvaluator {
_input_schema: schema,
predicate,
}))
}
fn null_row(&self, output_schema: SchemaRef) -> DeltaResult<Box<dyn EngineData>> {
let fields = output_schema.fields();
let arrays = fields
.map(|field| Scalar::Null(field.data_type().clone()).to_array(1))
.try_collect()?;
let record_batch =
RecordBatch::try_new(Arc::new(output_schema.as_ref().try_into_arrow()?), arrays)?;
Ok(Box::new(ArrowEngineData::new(record_batch)))
}
fn create_many(
&self,
schema: SchemaRef,
rows: &[&[Scalar]],
) -> DeltaResult<Box<dyn EngineData>> {
let arrow_schema: Arc<ArrowSchema> = Arc::new(schema.as_ref().try_into_arrow()?);
if rows.is_empty() {
return Ok(Box::new(ArrowEngineData::new(RecordBatch::new_empty(
arrow_schema,
))));
}
let num_rows = rows.len();
let num_fields = schema.fields().len();
for (row_idx, row) in rows.iter().enumerate() {
if row.len() != num_fields {
return Err(Error::generic(format!(
"Row {} has {} scalars but schema has {} fields",
row_idx,
row.len(),
num_fields
)));
}
}
let mut builders: Vec<Box<dyn ArrayBuilder>> = arrow_schema
.fields()
.iter()
.map(|field| array::make_builder(field.data_type(), num_rows))
.collect();
let fields: Vec<_> = schema.fields().collect();
for (col_idx, builder) in builders.iter_mut().enumerate() {
let field_name = fields[col_idx].name();
for (row_idx, row) in rows.iter().enumerate() {
row[col_idx].append_to(builder.as_mut(), 1).map_err(|e| {
Error::generic(format!(
"Row {row_idx}, field '{field_name}' \
(expected type {}, got {}): {e}",
fields[col_idx].data_type(),
row[col_idx].data_type()
))
})?;
}
}
let arrays: Vec<ArrayRef> = builders.into_iter().map(|mut b| b.finish()).collect();
Ok(Box::new(ArrowEngineData::new(RecordBatch::try_new(
arrow_schema,
arrays,
)?)))
}
}
#[derive(Debug)]
pub struct DefaultExpressionEvaluator {
_input_schema: SchemaRef,
expression: ExpressionRef,
output_type: DataType,
}
impl ExpressionEvaluator for DefaultExpressionEvaluator {
fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult<Box<dyn EngineData>> {
debug!("Arrow evaluator evaluating: {:#?}", self.expression);
let batch = extract_record_batch(batch)?;
let batch = match (self.expression.as_ref(), &self.output_type) {
(Expression::Transform(transform), DataType::Struct(_)) if transform.is_identity() => {
let array = match transform.input_path() {
None => Arc::new(StructArray::from(batch.clone())),
Some(path) => extract_column(batch, path)?,
};
apply_schema(&array, &self.output_type)?
}
(expr, output_type @ DataType::Struct(_)) => {
let array_ref = evaluate_expression(expr, batch, Some(output_type))?;
apply_schema(&array_ref, output_type)?
}
(expr, output_type) => {
let array_ref = evaluate_expression(expr, batch, Some(output_type))?;
let array_ref = apply_schema_to(&array_ref, output_type)?;
let arrow_type = ArrowDataType::try_from_kernel(output_type)?;
let schema = ArrowSchema::new(vec![ArrowField::new("output", arrow_type, true)]);
RecordBatch::try_new(Arc::new(schema), vec![array_ref])?
}
};
Ok(Box::new(ArrowEngineData::new(batch)))
}
}
#[derive(Debug)]
pub struct DefaultPredicateEvaluator {
_input_schema: SchemaRef,
predicate: PredicateRef,
}
impl PredicateEvaluator for DefaultPredicateEvaluator {
fn evaluate(&self, batch: &dyn EngineData) -> DeltaResult<Box<dyn EngineData>> {
debug!("Arrow evaluator evaluating: {:#?}", self.predicate);
let batch = extract_record_batch(batch)?;
let array = evaluate_predicate(&self.predicate, batch, false)?;
let schema = ArrowSchema::new(vec![ArrowField::new(
"output",
ArrowDataType::Boolean,
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?;
Ok(Box::new(ArrowEngineData::new(batch)))
}
}