use std::collections::HashSet;
use std::time::Instant;
use crate::features::storage::api as storage_api;
use plexus_serde::{Op, Plan};
use super::super::super::{ExecuteParams, ExplainError, Result, Row, RowStream};
use super::aggregate::{apply_aggregate, validate_supported_aggregate};
use super::execute_deserialized_plan;
use super::plan::validate_sort;
use super::types::{sort_rows, DirectUnion, UnionChain, UnionWrapper};
pub(super) fn union_chain_from_root(plan: &Plan) -> Result<Option<UnionChain>> {
let root = plan.root_op as usize;
let Some(Op::Return { .. }) = plan.ops.get(root) else {
return Err(ExplainError::SerializedPlanMalformed(
"root op must be Return".to_string(),
));
};
let mut idx = root;
let mut wrappers = Vec::new();
loop {
let Some(op) = plan.ops.get(idx) else {
return Err(ExplainError::SerializedPlanMalformed(format!(
"invalid root chain op index {}",
idx
)));
};
match op {
Op::Return { input } => idx = *input as usize,
Op::Sort { input, keys, dirs } => {
wrappers.push(UnionWrapper::Sort(validate_sort(keys, dirs)?));
idx = *input as usize;
}
Op::Limit { input, count, skip } => {
if *skip != 0 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Limit with skip != 0".to_string(),
));
}
if *count < 0 {
return Err(ExplainError::SerializedPlanMalformed(
"Limit count must be non-negative".to_string(),
));
}
wrappers.push(UnionWrapper::Limit(*count as usize));
idx = *input as usize;
}
Op::Aggregate {
input, keys, aggs, ..
} => {
wrappers.push(UnionWrapper::Aggregate(validate_supported_aggregate(
keys, aggs,
)?));
idx = *input as usize;
}
Op::PathConstruct { input, .. } => {
wrappers.push(UnionWrapper::PathConstruct);
idx = *input as usize;
}
Op::Union { lhs, rhs, all, .. } => {
return Ok(Some(UnionChain {
union: DirectUnion {
lhs: *lhs,
rhs: *rhs,
all: *all,
},
wrappers,
}));
}
_ => return Ok(None),
}
if idx >= plan.ops.len() {
return Err(ExplainError::SerializedPlanMalformed(format!(
"invalid root chain reference index {}",
idx
)));
}
}
}
pub(super) fn execute_union_chain(
plan: &Plan,
chain: &UnionChain,
params: &ExecuteParams,
handle: &mut storage_api::StorageHandle,
) -> Result<RowStream> {
let start = Instant::now();
let lhs_plan = branch_plan_with_return(plan, chain.union.lhs)?;
let rhs_plan = branch_plan_with_return(plan, chain.union.rhs)?;
let lhs_stream = execute_deserialized_plan(&lhs_plan, params, handle)?;
let rhs_stream = execute_deserialized_plan(&rhs_plan, params, handle)?;
let mut rows = lhs_stream.rows;
rows.extend(rhs_stream.rows);
if !chain.union.all {
let mut seen = HashSet::new();
rows.retain(|row| seen.insert(row.node_id));
}
apply_union_wrappers(&mut rows, &chain.wrappers);
Ok(RowStream {
rows,
scanned_nodes: lhs_stream
.scanned_nodes
.saturating_add(rhs_stream.scanned_nodes),
latency_micros: start.elapsed().as_micros(),
morsels_processed: lhs_stream
.morsels_processed
.saturating_add(rhs_stream.morsels_processed),
rerank_batches: lhs_stream
.rerank_batches
.saturating_add(rhs_stream.rerank_batches),
parallel_workers: lhs_stream.parallel_workers.max(rhs_stream.parallel_workers),
})
}
fn apply_union_wrappers(rows: &mut Vec<Row>, wrappers: &[UnionWrapper]) {
for wrapper in wrappers.iter().rev() {
match wrapper {
UnionWrapper::Sort(sort) => sort_rows(rows, *sort),
UnionWrapper::Limit(limit) => {
if rows.len() > *limit {
rows.truncate(*limit);
}
}
UnionWrapper::Aggregate(aggregate) => {
let owned = std::mem::take(rows);
*rows = apply_aggregate(owned, *aggregate);
}
UnionWrapper::PathConstruct => {}
}
}
}
fn branch_plan_with_return(plan: &Plan, root_input: u32) -> Result<Plan> {
let idx = root_input as usize;
if idx >= plan.ops.len() {
return Err(ExplainError::SerializedPlanMalformed(format!(
"invalid union branch index {}",
root_input
)));
}
let mut ops = plan.ops.clone();
let return_idx = ops.len() as u32;
ops.push(Op::Return { input: root_input });
Ok(Plan {
version: plan.version.clone(),
ops,
root_op: return_idx,
})
}