use std::hash::Hash;
use std::sync::Arc;
use crate::physical_expr::PhysicalExpr;
use arrow::datatypes::FieldRef;
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion_common::{Result, exec_err, internal_err};
use datafusion_expr::ColumnarValue;
#[derive(Debug, Clone)]
pub struct LambdaVariable {
index: usize,
field: FieldRef,
}
impl Eq for LambdaVariable {}
impl PartialEq for LambdaVariable {
fn eq(&self, other: &Self) -> bool {
self.index == other.index && self.field == other.field
}
}
impl Hash for LambdaVariable {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.index.hash(state);
self.field.hash(state);
}
}
impl LambdaVariable {
pub fn new(index: usize, field: FieldRef) -> Self {
Self { index, field }
}
pub fn name(&self) -> &str {
self.field.name()
}
pub fn index(&self) -> usize {
self.index
}
pub fn field(&self) -> &FieldRef {
&self.field
}
}
impl std::fmt::Display for LambdaVariable {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}@{}", self.name(), self.index)
}
}
impl PhysicalExpr for LambdaVariable {
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.field.data_type().clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(self.field.is_nullable())
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
if self.index >= batch.num_columns() {
return internal_err!(
"PhysicalExpr LambdaVariable references column '{}' at index {} (zero-based) but batch only has {} columns: {:?}",
self.name(),
self.index,
batch.num_columns(),
batch
.schema_ref()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
);
}
if self.field.as_ref() != batch.schema_ref().field(self.index) {
return exec_err!(
"Field of physical LambdaVariable with index {} doesn't match batch field during evaluation {} != {}",
self.index,
self.field,
batch.schema_ref().field(self.index)
);
}
Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index))))
}
fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
Ok(Arc::clone(&self.field))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(self)
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}@{}", self.name(), self.index)
}
}
pub fn lambda_variable(name: &str, schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
let index = schema.index_of(name)?;
let field = Arc::clone(&schema.fields()[index]);
Ok(Arc::new(LambdaVariable::new(index, field)))
}