use crate::error::DbxResult;
use crate::sql::executor::operators::PhysicalOperator;
use crate::sql::planner::types::ShuffleSalting;
use arrow::array::RecordBatch;
use arrow::datatypes::Schema;
use rand::Rng;
use tokio::sync::mpsc;
pub struct GridShuffleWriterOperator {
input: Box<dyn PhysicalOperator>,
_hash_params: Vec<usize>,
_exchange_id: usize,
salting: ShuffleSalting,
target_senders: Vec<mpsc::Sender<DbxResult<Option<Vec<u8>>>>>,
}
impl GridShuffleWriterOperator {
pub fn new(
input: Box<dyn PhysicalOperator>,
hash_params: Vec<usize>,
exchange_id: usize,
salting: ShuffleSalting,
target_senders: Vec<mpsc::Sender<DbxResult<Option<Vec<u8>>>>>,
) -> Self {
Self {
input,
_hash_params: hash_params,
_exchange_id: exchange_id,
salting,
target_senders,
}
}
fn serialize_batch(&self, batch: &RecordBatch) -> DbxResult<Vec<u8>> {
crate::grid::protocol::serialize_batch_to_ipc(batch)
}
}
impl PhysicalOperator for GridShuffleWriterOperator {
fn schema(&self) -> &Schema {
self.input.schema()
}
fn next(&mut self) -> DbxResult<Option<RecordBatch>> {
let batch_opt = self.input.next()?;
let batch = match batch_opt {
Some(b) => b,
None => return Ok(None),
};
let num_targets = self.target_senders.len();
if num_targets == 0 || batch.num_rows() == 0 {
return Ok(Some(batch)); }
match &self.salting {
ShuffleSalting::ReplicateProbe { factor: _ } => {
let bytes = self.serialize_batch(&batch)?;
for sender in &self.target_senders {
let _ = sender.blocking_send(Ok(Some(bytes.clone())));
}
}
ShuffleSalting::RandomDistributed { factor: _ } => {
let target_idx = rand::thread_rng().gen_range(0..num_targets);
let bytes = self.serialize_batch(&batch)?;
let _ = self.target_senders[target_idx].blocking_send(Ok(Some(bytes)));
}
ShuffleSalting::None => {
let target_idx = 0; let bytes = self.serialize_batch(&batch)?;
let _ = self.target_senders[target_idx].blocking_send(Ok(Some(bytes)));
}
}
Ok(Some(RecordBatch::new_empty(std::sync::Arc::new(
self.input.schema().clone(),
))))
}
fn reset(&mut self) -> DbxResult<()> {
self.input.reset()
}
}