use std::any::Any;
use std::fmt;
use std::sync::Arc;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::execution::SendableRecordBatchStream;
use datafusion_common::Result as DfResult;
use datafusion_execution::TaskContext;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::StreamExt;
use tracing::debug;
use crate::model::fao::FaoOperator;
#[derive(Debug)]
pub struct NeuralScanExec {
child: Arc<dyn ExecutionPlan>,
operator: Arc<dyn FaoOperator>,
schema: SchemaRef,
properties: PlanProperties,
}
impl NeuralScanExec {
pub fn new(child: Arc<dyn ExecutionPlan>, operator: Arc<dyn FaoOperator>) -> Self {
let schema = operator.output_schema().clone();
let properties = child.properties().clone();
Self {
child,
operator,
schema,
properties,
}
}
}
impl DisplayAs for NeuralScanExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"NeuralScanExec: fao={}@{}, model={}",
self.operator.function_id(),
self.operator.version(),
self.operator.model_id()
)
}
}
impl ExecutionPlan for NeuralScanExec {
fn name(&self) -> &str {
"NeuralScanExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.child]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(NeuralScanExec::new(
children[0].clone(),
Arc::clone(&self.operator),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DfResult<SendableRecordBatchStream> {
let child_stream = self.child.execute(partition, context)?;
let operator = Arc::clone(&self.operator);
let schema = self.schema.clone();
let output_stream = child_stream.then(move |batch_result| {
let op = Arc::clone(&operator);
async move {
match batch_result {
Ok(batch) => {
debug!(
rows = batch.num_rows(),
fao = op.function_id(),
"NeuralScanExec: processing batch"
);
op.execute(batch)
.await
.map_err(|e| datafusion_common::DataFusionError::External(Box::new(e)))
}
Err(e) => Err(e),
}
}
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
output_stream,
)))
}
}