shuttle/sync/mpsc.rs
1//! Multi-producer, single-consumer FIFO queue communication primitives.
2
3use crate::runtime::execution::ExecutionState;
4use crate::runtime::task::clock::VectorClock;
5use crate::runtime::task::{TaskId, DEFAULT_INLINE_TASKS};
6use crate::runtime::thread;
7use smallvec::SmallVec;
8use std::cell::RefCell;
9use std::fmt::Debug;
10use std::rc::Rc;
11use std::result::Result;
12pub use std::sync::mpsc::{RecvError, RecvTimeoutError, SendError, TryRecvError, TrySendError};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::trace;
16
17const MAX_INLINE_MESSAGES: usize = 32;
18
19/// Create an unbounded channel
20pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
21 let channel = Arc::new(Channel::new(None));
22 let sender = Sender {
23 inner: Arc::clone(&channel),
24 };
25 let receiver = Receiver {
26 inner: Arc::clone(&channel),
27 };
28 (sender, receiver)
29}
30
31/// Create a bounded channel
32pub fn sync_channel<T>(bound: usize) -> (SyncSender<T>, Receiver<T>) {
33 let channel = Arc::new(Channel::new(Some(bound)));
34 let sender = SyncSender {
35 inner: Arc::clone(&channel),
36 };
37 let receiver = Receiver {
38 inner: Arc::clone(&channel),
39 };
40 (sender, receiver)
41}
42
43#[derive(Debug)]
44struct Channel<T> {
45 bound: Option<usize>, // None for an unbounded channel, Some(k) for a bounded channel of size k
46 state: Rc<RefCell<ChannelState<T>>>,
47}
48
49// For tracking causality on channels, we timestamp each message with the clock of the sender.
50// When the receiver gets the message, it updates its clock with the the associated timestamp.
51// For unbounded channels, that's all the work we need to do.
52//
53// For bounded and rendezvous channels, things get a bit more interesting.
54// Consider a bounded channel of depth K. As soon as the sender successfully sends its K+1'th
55// message, it knows that the receiver has received at least 1 message. At this point, the
56// first receive event causally precedes the (K+1)'th send. By the rule for vector clocks,
57// (clock of the first receive) < (clock of the K+1'th send)
58// In order to ensure this ordering, we add a return queue of depth K to bounded channels.
59// Initially, this queue contains K empty vector clocks. On each receive, we push the
60// receiver's clock at the time of the receive to the end of this queue. Whenever the sender
61// successfully sends a message, it pops the clock at the front of the queue, and updates its
62// own clock with this value. Thus, on the (K+1)'th send, the sender's clock will be updated
63// with the clock at the first receive, as needed.
64//
65// The story is similar for rendezvous channels, except we have to handle things a bit more
66// specially because K=0.
67
68struct TimestampedValue<T> {
69 value: T,
70 clock: VectorClock,
71}
72
73impl<T> TimestampedValue<T> {
74 fn new(value: T, clock: VectorClock) -> Self {
75 Self { value, clock }
76 }
77}
78
79// Note: The channels in std::sync::mpsc only support a single Receiver (which cannot be
80// cloned). The state below admits a more general use case, where multiple Senders
81// and Receivers can share a single channel.
82struct ChannelState<T> {
83 messages: SmallVec<[TimestampedValue<T>; MAX_INLINE_MESSAGES]>, // messages in the channel
84 receiver_clock: Option<SmallVec<[VectorClock; MAX_INLINE_MESSAGES]>>, // receiver vector clocks for bounded case
85 known_senders: usize, // number of senders referencing this channel
86 known_receivers: usize, // number or receivers referencing this channel
87 waiting_senders: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>, // list of currently blocked senders
88 waiting_receivers: SmallVec<[TaskId; DEFAULT_INLINE_TASKS]>, // list of currently blocked receivers
89}
90
91impl<T> Debug for ChannelState<T> {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "Channel {{ ")?;
94 write!(f, "num_messages: {} ", self.messages.len())?;
95 write!(
96 f,
97 "known_senders {} known_receivers {} ",
98 self.known_senders, self.known_receivers
99 )?;
100 write!(f, "waiting_senders: [{:?}] ", self.waiting_senders)?;
101 write!(f, "waiting_receivers: [{:?}] ", self.waiting_receivers)?;
102 write!(f, "}}")
103 }
104}
105
106impl<T> Channel<T> {
107 fn new(bound: Option<usize>) -> Self {
108 let receiver_clock = if let Some(bound) = bound {
109 let mut s = SmallVec::with_capacity(bound);
110 for _ in 0..bound {
111 s.push(VectorClock::new());
112 }
113 Some(s)
114 } else {
115 None
116 };
117 Self {
118 bound,
119 state: Rc::new(RefCell::new(ChannelState {
120 messages: SmallVec::new(),
121 receiver_clock,
122 known_senders: 1,
123 known_receivers: 1,
124 waiting_senders: SmallVec::new(),
125 waiting_receivers: SmallVec::new(),
126 })),
127 }
128 }
129
130 fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
131 self.send_internal(message, false)
132 }
133
134 fn send(&self, message: T) -> Result<(), SendError<T>> {
135 self.send_internal(message, true).map_err(|e| match e {
136 TrySendError::Full(_) => unreachable!(),
137 TrySendError::Disconnected(m) => SendError(m),
138 })
139 }
140
141 fn send_internal(&self, message: T, can_block: bool) -> Result<(), TrySendError<T>> {
142 let me = ExecutionState::me();
143 let mut state = self.state.borrow_mut();
144
145 trace!(
146 state = ?state,
147 "sender {:?} starting send on channel {:p}",
148 me,
149 self,
150 );
151 if state.known_receivers == 0 {
152 // No receivers are left, so the channel is disconnected. Stop and return failure.
153 return Err(TrySendError::Disconnected(message));
154 }
155
156 let (is_rendezvous, is_full) = if let Some(bound) = self.bound {
157 // For a rendezvous channel (bound = 0), "is_full" holds when there is a message in the channel.
158 // For a non-rendezvous channel (bound > 0), "is_full" holds when the capacity is reached.
159 // We cover both these cases at once using max(bound, 1) below.
160 (bound == 0, state.messages.len() >= std::cmp::max(bound, 1))
161 } else {
162 (false, false)
163 };
164
165 // The sender should block in any of the following situations:
166 // the channel is full (as defined above)
167 // there are already waiting senders
168 // this is a rendezvous channel and there are no waiting receivers
169 let sender_should_block =
170 is_full || !state.waiting_senders.is_empty() || (is_rendezvous && state.waiting_receivers.is_empty());
171
172 if sender_should_block {
173 if !can_block {
174 return Err(TrySendError::Full(message));
175 }
176
177 state.waiting_senders.push(me);
178 trace!(
179 state = ?state,
180 "blocking sender {:?} on channel {:p}",
181 me,
182 self,
183 );
184 ExecutionState::with(|s| s.current_mut().block(false));
185 drop(state);
186
187 thread::switch();
188
189 state = self.state.borrow_mut();
190 trace!(
191 state = ?state,
192 "unblocked sender {:?} on channel {:p}",
193 me,
194 self,
195 );
196
197 // Check again that we still have a receiver; if not, return with error.
198 // We repeat this check because the receivers may have disconnected while the sender was blocked.
199 if state.known_receivers == 0 {
200 state.waiting_senders.retain(|t| *t != me);
201 // No receivers are left, so the channel is disconnected. Stop and return failure.
202 return Err(TrySendError::Disconnected(message));
203 }
204
205 let head = state.waiting_senders.remove(0);
206 assert_eq!(head, me);
207 }
208
209 ExecutionState::with(|s| {
210 let clock = s.increment_clock();
211 state.messages.push(TimestampedValue::new(message, clock.clone()));
212 });
213
214 // The sender has just added a message to the channel, so unblock the first waiting receiver if any
215 if let Some(&tid) = state.waiting_receivers.first() {
216 ExecutionState::with(|s| {
217 s.get_mut(tid).unblock();
218
219 // When a sender successfully sends on a rendezvous channel, it knows that the receiver will perform
220 // the matching receive, so we need to update the sender's clock with the receiver's.
221 if is_rendezvous {
222 let recv_clock = s.get_clock(tid).clone();
223 s.update_clock(&recv_clock);
224 }
225 });
226 }
227 // Check and unblock the next the waiting sender, if eligible
228 if let Some(&tid) = state.waiting_senders.first() {
229 let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
230 if state.messages.len() < bound {
231 ExecutionState::with(|s| s.get_mut(tid).unblock());
232 }
233 }
234
235 if !is_rendezvous {
236 if let Some(receiver_clock) = &mut state.receiver_clock {
237 let recv_clock = receiver_clock.remove(0);
238 ExecutionState::with(|s| s.update_clock(&recv_clock));
239 }
240 }
241
242 Ok(())
243 }
244
245 fn recv(&self) -> Result<T, RecvError> {
246 self.recv_internal(true).map_err(|e| match e {
247 TryRecvError::Disconnected => RecvError,
248 TryRecvError::Empty => unreachable!(),
249 })
250 }
251
252 fn try_recv(&self) -> Result<T, TryRecvError> {
253 self.recv_internal(false)
254 }
255
256 fn recv_internal(&self, can_block: bool) -> Result<T, TryRecvError> {
257 let me = ExecutionState::me();
258 let mut state = self.state.borrow_mut();
259
260 trace!(
261 state = ?state,
262 "starting recv on channel {:p}",
263 self,
264 );
265 // Check if there are any senders left; if not, and the channel is empty, fail with error
266 // (If there are no senders, but the channel is nonempty, the receiver can successfully consume that message.)
267 if state.messages.is_empty() && state.known_senders == 0 {
268 return Err(TryRecvError::Disconnected);
269 }
270
271 let is_rendezvous = self.bound == Some(0);
272 // If this is a rendezvous channel, and the channel is empty, and there are waiting senders,
273 // notify the first waiting sender
274 if is_rendezvous && state.messages.is_empty() {
275 if let Some(&tid) = state.waiting_senders.first() {
276 // Note: another receiver may have unblocked the sender already
277 ExecutionState::with(|s| s.get_mut(tid).unblock());
278 } else if !can_block {
279 // Nobody to rendezvous with
280 return Err(TryRecvError::Empty);
281 }
282 }
283
284 // Handle the try_recv case, accounting for the number of msgs available and already waiting receivers.
285 if !is_rendezvous && !can_block && state.waiting_receivers.len() >= state.messages.len() {
286 return Err(TryRecvError::Empty);
287 }
288
289 // Pre-increment the receiver's clock before continuing
290 //
291 // Note: The reason for pre-incrementing the receiver's clock is to deal properly with rendezvous channels.
292 // Here's the scenario we have to handle:
293 // 1. the receiver arrives at a rendezvous channel and blocks
294 // 2. the sender arrives, sees the receiver is waiting and does not block
295 // 3. the sender drops the message in the channel and updates its clock with the receiver's clock and continues
296 // 4. later, the receiver unblocks and picks up the message and updates its clock with the sender's
297 // Without the pre-increment, in step 3, the sender would update its clock with the receiver's clock before
298 // it is incremented. (The increment records the fact that the receiver arrived at the synchronization point.)
299 ExecutionState::with(|s| {
300 let _ = s.increment_clock();
301 });
302
303 // The receiver should block in any of the following situations:
304 // the channel is empty
305 // there are waiting receivers
306 let should_block = state.messages.is_empty() || !state.waiting_receivers.is_empty();
307 if should_block {
308 state.waiting_receivers.push(me);
309 trace!(
310 state = ?state,
311 "blocking receiver {:?} on channel {:p}",
312 me,
313 self,
314 );
315 ExecutionState::with(|s| s.current_mut().block(false));
316 drop(state);
317
318 thread::switch();
319
320 state = self.state.borrow_mut();
321 trace!(
322 state = ?state,
323 "unblocked receiver {:?} on channel {:p}",
324 me,
325 self,
326 );
327
328 // Check again if there are any senders left; if not, and the channel is empty, fail with error
329 // (If there are no senders, but the channel is nonempty, the receiver can successfully consume that message.)
330 // We repeat this check because the senders may have disconnected while the receiver was blocked.
331 if state.messages.is_empty() && state.known_senders == 0 {
332 state.waiting_receivers.retain(|t| *t != me);
333 return Err(TryRecvError::Disconnected);
334 }
335
336 let head = state.waiting_receivers.remove(0);
337 assert_eq!(head, me);
338 }
339
340 let item = state.messages.remove(0);
341 // The receiver has just removed an element from the channel. Check if any waiting senders
342 // need to be notified.
343 if let Some(&tid) = state.waiting_senders.first() {
344 let bound = self.bound.expect("can't have waiting senders on an unbounded channel");
345 // Unblock the first waiting sender provided one of the following conditions hold:
346 // - this is a non-rendezvous bounded channel (bound > 0)
347 // - this is a rendezvous channel and we have additional waiting receivers
348 if bound > 0 || !state.waiting_receivers.is_empty() {
349 ExecutionState::with(|s| s.get_mut(tid).unblock());
350 }
351 }
352 // Check and unblock the next the waiting receiver, if eligible
353 // Note: this is a no-op for mpsc channels, since there can only be one receiver
354 if let Some(&tid) = state.waiting_receivers.first() {
355 if !state.messages.is_empty() {
356 ExecutionState::with(|s| s.get_mut(tid).unblock());
357 }
358 }
359
360 // Update receiver clock from the clock attached to the message received
361 let TimestampedValue { value, clock } = item;
362 ExecutionState::with(|s| {
363 // Since we already incremented the receiver's clock above, just update it here
364 s.get_clock_mut(me).update(&clock);
365
366 // If this is a (non-rendezvous) bounded channel, propagate causality backwards to sender
367 if let Some(receiver_clock) = &mut state.receiver_clock {
368 let bound = self.bound.expect("unexpected internal error"); // must be defined for bounded channels
369 if bound > 0 {
370 // non-rendezvous
371 assert!(receiver_clock.len() < bound);
372 receiver_clock.push(s.get_clock(me).clone());
373 }
374 }
375 });
376 Ok(value)
377 }
378}
379
380// Safety: A Channel is never actually passed across true threads, only across continuations. The
381// Rc<RefCell<_>> type therefore can't be preempted mid-bookkeeping-operation.
382// TODO We use this workaround in several places in Shuttle. Maybe there's a cleaner solution.
383unsafe impl<T: Send> Send for Channel<T> {}
384unsafe impl<T: Send> Sync for Channel<T> {}
385
386/// The receiving half of Rust's [`channel`] (or [`sync_channel`]) type.
387/// This half can only be owned by one thread.
388#[derive(Debug)]
389pub struct Receiver<T> {
390 inner: Arc<Channel<T>>,
391}
392
393impl<T> Receiver<T> {
394 /// Attempts to wait for a value on this receiver, returning an error if the
395 /// corresponding channel has hung up.
396 pub fn recv(&self) -> Result<T, RecvError> {
397 self.inner.recv()
398 }
399
400 /// Attempts to wait for a value on this receiver, returning an error if the
401 /// corresponding channel has hung up.
402 pub fn try_recv(&self) -> Result<T, TryRecvError> {
403 self.inner.try_recv()
404 }
405
406 /// Attempts to wait for a value on this receiver, returning an error if the
407 /// corresponding channel has hung up, or if it waits more than timeout.
408 pub fn recv_timeout(&self, _timeout: Duration) -> Result<T, RecvTimeoutError> {
409 // TODO support the timeout case -- this method never times out
410 self.inner.recv().map_err(|_| RecvTimeoutError::Disconnected)
411 }
412
413 /// Returns an iterator that will block waiting for messages, but never
414 /// [`panic!`]. It will return [`None`] when the channel has hung up.
415 pub fn iter(&self) -> Iter<'_, T> {
416 Iter { rx: self }
417 }
418
419 /// Returns an iterator that will attempt to yield all pending values.
420 /// It will return `None` if there are no more pending values or if the
421 /// channel has hung up. The iterator will never [`panic!`] or block the
422 /// user by waiting for values.
423 pub fn try_iter(&self) -> TryIter<'_, T> {
424 TryIter { rx: self }
425 }
426}
427
428impl<T> Drop for Receiver<T> {
429 fn drop(&mut self) {
430 if ExecutionState::should_stop() {
431 return;
432 }
433 let mut state = self.inner.state.borrow_mut();
434 assert!(state.known_receivers > 0);
435 state.known_receivers -= 1;
436 if state.known_receivers == 0 {
437 // Last receiver was dropped; wake up all senders
438 for &tid in state.waiting_senders.iter() {
439 ExecutionState::with(|s| s.get_mut(tid).unblock());
440 }
441 }
442 }
443}
444
445/// An iterator over messages on a [`Receiver`], created by [`iter`].
446///
447/// This iterator will block whenever [`next`] is called,
448/// waiting for a new message, and [`None`] will be returned
449/// when the corresponding channel has hung up.
450///
451/// [`iter`]: Receiver::iter
452/// [`next`]: Iterator::next
453#[derive(Debug)]
454pub struct Iter<'a, T: 'a> {
455 rx: &'a Receiver<T>,
456}
457
458/// An iterator that attempts to yield all pending values for a [`Receiver`],
459/// created by [`try_iter`].
460///
461/// [`None`] will be returned when there are no pending values remaining or
462/// if the corresponding channel has hung up.
463///
464/// This iterator will never block the caller in order to wait for data to
465/// become available. Instead, it will return [`None`].
466///
467/// [`try_iter`]: Receiver::try_iter
468#[derive(Debug)]
469pub struct TryIter<'a, T: 'a> {
470 rx: &'a Receiver<T>,
471}
472
473/// An owning iterator over messages on a [`Receiver`],
474/// created by [`into_iter`].
475///
476/// This iterator will block whenever [`next`]
477/// is called, waiting for a new message, and [`None`] will be
478/// returned if the corresponding channel has hung up.
479///
480/// [`into_iter`]: Receiver::into_iter
481/// [`next`]: Iterator::next
482#[derive(Debug)]
483pub struct IntoIter<T> {
484 rx: Receiver<T>,
485}
486
487impl<T> Iterator for Iter<'_, T> {
488 type Item = T;
489
490 fn next(&mut self) -> Option<T> {
491 self.rx.recv().ok()
492 }
493}
494
495impl<T> Iterator for TryIter<'_, T> {
496 type Item = T;
497
498 fn next(&mut self) -> Option<T> {
499 self.rx.try_recv().ok()
500 }
501}
502
503impl<'a, T> IntoIterator for &'a Receiver<T> {
504 type Item = T;
505 type IntoIter = Iter<'a, T>;
506
507 fn into_iter(self) -> Iter<'a, T> {
508 self.iter()
509 }
510}
511
512impl<T> Iterator for IntoIter<T> {
513 type Item = T;
514 fn next(&mut self) -> Option<T> {
515 self.rx.recv().ok()
516 }
517}
518
519impl<T> IntoIterator for Receiver<T> {
520 type Item = T;
521 type IntoIter = IntoIter<T>;
522
523 fn into_iter(self) -> IntoIter<T> {
524 IntoIter { rx: self }
525 }
526}
527
528/// The sending-half of Rust's asynchronous [`channel`] type. This half can only be
529/// owned by one thread, but it can be cloned to send to other threads.
530#[derive(Debug)]
531pub struct Sender<T> {
532 inner: Arc<Channel<T>>,
533}
534
535impl<T> Sender<T> {
536 /// Attempts to send a value on this channel, returning it back if it could
537 /// not be sent.
538 pub fn send(&self, t: T) -> Result<(), SendError<T>> {
539 self.inner.send(t)
540 }
541}
542
543impl<T> Clone for Sender<T> {
544 fn clone(&self) -> Self {
545 let mut state = self.inner.state.borrow_mut();
546 state.known_senders += 1;
547 drop(state);
548 Self {
549 inner: self.inner.clone(),
550 }
551 }
552}
553
554impl<T> Drop for Sender<T> {
555 fn drop(&mut self) {
556 if ExecutionState::should_stop() {
557 return;
558 }
559 let mut state = self.inner.state.borrow_mut();
560 assert!(state.known_senders > 0);
561 state.known_senders -= 1;
562 if state.known_senders == 0 {
563 // Last sender was dropped; wake up all receivers
564 for &tid in state.waiting_receivers.iter() {
565 ExecutionState::with(|s| s.get_mut(tid).unblock());
566 }
567 }
568 }
569}
570
571/// The sending-half of Rust's synchronous [`sync_channel`] type.
572///
573/// Messages can be sent through this channel with [`SyncSender::send`] or \[`try_send`\] (TODO)
574///
575/// [`SyncSender::send`] will block if there is no space in the internal buffer.
576#[derive(Debug)]
577pub struct SyncSender<T> {
578 inner: Arc<Channel<T>>,
579}
580
581impl<T> SyncSender<T> {
582 /// Sends a value on this synchronous channel.
583 ///
584 /// This function will *block* until space in the internal buffer becomes
585 /// available or a receiver is available to hand off the message to.
586 pub fn send(&self, t: T) -> Result<(), SendError<T>> {
587 self.inner.send(t)
588 }
589
590 /// Attempts to send a value on this channel without blocking.
591 ///
592 /// This method differs from [`send`] by returning immediately if the
593 /// channel's buffer is full or no receiver is waiting to acquire some
594 /// data. Compared with [`send`], this function has two failure cases
595 /// instead of one (one for disconnection, one for a full buffer).
596 ///
597 /// [`send`]: Self::send
598 pub fn try_send(&self, t: T) -> Result<(), TrySendError<T>> {
599 self.inner.try_send(t)
600 }
601}
602
603impl<T> Clone for SyncSender<T> {
604 fn clone(&self) -> Self {
605 let mut state = self.inner.state.borrow_mut();
606 state.known_senders += 1;
607 drop(state);
608 Self {
609 inner: self.inner.clone(),
610 }
611 }
612}
613
614impl<T> Drop for SyncSender<T> {
615 fn drop(&mut self) {
616 if ExecutionState::should_stop() {
617 return;
618 }
619 let mut state = self.inner.state.borrow_mut();
620 assert!(state.known_senders > 0);
621 state.known_senders -= 1;
622 if state.known_senders == 0 {
623 // Last sender was dropped; wake up any receivers
624 for &tid in state.waiting_receivers.iter() {
625 ExecutionState::with(|s| s.get_mut(tid).unblock());
626 }
627 }
628 }
629}