use crate::{
NumEntries, WeakRuntime,
circuit::{
Host, LocalStoreMarker, OwnershipPreference, Runtime, Scope,
metadata::{
BatchSizeStats, EXCHANGE_DESERIALIZATION_TIME_SECONDS, EXCHANGE_DESERIALIZED_BYTES,
EXCHANGE_WAIT_TIME_SECONDS, INPUT_BATCHES_STATS, MetaItem, OUTPUT_BATCHES_STATS,
OperatorLocation, OperatorMeta,
},
operator_traits::{Operator, SinkOperator, SourceOperator},
runtime::{WorkerLocation, WorkerLocations},
tokio::TOKIO,
},
circuit_cache_key,
samply::SamplySpan,
storage::file::format::FixedLen,
};
use binrw::{BinRead, BinWrite};
use crossbeam_utils::CachePadded;
use feldera_storage::fbuf::FBuf;
use futures::{prelude::*, stream::FuturesUnordered};
use itertools::Itertools;
use rkyv::AlignedVec;
use size_of::HumanBytes;
use std::{
borrow::Cow,
collections::HashMap,
io::{Cursor, ErrorKind, IoSlice},
marker::PhantomData,
mem::MaybeUninit,
net::SocketAddr,
ops::Range,
sync::{
Arc, Mutex, OnceLock, RwLock,
atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering},
},
time::{Duration, Instant, SystemTime},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::{Notify, OnceCell as TokioOnceCell},
time::sleep,
};
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::warn;
use typedmap::TypedMapKey;
fn current_time_usecs() -> u64 {
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros().try_into().unwrap_or(u64::MAX))
.unwrap_or(0)
}
circuit_cache_key!(local ExchangeCacheId<T>(ExchangeId => Arc<Exchange<T>>));
#[binrw::binrw]
#[brw(little)]
struct ExchangeHeader {
exchange_id: ExchangeId,
sender_start: u32,
sender_end: u32,
#[brw(align_after(16))]
data_len: u32,
}
impl FixedLen for ExchangeHeader {
const LEN: usize = 16;
}
impl ExchangeHeader {
fn to_bytes(&self) -> [u8; Self::LEN] {
let mut cursor = Cursor::new([0; Self::LEN]);
self.write_le(&mut cursor).unwrap();
assert_eq!(cursor.position(), Self::LEN as u64);
cursor.into_inner()
}
fn from_bytes(bytes: &[u8; Self::LEN]) -> Self {
let mut cursor = Cursor::new(bytes);
let this = Self::read_le(&mut cursor).unwrap();
assert_eq!(cursor.position(), Self::LEN as u64);
this
}
async fn read<S>(stream: &mut S) -> std::io::Result<Option<Self>>
where
S: AsyncRead + Unpin,
{
let mut buf = [0; Self::LEN];
match stream.read(&mut buf).await? {
0 => Ok(None),
n => {
stream.read_exact(&mut buf[n..]).await?;
Ok(Some(ExchangeHeader::from_bytes(&buf)))
}
}
}
}
struct ExchangeServiceClient {
receivers: Range<usize>,
stream: tokio::sync::Mutex<TcpStream>,
}
impl ExchangeServiceClient {
async fn exchange(
&self,
exchange_id: ExchangeId,
senders: Range<usize>,
data: Vec<Vec<FBuf>>,
) -> std::io::Result<()> {
let n = senders.len() * self.receivers.len();
let mut headers = Vec::with_capacity(n);
for data in data.iter().flatten() {
headers.push(
ExchangeHeader {
exchange_id,
sender_start: senders.start.try_into().unwrap(),
sender_end: senders.end.try_into().unwrap(),
data_len: data.len().try_into().unwrap(),
}
.to_bytes(),
);
}
let zeros = [0; 16];
let mut slices = Vec::with_capacity(n * 3);
let mut header = headers.iter();
for data in data.iter().flatten() {
slices.push(IoSlice::new(header.next().unwrap()));
if !data.is_empty() {
slices.push(IoSlice::new(data.as_slice()));
}
let pad = &zeros[..data.len().next_multiple_of(16) - data.len()];
if !pad.is_empty() {
slices.push(IoSlice::new(pad));
}
}
let size = slices.iter().map(|slice| slice.len()).sum::<usize>();
let mut bufs = slices.as_mut_slice();
let mut stream = self.stream.lock().await;
let _span = SamplySpan::new("send")
.with_category("Exchange")
.with_tooltip(|| format!("send {} for exchange {exchange_id}", HumanBytes::from(size)));
while !bufs.is_empty() {
let n = stream.write_vectored(bufs).await?;
IoSlice::advance_slices(&mut bufs, n);
}
Ok(())
}
}
type ExchangeId = u32;
type ExchangeDirectory = Arc<RwLock<HashMap<ExchangeId, Arc<InnerExchange>>>>;
struct ExchangeServer {
receivers: Range<usize>,
directory: ExchangeDirectory,
stream: TcpStream,
}
impl ExchangeServer {
async fn serve(mut self) -> std::io::Result<()> {
while let Some(header) = ExchangeHeader::read(&mut self.stream).await? {
let start = Instant::now();
let exchange = self
.directory
.read()
.unwrap()
.get(&header.exchange_id)
.unwrap_or_else(|| panic!("unknown exchange ID {}", header.exchange_id))
.clone();
let senders = (header.sender_start as usize)..(header.sender_end as usize);
let mut bytes = senders.len() * self.receivers.len() * ExchangeHeader::LEN;
let mut data = Vec::with_capacity(senders.len());
let mut header = Some(header);
for _ in senders.clone() {
let mut receivers_data = Vec::with_capacity(self.receivers.len());
for _ in self.receivers.clone() {
let header = if let Some(header) = header.take() {
header
} else {
ExchangeHeader::read(&mut self.stream)
.await?
.ok_or_else(|| std::io::Error::from(ErrorKind::UnexpectedEof))?
};
let len = header.data_len as usize;
let padded_len = len.next_multiple_of(16);
let mut payload = AlignedVec::with_capacity(padded_len);
let pointer = payload.as_mut_ptr() as *mut MaybeUninit<u8>;
let mut slice = unsafe { std::slice::from_raw_parts_mut(pointer, padded_len) };
while !slice.is_empty() {
self.stream.read_buf(&mut slice).await?;
}
unsafe { payload.set_len(len) };
receivers_data.push(payload);
bytes += padded_len;
}
data.push(receivers_data);
}
SamplySpan::new("receive")
.with_start(start)
.with_category("Exchange")
.with_tooltip(|| {
format!(
"exchange {} receive {} from workers {}..{}",
exchange.exchange_id,
HumanBytes::from(bytes),
senders.start,
senders.end
)
})
.record();
tokio::spawn(async move { exchange.received(senders, data).await });
}
Ok(())
}
}
struct Clients {
runtime: WeakRuntime,
local_workers: Range<usize>,
listener: OnceLock<Option<ExchangeListener>>,
clients: Vec<(Host, TokioOnceCell<ExchangeServiceClient>)>,
}
impl Clients {
fn new(runtime: &Runtime) -> Clients {
Self {
local_workers: runtime.layout().local_workers(),
runtime: runtime.downgrade(),
listener: Default::default(),
clients: runtime
.layout()
.other_hosts()
.map(|host| (host.clone(), TokioOnceCell::new()))
.collect(),
}
}
async fn connect(&self, worker: usize) -> &ExchangeServiceClient {
self.listener.get_or_init(|| {
if let Some(runtime) = self.runtime.upgrade()
&& let Some(local_address) = runtime.layout().local_address()
{
let directory = runtime.local_store().get(&DirectoryId).unwrap().clone();
Some(ExchangeListener::new(
local_address,
directory,
self.local_workers.clone(),
))
} else {
None
}
});
let (host, cell) = self
.clients
.iter()
.find(|(host, _client)| host.workers.contains(&worker))
.unwrap();
cell.get_or_init(|| async {
let stream = loop {
match TcpStream::connect(host.address).await {
Ok(stream) => break stream,
Err(error) => println!(
"connection to {} failed ({error}), waiting to retry",
host.address
),
}
sleep(std::time::Duration::from_millis(1000)).await;
};
stream.set_nodelay(true).unwrap();
stream.set_zero_linger().unwrap();
ExchangeServiceClient {
receivers: self.local_workers.clone(),
stream: tokio::sync::Mutex::new(stream),
}
})
.await
}
}
struct CallbackInner {
cb: Option<Box<dyn Fn() + Send + Sync>>,
}
impl CallbackInner {
fn empty() -> Self {
Self { cb: None }
}
fn new<F>(cb: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
let cb = Box::new(cb) as Box<dyn Fn() + Send + Sync>;
Self { cb: Some(cb) }
}
}
struct Callback(AtomicPtr<CallbackInner>);
impl Callback {
fn empty() -> Self {
Self(AtomicPtr::new(Box::into_raw(Box::new(
CallbackInner::empty(),
))))
}
fn set_callback(&self, cb: impl Fn() + Send + Sync + 'static) {
let old_callback = self.0.swap(
Box::into_raw(Box::new(CallbackInner::new(cb))),
Ordering::AcqRel,
);
let old_callback = unsafe { Box::from_raw(old_callback) };
drop(old_callback);
}
fn call(&self) {
if let Some(cb) = &unsafe { &*self.0.load(Ordering::Acquire) }.cb {
cb()
}
}
}
struct InnerExchange {
exchange_id: ExchangeId,
npeers: usize,
local_workers: Range<usize>,
receiver_counters: Vec<AtomicUsize>,
receiver_callbacks: Vec<Callback>,
sender_counters: Vec<CachePadded<AtomicUsize>>,
sender_callbacks: Vec<Callback>,
sent: AtomicUsize,
clients: Arc<Clients>,
sender_notifies: Vec<Notify>,
deliver: Box<dyn Fn(AlignedVec, usize, usize) + Send + Sync + 'static>,
}
impl InnerExchange {
fn new(
exchange_id: ExchangeId,
deliver: impl Fn(AlignedVec, usize, usize) + Send + Sync + 'static,
clients: Arc<Clients>,
) -> InnerExchange {
let runtime = Runtime::runtime().unwrap();
let npeers = Runtime::num_workers();
let local_workers = runtime.layout().local_workers();
let n_local_workers = local_workers.len();
let n_remote_workers = npeers - n_local_workers;
Self {
exchange_id,
npeers,
local_workers,
clients,
receiver_counters: (0..npeers).map(|_| AtomicUsize::new(0)).collect(),
receiver_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
sender_notifies: (0..n_local_workers * n_remote_workers)
.map(|_| Notify::new())
.collect(),
sender_counters: (0..npeers)
.map(|_| CachePadded::new(AtomicUsize::new(npeers)))
.collect(),
sender_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
deliver: Box::new(deliver),
sent: AtomicUsize::new(0),
}
}
#[allow(dead_code)]
fn exchange_id(&self) -> ExchangeId {
self.exchange_id
}
fn sender_notify(&self, sender: usize, receiver: usize) -> &Notify {
debug_assert!(sender < self.npeers && !self.local_workers.contains(&sender));
debug_assert!(self.local_workers.contains(&receiver));
let n_local_workers = self.local_workers.len();
let sender_ofs = if sender >= self.local_workers.start {
sender - n_local_workers
} else {
sender
};
let receiver_ofs = receiver - self.local_workers.start;
&self.sender_notifies[sender_ofs * n_local_workers + receiver_ofs]
}
async fn received(self: &Arc<Self>, senders: Range<usize>, data: Vec<Vec<AlignedVec>>) {
let _span = SamplySpan::new("deliver")
.with_category("Exchange")
.with_tooltip(|| {
format!(
"exchange {} deliver from workers {}..{}",
self.exchange_id, senders.start, senders.end
)
});
let receivers = &self.local_workers;
for (sender, data) in senders.clone().zip(data.into_iter()) {
assert_eq!(data.len(), receivers.len());
for (receiver, data) in receivers.clone().zip(data.into_iter()) {
(self.deliver)(data, sender, receiver);
}
}
for receiver in receivers.clone() {
let n = senders.len();
let old_counter = self.receiver_counters[receiver].fetch_add(n, Ordering::AcqRel);
if old_counter >= self.npeers - n {
self.receiver_callbacks[receiver].call();
}
}
for sender in senders {
for receiver in receivers.clone() {
self.sender_notify(sender, receiver).notified().await;
}
}
}
fn mailbox_index(&self, sender: usize, receiver: usize) -> usize {
debug_assert!(sender < self.npeers);
debug_assert!(receiver < self.npeers);
sender * self.npeers + receiver
}
fn ready_to_send(&self, sender: usize) -> bool {
debug_assert!(self.local_workers.contains(&sender));
self.sender_counters[sender].load(Ordering::Acquire) == self.npeers
}
fn ready_to_receive(&self, receiver: usize) -> bool {
debug_assert!(receiver < self.npeers);
self.receiver_counters[receiver].load(Ordering::Acquire) == self.npeers
}
fn register_sender_callback<F>(&self, sender: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
debug_assert!(sender < self.npeers);
self.sender_callbacks[sender].set_callback(cb);
}
fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
debug_assert!(receiver < self.npeers);
self.receiver_callbacks[receiver].set_callback(cb);
}
}
#[derive(Clone, Debug)]
pub enum Mailbox<T> {
Tx(FBuf),
Rx(AlignedVec),
Plain(T),
}
impl<T> Mailbox<T> {
fn into_tx(self) -> Option<FBuf> {
match self {
Mailbox::Tx(bytes) => Some(bytes),
Mailbox::Rx(_) | Mailbox::Plain(_) => None,
}
}
}
pub(crate) struct Exchange<T> {
inner: Arc<InnerExchange>,
mailboxes: Arc<Vec<Mutex<Option<Mailbox<T>>>>>,
deserialize: Box<dyn Fn(AlignedVec) -> T + Send + Sync>,
deserialization_usecs: AtomicU64,
deserialized_bytes: AtomicUsize,
}
#[allow(dead_code)]
struct ExchangeListener(DropGuard);
impl ExchangeListener {
fn new(address: SocketAddr, directory: ExchangeDirectory, receivers: Range<usize>) -> Self {
let token = CancellationToken::new();
let drop = token.clone().drop_guard();
TOKIO.spawn(async move {
println!("listening on {address}");
let listener = TcpListener::bind(address).await.unwrap();
while let Some(stream) = tokio::select! {
stream = listener.accept() => Some(stream),
_ = token.cancelled() => None,
} {
match stream {
Ok((stream, _address)) => {
tokio::spawn(
ExchangeServer {
receivers: receivers.clone(),
directory: directory.clone(),
stream,
}
.serve(),
);
}
Err(error) => warn!("Error accepting connection from {address}: {error}"),
}
}
});
Self(drop)
}
}
impl<T> Exchange<T>
where
T: Clone + Send + 'static,
{
fn new(
exchange_id: ExchangeId,
clients: Arc<Clients>,
directory: ExchangeDirectory,
deserialize: Box<dyn Fn(AlignedVec) -> T + Send + Sync>,
) -> Self {
let npeers = Runtime::num_workers();
let mailboxes: Arc<Vec<Mutex<Option<Mailbox<T>>>>> =
Arc::new((0..npeers * npeers).map(|_| Mutex::new(None)).collect());
let mailboxes2 = mailboxes.clone();
let deliver = move |data, sender, receiver| {
let index: usize = sender * npeers + receiver;
let mut mailbox = mailboxes2[index].lock().unwrap();
assert!((*mailbox).is_none());
*mailbox = Some(Mailbox::Rx(data));
};
let inner = Arc::new(InnerExchange::new(exchange_id, deliver, clients));
directory
.write()
.unwrap()
.entry(exchange_id)
.and_modify(|_| panic!())
.or_insert(inner.clone());
Self {
inner,
mailboxes,
deserialize,
deserialization_usecs: AtomicU64::new(0),
deserialized_bytes: AtomicUsize::new(0),
}
}
#[allow(dead_code)]
fn exchange_id(&self) -> ExchangeId {
self.inner.exchange_id()
}
fn mailbox(&self, sender: usize, receiver: usize) -> &Mutex<Option<Mailbox<T>>> {
&self.mailboxes[self.inner.mailbox_index(sender, receiver)]
}
pub(crate) fn with_runtime(
runtime: &Runtime,
exchange_id: ExchangeId,
deserialize: Box<dyn Fn(AlignedVec) -> T + Send + Sync>,
) -> Arc<Self> {
let directory = runtime
.local_store()
.entry(DirectoryId)
.or_insert_with(|| Arc::new(RwLock::new(HashMap::new())))
.clone();
let clients = runtime
.local_store()
.entry(ClientsId)
.or_insert_with(|| {
Arc::new(Clients::new(runtime))
})
.clone();
runtime
.local_store()
.entry(ExchangeCacheId::new(exchange_id))
.or_insert_with(|| {
Arc::new(Exchange::new(
exchange_id,
clients.clone(),
directory,
deserialize,
))
})
.value()
.clone()
}
pub fn ready_to_send(&self, sender: usize) -> bool {
self.inner.ready_to_send(sender)
}
pub(crate) fn try_send_all_with_serializer<F>(
self: &Arc<Self>,
sender: usize,
data: impl Iterator<Item = T>,
mut serialize: F,
) -> bool
where
F: FnMut(T) -> FBuf + Send + Sync,
{
self.try_send_all(
sender,
data.zip(WorkerLocations::new())
.map(|(data, location)| match location {
WorkerLocation::Local => Mailbox::Plain(data),
WorkerLocation::Remote => Mailbox::Tx(serialize(data)),
}),
)
}
pub(crate) fn try_send_all(
self: &Arc<Self>,
sender: usize,
data: impl Iterator<Item = Mailbox<T>>,
) -> bool {
let npeers = self.inner.npeers;
if self.inner.sender_counters[sender]
.compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return false;
}
let local_workers = &self.inner.local_workers;
for (receiver, item) in (0..npeers).zip_eq(data.take(npeers)) {
let is_local = local_workers.contains(&receiver);
let mailbox_is_local = matches!(&item, Mailbox::Plain(_));
assert_eq!(is_local, mailbox_is_local);
*self.mailbox(sender, receiver).lock().unwrap() = Some(item);
if is_local {
let old_counter =
self.inner.receiver_counters[receiver].fetch_add(1, Ordering::AcqRel);
if old_counter >= npeers - 1 {
self.inner.receiver_callbacks[receiver].call();
}
}
}
if npeers == local_workers.len()
|| self.inner.sent.fetch_add(1, Ordering::AcqRel) + 1 != local_workers.len()
{
return true;
}
self.inner.sent.store(0, Ordering::Release);
let this = self.clone();
let runtime = Runtime::runtime().unwrap();
TOKIO.spawn(async move {
let mut futures = FuturesUnordered::new();
let senders = &this.inner.local_workers;
for host in runtime.layout().other_hosts() {
let receivers = &host.workers;
let mut serialized_bytes = 0;
let items: Vec<Vec<FBuf>> = senders
.clone()
.map(|sender| {
receivers
.clone()
.map(|receiver| {
let serialized = this
.mailbox(sender, receiver)
.lock()
.unwrap()
.take()
.unwrap()
.into_tx()
.expect("remote mailboxes should always be serialized");
serialized_bytes += serialized.len();
serialized
})
.collect()
})
.collect();
let client = this.inner.clients.connect(receivers.start).await;
futures.push(client.exchange(this.inner.exchange_id, senders.clone(), items));
}
while let Some(result) = futures.next().await {
result.unwrap();
}
let n = npeers - senders.len();
for sender in senders.clone() {
let old_counter = this.inner.sender_counters[sender].fetch_add(n, Ordering::AcqRel);
if old_counter >= npeers - n {
this.inner.sender_callbacks[sender].call();
}
}
});
true
}
pub(crate) fn ready_to_receive(&self, receiver: usize) -> bool {
self.inner.ready_to_receive(receiver)
}
pub(crate) fn try_receive_all<F>(&self, receiver: usize, mut cb: F) -> bool
where
F: FnMut(T),
{
let npeers = self.inner.npeers;
if self.inner.receiver_counters[receiver]
.compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return false;
}
let start = Instant::now();
let mut deserialized_bytes = 0;
let data = (0..self.inner.npeers)
.map(|sender| {
let mailbox = self
.mailbox(sender, receiver)
.lock()
.unwrap()
.take()
.unwrap();
match mailbox {
Mailbox::Plain(item) => item,
Mailbox::Tx(_) => unreachable!(),
Mailbox::Rx(bytes) => {
deserialized_bytes += bytes.len();
(self.deserialize)(bytes)
}
}
})
.collect_vec();
self.deserialized_bytes
.fetch_add(deserialized_bytes, Ordering::Relaxed);
self.deserialization_usecs
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
SamplySpan::new("deserialize")
.with_category("Exchange")
.with_start(start)
.with_tooltip(|| {
format!(
"exchange {} deserialize {}",
self.exchange_id(),
HumanBytes::from(deserialized_bytes)
)
})
.record();
for (sender, data) in data.into_iter().enumerate() {
cb(data);
if self.inner.local_workers.contains(&sender) {
let old_counter = self.inner.sender_counters[sender].fetch_add(1, Ordering::AcqRel);
if old_counter >= self.inner.npeers - 1 {
self.inner.sender_callbacks[sender].call();
}
} else {
self.inner.sender_notify(sender, receiver).notify_one();
}
}
true
}
pub(crate) fn register_sender_callback<F>(&self, sender: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
self.inner.register_sender_callback(sender, cb)
}
pub(crate) fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
self.inner.register_receiver_callback(receiver, cb)
}
}
pub struct ExchangeSender<D, T, L>
where
T: Send + 'static + Clone,
{
worker_index: usize,
location: OperatorLocation,
partition: L,
outputs: Vec<Mailbox<T>>,
exchange: Arc<Exchange<(T, bool)>>,
input_batch_stats: BatchSizeStats,
flushed: bool,
start_wait_usecs: Arc<AtomicU64>,
phantom: PhantomData<D>,
}
impl<D, T, L> ExchangeSender<D, T, L>
where
T: Send + 'static + Clone,
{
fn new(
worker_index: usize,
location: OperatorLocation,
exchange: Arc<Exchange<(T, bool)>>,
start_wait_usecs: Arc<AtomicU64>,
partition: L,
) -> Self {
debug_assert!(worker_index < Runtime::num_workers());
Self {
worker_index,
location,
partition,
outputs: Vec::with_capacity(Runtime::num_workers()),
exchange,
input_batch_stats: BatchSizeStats::new(),
flushed: false,
start_wait_usecs,
phantom: PhantomData,
}
}
}
impl<D, T, L> Operator for ExchangeSender<D, T, L>
where
D: 'static,
T: Send + 'static + Clone,
L: 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("ExchangeSender")
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
});
}
fn location(&self) -> OperatorLocation {
self.location
}
fn clock_start(&mut self, _scope: Scope) {}
fn clock_end(&mut self, _scope: Scope) {}
fn is_async(&self) -> bool {
true
}
fn register_ready_callback<F>(&mut self, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
self.exchange
.register_sender_callback(self.worker_index, cb)
}
fn ready(&self) -> bool {
self.exchange.ready_to_send(self.worker_index)
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flushed = true;
}
}
impl<D, T, L> SinkOperator<D> for ExchangeSender<D, T, L>
where
D: Clone + NumEntries + 'static,
T: Clone + Send + 'static,
L: FnMut(D, &mut Vec<Mailbox<T>>) + 'static,
{
async fn eval(&mut self, input: &D) {
self.eval_owned(input.clone()).await
}
async fn eval_owned(&mut self, input: D) {
self.input_batch_stats.add_batch(input.num_entries_deep());
debug_assert!(self.ready());
self.outputs.clear();
(self.partition)(input, &mut self.outputs);
self.start_wait_usecs
.store(current_time_usecs(), Ordering::Release);
let res = self.exchange.try_send_all(
self.worker_index,
self.outputs.drain(..).map(|mailbox| match mailbox {
Mailbox::Tx(mut data) => {
data.push(self.flushed as u8);
Mailbox::Tx(data)
}
Mailbox::Rx(_) => unreachable!(),
Mailbox::Plain(item) => Mailbox::Plain((item, self.flushed)),
}),
);
self.flushed = false;
debug_assert!(res);
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
pub struct ExchangeReceiver<IF, T, L>
where
T: Send + 'static + Clone,
{
worker_index: usize,
location: OperatorLocation,
init: IF,
combine: L,
exchange: Arc<Exchange<(T, bool)>>,
flush_count: usize,
flush_complete: bool,
start_wait_usecs: Arc<AtomicU64>,
total_wait_time: Arc<AtomicU64>,
output_batch_stats: BatchSizeStats,
}
impl<IF, T, L> ExchangeReceiver<IF, T, L>
where
T: Send + 'static + Clone,
{
pub(crate) fn new(
worker_index: usize,
location: OperatorLocation,
exchange: Arc<Exchange<(T, bool)>>,
init: IF,
start_wait_usecs: Arc<AtomicU64>,
combine: L,
) -> Self {
debug_assert!(worker_index < Runtime::num_workers());
Self {
worker_index,
location,
init,
combine,
exchange,
flush_count: 0,
flush_complete: false,
output_batch_stats: BatchSizeStats::new(),
start_wait_usecs,
total_wait_time: Arc::new(AtomicU64::new(0)),
}
}
}
impl<D, T, L> Operator for ExchangeReceiver<D, T, L>
where
D: 'static,
T: Send + 'static + Clone,
L: 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("ExchangeReceiver")
}
fn location(&self) -> OperatorLocation {
self.location
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
EXCHANGE_WAIT_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.total_wait_time.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.deserialization_usecs.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.deserialized_bytes.load(Ordering::Acquire)),
});
}
fn is_async(&self) -> bool {
true
}
fn register_ready_callback<F>(&mut self, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
let start_wait_usecs = self.start_wait_usecs.clone();
let total_wait_time = self.total_wait_time.clone();
let exchange = self.exchange.clone();
let worker_index = self.worker_index;
let cb = move || {
if exchange.ready_to_receive(worker_index) {
let start = start_wait_usecs.swap(0, Ordering::Acquire);
if start != 0 {
let end = current_time_usecs();
if end > start {
let wait_time_usecs = end - start;
total_wait_time.fetch_add(wait_time_usecs, Ordering::AcqRel);
}
}
}
cb()
};
self.exchange
.register_receiver_callback(self.worker_index, cb)
}
fn ready(&self) -> bool {
self.exchange.ready_to_receive(self.worker_index)
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flush_complete = false;
}
fn is_flush_complete(&self) -> bool {
self.flush_complete
}
}
impl<D, IF, T, L> SourceOperator<D> for ExchangeReceiver<IF, T, L>
where
D: NumEntries + 'static,
T: Clone + Send + 'static,
IF: Fn() -> D + 'static,
L: Fn(&mut D, T) + 'static,
{
async fn eval(&mut self) -> D {
debug_assert!(self.ready());
let mut combined = (self.init)();
let res = self
.exchange
.try_receive_all(self.worker_index, |(x, flushed)| {
if flushed {
self.flush_count += 1;
}
(self.combine)(&mut combined, x)
});
if self.flush_count == Runtime::num_workers() {
self.flush_complete = true;
self.flush_count = 0;
}
debug_assert!(res);
self.output_batch_stats
.add_batch(combined.num_entries_deep());
combined
}
}
#[derive(Hash, PartialEq, Eq)]
struct ClientsId;
impl TypedMapKey<LocalStoreMarker> for ClientsId {
type Value = Arc<Clients>;
}
#[derive(Hash, PartialEq, Eq)]
struct DirectoryId;
impl TypedMapKey<LocalStoreMarker> for DirectoryId {
type Value = ExchangeDirectory;
}
pub fn new_exchange_operators<TI, TO, TE, IF, PL, CL, D>(
location: OperatorLocation,
init: IF,
partition: PL,
deserialize: D,
combine: CL,
) -> Option<(ExchangeSender<TI, TE, PL>, ExchangeReceiver<IF, TE, CL>)>
where
TO: Clone,
TE: Send + 'static + Clone,
IF: Fn() -> TO + 'static,
PL: FnMut(TI, &mut Vec<Mailbox<TE>>) + 'static,
D: Fn(AlignedVec) -> TE + Send + Sync + 'static,
CL: Fn(&mut TO, TE) + 'static,
{
if Runtime::num_workers() == 1 {
return None;
}
let runtime = Runtime::runtime().unwrap();
let worker_index = Runtime::worker_index();
let exchange_id = runtime.sequence_next().try_into().unwrap();
let start_wait_usecs = Arc::new(AtomicU64::new(0));
let exchange = Exchange::with_runtime(
&runtime,
exchange_id,
Box::new(move |mut vec| {
let flush = match vec.pop().unwrap() {
0 => false,
1 => true,
_ => unreachable!(),
};
(deserialize(vec), flush)
}),
);
let sender = ExchangeSender::new(
worker_index,
location,
exchange.clone(),
start_wait_usecs.clone(),
partition,
);
let receiver = ExchangeReceiver::new(
worker_index,
location,
exchange,
init,
start_wait_usecs,
combine,
);
Some((sender, receiver))
}
#[cfg(test)]
mod tests {
use super::Exchange;
use crate::{
Circuit, RootCircuit,
circuit::{
Runtime,
runtime::{WorkerLocation, WorkerLocations},
schedule::{DynamicScheduler, Scheduler},
},
operator::{
Generator,
communication::{Mailbox, new_exchange_operators},
},
storage::file::{to_bytes, to_bytes_dyn},
trace::aligned_deserialize,
};
use std::{iter::repeat, thread::yield_now};
const ROUNDS: usize = if cfg!(miri) { 128 } else { 2048 };
#[test]
#[cfg_attr(miri, ignore)]
fn test_exchange() {
const WORKERS: usize = 16;
let hruntime = Runtime::run(WORKERS, |_parker| {
let exchange = Exchange::with_runtime(
&Runtime::runtime().unwrap(),
0,
Box::new(|data| aligned_deserialize(&data[..])),
);
for round in 0..ROUNDS {
let output_data = vec![round; WORKERS];
loop {
if exchange.try_send_all_with_serializer(
Runtime::worker_index(),
repeat(round),
|round| to_bytes(&round).unwrap(),
) {
break;
}
yield_now();
}
let mut input_data = Vec::with_capacity(WORKERS);
loop {
if exchange.try_receive_all(Runtime::worker_index(), |x| input_data.push(x)) {
break;
}
yield_now();
}
assert_eq!(input_data, output_data);
}
})
.expect("failed to start runtime");
hruntime.join().unwrap();
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_exchange_operators_dynamic() {
test_exchange_operators::<DynamicScheduler>();
}
fn test_exchange_operators<S>()
where
S: Scheduler + 'static,
{
fn do_test<S>(workers: usize)
where
S: Scheduler + 'static,
{
let hruntime = Runtime::run(workers, move |_parker| {
let circuit = RootCircuit::build_with_scheduler::<_, _, S>(move |circuit| {
let mut n: usize = 0;
let source = circuit.add_source(Generator::new(move || {
let result = n;
n += 1;
result
}));
let (sender, receiver) = new_exchange_operators(
None,
Vec::new,
move |n, vals| {
for location in WorkerLocations::new() {
match location {
WorkerLocation::Local => vals.push(Mailbox::Plain(n)),
WorkerLocation::Remote => {
vals.push(Mailbox::Tx(to_bytes_dyn(&n).unwrap()))
}
}
}
},
|data| aligned_deserialize(&data[..]),
|v: &mut Vec<usize>, n| v.push(n),
)
.unwrap();
let mut round = 0;
circuit
.add_exchange(sender, receiver, &source)
.inspect(move |v| {
assert_eq!(&vec![round; workers], v);
round += 1;
});
Ok(())
})
.unwrap()
.0;
for _ in 1..ROUNDS {
circuit.transaction().unwrap();
}
})
.expect("failed to start runtime");
hruntime.join().unwrap();
}
do_test::<S>(2);
do_test::<S>(16);
do_test::<S>(32);
}
}