shuttle/sync/
mpsc.rs

1//! Multi-producer, single-consumer FIFO queue communication primitives.
2
3use crate::runtime::execution::ExecutionState;
4use crate::runtime::task::clock::VectorClock;
5use crate::runtime::task::{TaskId, DEFAULT_INLINE_TASKS};
6use crate::runtime::thread;
7use smallvec::SmallVec;
8use std::cell::RefCell;
9use std::fmt::Debug;
10use std::rc::Rc;
11use std::result::Result;
12pub use std::sync::mpsc::{RecvError, RecvTimeoutError, SendError, TryRecvError, TrySendError};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::trace;
16
17const MAX_INLINE_MESSAGES: usize = 32;
18
19/// Create an unbounded channel
20pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
21    let channel = Arc::new(Channel::new(None));
22    let sender = Sender {
23        inner: Arc::clone(&channel),
24    };
25    let receiver = Receiver {
26        inner: Arc::clone(&channel),
27    };
28    (sender, receiver)
29}
30
31/// Create a bounded channel
32pub fn sync_channel<T>(bound: usize) -> (SyncSender<T>, Receiver<T>) {
33    let channel = Arc::new(Channel::new(Some(bound)));
34    let sender = SyncSender {
35        inner: Arc::clone(&channel),
36    };
37    let receiver = Receiver {
38        inner: Arc::clone(&channel),
39    };
40    (sender, receiver)
41}
42
43#[derive(Debug)]
44struct Channel<T> {
45    bound: Option<usize>, // None for an unbounded channel, Some(k) for a bounded channel of size k
46    state: Rc<RefCell<ChannelState<T>>>,
47}
48
49// For tracking causality on channels, we timestamp each message with the clock of the sender.
50// When the receiver gets the message, it updates its clock with the the associated timestamp.
51// For unbounded channels, that's all the work we need to do.
52//
53// For bounded and rendezvous channels, things get a bit more interesting.
54// Consider a bounded channel of depth K.  As soon as the sender successfully sends its K+1'th
55// message, it knows that the receiver has received at least 1 message.  At this point, the
56// first receive event causally precedes the (K+1)'th send.  By the rule for vector clocks,
57//  (clock of the first receive)  <  (clock of the K+1'th send)
58// In order to ensure this ordering, we add a return queue of depth K to bounded channels.
59// Initially, this queue contains K empty vector clocks.  On each receive, we push the
60// receiver's clock at the time of the receive to the end of this queue.  Whenever the sender
61// successfully sends a message, it pops the clock at the front of the queue, and updates its
62// own clock with this value.  Thus, on the (K+1)'th send, the sender's clock will be updated
63// with the clock at the first receive, as needed.
64//
65// The story is similar for rendezvous channels, except we have to handle things a bit more
66// specially because K=0.
67
68struct TimestampedValue<T> {
69    value: T,
70    clock: VectorClock,
71}
72
73impl<T> TimestampedValue<T> {
74    fn new(value: T, clock: VectorClock) -> Self {
75        Self { value, clock }
76    }
77}
78
79// Note: The channels in std::sync::mpsc only support a single Receiver (which cannot be
80// cloned).  The state below admits a more general use case, where multiple Senders
81// and Receivers can share a single channel.
82struct ChannelState<T> {
83    messages: SmallVec<[TimestampedValue<T>; MAX_INLINE_MESSAGES]>, // messages in the channel
84    receiver_clock: Option<SmallVec<[VectorClock; MAX_INLINE_MESSAGES]>>, // receiver vector clocks for bounded case
85    known_senders: usize,                                           // number of senders referencing this channel
86    known_receivers: usize,                                         // number or receivers referencing this channel
87    waiting_senders: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>,      // list of currently blocked senders
88    waiting_receivers: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>,    // list of currently blocked receivers
89}
90
91impl<T> Debug for ChannelState<T> {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "Channel {{ ")?;
94        write!(f, "num_messages: {} ", self.messages.len())?;
95        write!(
96            f,
97            "known_senders {} known_receivers {} ",
98            self.known_senders, self.known_receivers
99        )?;
100        write!(f, "waiting_senders: [{:?}] ", self.waiting_senders)?;
101        write!(f, "waiting_receivers: [{:?}] ", self.waiting_receivers)?;
102        write!(f, "}}")
103    }
104}
105
106impl<T> Channel<T> {
107    fn new(bound: Option<usize>) -> Self {
108        let receiver_clock = if let Some(bound) = bound {
109            let mut s = SmallVec::with_capacity(bound);
110            for _ in 0..bound {
111                s.push(VectorClock::new());
112            }
113            Some(s)
114        } else {
115            None
116        };
117        Self {
118            bound,
119            state: Rc::new(RefCell::new(ChannelState {
120                messages: SmallVec::new(),
121                receiver_clock,
122                known_senders: 1,
123                known_receivers: 1,
124                waiting_senders: SmallVec::new(),
125                waiting_receivers: SmallVec::new(),
126            })),
127        }
128    }
129
130    fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
131        self.send_internal(message, false)
132    }
133
134    fn send(&self, message: T) -> Result<(), SendError<T>> {
135        self.send_internal(message, true).map_err(|e| match e {
136            TrySendError::Full(_) => unreachable!(),
137            TrySendError::Disconnected(m) => SendError(m),
138        })
139    }
140
141    fn send_internal(&self, message: T, can_block: bool) -> Result<(), TrySendError<T>> {
142        let me = ExecutionState::me();
143        let mut state = self.state.borrow_mut();
144
145        trace!(
146            state = ?state,
147            "sender {:?} starting send on channel {:p}",
148            me,
149            self,
150        );
151        if state.known_receivers == 0 {
152            // No receivers are left, so the channel is disconnected.  Stop and return failure.
153            return Err(TrySendError::Disconnected(message));
154        }
155
156        let (is_rendezvous, is_full) = if let Some(bound) = self.bound {
157            // For a rendezvous channel (bound = 0), "is_full" holds when there is a message in the channel.
158            // For a non-rendezvous channel (bound > 0), "is_full" holds when the capacity is reached.
159            // We cover both these cases at once using max(bound, 1) below.
160            (bound == 0, state.messages.len() >= std::cmp::max(bound, 1))
161        } else {
162            (false, false)
163        };
164
165        // The sender should block in any of the following situations:
166        //    the channel is full (as defined above)
167        //    there are already waiting senders
168        //    this is a rendezvous channel and there are no waiting receivers
169        let sender_should_block =
170            is_full || !state.waiting_senders.is_empty() || (is_rendezvous && state.waiting_receivers.is_empty());
171
172        if sender_should_block {
173            if !can_block {
174                return Err(TrySendError::Full(message));
175            }
176
177            state.waiting_senders.push(me);
178            trace!(
179                state = ?state,
180                "blocking sender {:?} on channel {:p}",
181                me,
182                self,
183            );
184            ExecutionState::with(|s| s.current_mut().block(false));
185            drop(state);
186
187            thread::switch();
188
189            state = self.state.borrow_mut();
190            trace!(
191                state = ?state,
192                "unblocked sender {:?} on channel {:p}",
193                me,
194                self,
195            );
196
197            // Check again that we still have a receiver; if not, return with error.
198            // We repeat this check because the receivers may have disconnected while the sender was blocked.
199            if state.known_receivers == 0 {
200                state.waiting_senders.retain(|t| *t != me);
201                // No receivers are left, so the channel is disconnected.  Stop and return failure.
202                return Err(TrySendError::Disconnected(message));
203            }
204
205            let head = state.waiting_senders.remove(0);
206            assert_eq!(head, me);
207        }
208
209        ExecutionState::with(|s| {
210            let clock = s.increment_clock();
211            state.messages.push(TimestampedValue::new(message, clock.clone()));
212        });
213
214        // The sender has just added a message to the channel, so unblock the first waiting receiver if any
215        if let Some(&tid) = state.waiting_receivers.first() {
216            ExecutionState::with(|s| {
217                s.get_mut(tid).unblock();
218
219                // When a sender successfully sends on a rendezvous channel, it knows that the receiver will perform
220                // the matching receive, so we need to update the sender's clock with the receiver's.
221                if is_rendezvous {
222                    let recv_clock = s.get_clock(tid).clone();
223                    s.update_clock(&recv_clock);
224                }
225            });
226        }
227        // Check and unblock the next the waiting sender, if eligible
228        if let Some(&tid) = state.waiting_senders.first() {
229            let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
230            if state.messages.len() < bound {
231                ExecutionState::with(|s| s.get_mut(tid).unblock());
232            }
233        }
234
235        if !is_rendezvous {
236            if let Some(receiver_clock) = &mut state.receiver_clock {
237                let recv_clock = receiver_clock.remove(0);
238                ExecutionState::with(|s| s.update_clock(&recv_clock));
239            }
240        }
241
242        Ok(())
243    }
244
245    fn recv(&self) -> Result<T, RecvError> {
246        self.recv_internal(true).map_err(|e| match e {
247            TryRecvError::Disconnected => RecvError,
248            TryRecvError::Empty => unreachable!(),
249        })
250    }
251
252    fn try_recv(&self) -> Result<T, TryRecvError> {
253        self.recv_internal(false)
254    }
255
256    fn recv_internal(&self, can_block: bool) -> Result<T, TryRecvError> {
257        let me = ExecutionState::me();
258        let mut state = self.state.borrow_mut();
259
260        trace!(
261            state = ?state,
262            "starting recv on channel {:p}",
263            self,
264        );
265        // Check if there are any senders left; if not, and the channel is empty, fail with error
266        // (If there are no senders, but the channel is nonempty, the receiver can successfully consume that message.)
267        if state.messages.is_empty() && state.known_senders == 0 {
268            return Err(TryRecvError::Disconnected);
269        }
270
271        let is_rendezvous = self.bound == Some(0);
272        // If this is a rendezvous channel, and the channel is empty, and there are waiting senders,
273        // notify the first waiting sender
274        if is_rendezvous && state.messages.is_empty() {
275            if let Some(&tid) = state.waiting_senders.first() {
276                // Note: another receiver may have unblocked the sender already
277                ExecutionState::with(|s| s.get_mut(tid).unblock());
278            } else if !can_block {
279                // Nobody to rendezvous with
280                return Err(TryRecvError::Empty);
281            }
282        }
283
284        // Handle the try_recv case, accounting for the number of msgs available and already waiting receivers.
285        if !is_rendezvous && !can_block && state.waiting_receivers.len() >= state.messages.len() {
286            return Err(TryRecvError::Empty);
287        }
288
289        // Pre-increment the receiver's clock before continuing
290        //
291        // Note: The reason for pre-incrementing the receiver's clock is to deal properly with rendezvous channels.
292        // Here's the scenario we have to handle:
293        //   1. the receiver arrives at a rendezvous channel and blocks
294        //   2. the sender arrives, sees the receiver is waiting and does not block
295        //   3. the sender drops the message in the channel and updates its clock with the receiver's clock and continues
296        //   4. later, the receiver unblocks and picks up the message and updates its clock with the sender's
297        // Without the pre-increment, in step 3, the sender would update its clock with the receiver's clock before
298        // it is incremented.  (The increment records the fact that the receiver arrived at the synchronization point.)
299        ExecutionState::with(|s| {
300            let _ = s.increment_clock();
301        });
302
303        // The receiver should block in any of the following situations:
304        //    the channel is empty
305        //    there are waiting receivers
306        let should_block = state.messages.is_empty() || !state.waiting_receivers.is_empty();
307        if should_block {
308            state.waiting_receivers.push(me);
309            trace!(
310                state = ?state,
311                "blocking receiver {:?} on channel {:p}",
312                me,
313                self,
314            );
315            ExecutionState::with(|s| s.current_mut().block(false));
316            drop(state);
317
318            thread::switch();
319
320            state = self.state.borrow_mut();
321            trace!(
322                state = ?state,
323                "unblocked receiver {:?} on channel {:p}",
324                me,
325                self,
326            );
327
328            // Check again if there are any senders left; if not, and the channel is empty, fail with error
329            // (If there are no senders, but the channel is nonempty, the receiver can successfully consume that message.)
330            // We repeat this check because the senders may have disconnected while the receiver was blocked.
331            if state.messages.is_empty() && state.known_senders == 0 {
332                state.waiting_receivers.retain(|t| *t != me);
333                return Err(TryRecvError::Disconnected);
334            }
335
336            let head = state.waiting_receivers.remove(0);
337            assert_eq!(head, me);
338        }
339
340        let item = state.messages.remove(0);
341        // The receiver has just removed an element from the channel.  Check if any waiting senders
342        // need to be notified.
343        if let Some(&tid) = state.waiting_senders.first() {
344            let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
345            // Unblock the first waiting sender provided one of the following conditions hold:
346            // - this is a non-rendezvous bounded channel (bound > 0)
347            // - this is a rendezvous channel and we have additional waiting receivers
348            if bound > 0 || !state.waiting_receivers.is_empty() {
349                ExecutionState::with(|s| s.get_mut(tid).unblock());
350            }
351        }
352        // Check and unblock the next the waiting receiver, if eligible
353        // Note: this is a no-op for mpsc channels, since there can only be one receiver
354        if let Some(&tid) = state.waiting_receivers.first() {
355            if !state.messages.is_empty() {
356                ExecutionState::with(|s| s.get_mut(tid).unblock());
357            }
358        }
359
360        // Update receiver clock from the clock attached to the message received
361        let TimestampedValue { value, clock } = item;
362        ExecutionState::with(|s| {
363            // Since we already incremented the receiver's clock above, just update it here
364            s.get_clock_mut(me).update(&clock);
365
366            // If this is a (non-rendezvous) bounded channel, propagate causality backwards to sender
367            if let Some(receiver_clock) = &mut state.receiver_clock {
368                let bound = self.bound.expect("unexpected internal error"); // must be defined for bounded channels
369                if bound > 0 {
370                    // non-rendezvous
371                    assert!(receiver_clock.len() < bound);
372                    receiver_clock.push(s.get_clock(me).clone());
373                }
374            }
375        });
376        Ok(value)
377    }
378}
379
380// Safety: A Channel is never actually passed across true threads, only across continuations. The
381// Rc<RefCell<_>> type therefore can't be preempted mid-bookkeeping-operation.
382// TODO We use this workaround in several places in Shuttle.  Maybe there's a cleaner solution.
383unsafe impl<T: Send> Send for Channel<T> {}
384unsafe impl<T: Send> Sync for Channel<T> {}
385
386/// The receiving half of Rust's [`channel`] (or [`sync_channel`]) type.
387/// This half can only be owned by one thread.
388#[derive(Debug)]
389pub struct Receiver<T> {
390    inner: Arc<Channel<T>>,
391}
392
393impl<T> Receiver<T> {
394    /// Attempts to wait for a value on this receiver, returning an error if the
395    /// corresponding channel has hung up.
396    pub fn recv(&self) -> Result<T, RecvError> {
397        self.inner.recv()
398    }
399
400    /// Attempts to wait for a value on this receiver, returning an error if the
401    /// corresponding channel has hung up.
402    pub fn try_recv(&self) -> Result<T, TryRecvError> {
403        self.inner.try_recv()
404    }
405
406    /// Attempts to wait for a value on this receiver, returning an error if the
407    /// corresponding channel has hung up, or if it waits more than timeout.
408    pub fn recv_timeout(&self, _timeout: Duration) -> Result<T, RecvTimeoutError> {
409        // TODO support the timeout case -- this method never times out
410        self.inner.recv().map_err(|_| RecvTimeoutError::Disconnected)
411    }
412
413    /// Returns an iterator that will block waiting for messages, but never
414    /// [`panic!`]. It will return [`None`] when the channel has hung up.
415    pub fn iter(&self) -> Iter<'_, T> {
416        Iter { rx: self }
417    }
418
419    /// Returns an iterator that will attempt to yield all pending values.
420    /// It will return `None` if there are no more pending values or if the
421    /// channel has hung up. The iterator will never [`panic!`] or block the
422    /// user by waiting for values.
423    pub fn try_iter(&self) -> TryIter<'_, T> {
424        TryIter { rx: self }
425    }
426}
427
428impl<T> Drop for Receiver<T> {
429    fn drop(&mut self) {
430        if ExecutionState::should_stop() {
431            return;
432        }
433        let mut state = self.inner.state.borrow_mut();
434        assert!(state.known_receivers > 0);
435        state.known_receivers -= 1;
436        if state.known_receivers == 0 {
437            // Last receiver was dropped; wake up all senders
438            for &tid in state.waiting_senders.iter() {
439                ExecutionState::with(|s| s.get_mut(tid).unblock());
440            }
441        }
442    }
443}
444
445/// An iterator over messages on a [`Receiver`], created by [`iter`].
446///
447/// This iterator will block whenever [`next`] is called,
448/// waiting for a new message, and [`None`] will be returned
449/// when the corresponding channel has hung up.
450///
451/// [`iter`]: Receiver::iter
452/// [`next`]: Iterator::next
453#[derive(Debug)]
454pub struct Iter<'a, T: 'a> {
455    rx: &'a Receiver<T>,
456}
457
458/// An iterator that attempts to yield all pending values for a [`Receiver`],
459/// created by [`try_iter`].
460///
461/// [`None`] will be returned when there are no pending values remaining or
462/// if the corresponding channel has hung up.
463///
464/// This iterator will never block the caller in order to wait for data to
465/// become available. Instead, it will return [`None`].
466///
467/// [`try_iter`]: Receiver::try_iter
468#[derive(Debug)]
469pub struct TryIter<'a, T: 'a> {
470    rx: &'a Receiver<T>,
471}
472
473/// An owning iterator over messages on a [`Receiver`],
474/// created by [`into_iter`].
475///
476/// This iterator will block whenever [`next`]
477/// is called, waiting for a new message, and [`None`] will be
478/// returned if the corresponding channel has hung up.
479///
480/// [`into_iter`]: Receiver::into_iter
481/// [`next`]: Iterator::next
482#[derive(Debug)]
483pub struct IntoIter<T> {
484    rx: Receiver<T>,
485}
486
487impl<T> Iterator for Iter<'_, T> {
488    type Item = T;
489
490    fn next(&mut self) -> Option<T> {
491        self.rx.recv().ok()
492    }
493}
494
495impl<T> Iterator for TryIter<'_, T> {
496    type Item = T;
497
498    fn next(&mut self) -> Option<T> {
499        self.rx.try_recv().ok()
500    }
501}
502
503impl<'a, T> IntoIterator for &'a Receiver<T> {
504    type Item = T;
505    type IntoIter = Iter<'a, T>;
506
507    fn into_iter(self) -> Iter<'a, T> {
508        self.iter()
509    }
510}
511
512impl<T> Iterator for IntoIter<T> {
513    type Item = T;
514    fn next(&mut self) -> Option<T> {
515        self.rx.recv().ok()
516    }
517}
518
519impl<T> IntoIterator for Receiver<T> {
520    type Item = T;
521    type IntoIter = IntoIter<T>;
522
523    fn into_iter(self) -> IntoIter<T> {
524        IntoIter { rx: self }
525    }
526}
527
528/// The sending-half of Rust's asynchronous [`channel`] type. This half can only be
529/// owned by one thread, but it can be cloned to send to other threads.
530#[derive(Debug)]
531pub struct Sender<T> {
532    inner: Arc<Channel<T>>,
533}
534
535impl<T> Sender<T> {
536    /// Attempts to send a value on this channel, returning it back if it could
537    /// not be sent.
538    pub fn send(&self, t: T) -> Result<(), SendError<T>> {
539        self.inner.send(t)
540    }
541}
542
543impl<T> Clone for Sender<T> {
544    fn clone(&self) -> Self {
545        let mut state = self.inner.state.borrow_mut();
546        state.known_senders += 1;
547        drop(state);
548        Self {
549            inner: self.inner.clone(),
550        }
551    }
552}
553
554impl<T> Drop for Sender<T> {
555    fn drop(&mut self) {
556        if ExecutionState::should_stop() {
557            return;
558        }
559        let mut state = self.inner.state.borrow_mut();
560        assert!(state.known_senders > 0);
561        state.known_senders -= 1;
562        if state.known_senders == 0 {
563            // Last sender was dropped; wake up all receivers
564            for &tid in state.waiting_receivers.iter() {
565                ExecutionState::with(|s| s.get_mut(tid).unblock());
566            }
567        }
568    }
569}
570
571/// The sending-half of Rust's synchronous [`sync_channel`] type.
572///
573/// Messages can be sent through this channel with [`SyncSender::send`] or \[`try_send`\] (TODO)
574///
575/// [`SyncSender::send`] will block if there is no space in the internal buffer.
576#[derive(Debug)]
577pub struct SyncSender<T> {
578    inner: Arc<Channel<T>>,
579}
580
581impl<T> SyncSender<T> {
582    /// Sends a value on this synchronous channel.
583    ///
584    /// This function will *block* until space in the internal buffer becomes
585    /// available or a receiver is available to hand off the message to.
586    pub fn send(&self, t: T) -> Result<(), SendError<T>> {
587        self.inner.send(t)
588    }
589
590    /// Attempts to send a value on this channel without blocking.
591    ///
592    /// This method differs from [`send`] by returning immediately if the
593    /// channel's buffer is full or no receiver is waiting to acquire some
594    /// data. Compared with [`send`], this function has two failure cases
595    /// instead of one (one for disconnection, one for a full buffer).
596    ///
597    /// [`send`]: Self::send
598    pub fn try_send(&self, t: T) -> Result<(), TrySendError<T>> {
599        self.inner.try_send(t)
600    }
601}
602
603impl<T> Clone for SyncSender<T> {
604    fn clone(&self) -> Self {
605        let mut state = self.inner.state.borrow_mut();
606        state.known_senders += 1;
607        drop(state);
608        Self {
609            inner: self.inner.clone(),
610        }
611    }
612}
613
614impl<T> Drop for SyncSender<T> {
615    fn drop(&mut self) {
616        if ExecutionState::should_stop() {
617            return;
618        }
619        let mut state = self.inner.state.borrow_mut();
620        assert!(state.known_senders > 0);
621        state.known_senders -= 1;
622        if state.known_senders == 0 {
623            // Last sender was dropped; wake up any receivers
624            for &tid in state.waiting_receivers.iter() {
625                ExecutionState::with(|s| s.get_mut(tid).unblock());
626            }
627        }
628    }
629}