use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow_array::builder::UInt32Builder;
use arrow_array::{Array, ArrayRef, RecordBatch, UInt64Array};
use arrow_schema::{Field, Schema, SchemaRef};
use datafusion::common::{Result as DFResult, ScalarValue};
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{Stream, TryStreamExt};
use super::common::compute_plan_properties;
use super::scan::GraphScanExec;
pub(crate) const MAX_VIDS_PER_CHUNK: usize = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProbeSide {
Left,
Right,
}
#[derive(Debug, Clone, Copy)]
pub struct EquiPair {
pub left_col_idx: usize,
pub right_col_idx: usize,
}
impl EquiPair {
fn build_col(&self, probe_side: ProbeSide) -> usize {
match probe_side {
ProbeSide::Left => self.right_col_idx,
ProbeSide::Right => self.left_col_idx,
}
}
fn probe_col(&self, probe_side: ProbeSide) -> usize {
match probe_side {
ProbeSide::Left => self.left_col_idx,
ProbeSide::Right => self.right_col_idx,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VidJoinKind {
Inner,
Left,
}
pub struct VidLookupJoinExec {
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
probe_side: ProbeSide,
pairs: Vec<EquiPair>,
join_kind: VidJoinKind,
output_schema: SchemaRef,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
}
impl fmt::Debug for VidLookupJoinExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VidLookupJoinExec")
.field("probe_side", &self.probe_side)
.field("pairs", &self.pairs.len())
.field("join_kind", &self.join_kind)
.finish()
}
}
impl VidLookupJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
probe_side: ProbeSide,
pairs: Vec<EquiPair>,
join_kind: VidJoinKind,
) -> DFResult<Self> {
if pairs.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"VidLookupJoinExec: pairs must be non-empty".into(),
));
}
let probe_plan = match probe_side {
ProbeSide::Left => &left,
ProbeSide::Right => &right,
};
if probe_plan
.as_any()
.downcast_ref::<GraphScanExec>()
.is_none()
{
return Err(datafusion::error::DataFusionError::Plan(
"VidLookupJoinExec: probe-side child must be a GraphScanExec".into(),
));
}
let output_schema = concat_schemas(&left.schema(), &right.schema());
let properties = compute_plan_properties(output_schema.clone());
Ok(Self {
left,
right,
probe_side,
pairs,
join_kind,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
})
}
fn build_child(&self) -> &Arc<dyn ExecutionPlan> {
match self.probe_side {
ProbeSide::Left => &self.right,
ProbeSide::Right => &self.left,
}
}
fn probe_child(&self) -> &Arc<dyn ExecutionPlan> {
match self.probe_side {
ProbeSide::Left => &self.left,
ProbeSide::Right => &self.right,
}
}
}
impl DisplayAs for VidLookupJoinExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"VidLookupJoinExec: probe={:?}, pairs={}, kind={:?}",
self.probe_side,
self.pairs.len(),
self.join_kind
)
}
}
impl ExecutionPlan for VidLookupJoinExec {
fn name(&self) -> &str {
"VidLookupJoinExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![self.build_child()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(datafusion::error::DataFusionError::Plan(format!(
"VidLookupJoinExec expects exactly one child (the build side); got {}",
children.len()
)));
}
let new_build = children.into_iter().next().unwrap();
let (new_left, new_right) = match self.probe_side {
ProbeSide::Left => (self.left.clone(), new_build),
ProbeSide::Right => (new_build, self.right.clone()),
};
Ok(Arc::new(Self::try_new(
new_left,
new_right,
self.probe_side,
self.pairs.clone(),
self.join_kind,
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let build = self.build_child().clone();
let probe = self.probe_child().clone();
let probe_side = self.probe_side;
let pairs = self.pairs.clone();
let join_kind = self.join_kind;
let output_schema = self.output_schema.clone();
let left_schema = self.left.schema();
let right_schema = self.right.schema();
let fut = async move {
run_join(
build,
probe,
probe_side,
pairs,
join_kind,
left_schema,
right_schema,
output_schema.clone(),
partition,
context,
)
.await
};
Ok(Box::pin(VidLookupJoinStream {
state: VidLookupJoinStreamState::Running(Box::pin(fut)),
schema: self.output_schema.clone(),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
enum VidLookupJoinStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
Done,
}
struct VidLookupJoinStream {
state: VidLookupJoinStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for VidLookupJoinStream {
type Item = DFResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let metrics = self.metrics.clone();
let _timer = metrics.elapsed_compute().timer();
match &mut self.state {
VidLookupJoinStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batch)) => {
self.metrics.record_output(batch.num_rows());
self.state = VidLookupJoinStreamState::Done;
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Err(e)) => {
self.state = VidLookupJoinStreamState::Done;
Poll::Ready(Some(Err(e)))
}
Poll::Pending => Poll::Pending,
},
VidLookupJoinStreamState::Done => Poll::Ready(None),
}
}
}
impl RecordBatchStream for VidLookupJoinStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[allow(clippy::too_many_arguments)]
async fn run_join(
build: Arc<dyn ExecutionPlan>,
probe: Arc<dyn ExecutionPlan>,
probe_side: ProbeSide,
pairs: Vec<EquiPair>,
join_kind: VidJoinKind,
left_schema: SchemaRef,
right_schema: SchemaRef,
output_schema: SchemaRef,
partition: usize,
context: Arc<TaskContext>,
) -> DFResult<RecordBatch> {
let build_stream = build.execute(partition, context)?;
let build_batches: Vec<RecordBatch> = build_stream.try_collect().await?;
if build_batches.is_empty() {
return Ok(RecordBatch::new_empty(output_schema));
}
let anchor = pairs[0];
let build_anchor_col_idx = anchor.build_col(probe_side);
let mut vid_set: HashSet<u64> = HashSet::new();
for batch in &build_batches {
let arr = batch.column(build_anchor_col_idx);
let u64_arr = arr.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"VidLookupJoinExec: build anchor column at idx {} is not UInt64 (got {:?})",
build_anchor_col_idx,
arr.data_type()
))
})?;
for i in 0..u64_arr.len() {
if !u64_arr.is_null(i) {
vid_set.insert(u64_arr.value(i));
}
}
}
let probe_scan = probe
.as_any()
.downcast_ref::<GraphScanExec>()
.expect("planner ensured probe is GraphScanExec");
let probe_batch = if vid_set.is_empty() {
RecordBatch::new_empty(probe_scan.schema())
} else {
let vids: Vec<u64> = vid_set.iter().copied().collect();
let mut chunks: Vec<RecordBatch> = Vec::new();
for chunk in vids.chunks(MAX_VIDS_PER_CHUNK) {
let batch = probe_scan.execute_with_vid_filter(chunk).await?;
if batch.num_rows() > 0 {
chunks.push(batch);
}
}
if chunks.is_empty() {
RecordBatch::new_empty(probe_scan.schema())
} else if chunks.len() == 1 {
chunks.into_iter().next().unwrap()
} else {
let schema = chunks[0].schema();
arrow::compute::concat_batches(&schema, &chunks)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?
}
};
let probe_vid_idx = locate_vid_column(&probe_batch.schema())?;
let probe_anchor_col_idx = anchor.probe_col(probe_side);
if probe_anchor_col_idx != probe_vid_idx {
return Err(datafusion::error::DataFusionError::Plan(format!(
"VidLookupJoinExec: anchor probe column idx {} != probe schema's _vid idx {} \
(planner pre-check should have aligned these)",
probe_anchor_col_idx, probe_vid_idx
)));
}
let probe_index = build_probe_vid_index(&probe_batch, probe_vid_idx)?;
let n_non_anchor = pairs.len() - 1;
let mut matches: Vec<JoinMatch> = Vec::new();
let mut unmatched: Vec<(usize, usize)> = Vec::new();
for (build_batch_idx, build_batch) in build_batches.iter().enumerate() {
let build_anchor_arr = build_batch
.column(build_anchor_col_idx)
.as_any()
.downcast_ref::<UInt64Array>()
.expect("validated above");
for build_row_idx in 0..build_anchor_arr.len() {
if build_anchor_arr.is_null(build_row_idx) {
if join_kind == VidJoinKind::Left {
unmatched.push((build_batch_idx, build_row_idx));
}
continue;
}
let key = build_anchor_arr.value(build_row_idx);
let Some(probe_rows) = probe_index.get(&key) else {
if join_kind == VidJoinKind::Left {
unmatched.push((build_batch_idx, build_row_idx));
}
continue;
};
let mut had_match_for_this_build_row = false;
for &probe_row_idx in probe_rows {
let mut all_match = true;
for pair in &pairs[1..1 + n_non_anchor] {
let build_col_idx = pair.build_col(probe_side);
let probe_col_idx = pair.probe_col(probe_side);
if !values_equal(
build_batch.column(build_col_idx),
build_row_idx,
probe_batch.column(probe_col_idx),
probe_row_idx,
)? {
all_match = false;
break;
}
}
if all_match {
matches.push(JoinMatch {
build_batch_idx,
build_row_idx,
probe_row_idx,
});
had_match_for_this_build_row = true;
}
}
if !had_match_for_this_build_row && join_kind == VidJoinKind::Left {
unmatched.push((build_batch_idx, build_row_idx));
}
}
}
emit_joined_batch(
&build_batches,
&probe_batch,
&matches,
&unmatched,
probe_side,
&left_schema,
&right_schema,
&output_schema,
)
}
#[derive(Clone, Copy)]
struct JoinMatch {
build_batch_idx: usize,
build_row_idx: usize,
probe_row_idx: usize,
}
fn build_probe_vid_index(
probe_batch: &RecordBatch,
probe_vid_idx: usize,
) -> DFResult<HashMap<u64, Vec<usize>>> {
let arr = probe_batch
.column(probe_vid_idx)
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(
"VidLookupJoinExec: probe `_vid` column is not UInt64".into(),
)
})?;
let mut index: HashMap<u64, Vec<usize>> = HashMap::with_capacity(arr.len());
for i in 0..arr.len() {
if !arr.is_null(i) {
index.entry(arr.value(i)).or_default().push(i);
}
}
Ok(index)
}
fn values_equal(a_col: &ArrayRef, a_row: usize, b_col: &ArrayRef, b_row: usize) -> DFResult<bool> {
let a = ScalarValue::try_from_array(a_col, a_row)?;
let b = ScalarValue::try_from_array(b_col, b_row)?;
Ok(a == b)
}
fn locate_vid_column(schema: &SchemaRef) -> DFResult<usize> {
schema
.fields()
.iter()
.enumerate()
.find_map(|(i, f)| {
if f.name() == "_vid" || f.name().ends_with("._vid") {
Some(i)
} else {
None
}
})
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(
"VidLookupJoinExec: probe schema has no _vid column".into(),
)
})
}
fn concat_schemas(left: &SchemaRef, right: &SchemaRef) -> SchemaRef {
let mut fields: Vec<Field> = Vec::with_capacity(left.fields().len() + right.fields().len());
for f in left.fields() {
fields.push(f.as_ref().clone());
}
for f in right.fields() {
fields.push(f.as_ref().clone());
}
Arc::new(Schema::new(fields))
}
#[allow(clippy::too_many_arguments)]
fn emit_joined_batch(
build_batches: &[RecordBatch],
probe_batch: &RecordBatch,
matches: &[JoinMatch],
unmatched: &[(usize, usize)],
probe_side: ProbeSide,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
output_schema: &SchemaRef,
) -> DFResult<RecordBatch> {
let total_rows = matches.len() + unmatched.len();
if total_rows == 0 {
return Ok(RecordBatch::new_empty(output_schema.clone()));
}
let n_build_batches = build_batches.len();
let mut match_take_per_build_batch: Vec<Vec<u32>> =
(0..n_build_batches).map(|_| Vec::new()).collect();
let mut match_probe_take: Vec<u32> = Vec::with_capacity(matches.len());
for m in matches {
match_take_per_build_batch[m.build_batch_idx].push(m.build_row_idx as u32);
match_probe_take.push(m.probe_row_idx as u32);
}
let mut unmatched_take_per_build_batch: Vec<Vec<u32>> =
(0..n_build_batches).map(|_| Vec::new()).collect();
for &(bb_idx, br_idx) in unmatched {
unmatched_take_per_build_batch[bb_idx].push(br_idx as u32);
}
let n_build_cols = build_batches[0].num_columns();
let mut build_columns: Vec<ArrayRef> = Vec::with_capacity(n_build_cols);
for col_idx in 0..n_build_cols {
let mut chunks: Vec<ArrayRef> = Vec::new();
for batch_idx in 0..n_build_batches {
if !match_take_per_build_batch[batch_idx].is_empty() {
chunks.push(take_indices(
build_batches[batch_idx].column(col_idx),
&match_take_per_build_batch[batch_idx],
)?);
}
if !unmatched_take_per_build_batch[batch_idx].is_empty() {
chunks.push(take_indices(
build_batches[batch_idx].column(col_idx),
&unmatched_take_per_build_batch[batch_idx],
)?);
}
}
build_columns.push(concat_arrays(&chunks)?);
}
let n_probe_cols = probe_batch.num_columns();
let mut probe_columns: Vec<ArrayRef> = Vec::with_capacity(n_probe_cols);
let probe_match_arr = take_indices_u32_slice(&match_probe_take);
let n_unmatched = unmatched.len();
for col_idx in 0..n_probe_cols {
let probe_col = probe_batch.column(col_idx);
let matched_part = if match_probe_take.is_empty() {
arrow_array::new_empty_array(probe_col.data_type())
} else {
arrow::compute::take(probe_col.as_ref(), &probe_match_arr, None)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?
};
if n_unmatched == 0 {
probe_columns.push(matched_part);
} else {
let null_part = arrow_array::new_null_array(probe_col.data_type(), n_unmatched);
probe_columns.push(concat_arrays(&[matched_part, null_part])?);
}
}
let (left_columns, right_columns) = match probe_side {
ProbeSide::Left => (probe_columns, build_columns),
ProbeSide::Right => (build_columns, probe_columns),
};
let _ = (left_schema, right_schema);
let mut all_columns = left_columns;
all_columns.extend(right_columns);
RecordBatch::try_new(output_schema.clone(), all_columns)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
}
fn take_indices(col: &ArrayRef, indices: &[u32]) -> DFResult<ArrayRef> {
let take_array = take_indices_u32_slice(indices);
arrow::compute::take(col.as_ref(), &take_array, None)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
}
fn take_indices_u32_slice(indices: &[u32]) -> arrow_array::UInt32Array {
let mut b = UInt32Builder::with_capacity(indices.len());
for &i in indices {
b.append_value(i);
}
b.finish()
}
fn concat_arrays(arrays: &[ArrayRef]) -> DFResult<ArrayRef> {
if arrays.len() == 1 {
return Ok(arrays[0].clone());
}
let refs: Vec<&dyn Array> = arrays.iter().map(|a| a.as_ref()).collect();
arrow::compute::concat(&refs)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
}