use std::{
borrow::Cow,
collections::VecDeque,
iter::repeat_n,
ops::Range,
panic::Location,
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
};
use itertools::{Itertools as _, zip_eq};
use rkyv::AlignedVec;
use crate::{
Circuit, Runtime, Scope, Stream,
circuit::{
OwnershipPreference, StepSize, WorkerLocation, WorkerLocations,
circuit_builder::StreamId,
metadata::{BatchSizeStats, INPUT_BATCHES_STATS, OperatorLocation, OperatorMeta},
operator_traits::{Operator, SinkOperator, SourceOperator},
},
circuit_cache_key,
operator::{
communication::{
ExchangeClients, ExchangeDelivery, ExchangeDirectory, ExchangeId, pop_flushed,
},
dynamic::shard_batch,
},
trace::{Batch, Spine, deserialize_indexed_wset},
};
circuit_cache_key!(local StreamingExchangeCacheId<B: Batch>(ExchangeId => Arc<ShardedAccumulator<B>>));
circuit_cache_key!(ShardedAccumulatorId<C, B: Batch>(StreamId => Stream<C, Option<Spine<B>>>));
impl<C, B> Stream<C, B>
where
C: Circuit,
B: Batch,
{
#[track_caller]
pub fn dyn_shard_accumulate(&self, factories: &B::Factories) -> Stream<C, Option<Spine<B>>>
where
B: Batch<Time = ()>,
{
if let Some(sharded) = self.get_sharded_version() {
sharded.dyn_accumulate(factories)
} else if Runtime::num_workers() == 1 {
self.dyn_accumulate(factories)
} else {
self.circuit()
.cache_get_or_insert_with(ShardedAccumulatorId::new(self.stream_id()), || {
let runtime = Runtime::runtime().unwrap();
match runtime.get_step_size() {
StepSize::Microsteps => {
let exchange_id: ExchangeId =
runtime.sequence_next().try_into().unwrap();
let exchange = ShardedAccumulator::<B>::with_runtime(
&runtime,
Runtime::worker_index(),
exchange_id,
factories,
);
self.circuit()
.add_exchange(
ShardedAccumulatorSender::new(
Some(Location::caller()),
exchange.clone(),
),
ShardedAccumulatorReceiver::new(
Some(Location::caller()),
exchange,
),
self,
)
.mark_sharded()
}
StepSize::FullSteps => self.dyn_shard(factories).dyn_accumulate(factories),
}
})
.clone()
}
}
}
struct ShardedAccumulator<B>
where
B: Batch,
{
exchange_id: ExchangeId,
npeers: usize,
factories: B::Factories,
local_workers: Range<usize>,
clients: Arc<ExchangeClients>,
rxq: Vec<Mutex<Rxq<B>>>,
}
impl<B> ShardedAccumulator<B>
where
B: Batch<Time = ()>,
{
fn with_runtime(
runtime: &Runtime,
worker_index: usize,
exchange_id: ExchangeId,
factories: &B::Factories,
) -> Arc<Self> {
let directory = ExchangeDirectory::for_runtime(runtime);
let clients = ExchangeClients::for_runtime(runtime);
runtime
.local_store()
.entry(StreamingExchangeCacheId::new(exchange_id))
.or_insert_with(|| {
ShardedAccumulator::new(
runtime,
worker_index,
clients,
exchange_id,
&directory,
factories,
)
})
.value()
.clone()
}
fn new(
runtime: &Runtime,
worker_index: usize,
clients: Arc<ExchangeClients>,
exchange_id: ExchangeId,
directory: &ExchangeDirectory,
factories: &B::Factories,
) -> Arc<Self> {
let layout = runtime.layout();
let npeers = layout.n_workers();
let exchange = Arc::new(Self {
exchange_id,
npeers,
local_workers: layout.local_workers(),
factories: factories.clone(),
clients,
rxq: layout
.local_workers()
.map(|_| Mutex::new(Rxq::new(runtime, worker_index, factories, npeers)))
.collect(),
});
directory.insert(exchange_id, exchange.clone());
exchange
}
fn rxq(&self, receiver: usize) -> MutexGuard<'_, Rxq<B>> {
assert!(self.local_workers.contains(&receiver));
self.rxq[receiver - self.local_workers.start]
.lock()
.unwrap()
}
fn deliver(
&self,
factories: &B::Factories,
sender: usize,
receiver: usize,
batch: B,
flush: bool,
) {
let batch = Spine::maybe_flush_batch(batch, factories, || (None, None));
self.rxq(receiver).deliver(factories, sender, batch, flush);
}
async fn send(self: &Arc<Self>, batch: B, flush: bool) {
let sender = Runtime::worker_index();
let runtime = Runtime::runtime().unwrap();
let layout = runtime.layout();
let mut builders = Vec::with_capacity(layout.n_workers());
let mut batches = Vec::with_capacity(layout.n_workers());
shard_batch(
batch,
&(0..self.npeers),
&mut builders,
&mut batches,
&self.factories,
);
let worker_locations = WorkerLocations::for_layout(layout);
let mut data = batches.into_iter();
for receivers in layout.all_hosts() {
match worker_locations[receivers.start] {
WorkerLocation::Local => {
for receiver in receivers.clone() {
let item = data
.next()
.expect("data should include one item per peer")
.into_plain()
.expect("local data should not be serialized");
self.deliver(&self.factories, sender, receiver, item, flush);
}
}
WorkerLocation::Remote => {
let mut serialized_bytes = 0;
let items = receivers
.clone()
.map(|_| {
let mut fbuf = data
.next()
.expect("data should include one item per peer")
.into_tx()
.expect("remote mailboxes should always be serialized");
fbuf.push(flush as u8);
fbuf
})
.inspect(|serialized| {
serialized_bytes += serialized.len();
})
.collect_vec();
let this = self.clone();
this.clients
.connect(receivers.start)
.await
.send(this.exchange_id, sender, items)
.await;
}
}
}
}
fn receive(&self) -> Option<Spine<B>> {
let receiver = Runtime::worker_index();
self.rxq(receiver).receive()
}
}
impl<B> ExchangeDelivery for ShardedAccumulator<B>
where
B: Batch<Time = ()>,
{
fn received<'a>(
&'a self,
sender: usize,
data: Vec<AlignedVec>,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(async move {
for (receiver, mut data) in zip_eq(self.local_workers.clone(), data) {
let flush = pop_flushed(&mut data);
let batch = deserialize_indexed_wset(&self.factories, &data);
self.deliver(&self.factories, sender, receiver, batch, flush);
}
})
}
}
struct Rxq<B>
where
B: Batch,
{
runtime: Runtime,
worker_index: usize,
npeers: usize,
spines: VecDeque<RxqEntry<B>>,
n_flushes: Vec<usize>,
n_received: usize,
}
struct RxqEntry<B>
where
B: Batch,
{
n_unflushed: usize,
spine: Spine<B>,
}
impl<B> RxqEntry<B>
where
B: Batch,
{
fn new(
npeers: usize,
runtime: &Runtime,
worker_index: usize,
factories: &B::Factories,
) -> Self {
Self {
n_unflushed: npeers,
spine: Spine::with_runtime(runtime.clone(), worker_index, factories),
}
}
}
impl<B> Rxq<B>
where
B: Batch,
{
fn new(
runtime: &Runtime,
worker_index: usize,
factories: &B::Factories,
npeers: usize,
) -> Self {
Self {
runtime: runtime.clone(),
worker_index,
npeers,
spines: VecDeque::from([RxqEntry::new(npeers, runtime, worker_index, factories)]),
n_flushes: repeat_n(0, npeers).collect(),
n_received: 0,
}
}
fn deliver(&mut self, factories: &B::Factories, sender: usize, batch: Arc<B>, flush: bool) {
let index = self.n_flushes[sender] - self.n_received;
let entry = &mut self.spines[index];
entry.spine.insert_without_blocking(batch);
if flush {
entry.n_unflushed -= 1;
self.n_flushes[sender] += 1;
if index + 1 >= self.spines.len() {
self.spines.push_back(RxqEntry::new(
self.npeers,
&self.runtime,
self.worker_index,
factories,
));
}
}
}
fn receive(&mut self) -> Option<Spine<B>> {
self.spines
.pop_front_if(|entry| entry.n_unflushed == 0)
.map(|entry| {
self.n_received += 1;
entry.spine
})
}
}
struct ShardedAccumulatorSender<B>
where
B: Batch,
{
location: OperatorLocation,
exchange: Arc<ShardedAccumulator<B>>,
input_batch_stats: BatchSizeStats,
flushed: bool,
}
impl<B> ShardedAccumulatorSender<B>
where
B: Batch,
{
fn new(location: OperatorLocation, exchange: Arc<ShardedAccumulator<B>>) -> Self {
Self {
location,
exchange,
input_batch_stats: BatchSizeStats::new(),
flushed: false,
}
}
}
impl<B> Operator for ShardedAccumulatorSender<B>
where
B: Batch,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("ShardedAccumulatorSender")
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
});
}
fn location(&self) -> OperatorLocation {
self.location
}
fn clock_start(&mut self, _scope: Scope) {}
fn clock_end(&mut self, _scope: Scope) {}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
fn flush(&mut self) {
self.flushed = true;
}
}
impl<B> SinkOperator<B> for ShardedAccumulatorSender<B>
where
B: Batch<Time = ()>,
{
async fn eval(&mut self, batch: &B) {
self.eval_owned(batch.clone()).await
}
async fn eval_owned(&mut self, batch: B) {
self.input_batch_stats.add_batch(batch.num_entries_deep());
self.exchange.send(batch, self.flushed).await;
self.flushed = false;
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
struct ShardedAccumulatorReceiver<B>
where
B: Batch,
{
exchange: Arc<ShardedAccumulator<B>>,
location: OperatorLocation,
flushed: bool,
}
impl<B> ShardedAccumulatorReceiver<B>
where
B: Batch,
{
fn new(location: OperatorLocation, exchange: Arc<ShardedAccumulator<B>>) -> Self {
Self {
exchange,
location,
flushed: false,
}
}
}
impl<B> Operator for ShardedAccumulatorReceiver<B>
where
B: Batch,
{
fn name(&self) -> std::borrow::Cow<'static, str> {
Cow::Borrowed("ShardedAccumulatorReceiver")
}
fn location(&self) -> OperatorLocation {
self.location
}
fn fixedpoint(&self, _scope: crate::circuit::Scope) -> bool {
true
}
fn flush(&mut self) {
self.flushed = false;
}
fn is_flush_complete(&self) -> bool {
self.flushed
}
fn ready(&self) -> bool {
true
}
}
impl<B> SourceOperator<Option<Spine<B>>> for ShardedAccumulatorReceiver<B>
where
B: Batch<Time = ()>,
{
async fn eval(&mut self) -> Option<Spine<B>> {
let output = self.exchange.receive();
if let Some(spine) = &output {
spine.backpressure_wait().await;
self.flushed = true;
}
output
}
}
#[cfg(test)]
mod tests {
use crossbeam::thread;
use itertools::Itertools;
use crate::{
DBSPHandle, OutputHandle, RootCircuit, ZSetHandle, ZWeight,
circuit::{CircuitConfig, Layout, Runtime},
dynamic::{Data, DowncastTrait, DynWeightTyped},
trace::{BatchReader, Cursor, FallbackWSet, Spine},
typed_batch::TypedBatch,
};
use std::{collections::BTreeMap, iter::zip, net::TcpListener};
const STREAMING_ROUNDS: usize = 64;
fn test_circuit(workers: usize, hosts: usize) {
let (mut dbsp_handles, input_handles, output_handles) = match hosts {
0 => unreachable!(),
1 => {
let (dbsp_handle, (input_handle, output_handle)) =
Runtime::init_circuit(workers, circuit).expect("failed to start runtime");
(vec![dbsp_handle], vec![input_handle], vec![output_handle])
}
_ => {
assert!(workers >= hosts);
let exchange_listeners = (0..hosts)
.map(|_| {
TcpListener::bind("127.0.0.1:0")
.expect("should be able to bind a port on localhost")
})
.collect_vec();
let params = exchange_listeners
.iter()
.enumerate()
.map(|(index, listener)| {
(
listener
.local_addr()
.expect("should be able to get local address"),
workers / hosts + (index < workers % hosts) as usize,
)
})
.collect_vec();
let mut handles = Vec::with_capacity(params.len());
for ((local_address, _), exchange_listener) in
zip(params.iter(), exchange_listeners)
{
let cconf = CircuitConfig::from(
Layout::new_multihost(¶ms, *local_address).unwrap(),
)
.with_exchange_listener(exchange_listener);
let (dbsp_handle, (input_handle, output_handle)) =
Runtime::init_circuit(cconf, circuit).expect("failed to start runtime");
handles.push((dbsp_handle, input_handle, output_handle));
}
handles.into_iter().multiunzip()
}
};
fn for_each_host<F>(dbsp_handles: &mut [DBSPHandle], f: F)
where
F: Fn(&mut DBSPHandle) + Send + Sync + 'static,
{
thread::scope(|s| {
dbsp_handles
.iter_mut()
.map(|h| s.spawn(|_| f(h)))
.collect_vec()
.into_iter()
.for_each(|h| h.join().unwrap())
})
.unwrap();
}
for round in 0..STREAMING_ROUNDS {
for_each_host(&mut dbsp_handles, |h| h.start_transaction().unwrap());
for i in 0..=round {
input_handles[i % hosts].push(i, 1);
for_each_host(&mut dbsp_handles, |h| {
h.step().unwrap();
});
}
for_each_host(&mut dbsp_handles, |h| h.commit_transaction().unwrap());
let mut results = BTreeMap::<usize, ZWeight>::new();
for spine in output_handles
.iter()
.flat_map(|handle| handle.take_from_all())
{
let mut cursor = spine.cursor();
while let Some(key) = cursor.get_key() {
let key = *unsafe { key.downcast() };
let weight = *unsafe { cursor.weight().downcast::<ZWeight>() };
*results.entry(key).or_default() += weight;
cursor.step_key();
}
}
let results = results.into_iter().collect_vec();
let expected = (0..=round).map(|i| (i, 1)).collect_vec();
assert_eq!(&results, &expected);
}
}
fn circuit(
circuit: &mut RootCircuit,
) -> anyhow::Result<(
ZSetHandle<usize>,
OutputHandle<
TypedBatch<
usize,
(),
i64,
Spine<FallbackWSet<dyn Data + 'static, DynWeightTyped<i64>>>,
>,
>,
)> {
let (input, input_handle) = circuit.add_input_zset::<usize>();
let output_handle = input.shard_accumulate().latest_output();
Ok((input_handle, output_handle))
}
#[test]
fn sharded_accumulator_single_host() {
for workers in [2, 16, 32] {
test_circuit(workers, 1);
}
}
#[test]
fn sharded_accumulator_multihost() {
for (workers, hosts) in [(2, 2), (4, 2), (8, 2), (3, 3), (4, 4), (16, 4)] {
test_circuit(workers, hosts);
}
}
}