1use anyhow::{ensure, Result};
43use memmap::{MmapMut, MmapOptions};
44use std::sync::atomic::{
45 AtomicU64,
46 Ordering::{Relaxed, Release, SeqCst},
47};
48use std::{fs::OpenOptions, path::Path, slice};
49
50const MAGIC: u32 = 0x2d06_9f03;
51const VERSION: u32 = 1;
52const PID_MASK: u32 = 0x7fff_ffff;
53
54fn align(n: u64, alignment: u64) -> u64 {
55 ((n + alignment - 1) / alignment) * alignment
56}
57
58#[repr(C, align(128))]
59struct RingBufParams {
60 magic: u32,
61 version: u32,
62 entry_size: u64,
63 num_entries: u64,
64 control_offset: u64,
65 entries_offset: u64,
66}
67
68#[repr(C, align(128))]
69struct RingBufHeader {
70 params: RingBufParams,
72 read_idx: AtomicU64,
74 write_idx: AtomicU64,
75}
76
77struct ControlWord(u64);
82
83impl ControlWord {
84 fn new(pid: u32, len: u32) -> Self {
85 Self(((pid as u64) << 32) | (len as u64))
86 }
87 fn load(atomic: &AtomicU64) -> Self {
88 Self(atomic.load(Relaxed))
89 }
90 fn claim(&self, atomic: &AtomicU64) -> bool {
91 atomic.compare_exchange(0, self.0, SeqCst, Relaxed).is_ok()
92 }
93
94 fn len(&self) -> usize {
95 (self.0 as u32) as usize
96 }
97 fn pid(&self) -> u32 {
98 ((self.0 >> 32) as u32) & PID_MASK
99 }
100 fn is_finished(&self) -> bool {
101 self.0 >> 63 != 0
102 }
103 fn mark_finished(&self, atomic: &AtomicU64) -> bool {
104 atomic
105 .compare_exchange(self.0, (1 << 63) | self.0, SeqCst, Relaxed)
106 .is_ok()
107 }
108}
109
110struct RingBuf {
111 ptr: *const u8,
112 control_ptr: *const AtomicU64,
113 entries_ptr: *const u8,
114 num_entries: u64,
115 entry_size: u64,
116 _mmap: Option<MmapMut>,
117}
118
119impl RingBuf {
120 #[inline]
121 fn header(&self) -> &RingBufHeader {
122 unsafe { &*(self.ptr as *const RingBufHeader) }
123 }
124
125 #[inline]
126 fn control_word(&self, idx: u64) -> &AtomicU64 {
127 unsafe {
128 &*self
129 .control_ptr
130 .add((idx & (self.num_entries - 1)) as usize)
131 }
132 }
133
134 #[inline]
135 fn entry(&self, idx: u64) -> &[u8] {
136 let mask = self.num_entries - 1;
137 unsafe {
138 slice::from_raw_parts(
139 self.entries_ptr
140 .byte_add(((idx & mask) * self.entry_size) as usize),
141 self.entry_size as usize,
142 )
143 }
144 }
145
146 #[inline]
147 fn entry_mut(&self, idx: u64) -> &mut [u8] {
148 let mask = self.num_entries - 1;
149 unsafe {
150 slice::from_raw_parts_mut(
151 self.entries_ptr
152 .byte_add(((idx & mask) * self.entry_size) as usize) as *mut _,
153 self.entry_size as usize,
154 )
155 }
156 }
157}
158
159pub struct SingleConsumer {
160 ring: RingBuf,
161}
162
163pub enum StallResult {
164 Retry, SkipAndKeepTombstone, SkipAndReclaim, }
169
170impl SingleConsumer {
171 fn _open_or_create(
172 path: impl AsRef<Path>,
173 entry_size: u64,
174 num_entries: u64,
175 truncate: bool,
176 ) -> Result<Self> {
177 ensure!(
178 num_entries.is_power_of_two(),
179 "num_entries must be a power of 2, got {num_entries}"
180 );
181 ensure!(
182 entry_size <= u32::MAX as u64,
183 "entry_size must fit in 32 bits, got {entry_size}"
184 );
185
186 let pgsz = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 };
187 ensure!(pgsz >= 1, "_SC_PAGESIZE failed");
188
189 let entry_size = align(entry_size, size_of::<u64>() as u64);
190 let control_offset = size_of::<RingBufHeader>() as u64;
191 let entries_offset = align(
192 control_offset + (size_of::<ControlWord>() as u64) * num_entries,
193 pgsz,
194 );
195 let sz = align(entries_offset + entry_size * num_entries, pgsz);
196
197 let file = OpenOptions::new()
198 .read(true)
199 .write(true)
200 .create(true)
201 .truncate(truncate)
202 .open(path)?;
203 file.set_len(sz as u64)?;
204
205 let mut mmap = unsafe { MmapOptions::new().map_mut(&file) }?;
206 let header = unsafe { &mut *(mmap.as_mut_ptr() as *mut RingBufHeader) };
207 if header.params.magic == 0 {
208 header.params.magic = MAGIC;
209 header.params.version = VERSION;
210 header.params.entry_size = entry_size;
211 header.params.num_entries = num_entries;
212 header.params.control_offset = control_offset;
213 header.params.entries_offset = entries_offset;
214 header.read_idx = AtomicU64::new(0);
215 header.write_idx = AtomicU64::new(0);
216 }
217 ensure!(
218 header.params.magic == MAGIC
219 && header.params.version == VERSION
220 && header.params.entry_size == entry_size
221 );
222
223 Ok(Self {
224 ring: RingBuf {
225 ptr: mmap.as_ptr(),
226 num_entries,
227 entry_size,
228 control_ptr: unsafe {
229 mmap.as_ptr().byte_add(control_offset as usize) as *const AtomicU64
230 },
231 entries_ptr: unsafe { mmap.as_ptr().byte_add(entries_offset as usize) },
232 _mmap: Some(mmap),
233 },
234 })
235 }
236
237 pub fn create(path: impl AsRef<Path>, entry_size: u64, num_entries: u64) -> Result<Self> {
238 Self::_open_or_create(path, entry_size, num_entries, true)
239 }
240
241 pub fn open_or_create(
242 path: impl AsRef<Path>,
243 entry_size: u64,
244 num_entries: u64,
245 ) -> Result<Self> {
246 Self::_open_or_create(path, entry_size, num_entries, false)
247 }
248
249 pub fn from_buffer(
250 buf: &mut [u8],
251 entry_size: u64,
252 num_entries: u64,
253 clear: bool,
254 ) -> Result<Self> {
255 ensure!(
256 num_entries.is_power_of_two(),
257 "num_entries must be a power of 2, got {num_entries}"
258 );
259 ensure!(
260 entry_size <= u32::MAX as u64,
261 "entry_size must fit in 32 bits, got {entry_size}"
262 );
263
264 let pgsz = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 };
265 ensure!(pgsz >= 1, "_SC_PAGESIZE failed");
266
267 let entry_size = align(entry_size, size_of::<u64>() as u64);
268 let control_offset = size_of::<RingBufHeader>() as u64;
269 let entries_offset = align(
270 control_offset + (size_of::<ControlWord>() as u64) * num_entries,
271 pgsz,
272 );
273
274 if clear {
275 buf.fill(0);
276 }
277
278 let header = unsafe { &mut *(buf.as_mut_ptr() as *mut RingBufHeader) };
279 if header.params.magic == 0 {
280 header.params.magic = MAGIC;
281 header.params.version = VERSION;
282 header.params.entry_size = entry_size;
283 header.params.num_entries = num_entries;
284 header.params.control_offset = control_offset;
285 header.params.entries_offset = entries_offset;
286 header.read_idx = AtomicU64::new(0);
287 header.write_idx = AtomicU64::new(0);
288 }
289 ensure!(
290 header.params.magic == MAGIC
291 && header.params.version == VERSION
292 && header.params.entry_size == entry_size
293 );
294
295 Ok(Self {
296 ring: RingBuf {
297 ptr: buf.as_ptr(),
298 num_entries,
299 entry_size,
300 control_ptr: unsafe {
301 buf.as_ptr().byte_add(control_offset as usize) as *const AtomicU64
302 },
303 entries_ptr: unsafe { buf.as_ptr().byte_add(entries_offset as usize) },
304 _mmap: None,
305 },
306 })
307 }
308
309 pub fn pop(&self, buf: &mut [u8], mut stall: impl FnMut(u32, usize) -> StallResult) -> bool {
310 debug_assert!(buf.len() >= self.ring.entry_size as usize);
311 let header = self.ring.header();
312 let mut attempt = 0;
313 loop {
314 let ridx = header.read_idx.load(Relaxed);
315 let widx = header.write_idx.load(Relaxed);
316 debug_assert!(ridx <= widx, "ridx={ridx} widx={widx}");
317 if ridx == widx {
318 return false;
319 }
320 let ctrl = self.ring.control_word(ridx);
321 let ctrl_word = ControlWord::load(ctrl);
322
323 if !ctrl_word.is_finished() {
324 match stall(ctrl_word.pid(), attempt) {
325 StallResult::Retry => { }
327 StallResult::SkipAndKeepTombstone => {
328 header.read_idx.fetch_add(1, Release);
330 }
331 StallResult::SkipAndReclaim => {
332 ctrl.store(0, SeqCst);
336 header.read_idx.fetch_add(1, Release);
337 }
338 }
339 attempt += 1;
340 continue;
341 }
342
343 let entry = self.ring.entry(ridx);
344 let len = ctrl_word.len();
345 debug_assert!(len <= entry.len(), "len={len} entry_size={}", entry.len());
346 buf[..len].copy_from_slice(&entry[..len]);
347
348 ctrl.store(0, SeqCst);
349 header.read_idx.fetch_add(1, Relaxed);
350 return true;
351 }
352 }
353}
354
355pub struct MultiProducer {
356 ring: RingBuf,
357 tid: u32,
358}
359
360impl MultiProducer {
361 fn gettid() -> Result<u32> {
362 let tid = unsafe { libc::gettid() };
363 ensure!(tid > 0, "gettid failed");
364 let tid = tid as u32;
365 ensure!(
366 tid & PID_MASK == tid,
367 "PIDs are expected to have only 24 meaningful bits"
368 );
369 Ok(tid & PID_MASK)
370 }
371
372 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
373 let file = OpenOptions::new().read(true).write(true).open(path)?;
374 let mmap = unsafe { MmapOptions::new().map_mut(&file) }?;
375 let header = unsafe { &*(mmap.as_ptr() as *const RingBufHeader) };
376 ensure!(header.params.magic == MAGIC && header.params.version == VERSION);
377
378 Ok(Self {
379 ring: RingBuf {
380 ptr: mmap.as_ptr(),
381 control_ptr: unsafe {
382 mmap.as_ptr()
383 .byte_add(header.params.control_offset as usize)
384 as *const AtomicU64
385 },
386 entries_ptr: unsafe {
387 mmap.as_ptr()
388 .byte_add(header.params.entries_offset as usize)
389 },
390 num_entries: header.params.num_entries,
391 entry_size: header.params.entry_size,
392 _mmap: Some(mmap),
393 },
394 tid: Self::gettid()?,
395 })
396 }
397
398 pub fn from_buffer(buf: &mut [u8]) -> Result<Self> {
399 let header = unsafe { &*(buf.as_ptr() as *const RingBufHeader) };
400 ensure!(header.params.magic == MAGIC && header.params.version == VERSION);
401
402 Ok(Self {
403 ring: RingBuf {
404 ptr: buf.as_ptr(),
405 control_ptr: unsafe {
406 buf.as_ptr().byte_add(header.params.control_offset as usize) as *const AtomicU64
407 },
408 entries_ptr: unsafe {
409 buf.as_ptr().byte_add(header.params.entries_offset as usize)
410 },
411 num_entries: header.params.num_entries,
412 entry_size: header.params.entry_size,
413 _mmap: None,
414 },
415 tid: Self::gettid()?,
416 })
417 }
418
419 pub fn push(&self, data: &[u8]) -> bool {
420 debug_assert!(!data.is_empty() && data.len() <= self.ring.entry_size as usize);
421 let header = self.ring.header();
422 loop {
423 let ridx = header.read_idx.load(Relaxed);
424 let widx = header.write_idx.load(Relaxed);
425 if widx >= ridx + self.ring.num_entries {
426 return false;
428 }
429 if header
430 .write_idx
431 .compare_exchange(widx, widx + 1, SeqCst, Relaxed)
432 .is_err()
433 {
434 continue;
435 }
436
437 let ctrl = self.ring.control_word(widx);
438 let ctrl_word = ControlWord::new(self.tid, data.len() as u32);
439 if !ctrl_word.claim(ctrl) {
440 continue;
442 }
443 self.ring.entry_mut(widx)[..data.len()].copy_from_slice(data);
444
445 if !ctrl_word.mark_finished(ctrl) {
446 return false;
450 }
451
452 return true;
453 }
454 }
455}
456
457#[test]
458fn test_ring() -> Result<()> {
459 let sc = SingleConsumer::create("/tmp/myring", 8, 128)?;
460
461 let pushes = std::sync::Arc::new(AtomicU64::new(0));
462 let attempts = std::sync::Arc::new(AtomicU64::new(0));
463
464 let mut handles = vec![];
465 for i in 0..16usize {
466 let pushes = pushes.clone();
467 let attempts = attempts.clone();
468 handles.push(std::thread::spawn(move || {
469 let mp = MultiProducer::open("/tmp/myring").unwrap();
470 for j in i * 1000..i * 1000 + 100 {
471 while !mp.push(&(j.to_ne_bytes()[..])) {
472 attempts.fetch_add(1, SeqCst);
473 std::thread::yield_now();
474 }
475 pushes.fetch_add(1, SeqCst);
476 }
477 }));
478 std::thread::yield_now();
479 }
480
481 let mut res = vec![];
482 loop {
483 let mut buf = [0u8; 8];
484 while sc.pop(&mut buf, |_, _| {
485 std::thread::yield_now();
486 StallResult::Retry
487 }) {
488 res.push(unsafe { *(buf.as_ptr() as *const usize) });
489 }
490 if handles.iter().all(|h| h.is_finished()) {
491 break;
492 }
493 }
494
495 println!("{:?}", res);
496 assert_eq!(res.len(), pushes.load(SeqCst) as _);
497 println!("{}", attempts.load(SeqCst));
498
499 Ok(())
500}