use arrow::datatypes::{Field as ArrowField, Schema as ArrowSchema};
use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::{
execution::SendableRecordBatchStream, logical_expr::Expr,
physical_plan::projection::ProjectionExec,
};
use datafusion_common::DFSchema;
use datafusion_physical_expr::{expressions, PhysicalExpr};
use futures::TryStreamExt;
use snafu::{location, Location};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use lance_core::{datatypes::Schema, Error, Result};
use crate::{
exec::{execute_plan, LanceExecutionOptions, OneShotExec},
planner::Planner,
};
pub struct ProjectionPlan {
pub physical_schema: Arc<Schema>,
pub physical_df_schema: Arc<DFSchema>,
pub requested_output_expr: Option<Vec<(Expr, String)>>,
}
impl ProjectionPlan {
pub fn try_new(
base_schema: &Schema,
columns: &[(impl AsRef<str>, impl AsRef<str>)],
) -> Result<Self> {
let arrow_schema = Arc::new(ArrowSchema::from(base_schema));
let planner = Planner::new(arrow_schema);
let mut output = HashMap::new();
let mut physical_cols_set = HashSet::new();
let mut physical_cols = vec![];
for (output_name, raw_expr) in columns {
if output.contains_key(output_name.as_ref()) {
return Err(Error::io(
format!("Duplicate column name: {}", output_name.as_ref()),
location!(),
));
}
let expr = planner.parse_expr(raw_expr.as_ref())?;
for col in Planner::column_names_in_expr(&expr) {
if physical_cols_set.contains(&col) {
continue;
}
physical_cols.push(col.clone());
physical_cols_set.insert(col);
}
output.insert(output_name.as_ref().to_string(), expr);
}
let physical_schema = Arc::new(base_schema.project(&physical_cols)?);
let mut output_cols = vec![];
for (name, _) in columns {
output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
}
let requested_output_expr = Some(output_cols);
let physical_arrow_schema = ArrowSchema::from(physical_schema.as_ref());
let physical_df_schema = Arc::new(DFSchema::try_from(physical_arrow_schema).unwrap());
Ok(Self {
physical_schema,
physical_df_schema,
requested_output_expr,
})
}
pub fn new_empty(base_schema: Arc<Schema>) -> Self {
let physical_arrow_schema = ArrowSchema::from(base_schema.as_ref());
let physical_df_schema = Arc::new(DFSchema::try_from(physical_arrow_schema).unwrap());
Self {
physical_schema: base_schema,
physical_df_schema,
requested_output_expr: None,
}
}
pub fn arrow_schema(&self) -> &ArrowSchema {
self.physical_df_schema.as_arrow()
}
pub fn arrow_schema_ref(&self) -> SchemaRef {
Arc::new(self.physical_df_schema.as_arrow().clone())
}
pub fn to_physical_exprs(&self) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
if let Some(output_expr) = &self.requested_output_expr {
output_expr
.iter()
.map(|(expr, name)| {
Ok((
datafusion::physical_expr::create_physical_expr(
expr,
self.physical_df_schema.as_ref(),
&Default::default(),
)?,
name.clone(),
))
})
.collect::<Result<Vec<_>>>()
} else {
self.physical_schema
.fields
.iter()
.map(|f| {
Ok((
expressions::col(f.name.as_str(), self.physical_df_schema.as_arrow())?
.clone(),
f.name.clone(),
))
})
.collect::<Result<Vec<_>>>()
}
}
pub fn output_schema(&self) -> Result<ArrowSchema> {
let exprs = self.to_physical_exprs()?;
let fields = exprs
.iter()
.map(|(expr, name)| {
Ok(ArrowField::new(
name,
expr.data_type(self.arrow_schema())?,
expr.nullable(self.arrow_schema())?,
))
})
.collect::<Result<Vec<_>>>()?;
Ok(ArrowSchema::new(fields))
}
pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
if self.requested_output_expr.is_none() {
return Ok(batch);
}
let src = Arc::new(OneShotExec::from_batch(batch));
let physical_exprs = self.to_physical_exprs()?;
let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
let stream = execute_plan(projection, LanceExecutionOptions::default())?;
let batches = stream.try_collect::<Vec<_>>().await?;
if batches.len() != 1 {
Err(Error::Internal {
message: "Expected exactly one batch".to_string(),
location: location!(),
})
} else {
Ok(batches.into_iter().next().unwrap())
}
}
pub fn project_stream(
&self,
stream: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
if self.requested_output_expr.is_none() {
return Ok(stream);
}
let src = Arc::new(OneShotExec::new(stream));
let physical_exprs = self.to_physical_exprs()?;
let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
execute_plan(projection, LanceExecutionOptions::default())
}
}