Skip to main content

dbsp/operator/communication/
exchange.rs

1//! Exchange operators implement a N-to-N communication pattern where
2//! each participant sends exactly one value to and receives exactly one
3//! value from each peer at every clock cycle.
4
5// TODO: We may want to generalize these operators to implement N-to-M
6// communication, including 1-to-N and N-to-1.
7
8use crate::{
9    NumEntries, WeakRuntime,
10    circuit::{
11        Host, LocalStoreMarker, OwnershipPreference, Runtime, Scope,
12        metadata::{
13            BatchSizeStats, EXCHANGE_DESERIALIZATION_TIME_SECONDS, EXCHANGE_DESERIALIZED_BYTES,
14            EXCHANGE_SERIALIZATION_TIME_SECONDS, EXCHANGE_SERIALIZED_BYTES,
15            EXCHANGE_WAIT_TIME_SECONDS, INPUT_BATCHES_STATS, MetaItem, OUTPUT_BATCHES_STATS,
16            OperatorLocation, OperatorMeta,
17        },
18        operator_traits::{Operator, SinkOperator, SourceOperator},
19        tokio::TOKIO,
20    },
21    circuit_cache_key,
22};
23use crossbeam_utils::CachePadded;
24use futures::{future, prelude::*, stream::FuturesUnordered};
25use std::{
26    borrow::Cow,
27    collections::HashMap,
28    marker::PhantomData,
29    net::SocketAddr,
30    ops::Range,
31    sync::{
32        Arc, Mutex, OnceLock, RwLock,
33        atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering},
34    },
35    time::{Duration, Instant, SystemTime},
36};
37use tarpc::{
38    client, context,
39    serde_transport::tcp::{connect, listen},
40    server::{self, Channel},
41    tokio_serde::formats::Bincode,
42    tokio_util::sync::{CancellationToken, DropGuard},
43};
44use tokio::{
45    sync::{Notify, OnceCell as TokioOnceCell},
46    time::sleep,
47};
48use typedmap::TypedMapKey;
49
50/// Current time in microseconds.
51fn current_time_usecs() -> u64 {
52    SystemTime::now()
53        .duration_since(std::time::UNIX_EPOCH)
54        .map(|d| d.as_micros().try_into().unwrap_or(u64::MAX))
55        .unwrap_or(0)
56}
57
58// We use the `Runtime::local_store` mechanism to connect multiple workers
59// to an `Exchange` instance.  During circuit construction, each worker
60// allocates a unique id that happens to be the same across all workers.
61// The worker then allocates a new `Exchange` and adds it to the local store
62// using the id as a key.  If there already is an `Exchange` with this id in
63// the store, created by another worker, a reference to that `Exchange` will
64// be used instead.
65circuit_cache_key!(local ExchangeCacheId<T>(ExchangeId => Arc<Exchange<T>>));
66
67#[tarpc::service]
68trait ExchangeService {
69    /// Sends messages in `exchange_id` from all of the worker threads in
70    /// `senders` to all of the worker thread receivers in the server that
71    /// processes the message.  The Bincode-encoded message from `sender` to
72    /// `receiver` is `data[sender - senders.start][receiver -
73    /// receivers.start]`, where `receivers` is the `Range<usize>` of worker
74    /// thread IDs in the server that processes the message.
75    async fn exchange(exchange_id: usize, senders: Range<usize>, data: Vec<Vec<Vec<u8>>>);
76}
77
78type ExchangeId = usize;
79
80// Maps from an `exchange_id` to the `Inner` that implements the exchange.
81type ExchangeDirectory = Arc<RwLock<HashMap<ExchangeId, Arc<InnerExchange>>>>;
82
83#[derive(Clone)]
84struct ExchangeServer(ExchangeDirectory);
85
86impl ExchangeService for ExchangeServer {
87    async fn exchange(
88        self,
89        _: context::Context,
90        exchange_id: ExchangeId,
91        senders: Range<usize>,
92        data: Vec<Vec<Vec<u8>>>,
93    ) {
94        let inner = self.0.read().unwrap().get(&exchange_id).unwrap().clone();
95        inner.received(senders, data).await;
96    }
97}
98
99struct Clients {
100    runtime: WeakRuntime,
101
102    /// Listens for connections from other hosts.
103    ///
104    /// We create this lazily upon the first attempt to connect to other hosts.
105    /// If we create it before we've completely initialized the circuit, then we
106    /// might not have created all of the exchanges yet when some other host
107    /// tries to send data to one.
108    listener: OnceLock<Option<ExchangeListener>>,
109
110    /// Maps from a range of worker IDs to the RPC client used to contact those
111    /// workers.  Only worker IDs for remote workers appear in the map.
112    clients: Vec<(Host, TokioOnceCell<ExchangeServiceClient>)>,
113}
114
115impl Clients {
116    fn new(runtime: &Runtime) -> Clients {
117        Self {
118            runtime: runtime.downgrade(),
119            listener: Default::default(),
120            clients: runtime
121                .layout()
122                .other_hosts()
123                .map(|host| (host.clone(), TokioOnceCell::new()))
124                .collect(),
125        }
126    }
127
128    /// Returns a client for `worker`, which must be a remote worker ID, first
129    /// establishing a connection if there isn't one yet.
130    async fn connect(&self, worker: usize) -> &ExchangeServiceClient {
131        self.listener.get_or_init(|| {
132            if let Some(runtime) = self.runtime.upgrade()
133                && let Some(local_address) = runtime.layout().local_address()
134            {
135                let directory = runtime.local_store().get(&DirectoryId).unwrap().clone();
136                Some(ExchangeListener::new(local_address, directory))
137            } else {
138                None
139            }
140        });
141
142        let (host, cell) = self
143            .clients
144            .iter()
145            .find(|(host, _client)| host.workers.contains(&worker))
146            .unwrap();
147        cell.get_or_init(|| async {
148            let transport = loop {
149                let mut transport = connect(host.address, Bincode::default);
150                transport.config_mut().max_frame_length(usize::MAX);
151                match transport.await {
152                    Ok(transport) => break transport,
153                    Err(error) => println!(
154                        "connection to {} failed ({error}), waiting to retry",
155                        host.address
156                    ),
157                }
158                sleep(std::time::Duration::from_millis(1000)).await;
159            };
160            ExchangeServiceClient::new(client::Config::default(), transport).spawn()
161        })
162        .await
163    }
164}
165
166struct CallbackInner {
167    cb: Option<Box<dyn Fn() + Send + Sync>>,
168}
169
170impl CallbackInner {
171    fn empty() -> Self {
172        Self { cb: None }
173    }
174
175    fn new<F>(cb: F) -> Self
176    where
177        F: Fn() + Send + Sync + 'static,
178    {
179        let cb = Box::new(cb) as Box<dyn Fn() + Send + Sync>;
180        Self { cb: Some(cb) }
181    }
182}
183
184struct Callback(AtomicPtr<CallbackInner>);
185
186impl Callback {
187    fn empty() -> Self {
188        Self(AtomicPtr::new(Box::into_raw(Box::new(
189            CallbackInner::empty(),
190        ))))
191    }
192
193    fn set_callback(&self, cb: impl Fn() + Send + Sync + 'static) {
194        let old_callback = self.0.swap(
195            Box::into_raw(Box::new(CallbackInner::new(cb))),
196            Ordering::AcqRel,
197        );
198
199        let old_callback = unsafe { Box::from_raw(old_callback) };
200        drop(old_callback);
201    }
202
203    fn call(&self) {
204        if let Some(cb) = &unsafe { &*self.0.load(Ordering::Acquire) }.cb {
205            cb()
206        }
207    }
208}
209
210struct InnerExchange {
211    exchange_id: ExchangeId,
212    /// The number of communicating peers.
213    npeers: usize,
214    /// Range of worker IDs on the local host.
215    local_workers: Range<usize>,
216    /// Counts the number of messages yet to be received in the current round of
217    /// communication per receiver.  The receiver must wait until it has all
218    /// `npeers` messages before reading all of them from mailboxes in one
219    /// pass.
220    receiver_counters: Vec<AtomicUsize>,
221    /// Callback invoked when all `npeers` messages are ready for a receiver.
222    receiver_callbacks: Vec<Callback>,
223    /// Counts the number of empty mailboxes ready to accept new data per
224    /// sender. The sender waits until it has `npeers` available mailboxes
225    /// before writing all of them in one pass.
226    sender_counters: Vec<CachePadded<AtomicUsize>>,
227    /// Callback invoked when all `npeers` mailboxes are available.
228    sender_callbacks: Vec<Callback>,
229    /// The number of workers that have already sent their messages in the
230    /// current round.
231    sent: AtomicUsize,
232    /// The RPC clients to contact remote hosts.
233    clients: Arc<Clients>,
234    /// This allows the `exchange` RPC to wait until the receiver has taken its
235    /// data out of the mailbox.  There are `n_remote_workers * n_local_workers`
236    /// elements.
237    sender_notifies: Vec<Notify>,
238    /// A callback that takes the raw data exchanged over RPC and deserializes
239    /// and delivers it to the receiver's mailbox.
240    deliver: Box<dyn Fn(Vec<u8>, usize, usize) + Send + Sync + 'static>,
241    /// The amount of time spent in `deliver`.
242    delivery_usecs: AtomicU64,
243    /// The number of bytes passed to `deliver`.
244    delivered_bytes: AtomicUsize,
245}
246
247impl InnerExchange {
248    fn new(
249        exchange_id: ExchangeId,
250        deliver: impl Fn(Vec<u8>, usize, usize) + Send + Sync + 'static,
251        clients: Arc<Clients>,
252    ) -> InnerExchange {
253        let runtime = Runtime::runtime().unwrap();
254        let npeers = Runtime::num_workers();
255        let local_workers = runtime.layout().local_workers();
256        let n_local_workers = local_workers.len();
257        let n_remote_workers = npeers - n_local_workers;
258        Self {
259            exchange_id,
260            npeers,
261            local_workers,
262            clients,
263            receiver_counters: (0..npeers).map(|_| AtomicUsize::new(0)).collect(),
264            receiver_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
265            sender_notifies: (0..n_local_workers * n_remote_workers)
266                .map(|_| Notify::new())
267                .collect(),
268            sender_counters: (0..npeers)
269                .map(|_| CachePadded::new(AtomicUsize::new(npeers)))
270                .collect(),
271            sender_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
272            deliver: Box::new(deliver),
273            delivery_usecs: AtomicU64::new(0),
274            delivered_bytes: AtomicUsize::new(0),
275            sent: AtomicUsize::new(0),
276        }
277    }
278
279    #[allow(dead_code)]
280    fn exchange_id(&self) -> ExchangeId {
281        self.exchange_id
282    }
283
284    /// Returns the `sender_notify` for a sender/receiver pair.  `receiver`
285    /// must be a local worker ID, and `sender` must be a remote worker ID.
286    fn sender_notify(&self, sender: usize, receiver: usize) -> &Notify {
287        debug_assert!(sender < self.npeers && !self.local_workers.contains(&sender));
288        debug_assert!(self.local_workers.contains(&receiver));
289        let n_local_workers = self.local_workers.len();
290        let sender_ofs = if sender >= self.local_workers.start {
291            sender - n_local_workers
292        } else {
293            sender
294        };
295        let receiver_ofs = receiver - self.local_workers.start;
296        &self.sender_notifies[sender_ofs * n_local_workers + receiver_ofs]
297    }
298
299    /// Receives messages sent from all of the worker threads in `senders` to
300    /// all of the local worker threads `receivers` in `self`.  The
301    /// Bincode-encoded `Vec<u8>` message from `sender` to `receiver` is
302    /// `data[sender - senders.start][receiver - receivers.start]`.
303    async fn received(self: &Arc<Self>, senders: Range<usize>, data: Vec<Vec<Vec<u8>>>) {
304        let receivers = &self.local_workers;
305
306        // Deliver all of the data into the exchange's mailboxes.
307        let start = Instant::now();
308        let mut delivered_bytes = 0;
309        for (sender, data) in senders.clone().zip(data.into_iter()) {
310            assert_eq!(data.len(), receivers.len());
311            for (receiver, data) in receivers.clone().zip(data.into_iter()) {
312                delivered_bytes += data.len();
313                (self.deliver)(data, sender, receiver);
314            }
315        }
316        self.delivery_usecs
317            .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
318        self.delivered_bytes
319            .fetch_add(delivered_bytes, Ordering::Relaxed);
320
321        // Increment the receiver counters and deliver callbacks if necessary.
322        for receiver in receivers.clone() {
323            let n = senders.len();
324            let old_counter = self.receiver_counters[receiver].fetch_add(n, Ordering::AcqRel);
325            if old_counter >= self.npeers - n {
326                self.receiver_callbacks[receiver].call();
327            }
328        }
329
330        // Wait for the receivers to pick up their mail before returning.
331        for sender in senders {
332            for receiver in receivers.clone() {
333                self.sender_notify(sender, receiver).notified().await;
334            }
335        }
336    }
337
338    /// Returns an index for the sender/receiver pair.
339    fn mailbox_index(&self, sender: usize, receiver: usize) -> usize {
340        debug_assert!(sender < self.npeers);
341        debug_assert!(receiver < self.npeers);
342        sender * self.npeers + receiver
343    }
344
345    fn ready_to_send(&self, sender: usize) -> bool {
346        debug_assert!(self.local_workers.contains(&sender));
347        self.sender_counters[sender].load(Ordering::Acquire) == self.npeers
348    }
349
350    fn ready_to_receive(&self, receiver: usize) -> bool {
351        debug_assert!(receiver < self.npeers);
352        self.receiver_counters[receiver].load(Ordering::Acquire) == self.npeers
353    }
354
355    fn register_sender_callback<F>(&self, sender: usize, cb: F)
356    where
357        F: Fn() + Send + Sync + 'static,
358    {
359        debug_assert!(sender < self.npeers);
360        self.sender_callbacks[sender].set_callback(cb);
361    }
362
363    fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
364    where
365        F: Fn() + Send + Sync + 'static,
366    {
367        debug_assert!(receiver < self.npeers);
368
369        self.receiver_callbacks[receiver].set_callback(cb);
370    }
371}
372
373/// `Exchange` is an N-to-N communication primitive that partitions data across
374/// multiple concurrent threads.
375///
376/// An instance of `Exchange` can be shared by multiple threads that communicate
377/// in rounds.  In each round each peer _first_ sends exactly one data value to
378/// every other peer (and itself) and then receives one value from each peer.
379/// The send operation can only proceed when all peers have retrieved data
380/// produced at the previous round.  Likewise, the receive operation can proceed
381/// once all incoming values are ready for the current round.
382///
383/// Each worker has one ExchangeServiceClient and ExchangeServer for every
384/// worker (including itself), so N*N total.
385///
386/// In a round, each worker invokes exchange() once on each of its clients.
387/// Each server handles N calls to exchange(), once for each other worker and
388/// itself.
389///
390/// Each call to exchange populates a mailbox.  When all the mailboxes for a
391/// worker have been populated, it can read and clear them.
392pub(crate) struct Exchange<T> {
393    inner: Arc<InnerExchange>,
394    /// `npeers^2` mailboxes, one for each sender/receiver pair.  Each mailbox
395    /// is accessed by exactly two threads, so contention is low.
396    ///
397    /// We only use the mailboxes where either the sender or the receiver is one
398    /// of our local workers. In the diagram below, L is mailboxes used for
399    /// local exchange, S mailboxes used for sending RPC exchange, and R
400    /// mailboxes used for receiving exchange via RPC:
401    ///
402    /// ```text
403    ///           <-------receivers------->
404    ///                  local
405    ///                 workers
406    /// ^         -------------------------
407    /// |         |     |RRRRR|     |     |
408    ///           |     |RRRRR|     |     |
409    /// s         |-----|-----|-----|-----|
410    /// e  local  |SSSSS|LLLLL|SSSSS|SSSSS|
411    /// n workers |SSSSS|LLLLL|SSSSS|SSSSS|
412    /// d         |-----|-----|-----|-----|
413    /// e         |     |RRRRR|     |     |
414    /// r         |     |RRRRR|     |     |
415    /// s         |-----|-----|-----|-----|
416    ///           |     |RRRRR|     |     |
417    /// |         |     |RRRRR|     |     |
418    /// v         |-----|-----|-----|-----|
419    /// ```
420    mailboxes: Arc<Vec<Mutex<Option<T>>>>,
421    serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
422
423    /// The amount of time we've spent calling `serialize`.
424    serialization_usecs: AtomicU64,
425
426    /// The number of bytes produced by `serialize`.
427    serialized_bytes: AtomicUsize,
428}
429
430async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
431    tokio::spawn(fut);
432}
433
434// Stop Rust from complaining about unused field.
435#[allow(dead_code)]
436struct ExchangeListener(DropGuard);
437
438impl ExchangeListener {
439    fn new(address: SocketAddr, directory: ExchangeDirectory) -> Self {
440        let token = CancellationToken::new();
441        let drop = token.clone().drop_guard();
442        TOKIO.spawn(async move {
443            println!("listening on {address}");
444            let mut listener = listen(address, Bincode::default).await.unwrap();
445            listener.config_mut().max_frame_length(usize::MAX);
446            let incoming = listener
447                .filter_map(|r| future::ready(r.ok()))
448                .map(server::BaseChannel::with_defaults)
449                .map(move |channel| {
450                    let server = ExchangeServer(directory.clone());
451                    channel.execute(server.serve()).for_each(spawn)
452                })
453                .buffer_unordered(10)
454                .for_each(|_| async {});
455            tokio::select! {
456                _ = incoming => {}
457                _ = token.cancelled() => {}
458            }
459        });
460        Self(drop)
461    }
462}
463
464impl<T> Exchange<T>
465where
466    T: Clone + Send + 'static,
467{
468    /// Create a new exchange operator for `npeers` communicating threads.
469    fn new(
470        exchange_id: ExchangeId,
471        clients: Arc<Clients>,
472        directory: ExchangeDirectory,
473        serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
474        deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,
475    ) -> Self {
476        let npeers = Runtime::num_workers();
477        let mailboxes: Arc<Vec<Mutex<Option<T>>>> =
478            Arc::new((0..npeers * npeers).map(|_| Mutex::new(None)).collect());
479        let mailboxes2: Arc<Vec<Mutex<Option<T>>>> = mailboxes.clone();
480        let deliver = move |data: Vec<u8>, sender, receiver| {
481            let index: usize = sender * npeers + receiver;
482            let data = deserialize(data);
483            let mut mailbox = mailboxes2[index].lock().unwrap();
484            assert!((*mailbox).is_none());
485            *mailbox = Some(data);
486        };
487
488        let inner = Arc::new(InnerExchange::new(exchange_id, deliver, clients));
489        directory
490            .write()
491            .unwrap()
492            .entry(exchange_id)
493            .and_modify(|_| panic!())
494            .or_insert(inner.clone());
495        Self {
496            inner,
497            mailboxes,
498            serialize,
499            serialization_usecs: AtomicU64::new(0),
500            serialized_bytes: AtomicUsize::new(0),
501        }
502    }
503
504    #[allow(dead_code)]
505    fn exchange_id(&self) -> ExchangeId {
506        self.inner.exchange_id()
507    }
508
509    /// Returns a reference to a mailbox for the sender/receiver pair.
510    fn mailbox(&self, sender: usize, receiver: usize) -> &Mutex<Option<T>> {
511        &self.mailboxes[self.inner.mailbox_index(sender, receiver)]
512    }
513
514    /// Create a new `Exchange` instance if an instance with the same id
515    /// (created by another thread) does not yet exist within `runtime`.
516    /// The number of peers will be set to `runtime.num_workers()`.
517    pub(crate) fn with_runtime(
518        runtime: &Runtime,
519        exchange_id: ExchangeId,
520        serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
521        deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,
522    ) -> Arc<Self> {
523        let directory = runtime
524            .local_store()
525            .entry(DirectoryId)
526            .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())))
527            .clone();
528
529        let clients = runtime
530            .local_store()
531            .entry(ClientsId)
532            .or_insert_with(|| {
533                // Create clients for remote exchange.
534                Arc::new(Clients::new(runtime))
535            })
536            .clone();
537
538        runtime
539            .local_store()
540            .entry(ExchangeCacheId::new(exchange_id))
541            .or_insert_with(|| {
542                Arc::new(Exchange::new(
543                    exchange_id,
544                    clients.clone(),
545                    directory,
546                    serialize,
547                    deserialize,
548                ))
549            })
550            .value()
551            .clone()
552    }
553
554    /// True if all `sender`'s outgoing mailboxes are free and ready to accept
555    /// data.
556    ///
557    /// Once this function returns true, a subsequent `try_send_all` operation
558    /// is guaranteed to succeed for `sender`.
559    pub fn ready_to_send(&self, sender: usize) -> bool {
560        self.inner.ready_to_send(sender)
561    }
562
563    /// Write all outgoing messages for `sender` to mailboxes.
564    ///
565    /// Values to be sent are retrieved from the `data` iterator, with the
566    /// first value delivered to receiver 0, second value delivered to receiver
567    /// 1, and so on.
568    ///
569    /// # Errors
570    ///
571    /// Fails if at least one of the sender's outgoing mailboxes is not empty.
572    ///
573    /// # Panics
574    ///
575    /// Panics if `data` yields fewer than `self.npeers` items.
576    pub(crate) fn try_send_all<I>(self: &Arc<Self>, sender: usize, data: &mut I) -> bool
577    where
578        I: Iterator<Item = T> + Send,
579    {
580        let npeers = self.inner.npeers;
581        if self.inner.sender_counters[sender]
582            .compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
583            .is_err()
584        {
585            return false;
586        }
587
588        // Deliver all of the data to local mailboxes.
589        let local_workers = &self.inner.local_workers;
590        for receiver in 0..npeers {
591            *self.mailbox(sender, receiver).lock().unwrap() = data.next();
592
593            if local_workers.contains(&receiver) {
594                let old_counter =
595                    self.inner.receiver_counters[receiver].fetch_add(1, Ordering::AcqRel);
596                if old_counter >= npeers - 1 {
597                    self.inner.receiver_callbacks[receiver].call();
598                }
599            }
600        }
601
602        // In a single-host layout, or if some of our local workers haven't yet
603        // sent in this round, we're all done for now.
604        if npeers == local_workers.len()
605            || self.inner.sent.fetch_add(1, Ordering::AcqRel) + 1 != local_workers.len()
606        {
607            return true;
608        }
609        self.inner.sent.store(0, Ordering::Release);
610
611        // All of the local workers have sent their data in this round.  Take
612        // all of their data and send it to the remote hosts.
613        let this = self.clone();
614        let runtime = Runtime::runtime().unwrap();
615        TOKIO.spawn(async move {
616            let mut futures = FuturesUnordered::new();
617
618            // For each range of worker IDs `receivers` on a remote host,
619            // accumulate all of the data from our local `senders` to all
620            // of the `receivers` on that host.
621            let senders = &this.inner.local_workers;
622            let start = Instant::now();
623            for host in runtime.layout().other_hosts() {
624                let receivers = &host.workers;
625                let mut serialized_bytes = 0;
626                let items: Vec<Vec<_>> = senders
627                    .clone()
628                    .map(|sender| {
629                        receivers
630                            .clone()
631                            .map(|receiver| {
632                                let item = this
633                                    .mailbox(sender, receiver)
634                                    .lock()
635                                    .unwrap()
636                                    .take()
637                                    .unwrap();
638                                let serialized = (this.serialize)(item);
639                                serialized_bytes += serialized.len();
640                                serialized
641                            })
642                            .collect()
643                    })
644                    .collect();
645                this.serialization_usecs
646                    .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
647                this.serialized_bytes
648                    .fetch_add(serialized_bytes, Ordering::Relaxed);
649
650                let client = this.inner.clients.connect(receivers.start).await;
651
652                // Send it.
653                let mut context = context::current();
654                context.deadline = Instant::now() + Duration::from_hours(1);
655                futures.push(client.exchange(
656                    context,
657                    this.inner.exchange_id,
658                    senders.clone(),
659                    items,
660                ));
661            }
662
663            // Wait for all the sends to complete.
664            while let Some(result) = futures.next().await {
665                result.unwrap();
666            }
667
668            // Record that the sends completed.
669            let n = npeers - senders.len();
670            for sender in senders.clone() {
671                let old_counter = this.inner.sender_counters[sender].fetch_add(n, Ordering::AcqRel);
672                if old_counter >= npeers - n {
673                    this.inner.sender_callbacks[sender].call();
674                }
675            }
676        });
677
678        true
679    }
680
681    /// True if all `receiver`'s incoming mailboxes contain data.
682    ///
683    /// Once this function returns true, a subsequent `try_receive_all`
684    /// operation is guaranteed for `receiver`.
685    pub(crate) fn ready_to_receive(&self, receiver: usize) -> bool {
686        self.inner.ready_to_receive(receiver)
687    }
688
689    /// Read all incoming messages for `receiver`.
690    ///
691    /// Values are passed to callback function `cb` in the order of worker indexes.
692    ///
693    /// # Errors
694    ///
695    /// Fails if at least one of the receiver's incoming mailboxes is empty.
696    pub(crate) fn try_receive_all<F>(&self, receiver: usize, mut cb: F) -> bool
697    where
698        F: FnMut(T),
699    {
700        let npeers = self.inner.npeers;
701        if self.inner.receiver_counters[receiver]
702            .compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
703            .is_err()
704        {
705            return false;
706        }
707
708        for sender in 0..self.inner.npeers {
709            let data = self
710                .mailbox(sender, receiver)
711                .lock()
712                .unwrap()
713                .take()
714                .unwrap();
715            cb(data);
716            if self.inner.local_workers.contains(&sender) {
717                let old_counter = self.inner.sender_counters[sender].fetch_add(1, Ordering::AcqRel);
718                if old_counter >= self.inner.npeers - 1 {
719                    self.inner.sender_callbacks[sender].call();
720                }
721            } else {
722                self.inner.sender_notify(sender, receiver).notify_one();
723            }
724        }
725        true
726    }
727
728    /// Register callback to be invoked whenever the `ready_to_send` condition
729    /// becomes true.
730    ///
731    /// The callback can be setup at most once (e.g., when a scheduler attaches
732    /// to the circuit) and cannot be unregistered.  Notifications delivered
733    /// before the callback is registered are lost.  The client should call
734    /// `ready_to_send` after installing the callback to check the status.
735    ///
736    /// After the callback has been registered, notifications are delivered with
737    /// at-least-once semantics: a notification is generated whenever the
738    /// status changes from not ready to ready, but spurious notifications
739    /// can occur occasionally.  Therefore, the user must check the status
740    /// explicitly by calling `ready_to_send` or be prepared that `try_send_all`
741    /// can fail.
742    pub(crate) fn register_sender_callback<F>(&self, sender: usize, cb: F)
743    where
744        F: Fn() + Send + Sync + 'static,
745    {
746        self.inner.register_sender_callback(sender, cb)
747    }
748
749    /// Register callback to be invoked whenever the `ready_to_receive`
750    /// condition becomes true.
751    ///
752    /// The callback can be setup at most once (e.g., when a scheduler attaches
753    /// to the circuit) and cannot be unregistered.  Notifications delivered
754    /// before the callback is registered are lost.  The client should call
755    /// `ready_to_receive` after installing the callback to check
756    /// the status.
757    ///
758    /// After the callback has been registered, notifications are delivered with
759    /// at-least-once semantics: a notification is generated whenever the
760    /// status changes from not ready to ready, but spurious notifications
761    /// can occur occasionally.  The user must check the status explicitly
762    /// by calling `ready_to_receive` or be prepared that `try_receive_all`
763    /// can fail.
764    pub(crate) fn register_receiver_callback<F>(&self, receiver: usize, cb: F)
765    where
766        F: Fn() + Send + Sync + 'static,
767    {
768        self.inner.register_receiver_callback(receiver, cb)
769    }
770}
771
772/// Operator that partitions incoming data across all workers.
773///
774/// This operator works in tandem with [`ExchangeReceiver`], which reassembles
775/// the data on the receiving side.  Together they implement an all-to-all
776/// communication mechanism, where at every clock cycle each worker partitions
777/// its incoming data into `N` values, one for each worker, using a
778/// user-provided closure.  It then reads values sent to it by all peers and
779/// reassembles them into a single value using another user-provided closure.
780///
781/// The exchange mechanism is split into two operators, so that after sending
782/// the data the circuit does not need to block waiting for its peers to finish
783/// sending and can instead schedule other operators.
784///
785/// ```text
786///                    ExchangeSender  ExchangeReceiver
787///                       ┌───────┐      ┌───────┐
788///                       │       │      │       │
789///        ┌───────┐      │       │      │       │          ┌───────┐
790///        │source ├─────►│       │      │       ├─────────►│ sink  │
791///        └───────┘      │       │      │       │          └───────┘
792///                       │       ├───┬─►│       │
793///                       │       │   │  │       │
794///                       └───────┘   │  └───────┘
795/// WORKER 1                          │
796/// ──────────────────────────────────┼──────────────────────────────
797/// WORKER 2                          │
798///                                   │
799///                       ┌───────┐   │  ┌───────┐
800///                       │       ├───┴─►│       │
801///        ┌───────┐      │       │      │       │          ┌───────┐
802///        │source ├─────►│       │      │       ├─────────►│ sink  │
803///        └───────┘      │       │      │       │          └───────┘
804///                       │       │      │       │
805///                       │       │      │       │
806///                       └───────┘      └───────┘
807///                    ExchangeSender  ExchangeReceiver
808/// ```
809///
810/// `ExchangeSender` is an asynchronous operator., i.e.,
811/// [`ExchangeSender::is_async`] returns `true`.  It becomes schedulable
812/// ([`ExchangeSender::ready`] returns `true`) once all peers have retrieved
813/// values written by the operator in the previous clock cycle.  The scheduler
814/// should use [`ExchangeSender::register_ready_callback`] to get notified when
815/// the operator becomes schedulable.
816///
817/// `ExchangeSender` doesn't have a public constructor and must be instantiated
818/// using the [`new_exchange_operators`] function, which creates an
819/// [`ExchangeSender`]/[`ExchangeReceiver`] pair of operators and connects them
820/// to their counterparts in other workers as in the diagram above.
821///
822/// An [`ExchangeSender`]/[`ExchangeReceiver`] pair is added to a circuit using
823/// the [`Circuit::add_exchange`](`crate::circuit::Circuit::add_exchange`)
824/// method, which registers a dependency between them, making sure that
825/// `ExchangeSender` is evaluated before `ExchangeReceiver`.
826///
827/// # Examples
828///
829/// The following example instantiates the circuit in the diagram above.
830///
831/// ```
832/// # #[cfg(miri)]
833/// # fn main() {}
834///
835/// # #[cfg(not(miri))]
836/// # fn main() {
837/// use dbsp::{
838///     operator::{communication::new_exchange_operators, Generator},
839///     Circuit, RootCircuit, Runtime,
840///     storage::file::to_bytes,
841///     trace::unaligned_deserialize,
842/// };
843///
844/// const WORKERS: usize = 16;
845/// const ROUNDS: usize = 10;
846///
847/// let hruntime = Runtime::run(WORKERS, |_parker| {
848///     let circuit = RootCircuit::build(|circuit| {
849///         // Create a data source that generates numbers 0, 1, 2, ...
850///         let mut n: usize = 0;
851///         let source = circuit.add_source(Generator::new(move || {
852///             let result = n;
853///             n += 1;
854///             result
855///         }));
856///
857///         // Create an `ExchangeSender`/`ExchangeReceiver pair`.
858///         let (sender, receiver) = new_exchange_operators(
859///             None,
860///             || Vec::new(),
861///             // Partitioning function sends a copy of the input `n` to each peer.
862///             |n, output| {
863///                 for _ in 0..WORKERS {
864///                     output.push(n)
865///                 }
866///             },
867///             |value| to_bytes(&value).unwrap().into_vec(),
868///             |data| unaligned_deserialize(&data[..]),///             // Reassemble received values into a vector.
869///             |v: &mut Vec<usize>, n| v.push(n),
870///         ).unwrap();
871///
872///         // Add exchange operators to the circuit.
873///         let combined = circuit.add_exchange(sender, receiver, &source);
874///         let mut round = 0;
875///
876///         // Expected output stream of`ExchangeReceiver`:
877///         // [0,0,0,...]
878///         // [1,1,1,...]
879///         // [2,2,2,...]
880///         // ...
881///         combined.inspect(move |v| {
882///             assert_eq!(&vec![round; WORKERS], v);
883///             round += 1;
884///         });
885///         Ok(())
886///     })
887///     .unwrap()
888///     .0;
889///
890///     for _ in 1..ROUNDS {
891///         circuit.step();
892///     }
893/// }).expect("failed to start runtime");
894///
895/// hruntime.join().unwrap();
896/// # }
897/// ```
898pub struct ExchangeSender<D, T, L>
899where
900    T: Send + 'static + Clone,
901{
902    worker_index: usize,
903    location: OperatorLocation,
904    partition: L,
905    outputs: Vec<T>,
906    exchange: Arc<Exchange<(T, bool)>>,
907
908    // Input batch sizes.
909    input_batch_stats: BatchSizeStats,
910
911    flushed: bool,
912
913    // The instant when the sender produced its outputs, and the
914    // receiver starts waiting for all other workers to produce their
915    // outputs.
916    start_wait_usecs: Arc<AtomicU64>,
917
918    phantom: PhantomData<D>,
919}
920
921impl<D, T, L> ExchangeSender<D, T, L>
922where
923    T: Send + 'static + Clone,
924{
925    fn new(
926        worker_index: usize,
927        location: OperatorLocation,
928        exchange: Arc<Exchange<(T, bool)>>,
929        start_wait_usecs: Arc<AtomicU64>,
930        partition: L,
931    ) -> Self {
932        debug_assert!(worker_index < Runtime::num_workers());
933        Self {
934            worker_index,
935            location,
936            partition,
937            outputs: Vec::with_capacity(Runtime::num_workers()),
938            exchange,
939            input_batch_stats: BatchSizeStats::new(),
940            flushed: false,
941            start_wait_usecs,
942            phantom: PhantomData,
943        }
944    }
945}
946
947impl<D, T, L> Operator for ExchangeSender<D, T, L>
948where
949    D: 'static,
950    T: Send + 'static + Clone,
951    L: 'static,
952{
953    fn name(&self) -> Cow<'static, str> {
954        Cow::from("ExchangeSender")
955    }
956
957    fn metadata(&self, meta: &mut OperatorMeta) {
958        meta.extend(metadata! {
959            INPUT_BATCHES_STATS => self.input_batch_stats.metadata(),
960            EXCHANGE_SERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.serialization_usecs.load(Ordering::Acquire))),
961            EXCHANGE_SERIALIZED_BYTES => MetaItem::bytes(self.exchange.serialized_bytes.load(Ordering::Acquire))
962        });
963    }
964
965    fn location(&self) -> OperatorLocation {
966        self.location
967    }
968
969    fn clock_start(&mut self, _scope: Scope) {}
970    fn clock_end(&mut self, _scope: Scope) {}
971
972    fn is_async(&self) -> bool {
973        true
974    }
975
976    fn register_ready_callback<F>(&mut self, cb: F)
977    where
978        F: Fn() + Send + Sync + 'static,
979    {
980        self.exchange
981            .register_sender_callback(self.worker_index, cb)
982    }
983
984    fn ready(&self) -> bool {
985        self.exchange.ready_to_send(self.worker_index)
986    }
987
988    fn fixedpoint(&self, _scope: Scope) -> bool {
989        true
990    }
991
992    fn flush(&mut self) {
993        self.flushed = true;
994    }
995}
996
997impl<D, T, L> SinkOperator<D> for ExchangeSender<D, T, L>
998where
999    D: Clone + NumEntries + 'static,
1000    T: Clone + Send + 'static,
1001    L: FnMut(D, &mut Vec<T>) + 'static,
1002{
1003    async fn eval(&mut self, input: &D) {
1004        self.eval_owned(input.clone()).await
1005    }
1006
1007    async fn eval_owned(&mut self, input: D) {
1008        self.input_batch_stats.add_batch(input.num_entries_deep());
1009
1010        debug_assert!(self.ready());
1011        self.outputs.clear();
1012        (self.partition)(input, &mut self.outputs);
1013        self.start_wait_usecs
1014            .store(current_time_usecs(), Ordering::Release);
1015
1016        let res = self.exchange.try_send_all(
1017            self.worker_index,
1018            &mut self.outputs.drain(..).map(|x| (x, self.flushed)),
1019        );
1020        self.flushed = false;
1021        debug_assert!(res);
1022    }
1023
1024    fn input_preference(&self) -> OwnershipPreference {
1025        OwnershipPreference::PREFER_OWNED
1026    }
1027}
1028
1029/// Operator that receives values sent by the `ExchangeSender` operator and
1030/// assembles them into a single output value.
1031///
1032/// The `init` closure returns the initial value for the result.  This value
1033/// is updated by the `combine` closure with each value received from a remote
1034/// peer.
1035///
1036/// See [`ExchangeSender`] documentation for details.
1037///
1038/// `ExchangeReceiver` is an asynchronous operator., i.e.,
1039/// [`ExchangeReceiver::is_async`] returns `true`.  It becomes schedulable
1040/// ([`ExchangeReceiver::ready`] returns `true`) once all peers have sent values
1041/// for this worker in the current clock cycle.  The scheduler should use
1042/// [`ExchangeReceiver::register_ready_callback`] to get notified when the
1043/// operator becomes schedulable.
1044pub struct ExchangeReceiver<IF, T, L>
1045where
1046    T: Send + 'static + Clone,
1047{
1048    worker_index: usize,
1049    location: OperatorLocation,
1050    init: IF,
1051    combine: L,
1052    exchange: Arc<Exchange<(T, bool)>>,
1053    flush_count: usize,
1054    flush_complete: bool,
1055    start_wait_usecs: Arc<AtomicU64>,
1056    total_wait_time: Arc<AtomicU64>,
1057
1058    // Output batch sizes.
1059    output_batch_stats: BatchSizeStats,
1060}
1061
1062impl<IF, T, L> ExchangeReceiver<IF, T, L>
1063where
1064    T: Send + 'static + Clone,
1065{
1066    pub(crate) fn new(
1067        worker_index: usize,
1068        location: OperatorLocation,
1069        exchange: Arc<Exchange<(T, bool)>>,
1070        init: IF,
1071        start_wait_usecs: Arc<AtomicU64>,
1072        combine: L,
1073    ) -> Self {
1074        debug_assert!(worker_index < Runtime::num_workers());
1075
1076        Self {
1077            worker_index,
1078            location,
1079            init,
1080            combine,
1081            exchange,
1082            flush_count: 0,
1083            flush_complete: false,
1084            output_batch_stats: BatchSizeStats::new(),
1085            start_wait_usecs,
1086            total_wait_time: Arc::new(AtomicU64::new(0)),
1087        }
1088    }
1089}
1090
1091impl<D, T, L> Operator for ExchangeReceiver<D, T, L>
1092where
1093    D: 'static,
1094    T: Send + 'static + Clone,
1095    L: 'static,
1096{
1097    fn name(&self) -> Cow<'static, str> {
1098        Cow::from("ExchangeReceiver")
1099    }
1100
1101    fn location(&self) -> OperatorLocation {
1102        self.location
1103    }
1104
1105    fn metadata(&self, meta: &mut OperatorMeta) {
1106        meta.extend(metadata! {
1107            OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
1108            EXCHANGE_WAIT_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.total_wait_time.load(Ordering::Acquire))),
1109            EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.inner.delivery_usecs.load(Ordering::Acquire))),
1110            EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.inner.delivered_bytes.load(Ordering::Acquire)),
1111        });
1112    }
1113
1114    fn is_async(&self) -> bool {
1115        true
1116    }
1117
1118    fn register_ready_callback<F>(&mut self, cb: F)
1119    where
1120        F: Fn() + Send + Sync + 'static,
1121    {
1122        let start_wait_usecs = self.start_wait_usecs.clone();
1123        let total_wait_time = self.total_wait_time.clone();
1124        let exchange = self.exchange.clone();
1125        let worker_index = self.worker_index;
1126
1127        let cb = move || {
1128            if exchange.ready_to_receive(worker_index) {
1129                // The callback can be invoked multiple times per step.
1130                // Reset start_wait_usecs to 0 to make sure we don't double-count.
1131                let start = start_wait_usecs.swap(0, Ordering::Acquire);
1132                if start != 0 {
1133                    let end = current_time_usecs();
1134                    if end > start {
1135                        let wait_time_usecs = end - start;
1136                        // if worker_index == 0 {
1137                        //     info!(
1138                        //         "{worker_index}: {} +{wait_time_usecs}",
1139                        //         exchange.exchange_id()
1140                        //     );
1141                        // }
1142                        total_wait_time.fetch_add(wait_time_usecs, Ordering::AcqRel);
1143                    }
1144                }
1145            }
1146            cb()
1147        };
1148        self.exchange
1149            .register_receiver_callback(self.worker_index, cb)
1150    }
1151
1152    fn ready(&self) -> bool {
1153        self.exchange.ready_to_receive(self.worker_index)
1154    }
1155
1156    fn fixedpoint(&self, _scope: Scope) -> bool {
1157        true
1158    }
1159
1160    fn flush(&mut self) {
1161        // println!("{} exchange_receiver::flush", Runtime::worker_index());
1162        self.flush_complete = false;
1163    }
1164
1165    fn is_flush_complete(&self) -> bool {
1166        // println!(
1167        //     "{} exchange_receiver::is_flush_complete (flush_complete = {})",
1168        //     Runtime::worker_index(),
1169        //     self.flush_complete
1170        // );
1171        self.flush_complete
1172    }
1173}
1174
1175impl<D, IF, T, L> SourceOperator<D> for ExchangeReceiver<IF, T, L>
1176where
1177    D: NumEntries + 'static,
1178    T: Clone + Send + 'static,
1179    IF: Fn() -> D + 'static,
1180    L: Fn(&mut D, T) + 'static,
1181{
1182    async fn eval(&mut self) -> D {
1183        debug_assert!(self.ready());
1184        let mut combined = (self.init)();
1185        let res = self
1186            .exchange
1187            .try_receive_all(self.worker_index, |(x, flushed)| {
1188                // println!(
1189                //     "{} exchange_receiver::eval received input with flushed={:?}",
1190                //     Runtime::worker_index(),
1191                //     flushed
1192                // );
1193                if flushed {
1194                    self.flush_count += 1;
1195                }
1196                (self.combine)(&mut combined, x)
1197            });
1198        if self.flush_count == Runtime::num_workers() {
1199            // println!(
1200            //     "{} exchange_receiver::eval received all inputs",
1201            //     Runtime::worker_index()
1202            // );
1203
1204            self.flush_complete = true;
1205            self.flush_count = 0;
1206        }
1207
1208        debug_assert!(res);
1209
1210        self.output_batch_stats
1211            .add_batch(combined.num_entries_deep());
1212        combined
1213    }
1214}
1215
1216#[derive(Hash, PartialEq, Eq)]
1217struct ClientsId;
1218
1219impl TypedMapKey<LocalStoreMarker> for ClientsId {
1220    type Value = Arc<Clients>;
1221}
1222
1223#[derive(Hash, PartialEq, Eq)]
1224struct DirectoryId;
1225
1226impl TypedMapKey<LocalStoreMarker> for DirectoryId {
1227    type Value = ExchangeDirectory;
1228}
1229
1230/// Create an [`ExchangeSender`]/[`ExchangeReceiver`] operator pair.
1231///
1232/// See [`ExchangeSender`] documentation for details and example usage.
1233///
1234/// # Arguments
1235///
1236/// * `runtime` - [`Runtime`](`crate::circuit::Runtime`) within which operators
1237///   are created.
1238/// * `worker_index` - index of the current worker.
1239/// * `partition` - partitioning logic that must push exactly
1240///   `runtime.num_workers()` values into its vector argument
1241/// * `serialize` - serializes exchanged data for transmission across a network
1242/// * `deserialize` - deserializes exchanged data that was transmitted across a network
1243/// * `combine` - re-assemble logic that combines values received from all peers
1244///   into a single output value.
1245///
1246/// # Type arguments
1247/// * `TI` - Type of values in the input stream consumed by `ExchangeSender`.
1248/// * `TO` - Type of values in the output stream produced by `ExchangeReceiver`.
1249/// * `TE` - Type of values sent across workers.
1250/// * `PL` - Type of closure that splits a value of type `TI` into
1251///   `runtime.num_workers()` values of type `TE`.
1252/// * `I` - Iterator returned by `PL`.
1253/// * `IF` - Type of closure used to initialize the output value of type `TO`.
1254/// * `CL` - Type of closure that folds `num_workers` values of type `TE` into a
1255///   value of type `TO`.
1256pub fn new_exchange_operators<TI, TO, TE, IF, PL, CL, S, D>(
1257    location: OperatorLocation,
1258    init: IF,
1259    partition: PL,
1260    serialize: S,
1261    deserialize: D,
1262    combine: CL,
1263) -> Option<(ExchangeSender<TI, TE, PL>, ExchangeReceiver<IF, TE, CL>)>
1264where
1265    TO: Clone,
1266    TE: Send + 'static + Clone,
1267    IF: Fn() -> TO + 'static,
1268    PL: FnMut(TI, &mut Vec<TE>) + 'static,
1269    S: Fn(TE) -> Vec<u8> + Send + Sync + 'static,
1270    D: Fn(Vec<u8>) -> TE + Send + Sync + 'static,
1271    CL: Fn(&mut TO, TE) + 'static,
1272{
1273    if Runtime::num_workers() == 1 {
1274        return None;
1275    }
1276    let runtime = Runtime::runtime().unwrap();
1277    let worker_index = Runtime::worker_index();
1278
1279    let exchange_id = runtime.sequence_next();
1280    let start_wait_usecs = Arc::new(AtomicU64::new(0));
1281    let exchange = Exchange::with_runtime(
1282        &runtime,
1283        exchange_id,
1284        Box::new(move |(value, flush)| {
1285            let mut vec = serialize(value);
1286            vec.push(flush as u8);
1287            vec
1288        }),
1289        Box::new(move |mut vec| {
1290            let flush = match vec.pop().unwrap() {
1291                0 => false,
1292                1 => true,
1293                _ => unreachable!(),
1294            };
1295            (deserialize(vec), flush)
1296        }),
1297    );
1298    let sender = ExchangeSender::new(
1299        worker_index,
1300        location,
1301        exchange.clone(),
1302        start_wait_usecs.clone(),
1303        partition,
1304    );
1305    let receiver = ExchangeReceiver::new(
1306        worker_index,
1307        location,
1308        exchange,
1309        init,
1310        start_wait_usecs,
1311        combine,
1312    );
1313    Some((sender, receiver))
1314}
1315
1316#[cfg(test)]
1317mod tests {
1318    use super::Exchange;
1319    use crate::{
1320        Circuit, RootCircuit,
1321        circuit::{
1322            Runtime,
1323            schedule::{DynamicScheduler, Scheduler},
1324        },
1325        operator::{Generator, communication::new_exchange_operators},
1326        storage::file::{to_bytes, to_bytes_dyn},
1327        trace::unaligned_deserialize,
1328    };
1329    use std::thread::yield_now;
1330
1331    // We decrease the number of rounds we do when we're running under miri,
1332    // otherwise it'll run forever
1333    const ROUNDS: usize = if cfg!(miri) { 128 } else { 2048 };
1334
1335    // Create an exchange object with `WORKERS` concurrent senders/receivers.
1336    // Iterate for `ROUNDS` rounds with each sender sending value `N` to each
1337    // receiver in round number `N`.  Both senders and receivers may retry
1338    // sending/receiving multiple times, but in the end each receiver should get
1339    // all values in correct order.
1340    #[test]
1341    #[cfg_attr(miri, ignore)]
1342    fn test_exchange() {
1343        const WORKERS: usize = 16;
1344
1345        let hruntime = Runtime::run(WORKERS, |_parker| {
1346            let exchange = Exchange::with_runtime(
1347                &Runtime::runtime().unwrap(),
1348                0,
1349                Box::new(|value| to_bytes(&value).unwrap().into_vec()),
1350                Box::new(|data| unaligned_deserialize(&data[..])),
1351            );
1352
1353            for round in 0..ROUNDS {
1354                let output_data = vec![round; WORKERS];
1355                let mut output_iter = output_data.clone().into_iter();
1356                loop {
1357                    if exchange.try_send_all(Runtime::worker_index(), &mut output_iter) {
1358                        break;
1359                    }
1360
1361                    yield_now();
1362                }
1363
1364                let mut input_data = Vec::with_capacity(WORKERS);
1365                loop {
1366                    if exchange.try_receive_all(Runtime::worker_index(), |x| input_data.push(x)) {
1367                        break;
1368                    }
1369
1370                    yield_now();
1371                }
1372
1373                assert_eq!(input_data, output_data);
1374            }
1375        })
1376        .expect("failed to start runtime");
1377
1378        hruntime.join().unwrap();
1379    }
1380
1381    #[test]
1382    #[cfg_attr(miri, ignore)]
1383    fn test_exchange_operators_dynamic() {
1384        test_exchange_operators::<DynamicScheduler>();
1385    }
1386
1387    // Create a circuit with `WORKERS` concurrent workers with the following
1388    // structure: `Generator - ExchangeSender -> ExchangeReceiver -> Inspect`.
1389    // `Generator` - yields sequential numbers 0, 1, 2, ...
1390    // `ExchangeSender` - sends each number to all peers.
1391    // `ExchangeReceiver` - combines all received numbers in a vector.
1392    // `Inspect` - validates the output of the receiver.
1393    fn test_exchange_operators<S>()
1394    where
1395        S: Scheduler + 'static,
1396    {
1397        fn do_test<S>(workers: usize)
1398        where
1399            S: Scheduler + 'static,
1400        {
1401            let hruntime = Runtime::run(workers, move |_parker| {
1402                let circuit = RootCircuit::build_with_scheduler::<_, _, S>(move |circuit| {
1403                    let mut n: usize = 0;
1404                    let source = circuit.add_source(Generator::new(move || {
1405                        let result = n;
1406                        n += 1;
1407                        result
1408                    }));
1409
1410                    let (sender, receiver) = new_exchange_operators(
1411                        None,
1412                        Vec::new,
1413                        move |n, vals| {
1414                            for _ in 0..workers {
1415                                vals.push(n)
1416                            }
1417                        },
1418                        |value| to_bytes_dyn(&value).unwrap().into_vec(),
1419                        |data| unaligned_deserialize(&data[..]),
1420                        |v: &mut Vec<usize>, n| v.push(n),
1421                    )
1422                    .unwrap();
1423
1424                    let mut round = 0;
1425                    circuit
1426                        .add_exchange(sender, receiver, &source)
1427                        .inspect(move |v| {
1428                            assert_eq!(&vec![round; workers], v);
1429                            round += 1;
1430                        });
1431                    Ok(())
1432                })
1433                .unwrap()
1434                .0;
1435
1436                for _ in 1..ROUNDS {
1437                    circuit.transaction().unwrap();
1438                }
1439            })
1440            .expect("failed to start runtime");
1441
1442            hruntime.join().unwrap();
1443        }
1444
1445        do_test::<S>(2);
1446        do_test::<S>(16);
1447        do_test::<S>(32);
1448    }
1449}