use std::collections::VecDeque;
use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
use super::compute_node_prelude::*;
use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::SourceToken;
enum BufferedStream {
Open(VecDeque<Morsel>),
Closed,
}
impl BufferedStream {
fn new() -> Self {
Self::Open(VecDeque::new())
}
}
pub struct MultiplexerNode {
buffers: Vec<BufferedStream>,
}
impl MultiplexerNode {
pub fn new() -> Self {
Self {
buffers: Vec::default(),
}
}
}
impl ComputeNode for MultiplexerNode {
fn name(&self) -> &str {
"multiplexer"
}
fn update_state(
&mut self,
recv: &mut [PortState],
send: &mut [PortState],
_state: &StreamingExecutionState,
) -> PolarsResult<()> {
assert!(recv.len() == 1 && !send.is_empty());
self.buffers.resize_with(send.len(), BufferedStream::new);
for (s, b) in send.iter().zip(&mut self.buffers) {
if *s == PortState::Done {
*b = BufferedStream::Closed;
}
}
let input_done = recv[0] == PortState::Done
&& self.buffers.iter().all(|b| match b {
BufferedStream::Open(v) => v.is_empty(),
BufferedStream::Closed => true,
});
let output_done = send.iter().all(|p| *p == PortState::Done);
if input_done || output_done {
recv[0] = PortState::Done;
for s in send {
*s = PortState::Done;
}
return Ok(());
}
let all_blocked = send.iter().all(|p| *p == PortState::Blocked);
for (i, s) in send.iter_mut().enumerate() {
let buffer_empty = match &self.buffers[i] {
BufferedStream::Open(v) => v.is_empty(),
BufferedStream::Closed => true,
};
*s = if buffer_empty && recv[0] == PortState::Done {
PortState::Done
} else if !buffer_empty || recv[0] == PortState::Ready {
PortState::Ready
} else {
PortState::Blocked
};
}
recv[0] = if all_blocked {
PortState::Blocked
} else {
PortState::Ready
};
Ok(())
}
fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv_ports: &mut [Option<RecvPort<'_>>],
send_ports: &mut [Option<SendPort<'_>>],
_state: &'s StreamingExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
assert!(recv_ports.len() == 1 && !send_ports.is_empty());
assert!(self.buffers.len() == send_ports.len());
enum Listener<'a> {
Active(UnboundedSender<Morsel>),
Buffering(&'a mut VecDeque<Morsel>),
Inactive,
}
let buffered_source_token = SourceToken::new();
let (mut buf_senders, buf_receivers): (Vec<_>, Vec<_>) = self
.buffers
.iter_mut()
.enumerate()
.map(|(port_idx, buffer)| {
if let BufferedStream::Open(buf) = buffer {
if send_ports[port_idx].is_some() {
let (rx, tx) = unbounded_channel();
(Listener::Active(rx), Some((buf, tx)))
} else {
(Listener::Buffering(buf), None)
}
} else {
(Listener::Inactive, None)
}
})
.unzip();
if let Some(mut receiver) = recv_ports[0].take().map(|r| r.serial()) {
let buffered_source_token = buffered_source_token.clone();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
loop {
let Ok(mut morsel) = receiver.recv().await else {
break;
};
drop(morsel.take_consume_token());
let mut anyone_interested = false;
let mut active_listener_interested = false;
for buf_sender in &mut buf_senders {
match buf_sender {
Listener::Active(s) => match s.send(morsel.clone()) {
Ok(_) => {
anyone_interested = true;
active_listener_interested = true;
},
Err(_) => *buf_sender = Listener::Inactive,
},
Listener::Buffering(b) => {
b.push_front(morsel.clone());
anyone_interested = true;
},
Listener::Inactive => {},
}
}
if !anyone_interested {
break;
}
if !active_listener_interested || buffered_source_token.stop_requested() {
morsel.source_token().stop();
}
}
Ok(())
}));
}
for (send_port, opt_buf_recv) in send_ports.iter_mut().zip(buf_receivers) {
if let Some((buf, mut rx)) = opt_buf_recv {
let mut sender = send_port.take().unwrap().serial();
let wait_group = WaitGroup::default();
let buffered_source_token = buffered_source_token.clone();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Some(mut morsel) = buf.pop_back() {
morsel.replace_source_token(buffered_source_token.clone());
morsel.set_consume_token(wait_group.token());
if sender.send(morsel).await.is_err()
|| buffered_source_token.stop_requested()
{
break;
}
wait_group.wait().await;
}
while let Some(mut morsel) = rx.recv().await {
morsel.set_consume_token(wait_group.token());
if sender.send(morsel).await.is_err() {
break;
}
wait_group.wait().await;
}
Ok(())
}));
}
}
}
}