use crate::{
Circuit, Runtime, Stream,
circuit::{
GlobalNodeId, OwnershipPreference, Scope,
circuit_builder::StreamId,
metadata::OperatorLocation,
operator_traits::{Operator, SinkOperator, SourceOperator},
runtime::{WorkerLocation, WorkerLocations},
},
circuit_cache_key,
operator::communication::{Mailbox, new_exchange_operators},
trace::{Batch, deserialize_indexed_wset, merge_batches, serialize_indexed_wset},
};
use arc_swap::ArcSwap;
use crossbeam::atomic::AtomicConsume;
use crossbeam_utils::CachePadded;
use std::{
borrow::Cow,
marker::PhantomData,
mem::MaybeUninit,
panic::Location,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
type NotifyCallback = dyn Fn() + Send + Sync + 'static;
circuit_cache_key!(GatherId<C, D>((StreamId, usize) => Stream<C, D>));
circuit_cache_key!(local GatherDataId<T>(usize => Arc<GatherData<(T, bool)>>));
circuit_cache_key!(MultihostGatherId<C, D>((GlobalNodeId, usize) => Stream<C, D>));
impl<C, B> Stream<C, B>
where
C: Circuit,
B: Send + 'static,
{
#[track_caller]
pub fn dyn_gather(&self, factories: &B::Factories, receiver_worker: usize) -> Stream<C, B>
where
B: Batch<Time = ()> + Send,
{
let workers = Runtime::num_workers();
assert!(receiver_worker < workers);
if workers == 1 {
return self.clone();
}
let runtime = Runtime::runtime().unwrap();
if runtime.layout().is_solo() {
self.dyn_gather_single_host(workers, runtime, factories, receiver_worker)
} else {
self.dyn_gather_multihost(factories, receiver_worker)
}
}
#[track_caller]
fn dyn_gather_single_host(
&self,
workers: usize,
runtime: Runtime,
factories: &B::Factories,
receiver_worker: usize,
) -> Stream<C, B>
where
B: Batch<Time = ()> + Send,
{
let location = Location::caller();
self.circuit()
.cache_get_or_insert_with(
GatherId::new((self.stream_id(), receiver_worker)),
move || {
let current_worker = Runtime::worker_index();
let gather_id = runtime.sequence_next();
let gather = runtime
.local_store()
.entry(GatherDataId::new(gather_id))
.or_insert_with(|| Arc::new(GatherData::new(workers, location)))
.value()
.clone();
let producer = unsafe { GatherProducer::new(gather.clone(), current_worker) };
if current_worker == receiver_worker {
self.circuit().add_exchange(
producer,
GatherConsumer::new(gather, factories),
self,
)
} else {
self.circuit().add_exchange(
producer,
EmptyGatherConsumer::new(location, factories),
self,
)
}
.set_persistent_id(
self.get_persistent_id()
.map(|name| format!("{name}.gather-{receiver_worker}"))
.as_deref(),
)
},
)
.clone()
}
#[track_caller]
fn dyn_gather_multihost(&self, factories: &B::Factories, receiver_worker: usize) -> Stream<C, B>
where
B: Batch<Time = ()> + Send,
{
let location = Location::caller();
self.circuit()
.cache_get_or_insert_with(
MultihostGatherId::new((self.origin_node_id().clone(), receiver_worker)),
move || {
let factories_clone = factories.clone();
let factories_clone2 = factories.clone();
let factories_clone3 = factories.clone();
let output = self.circuit().region("gather", || {
let (sender, receiver) = new_exchange_operators(
Some(location),
|| Vec::new(),
move |batch: B, batches: &mut Vec<Mailbox<B>>| {
let empty = B::dyn_empty(&factories_clone3);
let mut serializer_inner = None;
for location in WorkerLocations::new() {
match location {
WorkerLocation::Local => {
batches.push(Mailbox::Plain(empty.clone()))
}
WorkerLocation::Remote => {
batches.push(Mailbox::Tx(serialize_indexed_wset(
&empty,
serializer_inner.get_or_insert_default(),
)))
}
}
}
if Runtime::runtime()
.unwrap()
.layout()
.local_workers()
.contains(&receiver_worker)
{
batches[receiver_worker] = Mailbox::Plain(batch);
} else {
batches[receiver_worker] = Mailbox::Tx(serialize_indexed_wset(
&batch,
serializer_inner.get_or_insert_default(),
));
}
},
move |data| deserialize_indexed_wset(&factories_clone, &data),
|batches: &mut Vec<B>, batch: B| batches.push(batch),
)
.unwrap();
self.circuit()
.add_exchange(sender, receiver, self)
.apply_owned_named("gather shards", move |batches| {
merge_batches(&factories_clone2, batches, &None, &None)
})
});
output.set_persistent_id(
self.get_persistent_id()
.map(|name| format!("{name}.gather"))
.as_deref(),
)
},
)
.clone()
}
}
struct GatherData<T> {
is_valid: Box<[CachePadded<AtomicBool>]>,
values: Box<[CachePadded<MaybeUninit<T>>]>,
notify: ArcSwap<Box<NotifyCallback>>,
location: &'static Location<'static>,
}
impl<T> GatherData<T> {
fn new(length: usize, location: &'static Location<'static>) -> Self {
fn noop_notify() {
if cfg!(debug_assertions) {
panic!("a notification callback was never set on a gather node");
}
}
let is_valid = (0..length)
.map(|_| CachePadded::new(AtomicBool::new(false)))
.collect();
let mut values = Vec::with_capacity(length);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(length);
}
Self {
is_valid,
values: values.into_boxed_slice(),
notify: ArcSwap::new(Arc::new(Box::new(noop_notify))),
location,
}
}
const fn workers(&self) -> usize {
self.is_valid.len()
}
fn set_notify(&self, notify: Box<NotifyCallback>) {
self.notify.store(Arc::new(notify));
}
unsafe fn all_channels_ready(&self) -> bool {
self.is_valid.iter().all(|is_valid| is_valid.load_consume())
}
unsafe fn push(&self, worker: usize, value: T) {
if cfg!(debug_assertions) {
assert!(worker < self.values.len());
let currently_filled = self.is_valid[worker].load_consume();
assert!(!currently_filled);
}
unsafe {
(*(self.values.as_ptr().add(worker) as *mut CachePadded<MaybeUninit<T>>)).write(value);
self.is_valid
.get_unchecked(worker)
.store(true, Ordering::Release);
}
(self.notify.load())();
}
unsafe fn pop(&self, worker: usize) -> T {
debug_assert!(worker < self.values.len());
unsafe {
let slot_is_valid = self.is_valid.get_unchecked(worker);
let is_valid = slot_is_valid.load_consume();
debug_assert!(is_valid);
let value = self.values.get_unchecked(worker).assume_init_read();
slot_is_valid.store(false, Ordering::Relaxed);
value
}
}
}
impl<T> Drop for GatherData<T> {
fn drop(&mut self) {
if cfg!(debug_assertions) && !std::thread::panicking() {
assert!(
!self.is_valid.iter().any(|is_valid| is_valid.load_consume()),
"dropped a GatherData with values stored in its channel",
);
}
assert!(self.is_valid.len() == self.values.len());
for idx in 0..self.is_valid.len() {
if self.is_valid[idx].load_consume() {
unsafe { self.values[idx].assume_init_drop() };
}
}
}
}
unsafe impl<T: Send> Send for GatherData<T> {}
unsafe impl<T: Send> Sync for GatherData<T> {}
struct GatherProducer<T> {
gather: Arc<GatherData<(T, bool)>>,
worker: usize,
flushed: bool,
}
impl<T> GatherProducer<T> {
const unsafe fn new(gather: Arc<GatherData<(T, bool)>>, worker: usize) -> Self {
Self {
gather,
worker,
flushed: false,
}
}
}
impl<T> Operator for GatherProducer<T>
where
T: 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("GatherProducer")
}
fn location(&self) -> OperatorLocation {
Some(self.gather.location)
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flushed = true;
}
}
impl<T> SinkOperator<T> for GatherProducer<T>
where
T: Clone + Send + 'static,
{
async fn eval(&mut self, input: &T) {
unsafe { self.gather.push(self.worker, (input.clone(), self.flushed)) };
self.flushed = false;
}
async fn eval_owned(&mut self, input: T) {
unsafe { self.gather.push(self.worker, (input, self.flushed)) };
self.flushed = false;
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
struct GatherConsumer<T: Batch> {
factories: T::Factories,
gather: Arc<GatherData<(T, bool)>>,
flush_count: usize,
flush_complete: bool,
}
impl<T: Batch> GatherConsumer<T> {
fn new(gather: Arc<GatherData<(T, bool)>>, factories: &T::Factories) -> Self {
Self {
gather,
factories: factories.clone(),
flush_count: 0,
flush_complete: false,
}
}
}
impl<T: Batch + 'static> Operator for GatherConsumer<T> {
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("GatherConsumer")
}
fn location(&self) -> OperatorLocation {
Some(self.gather.location)
}
fn is_async(&self) -> bool {
true
}
fn register_ready_callback<F>(&mut self, callback: F)
where
F: Fn() + Send + Sync + 'static,
{
self.gather.set_notify(Box::new(callback));
}
fn ready(&self) -> bool {
unsafe { self.gather.all_channels_ready() }
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flush_complete = false;
}
fn is_flush_complete(&self) -> bool {
self.flush_complete
}
}
impl<T> SourceOperator<T> for GatherConsumer<T>
where
T: Batch<Time = ()> + 'static,
{
async fn eval(&mut self) -> T {
debug_assert!(unsafe { self.gather.all_channels_ready() });
let result = merge_batches(
&self.factories,
(0..self.gather.workers()).map(|worker| {
let (batch, flushed) = unsafe { self.gather.pop(worker) };
if flushed {
self.flush_count += 1;
}
batch
}),
&None,
&None,
);
if self.flush_count == Runtime::num_workers() {
self.flush_count = 0;
self.flush_complete = true;
}
result
}
}
struct EmptyGatherConsumer<T: Batch> {
factories: T::Factories,
location: &'static Location<'static>,
__type: PhantomData<T>,
}
impl<T: Batch> EmptyGatherConsumer<T> {
fn new(location: &'static Location<'static>, factories: &T::Factories) -> Self {
Self {
factories: factories.clone(),
location,
__type: PhantomData,
}
}
}
impl<T: Batch + 'static> Operator for EmptyGatherConsumer<T> {
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("EmptyGatherConsumer")
}
fn location(&self) -> OperatorLocation {
Some(self.location)
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
}
impl<T> SourceOperator<T> for EmptyGatherConsumer<T>
where
T: Batch<Time = ()> + 'static,
{
async fn eval(&mut self) -> T {
T::dyn_empty(&self.factories)
}
}