mspc_ipc/
lib.rs

1////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
2// Algorithm:
3//
4// Setup: the ring buffer is made of `num_entries` (must be a power of 2), each entry being `entry_size` bytes
5// (rounded up to 8 bytes). We also hold `num_entries` control words (AtomicU64). Everything is initialized
6// to zero.
7//
8// We hold a header with immutable info (first cache line), followed by two atomic indices (ridx and widx)
9// in a separate cache line. Both indices are in the same cache line because both the consumer and the producer
10// access them. The indices are monotonically-increasing, being masked by `(num_entries-1)` when accessing
11// the arrays.
12//
13// No-std: the code itself uses std, for things like `anyhow`, but these are all done in the setup phase,
14// when opening and mmaping files. The algorithm itself does not require any runtime or OS support like futex.
15//
16// Producer: when pushing, the producer checks for room (widx<ridx+num_entries), and if there is room,
17// it attempts to advance widx by 1, using a CAS. If it fails, it tries again. Upon success, it owns slot the slot,
18// and writes the control word using a CAS. The contol word has a pid part that is unique to each producer,
19// and the length of the data (which must be <= entry_size, and also fit in 32 bits). This CAS writing may fail --
20// which means the consumer has given up on us, in which case some one else might be holding the slot. The
21// pid ensures we'd be aware of it and bail out. If we succeed in writing the control word, we proceed to
22// writing the entry's data (which may take time), followed by rewriting the control word, this time setting the
23// FINISHED bit (only if it matches the expected value).
24//
25// Consumer: when popping, the consumer checks if there are entries to read (ridx<widx). If there are, it reads
26// the control word and checks the finished bit. If it's set, it's safe to read the entry, clear the control word
27// and advance ridx. If the finished bit is not set, it means either the producer died in the middle of writing
28// the entry, or the consumer has ran into a producer which that's still busy writing the entry. In this case,
29// the consumer should stall a little and retry (either using `sched_yield`/`nanosleep` or spinning), after which
30// either the entry became finished, or the consumer gives up on this entry by clearing the control word and
31// advancing ridx.
32//
33// The only open issue is a "sleepy producer" that started writing a large entry, hanged until the consumer gave
34// up, and then woke up and continued the memcpy. In this case, it will corrupt and entry that's already taken by
35// some other producer. To solve that, the stall function takes the producer's pid. It is allowed to wait for
36// any duration or time (returning `Retry`), or to skip the entry while leaving it occupied (`Skip`), as well as
37// clearing the entry (`Reclaim`). Note that `Reclaim` is only safe to use if the producer is dead (or if you
38// can ensure it will never wake up)
39//
40////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
41
42use anyhow::{ensure, Result};
43use memmap::{MmapMut, MmapOptions};
44use std::sync::atomic::{
45    AtomicU64,
46    Ordering::{Relaxed, Release, SeqCst},
47};
48use std::{fs::OpenOptions, path::Path, slice};
49
50const MAGIC: u32 = 0x2d06_9f03;
51const VERSION: u32 = 1;
52const PID_MASK: u32 = 0x7fff_ffff;
53
54fn align(n: u64, alignment: u64) -> u64 {
55    ((n + alignment - 1) / alignment) * alignment
56}
57
58#[repr(C, align(128))]
59struct RingBufParams {
60    magic: u32,
61    version: u32,
62    entry_size: u64,
63    num_entries: u64,
64    control_offset: u64,
65    entries_offset: u64,
66}
67
68#[repr(C, align(128))]
69struct RingBufHeader {
70    // first cache line
71    params: RingBufParams,
72    // second cache line
73    read_idx: AtomicU64,
74    write_idx: AtomicU64,
75}
76
77// +----------+-----------+-------------+
78// | finished |    pid    |     len     |
79// |    (1)   |    (31)   |     (32)    |
80// +----------+-----------+-------------+
81struct ControlWord(u64);
82
83impl ControlWord {
84    fn new(pid: u32, len: u32) -> Self {
85        Self(((pid as u64) << 32) | (len as u64))
86    }
87    fn load(atomic: &AtomicU64) -> Self {
88        Self(atomic.load(Relaxed))
89    }
90    fn claim(&self, atomic: &AtomicU64) -> bool {
91        atomic.compare_exchange(0, self.0, SeqCst, Relaxed).is_ok()
92    }
93
94    fn len(&self) -> usize {
95        (self.0 as u32) as usize
96    }
97    fn pid(&self) -> u32 {
98        ((self.0 >> 32) as u32) & PID_MASK
99    }
100    fn is_finished(&self) -> bool {
101        self.0 >> 63 != 0
102    }
103    fn mark_finished(&self, atomic: &AtomicU64) -> bool {
104        atomic
105            .compare_exchange(self.0, (1 << 63) | self.0, SeqCst, Relaxed)
106            .is_ok()
107    }
108}
109
110struct RingBuf {
111    ptr: *const u8,
112    control_ptr: *const AtomicU64,
113    entries_ptr: *const u8,
114    num_entries: u64,
115    entry_size: u64,
116    _mmap: Option<MmapMut>,
117}
118
119impl RingBuf {
120    #[inline]
121    fn header(&self) -> &RingBufHeader {
122        unsafe { &*(self.ptr as *const RingBufHeader) }
123    }
124
125    #[inline]
126    fn control_word(&self, idx: u64) -> &AtomicU64 {
127        unsafe {
128            &*self
129                .control_ptr
130                .add((idx & (self.num_entries - 1)) as usize)
131        }
132    }
133
134    #[inline]
135    fn entry(&self, idx: u64) -> &[u8] {
136        let mask = self.num_entries - 1;
137        unsafe {
138            slice::from_raw_parts(
139                self.entries_ptr
140                    .byte_add(((idx & mask) * self.entry_size) as usize),
141                self.entry_size as usize,
142            )
143        }
144    }
145
146    #[inline]
147    fn entry_mut(&self, idx: u64) -> &mut [u8] {
148        let mask = self.num_entries - 1;
149        unsafe {
150            slice::from_raw_parts_mut(
151                self.entries_ptr
152                    .byte_add(((idx & mask) * self.entry_size) as usize) as *mut _,
153                self.entry_size as usize,
154            )
155        }
156    }
157}
158
159pub struct SingleConsumer {
160    ring: RingBuf,
161}
162
163pub enum StallResult {
164    Retry,                // stall (re-check the status of the entry)
165    SkipAndKeepTombstone, // skip this entry, it will be "tombstoned"
166    SkipAndReclaim,       // reclaim (clear) this entry (allow it to be reused) --
167                          // ONLY DO THIS IF THE PRODUCER IS SURELY DEAD
168}
169
170impl SingleConsumer {
171    fn _open_or_create(
172        path: impl AsRef<Path>,
173        entry_size: u64,
174        num_entries: u64,
175        truncate: bool,
176    ) -> Result<Self> {
177        ensure!(
178            num_entries.is_power_of_two(),
179            "num_entries must be a power of 2, got {num_entries}"
180        );
181        ensure!(
182            entry_size <= u32::MAX as u64,
183            "entry_size must fit in 32 bits, got {entry_size}"
184        );
185
186        let pgsz = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 };
187        ensure!(pgsz >= 1, "_SC_PAGESIZE failed");
188
189        let entry_size = align(entry_size, size_of::<u64>() as u64);
190        let control_offset = size_of::<RingBufHeader>() as u64;
191        let entries_offset = align(
192            control_offset + (size_of::<ControlWord>() as u64) * num_entries,
193            pgsz,
194        );
195        let sz = align(entries_offset + entry_size * num_entries, pgsz);
196
197        let file = OpenOptions::new()
198            .read(true)
199            .write(true)
200            .create(true)
201            .truncate(truncate)
202            .open(path)?;
203        file.set_len(sz as u64)?;
204
205        let mut mmap = unsafe { MmapOptions::new().map_mut(&file) }?;
206        let header = unsafe { &mut *(mmap.as_mut_ptr() as *mut RingBufHeader) };
207        if header.params.magic == 0 {
208            header.params.magic = MAGIC;
209            header.params.version = VERSION;
210            header.params.entry_size = entry_size;
211            header.params.num_entries = num_entries;
212            header.params.control_offset = control_offset;
213            header.params.entries_offset = entries_offset;
214            header.read_idx = AtomicU64::new(0);
215            header.write_idx = AtomicU64::new(0);
216        }
217        ensure!(
218            header.params.magic == MAGIC
219                && header.params.version == VERSION
220                && header.params.entry_size == entry_size
221        );
222
223        Ok(Self {
224            ring: RingBuf {
225                ptr: mmap.as_ptr(),
226                num_entries,
227                entry_size,
228                control_ptr: unsafe {
229                    mmap.as_ptr().byte_add(control_offset as usize) as *const AtomicU64
230                },
231                entries_ptr: unsafe { mmap.as_ptr().byte_add(entries_offset as usize) },
232                _mmap: Some(mmap),
233            },
234        })
235    }
236
237    pub fn create(path: impl AsRef<Path>, entry_size: u64, num_entries: u64) -> Result<Self> {
238        Self::_open_or_create(path, entry_size, num_entries, true)
239    }
240
241    pub fn open_or_create(
242        path: impl AsRef<Path>,
243        entry_size: u64,
244        num_entries: u64,
245    ) -> Result<Self> {
246        Self::_open_or_create(path, entry_size, num_entries, false)
247    }
248
249    pub fn from_buffer(
250        buf: &mut [u8],
251        entry_size: u64,
252        num_entries: u64,
253        clear: bool,
254    ) -> Result<Self> {
255        ensure!(
256            num_entries.is_power_of_two(),
257            "num_entries must be a power of 2, got {num_entries}"
258        );
259        ensure!(
260            entry_size <= u32::MAX as u64,
261            "entry_size must fit in 32 bits, got {entry_size}"
262        );
263
264        let pgsz = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 };
265        ensure!(pgsz >= 1, "_SC_PAGESIZE failed");
266
267        let entry_size = align(entry_size, size_of::<u64>() as u64);
268        let control_offset = size_of::<RingBufHeader>() as u64;
269        let entries_offset = align(
270            control_offset + (size_of::<ControlWord>() as u64) * num_entries,
271            pgsz,
272        );
273
274        if clear {
275            buf.fill(0);
276        }
277
278        let header = unsafe { &mut *(buf.as_mut_ptr() as *mut RingBufHeader) };
279        if header.params.magic == 0 {
280            header.params.magic = MAGIC;
281            header.params.version = VERSION;
282            header.params.entry_size = entry_size;
283            header.params.num_entries = num_entries;
284            header.params.control_offset = control_offset;
285            header.params.entries_offset = entries_offset;
286            header.read_idx = AtomicU64::new(0);
287            header.write_idx = AtomicU64::new(0);
288        }
289        ensure!(
290            header.params.magic == MAGIC
291                && header.params.version == VERSION
292                && header.params.entry_size == entry_size
293        );
294
295        Ok(Self {
296            ring: RingBuf {
297                ptr: buf.as_ptr(),
298                num_entries,
299                entry_size,
300                control_ptr: unsafe {
301                    buf.as_ptr().byte_add(control_offset as usize) as *const AtomicU64
302                },
303                entries_ptr: unsafe { buf.as_ptr().byte_add(entries_offset as usize) },
304                _mmap: None,
305            },
306        })
307    }
308
309    pub fn pop(&self, buf: &mut [u8], mut stall: impl FnMut(u32, usize) -> StallResult) -> bool {
310        debug_assert!(buf.len() >= self.ring.entry_size as usize);
311        let header = self.ring.header();
312        let mut attempt = 0;
313        loop {
314            let ridx = header.read_idx.load(Relaxed);
315            let widx = header.write_idx.load(Relaxed);
316            debug_assert!(ridx <= widx, "ridx={ridx} widx={widx}");
317            if ridx == widx {
318                return false;
319            }
320            let ctrl = self.ring.control_word(ridx);
321            let ctrl_word = ControlWord::load(ctrl);
322
323            if !ctrl_word.is_finished() {
324                match stall(ctrl_word.pid(), attempt) {
325                    StallResult::Retry => { // keep waiting
326                    }
327                    StallResult::SkipAndKeepTombstone => {
328                        // leave this entry occupied and move to the next one
329                        header.read_idx.fetch_add(1, Release);
330                    }
331                    StallResult::SkipAndReclaim => {
332                        // forcefully clear the entry and move to the next one -- should only be done if the caller
333                        // is sure the producer is dead, otherwise the producer might wake up in the future and
334                        // corrupt the entry's buffer (the memcpy part is not atomic)
335                        ctrl.store(0, SeqCst);
336                        header.read_idx.fetch_add(1, Release);
337                    }
338                }
339                attempt += 1;
340                continue;
341            }
342
343            let entry = self.ring.entry(ridx);
344            let len = ctrl_word.len();
345            debug_assert!(len <= entry.len(), "len={len} entry_size={}", entry.len());
346            buf[..len].copy_from_slice(&entry[..len]);
347
348            ctrl.store(0, SeqCst);
349            header.read_idx.fetch_add(1, Relaxed);
350            return true;
351        }
352    }
353}
354
355pub struct MultiProducer {
356    ring: RingBuf,
357    tid: u32,
358}
359
360impl MultiProducer {
361    fn gettid() -> Result<u32> {
362        let tid = unsafe { libc::gettid() };
363        ensure!(tid > 0, "gettid failed");
364        let tid = tid as u32;
365        ensure!(
366            tid & PID_MASK == tid,
367            "PIDs are expected to have only 24 meaningful bits"
368        );
369        Ok(tid & PID_MASK)
370    }
371
372    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
373        let file = OpenOptions::new().read(true).write(true).open(path)?;
374        let mmap = unsafe { MmapOptions::new().map_mut(&file) }?;
375        let header = unsafe { &*(mmap.as_ptr() as *const RingBufHeader) };
376        ensure!(header.params.magic == MAGIC && header.params.version == VERSION);
377
378        Ok(Self {
379            ring: RingBuf {
380                ptr: mmap.as_ptr(),
381                control_ptr: unsafe {
382                    mmap.as_ptr()
383                        .byte_add(header.params.control_offset as usize)
384                        as *const AtomicU64
385                },
386                entries_ptr: unsafe {
387                    mmap.as_ptr()
388                        .byte_add(header.params.entries_offset as usize)
389                },
390                num_entries: header.params.num_entries,
391                entry_size: header.params.entry_size,
392                _mmap: Some(mmap),
393            },
394            tid: Self::gettid()?,
395        })
396    }
397
398    pub fn from_buffer(buf: &mut [u8]) -> Result<Self> {
399        let header = unsafe { &*(buf.as_ptr() as *const RingBufHeader) };
400        ensure!(header.params.magic == MAGIC && header.params.version == VERSION);
401
402        Ok(Self {
403            ring: RingBuf {
404                ptr: buf.as_ptr(),
405                control_ptr: unsafe {
406                    buf.as_ptr().byte_add(header.params.control_offset as usize) as *const AtomicU64
407                },
408                entries_ptr: unsafe {
409                    buf.as_ptr().byte_add(header.params.entries_offset as usize)
410                },
411                num_entries: header.params.num_entries,
412                entry_size: header.params.entry_size,
413                _mmap: None,
414            },
415            tid: Self::gettid()?,
416        })
417    }
418
419    pub fn push(&self, data: &[u8]) -> bool {
420        debug_assert!(!data.is_empty() && data.len() <= self.ring.entry_size as usize);
421        let header = self.ring.header();
422        loop {
423            let ridx = header.read_idx.load(Relaxed);
424            let widx = header.write_idx.load(Relaxed);
425            if widx >= ridx + self.ring.num_entries {
426                // no room
427                return false;
428            }
429            if header
430                .write_idx
431                .compare_exchange(widx, widx + 1, SeqCst, Relaxed)
432                .is_err()
433            {
434                continue;
435            }
436
437            let ctrl = self.ring.control_word(widx);
438            let ctrl_word = ControlWord::new(self.tid, data.len() as u32);
439            if !ctrl_word.claim(ctrl) {
440                // this entry is taken (due to another process still holding it), skip for now
441                continue;
442            }
443            self.ring.entry_mut(widx)[..data.len()].copy_from_slice(data);
444
445            if !ctrl_word.mark_finished(ctrl) {
446                // we may have corrupted an entry now belonging to another producer during the memcpy above
447                // all we can do is signal this case by overwriting the control word to `FINISHED|0` so the
448                // consumer will not read anything from it
449                return false;
450            }
451
452            return true;
453        }
454    }
455}
456
457#[test]
458fn test_ring() -> Result<()> {
459    let sc = SingleConsumer::create("/tmp/myring", 8, 128)?;
460
461    let pushes = std::sync::Arc::new(AtomicU64::new(0));
462    let attempts = std::sync::Arc::new(AtomicU64::new(0));
463
464    let mut handles = vec![];
465    for i in 0..16usize {
466        let pushes = pushes.clone();
467        let attempts = attempts.clone();
468        handles.push(std::thread::spawn(move || {
469            let mp = MultiProducer::open("/tmp/myring").unwrap();
470            for j in i * 1000..i * 1000 + 100 {
471                while !mp.push(&(j.to_ne_bytes()[..])) {
472                    attempts.fetch_add(1, SeqCst);
473                    std::thread::yield_now();
474                }
475                pushes.fetch_add(1, SeqCst);
476            }
477        }));
478        std::thread::yield_now();
479    }
480
481    let mut res = vec![];
482    loop {
483        let mut buf = [0u8; 8];
484        while sc.pop(&mut buf, |_, _| {
485            std::thread::yield_now();
486            StallResult::Retry
487        }) {
488            res.push(unsafe { *(buf.as_ptr() as *const usize) });
489        }
490        if handles.iter().all(|h| h.is_finished()) {
491            break;
492        }
493    }
494
495    println!("{:?}", res);
496    assert_eq!(res.len(), pushes.load(SeqCst) as _);
497    println!("{}", attempts.load(SeqCst));
498
499    Ok(())
500}