dbsp/operator/communication/exchange.rs
1//! Exchange operators implement a N-to-N communication pattern where
2//! each participant sends exactly one value to and receives exactly one
3//! value from each peer at every clock cycle.
4
5// TODO: We may want to generalize these operators to implement N-to-M
6// communication, including 1-to-N and N-to-1.
7
8use crate::{
9 NumEntries, WeakRuntime,
10 circuit::{
11 Host, LocalStoreMarker, OwnershipPreference, Runtime, Scope,
12 metadata::{
13 BatchSizeStats, EXCHANGE_DESERIALIZATION_TIME_SECONDS, EXCHANGE_DESERIALIZED_BYTES,
14 EXCHANGE_SERIALIZATION_TIME_SECONDS, EXCHANGE_SERIALIZED_BYTES,
15 EXCHANGE_WAIT_TIME_SECONDS, INPUT_BATCHES_STATS, MetaItem, OUTPUT_BATCHES_STATS,
16 OperatorLocation, OperatorMeta,
17 },
18 operator_traits::{Operator, SinkOperator, SourceOperator},
19 tokio::TOKIO,
20 },
21 circuit_cache_key,
22};
23use crossbeam_utils::CachePadded;
24use futures::{future, prelude::*, stream::FuturesUnordered};
25use std::{
26 borrow::Cow,
27 collections::HashMap,
28 marker::PhantomData,
29 net::SocketAddr,
30 ops::Range,
31 sync::{
32 Arc, Mutex, OnceLock, RwLock,
33 atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering},
34 },
35 time::{Duration, Instant, SystemTime},
36};
37use tarpc::{
38 client, context,
39 serde_transport::tcp::{connect, listen},
40 server::{self, Channel},
41 tokio_serde::formats::Bincode,
42 tokio_util::sync::{CancellationToken, DropGuard},
43};
44use tokio::{
45 sync::{Notify, OnceCell as TokioOnceCell},
46 time::sleep,
47};
48use typedmap::TypedMapKey;
49
50/// Current time in microseconds.
51fn current_time_usecs() -> u64 {
52 SystemTime::now()
53 .duration_since(std::time::UNIX_EPOCH)
54 .map(|d| d.as_micros().try_into().unwrap_or(u64::MAX))
55 .unwrap_or(0)
56}
57
58// We use the `Runtime::local_store` mechanism to connect multiple workers
59// to an `Exchange` instance. During circuit construction, each worker
60// allocates a unique id that happens to be the same across all workers.
61// The worker then allocates a new `Exchange` and adds it to the local store
62// using the id as a key. If there already is an `Exchange` with this id in
63// the store, created by another worker, a reference to that `Exchange` will
64// be used instead.
65circuit_cache_key!(local ExchangeCacheId<T>(ExchangeId => Arc<Exchange<T>>));
66
67#[tarpc::service]
68trait ExchangeService {
69 /// Sends messages in `exchange_id` from all of the worker threads in
70 /// `senders` to all of the worker thread receivers in the server that
71 /// processes the message. The Bincode-encoded message from `sender` to
72 /// `receiver` is `data[sender - senders.start][receiver -
73 /// receivers.start]`, where `receivers` is the `Range<usize>` of worker
74 /// thread IDs in the server that processes the message.
75 async fn exchange(exchange_id: usize, senders: Range<usize>, data: Vec<Vec<Vec<u8>>>);
76}
77
78type ExchangeId = usize;
79
80// Maps from an `exchange_id` to the `Inner` that implements the exchange.
81type ExchangeDirectory = Arc<RwLock<HashMap<ExchangeId, Arc<InnerExchange>>>>;
82
83#[derive(Clone)]
84struct ExchangeServer(ExchangeDirectory);
85
86impl ExchangeService for ExchangeServer {
87 async fn exchange(
88 self,
89 _: context::Context,
90 exchange_id: ExchangeId,
91 senders: Range<usize>,
92 data: Vec<Vec<Vec<u8>>>,
93 ) {
94 let inner = self.0.read().unwrap().get(&exchange_id).unwrap().clone();
95 inner.received(senders, data).await;
96 }
97}
98
99struct Clients {
100 runtime: WeakRuntime,
101
102 /// Listens for connections from other hosts.
103 ///
104 /// We create this lazily upon the first attempt to connect to other hosts.
105 /// If we create it before we've completely initialized the circuit, then we
106 /// might not have created all of the exchanges yet when some other host
107 /// tries to send data to one.
108 listener: OnceLock<Option<ExchangeListener>>,
109
110 /// Maps from a range of worker IDs to the RPC client used to contact those
111 /// workers. Only worker IDs for remote workers appear in the map.
112 clients: Vec<(Host, TokioOnceCell<ExchangeServiceClient>)>,
113}
114
115impl Clients {
116 fn new(runtime: &Runtime) -> Clients {
117 Self {
118 runtime: runtime.downgrade(),
119 listener: Default::default(),
120 clients: runtime
121 .layout()
122 .other_hosts()
123 .map(|host| (host.clone(), TokioOnceCell::new()))
124 .collect(),
125 }
126 }
127
128 /// Returns a client for `worker`, which must be a remote worker ID, first
129 /// establishing a connection if there isn't one yet.
130 async fn connect(&self, worker: usize) -> &ExchangeServiceClient {
131 self.listener.get_or_init(|| {
132 if let Some(runtime) = self.runtime.upgrade()
133 && let Some(local_address) = runtime.layout().local_address()
134 {
135 let directory = runtime.local_store().get(&DirectoryId).unwrap().clone();
136 Some(ExchangeListener::new(local_address, directory))
137 } else {
138 None
139 }
140 });
141
142 let (host, cell) = self
143 .clients
144 .iter()
145 .find(|(host, _client)| host.workers.contains(&worker))
146 .unwrap();
147 cell.get_or_init(|| async {
148 let transport = loop {
149 let mut transport = connect(host.address, Bincode::default);
150 transport.config_mut().max_frame_length(usize::MAX);
151 match transport.await {
152 Ok(transport) => break transport,
153 Err(error) => println!(
154 "connection to {} failed ({error}), waiting to retry",
155 host.address
156 ),
157 }
158 sleep(std::time::Duration::from_millis(1000)).await;
159 };
160 ExchangeServiceClient::new(client::Config::default(), transport).spawn()
161 })
162 .await
163 }
164}
165
166struct CallbackInner {
167 cb: Option<Box<dyn Fn() + Send + Sync>>,
168}
169
170impl CallbackInner {
171 fn empty() -> Self {
172 Self { cb: None }
173 }
174
175 fn new<F>(cb: F) -> Self
176 where
177 F: Fn() + Send + Sync + 'static,
178 {
179 let cb = Box::new(cb) as Box<dyn Fn() + Send + Sync>;
180 Self { cb: Some(cb) }
181 }
182}
183
184struct Callback(AtomicPtr<CallbackInner>);
185
186impl Callback {
187 fn empty() -> Self {
188 Self(AtomicPtr::new(Box::into_raw(Box::new(
189 CallbackInner::empty(),
190 ))))
191 }
192
193 fn set_callback(&self, cb: impl Fn() + Send + Sync + 'static) {
194 let old_callback = self.0.swap(
195 Box::into_raw(Box::new(CallbackInner::new(cb))),
196 Ordering::AcqRel,
197 );
198
199 let old_callback = unsafe { Box::from_raw(old_callback) };
200 drop(old_callback);
201 }
202
203 fn call(&self) {
204 if let Some(cb) = &unsafe { &*self.0.load(Ordering::Acquire) }.cb {
205 cb()
206 }
207 }
208}
209
210struct InnerExchange {
211 exchange_id: ExchangeId,
212 /// The number of communicating peers.
213 npeers: usize,
214 /// Range of worker IDs on the local host.
215 local_workers: Range<usize>,
216 /// Counts the number of messages yet to be received in the current round of
217 /// communication per receiver. The receiver must wait until it has all
218 /// `npeers` messages before reading all of them from mailboxes in one
219 /// pass.
220 receiver_counters: Vec<AtomicUsize>,
221 /// Callback invoked when all `npeers` messages are ready for a receiver.
222 receiver_callbacks: Vec<Callback>,
223 /// Counts the number of empty mailboxes ready to accept new data per
224 /// sender. The sender waits until it has `npeers` available mailboxes
225 /// before writing all of them in one pass.
226 sender_counters: Vec<CachePadded<AtomicUsize>>,
227 /// Callback invoked when all `npeers` mailboxes are available.
228 sender_callbacks: Vec<Callback>,
229 /// The number of workers that have already sent their messages in the
230 /// current round.
231 sent: AtomicUsize,
232 /// The RPC clients to contact remote hosts.
233 clients: Arc<Clients>,
234 /// This allows the `exchange` RPC to wait until the receiver has taken its
235 /// data out of the mailbox. There are `n_remote_workers * n_local_workers`
236 /// elements.
237 sender_notifies: Vec<Notify>,
238 /// A callback that takes the raw data exchanged over RPC and deserializes
239 /// and delivers it to the receiver's mailbox.
240 deliver: Box<dyn Fn(Vec<u8>, usize, usize) + Send + Sync + 'static>,
241 /// The amount of time spent in `deliver`.
242 delivery_usecs: AtomicU64,
243 /// The number of bytes passed to `deliver`.
244 delivered_bytes: AtomicUsize,
245}
246
247impl InnerExchange {
248 fn new(
249 exchange_id: ExchangeId,
250 deliver: impl Fn(Vec<u8>, usize, usize) + Send + Sync + 'static,
251 clients: Arc<Clients>,
252 ) -> InnerExchange {
253 let runtime = Runtime::runtime().unwrap();
254 let npeers = Runtime::num_workers();
255 let local_workers = runtime.layout().local_workers();
256 let n_local_workers = local_workers.len();
257 let n_remote_workers = npeers - n_local_workers;
258 Self {
259 exchange_id,
260 npeers,
261 local_workers,
262 clients,
263 receiver_counters: (0..npeers).map(|_| AtomicUsize::new(0)).collect(),
264 receiver_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
265 sender_notifies: (0..n_local_workers * n_remote_workers)
266 .map(|_| Notify::new())
267 .collect(),
268 sender_counters: (0..npeers)
269 .map(|_| CachePadded::new(AtomicUsize::new(npeers)))
270 .collect(),
271 sender_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
272 deliver: Box::new(deliver),
273 delivery_usecs: AtomicU64::new(0),
274 delivered_bytes: AtomicUsize::new(0),
275 sent: AtomicUsize::new(0),
276 }
277 }
278
279 #[allow(dead_code)]
280 fn exchange_id(&self) -> ExchangeId {
281 self.exchange_id
282 }
283
284 /// Returns the `sender_notify` for a sender/receiver pair. `receiver`
285 /// must be a local worker ID, and `sender` must be a remote worker ID.
286 fn sender_notify(&self, sender: usize, receiver: usize) -> &Notify {
287 debug_assert!(sender < self.npeers && !self.local_workers.contains(&sender));
288 debug_assert!(self.local_workers.contains(&receiver));
289 let n_local_workers = self.local_workers.len();
290 let sender_ofs = if sender >= self.local_workers.start {
291 sender - n_local_workers
292 } else {
293 sender
294 };
295 let receiver_ofs = receiver - self.local_workers.start;
296 &self.sender_notifies[sender_ofs * n_local_workers + receiver_ofs]
297 }
298
299 /// Receives messages sent from all of the worker threads in `senders` to
300 /// all of the local worker threads `receivers` in `self`. The
301 /// Bincode-encoded `Vec<u8>` message from `sender` to `receiver` is
302 /// `data[sender - senders.start][receiver - receivers.start]`.
303 async fn received(self: &Arc<Self>, senders: Range<usize>, data: Vec<Vec<Vec<u8>>>) {
304 let receivers = &self.local_workers;
305
306 // Deliver all of the data into the exchange's mailboxes.
307 let start = Instant::now();
308 let mut delivered_bytes = 0;
309 for (sender, data) in senders.clone().zip(data.into_iter()) {
310 assert_eq!(data.len(), receivers.len());
311 for (receiver, data) in receivers.clone().zip(data.into_iter()) {
312 delivered_bytes += data.len();
313 (self.deliver)(data, sender, receiver);
314 }
315 }
316 self.delivery_usecs
317 .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
318 self.delivered_bytes
319 .fetch_add(delivered_bytes, Ordering::Relaxed);
320
321 // Increment the receiver counters and deliver callbacks if necessary.
322 for receiver in receivers.clone() {
323 let n = senders.len();
324 let old_counter = self.receiver_counters[receiver].fetch_add(n, Ordering::AcqRel);
325 if old_counter >= self.npeers - n {
326 self.receiver_callbacks[receiver].call();
327 }
328 }
329
330 // Wait for the receivers to pick up their mail before returning.
331 for sender in senders {
332 for receiver in receivers.clone() {
333 self.sender_notify(sender, receiver).notified().await;
334 }
335 }
336 }
337
338 /// Returns an index for the sender/receiver pair.
339 fn mailbox_index(&self, sender: usize, receiver: usize) -> usize {
340 debug_assert!(sender < self.npeers);
341 debug_assert!(receiver < self.npeers);
342 sender * self.npeers + receiver
343 }
344
345 fn ready_to_send(&self, sender: usize) -> bool {
346 debug_assert!(self.local_workers.contains(&sender));
347 self.sender_counters[sender].load(Ordering::Acquire) == self.npeers
348 }
349
350 fn ready_to_receive(&self, receiver: usize) -> bool {
351 debug_assert!(receiver < self.npeers);
352 self.receiver_counters[receiver].load(Ordering::Acquire) == self.npeers
353 }
354
355 fn register_sender_callback<F>(&self, sender: usize, cb: F)
356 where
357 F: Fn() + Send + Sync + 'static,
358 {
359 debug_assert!(sender < self.npeers);
360 self.sender_callbacks[sender].set_callback(cb);
361 }
362
363 fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
364 where
365 F: Fn() + Send + Sync + 'static,
366 {
367 debug_assert!(receiver < self.npeers);
368
369 self.receiver_callbacks[receiver].set_callback(cb);
370 }
371}
372
373/// `Exchange` is an N-to-N communication primitive that partitions data across
374/// multiple concurrent threads.
375///
376/// An instance of `Exchange` can be shared by multiple threads that communicate
377/// in rounds. In each round each peer _first_ sends exactly one data value to
378/// every other peer (and itself) and then receives one value from each peer.
379/// The send operation can only proceed when all peers have retrieved data
380/// produced at the previous round. Likewise, the receive operation can proceed
381/// once all incoming values are ready for the current round.
382///
383/// Each worker has one ExchangeServiceClient and ExchangeServer for every
384/// worker (including itself), so N*N total.
385///
386/// In a round, each worker invokes exchange() once on each of its clients.
387/// Each server handles N calls to exchange(), once for each other worker and
388/// itself.
389///
390/// Each call to exchange populates a mailbox. When all the mailboxes for a
391/// worker have been populated, it can read and clear them.
392pub(crate) struct Exchange<T> {
393 inner: Arc<InnerExchange>,
394 /// `npeers^2` mailboxes, one for each sender/receiver pair. Each mailbox
395 /// is accessed by exactly two threads, so contention is low.
396 ///
397 /// We only use the mailboxes where either the sender or the receiver is one
398 /// of our local workers. In the diagram below, L is mailboxes used for
399 /// local exchange, S mailboxes used for sending RPC exchange, and R
400 /// mailboxes used for receiving exchange via RPC:
401 ///
402 /// ```text
403 /// <-------receivers------->
404 /// local
405 /// workers
406 /// ^ -------------------------
407 /// | | |RRRRR| | |
408 /// | |RRRRR| | |
409 /// s |-----|-----|-----|-----|
410 /// e local |SSSSS|LLLLL|SSSSS|SSSSS|
411 /// n workers |SSSSS|LLLLL|SSSSS|SSSSS|
412 /// d |-----|-----|-----|-----|
413 /// e | |RRRRR| | |
414 /// r | |RRRRR| | |
415 /// s |-----|-----|-----|-----|
416 /// | |RRRRR| | |
417 /// | | |RRRRR| | |
418 /// v |-----|-----|-----|-----|
419 /// ```
420 mailboxes: Arc<Vec<Mutex<Option<T>>>>,
421 serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
422
423 /// The amount of time we've spent calling `serialize`.
424 serialization_usecs: AtomicU64,
425
426 /// The number of bytes produced by `serialize`.
427 serialized_bytes: AtomicUsize,
428}
429
430async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
431 tokio::spawn(fut);
432}
433
434// Stop Rust from complaining about unused field.
435#[allow(dead_code)]
436struct ExchangeListener(DropGuard);
437
438impl ExchangeListener {
439 fn new(address: SocketAddr, directory: ExchangeDirectory) -> Self {
440 let token = CancellationToken::new();
441 let drop = token.clone().drop_guard();
442 TOKIO.spawn(async move {
443 println!("listening on {address}");
444 let mut listener = listen(address, Bincode::default).await.unwrap();
445 listener.config_mut().max_frame_length(usize::MAX);
446 let incoming = listener
447 .filter_map(|r| future::ready(r.ok()))
448 .map(server::BaseChannel::with_defaults)
449 .map(move |channel| {
450 let server = ExchangeServer(directory.clone());
451 channel.execute(server.serve()).for_each(spawn)
452 })
453 .buffer_unordered(10)
454 .for_each(|_| async {});
455 tokio::select! {
456 _ = incoming => {}
457 _ = token.cancelled() => {}
458 }
459 });
460 Self(drop)
461 }
462}
463
464impl<T> Exchange<T>
465where
466 T: Clone + Send + 'static,
467{
468 /// Create a new exchange operator for `npeers` communicating threads.
469 fn new(
470 exchange_id: ExchangeId,
471 clients: Arc<Clients>,
472 directory: ExchangeDirectory,
473 serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
474 deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,
475 ) -> Self {
476 let npeers = Runtime::num_workers();
477 let mailboxes: Arc<Vec<Mutex<Option<T>>>> =
478 Arc::new((0..npeers * npeers).map(|_| Mutex::new(None)).collect());
479 let mailboxes2: Arc<Vec<Mutex<Option<T>>>> = mailboxes.clone();
480 let deliver = move |data: Vec<u8>, sender, receiver| {
481 let index: usize = sender * npeers + receiver;
482 let data = deserialize(data);
483 let mut mailbox = mailboxes2[index].lock().unwrap();
484 assert!((*mailbox).is_none());
485 *mailbox = Some(data);
486 };
487
488 let inner = Arc::new(InnerExchange::new(exchange_id, deliver, clients));
489 directory
490 .write()
491 .unwrap()
492 .entry(exchange_id)
493 .and_modify(|_| panic!())
494 .or_insert(inner.clone());
495 Self {
496 inner,
497 mailboxes,
498 serialize,
499 serialization_usecs: AtomicU64::new(0),
500 serialized_bytes: AtomicUsize::new(0),
501 }
502 }
503
504 #[allow(dead_code)]
505 fn exchange_id(&self) -> ExchangeId {
506 self.inner.exchange_id()
507 }
508
509 /// Returns a reference to a mailbox for the sender/receiver pair.
510 fn mailbox(&self, sender: usize, receiver: usize) -> &Mutex<Option<T>> {
511 &self.mailboxes[self.inner.mailbox_index(sender, receiver)]
512 }
513
514 /// Create a new `Exchange` instance if an instance with the same id
515 /// (created by another thread) does not yet exist within `runtime`.
516 /// The number of peers will be set to `runtime.num_workers()`.
517 pub(crate) fn with_runtime(
518 runtime: &Runtime,
519 exchange_id: ExchangeId,
520 serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
521 deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,
522 ) -> Arc<Self> {
523 let directory = runtime
524 .local_store()
525 .entry(DirectoryId)
526 .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())))
527 .clone();
528
529 let clients = runtime
530 .local_store()
531 .entry(ClientsId)
532 .or_insert_with(|| {
533 // Create clients for remote exchange.
534 Arc::new(Clients::new(runtime))
535 })
536 .clone();
537
538 runtime
539 .local_store()
540 .entry(ExchangeCacheId::new(exchange_id))
541 .or_insert_with(|| {
542 Arc::new(Exchange::new(
543 exchange_id,
544 clients.clone(),
545 directory,
546 serialize,
547 deserialize,
548 ))
549 })
550 .value()
551 .clone()
552 }
553
554 /// True if all `sender`'s outgoing mailboxes are free and ready to accept
555 /// data.
556 ///
557 /// Once this function returns true, a subsequent `try_send_all` operation
558 /// is guaranteed to succeed for `sender`.
559 pub fn ready_to_send(&self, sender: usize) -> bool {
560 self.inner.ready_to_send(sender)
561 }
562
563 /// Write all outgoing messages for `sender` to mailboxes.
564 ///
565 /// Values to be sent are retrieved from the `data` iterator, with the
566 /// first value delivered to receiver 0, second value delivered to receiver
567 /// 1, and so on.
568 ///
569 /// # Errors
570 ///
571 /// Fails if at least one of the sender's outgoing mailboxes is not empty.
572 ///
573 /// # Panics
574 ///
575 /// Panics if `data` yields fewer than `self.npeers` items.
576 pub(crate) fn try_send_all<I>(self: &Arc<Self>, sender: usize, data: &mut I) -> bool
577 where
578 I: Iterator<Item = T> + Send,
579 {
580 let npeers = self.inner.npeers;
581 if self.inner.sender_counters[sender]
582 .compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
583 .is_err()
584 {
585 return false;
586 }
587
588 // Deliver all of the data to local mailboxes.
589 let local_workers = &self.inner.local_workers;
590 for receiver in 0..npeers {
591 *self.mailbox(sender, receiver).lock().unwrap() = data.next();
592
593 if local_workers.contains(&receiver) {
594 let old_counter =
595 self.inner.receiver_counters[receiver].fetch_add(1, Ordering::AcqRel);
596 if old_counter >= npeers - 1 {
597 self.inner.receiver_callbacks[receiver].call();
598 }
599 }
600 }
601
602 // In a single-host layout, or if some of our local workers haven't yet
603 // sent in this round, we're all done for now.
604 if npeers == local_workers.len()
605 || self.inner.sent.fetch_add(1, Ordering::AcqRel) + 1 != local_workers.len()
606 {
607 return true;
608 }
609 self.inner.sent.store(0, Ordering::Release);
610
611 // All of the local workers have sent their data in this round. Take
612 // all of their data and send it to the remote hosts.
613 let this = self.clone();
614 let runtime = Runtime::runtime().unwrap();
615 TOKIO.spawn(async move {
616 let mut futures = FuturesUnordered::new();
617
618 // For each range of worker IDs `receivers` on a remote host,
619 // accumulate all of the data from our local `senders` to all
620 // of the `receivers` on that host.
621 let senders = &this.inner.local_workers;
622 let start = Instant::now();
623 for host in runtime.layout().other_hosts() {
624 let receivers = &host.workers;
625 let mut serialized_bytes = 0;
626 let items: Vec<Vec<_>> = senders
627 .clone()
628 .map(|sender| {
629 receivers
630 .clone()
631 .map(|receiver| {
632 let item = this
633 .mailbox(sender, receiver)
634 .lock()
635 .unwrap()
636 .take()
637 .unwrap();
638 let serialized = (this.serialize)(item);
639 serialized_bytes += serialized.len();
640 serialized
641 })
642 .collect()
643 })
644 .collect();
645 this.serialization_usecs
646 .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
647 this.serialized_bytes
648 .fetch_add(serialized_bytes, Ordering::Relaxed);
649
650 let client = this.inner.clients.connect(receivers.start).await;
651
652 // Send it.
653 let mut context = context::current();
654 context.deadline = Instant::now() + Duration::from_hours(1);
655 futures.push(client.exchange(
656 context,
657 this.inner.exchange_id,
658 senders.clone(),
659 items,
660 ));
661 }
662
663 // Wait for all the sends to complete.
664 while let Some(result) = futures.next().await {
665 result.unwrap();
666 }
667
668 // Record that the sends completed.
669 let n = npeers - senders.len();
670 for sender in senders.clone() {
671 let old_counter = this.inner.sender_counters[sender].fetch_add(n, Ordering::AcqRel);
672 if old_counter >= npeers - n {
673 this.inner.sender_callbacks[sender].call();
674 }
675 }
676 });
677
678 true
679 }
680
681 /// True if all `receiver`'s incoming mailboxes contain data.
682 ///
683 /// Once this function returns true, a subsequent `try_receive_all`
684 /// operation is guaranteed for `receiver`.
685 pub(crate) fn ready_to_receive(&self, receiver: usize) -> bool {
686 self.inner.ready_to_receive(receiver)
687 }
688
689 /// Read all incoming messages for `receiver`.
690 ///
691 /// Values are passed to callback function `cb` in the order of worker indexes.
692 ///
693 /// # Errors
694 ///
695 /// Fails if at least one of the receiver's incoming mailboxes is empty.
696 pub(crate) fn try_receive_all<F>(&self, receiver: usize, mut cb: F) -> bool
697 where
698 F: FnMut(T),
699 {
700 let npeers = self.inner.npeers;
701 if self.inner.receiver_counters[receiver]
702 .compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
703 .is_err()
704 {
705 return false;
706 }
707
708 for sender in 0..self.inner.npeers {
709 let data = self
710 .mailbox(sender, receiver)
711 .lock()
712 .unwrap()
713 .take()
714 .unwrap();
715 cb(data);
716 if self.inner.local_workers.contains(&sender) {
717 let old_counter = self.inner.sender_counters[sender].fetch_add(1, Ordering::AcqRel);
718 if old_counter >= self.inner.npeers - 1 {
719 self.inner.sender_callbacks[sender].call();
720 }
721 } else {
722 self.inner.sender_notify(sender, receiver).notify_one();
723 }
724 }
725 true
726 }
727
728 /// Register callback to be invoked whenever the `ready_to_send` condition
729 /// becomes true.
730 ///
731 /// The callback can be setup at most once (e.g., when a scheduler attaches
732 /// to the circuit) and cannot be unregistered. Notifications delivered
733 /// before the callback is registered are lost. The client should call
734 /// `ready_to_send` after installing the callback to check the status.
735 ///
736 /// After the callback has been registered, notifications are delivered with
737 /// at-least-once semantics: a notification is generated whenever the
738 /// status changes from not ready to ready, but spurious notifications
739 /// can occur occasionally. Therefore, the user must check the status
740 /// explicitly by calling `ready_to_send` or be prepared that `try_send_all`
741 /// can fail.
742 pub(crate) fn register_sender_callback<F>(&self, sender: usize, cb: F)
743 where
744 F: Fn() + Send + Sync + 'static,
745 {
746 self.inner.register_sender_callback(sender, cb)
747 }
748
749 /// Register callback to be invoked whenever the `ready_to_receive`
750 /// condition becomes true.
751 ///
752 /// The callback can be setup at most once (e.g., when a scheduler attaches
753 /// to the circuit) and cannot be unregistered. Notifications delivered
754 /// before the callback is registered are lost. The client should call
755 /// `ready_to_receive` after installing the callback to check
756 /// the status.
757 ///
758 /// After the callback has been registered, notifications are delivered with
759 /// at-least-once semantics: a notification is generated whenever the
760 /// status changes from not ready to ready, but spurious notifications
761 /// can occur occasionally. The user must check the status explicitly
762 /// by calling `ready_to_receive` or be prepared that `try_receive_all`
763 /// can fail.
764 pub(crate) fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
765 where
766 F: Fn() + Send + Sync + 'static,
767 {
768 self.inner.register_receiver_callback(receiver, cb)
769 }
770}
771
772/// Operator that partitions incoming data across all workers.
773///
774/// This operator works in tandem with [`ExchangeReceiver`], which reassembles
775/// the data on the receiving side. Together they implement an all-to-all
776/// communication mechanism, where at every clock cycle each worker partitions
777/// its incoming data into `N` values, one for each worker, using a
778/// user-provided closure. It then reads values sent to it by all peers and
779/// reassembles them into a single value using another user-provided closure.
780///
781/// The exchange mechanism is split into two operators, so that after sending
782/// the data the circuit does not need to block waiting for its peers to finish
783/// sending and can instead schedule other operators.
784///
785/// ```text
786/// ExchangeSender ExchangeReceiver
787/// ┌───────┐ ┌───────┐
788/// │ │ │ │
789/// ┌───────┐ │ │ │ │ ┌───────┐
790/// │source ├─────►│ │ │ ├─────────►│ sink │
791/// └───────┘ │ │ │ │ └───────┘
792/// │ ├───┬─►│ │
793/// │ │ │ │ │
794/// └───────┘ │ └───────┘
795/// WORKER 1 │
796/// ──────────────────────────────────┼──────────────────────────────
797/// WORKER 2 │
798/// │
799/// ┌───────┐ │ ┌───────┐
800/// │ ├───┴─►│ │
801/// ┌───────┐ │ │ │ │ ┌───────┐
802/// │source ├─────►│ │ │ ├─────────►│ sink │
803/// └───────┘ │ │ │ │ └───────┘
804/// │ │ │ │
805/// │ │ │ │
806/// └───────┘ └───────┘
807/// ExchangeSender ExchangeReceiver
808/// ```
809///
810/// `ExchangeSender` is an asynchronous operator., i.e.,
811/// [`ExchangeSender::is_async`] returns `true`. It becomes schedulable
812/// ([`ExchangeSender::ready`] returns `true`) once all peers have retrieved
813/// values written by the operator in the previous clock cycle. The scheduler
814/// should use [`ExchangeSender::register_ready_callback`] to get notified when
815/// the operator becomes schedulable.
816///
817/// `ExchangeSender` doesn't have a public constructor and must be instantiated
818/// using the [`new_exchange_operators`] function, which creates an
819/// [`ExchangeSender`]/[`ExchangeReceiver`] pair of operators and connects them
820/// to their counterparts in other workers as in the diagram above.
821///
822/// An [`ExchangeSender`]/[`ExchangeReceiver`] pair is added to a circuit using
823/// the [`Circuit::add_exchange`](`crate::circuit::Circuit::add_exchange`)
824/// method, which registers a dependency between them, making sure that
825/// `ExchangeSender` is evaluated before `ExchangeReceiver`.
826///
827/// # Examples
828///
829/// The following example instantiates the circuit in the diagram above.
830///
831/// ```
832/// # #[cfg(miri)]
833/// # fn main() {}
834///
835/// # #[cfg(not(miri))]
836/// # fn main() {
837/// use dbsp::{
838/// operator::{communication::new_exchange_operators, Generator},
839/// Circuit, RootCircuit, Runtime,
840/// storage::file::to_bytes,
841/// trace::unaligned_deserialize,
842/// };
843///
844/// const WORKERS: usize = 16;
845/// const ROUNDS: usize = 10;
846///
847/// let hruntime = Runtime::run(WORKERS, |_parker| {
848/// let circuit = RootCircuit::build(|circuit| {
849/// // Create a data source that generates numbers 0, 1, 2, ...
850/// let mut n: usize = 0;
851/// let source = circuit.add_source(Generator::new(move || {
852/// let result = n;
853/// n += 1;
854/// result
855/// }));
856///
857/// // Create an `ExchangeSender`/`ExchangeReceiver pair`.
858/// let (sender, receiver) = new_exchange_operators(
859/// None,
860/// || Vec::new(),
861/// // Partitioning function sends a copy of the input `n` to each peer.
862/// |n, output| {
863/// for _ in 0..WORKERS {
864/// output.push(n)
865/// }
866/// },
867/// |value| to_bytes(&value).unwrap().into_vec(),
868/// |data| unaligned_deserialize(&data[..]),/// // Reassemble received values into a vector.
869/// |v: &mut Vec<usize>, n| v.push(n),
870/// ).unwrap();
871///
872/// // Add exchange operators to the circuit.
873/// let combined = circuit.add_exchange(sender, receiver, &source);
874/// let mut round = 0;
875///
876/// // Expected output stream of`ExchangeReceiver`:
877/// // [0,0,0,...]
878/// // [1,1,1,...]
879/// // [2,2,2,...]
880/// // ...
881/// combined.inspect(move |v| {
882/// assert_eq!(&vec![round; WORKERS], v);
883/// round += 1;
884/// });
885/// Ok(())
886/// })
887/// .unwrap()
888/// .0;
889///
890/// for _ in 1..ROUNDS {
891/// circuit.step();
892/// }
893/// }).expect("failed to start runtime");
894///
895/// hruntime.join().unwrap();
896/// # }
897/// ```
898pub struct ExchangeSender<D, T, L>
899where
900 T: Send + 'static + Clone,
901{
902 worker_index: usize,
903 location: OperatorLocation,
904 partition: L,
905 outputs: Vec<T>,
906 exchange: Arc<Exchange<(T, bool)>>,
907
908 // Input batch sizes.
909 input_batch_stats: BatchSizeStats,
910
911 flushed: bool,
912
913 // The instant when the sender produced its outputs, and the
914 // receiver starts waiting for all other workers to produce their
915 // outputs.
916 start_wait_usecs: Arc<AtomicU64>,
917
918 phantom: PhantomData<D>,
919}
920
921impl<D, T, L> ExchangeSender<D, T, L>
922where
923 T: Send + 'static + Clone,
924{
925 fn new(
926 worker_index: usize,
927 location: OperatorLocation,
928 exchange: Arc<Exchange<(T, bool)>>,
929 start_wait_usecs: Arc<AtomicU64>,
930 partition: L,
931 ) -> Self {
932 debug_assert!(worker_index < Runtime::num_workers());
933 Self {
934 worker_index,
935 location,
936 partition,
937 outputs: Vec::with_capacity(Runtime::num_workers()),
938 exchange,
939 input_batch_stats: BatchSizeStats::new(),
940 flushed: false,
941 start_wait_usecs,
942 phantom: PhantomData,
943 }
944 }
945}
946
947impl<D, T, L> Operator for ExchangeSender<D, T, L>
948where
949 D: 'static,
950 T: Send + 'static + Clone,
951 L: 'static,
952{
953 fn name(&self) -> Cow<'static, str> {
954 Cow::from("ExchangeSender")
955 }
956
957 fn metadata(&self, meta: &mut OperatorMeta) {
958 meta.extend(metadata! {
959 INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
960 EXCHANGE_SERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.serialization_usecs.load(Ordering::Acquire))),
961 EXCHANGE_SERIALIZED_BYTES => MetaItem::bytes(self.exchange.serialized_bytes.load(Ordering::Acquire))
962 });
963 }
964
965 fn location(&self) -> OperatorLocation {
966 self.location
967 }
968
969 fn clock_start(&mut self, _scope: Scope) {}
970 fn clock_end(&mut self, _scope: Scope) {}
971
972 fn is_async(&self) -> bool {
973 true
974 }
975
976 fn register_ready_callback<F>(&mut self, cb: F)
977 where
978 F: Fn() + Send + Sync + 'static,
979 {
980 self.exchange
981 .register_sender_callback(self.worker_index, cb)
982 }
983
984 fn ready(&self) -> bool {
985 self.exchange.ready_to_send(self.worker_index)
986 }
987
988 fn fixedpoint(&self, _scope: Scope) -> bool {
989 true
990 }
991
992 fn flush(&mut self) {
993 self.flushed = true;
994 }
995}
996
997impl<D, T, L> SinkOperator<D> for ExchangeSender<D, T, L>
998where
999 D: Clone + NumEntries + 'static,
1000 T: Clone + Send + 'static,
1001 L: FnMut(D, &mut Vec<T>) + 'static,
1002{
1003 async fn eval(&mut self, input: &D) {
1004 self.eval_owned(input.clone()).await
1005 }
1006
1007 async fn eval_owned(&mut self, input: D) {
1008 self.input_batch_stats.add_batch(input.num_entries_deep());
1009
1010 debug_assert!(self.ready());
1011 self.outputs.clear();
1012 (self.partition)(input, &mut self.outputs);
1013 self.start_wait_usecs
1014 .store(current_time_usecs(), Ordering::Release);
1015
1016 let res = self.exchange.try_send_all(
1017 self.worker_index,
1018 &mut self.outputs.drain(..).map(|x| (x, self.flushed)),
1019 );
1020 self.flushed = false;
1021 debug_assert!(res);
1022 }
1023
1024 fn input_preference(&self) -> OwnershipPreference {
1025 OwnershipPreference::PREFER_OWNED
1026 }
1027}
1028
1029/// Operator that receives values sent by the `ExchangeSender` operator and
1030/// assembles them into a single output value.
1031///
1032/// The `init` closure returns the initial value for the result. This value
1033/// is updated by the `combine` closure with each value received from a remote
1034/// peer.
1035///
1036/// See [`ExchangeSender`] documentation for details.
1037///
1038/// `ExchangeReceiver` is an asynchronous operator., i.e.,
1039/// [`ExchangeReceiver::is_async`] returns `true`. It becomes schedulable
1040/// ([`ExchangeReceiver::ready`] returns `true`) once all peers have sent values
1041/// for this worker in the current clock cycle. The scheduler should use
1042/// [`ExchangeReceiver::register_ready_callback`] to get notified when the
1043/// operator becomes schedulable.
1044pub struct ExchangeReceiver<IF, T, L>
1045where
1046 T: Send + 'static + Clone,
1047{
1048 worker_index: usize,
1049 location: OperatorLocation,
1050 init: IF,
1051 combine: L,
1052 exchange: Arc<Exchange<(T, bool)>>,
1053 flush_count: usize,
1054 flush_complete: bool,
1055 start_wait_usecs: Arc<AtomicU64>,
1056 total_wait_time: Arc<AtomicU64>,
1057
1058 // Output batch sizes.
1059 output_batch_stats: BatchSizeStats,
1060}
1061
1062impl<IF, T, L> ExchangeReceiver<IF, T, L>
1063where
1064 T: Send + 'static + Clone,
1065{
1066 pub(crate) fn new(
1067 worker_index: usize,
1068 location: OperatorLocation,
1069 exchange: Arc<Exchange<(T, bool)>>,
1070 init: IF,
1071 start_wait_usecs: Arc<AtomicU64>,
1072 combine: L,
1073 ) -> Self {
1074 debug_assert!(worker_index < Runtime::num_workers());
1075
1076 Self {
1077 worker_index,
1078 location,
1079 init,
1080 combine,
1081 exchange,
1082 flush_count: 0,
1083 flush_complete: false,
1084 output_batch_stats: BatchSizeStats::new(),
1085 start_wait_usecs,
1086 total_wait_time: Arc::new(AtomicU64::new(0)),
1087 }
1088 }
1089}
1090
1091impl<D, T, L> Operator for ExchangeReceiver<D, T, L>
1092where
1093 D: 'static,
1094 T: Send + 'static + Clone,
1095 L: 'static,
1096{
1097 fn name(&self) -> Cow<'static, str> {
1098 Cow::from("ExchangeReceiver")
1099 }
1100
1101 fn location(&self) -> OperatorLocation {
1102 self.location
1103 }
1104
1105 fn metadata(&self, meta: &mut OperatorMeta) {
1106 meta.extend(metadata! {
1107 OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
1108 EXCHANGE_WAIT_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.total_wait_time.load(Ordering::Acquire))),
1109 EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.inner.delivery_usecs.load(Ordering::Acquire))),
1110 EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.inner.delivered_bytes.load(Ordering::Acquire)),
1111 });
1112 }
1113
1114 fn is_async(&self) -> bool {
1115 true
1116 }
1117
1118 fn register_ready_callback<F>(&mut self, cb: F)
1119 where
1120 F: Fn() + Send + Sync + 'static,
1121 {
1122 let start_wait_usecs = self.start_wait_usecs.clone();
1123 let total_wait_time = self.total_wait_time.clone();
1124 let exchange = self.exchange.clone();
1125 let worker_index = self.worker_index;
1126
1127 let cb = move || {
1128 if exchange.ready_to_receive(worker_index) {
1129 // The callback can be invoked multiple times per step.
1130 // Reset start_wait_usecs to 0 to make sure we don't double-count.
1131 let start = start_wait_usecs.swap(0, Ordering::Acquire);
1132 if start != 0 {
1133 let end = current_time_usecs();
1134 if end > start {
1135 let wait_time_usecs = end - start;
1136 // if worker_index == 0 {
1137 // info!(
1138 // "{worker_index}: {} +{wait_time_usecs}",
1139 // exchange.exchange_id()
1140 // );
1141 // }
1142 total_wait_time.fetch_add(wait_time_usecs, Ordering::AcqRel);
1143 }
1144 }
1145 }
1146 cb()
1147 };
1148 self.exchange
1149 .register_receiver_callback(self.worker_index, cb)
1150 }
1151
1152 fn ready(&self) -> bool {
1153 self.exchange.ready_to_receive(self.worker_index)
1154 }
1155
1156 fn fixedpoint(&self, _scope: Scope) -> bool {
1157 true
1158 }
1159
1160 fn flush(&mut self) {
1161 // println!("{} exchange_receiver::flush", Runtime::worker_index());
1162 self.flush_complete = false;
1163 }
1164
1165 fn is_flush_complete(&self) -> bool {
1166 // println!(
1167 // "{} exchange_receiver::is_flush_complete (flush_complete = {})",
1168 // Runtime::worker_index(),
1169 // self.flush_complete
1170 // );
1171 self.flush_complete
1172 }
1173}
1174
1175impl<D, IF, T, L> SourceOperator<D> for ExchangeReceiver<IF, T, L>
1176where
1177 D: NumEntries + 'static,
1178 T: Clone + Send + 'static,
1179 IF: Fn() -> D + 'static,
1180 L: Fn(&mut D, T) + 'static,
1181{
1182 async fn eval(&mut self) -> D {
1183 debug_assert!(self.ready());
1184 let mut combined = (self.init)();
1185 let res = self
1186 .exchange
1187 .try_receive_all(self.worker_index, |(x, flushed)| {
1188 // println!(
1189 // "{} exchange_receiver::eval received input with flushed={:?}",
1190 // Runtime::worker_index(),
1191 // flushed
1192 // );
1193 if flushed {
1194 self.flush_count += 1;
1195 }
1196 (self.combine)(&mut combined, x)
1197 });
1198 if self.flush_count == Runtime::num_workers() {
1199 // println!(
1200 // "{} exchange_receiver::eval received all inputs",
1201 // Runtime::worker_index()
1202 // );
1203
1204 self.flush_complete = true;
1205 self.flush_count = 0;
1206 }
1207
1208 debug_assert!(res);
1209
1210 self.output_batch_stats
1211 .add_batch(combined.num_entries_deep());
1212 combined
1213 }
1214}
1215
1216#[derive(Hash, PartialEq, Eq)]
1217struct ClientsId;
1218
1219impl TypedMapKey<LocalStoreMarker> for ClientsId {
1220 type Value = Arc<Clients>;
1221}
1222
1223#[derive(Hash, PartialEq, Eq)]
1224struct DirectoryId;
1225
1226impl TypedMapKey<LocalStoreMarker> for DirectoryId {
1227 type Value = ExchangeDirectory;
1228}
1229
1230/// Create an [`ExchangeSender`]/[`ExchangeReceiver`] operator pair.
1231///
1232/// See [`ExchangeSender`] documentation for details and example usage.
1233///
1234/// # Arguments
1235///
1236/// * `runtime` - [`Runtime`](`crate::circuit::Runtime`) within which operators
1237/// are created.
1238/// * `worker_index` - index of the current worker.
1239/// * `partition` - partitioning logic that must push exactly
1240/// `runtime.num_workers()` values into its vector argument
1241/// * `serialize` - serializes exchanged data for transmission across a network
1242/// * `deserialize` - deserializes exchanged data that was transmitted across a network
1243/// * `combine` - re-assemble logic that combines values received from all peers
1244/// into a single output value.
1245///
1246/// # Type arguments
1247/// * `TI` - Type of values in the input stream consumed by `ExchangeSender`.
1248/// * `TO` - Type of values in the output stream produced by `ExchangeReceiver`.
1249/// * `TE` - Type of values sent across workers.
1250/// * `PL` - Type of closure that splits a value of type `TI` into
1251/// `runtime.num_workers()` values of type `TE`.
1252/// * `I` - Iterator returned by `PL`.
1253/// * `IF` - Type of closure used to initialize the output value of type `TO`.
1254/// * `CL` - Type of closure that folds `num_workers` values of type `TE` into a
1255/// value of type `TO`.
1256pub fn new_exchange_operators<TI, TO, TE, IF, PL, CL, S, D>(
1257 location: OperatorLocation,
1258 init: IF,
1259 partition: PL,
1260 serialize: S,
1261 deserialize: D,
1262 combine: CL,
1263) -> Option<(ExchangeSender<TI, TE, PL>, ExchangeReceiver<IF, TE, CL>)>
1264where
1265 TO: Clone,
1266 TE: Send + 'static + Clone,
1267 IF: Fn() -> TO + 'static,
1268 PL: FnMut(TI, &mut Vec<TE>) + 'static,
1269 S: Fn(TE) -> Vec<u8> + Send + Sync + 'static,
1270 D: Fn(Vec<u8>) -> TE + Send + Sync + 'static,
1271 CL: Fn(&mut TO, TE) + 'static,
1272{
1273 if Runtime::num_workers() == 1 {
1274 return None;
1275 }
1276 let runtime = Runtime::runtime().unwrap();
1277 let worker_index = Runtime::worker_index();
1278
1279 let exchange_id = runtime.sequence_next();
1280 let start_wait_usecs = Arc::new(AtomicU64::new(0));
1281 let exchange = Exchange::with_runtime(
1282 &runtime,
1283 exchange_id,
1284 Box::new(move |(value, flush)| {
1285 let mut vec = serialize(value);
1286 vec.push(flush as u8);
1287 vec
1288 }),
1289 Box::new(move |mut vec| {
1290 let flush = match vec.pop().unwrap() {
1291 0 => false,
1292 1 => true,
1293 _ => unreachable!(),
1294 };
1295 (deserialize(vec), flush)
1296 }),
1297 );
1298 let sender = ExchangeSender::new(
1299 worker_index,
1300 location,
1301 exchange.clone(),
1302 start_wait_usecs.clone(),
1303 partition,
1304 );
1305 let receiver = ExchangeReceiver::new(
1306 worker_index,
1307 location,
1308 exchange,
1309 init,
1310 start_wait_usecs,
1311 combine,
1312 );
1313 Some((sender, receiver))
1314}
1315
1316#[cfg(test)]
1317mod tests {
1318 use super::Exchange;
1319 use crate::{
1320 Circuit, RootCircuit,
1321 circuit::{
1322 Runtime,
1323 schedule::{DynamicScheduler, Scheduler},
1324 },
1325 operator::{Generator, communication::new_exchange_operators},
1326 storage::file::{to_bytes, to_bytes_dyn},
1327 trace::unaligned_deserialize,
1328 };
1329 use std::thread::yield_now;
1330
1331 // We decrease the number of rounds we do when we're running under miri,
1332 // otherwise it'll run forever
1333 const ROUNDS: usize = if cfg!(miri) { 128 } else { 2048 };
1334
1335 // Create an exchange object with `WORKERS` concurrent senders/receivers.
1336 // Iterate for `ROUNDS` rounds with each sender sending value `N` to each
1337 // receiver in round number `N`. Both senders and receivers may retry
1338 // sending/receiving multiple times, but in the end each receiver should get
1339 // all values in correct order.
1340 #[test]
1341 #[cfg_attr(miri, ignore)]
1342 fn test_exchange() {
1343 const WORKERS: usize = 16;
1344
1345 let hruntime = Runtime::run(WORKERS, |_parker| {
1346 let exchange = Exchange::with_runtime(
1347 &Runtime::runtime().unwrap(),
1348 0,
1349 Box::new(|value| to_bytes(&value).unwrap().into_vec()),
1350 Box::new(|data| unaligned_deserialize(&data[..])),
1351 );
1352
1353 for round in 0..ROUNDS {
1354 let output_data = vec![round; WORKERS];
1355 let mut output_iter = output_data.clone().into_iter();
1356 loop {
1357 if exchange.try_send_all(Runtime::worker_index(), &mut output_iter) {
1358 break;
1359 }
1360
1361 yield_now();
1362 }
1363
1364 let mut input_data = Vec::with_capacity(WORKERS);
1365 loop {
1366 if exchange.try_receive_all(Runtime::worker_index(), |x| input_data.push(x)) {
1367 break;
1368 }
1369
1370 yield_now();
1371 }
1372
1373 assert_eq!(input_data, output_data);
1374 }
1375 })
1376 .expect("failed to start runtime");
1377
1378 hruntime.join().unwrap();
1379 }
1380
1381 #[test]
1382 #[cfg_attr(miri, ignore)]
1383 fn test_exchange_operators_dynamic() {
1384 test_exchange_operators::<DynamicScheduler>();
1385 }
1386
1387 // Create a circuit with `WORKERS` concurrent workers with the following
1388 // structure: `Generator - ExchangeSender -> ExchangeReceiver -> Inspect`.
1389 // `Generator` - yields sequential numbers 0, 1, 2, ...
1390 // `ExchangeSender` - sends each number to all peers.
1391 // `ExchangeReceiver` - combines all received numbers in a vector.
1392 // `Inspect` - validates the output of the receiver.
1393 fn test_exchange_operators<S>()
1394 where
1395 S: Scheduler + 'static,
1396 {
1397 fn do_test<S>(workers: usize)
1398 where
1399 S: Scheduler + 'static,
1400 {
1401 let hruntime = Runtime::run(workers, move |_parker| {
1402 let circuit = RootCircuit::build_with_scheduler::<_, _, S>(move |circuit| {
1403 let mut n: usize = 0;
1404 let source = circuit.add_source(Generator::new(move || {
1405 let result = n;
1406 n += 1;
1407 result
1408 }));
1409
1410 let (sender, receiver) = new_exchange_operators(
1411 None,
1412 Vec::new,
1413 move |n, vals| {
1414 for _ in 0..workers {
1415 vals.push(n)
1416 }
1417 },
1418 |value| to_bytes_dyn(&value).unwrap().into_vec(),
1419 |data| unaligned_deserialize(&data[..]),
1420 |v: &mut Vec<usize>, n| v.push(n),
1421 )
1422 .unwrap();
1423
1424 let mut round = 0;
1425 circuit
1426 .add_exchange(sender, receiver, &source)
1427 .inspect(move |v| {
1428 assert_eq!(&vec![round; workers], v);
1429 round += 1;
1430 });
1431 Ok(())
1432 })
1433 .unwrap()
1434 .0;
1435
1436 for _ in 1..ROUNDS {
1437 circuit.transaction().unwrap();
1438 }
1439 })
1440 .expect("failed to start runtime");
1441
1442 hruntime.join().unwrap();
1443 }
1444
1445 do_test::<S>(2);
1446 do_test::<S>(16);
1447 do_test::<S>(32);
1448 }
1449}