use super::message::ShuffleMessage;
use crate::checkpoint::barrier::CheckpointBarrier;
const SHUFFLE_RECV_QUEUE: usize = 1024;
pub type ShufflePeerId = u64;
#[cfg(feature = "cluster")]
pub const SHUFFLE_ADDR_KEY: &str = "shuffle:addr";
#[cfg(feature = "cluster")]
#[allow(
clippy::doc_markdown,
clippy::default_trait_access,
clippy::missing_const_for_fn,
clippy::must_use_candidate,
clippy::too_many_lines,
missing_docs
)]
pub(crate) mod shuffle_v1 {
tonic::include_proto!("laminar.shuffle.v1");
}
#[derive(Default)]
struct Holdover {
staged: parking_lot::Mutex<rustc_hash::FxHashMap<String, Vec<arrow_array::RecordBatch>>>,
staged_barriers: parking_lot::Mutex<Vec<(ShufflePeerId, CheckpointBarrier)>>,
}
#[cfg(feature = "cluster")]
mod grpc {
use std::collections::hash_map::Entry;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use arrow_array::RecordBatch;
use crossfire::{mpsc, AsyncRx, MAsyncTx};
use futures::StreamExt as _;
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use tokio::task::JoinHandle;
use tonic::transport::{Channel, Server};
use tonic::Request;
use super::shuffle_v1::shuffle_frame;
use super::shuffle_v1::shuffle_transport_client::ShuffleTransportClient;
use super::shuffle_v1::shuffle_transport_server::{ShuffleTransport, ShuffleTransportServer};
use super::shuffle_v1::{Barrier, Close, Hello, ShuffleFrame, ShuffleSummary, VnodeData};
use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_ADDR_KEY, SHUFFLE_RECV_QUEUE};
use crate::checkpoint::barrier::CheckpointBarrier;
use crate::cluster::control::ClusterKv;
use crate::serialization::{BatchStreamDecoder, BatchStreamEncoder};
const SHUFFLE_SEND_QUEUE: usize = 1024;
type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
fn io_err<E: std::fmt::Display>(e: E) -> io::Error {
io::Error::other(e.to_string())
}
fn encode_message(
msg: &ShuffleMessage,
encoders: &mut FxHashMap<String, BatchStreamEncoder>,
) -> Result<ShuffleFrame, tonic::Status> {
let kind = match msg {
ShuffleMessage::Hello(node_id) => {
shuffle_frame::Kind::Hello(Hello { node_id: *node_id })
}
ShuffleMessage::Barrier(b) => shuffle_frame::Kind::Barrier(Barrier {
checkpoint_id: b.checkpoint_id,
epoch: b.epoch,
flags: b.flags,
}),
ShuffleMessage::VnodeData(stage, vnode, batch) => {
let encoder = match encoders.entry(stage.clone()) {
Entry::Occupied(e) => {
let enc = e.into_mut();
let schema = batch.schema();
if !Arc::ptr_eq(enc.schema(), &schema) && *enc.schema() != schema {
return Err(tonic::Status::internal(format!(
"shuffle stage '{stage}' changed schema mid-connection",
)));
}
enc
}
Entry::Vacant(v) => {
v.insert(BatchStreamEncoder::new(&batch.schema()).map_err(|e| {
tonic::Status::internal(format!("shuffle ipc encoder init: {e}"))
})?)
}
};
let arrow_ipc = encoder
.encode(batch)
.map_err(|e| tonic::Status::internal(format!("shuffle ipc encode: {e}")))?;
shuffle_frame::Kind::VnodeData(VnodeData {
stage: stage.clone(),
vnode: *vnode,
arrow_ipc,
})
}
ShuffleMessage::Close(reason) => shuffle_frame::Kind::Close(Close {
reason: reason.clone(),
}),
};
Ok(ShuffleFrame { kind: Some(kind) })
}
struct PeerConn {
tx: MAsyncTx<mpsc::Array<ShuffleMessage>>,
alive: Arc<AtomicBool>,
driver: JoinHandle<()>,
}
impl PeerConn {
fn is_alive(&self) -> bool {
self.alive.load(Ordering::Acquire)
}
}
impl Drop for PeerConn {
fn drop(&mut self) {
self.driver.abort();
}
}
pub struct ShuffleSender {
local_id: ShufflePeerId,
peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
pool: Mutex<FxHashMap<ShufflePeerId, Arc<PeerConn>>>,
kv: Option<Arc<dyn ClusterKv>>,
}
impl std::fmt::Debug for ShuffleSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShuffleSender")
.field("local_id", &self.local_id)
.finish_non_exhaustive()
}
}
impl ShuffleSender {
#[must_use]
pub fn new(local_id: ShufflePeerId) -> Self {
Self {
local_id,
peers: Mutex::new(FxHashMap::default()),
pool: Mutex::new(FxHashMap::default()),
kv: None,
}
}
#[must_use]
pub fn with_kv(local_id: ShufflePeerId, kv: Arc<dyn ClusterKv>) -> Self {
let mut s = Self::new(local_id);
s.kv = Some(kv);
s
}
#[allow(clippy::unused_async)]
pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
self.peers.lock().insert(peer, addr);
}
pub async fn send_to(&self, peer: ShufflePeerId, msg: &ShuffleMessage) -> io::Result<()> {
let conn = self.connection_for(peer).await?;
conn.tx.send(msg.clone()).await.map_err(|_| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("shuffle stream to peer {peer} closed"),
)
})
}
pub async fn fan_out_barrier(
&self,
peers: &[ShufflePeerId],
barrier: CheckpointBarrier,
) -> io::Result<()> {
let msg = ShuffleMessage::Barrier(barrier);
for &peer in peers {
self.send_to(peer, &msg).await?;
}
Ok(())
}
async fn discover_peer(&self, peer: ShufflePeerId) -> Option<SocketAddr> {
let kv = self.kv.as_ref()?;
let raw = kv
.read_from(crate::cluster::discovery::NodeId(peer), SHUFFLE_ADDR_KEY)
.await?;
let addr: SocketAddr = raw.parse().ok()?;
self.peers.lock().insert(peer, addr);
Some(addr)
}
async fn connection_for(&self, peer: ShufflePeerId) -> io::Result<Arc<PeerConn>> {
if let Some(existing) = self.pool.lock().get(&peer).cloned() {
if existing.is_alive() {
return Ok(existing);
}
}
self.pool.lock().retain(|p, c| *p != peer || c.is_alive());
let addr = match self.discover_peer(peer).await {
Some(addr) => addr,
None => self.peers.lock().get(&peer).copied().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("peer {peer} has no registered shuffle address"),
)
})?,
};
let conn = Arc::new(open_call(self.local_id, addr)?);
let mut pool = self.pool.lock();
if let Some(winner) = pool.get(&peer).cloned() {
if winner.is_alive() {
return Ok(winner);
}
}
pool.insert(peer, Arc::clone(&conn));
Ok(conn)
}
}
fn open_call(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<PeerConn> {
let endpoint = crate::cluster::control::tls::client_endpoint(&addr.to_string())
.map_err(io_err)?
.tcp_nodelay(true);
let (tx, rx) = mpsc::bounded_async::<ShuffleMessage>(SHUFFLE_SEND_QUEUE);
let alive = Arc::new(AtomicBool::new(true));
let alive_for_driver = Arc::clone(&alive);
let hello = ShuffleFrame {
kind: Some(shuffle_frame::Kind::Hello(Hello { node_id: local_id })),
};
let encoders: FxHashMap<String, BatchStreamEncoder> = FxHashMap::default();
let outbound = futures::stream::once(async move { hello }).chain(futures::stream::unfold(
(rx, encoders),
|(rx, mut encoders)| async move {
let msg = rx.recv().await.ok()?;
match encode_message(&msg, &mut encoders) {
Ok(frame) => Some((frame, (rx, encoders))),
Err(e) => {
tracing::warn!(error = %e, "shuffle frame encode failed; closing stream");
None
}
}
},
));
let driver = tokio::spawn(async move {
let Ok(channel) = endpoint.connect().await else {
alive_for_driver.store(false, Ordering::Release);
return;
};
let mut client = ShuffleTransportClient::<Channel>::new(channel);
let _ = client.shuffle(Request::new(outbound)).await;
alive_for_driver.store(false, Ordering::Release);
});
Ok(PeerConn { tx, alive, driver })
}
pub struct ShuffleReceiver {
local_id: ShufflePeerId,
local_addr: SocketAddr,
rx: Mutex<Option<InboundRx>>,
rx_returned: Arc<tokio::sync::Notify>,
server: JoinHandle<()>,
holdover: Arc<Holdover>,
}
impl Drop for ShuffleReceiver {
fn drop(&mut self) {
self.server.abort();
}
}
impl std::fmt::Debug for ShuffleReceiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShuffleReceiver")
.field("local_id", &self.local_id)
.field("local_addr", &self.local_addr)
.finish_non_exhaustive()
}
}
impl ShuffleReceiver {
pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
let listener = tokio::net::TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
let (tx, rx) =
mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
let service = ShuffleService { tx };
let incoming = futures::stream::unfold(listener, |listener| async move {
let item = match listener.accept().await {
Ok((stream, _)) => {
let _ = stream.set_nodelay(true);
Ok(stream)
}
Err(e) => Err(e),
};
Some((item, listener))
});
let mut builder = Server::builder();
if let Some(tls) = crate::cluster::control::tls::server_tls() {
builder = builder
.tls_config(tls.clone())
.map_err(|e| io::Error::other(format!("cluster shuffle TLS config: {e}")))?;
}
let router = builder.add_service(ShuffleTransportServer::new(service));
let server = tokio::spawn(async move {
let _ = router.serve_with_incoming(incoming).await;
});
Ok(Self {
local_id,
local_addr,
rx: Mutex::new(Some(rx)),
rx_returned: Arc::new(tokio::sync::Notify::new()),
server,
holdover: Arc::new(Holdover::default()),
})
}
pub async fn bind_with_kv(
local_id: ShufflePeerId,
addr: SocketAddr,
kv: Arc<dyn ClusterKv>,
) -> io::Result<Self> {
let recv = Self::bind(local_id, addr).await?;
kv.write(SHUFFLE_ADDR_KEY, recv.local_addr.to_string())
.await;
Ok(recv)
}
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
loop {
let taken = { self.rx.lock().take() };
let Some(rx) = taken else {
self.rx_returned.notified().await;
continue;
};
let mut guard = RxReturnGuard {
slot: &self.rx,
notify: &self.rx_returned,
rx: Some(rx),
};
let rx = guard.rx.as_mut()?;
return rx.recv().await.ok();
}
}
#[must_use]
pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
let mut out = Vec::new();
let slot = self.rx.lock();
if let Some(rx) = slot.as_ref() {
while let Ok(item) = rx.try_recv() {
out.push(item);
}
}
out
}
fn drain_inbound_into(&self, staged: &mut FxHashMap<String, Vec<RecordBatch>>) {
let slot = self.rx.lock();
if let Some(rx) = slot.as_ref() {
while let Ok((from, msg)) = rx.try_recv() {
match msg {
ShuffleMessage::VnodeData(s, _vnode, batch) => {
staged.entry(s).or_default().push(batch);
}
ShuffleMessage::Barrier(b) => {
self.holdover.staged_barriers.lock().push((from, b));
}
_ => {} }
}
}
}
#[must_use]
pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
let mut staged = self.holdover.staged.lock();
self.drain_inbound_into(&mut staged);
staged.remove(stage).unwrap_or_default()
}
#[must_use]
pub fn drain_staged_with_prefix(
&self,
prefix: &str,
) -> FxHashMap<String, Vec<RecordBatch>> {
let mut staged = self.holdover.staged.lock();
self.drain_inbound_into(&mut staged);
let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
staged.retain(|stage, batches| {
if stage.starts_with(prefix) {
out.insert(stage.clone(), std::mem::take(batches));
false
} else {
true
}
});
out
}
pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
self.holdover
.staged
.lock()
.entry(stage)
.or_default()
.push(batch);
}
#[must_use]
pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
std::mem::take(&mut self.holdover.staged_barriers.lock())
}
#[must_use]
pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
let mut staged = self.holdover.staged.lock();
staged
.drain()
.flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
.collect()
}
}
struct RxReturnGuard<'a> {
slot: &'a Mutex<Option<InboundRx>>,
notify: &'a tokio::sync::Notify,
rx: Option<InboundRx>,
}
impl Drop for RxReturnGuard<'_> {
fn drop(&mut self) {
if let Some(rx) = self.rx.take() {
*self.slot.lock() = Some(rx);
self.notify.notify_one();
}
}
}
struct ShuffleService {
tx: InboundTx,
}
#[tonic::async_trait]
impl ShuffleTransport for ShuffleService {
async fn shuffle(
&self,
request: Request<tonic::Streaming<ShuffleFrame>>,
) -> Result<tonic::Response<ShuffleSummary>, tonic::Status> {
let summary = run_stream(self.tx.clone(), request.into_inner()).await?;
Ok(tonic::Response::new(summary))
}
}
async fn run_stream(
tx: InboundTx,
mut stream: tonic::Streaming<ShuffleFrame>,
) -> Result<ShuffleSummary, tonic::Status> {
let first = stream
.message()
.await?
.ok_or_else(|| tonic::Status::invalid_argument("shuffle stream closed before Hello"))?;
let peer = match first.kind {
Some(shuffle_frame::Kind::Hello(h)) => h.node_id,
_ => {
return Err(tonic::Status::invalid_argument(
"first shuffle frame must be Hello",
))
}
};
let mut decoders: FxHashMap<String, BatchStreamDecoder> = FxHashMap::default();
let mut frames_received = 0u64;
while let Some(frame) = stream.message().await? {
let kind = frame
.kind
.ok_or_else(|| tonic::Status::invalid_argument("empty shuffle frame"))?;
match kind {
shuffle_frame::Kind::Close(_) => break,
shuffle_frame::Kind::Hello(h) => {
frames_received += 1;
if tx
.send((peer, ShuffleMessage::Hello(h.node_id)))
.await
.is_err()
{
break;
}
}
shuffle_frame::Kind::Barrier(b) => {
frames_received += 1;
let msg = ShuffleMessage::Barrier(CheckpointBarrier {
checkpoint_id: b.checkpoint_id,
epoch: b.epoch,
flags: b.flags,
});
if tx.send((peer, msg)).await.is_err() {
break;
}
}
shuffle_frame::Kind::VnodeData(v) => {
frames_received += 1;
let batches = decoders
.entry(v.stage.clone())
.or_default()
.decode_chunk(v.arrow_ipc)
.map_err(|e| {
tonic::Status::invalid_argument(format!("shuffle ipc: {e}"))
})?;
let mut stream_broken = false;
for batch in batches {
if !forward_vnode_batch(&tx, peer, &v.stage, v.vnode, batch).await? {
stream_broken = true;
break;
}
}
if stream_broken {
break;
}
}
}
}
Ok(ShuffleSummary { frames_received })
}
async fn forward_vnode_batch(
tx: &InboundTx,
peer: ShufflePeerId,
stage: &str,
default_vnode: u32,
batch: RecordBatch,
) -> Result<bool, tonic::Status> {
let schema = batch.schema();
let Some((col_idx, _field)) = schema.column_with_name("__laminar_vnode") else {
let msg = ShuffleMessage::VnodeData(stage.to_string(), default_vnode, batch);
return Ok(tx.send((peer, msg)).await.is_ok());
};
let vnode_array = batch
.column(col_idx)
.as_any()
.downcast_ref::<arrow_array::UInt32Array>()
.ok_or_else(|| {
tonic::Status::invalid_argument("vnode metadata column is not UInt32Array")
})?;
let row_vnodes: Vec<u32> = vnode_array.values().to_vec();
let mut projection: Vec<usize> = (0..schema.fields().len()).collect();
projection.remove(col_idx);
let batch_without_vnode = batch.project(&projection).map_err(|e| {
tonic::Status::internal(format!("Failed to project out vnode metadata: {e}"))
})?;
let slices =
crate::shuffle::routing::slice_batch_by_vnodes(&batch_without_vnode, &row_vnodes);
for (v, slice) in slices {
let sub_msg = ShuffleMessage::VnodeData(stage.to_string(), v, slice);
if tx.send((peer, sub_msg)).await.is_err() {
return Ok(false);
}
}
Ok(true)
}
#[cfg(test)]
mod encode_tests {
use super::*;
use arrow_array::Int64Array;
use arrow_schema::{DataType, Field, Schema};
#[test]
fn schema_change_on_a_stage_is_rejected() {
let batch = |name: &str| {
let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int64, false)]));
arrow_array::RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1]))])
.unwrap()
};
let mut encoders = FxHashMap::default();
let msg = ShuffleMessage::VnodeData("s".into(), 0, batch("a"));
encode_message(&msg, &mut encoders).unwrap();
let changed = ShuffleMessage::VnodeData("s".into(), 0, batch("b"));
let err = encode_message(&changed, &mut encoders).unwrap_err();
assert!(err.message().contains("changed schema"), "{err}");
let other = ShuffleMessage::VnodeData("t".into(), 0, batch("b"));
encode_message(&other, &mut encoders).unwrap();
}
}
}
#[cfg(feature = "cluster")]
pub use grpc::{ShuffleReceiver, ShuffleSender};
#[cfg(not(feature = "cluster"))]
mod shim {
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use arrow_array::RecordBatch;
use crossfire::{mpsc, AsyncRx, MAsyncTx};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_RECV_QUEUE};
use crate::checkpoint::barrier::CheckpointBarrier;
type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
pub struct ShuffleSender {
local_id: ShufflePeerId,
peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
}
impl std::fmt::Debug for ShuffleSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShuffleSender")
.field("local_id", &self.local_id)
.finish_non_exhaustive()
}
}
impl ShuffleSender {
#[must_use]
pub fn new(local_id: ShufflePeerId) -> Self {
Self {
local_id,
peers: Mutex::new(FxHashMap::default()),
}
}
#[allow(clippy::unused_async)] pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
self.peers.lock().insert(peer, addr);
}
#[allow(clippy::unused_async)] pub async fn send_to(&self, peer: ShufflePeerId, _msg: &ShuffleMessage) -> io::Result<()> {
if self.peers.lock().contains_key(&peer) {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("peer {peer} has no registered shuffle address"),
))
}
}
pub async fn fan_out_barrier(
&self,
peers: &[ShufflePeerId],
barrier: CheckpointBarrier,
) -> io::Result<()> {
let msg = ShuffleMessage::Barrier(barrier);
for &peer in peers {
self.send_to(peer, &msg).await?;
}
Ok(())
}
}
pub struct ShuffleReceiver {
local_id: ShufflePeerId,
local_addr: SocketAddr,
#[allow(dead_code)]
tx: InboundTx,
rx: Mutex<Option<InboundRx>>,
rx_returned: Arc<tokio::sync::Notify>,
holdover: Arc<Holdover>,
}
impl std::fmt::Debug for ShuffleReceiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShuffleReceiver")
.field("local_id", &self.local_id)
.field("local_addr", &self.local_addr)
.finish_non_exhaustive()
}
}
impl ShuffleReceiver {
pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
let listener = tokio::net::TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
drop(listener);
let (tx, rx) =
mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
Ok(Self {
local_id,
local_addr,
tx,
rx: Mutex::new(Some(rx)),
rx_returned: Arc::new(tokio::sync::Notify::new()),
holdover: Arc::new(Holdover::default()),
})
}
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
loop {
let taken = { self.rx.lock().take() };
let Some(rx) = taken else {
self.rx_returned.notified().await;
continue;
};
let mut guard = RxReturnGuard {
slot: &self.rx,
notify: &self.rx_returned,
rx: Some(rx),
};
let rx = guard.rx.as_mut()?;
return rx.recv().await.ok();
}
}
#[must_use]
pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
let mut out = Vec::new();
let slot = self.rx.lock();
if let Some(rx) = slot.as_ref() {
while let Ok(item) = rx.try_recv() {
out.push(item);
}
}
out
}
#[must_use]
pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
let mut staged = self.holdover.staged.lock();
{
let slot = self.rx.lock();
if let Some(rx) = slot.as_ref() {
while let Ok((from, msg)) = rx.try_recv() {
match msg {
ShuffleMessage::VnodeData(s, _vnode, batch) => {
staged.entry(s).or_default().push(batch);
}
ShuffleMessage::Barrier(b) => {
self.holdover.staged_barriers.lock().push((from, b));
}
_ => {}
}
}
}
}
staged.remove(stage).unwrap_or_default()
}
#[must_use]
pub fn drain_staged_with_prefix(
&self,
prefix: &str,
) -> FxHashMap<String, Vec<RecordBatch>> {
let mut staged = self.holdover.staged.lock();
{
let slot = self.rx.lock();
if let Some(rx) = slot.as_ref() {
while let Ok((from, msg)) = rx.try_recv() {
match msg {
ShuffleMessage::VnodeData(s, _vnode, batch) => {
staged.entry(s).or_default().push(batch);
}
ShuffleMessage::Barrier(b) => {
self.holdover.staged_barriers.lock().push((from, b));
}
_ => {}
}
}
}
}
let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
staged.retain(|stage, batches| {
if stage.starts_with(prefix) {
out.insert(stage.clone(), std::mem::take(batches));
false
} else {
true
}
});
out
}
pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
self.holdover
.staged
.lock()
.entry(stage)
.or_default()
.push(batch);
}
#[must_use]
pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
std::mem::take(&mut self.holdover.staged_barriers.lock())
}
#[must_use]
pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
let mut staged = self.holdover.staged.lock();
staged
.drain()
.flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
.collect()
}
}
struct RxReturnGuard<'a> {
slot: &'a Mutex<Option<InboundRx>>,
notify: &'a tokio::sync::Notify,
rx: Option<InboundRx>,
}
impl Drop for RxReturnGuard<'_> {
fn drop(&mut self) {
if let Some(rx) = self.rx.take() {
*self.slot.lock() = Some(rx);
self.notify.notify_one();
}
}
}
}
#[cfg(not(feature = "cluster"))]
pub use shim::{ShuffleReceiver, ShuffleSender};
#[cfg(all(test, feature = "cluster"))]
mod tests {
use std::io;
use std::sync::Arc;
use super::*;
async fn bind_on_loopback(local_id: ShufflePeerId) -> ShuffleReceiver {
ShuffleReceiver::bind(local_id, "127.0.0.1:0".parse().unwrap())
.await
.expect("bind")
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn sender_to_receiver_delivers_with_peer_attribution() {
let recv = bind_on_loopback(2).await;
let recv_addr = recv.local_addr();
let sender = ShuffleSender::new(1);
sender.register_peer(2, recv_addr).await;
sender
.send_to(2, &ShuffleMessage::Hello(1234))
.await
.unwrap();
let (from, msg) = recv.recv().await.unwrap();
assert_eq!(from, 1, "receiver attributes frame to sender id");
assert_eq!(msg, ShuffleMessage::Hello(1234));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn sender_reuses_stream_across_sends() {
let recv = bind_on_loopback(2).await;
let sender = ShuffleSender::new(1);
sender.register_peer(2, recv.local_addr()).await;
for delta in [10u64, 20, 30, 40] {
sender
.send_to(2, &ShuffleMessage::Hello(delta))
.await
.unwrap();
}
let mut got = Vec::new();
for _ in 0..4 {
got.push(recv.recv().await.unwrap().1);
}
assert_eq!(
got,
vec![
ShuffleMessage::Hello(10),
ShuffleMessage::Hello(20),
ShuffleMessage::Hello(30),
ShuffleMessage::Hello(40),
]
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_to_unregistered_peer_errors() {
let sender = ShuffleSender::new(1);
let err = sender
.send_to(99, &ShuffleMessage::Hello(1))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_discovers_peer_address_from_kv() {
use crate::cluster::control::{ClusterKv, InMemoryKv};
use crate::cluster::discovery::NodeId;
let recv = bind_on_loopback(2).await;
let kv = Arc::new(InMemoryKv::new(NodeId(1)));
kv.seed(NodeId(2), SHUFFLE_ADDR_KEY, recv.local_addr().to_string());
let sender = ShuffleSender::with_kv(1, kv as Arc<dyn ClusterKv>);
sender.send_to(2, &ShuffleMessage::Hello(7)).await.unwrap();
let (from, msg) = recv.recv().await.unwrap();
assert_eq!(from, 1);
assert_eq!(msg, ShuffleMessage::Hello(7));
}
#[cfg(not(windows))]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_reconnects_after_peer_restart_at_new_address() {
let recv_v1 = bind_on_loopback(2).await;
let addr_v1 = recv_v1.local_addr();
let sender = ShuffleSender::new(1);
sender.register_peer(2, addr_v1).await;
sender
.send_to(2, &ShuffleMessage::Hello(111))
.await
.unwrap();
let (from, msg) = recv_v1.recv().await.unwrap();
assert_eq!(from, 1);
assert_eq!(msg, ShuffleMessage::Hello(111));
drop(recv_v1);
let recv_v2 = bind_on_loopback(2).await;
let addr_v2 = recv_v2.local_addr();
assert_ne!(addr_v1, addr_v2, "ephemeral rebind must pick a new port");
sender.register_peer(2, addr_v2).await;
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
loop {
let _ = sender.send_to(2, &ShuffleMessage::Hello(222)).await;
if let Some((from, ShuffleMessage::Hello(222))) =
tokio::time::timeout(std::time::Duration::from_millis(200), recv_v2.recv())
.await
.ok()
.flatten()
{
assert_eq!(from, 1);
return;
}
assert!(
std::time::Instant::now() < deadline,
"did not deliver to restarted peer within 30s",
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn drain_staged_with_prefix_lifts_subs_and_keeps_operator_stages() {
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use rustc_hash::FxHashMap;
use crate::checkpoint::barrier::CheckpointBarrier;
fn batch(values: Vec<i64>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap()
}
fn col(b: &RecordBatch) -> Vec<i64> {
b.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.values()
.to_vec()
}
let recv = bind_on_loopback(2).await;
let sender = ShuffleSender::new(1);
sender.register_peer(2, recv.local_addr()).await;
for (stage, vals) in [
("__sub::alpha", vec![1, 2, 3]),
("__sub::beta", vec![4, 5, 6]),
("op_stage", vec![7, 8, 9]),
] {
sender
.send_to(2, &ShuffleMessage::VnodeData(stage.into(), 0, batch(vals)))
.await
.unwrap();
}
sender
.send_to(
2,
&ShuffleMessage::Barrier(CheckpointBarrier {
checkpoint_id: 7,
epoch: 3,
flags: 0,
}),
)
.await
.unwrap();
let mut subs: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
let mut barriers = Vec::new();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
while subs.len() < 2 || barriers.is_empty() {
for (k, v) in recv.drain_staged_with_prefix("__sub::") {
subs.entry(k).or_default().extend(v);
}
barriers.extend(recv.drain_staged_barriers());
assert!(
std::time::Instant::now() < deadline,
"frames not delivered within 2s",
);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
assert_eq!(subs.len(), 2, "only the two __sub:: stages are returned");
assert_eq!(col(&subs["__sub::alpha"][0]), vec![1, 2, 3]);
assert_eq!(col(&subs["__sub::beta"][0]), vec![4, 5, 6]);
assert_eq!(barriers.len(), 1);
assert_eq!(barriers[0].0, 1, "barrier attributed to sender peer 1");
assert_eq!(barriers[0].1.checkpoint_id, 7);
let op = recv.drain_vnode_data_for("op_stage");
assert_eq!(op.len(), 1);
assert_eq!(col(&op[0]), vec![7, 8, 9]);
assert!(recv.drain_staged_with_prefix("__sub::").is_empty());
}
}