use std::sync::Arc;
use async_trait::async_trait;
use futures::future::{BoxFuture, Shared, WeakShared};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use vortex_error::{
SharedVortexResult, VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err,
};
use vortex_utils::aliases::hash_map::HashMap;
use crate::Canonical;
use crate::operator::{BatchBindCtx, BatchExecution, BatchExecutionRef, OperatorKey, OperatorRef};
use crate::pipeline::operator::PipelineOperator;
#[derive(Default)]
pub struct Executor {
execution_cache: HashMap<
OperatorKey<OperatorRef>,
WeakShared<BoxFuture<'static, SharedVortexResult<Canonical>>>,
>,
}
impl Executor {
pub fn execute(
&mut self,
operator: OperatorRef,
) -> BoxFuture<'static, VortexResult<Canonical>> {
let execution = self.batch_execution(&operator);
async move { execution?.execute().await }.boxed()
}
fn batch_execution(&mut self, operator: &OperatorRef) -> VortexResult<BatchExecutionRef> {
let key = OperatorKey(operator.clone());
if let Some(weak_shared) = self.execution_cache.get(&key) {
if let Some(shared) = weak_shared.upgrade() {
return Ok(Box::new(SharedBatchExecution(shared)));
} else {
self.execution_cache.remove(&key);
}
}
let operator = match PipelineOperator::new(operator.clone()) {
None => operator.clone(),
Some(pipeline_op) => Arc::new(pipeline_op),
};
log::info!("Executing operator: {}", operator.display_tree());
println!("Executing operator: {}", operator.display_tree());
let mut children: Vec<_> = operator
.children()
.iter()
.map(|child| self.batch_execution(child))
.map_ok(Some)
.try_collect()?;
let execution = operator
.as_batch()
.ok_or_else(|| {
vortex_err!(
"Operator does not support batch execution OR pipelined execution: {:?}",
operator
)
})?
.bind(&mut children)?;
let shared_future = execution.execute().map_err(Arc::new).boxed().shared();
self.execution_cache.insert(
OperatorKey(operator),
shared_future.downgrade().vortex_expect("just created"),
);
Ok(Box::new(SharedBatchExecution(shared_future)))
}
}
impl BatchBindCtx for Vec<Option<BatchExecutionRef>> {
fn child(&mut self, idx: usize) -> VortexResult<BatchExecutionRef> {
if idx >= self.len() {
vortex_bail!("Child index {} out of bounds", idx);
}
self[idx]
.take()
.ok_or_else(|| vortex_err!("Child already consumed"))
}
}
struct SharedBatchExecution(Shared<BoxFuture<'static, SharedVortexResult<Canonical>>>);
#[async_trait]
impl BatchExecution for SharedBatchExecution {
async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
self.0.await.map_err(VortexError::from)
}
}
#[cfg(test)]
mod tests {
use futures::executor::block_on;
use vortex_buffer::buffer;
use vortex_metrics::VortexMetrics;
use super::*;
use crate::compute::Operator as Op;
use crate::operator::compare::CompareOperator;
use crate::operator::metrics::MetricsOperator;
use crate::{IntoArray, ToCanonical};
#[test]
fn test_basic_execution() {
let array = buffer![1i32, 2, 3, 4].into_array().to_primitive();
let mut executor = Executor::default();
let result = block_on(executor.execute(Arc::new(array.clone()))).unwrap();
assert_eq!(
result.into_primitive().as_slice::<i32>(),
array.as_slice::<i32>()
);
}
#[test]
fn test_pipelined_execution() {
let lhs = buffer![1i32, 2, 3].into_array().to_primitive();
let rhs = buffer![3i32, 2, 1].into_array().to_primitive();
let compare =
Arc::new(CompareOperator::try_new(Arc::new(lhs), Arc::new(rhs), Op::Gt).unwrap());
let mut executor = Executor::default();
let result = block_on(executor.execute(compare)).unwrap();
assert_eq!(
result.into_bool().bool_vec().unwrap(),
vec![false, false, true]
);
}
#[test]
fn test_common_subtree_elimination() {
let array = buffer![1i32, 2, 3, 4].into_array().to_primitive();
let array = Arc::new(MetricsOperator::new(
Arc::new(array),
VortexMetrics::default(),
));
let compare =
Arc::new(CompareOperator::try_new(array.clone(), array.clone(), Op::Gt).unwrap());
let compare = Arc::new(MetricsOperator::new(compare, VortexMetrics::default()));
let mut executor = Executor::default();
let result = block_on(executor.execute(compare.clone())).unwrap();
assert_eq!(
result.into_bool().bool_vec().unwrap(),
vec![false, false, false, false]
);
assert_eq!(compare.metrics().timer("operator.operator.step").count(), 1);
assert_eq!(array.metrics().timer("operator.batch.execute").count(), 1);
}
}