use crate::{
NumEntries, WeakRuntime,
circuit::{
GlobalNodeId, Host, LocalStoreMarker, OwnershipPreference, Runtime, Scope,
metadata::{
BatchSizeStats, EXCHANGE_DESERIALIZATION_TIME_SECONDS, EXCHANGE_DESERIALIZED_BYTES,
EXCHANGE_WAIT_TIME_SECONDS, INPUT_BATCHES_STATS, MetaItem, OUTPUT_BATCHES_STATS,
OperatorLocation, OperatorMeta,
},
operator_traits::{Operator, SinkOperator, SourceOperator},
runtime::{WorkerLocation, WorkerLocations},
tokio::TOKIO,
},
circuit_cache_key,
storage::file::format::FixedLen,
};
use binrw::{BinRead, BinWrite};
use crossbeam_utils::CachePadded;
use feldera_samply::Span;
use feldera_storage::fbuf::FBuf;
use itertools::Itertools;
use rkyv::AlignedVec;
use size_of::HumanBytes;
use std::{
borrow::Cow,
collections::HashMap,
fmt::Debug,
io::{Cursor, ErrorKind, IoSlice},
iter::zip,
marker::PhantomData,
mem::MaybeUninit,
net::SocketAddr,
ops::Range,
pin::Pin,
sync::{
Arc, Mutex, MutexGuard, OnceLock, RwLock,
atomic::{AtomicIsize, AtomicPtr, AtomicU64, AtomicUsize, Ordering},
},
time::{Duration, Instant, SystemTime},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::{Notify, OnceCell, mpsc::error::SendError},
time::sleep,
};
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{info, warn};
use typedmap::TypedMapKey;
fn current_time_usecs() -> u64 {
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros().try_into().unwrap_or(u64::MAX))
.unwrap_or(0)
}
circuit_cache_key!(local ExchangeCacheId<T>(ExchangeId => Arc<Exchange<T>>));
#[binrw::binrw]
#[brw(little)]
struct ExchangeHeader {
exchange_id: ExchangeId,
sender: u32,
receiver: u32,
#[brw(align_after(16))]
data_len: u32,
}
impl FixedLen for ExchangeHeader {
const LEN: usize = 16;
}
impl ExchangeHeader {
fn to_bytes(&self) -> [u8; Self::LEN] {
let mut cursor = Cursor::new([0; Self::LEN]);
self.write_le(&mut cursor).unwrap();
assert_eq!(cursor.position(), Self::LEN as u64);
cursor.into_inner()
}
fn from_bytes(bytes: &[u8; Self::LEN]) -> Self {
let mut cursor = Cursor::new(bytes);
let this = Self::read_le(&mut cursor).unwrap();
assert_eq!(cursor.position(), Self::LEN as u64);
this
}
async fn read<S>(stream: &mut S) -> std::io::Result<Option<Self>>
where
S: AsyncRead + Unpin,
{
let mut buf = [0; Self::LEN];
match stream.read(&mut buf).await? {
0 => Ok(None),
n => {
stream.read_exact(&mut buf[n..]).await?;
Ok(Some(ExchangeHeader::from_bytes(&buf)))
}
}
}
}
struct ByteBoundedSender {
tx: tokio::sync::mpsc::UnboundedSender<ExchangeMessage>,
bound: Arc<ByteBound>,
}
impl ByteBoundedSender {
pub fn send(
&self,
message: ExchangeMessage,
) -> Result<Option<Arc<ByteBound>>, SendError<ExchangeMessage>> {
let len = message.data.len().try_into().unwrap();
self.tx.send(message)?;
Ok(self.bound.reserve(len))
}
}
struct ByteBoundedReceiver {
rx: tokio::sync::mpsc::UnboundedReceiver<ExchangeMessage>,
bound: Arc<ByteBound>,
}
impl ByteBoundedReceiver {
pub async fn recv(&mut self) -> Option<ExchangeMessage> {
let message = self.rx.recv().await?;
let len = message.data.len().try_into().unwrap();
let before = self.bound.remaining.fetch_add(len, Ordering::AcqRel);
let after = before + len;
if before < 0 && after >= 0 {
self.bound.notify.notify_waiters();
}
Some(message)
}
}
pub struct ByteBound {
remaining: AtomicIsize,
notify: Notify,
}
impl ByteBound {
fn reserve(self: &Arc<Self>, len: isize) -> Option<Arc<Self>> {
let remaining = self.remaining.fetch_sub(len, Ordering::AcqRel) - len;
(remaining < 0).then(|| self.clone())
}
pub async fn wait(&self) {
while let notified = self.notify.notified()
&& self.remaining.load(Ordering::Acquire) < 0
{
notified.await;
}
}
}
fn byte_bounded_channel(limit: usize) -> (ByteBoundedSender, ByteBoundedReceiver) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let bound = Arc::new(ByteBound {
remaining: AtomicIsize::new(isize::try_from(limit).unwrap_or(isize::MAX)),
notify: Notify::new(),
});
(
ByteBoundedSender {
tx,
bound: bound.clone(),
},
ByteBoundedReceiver { rx, bound },
)
}
struct ExchangeMessage {
global_node_id: Arc<String>,
exchange_id: ExchangeId,
sender: usize,
data: Vec<FBuf>,
}
pub struct ExchangeClient {
tx: ByteBoundedSender,
}
impl ExchangeClient {
async fn new(remote_address: SocketAddr, remote_workers: &Range<usize>) -> Self {
let (tx, rx) = byte_bounded_channel(10_000_000);
TOKIO.spawn(Self::run(remote_address, remote_workers.clone(), rx));
Self { tx }
}
async fn run(
remote_address: SocketAddr,
remote_workers: Range<usize>,
mut rx: ByteBoundedReceiver,
) {
let mut connection = loop {
match TcpStream::connect(remote_address).await {
Ok(stream) => break stream,
Err(error) => {
info!("connection to {remote_address} failed ({error}), waiting to retry")
}
}
sleep(std::time::Duration::from_millis(1000)).await;
};
connection.set_nodelay(true).unwrap();
connection.set_zero_linger().unwrap();
while let Some(message) = rx.recv().await {
let n = remote_workers.len();
let mut headers = Vec::with_capacity(n);
for (data, receiver) in zip(&message.data, remote_workers.clone()) {
headers.push(
ExchangeHeader {
exchange_id: message.exchange_id,
sender: message.sender as u32,
receiver: receiver as u32,
data_len: data.len().try_into().unwrap(),
}
.to_bytes(),
);
}
let zeros = [0; 16];
let mut slices = Vec::with_capacity(n * 3);
let mut header = headers.iter();
for data in &message.data {
slices.push(IoSlice::new(header.next().unwrap()));
if !data.is_empty() {
slices.push(IoSlice::new(data.as_slice()));
}
let pad = &zeros[..data.len().next_multiple_of(16) - data.len()];
if !pad.is_empty() {
slices.push(IoSlice::new(pad));
}
}
let size = slices.iter().map(|slice| slice.len()).sum::<usize>();
let mut bufs = slices.as_mut_slice();
let _span = Span::new("send")
.with_category("Exchange")
.with_tooltip(|| {
format!(
"{} send {}",
&message.global_node_id,
HumanBytes::from(size),
)
});
while !bufs.is_empty() {
let n = connection
.write_vectored(bufs)
.await
.expect("lost connection to remote host");
IoSlice::advance_slices(&mut bufs, n);
}
}
}
pub fn send(
&self,
global_node_id: Arc<String>,
exchange_id: ExchangeId,
sender: usize,
data: Vec<FBuf>,
) -> Option<Arc<ByteBound>> {
self.tx
.send(ExchangeMessage {
global_node_id,
exchange_id,
sender,
data,
})
.expect("remote exchange failed")
}
}
pub type ExchangeId = u32;
pub trait ExchangeDelivery {
fn name(&self) -> &str;
fn received<'a>(
&'a self,
sender: usize,
data: Vec<AlignedVec>,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
}
#[derive(Clone, Default)]
pub struct ExchangeDirectory(
Arc<RwLock<HashMap<ExchangeId, Arc<dyn ExchangeDelivery + Send + Sync>>>>,
);
impl ExchangeDirectory {
pub fn for_runtime(runtime: &Runtime) -> Self {
runtime
.local_store()
.entry(DirectoryId)
.or_insert_with(|| Self(Arc::new(RwLock::new(HashMap::new()))))
.clone()
}
pub fn get(&self, exchange_id: ExchangeId) -> Option<Arc<dyn ExchangeDelivery + Send + Sync>> {
self.0.read().unwrap().get(&exchange_id).cloned()
}
pub fn insert(
&self,
exchange_id: ExchangeId,
exchange: Arc<dyn ExchangeDelivery + Send + Sync>,
) {
self.0
.write()
.unwrap()
.entry(exchange_id)
.and_modify(|_| panic!())
.or_insert_with(|| exchange);
}
}
struct ExchangeServer {
receivers: Range<usize>,
directory: ExchangeDirectory,
stream: TcpStream,
}
impl ExchangeServer {
async fn serve(mut self) -> std::io::Result<()> {
while let Some(header) = ExchangeHeader::read(&mut self.stream).await? {
let start = Instant::now();
let exchange_id = header.exchange_id;
let sender = header.sender as usize;
let mut bytes = self.receivers.len() * ExchangeHeader::LEN;
let mut header = Some(header);
let mut data = Vec::with_capacity(self.receivers.len());
for _ in self.receivers.clone() {
let header = if let Some(header) = header.take() {
header
} else {
ExchangeHeader::read(&mut self.stream)
.await?
.ok_or_else(|| std::io::Error::from(ErrorKind::UnexpectedEof))?
};
let len = header.data_len as usize;
let padded_len = len.next_multiple_of(16);
let mut payload = AlignedVec::with_capacity(padded_len);
let pointer = payload.as_mut_ptr() as *mut MaybeUninit<u8>;
let mut slice = unsafe { std::slice::from_raw_parts_mut(pointer, padded_len) };
while !slice.is_empty() {
self.stream.read_buf(&mut slice).await?;
}
unsafe { payload.set_len(len) };
data.push(payload);
bytes += padded_len;
}
let receiver = self
.directory
.get(exchange_id)
.expect("should have exchange for received data");
Span::new("receive")
.with_start(start)
.with_category("Exchange")
.with_tooltip(|| {
format!(
"{} receive {} from worker {sender}",
receiver.name(),
HumanBytes::from(bytes),
)
})
.record();
receiver.received(sender, data).await;
}
Ok(())
}
}
pub struct ExchangeClients {
runtime: WeakRuntime,
local_workers: Range<usize>,
listener: OnceCell<Option<ExchangeListener>>,
clients: Vec<(Host, OnceCell<ExchangeClient>)>,
}
impl ExchangeClients {
pub fn for_runtime(runtime: &Runtime) -> Arc<ExchangeClients> {
runtime
.local_store()
.entry(ClientsId)
.or_insert_with(|| {
Arc::new(ExchangeClients::new(runtime))
})
.clone()
}
fn new(runtime: &Runtime) -> ExchangeClients {
Self {
local_workers: runtime.layout().local_workers(),
runtime: runtime.downgrade(),
listener: Default::default(),
clients: runtime
.layout()
.other_hosts()
.map(|host| (host.clone(), OnceCell::new()))
.collect(),
}
}
pub async fn connect(&self, worker: usize) -> &ExchangeClient {
self.listener
.get_or_init(|| async {
if let Some(runtime) = self.runtime.upgrade()
&& let Some(local_address) = runtime.layout().local_address()
{
let directory = ExchangeDirectory::for_runtime(&runtime);
Some(ExchangeListener::new(
local_address,
runtime.take_exchange_listener(),
directory,
self.local_workers.clone(),
))
} else {
None
}
})
.await;
let (host, cell) = self
.clients
.iter()
.find(|(host, _client)| host.workers.contains(&worker))
.unwrap();
cell.get_or_init(|| ExchangeClient::new(host.address, &host.workers))
.await
}
}
struct CallbackInner {
cb: Option<Box<dyn Fn() + Send + Sync>>,
}
impl CallbackInner {
fn empty() -> Self {
Self { cb: None }
}
fn new<F>(cb: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
let cb = Box::new(cb) as Box<dyn Fn() + Send + Sync>;
Self { cb: Some(cb) }
}
}
struct Callback(AtomicPtr<CallbackInner>);
impl Callback {
fn empty() -> Self {
Self(AtomicPtr::new(Box::into_raw(Box::new(
CallbackInner::empty(),
))))
}
fn set_callback(&self, cb: impl Fn() + Send + Sync + 'static) {
let old_callback = self.0.swap(
Box::into_raw(Box::new(CallbackInner::new(cb))),
Ordering::AcqRel,
);
let old_callback = unsafe { Box::from_raw(old_callback) };
drop(old_callback);
}
fn call(&self) {
if let Some(cb) = &unsafe { &*self.0.load(Ordering::Acquire) }.cb {
cb()
}
}
}
#[derive(Clone, Debug)]
pub enum Mailbox<T> {
Tx(FBuf),
Rx(AlignedVec),
Plain(T),
}
impl<T> Mailbox<T> {
pub fn into_plain(self) -> Option<T> {
match self {
Self::Plain(item) => Some(item),
_ => None,
}
}
pub fn into_tx(self) -> Option<FBuf> {
match self {
Mailbox::Tx(bytes) => Some(bytes),
Mailbox::Rx(_) | Mailbox::Plain(_) => None,
}
}
fn deserialize<D>(self, deserialize: D) -> T
where
D: Fn(AlignedVec) -> T,
{
match self {
Mailbox::Plain(item) => item,
Mailbox::Tx(_) => unreachable!(),
Mailbox::Rx(bytes) => deserialize(bytes),
}
}
}
pub(crate) struct Exchange<T> {
exchange_id: ExchangeId,
receiver_global_node_id: OnceLock<Arc<String>>,
npeers: usize,
local_workers: Range<usize>,
receiver_counters: Vec<AtomicUsize>,
receiver_callbacks: Vec<Callback>,
receiver_notifies: Vec<Notify>,
sender_counters: Vec<CachePadded<AtomicUsize>>,
sender_callbacks: Vec<Callback>,
sender_notifies: Vec<Notify>,
clients: Arc<ExchangeClients>,
mailboxes: Vec<Mutex<Option<Mailbox<T>>>>,
deserialization_usecs: AtomicU64,
deserialized_bytes: AtomicUsize,
}
#[allow(dead_code)]
struct ExchangeListener(DropGuard);
impl ExchangeListener {
fn new(
local_address: SocketAddr,
exchange_listener: Option<std::net::TcpListener>,
directory: ExchangeDirectory,
receivers: Range<usize>,
) -> Self {
let token = CancellationToken::new();
let drop = token.clone().drop_guard();
TOKIO.spawn(async move {
info!("listening on {local_address}");
let listener = match exchange_listener {
Some(exchange_listener) => {
exchange_listener
.set_nonblocking(true)
.expect("should be able to set nonblocking mode");
TcpListener::from_std(exchange_listener).unwrap()
}
None => TcpListener::bind(local_address).await.unwrap(),
};
while let Some(stream) = tokio::select! {
stream = listener.accept() => Some(stream),
_ = token.cancelled() => None,
} {
match stream {
Ok((stream, _address)) => {
tokio::spawn(
ExchangeServer {
receivers: receivers.clone(),
directory: directory.clone(),
stream,
}
.serve(),
);
}
Err(error) => warn!("Error accepting connection: {error}"),
}
}
});
Self(drop)
}
}
impl<T> Exchange<T>
where
T: Clone + Debug + Send + 'static,
{
fn new(
runtime: &Runtime,
clients: Arc<ExchangeClients>,
exchange_id: ExchangeId,
directory: &ExchangeDirectory,
) -> Arc<Self> {
let npeers = Runtime::num_workers();
let mailboxes: Vec<Mutex<Option<Mailbox<T>>>> =
(0..npeers * npeers).map(|_| Default::default()).collect();
let layout = runtime.layout();
let npeers = layout.n_workers();
let exchange = Arc::new(Self {
exchange_id,
receiver_global_node_id: Default::default(),
npeers,
local_workers: layout.local_workers(),
clients,
receiver_counters: (0..npeers).map(|_| AtomicUsize::new(0)).collect(),
receiver_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
receiver_notifies: (0..npeers).map(|_| Notify::new()).collect(),
sender_counters: (0..npeers)
.map(|_| CachePadded::new(AtomicUsize::new(layout.local_workers().len())))
.collect(),
sender_notifies: (0..npeers).map(|_| Notify::new()).collect(),
sender_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
mailboxes,
deserialization_usecs: AtomicU64::new(0),
deserialized_bytes: AtomicUsize::new(0),
});
directory.insert(exchange_id, exchange.clone());
exchange
}
pub fn exchange_id(&self) -> ExchangeId {
self.exchange_id
}
fn mailbox_index(&self, sender: usize, receiver: usize) -> usize {
debug_assert!(sender < self.npeers);
debug_assert!(receiver < self.npeers);
sender * self.npeers + receiver
}
fn ready_to_receive(&self, receiver: usize) -> bool {
debug_assert!(receiver < self.npeers);
self.receiver_counters[receiver].load(Ordering::Acquire) == self.npeers
}
pub(crate) fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
debug_assert!(receiver < self.npeers);
self.receiver_callbacks[receiver].set_callback(cb);
}
fn mailbox(&self, sender: usize, receiver: usize) -> MutexGuard<'_, Option<Mailbox<T>>> {
self.mailboxes[self.mailbox_index(sender, receiver)]
.lock()
.unwrap()
}
pub(crate) fn with_runtime(runtime: &Runtime, exchange_id: ExchangeId) -> Arc<Self> {
let directory = ExchangeDirectory::for_runtime(runtime);
let clients = ExchangeClients::for_runtime(runtime);
runtime
.local_store()
.entry(ExchangeCacheId::new(exchange_id))
.or_insert_with(|| Exchange::new(runtime, clients, exchange_id, &directory))
.value()
.clone()
}
pub(crate) fn register_sender_callback<F>(&self, sender: usize, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
debug_assert!(sender < self.npeers);
self.sender_callbacks[sender].set_callback(cb);
}
pub fn ready_to_send(&self, sender: usize) -> bool {
self.sender_counters[sender].load(Ordering::Acquire) == self.local_workers.len()
}
async fn wait_for_ready_to_send(&self, sender: usize) {
fn ready_to_send<T>(this: &Exchange<T>, sender: usize) -> bool {
this.sender_counters[sender]
.compare_exchange(
this.local_workers.len(),
0,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
if !ready_to_send(self, sender) {
loop {
let notify = self.sender_notifies[sender].notified();
if ready_to_send(self, sender) {
break;
}
notify.await;
}
}
}
pub(crate) async fn send_all_with_serializer<F>(
self: &Arc<Self>,
global_node_id: &Arc<String>,
data: impl Iterator<Item = T>,
mut serialize: F,
) where
F: FnMut(T) -> FBuf + Send + Sync,
{
self.send_all(
global_node_id,
data.zip(WorkerLocations::new())
.map(|(data, location)| match location {
WorkerLocation::Local => Mailbox::Plain(data),
WorkerLocation::Remote => Mailbox::Tx(serialize(data)),
}),
)
.await
}
fn deliver(&self, sender: usize, receiver: usize, item: Mailbox<T>) {
let mut mailbox = self.mailbox(sender, receiver);
assert!(mailbox.is_none());
*mailbox = Some(item);
let old_counter = self.receiver_counters[receiver].fetch_add(1, Ordering::AcqRel);
if old_counter >= self.npeers - 1 {
self.receiver_callbacks[receiver].call();
self.receiver_notifies[receiver].notify_waiters();
}
}
pub(crate) async fn send_all(
self: &Arc<Self>,
global_node_id: &Arc<String>,
mut data: impl Iterator<Item = Mailbox<T>>,
) {
let sender = Runtime::worker_index();
self.wait_for_ready_to_send(sender).await;
let runtime = Runtime::runtime().unwrap();
let layout = runtime.layout();
let worker_locations = WorkerLocations::for_layout(layout);
for receivers in layout.all_hosts() {
match worker_locations[receivers.start] {
WorkerLocation::Local => {
for receiver in receivers.clone() {
let item = data.next().expect("data should include one item per peer");
self.deliver(sender, receiver, item);
}
}
WorkerLocation::Remote => {
let items = receivers
.clone()
.map(|_| {
data.next()
.expect("data should include one item per peer")
.into_tx()
.expect("remote mailboxes should always be serialized")
})
.collect_vec();
let _ = self.clients.connect(receivers.start).await.send(
global_node_id.clone(),
self.exchange_id,
sender,
items,
);
}
}
}
}
pub(crate) async fn receive_all<D>(&self, deserialize: D) -> Vec<T>
where
D: Fn(AlignedVec) -> T,
{
let receiver = Runtime::worker_index();
fn may_receive<T>(exchange: &Exchange<T>, receiver: usize) -> bool {
exchange.receiver_counters[receiver]
.compare_exchange(exchange.npeers, 0, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
if !may_receive(self, receiver) {
loop {
let notifier = self.receiver_notifies[receiver].notified();
if may_receive(self, receiver) {
break;
}
notifier.await;
}
}
let mut data = Vec::with_capacity(self.npeers);
for sender in 0..self.npeers {
let mailbox = self.mailbox(sender, receiver).take().unwrap();
data.push(mailbox.deserialize(&deserialize));
let old_counter = self.sender_counters[sender].fetch_add(1, Ordering::AcqRel);
if old_counter + 1 >= self.local_workers.len() {
self.sender_callbacks[sender].call();
self.sender_notifies[sender].notify_waiters();
}
}
data
}
}
impl<T> ExchangeDelivery for Exchange<T>
where
T: Clone + Debug + Send + 'static,
{
fn name(&self) -> &str {
self.receiver_global_node_id.get().unwrap()
}
fn received<'a>(
&'a self,
sender: usize,
data: Vec<AlignedVec>,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(async move {
self.wait_for_ready_to_send(sender).await;
for (receiver, data) in zip(self.local_workers.clone(), data) {
self.deliver(sender, receiver, Mailbox::Rx(data));
}
})
}
}
pub struct ExchangeSender<D, T, L>
where
T: Send + 'static + Clone,
{
global_node_id: Arc<String>,
location: OperatorLocation,
partition: L,
outputs: Vec<Mailbox<T>>,
exchange: Arc<Exchange<(T, bool)>>,
input_batch_stats: BatchSizeStats,
flushed: bool,
start_wait_usecs: Arc<AtomicU64>,
phantom: PhantomData<D>,
}
impl<D, T, L> ExchangeSender<D, T, L>
where
T: Send + 'static + Clone,
{
fn new(
location: OperatorLocation,
exchange: Arc<Exchange<(T, bool)>>,
start_wait_usecs: Arc<AtomicU64>,
partition: L,
) -> Self {
Self {
global_node_id: Arc::new(format!("ExchangeSender {}", exchange.exchange_id)),
location,
partition,
outputs: Vec::with_capacity(Runtime::num_workers()),
exchange,
input_batch_stats: BatchSizeStats::new(),
flushed: false,
start_wait_usecs,
phantom: PhantomData,
}
}
}
impl<D, T, L> Operator for ExchangeSender<D, T, L>
where
D: 'static,
T: Send + 'static + Clone + Debug,
L: 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("ExchangeSender")
}
fn init(&mut self, global_id: &GlobalNodeId) {
self.global_node_id = Arc::new(format!("ExchangeSender {}", global_id.node_identifier()));
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
});
}
fn location(&self) -> OperatorLocation {
self.location
}
fn clock_start(&mut self, _scope: Scope) {}
fn clock_end(&mut self, _scope: Scope) {}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flushed = true;
}
}
impl<D, T, L> SinkOperator<D> for ExchangeSender<D, T, L>
where
D: Clone + Debug + NumEntries + 'static,
T: Clone + Debug + Send + 'static,
L: FnMut(D, &mut Vec<Mailbox<T>>) + 'static,
{
async fn eval(&mut self, input: &D) {
self.eval_owned(input.clone()).await
}
async fn eval_owned(&mut self, input: D) {
self.input_batch_stats.add_batch(input.num_entries_deep());
debug_assert!(self.ready());
self.outputs.clear();
(self.partition)(input, &mut self.outputs);
self.start_wait_usecs
.store(current_time_usecs(), Ordering::Release);
let data = self.outputs.drain(..).map(|mailbox| match mailbox {
Mailbox::Tx(mut data) => {
data.push(self.flushed as u8);
Mailbox::Tx(data)
}
Mailbox::Rx(_) => unreachable!(),
Mailbox::Plain(item) => Mailbox::Plain((item, self.flushed)),
});
self.exchange.send_all(&self.global_node_id, data).await;
self.flushed = false;
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
pub struct ExchangeReceiver<IF, T, L, D>
where
T: Send + 'static + Clone,
{
worker_index: usize,
location: OperatorLocation,
init: IF,
deserialize: D,
combine: L,
exchange: Arc<Exchange<(T, bool)>>,
flush_count: usize,
flush_complete: bool,
start_wait_usecs: Arc<AtomicU64>,
total_wait_time: Arc<AtomicU64>,
output_batch_stats: BatchSizeStats,
}
impl<IF, T, L, D> ExchangeReceiver<IF, T, L, D>
where
T: Send + 'static + Clone + Debug,
{
pub(crate) fn new(
worker_index: usize,
location: OperatorLocation,
exchange: Arc<Exchange<(T, bool)>>,
init: IF,
start_wait_usecs: Arc<AtomicU64>,
deserialize: D,
combine: L,
) -> Self {
debug_assert!(worker_index < Runtime::num_workers());
Self {
worker_index,
location,
init,
combine,
deserialize,
exchange,
flush_count: 0,
flush_complete: false,
output_batch_stats: BatchSizeStats::new(),
start_wait_usecs,
total_wait_time: Arc::new(AtomicU64::new(0)),
}
}
fn is_ready(&self) -> bool {
self.exchange.ready_to_receive(self.worker_index)
}
}
impl<IF, T, L, D> Operator for ExchangeReceiver<IF, T, L, D>
where
IF: 'static,
T: Send + 'static + Clone + Debug,
L: 'static,
D: 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("ExchangeReceiver")
}
fn init(&mut self, global_id: &GlobalNodeId) {
let _ = self.exchange.receiver_global_node_id.set(Arc::new(format!(
"ExchangeReceiver {}",
global_id.node_identifier()
)));
}
fn location(&self) -> OperatorLocation {
self.location
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
EXCHANGE_WAIT_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.total_wait_time.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.deserialization_usecs.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.deserialized_bytes.load(Ordering::Acquire)),
});
}
fn is_async(&self) -> bool {
true
}
fn register_ready_callback<F>(&mut self, cb: F)
where
F: Fn() + Send + Sync + 'static,
{
let start_wait_usecs = self.start_wait_usecs.clone();
let total_wait_time = self.total_wait_time.clone();
let exchange = self.exchange.clone();
let worker_index = self.worker_index;
let cb = move || {
if exchange.ready_to_receive(worker_index) {
let start = start_wait_usecs.swap(0, Ordering::Acquire);
if start != 0 {
let end = current_time_usecs();
if end > start {
let wait_time_usecs = end - start;
total_wait_time.fetch_add(wait_time_usecs, Ordering::AcqRel);
}
}
}
cb()
};
self.exchange
.register_receiver_callback(self.worker_index, cb)
}
fn ready(&self) -> bool {
self.is_ready()
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flush_complete = false;
}
fn is_flush_complete(&self) -> bool {
self.flush_complete
}
}
pub fn pop_flushed(vec: &mut AlignedVec) -> bool {
match vec.pop().unwrap() {
0 => false,
1 => true,
_ => unreachable!(),
}
}
impl<O, IF, T, L, D> SourceOperator<O> for ExchangeReceiver<IF, T, L, D>
where
O: NumEntries + 'static,
T: Clone + Debug + Send + 'static,
IF: Fn() -> O + 'static,
L: Fn(&mut O, T) + 'static,
D: Fn(AlignedVec) -> T + Send + Sync + 'static,
{
async fn eval(&mut self) -> O {
debug_assert!(self.ready());
let deserialize = |mut vec: AlignedVec| {
let flushed = pop_flushed(&mut vec);
let value = (self.deserialize)(vec);
(value, flushed)
};
let mut combined = (self.init)();
let res = self.exchange.receive_all(deserialize).await;
for (data, flushed) in res {
if flushed {
self.flush_count += 1;
}
(self.combine)(&mut combined, data)
}
if self.flush_count == Runtime::num_workers() {
self.flush_complete = true;
self.flush_count = 0;
}
self.output_batch_stats
.add_batch(combined.num_entries_deep());
combined
}
}
#[derive(Hash, PartialEq, Eq)]
struct ClientsId;
impl TypedMapKey<LocalStoreMarker> for ClientsId {
type Value = Arc<ExchangeClients>;
}
#[derive(Hash, PartialEq, Eq)]
struct DirectoryId;
impl TypedMapKey<LocalStoreMarker> for DirectoryId {
type Value = ExchangeDirectory;
}
pub fn new_exchange_operators<TI, TO, TE, IF, PL, CL, D>(
location: OperatorLocation,
init: IF,
partition: PL,
deserialize: D,
combine: CL,
) -> Option<(ExchangeSender<TI, TE, PL>, ExchangeReceiver<IF, TE, CL, D>)>
where
TO: Clone,
TE: Send + 'static + Clone + Debug,
IF: Fn() -> TO + 'static,
PL: FnMut(TI, &mut Vec<Mailbox<TE>>) + 'static,
D: Fn(AlignedVec) -> TE + Send + Sync + 'static,
CL: Fn(&mut TO, TE) + 'static,
{
if Runtime::num_workers() == 1 {
return None;
}
let runtime = Runtime::runtime().unwrap();
let worker_index = Runtime::worker_index();
let exchange_id = runtime.sequence_next().try_into().unwrap();
let start_wait_usecs = Arc::new(AtomicU64::new(0));
let exchange = Exchange::with_runtime(&runtime, exchange_id);
let sender = ExchangeSender::new(
location,
exchange.clone(),
start_wait_usecs.clone(),
partition,
);
let receiver = ExchangeReceiver::new(
worker_index,
location,
exchange,
init,
start_wait_usecs,
deserialize,
combine,
);
Some((sender, receiver))
}
#[cfg(test)]
mod tests {
use feldera_storage::tokio::TOKIO;
use itertools::Itertools;
use super::Exchange;
use crate::{
Circuit, RootCircuit,
circuit::{
CircuitConfig, Layout, Runtime,
runtime::{WorkerLocation, WorkerLocations},
schedule::{DynamicScheduler, Scheduler},
},
operator::{
Generator,
communication::{Mailbox, new_exchange_operators},
},
storage::file::{to_bytes, to_bytes_dyn},
trace::aligned_deserialize,
};
use std::{
iter::{repeat, zip},
net::TcpListener,
sync::Arc,
};
const ROUNDS: usize = if cfg!(miri) { 128 } else { 2048 };
fn circuit() {
let exchange = Exchange::with_runtime(&Runtime::runtime().unwrap(), 0);
TOKIO.block_on(async {
let sender = Runtime::worker_index();
let n_workers = Runtime::num_workers();
let global_node_id = Arc::new(String::from("test_global_node_id"));
for round in 0..ROUNDS {
exchange
.send_all_with_serializer(&global_node_id, repeat((sender, round)), |data| {
to_bytes(&data).unwrap()
})
.await;
let received = exchange
.receive_all(|data| aligned_deserialize(&data[..]))
.await;
let expected = (0..n_workers).map(|worker| (worker, round)).collect_vec();
assert_eq!(received, expected);
}
});
}
fn test_circuit(
workers: usize,
hosts: usize,
circuit: impl FnOnce() + Copy + Clone + Send + Sync + 'static,
) {
match hosts {
0 => unreachable!(),
1 => {
let hruntime = Runtime::run(workers, move |_parker| circuit())
.expect("failed to start runtime");
hruntime.join().unwrap();
}
_ => {
assert!(workers >= hosts);
let exchange_listeners = (0..hosts)
.map(|_| {
TcpListener::bind("127.0.0.1:0")
.expect("should be able to bind a port on localhost")
})
.collect_vec();
let params = exchange_listeners
.iter()
.enumerate()
.map(|(index, listener)| {
(
listener
.local_addr()
.expect("should be able to get local address"),
workers / hosts + (index < workers % hosts) as usize,
)
})
.collect_vec();
let mut runtimes = Vec::with_capacity(hosts);
for ((local_address, _), exchange_listener) in
zip(params.iter(), exchange_listeners)
{
let cconf = CircuitConfig::from(
Layout::new_multihost(¶ms, *local_address).unwrap(),
)
.with_exchange_listener(exchange_listener);
runtimes.push(
Runtime::run(cconf, move |_parker| circuit())
.expect("failed to start runtime"),
);
}
for runtime in runtimes {
runtime.join().unwrap();
}
}
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn single_host() {
for workers in [2, 4, 8] {
test_circuit(workers, 1, circuit);
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn multihost() {
for (workers, hosts) in [(2, 2), (4, 2), (8, 2), (3, 3), (4, 4), (16, 4)] {
test_circuit(workers, hosts, circuit);
}
}
fn operator_circuit<S>()
where
S: Scheduler + 'static,
{
let circuit = RootCircuit::build_with_scheduler::<_, _, S>(move |circuit| {
let mut n: usize = 0;
let source = circuit.add_source(Generator::new(move || {
let result = n;
n += 1;
result
}));
let (sender, receiver) = new_exchange_operators(
None,
Vec::new,
move |n, vals| {
for location in WorkerLocations::new() {
match location {
WorkerLocation::Local => vals.push(Mailbox::Plain(n)),
WorkerLocation::Remote => {
vals.push(Mailbox::Tx(to_bytes_dyn(&n).unwrap()))
}
}
}
},
|data| aligned_deserialize(&data[..]),
|v: &mut Vec<usize>, n| v.push(n),
)
.unwrap();
let mut round = 0;
circuit
.add_exchange(sender, receiver, &source)
.inspect(move |v| {
assert_eq!(&vec![round; Runtime::num_workers()], v);
round += 1;
});
Ok(())
})
.unwrap()
.0;
for _ in 1..ROUNDS {
circuit.transaction().unwrap();
}
}
fn test_operators_single_host(circuit: impl FnOnce() + Copy + Clone + Send + Sync + 'static) {
for workers in [2, 16, 32] {
test_circuit(workers, 1, circuit);
}
}
fn test_operators_multihost(circuit: impl FnOnce() + Copy + Clone + Send + Sync + 'static) {
for (workers, hosts) in [(2, 2), (4, 2), (8, 2), (3, 3), (4, 4), (16, 4)] {
test_circuit(workers, hosts, circuit);
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn operators_single_host_dynamic() {
test_operators_single_host(operator_circuit::<DynamicScheduler>);
}
#[test]
#[cfg_attr(miri, ignore)]
fn operators_multihost_dynamic() {
test_operators_multihost(operator_circuit::<DynamicScheduler>);
}
}