use std::sync::Arc;
use polars_core::schema::Schema;
use super::compute_node_prelude::*;
pub struct OrderedUnionNode {
cur_input_idx: usize,
max_morsel_seq_sent: MorselSeq,
morsel_offset: MorselSeq,
output_schema: Arc<Schema>,
}
impl OrderedUnionNode {
pub fn new(output_schema: Arc<Schema>) -> Self {
Self {
cur_input_idx: 0,
max_morsel_seq_sent: MorselSeq::new(0),
morsel_offset: MorselSeq::new(0),
output_schema,
}
}
}
impl ComputeNode for OrderedUnionNode {
fn name(&self) -> &str {
"ordered-union"
}
fn update_state(
&mut self,
recv: &mut [PortState],
send: &mut [PortState],
_state: &StreamingExecutionState,
) -> PolarsResult<()> {
assert!(self.cur_input_idx <= recv.len() && send.len() == 1);
while self.cur_input_idx < recv.len() && recv[self.cur_input_idx] == PortState::Done {
self.cur_input_idx += 1;
}
if self.cur_input_idx < recv.len() {
core::mem::swap(&mut recv[self.cur_input_idx], &mut send[0]);
} else {
send[0] = PortState::Done;
}
for r in recv.iter_mut().skip(self.cur_input_idx + 1) {
*r = PortState::Blocked;
}
self.morsel_offset = self.max_morsel_seq_sent.successor();
Ok(())
}
fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv_ports: &mut [Option<RecvPort<'_>>],
send_ports: &mut [Option<SendPort<'_>>],
_state: &'s StreamingExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
let ready_count = recv_ports.iter().filter(|r| r.is_some()).count();
assert!(ready_count == 1 && send_ports.len() == 1);
let receivers = recv_ports[self.cur_input_idx].take().unwrap().parallel();
let senders = send_ports[0].take().unwrap().parallel();
let mut inner_handles = Vec::new();
for (mut recv, mut send) in receivers.into_iter().zip(senders) {
let output_schema = self.output_schema.clone();
let morsel_offset = self.morsel_offset;
inner_handles.push(scope.spawn_task(TaskPriority::High, async move {
let mut max_seq = MorselSeq::new(0);
while let Ok(mut morsel) = recv.recv().await {
morsel.df_mut().ensure_matches_schema(&output_schema)?;
let seq = morsel.seq().offset_by(morsel_offset);
max_seq = max_seq.max(seq);
morsel.set_seq(seq);
if send.send(morsel).await.is_err() {
break;
}
}
PolarsResult::Ok(max_seq)
}));
}
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
for handle in inner_handles {
self.max_morsel_seq_sent = self.max_morsel_seq_sent.max(handle.await?);
}
Ok(())
}));
}
}